CSp/CS.py
changeset 0 5c129dd80d4f
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/CSp/CS.py	Thu Nov 21 14:55:10 2019 +0100
@@ -0,0 +1,315 @@
+#!/usr/bin/python3
+# coding=utf8
+
+# ring nonssl: cca 8000  rings/sec*3nodes
+# ring    ssl: cca 2600  rings/sec*3nodes
+# mash nonssl: cca 2000 mashes/sec*3nodes
+# mash    ssl: cca  300 mashes/sec*3nodes
+# u mashe je čas na přenosy úměrný počtu uzlů,
+#    kdežto čas na connect/close je úměrný počtu spojů tj. kvadrátu počtu uzlů
+
+import time
+import socket
+import errno
+import os
+import signal
+import sys
+import pickle
+import ssl
+import select
+import random
+import multiprocessing
+import threading
+
+
+class Debug():
+    def __init__(self, debid):
+        self.debid = debid
+    def log(self, level, *msg):
+        if level <= maxDebLev:
+            log_lock.acquire()
+            print("{:10.6f} {}:".format(time.time()-t0, self.debid), *msg, file=sys.stderr)
+            sys.stderr.flush()
+            log_lock.release()
+    def abend(self, s):
+        self.log(0, s)
+        os.kill(0, signal.SIGTERM)
+
+        
+class Node(Debug):
+    def __init__(self, debid, forwarding, topo, port, p0, pn, issl):
+        Debug.__init__(self, "{} node {}".format(debid, port))
+        self.topo = topo
+        self.locPort = port
+        self.p0 = p0
+        self.pn = pn
+        self.ssc = None
+        self.cli_side = {}
+        self.srv_side = {}
+        self.kicker = (self.locPort == self.pn)
+        self.forwarding = forwarding                # in forwarding kicker task indicates when TTL reached 0
+        self.closing = False
+        self.payload = None 
+        self.issl = issl
+        if issl:                                                                                                                                           
+            self.log(4, "setting SSL context...")  
+            self.sslCert = cePath + "/certs/{}.pem".format(port)
+            self.sslKey = cePath + "/keys/{}.key".format(port)                                                                                
+            try:                                                                                                                                                
+                self.ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)                                                                                                  
+                self.ctx.verify_mode = ssl.CERT_REQUIRED
+                self.log(5, "cert={}, key={}".format(self.sslCert, self.sslKey))                                                                                                        
+                self.ctx.load_cert_chain(self.sslCert, self.sslKey)
+                self.ctx.load_verify_locations(None, caPath)                                                                              
+            except ssl.SSLError as e: self.abend("SSL context: {}".format(str(e)))                                                                                                                            
+    def run(self):        
+        self.bind()
+        self.payload = Data(self.debid, self.locPort)
+        if self.kicker:
+            nxt = self.next_node()
+            self.log(1, "kicker ready to send '{}' to {}".format(self.payload.digest(), nxt))                                                                                                                                                                                                                                                                                       
+            if not nxt in self.cli_side: self.conn(nxt)
+            self.payload.put(self.cli_side[nxt])
+        if self.forwarding[0].value: wait_list = (self.ssc,)
+        while wait_list:
+            self.log(4, "select...")
+            self.log(5, "waitlist:", *(sc.fileno() for sc in wait_list))
+            ready_list = select.select(wait_list, (), (), sel_TO)
+            self.log(5, "readylist:", *(sc.fileno() for sc in ready_list[0]))
+            for sc in ready_list[0]:
+                self.log(5, "sc {} ready...".format(sc.fileno()))
+                if sc == self.ssc: sc = self.acc()                    
+                self.forward(sc)
+            wait_list = ()
+            self.log(4, "forwarding={}".format(self.forwarding[0].value)) 
+            if self.forwarding[0].value: wait_list += (self.ssc,)            # when off, no new connection will come on ssc 
+            else: self.close_cli()
+            for sc in self.srv_side.values(): wait_list += (sc,)
+        self.close_srv()
+        signal.signal(signal.SIGUSR2, sighand)
+        last = 0
+        ctr_lock.acquire()
+        active.value -= 1
+        if active.value == 0: last = 1        
+        ctr_lock.release()
+        if last: os.kill(0, signal.SIGUSR2)
+        else: signal.pause()
+        self.log(2, "ended")
+        os._exit(0)
+    def forward(self, sc):
+        if not self.payload.get(sc):
+            self.log(5, "delete srv_side[{}]".format(sc.fileno()))
+            del self.srv_side[sc]
+            self.log(4, "closing {}...".format(sc.fileno()))
+            sc.close()
+        else:
+            self.log(5, "received data from {}".format(self.payload.rport))
+            ctr_lock.acquire(); 
+            forwards.value += 1; 
+            ctr_lock.release()
+            if self.kicker:
+                self.payload.dttl()
+                self.log(3, "received from {}: {}, ttl={}".format(self.topo, self.payload.digest(), self.payload.ttl))                                       
+                if self.payload.ttl <= 0:                                                                                                                   
+                    self.log(1, "received after passing all {}: {}".\
+                             format("mashes" if self.topo == "mash" else "rings", self.payload.digest()))
+                    self.forwarding[0].value = 0                                                                                                                    
+                    return                                                                                                                     
+            nxt = self.next_node()
+            self.log(5, "forwarding to {}...".format(nxt))
+            if pacing: time.sleep(pace)                                                                                                                                                                                                                                                                                       
+            if not nxt in self.cli_side: self.conn(nxt)
+            self.payload.put(self.cli_side[nxt])
+            self.log(5, "forwarded to {}".format(nxt))
+    def next_node(self):
+        if self.topo == "ring":                                                                                                                                 
+            if self.kicker: nxt = self.p0                                                                                                                      
+            else: nxt = self.locPort + 1                                                                                                                       
+        else:                                                                                                                                                   
+            nxt = self.locPort                                                                                                                                 
+            while nxt == self.locPort:                                                                                                                        
+                nxt = random.randint(self.p0, self.pn)                                                                                                         
+        return nxt        
+    def bind(self):
+            self.log(4, "binding...")
+            try:
+                self.ssc = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+                self.ssc.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+            except Exception as e: 
+                self.abend("ssc alloc: {}".format(e.strerror))
+            if self.issl:
+                self.log(4, "ssc SSL wrap")
+                try: self.ssc = self.ctx.wrap_socket(self.ssc)
+                except ssl.SSLError as e: self.abend("ssc SSL wrap: {}".format(str(e)))
+            try:
+                self.ssc.bind(("127.0.0.1", self.locPort))
+                self.ssc.listen(1)
+            except Exception as e: self.abend("bind: {}".format(e.strerror))
+            self.log(2, "bound")
+    def conn(self, remPort):
+        self.log(4, "connecting to {}...".format(remPort))
+        try: sc = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        except Exception as e: self.abend("socket alloc: {}".format(e.strerror))
+        if self.issl:
+            self.log(4, "sc SSL wrap")
+            try: sc = self.ctx.wrap_socket(sc)
+            except ssl.SSLError as e: self.abend("sc SSL wrap: {}".format(str(e)))
+        retry = connThreshold
+        connected = False
+        while not connected and retry > 0:
+            try:
+                sc.connect(("127.0.0.1", remPort))
+                connected = True
+            except Exception as e:
+                if e.errno == errno.ECONNREFUSED:
+                    retry = retry - 1
+                    time.sleep(conn_TO)
+                else: self.abend("connect: {}".format(str(e)))
+        if retry == 0: self.abend("connection refused, threshold {} reached".format(connThreshold))
+        ctr_lock.acquire();
+        connects.value += 1;
+        ctr_lock.release()
+        try: self.cli_side[remPort] = sc.makefile("wb")
+        except Exception as e: self.abend("client side makefile: {}".format(str(e)))
+        self.log(2, "connected to {} after {} retries".format(remPort, connThreshold - retry))
+    def acc(self):
+        self.log(4, "accepting...")
+        try:
+            ac = self.ssc.accept()
+            sc = ac[0]
+        except Exception as e: self.abend("accept: {}".format(str(e)))
+        try: sc = sc.makefile("rb")
+        except Exception as e: self.abend("srv side makefile: {}".format(str(e)))
+        self.srv_side[sc] = sc
+        self.log(2, "accepted on sc={}, addr={}".format(sc.fileno(), ac[1]))
+        return sc
+    def close_srv(self):
+        self.log(5, "closing ssc...")
+        if self.ssc: self.ssc.close()
+    def close_cli(self):
+        def do_close():
+            self.log(5, "closing clients...")
+            scs = self.cli_side.values()
+            self.cli_side.clear()
+            for sc in scs: sc.close()
+            self.log(4, "all clients closed")
+        if not self.closing:
+            threading.Thread(target = do_close, name = "client {} close".format(self.locPort)).start()
+            self.closing = True
+
+
+class Data(Debug):
+    def __init__(self, debid, port):
+        self.debid = debid + " payload"
+        self.clear()
+        self.ttl = ittl
+        self.lport = port
+        self.rport = 0
+        self.text = ""
+    def clear(self):
+        (self.ttl, self.rport, self.text) = 3 * (None,)
+    def put(self, sc):
+        self.log(5, "sending via {}...".format(sc.fileno()))
+        if self.ttl == None: self.ttl = ittl
+        if self.text == None: self.text = itext
+        try:
+            pickle.dump((self.ttl, self.lport, self.text), sc)
+            sc.flush()
+        except Exception as e: self.abend("send: {}".format(str(e)))
+    def get(self, sc):
+        self.log(5, "reading from {}...".format(sc.fileno()))
+        self.clear()
+        try:
+            (self.ttl, self.rport, self.text) = pickle.load(sc)
+            return True
+        except Exception as e:
+            if isinstance(e, EOFError): return False
+            else: self.abend("receive: {}".format(str(e)))
+    def dttl(self):
+        self.ttl -= 1
+        return self.ttl
+    def digest(self):
+        return self.text if len(self.text) < 24 else self.text[0:8]+"--------"+self.text[-8:]
+    def toString(self):
+        return "ttl={}, from port={}, text={}".\
+            format(str(self.ttl), str(self.rport), self.digest())
+
+class Constellation(Debug):
+    def __init__(self, issl, topo, p0, n):
+        Debug.__init__(self, "{}SSL {}".format("" if issl else "non", topo.upper()))
+    def run(self, issl, topo, p0, n):
+        signal.signal(signal.SIGUSR2, signal.SIG_IGN)
+        forwarding = [multiprocessing.Value('i', 1, lock=False)]                    # list is passed by reference
+        if n == 1:
+            self.log(0, "one-node configuration is not implemented")
+        else:
+            self.log(1, "{} nodes starting...".format(n))
+            p0 += 500 if issl else 0
+            pn = p0 + n - 1
+            for port in range(p0, p0 + n):
+                pid = os.fork()
+                if not pid: Node(self.debid, forwarding, topo, port, p0, pn, issl).run()
+                else: self.log(4, "node {} started in process {}".format(port, pid))
+        self.log(2, "all nodes established")
+        while 1:
+            try:
+                p = os.wait()
+                if p[1] & 255:
+                    self.log(4, "pid {} killed by {}".format(p[0], p[1] & 255))
+                else:
+                    self.log(4, "pid {} returned  {}".format(p[0], p[1] >> 8))
+            except: break
+        os._exit(0)
+
+def ga(key, default):
+    return os.environ[key] if key in os.environ else default
+def gi(key, default):
+    return int(ga(key, default))
+def sighand(signal, frame):
+    pass
+
+debug = Debug("client/server demo")
+log_lock = multiprocessing.Lock()
+ctr_lock = multiprocessing.Lock()
+forwards = multiprocessing.Value('i', 0)
+connects = multiprocessing.Value('i', 0)
+active = multiprocessing.Value('i', 0)
+t0 = time.time()
+maxDebLev = gi('DEB', 0)
+mn = rn = gi('N', 0)
+rp0 = gi('RP0', 11000)
+rn = gi('RN', 3)
+mp0 = gi('MP0', 12000)
+mn = gi('MN', 3)
+itext = ga('T', "bla bla")
+ittl = gi('TTL', 3)
+pace = float(os.environ["P"]) if "P" in os.environ else 0
+pacing = 1 if pace > 0 else 0
+random.seed(gi('RS', 0))
+connThreshold = 77
+conn_TO = 0.01
+sel_TO = 1
+issl = gi('SSL', 0)
+caPath = "/home/local/etc/ssl/certs/"
+sslPathSuff = "/../CS/"
+cePath = os.environ["CEP"] if "CEP" in os.environ else os.path.dirname(sys.argv[0]) + sslPathSuff
+active.value = mn + rn
+if issl > 1: active.value *= 2
+signal.signal(signal.SIGUSR2, signal.SIG_IGN)
+debug.log(1, "pgm={}, ttl={}, pace={}, seed={}, SSL mask={}, debug={}".format(sys.argv[0], ittl, pace, gi('RS', 0), issl, maxDebLev))
+if issl > 0: debug.log(3, "ssl path: {}, CA path: {}".format(cePath, caPath))
+          
+if issl < 2: issl = (issl,)
+else: issl = (0, 1)
+if ittl > 0:
+    for ss in issl:
+        topo = ("mash", "ring")
+        p0 = (mp0, rp0)
+        n = (mn, rn)
+        for p in zip(topo, p0, n): 
+            if not os.fork(): 
+                Constellation(ss, *p).run(ss, *p)
+    while 1:
+        try: os.wait()
+        except: break
+debug.log(1, "final balance: forwards={}, connections={}".format(forwards.value, connects.value))