#!/usr/bin/python3
# coding=utf8
# ring nonssl: cca 8000 rings/sec*3nodes
# ring ssl: cca 2600 rings/sec*3nodes
# mash nonssl: cca 2000 mashes/sec*3nodes
# mash ssl: cca 300 mashes/sec*3nodes
# u mashe je čas na přenosy úměrný počtu uzlů,
# kdežto čas na connect/close je úměrný počtu spojů tj. kvadrátu počtu uzlů
import time
import socket
import errno
import os
import signal
import sys
import pickle
import ssl
import select
import random
import multiprocessing
import threading
class Debug():
def __init__(self, debid):
self.debid = debid
def log(self, level, *msg):
if level <= maxDebLev:
log_lock.acquire()
print("{:10.6f} {}:".format(time.time()-t0, self.debid), *msg, file=sys.stderr)
sys.stderr.flush()
log_lock.release()
def abend(self, s):
self.log(0, s)
os.kill(0, signal.SIGTERM)
class Node(Debug):
def __init__(self, debid, forwarding, topo, port, p0, pn, issl):
Debug.__init__(self, "{} node {}".format(debid, port))
self.topo = topo
self.locPort = port
self.p0 = p0
self.pn = pn
self.ssc = None
self.cli_side = {}
self.srv_side = {}
self.kicker = (self.locPort == self.pn)
self.forwarding = forwarding # in forwarding kicker task indicates when TTL reached 0
self.closing = False
self.payload = None
self.issl = issl
if issl:
self.log(4, "setting SSL context...")
self.sslCert = cePath + "/certs/{}.pem".format(port)
self.sslKey = cePath + "/keys/{}.key".format(port)
try:
self.ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
self.ctx.verify_mode = ssl.CERT_REQUIRED
self.log(5, "cert={}, key={}".format(self.sslCert, self.sslKey))
self.ctx.load_cert_chain(self.sslCert, self.sslKey)
self.ctx.load_verify_locations(None, caPath)
except ssl.SSLError as e: self.abend("SSL context: {}".format(str(e)))
def run(self):
self.bind()
self.payload = Data(self.debid, self.locPort)
if self.kicker:
nxt = self.next_node()
self.log(1, "kicker ready to send '{}' to {}".format(self.payload.digest(), nxt))
if not nxt in self.cli_side: self.conn(nxt)
self.payload.put(self.cli_side[nxt])
if self.forwarding[0].value: wait_list = (self.ssc,)
while wait_list:
self.log(4, "select...")
self.log(5, "waitlist:", *(sc.fileno() for sc in wait_list))
ready_list = select.select(wait_list, (), (), sel_TO)
self.log(5, "readylist:", *(sc.fileno() for sc in ready_list[0]))
for sc in ready_list[0]:
self.log(5, "sc {} ready...".format(sc.fileno()))
if sc == self.ssc: sc = self.acc()
self.forward(sc)
wait_list = ()
self.log(4, "forwarding={}".format(self.forwarding[0].value))
if self.forwarding[0].value: wait_list += (self.ssc,) # when off, no new connection will come on ssc
else: self.close_cli()
for sc in self.srv_side.values(): wait_list += (sc,)
self.close_srv()
signal.signal(signal.SIGUSR2, sighand)
last = 0
ctr_lock.acquire()
active.value -= 1
if active.value == 0: last = 1
ctr_lock.release()
if last: os.kill(0, signal.SIGUSR2)
else: signal.pause()
self.log(2, "ended")
os._exit(0)
def forward(self, sc):
if not self.payload.get(sc):
self.log(5, "delete srv_side[{}]".format(sc.fileno()))
del self.srv_side[sc]
self.log(4, "closing {}...".format(sc.fileno()))
sc.close()
else:
self.log(5, "received data from {}".format(self.payload.rport))
ctr_lock.acquire();
forwards.value += 1;
ctr_lock.release()
if self.kicker:
self.payload.dttl()
self.log(3, "received from {}: {}, ttl={}".format(self.topo, self.payload.digest(), self.payload.ttl))
if self.payload.ttl <= 0:
self.log(1, "received after passing all {}: {}".\
format("mashes" if self.topo == "mash" else "rings", self.payload.digest()))
self.forwarding[0].value = 0
return
nxt = self.next_node()
self.log(5, "forwarding to {}...".format(nxt))
if pacing: time.sleep(pace)
if not nxt in self.cli_side: self.conn(nxt)
self.payload.put(self.cli_side[nxt])
self.log(5, "forwarded to {}".format(nxt))
def next_node(self):
if self.topo == "ring":
if self.kicker: nxt = self.p0
else: nxt = self.locPort + 1
else:
nxt = self.locPort
while nxt == self.locPort:
nxt = random.randint(self.p0, self.pn)
return nxt
def bind(self):
self.log(4, "binding...")
try:
self.ssc = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.ssc.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
except Exception as e:
self.abend("ssc alloc: {}".format(e.strerror))
if self.issl:
self.log(4, "ssc SSL wrap")
try: self.ssc = self.ctx.wrap_socket(self.ssc)
except ssl.SSLError as e: self.abend("ssc SSL wrap: {}".format(str(e)))
try:
self.ssc.bind(("127.0.0.1", self.locPort))
self.ssc.listen(1)
except Exception as e: self.abend("bind: {}".format(e.strerror))
self.log(2, "bound")
def conn(self, remPort):
self.log(4, "connecting to {}...".format(remPort))
try: sc = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
except Exception as e: self.abend("socket alloc: {}".format(e.strerror))
if self.issl:
self.log(4, "sc SSL wrap")
try: sc = self.ctx.wrap_socket(sc)
except ssl.SSLError as e: self.abend("sc SSL wrap: {}".format(str(e)))
retry = connThreshold
connected = False
while not connected and retry > 0:
try:
sc.connect(("127.0.0.1", remPort))
connected = True
except Exception as e:
if e.errno == errno.ECONNREFUSED:
retry = retry - 1
time.sleep(conn_TO)
else: self.abend("connect: {}".format(str(e)))
if retry == 0: self.abend("connection refused, threshold {} reached".format(connThreshold))
ctr_lock.acquire();
connects.value += 1;
ctr_lock.release()
try: self.cli_side[remPort] = sc.makefile("wb")
except Exception as e: self.abend("client side makefile: {}".format(str(e)))
self.log(2, "connected to {} after {} retries".format(remPort, connThreshold - retry))
def acc(self):
self.log(4, "accepting...")
try:
ac = self.ssc.accept()
sc = ac[0]
except Exception as e: self.abend("accept: {}".format(str(e)))
try: sc = sc.makefile("rb")
except Exception as e: self.abend("srv side makefile: {}".format(str(e)))
self.srv_side[sc] = sc
self.log(2, "accepted on sc={}, addr={}".format(sc.fileno(), ac[1]))
return sc
def close_srv(self):
self.log(5, "closing ssc...")
if self.ssc: self.ssc.close()
def close_cli(self):
def do_close():
self.log(5, "closing clients...")
scs = self.cli_side.values()
self.cli_side.clear()
for sc in scs: sc.close()
self.log(4, "all clients closed")
if not self.closing:
threading.Thread(target = do_close, name = "client {} close".format(self.locPort)).start()
self.closing = True
class Data(Debug):
def __init__(self, debid, port):
self.debid = debid + " payload"
self.clear()
self.ttl = ittl
self.lport = port
self.rport = 0
self.text = ""
def clear(self):
(self.ttl, self.rport, self.text) = 3 * (None,)
def put(self, sc):
self.log(5, "sending via {}...".format(sc.fileno()))
if self.ttl == None: self.ttl = ittl
if self.text == None: self.text = itext
try:
pickle.dump((self.ttl, self.lport, self.text), sc)
sc.flush()
except Exception as e: self.abend("send: {}".format(str(e)))
def get(self, sc):
self.log(5, "reading from {}...".format(sc.fileno()))
self.clear()
try:
(self.ttl, self.rport, self.text) = pickle.load(sc)
return True
except Exception as e:
if isinstance(e, EOFError): return False
else: self.abend("receive: {}".format(str(e)))
def dttl(self):
self.ttl -= 1
return self.ttl
def digest(self):
return self.text if len(self.text) < 24 else self.text[0:8]+"--------"+self.text[-8:]
def toString(self):
return "ttl={}, from port={}, text={}".\
format(str(self.ttl), str(self.rport), self.digest())
class Constellation(Debug):
def __init__(self, issl, topo, p0, n):
Debug.__init__(self, "{}SSL {}".format("" if issl else "non", topo.upper()))
def run(self, issl, topo, p0, n):
signal.signal(signal.SIGUSR2, signal.SIG_IGN)
forwarding = [multiprocessing.Value('i', 1, lock=False)] # list is passed by reference
if n == 1:
self.log(0, "one-node configuration is not implemented")
else:
self.log(1, "{} nodes starting...".format(n))
p0 += 500 if issl else 0
pn = p0 + n - 1
for port in range(p0, p0 + n):
pid = os.fork()
if not pid: Node(self.debid, forwarding, topo, port, p0, pn, issl).run()
else: self.log(4, "node {} started in process {}".format(port, pid))
self.log(2, "all nodes established")
while 1:
try:
p = os.wait()
if p[1] & 255:
self.log(4, "pid {} killed by {}".format(p[0], p[1] & 255))
else:
self.log(4, "pid {} returned {}".format(p[0], p[1] >> 8))
except: break
os._exit(0)
def ga(key, default):
return os.environ[key] if key in os.environ else default
def gi(key, default):
return int(ga(key, default))
def sighand(signal, frame):
pass
debug = Debug("client/server demo")
log_lock = multiprocessing.Lock()
ctr_lock = multiprocessing.Lock()
forwards = multiprocessing.Value('i', 0)
connects = multiprocessing.Value('i', 0)
active = multiprocessing.Value('i', 0)
t0 = time.time()
maxDebLev = gi('DEB', 0)
mn = rn = gi('N', 0)
rp0 = gi('RP0', 11000)
rn = gi('RN', 3)
mp0 = gi('MP0', 12000)
mn = gi('MN', 3)
itext = ga('T', "bla bla")
ittl = gi('TTL', 3)
pace = float(os.environ["P"]) if "P" in os.environ else 0
pacing = 1 if pace > 0 else 0
random.seed(gi('RS', 0))
connThreshold = 77
conn_TO = 0.01
sel_TO = 1
issl = gi('SSL', 0)
caPath = "/home/local/etc/ssl/certs/"
sslPathSuff = "/../CS/"
cePath = os.environ["CEP"] if "CEP" in os.environ else os.path.dirname(sys.argv[0]) + sslPathSuff
active.value = mn + rn
if issl > 1: active.value *= 2
signal.signal(signal.SIGUSR2, signal.SIG_IGN)
debug.log(1, "pgm={}, ttl={}, pace={}, seed={}, SSL mask={}, debug={}".format(sys.argv[0], ittl, pace, gi('RS', 0), issl, maxDebLev))
if issl > 0: debug.log(3, "ssl path: {}, CA path: {}".format(cePath, caPath))
if issl < 2: issl = (issl,)
else: issl = (0, 1)
if ittl > 0:
for ss in issl:
topo = ("mash", "ring")
p0 = (mp0, rp0)
n = (mn, rn)
for p in zip(topo, p0, n):
if not os.fork():
Constellation(ss, *p).run(ss, *p)
while 1:
try: os.wait()
except: break
debug.log(1, "final balance: forwards={}, connections={}".format(forwards.value, connects.value))