Use with lock: instead of lock.acquire() and lock.release()

This commit is contained in:
Fierelier 2021-04-09 16:45:32 +02:00
parent 6591a01459
commit 7a1ec5d0ba
4 changed files with 69 additions and 89 deletions

View File

@ -61,9 +61,8 @@ def getModlist(path):
return modList return modList
def triggerEvent(event,*args,**kwargs): def triggerEvent(event,*args,**kwargs):
eventHandlersLock.acquire() with eventHandlersLock:
handlers = eventHandlers.copy() handlers = eventHandlers.copy()
eventHandlersLock.release()
if not event in handlers: return if not event in handlers: return
for func in handlers[event]: for func in handlers[event]:
@ -73,10 +72,9 @@ def triggerEvent(event,*args,**kwargs):
return False return False
def addEventHandler(event,func): def addEventHandler(event,func):
eventHandlersLock.acquire() with eventHandlersLock:
if not event in eventHandlers: eventHandlers[event] = [] if not event in eventHandlers: eventHandlers[event] = []
eventHandlers[event].append(func) eventHandlers[event].append(func)
eventHandlersLock.release()
def sendResponse(connection,data): def sendResponse(connection,data):
connection.sendall(len(data).to_bytes(4,"big") + data) connection.sendall(len(data).to_bytes(4,"big") + data)
@ -94,25 +92,21 @@ class connectionThread(threading.Thread):
self.lock = threading.Lock() self.lock = threading.Lock()
def closeThread(self): def closeThread(self):
self.lock.acquire() with self.lock, threadsLock:
threadsLock.acquire() try:
try: self.connection.close()
self.connection.close() except:
except: print("failed to close connection, ignoring.")
print("failed to close connection, ignoring.") pass
pass
del threads[str(self.threadId)] del threads[str(self.threadId)]
print("thread closed: " +str(self.threadId)+ " (open: " +str(len(threads))+ ")") print("thread closed: " +str(self.threadId)+ " (open: " +str(len(threads))+ ")")
self.closed = True self.closed = True
threadsLock.release()
self.lock.release()
def run(self): def run(self):
self.lock.acquire()
# inform about connection # inform about connection
print("thread opened: " +", ".join((str(self.threadId),str(self.address)))) with self.lock:
self.lock.release() print("thread opened: " +", ".join((str(self.threadId),str(self.address))))
while True: while True:
try: try:
@ -128,23 +122,18 @@ class connectionThread(threading.Thread):
# inform about request # inform about request
cancel = triggerEvent("onPreRequest",self,requestLength) cancel = triggerEvent("onPreRequest",self,requestLength)
self.lock.acquire() with self.lock:
if self.closed: if self.closed:
self.lock.release() return
return
self.lock.release()
if cancel: continue if cancel: continue
# process request # process request
cancel = triggerEvent("onRequest",self,requestLength) cancel = triggerEvent("onRequest",self,requestLength)
self.lock.acquire() with self.lock:
if self.closed: if self.closed:
self.lock.release() return
return
self.lock.release()
if cancel: continue if cancel: continue
except Exception as e: except Exception as e:
#self.lock.release() - fix this
cancel = False cancel = False
try: try:
cancel = triggerEvent("onException",self,e) cancel = triggerEvent("onException",self,e)
@ -190,20 +179,19 @@ def main():
global close global close
while True: while True:
connection, address = socketServer.accept() connection, address = socketServer.accept()
threadsLock.acquire() with threadsLock:
if close: threadsLock.release(); break if close: break
cancel = triggerEvent("onConnect",connection,address) cancel = triggerEvent("onConnect",connection,address)
if close: threadsLock.release(); break if close: break
if cancel: continue if cancel: continue
threadId += 1
while str(threadId) in threads:
threadId += 1 threadId += 1
while str(threadId) in threads:
threadId += 1
thread = connectionThread(threadId,connection,address) thread = connectionThread(threadId,connection,address)
threads[str(threadId)] = thread threads[str(threadId)] = thread
thread.start() thread.start()
threadsLock.release()
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -5,9 +5,8 @@ textCommandsLock = threading.Lock()
global textCommandRun global textCommandRun
def textCommandRun(self,args): def textCommandRun(self,args):
textCommandsLock.acquire() with textCommandsLock:
commands = textCommands.copy() commands = textCommands.copy()
textCommandsLock.release()
if not args[0] in commands: if not args[0] in commands:
return ["error","nonfatal","command_not_found"] return ["error","nonfatal","command_not_found"]
@ -16,9 +15,8 @@ def textCommandRun(self,args):
global textCommandAddHandler global textCommandAddHandler
def textCommandAddHandler(command,function): def textCommandAddHandler(command,function):
textCommandsLock.acquire() with textCommandsLock:
textCommands[command] = function textCommands[command] = function
textCommandsLock.release()
global textCommandToList global textCommandToList
def textCommandToList(cmd): def textCommandToList(cmd):

