Add connection tracker and limits

This commit is contained in:
Fierelier 2021-04-14 18:28:54 +02:00
parent ecc73b49de
commit 03ae40f615

View File

@ -26,14 +26,18 @@ import configparser
outThreads = {}
inThreads = {}
connections = {}
threadId = 0
threadsLock = threading.Lock()
fileLock = threading.Lock()
connectionsLock = threading.Lock()
serverAddr = ("127.0.0.1",12000)
serverSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
bufferSize = 1000 # Buffer size in bytes
maxClients = 100 # How many clients can be connected at maximum?
maxClientsPerIP = 5 # How many clients can be connected at maximum, per IP?
maxAccumulatedData = 20*1000*1000 # How much data can be in an outbound thread's queue at maximum before the connection is closed?
def commandToList(cmd):
@ -69,21 +73,32 @@ def commandToList(cmd):
return args
class outThread(threading.Thread):
def __init__(self,threadId,connection,address,user):
def __init__(self,threadId,user):
threading.Thread.__init__(self)
self.threadId = threadId
self.queue = queue.Queue()
self.connection = connection
self.address = address
self.user = user
self.ignore = False
def closeConnection(self):
with connectionsLock:
if str(self.threadId) in connections:
try:
connections[str(self.threadId)][0].close()
except:
print("warning, closing connection failed")
del connections[str(self.threadId)]
def getConnection(self):
with connectionsLock:
if str(self.threadId) in connections:
return connections[str(self.threadId)]
return False
def closeThread(self):
with threadsLock:
try:
self.connection.close()
except Exception as e:
print("closing a connection failed: " +str(e))
self.closeConnection()
del outThreads[str(self.threadId)]
def run(self):
@ -94,18 +109,30 @@ class outThread(threading.Thread):
data[0](*data[1],**data[2])
if data[0] == self.closeThread: return
continue
self.connection.sendall(data)
self.getConnection()[0].sendall(data)
self.closeThread()
except:
self.closeThread()
raise
class inThread(threading.Thread):
def __init__(self,threadId,connection,address):
def __init__(self,threadId):
threading.Thread.__init__(self)
self.threadId = threadId
self.connection = connection
self.address = address
def closeConnection(self):
with connectionsLock:
if str(self.threadId) in connections:
try:
connections[str(self.threadId)][0].close()
except:
print("warning, closing connection failed")
del connections[str(self.threadId)]
def getConnection(self):
with connectionsLock:
if str(self.threadId) in connections:
return connections[str(self.threadId)]
def closeThread(self,closeConnection = True):
with threadsLock:
@ -113,24 +140,16 @@ class inThread(threading.Thread):
for thread in outThreads:
thread = outThreads[thread]
if thread.user == self.user:
try:
thread.queue.put((thread.closeThread,[],{}))
thread.connection.close()
except Exception as e:
print("closing a connection failed: " +str(e))
pass
thread.queue.put((thread.closeThread,[],{}))
self.closeConnection()
try:
self.connection.close()
except Exception as e:
print("closing a connection failed: " +str(e))
pass
del inThreads[str(self.threadId)]
def run(self):
try:
global threadId
data = self.connection.recv(1000)
data = self.getConnection()[0].recv(1000)
if data == b"":
self.closeThread()
return
@ -160,7 +179,7 @@ class inThread(threading.Thread):
return
with threadsLock:
thread = outThread(self.threadId,self.connection,self.address,self.user)
thread = outThread(self.threadId,self.user)
outThreads[str(self.threadId)] = thread
thread.start()
@ -173,7 +192,7 @@ class inThread(threading.Thread):
return
while True:
data = self.connection.recv(bufferSize)
data = self.getConnection()[0].recv(bufferSize)
if data == b"":
self.closeThread()
return
@ -202,13 +221,31 @@ class debugThread(threading.Thread):
def run(self):
while True:
with threadsLock:
print("---")
print("\n---\n")
print("Threads - IN: " +str(len(inThreads)))
print("Threads - OUT: " +str(len(outThreads)))
print("ACCUMULATED DATA:")
print("\nACCUMULATED DATA:")
for threadId in outThreads:
thread = outThreads[threadId]
print(threadId + ": " + str(thread.queue.qsize() * bufferSize))
print("\nCONNECTIONS:")
connCount = 0
connCountIp = {}
with connectionsLock:
for connId in connections:
conn = connections[connId]
ip = conn[1][0]
if not ip in connCountIp:
connCountIp[ip] = 0
connCountIp[ip] += 1
connCount += 1
for ip in connCountIp:
print(ip+ ": " +str(connCountIp[ip]))
print("Overall: " +str(connCount))
time.sleep(1)
def main():
@ -226,10 +263,26 @@ def main():
while True:
connection, address = serverSocket.accept()
connection.settimeout(15)
with connectionsLock:
clientCount = 0
ipClientCount = 0
for connId in connections:
clientCount += 1
conn = connections[connId]
if conn[1][0] == address[0]:
ipClientCount += 1
if clientCount + 1 > maxClients or ipClientCount + 1 > maxClientsPerIP:
connection.close()
continue
with threadsLock:
threadId += 1
with connectionsLock:
connections[str(threadId)] = (connection,address)
thread = inThread(threadId,connection,address)
thread = inThread(threadId)
inThreads[str(threadId)] = thread
thread.start()