CSp/CS.py
author hh
Thu, 21 Nov 2019 14:55:10 +0100
changeset 0 5c129dd80d4f
permissions -rwxr-xr-x
--

#!/usr/bin/python3
# coding=utf8

# ring nonssl: cca 8000  rings/sec*3nodes
# ring    ssl: cca 2600  rings/sec*3nodes
# mash nonssl: cca 2000 mashes/sec*3nodes
# mash    ssl: cca  300 mashes/sec*3nodes
# u mashe je čas na přenosy úměrný počtu uzlů,
#    kdežto čas na connect/close je úměrný počtu spojů tj. kvadrátu počtu uzlů

import time
import socket
import errno
import os
import signal
import sys
import pickle
import ssl
import select
import random
import multiprocessing
import threading


class Debug():
    def __init__(self, debid):
        self.debid = debid
    def log(self, level, *msg):
        if level <= maxDebLev:
            log_lock.acquire()
            print("{:10.6f} {}:".format(time.time()-t0, self.debid), *msg, file=sys.stderr)
            sys.stderr.flush()
            log_lock.release()
    def abend(self, s):
        self.log(0, s)
        os.kill(0, signal.SIGTERM)

        
class Node(Debug):
    def __init__(self, debid, forwarding, topo, port, p0, pn, issl):
        Debug.__init__(self, "{} node {}".format(debid, port))
        self.topo = topo
        self.locPort = port
        self.p0 = p0
        self.pn = pn
        self.ssc = None
        self.cli_side = {}
        self.srv_side = {}
        self.kicker = (self.locPort == self.pn)
        self.forwarding = forwarding                # in forwarding kicker task indicates when TTL reached 0
        self.closing = False
        self.payload = None 
        self.issl = issl
        if issl:                                                                                                                                           
            self.log(4, "setting SSL context...")  
            self.sslCert = cePath + "/certs/{}.pem".format(port)
            self.sslKey = cePath + "/keys/{}.key".format(port)                                                                                
            try:                                                                                                                                                
                self.ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)                                                                                                  
                self.ctx.verify_mode = ssl.CERT_REQUIRED
                self.log(5, "cert={}, key={}".format(self.sslCert, self.sslKey))                                                                                                        
                self.ctx.load_cert_chain(self.sslCert, self.sslKey)
                self.ctx.load_verify_locations(None, caPath)                                                                              
            except ssl.SSLError as e: self.abend("SSL context: {}".format(str(e)))                                                                                                                            
    def run(self):        
        self.bind()
        self.payload = Data(self.debid, self.locPort)
        if self.kicker:
            nxt = self.next_node()
            self.log(1, "kicker ready to send '{}' to {}".format(self.payload.digest(), nxt))                                                                                                                                                                                                                                                                                       
            if not nxt in self.cli_side: self.conn(nxt)
            self.payload.put(self.cli_side[nxt])
        if self.forwarding[0].value: wait_list = (self.ssc,)
        while wait_list:
            self.log(4, "select...")
            self.log(5, "waitlist:", *(sc.fileno() for sc in wait_list))
            ready_list = select.select(wait_list, (), (), sel_TO)
            self.log(5, "readylist:", *(sc.fileno() for sc in ready_list[0]))
            for sc in ready_list[0]:
                self.log(5, "sc {} ready...".format(sc.fileno()))
                if sc == self.ssc: sc = self.acc()                    
                self.forward(sc)
            wait_list = ()
            self.log(4, "forwarding={}".format(self.forwarding[0].value)) 
            if self.forwarding[0].value: wait_list += (self.ssc,)            # when off, no new connection will come on ssc 
            else: self.close_cli()
            for sc in self.srv_side.values(): wait_list += (sc,)
        self.close_srv()
        signal.signal(signal.SIGUSR2, sighand)
        last = 0
        ctr_lock.acquire()
        active.value -= 1
        if active.value == 0: last = 1        
        ctr_lock.release()
        if last: os.kill(0, signal.SIGUSR2)
        else: signal.pause()
        self.log(2, "ended")
        os._exit(0)
    def forward(self, sc):
        if not self.payload.get(sc):
            self.log(5, "delete srv_side[{}]".format(sc.fileno()))
            del self.srv_side[sc]
            self.log(4, "closing {}...".format(sc.fileno()))
            sc.close()
        else:
            self.log(5, "received data from {}".format(self.payload.rport))
            ctr_lock.acquire(); 
            forwards.value += 1; 
            ctr_lock.release()
            if self.kicker:
                self.payload.dttl()
                self.log(3, "received from {}: {}, ttl={}".format(self.topo, self.payload.digest(), self.payload.ttl))                                       
                if self.payload.ttl <= 0:                                                                                                                   
                    self.log(1, "received after passing all {}: {}".\
                             format("mashes" if self.topo == "mash" else "rings", self.payload.digest()))
                    self.forwarding[0].value = 0                                                                                                                    
                    return                                                                                                                     
            nxt = self.next_node()
            self.log(5, "forwarding to {}...".format(nxt))
            if pacing: time.sleep(pace)                                                                                                                                                                                                                                                                                       
            if not nxt in self.cli_side: self.conn(nxt)
            self.payload.put(self.cli_side[nxt])
            self.log(5, "forwarded to {}".format(nxt))
    def next_node(self):
        if self.topo == "ring":                                                                                                                                 
            if self.kicker: nxt = self.p0                                                                                                                      
            else: nxt = self.locPort + 1                                                                                                                       
        else:                                                                                                                                                   
            nxt = self.locPort                                                                                                                                 
            while nxt == self.locPort:                                                                                                                        
                nxt = random.randint(self.p0, self.pn)                                                                                                         
        return nxt        
    def bind(self):
            self.log(4, "binding...")
            try:
                self.ssc = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                self.ssc.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            except Exception as e: 
                self.abend("ssc alloc: {}".format(e.strerror))
            if self.issl:
                self.log(4, "ssc SSL wrap")
                try: self.ssc = self.ctx.wrap_socket(self.ssc)
                except ssl.SSLError as e: self.abend("ssc SSL wrap: {}".format(str(e)))
            try:
                self.ssc.bind(("127.0.0.1", self.locPort))
                self.ssc.listen(1)
            except Exception as e: self.abend("bind: {}".format(e.strerror))
            self.log(2, "bound")
    def conn(self, remPort):
        self.log(4, "connecting to {}...".format(remPort))
        try: sc = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        except Exception as e: self.abend("socket alloc: {}".format(e.strerror))
        if self.issl:
            self.log(4, "sc SSL wrap")
            try: sc = self.ctx.wrap_socket(sc)
            except ssl.SSLError as e: self.abend("sc SSL wrap: {}".format(str(e)))
        retry = connThreshold
        connected = False
        while not connected and retry > 0:
            try:
                sc.connect(("127.0.0.1", remPort))
                connected = True
            except Exception as e:
                if e.errno == errno.ECONNREFUSED:
                    retry = retry - 1
                    time.sleep(conn_TO)
                else: self.abend("connect: {}".format(str(e)))
        if retry == 0: self.abend("connection refused, threshold {} reached".format(connThreshold))
        ctr_lock.acquire();
        connects.value += 1;
        ctr_lock.release()
        try: self.cli_side[remPort] = sc.makefile("wb")
        except Exception as e: self.abend("client side makefile: {}".format(str(e)))
        self.log(2, "connected to {} after {} retries".format(remPort, connThreshold - retry))
    def acc(self):
        self.log(4, "accepting...")
        try:
            ac = self.ssc.accept()
            sc = ac[0]
        except Exception as e: self.abend("accept: {}".format(str(e)))
        try: sc = sc.makefile("rb")
        except Exception as e: self.abend("srv side makefile: {}".format(str(e)))
        self.srv_side[sc] = sc
        self.log(2, "accepted on sc={}, addr={}".format(sc.fileno(), ac[1]))
        return sc
    def close_srv(self):
        self.log(5, "closing ssc...")
        if self.ssc: self.ssc.close()
    def close_cli(self):
        def do_close():
            self.log(5, "closing clients...")
            scs = self.cli_side.values()
            self.cli_side.clear()
            for sc in scs: sc.close()
            self.log(4, "all clients closed")
        if not self.closing:
            threading.Thread(target = do_close, name = "client {} close".format(self.locPort)).start()
            self.closing = True