View File

@ -28,20 +28,18 @@ def textUserRegister(self,command,args):
return ["error","nonfatal","invalid_name","Allowed characters: " +", ".join([char for char in textUserAllowedCharacters])] return ["error","nonfatal","invalid_name","Allowed characters: " +", ".join([char for char in textUserAllowedCharacters])]
userpath = textUserGetPath(user) userpath = textUserGetPath(user)
fileLock.acquire()
if os.path.isdir(userpath): with fileLock:
fileLock.release() if os.path.isdir(userpath):
return ["error","nonfatal","user_exists"] return ["error","nonfatal","user_exists"]
password = args[1] password = args[1]
os.makedirs(userpath) os.makedirs(userpath)
passFile = open(p(userpath,"pass.txt"),"w") passFile = open(p(userpath,"pass.txt"),"w")
passFile.write(password) passFile.write(password)
passFile.close() passFile.close()
fileLock.release() return ["ok"]
return ["ok"]
textCommandAddHandler("register",textUserRegister) textCommandAddHandler("register",textUserRegister)
global textUserLogin global textUserLogin
@ -55,36 +53,33 @@ def textUserLogin(self,command,args):
for symbol in user: for symbol in user:
if not symbol in textUserAllowedCharacters: if not symbol in textUserAllowedCharacters:
fileLock.release()
return ["error","nonfatal","invalid_name","Allowed characters: " +", ".join([char for char in textUserAllowedCharacters])] return ["error","nonfatal","invalid_name","Allowed characters: " +", ".join([char for char in textUserAllowedCharacters])]
userpath = textUserGetPath(user) userpath = textUserGetPath(user)
fileLock.acquire()
if not os.path.isdir(userpath): with fileLock:
fileLock.release() if not os.path.isdir(userpath):
return ["error","nonfatal","wrong_user_or_password"] return ["error","nonfatal","wrong_user_or_password"]
password = args[1] password = args[1]
passFile = open(p(userpath,"pass.txt"),"r") passFile = open(p(userpath,"pass.txt"),"r")
passw = passFile.read() passw = passFile.read()
passFile.close() passFile.close()
fileLock.release() if password != passw:
if password != passw: return ["error","nonfatal","wrong_user_or_password"]
return ["error","nonfatal","wrong_user_or_password"]
with self.lock:
self.user = user
self.lock.acquire()
self.user = user
self.lock.release()
return ["ok"] return ["ok"]
textCommandAddHandler("login",textUserLogin) textCommandAddHandler("login",textUserLogin)
global textUserGet global textUserGet
def textUserGet(self,command,args): def textUserGet(self,command,args):
self.lock.acquire() with self.lock:
user = self.user user = self.user
self.lock.release()
if not user: if not user:
return ["error","nonfatal","not_logged_in"] return ["error","nonfatal","not_logged_in"]

View File

@ -34,10 +34,9 @@ def textOnRequest(event,self,requestLength):
text = data.decode("utf-8") text = data.decode("utf-8")
print(":".join(map(str,self.address))+ " > " +text) print(":".join(map(str,self.address))+ " > " +text)
if text == "close": if text == "close":
threadsLock.acquire() with threadsLock:
global close global close
close = True close = True
threadsLock.release()
self.closeThread() self.closeThread()
return True return True