|
from app.app import Application as BaseApplication, get_timestamp
|
|
import asyncio
|
|
from concurrent.futures import ThreadPoolExecutor as Executor
|
|
import time
|
|
|
|
ZAMENYAT_THREAD_COUNT = 500
|
|
ZAMENYAT_BUFFER_SIZE = 4096*2
|
|
ZAMENYAT_HEADER_MAX_LENGTH = 4096*2
|
|
|
|
class AsyncWriter:
|
|
|
|
def __init__(self, writer, buffer_size=ZAMENYAT_BUFFER_SIZE):
|
|
self.writer = writer
|
|
self.buffer_size = buffer_size
|
|
self.drain = self.writer.drain
|
|
self.close = self.writer.close
|
|
self.wait_closed = self.writer.wait_closed
|
|
|
|
async def write(self, data):
|
|
|
|
while data:
|
|
chunk_size = self.buffer_size if len(data) > self.buffer_size else len(data)
|
|
chunk = data[:chunk_size]
|
|
self.writer.write(chunk)
|
|
data = data[chunk_size:]
|
|
await self.writer.drain()
|
|
|
|
class AsyncReader:
|
|
|
|
def __init__(self, reader):
|
|
self.reader = reader
|
|
self.buffer = b''
|
|
|
|
async def read(self, buffer_size=ZAMENYAT_BUFFER_SIZE, exact=False):
|
|
read_extra = buffer_size - len(self.buffer)
|
|
while len(self.buffer) < buffer_size:
|
|
chunk_size = buffer_size - len(self.buffer)
|
|
chunk = await self.reader.read(chunk_size)
|
|
if not chunk:
|
|
return None
|
|
self.buffer += chunk
|
|
if not exact:
|
|
break
|
|
buffer_size = len(self.buffer) if len(self.buffer) < buffer_size else buffer_size
|
|
data = self.buffer[:buffer_size]
|
|
self.buffer = self.buffer[buffer_size:]
|
|
return data
|
|
|
|
async def unread(self, data):
|
|
if not data:
|
|
return
|
|
if hasattr(data, 'encode'):
|
|
data = data.encode()
|
|
self.buffer = data + self.buffer
|
|
|
|
class Socket:
|
|
|
|
def __init__(self, reader, writer, buffer_size):
|
|
self.reader = AsyncReader(reader)
|
|
self.writer = AsyncWriter(writer)
|
|
self.read = self.reader.read
|
|
self.unread = self.reader.unread
|
|
self.write = self.writer.write
|
|
self.drain = self.writer.drain
|
|
self.close = self.writer.close
|
|
self.wait_closed = self.writer.wait_closed
|
|
|
|
class Application:
|
|
|
|
def __init__(self, upstream_host, upstream_port, *args, **kwargs):
|
|
self.upstream_host = upstream_host
|
|
self.upstream_port = upstream_port
|
|
self.server = None
|
|
self.host = None
|
|
self.port = None
|
|
self.executor = None
|
|
self.buffer_size = ZAMENYAT_BUFFER_SIZE
|
|
self.header_max_length = ZAMENYAT_HEADER_MAX_LENGTH
|
|
self.connection_count = 0
|
|
self.total_connection_count = 0
|
|
super().__init__(*args, **kwargs)
|
|
|
|
async def get_headers(self, reader):
|
|
data = b''
|
|
headers = None
|
|
while True:
|
|
chunk = await reader.read(self.buffer_size)
|
|
if not chunk:
|
|
break
|
|
data += chunk
|
|
if len(data) > self.header_max_length:
|
|
break
|
|
headers_end = data.find(b'\r\n\r\n')
|
|
if headers_end:
|
|
headers = data[:headers_end]
|
|
data = data[headers_end + 4:]
|
|
await reader.unread(data)
|
|
break
|
|
if not headers:
|
|
return None, None
|
|
header_dict = {}
|
|
req_resp, *headers = headers.split(b"\r\n")
|
|
for header_line in headers:
|
|
key, *value = header_line.split(b": ")
|
|
key = key.decode()
|
|
value = ": ".join([value.decode() for value in value])
|
|
header_dict[key] = int(value) if value.isdigit() else value
|
|
return req_resp.decode(), header_dict
|
|
|
|
def header_dict_to_bytes(self, req_resp, headers):
|
|
header_list = [req_resp]
|
|
for key, value in headers.items():
|
|
header_list.append("{}: {}".format(key, value))
|
|
header_list.append("\r\n")
|
|
return ("\r\n".join(header_list)).encode()
|
|
|
|
async def stream(self, reader,writer,is_websocket=False):
|
|
global headers
|
|
try:
|
|
reader = Socket(reader,writer, ZAMENYAT_BUFFER_SIZE)
|
|
writer = Socket(reader,writer, ZAMENYAT_BUFFER_SIZE)
|
|
while True:
|
|
req_resp, headers = None, None
|
|
data = b''
|
|
if not is_websocket:
|
|
req_resp, headers = await self.get_headers(reader)
|
|
if not headers:
|
|
break
|
|
else:
|
|
data = await reader.read()
|
|
if 'Content-Length' in headers:
|
|
while not len(data) == headers['Content-Length']:
|
|
chunk_size = headers['Content-Length'] - len(data) if self.buffer_size > headers['Content-Length'] - len(data) else self.buffer_size
|
|
print("Bef read")
|
|
chunk = await reader.read(chunk_size)
|
|
if not chunk:
|
|
data = None
|
|
break
|
|
print("Aff read")
|
|
data += chunk
|
|
print(self.header_dict_to_bytes(req_resp,headers).decode())
|
|
await writer.write(self.header_dict_to_bytes(req_resp, headers))
|
|
await writer.drain()
|
|
if data:
|
|
await writer.write(data)
|
|
#if not headers.get('Connection') == 'keep-alive': # and not headers.get('Upgrade-Insecure-Requests'):
|
|
# break
|
|
break
|
|
except asyncio.CancelledError:
|
|
pass
|
|
finally:
|
|
pass
|
|
return headers
|
|
#writer.close()
|
|
#await writer.wait_closed()
|
|
|
|
async def handle_client(self,reader,writer):
|
|
self.connection_count += 1
|
|
self.total_connection_count += 1
|
|
connection_nr = self.connection_count
|
|
|
|
|
|
upstream_reader, upstream_writer = await asyncio.open_connection(self.upstream_host, self.upstream_port)
|
|
|
|
is_websocket = False
|
|
|
|
while True:
|
|
time_start = time.time()
|
|
print(f"Connected to upstream #{self.total_connection_count} server {self.upstream_host}:{self.upstream_port} #{connection_nr} Time: {get_timestamp()}")
|
|
|
|
request_headers = await self.stream(reader, upstream_writer,is_websocket)
|
|
await self.stream(upstream_reader, writer, is_websocket)
|
|
time_end = time.time()
|
|
time_duration = time_end - time_start
|
|
print(f"Disconnected upstream #{self.total_connection_count} server {self.upstream_host}:{self.upstream_port} #{connection_nr} Duration: {time_duration:.5f}s")
|
|
|
|
keep_alive = False
|
|
|
|
if request_headers.get('Connection') == 'keep-alive': # and not headers.get('Upgrade-Insecure-Requests'):
|
|
keep_alive = True
|
|
|
|
if request_headers.get("Upgrade") == 'websocket':
|
|
is_websocket = True
|
|
|
|
if not any([keep_alive, is_websocket]):
|
|
break
|
|
|
|
self.connection_count -= 1
|
|
|
|
writer.close()
|
|
await writer.wait_closed()
|
|
upstream_writer.close()
|
|
await upstream_writer.wait_closed()
|
|
|
|
def upgrade_executor(self, thread_count):
|
|
self.executor = Executor(max_workers=thread_count)
|
|
loop = asyncio.get_running_loop()
|
|
loop.set_default_executor(self.executor)
|
|
return self.executor
|
|
|
|
async def serve_async(self, host,port):
|
|
self.upgrade_executor(ZAMENYAT_THREAD_COUNT)
|
|
self.host = host
|
|
self.port = port
|
|
self.server = await asyncio.start_server(self.handle_client, self.host, self.port)
|
|
async with self.server:
|
|
await self.server.serve_forever()
|
|
|
|
def serve(self, host, port):
|
|
try:
|
|
asyncio.run(self.serve_async(host,port))
|
|
except KeyboardInterrupt:
|
|
print("Shutted down server")
|