class Data(Debug):
    def __init__(self, debid, port):
        self.debid = debid + " payload"
        self.clear()
        self.ttl = ittl
        self.lport = port
        self.rport = 0
        self.text = ""
    def clear(self):
        (self.ttl, self.rport, self.text) = 3 * (None,)
    def put(self, sc):
        self.log(5, "sending via {}...".format(sc.fileno()))
        if self.ttl == None: self.ttl = ittl
        if self.text == None: self.text = itext
        try:
            pickle.dump((self.ttl, self.lport, self.text), sc)
            sc.flush()
        except Exception as e: self.abend("send: {}".format(str(e)))
    def get(self, sc):
        self.log(5, "reading from {}...".format(sc.fileno()))
        self.clear()
        try:
            (self.ttl, self.rport, self.text) = pickle.load(sc)
            return True
        except Exception as e:
            if isinstance(e, EOFError): return False
            else: self.abend("receive: {}".format(str(e)))
    def dttl(self):
        self.ttl -= 1
        return self.ttl
    def digest(self):
        return self.text if len(self.text) < 24 else self.text[0:8]+"--------"+self.text[-8:]
    def toString(self):
        return "ttl={}, from port={}, text={}".\
            format(str(self.ttl), str(self.rport), self.digest())

class Constellation(Debug):
    def __init__(self, issl, topo, p0, n):
        Debug.__init__(self, "{}SSL {}".format("" if issl else "non", topo.upper()))
    def run(self, issl, topo, p0, n):
        signal.signal(signal.SIGUSR2, signal.SIG_IGN)
        forwarding = [multiprocessing.Value('i', 1, lock=False)]                    # list is passed by reference
        if n == 1:
            self.log(0, "one-node configuration is not implemented")
        else:
            self.log(1, "{} nodes starting...".format(n))
            p0 += 500 if issl else 0
            pn = p0 + n - 1
            for port in range(p0, p0 + n):
                pid = os.fork()
                if not pid: Node(self.debid, forwarding, topo, port, p0, pn, issl).run()
                else: self.log(4, "node {} started in process {}".format(port, pid))
        self.log(2, "all nodes established")
        while 1:
            try:
                p = os.wait()
                if p[1] & 255:
                    self.log(4, "pid {} killed by {}".format(p[0], p[1] & 255))
                else:
                    self.log(4, "pid {} returned  {}".format(p[0], p[1] >> 8))
            except: break
        os._exit(0)

