Rewrite (parrot-server)

This commit is contained in:
Fierelier 2021-05-23 13:04:49 +02:00
parent 0d7d59d7fe
commit 46a32f2255
2 changed files with 160 additions and 201 deletions

View File

@ -19,237 +19,196 @@ sp = pUp(s)
# script start # script start
import threading import threading
import queue
import socket import socket
import struct import traceback
import time import time
import colorama
colorama.init()
addr = ("127.0.0.1",21779) maxConnections = 10000
maxConnectionsPerIp = 10
maxQueueSize = 1000
maxRequestSize = 4096
pauseBetweenCommands = 0.1
threads = {} serverAddr = ("127.0.0.1",21779)
threadId = 0 serverSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
threadsLock = threading.Lock()
close = False
eventHandlers = {} connectionsLock = threading.Lock()
eventHandlersLock = threading.Lock() connections = {}
connectionsId = 0
def runCode(str, lcs = False, description = "loose-code"): heartbeatTime = 600
if lcs == False: lcs = {}
code = compile(str,description,"exec")
exec(code,globals(),lcs)
return lcs
def runScript(sf, lcs = False): threadCount = 0
if lcs == False: lcs = {} threadCountLock = threading.Lock()
with open(sf) as script:
runCode(script.read(),lcs,sf)
return lcs
def getModlist(path): printLock = threading.Lock()
modList = [] def tprint(st):
for root,dirs,files in os.walk(path): with printLock:
for file in dirs: print(st)
ffile = p(root,file)
lfile = ffile.replace(path + os.path.sep,"",1)
if lfile[0] == "-": continue
if lfile[0] == "[" and lfile[-1] == "]":
modList = modList + sorted(getModlist(ffile))
continue
modList.append(ffile)
break
return modList
def triggerEvent(event,*args,**kwargs): def addThread():
with eventHandlersLock: global threadCount
handlers = eventHandlers.copy() with threadCountLock:
threadCount += 1
if not event in handlers: return tprint(colorama.Fore.YELLOW + colorama.Style.BRIGHT + "Thread opened. Threads: " +str(threadCount)+ " (Actual: " +str(threading.active_count())+ ")" + colorama.Style.RESET_ALL)
for func in handlers[event]:
cancel = func(event,*args,**kwargs)
if cancel: return True
return False
def addEventHandler(event,func): def removeThread():
with eventHandlersLock: global threadCount
if not event in eventHandlers: eventHandlers[event] = [] with threadCountLock:
eventHandlers[event].append(func) threadCount -= 1
tprint(colorama.Fore.YELLOW + colorama.Style.BRIGHT + "Thread closed. Threads: " +str(threadCount)+ " (Actual: " +str(threading.active_count())+ ")" + colorama.Style.RESET_ALL)
def sendResponse(connection,data): def sendResponse(connection,data):
connection.sendall(len(data).to_bytes(4,"big") + data) connection.sendall(len(data).to_bytes(4,"big") + data)
senderThreadSleepMin = 0.0333 def getResponse(connection):
senderThreadSleepMax = 1.0 data = b''
senderThreadSleepIncr = 0.01 data = connection.recv(4)
if not data: return False
requestLength = int.from_bytes(data,"big")
if requestLength > maxRequestSize: raise Exception("security","request_too_large")
return connection.recv(requestLength)
class senderThread(threading.Thread): def closeConnection(connectionId):
def __init__(self,connectionThread): if not connectionId in connections: return False
try:
connections[connectionId]["connection"].close()
except Exception as e:
tprint("Failed to close connection: " +str(e))
pass
try:
connections[connectionId]["threadOut"].queue.put(False)
except:
with printLock:
print(colorama.Fore.GREEN + colorama.Style.BRIGHT)
traceback.print_exc()
print(colorama.Style.RESET_ALL)
del connections[connectionId]
return True
class connectionThreadOut(threading.Thread):
def __init__(self,connectionId):
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.lock = threading.Lock() self.queue = queue.Queue()
with self.lock: self.connectionId = connectionId
self.connectionThread = connectionThread
self.queue = []
self.newQueue = False
self.sleep = senderThreadSleepMin
def closeThread(self): def getConnection(self):
with self.lock: with connectionsLock:
self.queue = [["close"]] if self.connectionId in connections:
self.newQueue = True return connections[self.connectionId]["connection"]
return False
def addToQueue(self,entry):
with self.lock:
self.queue.append(entry)
self.newQueue = True
def run(self): def run(self):
sleepTime = 0 try:
while True: while True:
with self.lock: data = self.queue.get(timeout=heartbeatTime)
sleepTime = self.sleep if data == False: return
#print(sleepTime)
time.sleep(sleepTime)
with self.lock:
if not self.newQueue:
if self.sleep < senderThreadSleepMax:
self.sleep += senderThreadSleepIncr * self.sleep
if self.sleep > senderThreadSleepMax: self.sleep = senderThreadSleepMax
continue
for entry in self.queue: connection = self.getConnection()
if entry[0] == "close": return if not connection:
entry[0](*entry[1],**entry[2]) with connectionsLock: closeConnection(self.connectionId)
self.queue = []
self.newQueue = False
self.sleep = senderThreadSleepMin
class connectionThread(threading.Thread):
def __init__(self,threadId,connection,address):
threading.Thread.__init__(self)
self.lock = threading.Lock()
with self.lock:
self.threadId = threadId
self.connection = connection
self.address = address
self.closed = False
self.user = False
self.senderThread = senderThread(self)
self.senderThread.start()
def closeThread(self):
with self.lock, threadsLock:
self.senderThread.closeThread()
try:
self.connection.close()
except:
print("failed to close connection, ignoring.")
pass
del threads[str(self.threadId)]
print("thread closed: " +str(self.threadId)+ " (open: " +str(len(threads))+ ")")
self.closed = True
def sendResponse(self,data,lock = True):
if lock == True:
with self.lock:
self.senderThread.addToQueue([sendResponse,[self.connection,data],{}])
else:
self.senderThread.addToQueue([sendResponse,[self.connection,data],{}])
def run(self):
with self.lock:
print("thread opened: " +", ".join((str(self.threadId),str(self.address))))
while True:
try:
# get request length
data = b''
data = self.connection.recv(4)
if not data:
self.closeThread()
return return
requestLength = int.from_bytes(data,"big") sendResponse(connection,data)
except Exception as e:
# inform about request with connectionsLock: closeConnection(self.connectionId)
cancel = triggerEvent("onPreRequest",self,requestLength) with printLock:
with self.lock: print(colorama.Fore.GREEN + colorama.Style.BRIGHT)
if self.closed: traceback.print_exc()
return print(colorama.Style.RESET_ALL)
if cancel: continue finally:
removeThread()
# process request
cancel = triggerEvent("onRequest",self,requestLength)
with self.lock:
if self.closed:
return
if cancel: continue
except Exception as e:
cancel = False
try:
cancel = triggerEvent("onException",self,e)
except:
self.closeThread()
raise
if cancel: continue
self.closeThread()
raise e
modulesLoaded = [] class connectionThreadIn(threading.Thread):
modulePath = p(sp,"modules") def __init__(self,connectionId):
def moduleRun(localModule): threading.Thread.__init__(self)
if not localModule in modulesLoaded: modulesLoaded.append(localModule) self.connectionId = connectionId
print("> " +localModule+ "...")
runScript(p(modulePath,localModule,"module.py"))
def moduleDepends(localModules):
if type(localModules) == str: localModules = [localModules]
for localModule in localModules: def getConnection(self):
if localModule in modulesLoaded: return with connectionsLock:
print("depend ",end="") if self.connectionId in connections:
moduleRun(localModule) return connections[self.connectionId]["connection"]
return False
def run(self):
try:
while True:
connection = self.getConnection()
if not connection:
with connectionsLock: closeConnection(self.connectionId)
return
data = getResponse(connection)
if data == False:
with connectionsLock: closeConnection(self.connectionId)
return
with connectionsLock:
if self.connectionId in connections:
queue = connections[self.connectionId]["threadOut"].queue
if queue.qsize() >= maxQueueSize:
closeConnection(self.connectionId)
return
queue.put(data)
time.sleep(pauseBetweenCommands)
except Exception as e:
with connectionsLock: closeConnection(self.connectionId)
with printLock:
print(colorama.Fore.GREEN + colorama.Style.BRIGHT)
traceback.print_exc()
print(colorama.Style.RESET_ALL)
finally:
removeThread()
def main(): def main():
print("Loading modules...") global connectionsId
for path in getModlist(modulePath): serverSocket.bind(serverAddr)
if os.path.isfile(p(path,"module.py")): serverSocket.listen(65535)
localModule = path.replace(modulePath + os.path.sep,"",1)
if not localModule in modulesLoaded:
moduleRun(localModule)
print("\nServing on " +":".join(map(str,addr))+ "!")
global socketServer
socketServer = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socketServer.bind(addr)
socketServer.listen(1000)
global threadId
global close
while True: while True:
connection, address = socketServer.accept() connection,address = serverSocket.accept()
connection.settimeout(heartbeatTime)
# inform about connection with connectionsLock:
with threadsLock: # Count connections
if close: break connectionsCount = 0
cancel = triggerEvent("onConnect",connection,address) connectionsCountIp = 0
if close: break for connectionId in connections:
if cancel: continue connectionsCount += 1
if connections[connectionId]["address"][0] == address[0]:
connectionsCountIp += 1
threadId += 1 if connectionsCount >= maxConnections:
while str(threadId) in threads: tprint("Connection closed - too many clients.")
threadId += 1 closeConnection(connectionId)
continue
thread = connectionThread(threadId,connection,address) if connectionsCountIp >= maxConnectionsPerIp:
threads[str(threadId)] = thread tprint("Connection closed - same IP connected too many times.")
thread.start() closeConnection(connectionId)
continue
# Create connection
connectionsId += 1
threadIn = connectionThreadIn(str(connectionsId))
threadOut = connectionThreadOut(str(connectionsId))
connections[str(connectionsId)] = {
"connection": connection,
"address": address,
"threadOut": threadOut,
"threadIn": threadIn,
"user": False
}
addThread()
addThread()
threadOut.start()
threadIn.start()
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -49,7 +49,7 @@ def getResponse(connection):
def main(): def main():
global connection global connection
connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
connection.connect(("127.0.0.1",21779)) connection.connect((sys.argv[1],int(sys.argv[2])))
thread = receiverThread(connection) thread = receiverThread(connection)
thread.start() thread.start()