diff --git a/fstream-server.py b/fstream-server.py index 8f23c32..412d7b0 100644 --- a/fstream-server.py +++ b/fstream-server.py @@ -26,14 +26,18 @@ import configparser outThreads = {} inThreads = {} +connections = {} threadId = 0 threadsLock = threading.Lock() fileLock = threading.Lock() +connectionsLock = threading.Lock() serverAddr = ("127.0.0.1",12000) serverSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) bufferSize = 1000 # Buffer size in bytes +maxClients = 100 # How many clients can be connected at maximum? +maxClientsPerIP = 5 # How many clients can be connected at maximum, per IP? maxAccumulatedData = 20*1000*1000 # How much data can be in an outbound thread's queue at maximum before the connection is closed? def commandToList(cmd): @@ -69,21 +73,32 @@ def commandToList(cmd): return args class outThread(threading.Thread): - def __init__(self,threadId,connection,address,user): + def __init__(self,threadId,user): threading.Thread.__init__(self) self.threadId = threadId self.queue = queue.Queue() - self.connection = connection - self.address = address self.user = user self.ignore = False + def closeConnection(self): + with connectionsLock: + if str(self.threadId) in connections: + try: + connections[str(self.threadId)][0].close() + except: + print("warning, closing connection failed") + del connections[str(self.threadId)] + + def getConnection(self): + with connectionsLock: + if str(self.threadId) in connections: + return connections[str(self.threadId)] + + return False + def closeThread(self): with threadsLock: - try: - self.connection.close() - except Exception as e: - print("closing a connection failed: " +str(e)) + self.closeConnection() del outThreads[str(self.threadId)] def run(self): @@ -94,18 +109,30 @@ class outThread(threading.Thread): data[0](*data[1],**data[2]) if data[0] == self.closeThread: return continue - self.connection.sendall(data) + self.getConnection()[0].sendall(data) self.closeThread() except: self.closeThread() raise class inThread(threading.Thread): - def __init__(self,threadId,connection,address): + def __init__(self,threadId): threading.Thread.__init__(self) self.threadId = threadId - self.connection = connection - self.address = address + + def closeConnection(self): + with connectionsLock: + if str(self.threadId) in connections: + try: + connections[str(self.threadId)][0].close() + except: + print("warning, closing connection failed") + del connections[str(self.threadId)] + + def getConnection(self): + with connectionsLock: + if str(self.threadId) in connections: + return connections[str(self.threadId)] def closeThread(self,closeConnection = True): with threadsLock: @@ -113,24 +140,16 @@ class inThread(threading.Thread): for thread in outThreads: thread = outThreads[thread] if thread.user == self.user: - try: - thread.queue.put((thread.closeThread,[],{})) - thread.connection.close() - except Exception as e: - print("closing a connection failed: " +str(e)) - pass + thread.queue.put((thread.closeThread,[],{})) - try: - self.connection.close() - except Exception as e: - print("closing a connection failed: " +str(e)) - pass + self.closeConnection() + del inThreads[str(self.threadId)] def run(self): try: global threadId - data = self.connection.recv(1000) + data = self.getConnection()[0].recv(1000) if data == b"": self.closeThread() return @@ -160,7 +179,7 @@ class inThread(threading.Thread): return with threadsLock: - thread = outThread(self.threadId,self.connection,self.address,self.user) + thread = outThread(self.threadId,self.user) outThreads[str(self.threadId)] = thread thread.start() @@ -173,7 +192,7 @@ class inThread(threading.Thread): return while True: - data = self.connection.recv(bufferSize) + data = self.getConnection()[0].recv(bufferSize) if data == b"": self.closeThread() return @@ -202,13 +221,31 @@ class debugThread(threading.Thread): def run(self): while True: with threadsLock: - print("---") + print("\n---\n") print("Threads - IN: " +str(len(inThreads))) print("Threads - OUT: " +str(len(outThreads))) - print("ACCUMULATED DATA:") + print("\nACCUMULATED DATA:") for threadId in outThreads: thread = outThreads[threadId] print(threadId + ": " + str(thread.queue.qsize() * bufferSize)) + + print("\nCONNECTIONS:") + connCount = 0 + connCountIp = {} + with connectionsLock: + for connId in connections: + conn = connections[connId] + ip = conn[1][0] + if not ip in connCountIp: + connCountIp[ip] = 0 + connCountIp[ip] += 1 + connCount += 1 + + for ip in connCountIp: + print(ip+ ": " +str(connCountIp[ip])) + + print("Overall: " +str(connCount)) + time.sleep(1) def main(): @@ -226,10 +263,26 @@ def main(): while True: connection, address = serverSocket.accept() connection.settimeout(15) + + with connectionsLock: + clientCount = 0 + ipClientCount = 0 + for connId in connections: + clientCount += 1 + conn = connections[connId] + if conn[1][0] == address[0]: + ipClientCount += 1 + + if clientCount + 1 > maxClients or ipClientCount + 1 > maxClientsPerIP: + connection.close() + continue + with threadsLock: threadId += 1 + with connectionsLock: + connections[str(threadId)] = (connection,address) - thread = inThread(threadId,connection,address) + thread = inThread(threadId) inThreads[str(threadId)] = thread thread.start()