def ga(key, default):
    return os.environ[key] if key in os.environ else default
def gi(key, default):
    return int(ga(key, default))
def sighand(signal, frame):
    pass

debug = Debug("client/server demo")
log_lock = multiprocessing.Lock()
ctr_lock = multiprocessing.Lock()
forwards = multiprocessing.Value('i', 0)
connects = multiprocessing.Value('i', 0)
active = multiprocessing.Value('i', 0)
t0 = time.time()
maxDebLev = gi('DEB', 0)
mn = rn = gi('N', 0)
rp0 = gi('RP0', 11000)
rn = gi('RN', 3)
mp0 = gi('MP0', 12000)
mn = gi('MN', 3)
itext = ga('T', "bla bla")
ittl = gi('TTL', 3)
pace = float(os.environ["P"]) if "P" in os.environ else 0
pacing = 1 if pace > 0 else 0
random.seed(gi('RS', 0))
connThreshold = 77
conn_TO = 0.01
sel_TO = 1
issl = gi('SSL', 0)
caPath = "/home/local/etc/ssl/certs/"
sslPathSuff = "/../CS/"
cePath = os.environ["CEP"] if "CEP" in os.environ else os.path.dirname(sys.argv[0]) + sslPathSuff
active.value = mn + rn
if issl > 1: active.value *= 2
signal.signal(signal.SIGUSR2, signal.SIG_IGN)
debug.log(1, "pgm={}, ttl={}, pace={}, seed={}, SSL mask={}, debug={}".format(sys.argv[0], ittl, pace, gi('RS', 0), issl, maxDebLev))
if issl > 0: debug.log(3, "ssl path: {}, CA path: {}".format(cePath, caPath))
          
if issl < 2: issl = (issl,)
else: issl = (0, 1)
if ittl > 0:
    for ss in issl:
        topo = ("mash", "ring")
        p0 = (mp0, rp0)
        n = (mn, rn)
        for p in zip(topo, p0, n): 
            if not os.fork(): 
                Constellation(ss, *p).run(ss, *p)
    while 1:
        try: os.wait()
        except: break
debug.log(1, "final balance: forwards={}, connections={}".format(forwards.value, connects.value))