Add connection tracker and limits

This commit is contained in:
Fierelier 2021-04-14 18:28:54 +02:00
parent ecc73b49de
commit 03ae40f615

View File

@ -26,14 +26,18 @@ import configparser
outThreads = {} outThreads = {}
inThreads = {} inThreads = {}
connections = {}
threadId = 0 threadId = 0
threadsLock = threading.Lock() threadsLock = threading.Lock()
fileLock = threading.Lock() fileLock = threading.Lock()
connectionsLock = threading.Lock()
serverAddr = ("127.0.0.1",12000) serverAddr = ("127.0.0.1",12000)
serverSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) serverSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
bufferSize = 1000 # Buffer size in bytes bufferSize = 1000 # Buffer size in bytes
maxClients = 100 # How many clients can be connected at maximum?
maxClientsPerIP = 5 # How many clients can be connected at maximum, per IP?
maxAccumulatedData = 20*1000*1000 # How much data can be in an outbound thread's queue at maximum before the connection is closed? maxAccumulatedData = 20*1000*1000 # How much data can be in an outbound thread's queue at maximum before the connection is closed?
def commandToList(cmd): def commandToList(cmd):
@ -69,21 +73,32 @@ def commandToList(cmd):
return args return args
class outThread(threading.Thread): class outThread(threading.Thread):
def __init__(self,threadId,connection,address,user): def __init__(self,threadId,user):
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.threadId = threadId self.threadId = threadId
self.queue = queue.Queue() self.queue = queue.Queue()
self.connection = connection
self.address = address
self.user = user self.user = user
self.ignore = False self.ignore = False
def closeConnection(self):
with connectionsLock:
if str(self.threadId) in connections:
try:
connections[str(self.threadId)][0].close()
except:
print("warning, closing connection failed")
del connections[str(self.threadId)]
def getConnection(self):
with connectionsLock:
if str(self.threadId) in connections:
return connections[str(self.threadId)]
return False
def closeThread(self): def closeThread(self):
with threadsLock: with threadsLock:
try: self.closeConnection()
self.connection.close()
except Exception as e:
print("closing a connection failed: " +str(e))
del outThreads[str(self.threadId)] del outThreads[str(self.threadId)]
def run(self): def run(self):
@ -94,18 +109,30 @@ class outThread(threading.Thread):
data[0](*data[1],**data[2]) data[0](*data[1],**data[2])
if data[0] == self.closeThread: return if data[0] == self.closeThread: return
continue continue
self.connection.sendall(data) self.getConnection()[0].sendall(data)
self.closeThread() self.closeThread()
except: except:
self.closeThread() self.closeThread()
raise raise
class inThread(threading.Thread): class inThread(threading.Thread):
def __init__(self,threadId,connection,address): def __init__(self,threadId):
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.threadId = threadId self.threadId = threadId
self.connection = connection
self.address = address def closeConnection(self):
with connectionsLock:
if str(self.threadId) in connections:
try:
connections[str(self.threadId)][0].close()
except:
print("warning, closing connection failed")
del connections[str(self.threadId)]
def getConnection(self):
with connectionsLock:
if str(self.threadId) in connections:
return connections[str(self.threadId)]
def closeThread(self,closeConnection = True): def closeThread(self,closeConnection = True):
with threadsLock: with threadsLock:
@ -113,24 +140,16 @@ class inThread(threading.Thread):
for thread in outThreads: for thread in outThreads:
thread = outThreads[thread] thread = outThreads[thread]
if thread.user == self.user: if thread.user == self.user:
try:
thread.queue.put((thread.closeThread,[],{})) thread.queue.put((thread.closeThread,[],{}))
thread.connection.close()
except Exception as e:
print("closing a connection failed: " +str(e))
pass
try: self.closeConnection()
self.connection.close()
except Exception as e:
print("closing a connection failed: " +str(e))
pass
del inThreads[str(self.threadId)] del inThreads[str(self.threadId)]
def run(self): def run(self):
try: try:
global threadId global threadId
data = self.connection.recv(1000) data = self.getConnection()[0].recv(1000)
if data == b"": if data == b"":
self.closeThread() self.closeThread()
return return
@ -160,7 +179,7 @@ class inThread(threading.Thread):
return return
with threadsLock: with threadsLock:
thread = outThread(self.threadId,self.connection,self.address,self.user) thread = outThread(self.threadId,self.user)
outThreads[str(self.threadId)] = thread outThreads[str(self.threadId)] = thread
thread.start() thread.start()
@ -173,7 +192,7 @@ class inThread(threading.Thread):
return return
while True: while True:
data = self.connection.recv(bufferSize) data = self.getConnection()[0].recv(bufferSize)
if data == b"": if data == b"":
self.closeThread() self.closeThread()
return return
@ -202,13 +221,31 @@ class debugThread(threading.Thread):
def run(self): def run(self):
while True: while True:
with threadsLock: with threadsLock:
print("---") print("\n---\n")
print("Threads - IN: " +str(len(inThreads))) print("Threads - IN: " +str(len(inThreads)))
print("Threads - OUT: " +str(len(outThreads))) print("Threads - OUT: " +str(len(outThreads)))
print("ACCUMULATED DATA:") print("\nACCUMULATED DATA:")
for threadId in outThreads: for threadId in outThreads:
thread = outThreads[threadId] thread = outThreads[threadId]
print(threadId + ": " + str(thread.queue.qsize() * bufferSize)) print(threadId + ": " + str(thread.queue.qsize() * bufferSize))
print("\nCONNECTIONS:")
connCount = 0
connCountIp = {}
with connectionsLock:
for connId in connections:
conn = connections[connId]
ip = conn[1][0]
if not ip in connCountIp:
connCountIp[ip] = 0
connCountIp[ip] += 1
connCount += 1
for ip in connCountIp:
print(ip+ ": " +str(connCountIp[ip]))
print("Overall: " +str(connCount))
time.sleep(1) time.sleep(1)
def main(): def main():
@ -226,10 +263,26 @@ def main():
while True: while True:
connection, address = serverSocket.accept() connection, address = serverSocket.accept()
connection.settimeout(15) connection.settimeout(15)
with connectionsLock:
clientCount = 0
ipClientCount = 0
for connId in connections:
clientCount += 1
conn = connections[connId]
if conn[1][0] == address[0]:
ipClientCount += 1
if clientCount + 1 > maxClients or ipClientCount + 1 > maxClientsPerIP:
connection.close()
continue
with threadsLock: with threadsLock:
threadId += 1 threadId += 1
with connectionsLock:
connections[str(threadId)] = (connection,address)
thread = inThread(threadId,connection,address) thread = inThread(threadId)
inThreads[str(threadId)] = thread inThreads[str(threadId)] = thread
thread.start() thread.start()