aes-tunnel/aes-tunnel

257 lines
7.0 KiB
Python
Executable File

#!/usr/bin/env python3
import sys
oldexcepthook = sys.excepthook
def newexcepthook(type,value,traceback):
oldexcepthook(type,value,traceback)
#input("Press ENTER to quit.")
sys.excepthook = newexcepthook
import os
p = os.path.join
pUp = os.path.dirname
s = False
if getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS'):
s = os.path.realpath(sys.executable)
else:
s = os.path.realpath(__file__)
sp = pUp(s)
# script start
import socket
import time
import threading
import queue
try:
from Cryptodome.Cipher import AES
from Cryptodome import Random
except:
from Crypto.Cipher import AES
from Crypto import Random
bufferMax = 4096 # preferably, don't change this
def forceClose(conn):
try:
conn.close()
except:
pass
def recv(conn,l):
start = time.process_time()
timeo = conn.gettimeout()
bytes = bytearray()
while l > 0:
b = conn.recv(l)
if not b: raise ConnectionResetError
if time.process_time() - start > timeo: raise TimeoutError
bytes += b
l -= len(b)
return bytes
def padData(data,block_size,length = None):
if length == None: length = len(data)
overage = length % block_size
return data + (b"\x00" * (block_size - overage))
def padLength(block_size,length):
overage = length % block_size
return length + block_size - overage
def getTimespan(curTime,lastTime):
span = curTime - lastTime
if span < 0.0: span = 0.0
return span
class server_thread_encrypt(threading.Thread):
def __init__(self):
super().__init__()
self.c_conn_recv = None
self.c_conn_send = None
self.c_shared = None
def recv(self,length):
while True:
try:
data = self.c_conn_recv.recv(length)
with self.c_shared["lock"]: self.c_shared["activity"] = time.time()
return data
except TimeoutError:
with self.c_shared["lock"]:
if getTimespan(time.time(),self.c_shared["activity"]) > timeout:
raise TimeoutError
def recvall(self,length):
while True:
try:
data = recv(self.c_conn_recv,length)
with self.c_shared["lock"]: self.c_shared["activity"] = time.time()
return data
except TimeoutError:
with self.c_shared["lock"]:
if getTimespan(time.time(),self.c_shared["activity"]) > timeout:
raise TimeoutError
def sendall(self,data):
self.c_conn_send.sendall(data)
with self.c_shared["lock"]: self.c_shared["activity"] = time.time()
def stop(self):
forceClose(self.c_conn_recv)
forceClose(self.c_conn_send)
def run(self):
try:
while True:
data = self.recv(bufferMax)
length = len(data)
if length == 0: raise Exception("Connection closed: no data")
iv = Random.get_random_bytes(AES.block_size)
cipher = AES.new(pw, AES.MODE_CBC, iv=iv)
lengthEnc = padData(length.to_bytes(4,"big"),AES.block_size,4)
data = cipher.encrypt(padData(lengthEnc + data,AES.block_size,length + AES.block_size))
self.sendall(iv)
self.sendall(data)
except Exception:
self.stop()
raise
class server_thread_decrypt(threading.Thread):
def __init__(self):
super().__init__()
self.c_conn_recv = None
self.c_conn_send = None
self.c_shared = None
def recv(self,length):
while True:
try:
data = self.c_conn_recv.recv(length)
with self.c_shared["lock"]: self.c_shared["activity"] = time.time()
return data
except TimeoutError:
with self.c_shared["lock"]:
if getTimespan(time.time(),self.c_shared["activity"]) > timeout:
raise TimeoutError
def recvall(self,length):
while True:
try:
data = recv(self.c_conn_recv,length)
with self.c_shared["lock"]: self.c_shared["activity"] = time.time()
return data
except TimeoutError:
with self.c_shared["lock"]:
if getTimespan(time.time(),self.c_shared["activity"]) > timeout:
raise TimeoutError
def sendall(self,data):
self.c_conn_send.sendall(data)
with self.c_shared["lock"]: self.c_shared["activity"] = time.time()
def stop(self):
forceClose(self.c_conn_recv)
forceClose(self.c_conn_send)
def run(self):
try:
while True:
iv = self.recvall(AES.block_size)
cipher = AES.new(pw, AES.MODE_CBC, iv = iv)
lengthEnc = cipher.decrypt(self.recvall(AES.block_size))
length = int.from_bytes(lengthEnc[:4],"big")
if length == 0: raise Exception("Connection closed: no data")
if length > 4096: raise Exception("Connection closed: buffer too big")
if lengthEnc[4:AES.block_size] != (b"\x00" * (AES.block_size - 4)):
raise Exception("Connection closed: buffer length mangled (wrong password?)")
data = cipher.decrypt(self.recvall(padLength(AES.block_size,length)))[:length]
self.sendall(data)
except Exception:
self.stop()
raise
class dispatch_thread(threading.Thread):
def __init__(self):
super().__init__()
self.c_client_conn = None
self.c_client_addr = None
def run(self):
server_conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if nodelay: server_conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
server_conn.settimeout(timeout)
server_conn.connect((out_host,out_port))
if direction == 1:
thread_in = server_thread_encrypt()
thread_in.c_conn_recv = self.c_client_conn
thread_in.c_conn_send = server_conn
thread_out = server_thread_decrypt()
thread_out.c_conn_recv = server_conn
thread_out.c_conn_send = self.c_client_conn
else:
thread_in = server_thread_decrypt()
thread_in.c_conn_recv = self.c_client_conn
thread_in.c_conn_send = server_conn
thread_out = server_thread_encrypt()
thread_out.c_conn_recv = server_conn
thread_out.c_conn_send = self.c_client_conn
shared = {}
shared["lock"] = threading.Lock()
shared["activity"] = time.time()
thread_in.c_shared = shared
thread_out.c_shared = shared
thread_in.start()
thread_out.start()
pw = os.environ["TUNNEL_ENC_PASS"]
pw = pw.encode("utf-8")
while len(pw) < 32: pw = pw + pw
pw = pw[:32]
out_host = os.environ["TUNNEL_OUT_HOST"]
out_port = int(os.environ["TUNNEL_OUT_PORT"])
nodelay = ("TUNNEL_NODELAY" in os.environ)
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if nodelay: server.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
server.bind((os.environ["TUNNEL_IN_HOST"],int(os.environ["TUNNEL_IN_PORT"])))
server.listen(65535)
print("Ready to take connections on " +os.environ["TUNNEL_IN_HOST"]+ ":" +os.environ["TUNNEL_IN_PORT"]+ "!")
direction = None
if os.environ["TUNNEL_ENCRYPT"] == "out":
direction = 1
elif os.environ["TUNNEL_ENCRYPT"] == "in":
direction = 0
else:
raise Exception("TUNNEL_ENCRYPT needs to be set to either out or in")
if direction == 1:
print("* In: Plain")
print("* Out: Encrypted")
else:
print("* In: Encrypted")
print("* Out: Plain")
if nodelay: print("TCP_NODELAY is enabled.")
timeout = 15.0
if "TUNNEL_TIMEOUT" in os.environ: timeout = float(os.environ["TUNNEL_TIMEOUT"])
while True:
try:
client_conn, client_addr = server.accept()
client_conn.settimeout(timeout)
thr_dispatch = dispatch_thread()
thr_dispatch.c_client_conn = client_conn
thr_dispatch.c_client_addr = client_addr
thr_dispatch.start()
except Exception:
raise