Add connection tracker and limits
This commit is contained in:
parent
ecc73b49de
commit
03ae40f615
@ -26,14 +26,18 @@ import configparser
|
||||
|
||||
outThreads = {}
|
||||
inThreads = {}
|
||||
connections = {}
|
||||
threadId = 0
|
||||
threadsLock = threading.Lock()
|
||||
fileLock = threading.Lock()
|
||||
connectionsLock = threading.Lock()
|
||||
|
||||
serverAddr = ("127.0.0.1",12000)
|
||||
serverSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
|
||||
bufferSize = 1000 # Buffer size in bytes
|
||||
maxClients = 100 # How many clients can be connected at maximum?
|
||||
maxClientsPerIP = 5 # How many clients can be connected at maximum, per IP?
|
||||
maxAccumulatedData = 20*1000*1000 # How much data can be in an outbound thread's queue at maximum before the connection is closed?
|
||||
|
||||
def commandToList(cmd):
|
||||
@ -69,21 +73,32 @@ def commandToList(cmd):
|
||||
return args
|
||||
|
||||
class outThread(threading.Thread):
|
||||
def __init__(self,threadId,connection,address,user):
|
||||
def __init__(self,threadId,user):
|
||||
threading.Thread.__init__(self)
|
||||
self.threadId = threadId
|
||||
self.queue = queue.Queue()
|
||||
self.connection = connection
|
||||
self.address = address
|
||||
self.user = user
|
||||
self.ignore = False
|
||||
|
||||
def closeConnection(self):
|
||||
with connectionsLock:
|
||||
if str(self.threadId) in connections:
|
||||
try:
|
||||
connections[str(self.threadId)][0].close()
|
||||
except:
|
||||
print("warning, closing connection failed")
|
||||
del connections[str(self.threadId)]
|
||||
|
||||
def getConnection(self):
|
||||
with connectionsLock:
|
||||
if str(self.threadId) in connections:
|
||||
return connections[str(self.threadId)]
|
||||
|
||||
return False
|
||||
|
||||
def closeThread(self):
|
||||
with threadsLock:
|
||||
try:
|
||||
self.connection.close()
|
||||
except Exception as e:
|
||||
print("closing a connection failed: " +str(e))
|
||||
self.closeConnection()
|
||||
del outThreads[str(self.threadId)]
|
||||
|
||||
def run(self):
|
||||
@ -94,18 +109,30 @@ class outThread(threading.Thread):
|
||||
data[0](*data[1],**data[2])
|
||||
if data[0] == self.closeThread: return
|
||||
continue
|
||||
self.connection.sendall(data)
|
||||
self.getConnection()[0].sendall(data)
|
||||
self.closeThread()
|
||||
except:
|
||||
self.closeThread()
|
||||
raise
|
||||
|
||||
class inThread(threading.Thread):
|
||||
def __init__(self,threadId,connection,address):
|
||||
def __init__(self,threadId):
|
||||
threading.Thread.__init__(self)
|
||||
self.threadId = threadId
|
||||
self.connection = connection
|
||||
self.address = address
|
||||
|
||||
def closeConnection(self):
|
||||
with connectionsLock:
|
||||
if str(self.threadId) in connections:
|
||||
try:
|
||||
connections[str(self.threadId)][0].close()
|
||||
except:
|
||||
print("warning, closing connection failed")
|
||||
del connections[str(self.threadId)]
|
||||
|
||||
def getConnection(self):
|
||||
with connectionsLock:
|
||||
if str(self.threadId) in connections:
|
||||
return connections[str(self.threadId)]
|
||||
|
||||
def closeThread(self,closeConnection = True):
|
||||
with threadsLock:
|
||||
@ -113,24 +140,16 @@ class inThread(threading.Thread):
|
||||
for thread in outThreads:
|
||||
thread = outThreads[thread]
|
||||
if thread.user == self.user:
|
||||
try:
|
||||
thread.queue.put((thread.closeThread,[],{}))
|
||||
thread.connection.close()
|
||||
except Exception as e:
|
||||
print("closing a connection failed: " +str(e))
|
||||
pass
|
||||
thread.queue.put((thread.closeThread,[],{}))
|
||||
|
||||
try:
|
||||
self.connection.close()
|
||||
except Exception as e:
|
||||
print("closing a connection failed: " +str(e))
|
||||
pass
|
||||
self.closeConnection()
|
||||
|
||||
del inThreads[str(self.threadId)]
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
global threadId
|
||||
data = self.connection.recv(1000)
|
||||
data = self.getConnection()[0].recv(1000)
|
||||
if data == b"":
|
||||
self.closeThread()
|
||||
return
|
||||
@ -160,7 +179,7 @@ class inThread(threading.Thread):
|
||||
return
|
||||
|
||||
with threadsLock:
|
||||
thread = outThread(self.threadId,self.connection,self.address,self.user)
|
||||
thread = outThread(self.threadId,self.user)
|
||||
outThreads[str(self.threadId)] = thread
|
||||
thread.start()
|
||||
|
||||
@ -173,7 +192,7 @@ class inThread(threading.Thread):
|
||||
return
|
||||
|
||||
while True:
|
||||
data = self.connection.recv(bufferSize)
|
||||
data = self.getConnection()[0].recv(bufferSize)
|
||||
if data == b"":
|
||||
self.closeThread()
|
||||
return
|
||||
@ -202,13 +221,31 @@ class debugThread(threading.Thread):
|
||||
def run(self):
|
||||
while True:
|
||||
with threadsLock:
|
||||
print("---")
|
||||
print("\n---\n")
|
||||
print("Threads - IN: " +str(len(inThreads)))
|
||||
print("Threads - OUT: " +str(len(outThreads)))
|
||||
print("ACCUMULATED DATA:")
|
||||
print("\nACCUMULATED DATA:")
|
||||
for threadId in outThreads:
|
||||
thread = outThreads[threadId]
|
||||
print(threadId + ": " + str(thread.queue.qsize() * bufferSize))
|
||||
|
||||
print("\nCONNECTIONS:")
|
||||
connCount = 0
|
||||
connCountIp = {}
|
||||
with connectionsLock:
|
||||
for connId in connections:
|
||||
conn = connections[connId]
|
||||
ip = conn[1][0]
|
||||
if not ip in connCountIp:
|
||||
connCountIp[ip] = 0
|
||||
connCountIp[ip] += 1
|
||||
connCount += 1
|
||||
|
||||
for ip in connCountIp:
|
||||
print(ip+ ": " +str(connCountIp[ip]))
|
||||
|
||||
print("Overall: " +str(connCount))
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
def main():
|
||||
@ -226,10 +263,26 @@ def main():
|
||||
while True:
|
||||
connection, address = serverSocket.accept()
|
||||
connection.settimeout(15)
|
||||
|
||||
with connectionsLock:
|
||||
clientCount = 0
|
||||
ipClientCount = 0
|
||||
for connId in connections:
|
||||
clientCount += 1
|
||||
conn = connections[connId]
|
||||
if conn[1][0] == address[0]:
|
||||
ipClientCount += 1
|
||||
|
||||
if clientCount + 1 > maxClients or ipClientCount + 1 > maxClientsPerIP:
|
||||
connection.close()
|
||||
continue
|
||||
|
||||
with threadsLock:
|
||||
threadId += 1
|
||||
with connectionsLock:
|
||||
connections[str(threadId)] = (connection,address)
|
||||
|
||||
thread = inThread(threadId,connection,address)
|
||||
thread = inThread(threadId)
|
||||
inThreads[str(threadId)] = thread
|
||||
thread.start()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user