From 37f708c06432047b2cc20062352a8cabb9a45c49 Mon Sep 17 00:00:00 2001 From: retoor Date: Tue, 27 Jan 2026 22:10:48 +0100 Subject: [PATCH] Proxy. --- proxy_with_ratelimiter.py | 312 ++++++++++++++------------------------ 1 file changed, 116 insertions(+), 196 deletions(-) diff --git a/proxy_with_ratelimiter.py b/proxy_with_ratelimiter.py index 2344500..d61cfdb 100644 --- a/proxy_with_ratelimiter.py +++ b/proxy_with_ratelimiter.py @@ -10,16 +10,32 @@ import sys import time from typing import Dict, Optional +DEFAULT_PORT = 8585 +DEFAULT_REQUESTS_PER_SECOND = 2 +DEFAULT_BURST_SIZE = 5 +TOKEN_COST = 1.0 +HTTP_TOO_MANY_REQUESTS = 429 +HTTP_NOT_FOUND = 404 +HTTP_BAD_GATEWAY = 502 +MSG_TOO_MANY_REQUESTS = "Too Many Requests" +MSG_NOT_FOUND = "Not Found" +MSG_BAD_GATEWAY = "Bad Gateway" +DELIMITER_CRLF = b'\r\n\r\n' +DELIMITER_LF = b'\n\n' +HOST_ALL = '0.0.0.0' +CONTENT_TYPE_TEXT = "text/plain" +CONNECTION_CLOSE = "close" +UPGRADE_WEBSOCKET = "websocket" +AUTH_BASIC = "Basic" + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') - class RateLimiter: - def __init__(self, rate: float, capacity: int): - self.rate = rate - self.capacity = capacity + self.rate = float(rate) + self.capacity = float(capacity) self.buckets: Dict[str, Dict[str, float]] = defaultdict( - lambda: {'tokens': capacity, 'last_update': time.time()} + lambda: {'tokens': self.capacity, 'last_update': time.time()} ) def is_allowed(self, client_ip: str) -> bool: @@ -27,260 +43,164 @@ class RateLimiter: bucket = self.buckets[client_ip] time_passed = now - bucket['last_update'] - bucket['tokens'] = min( - self.capacity, - bucket['tokens'] + time_passed * self.rate - ) + refill = time_passed * self.rate + + bucket['tokens'] = min(self.capacity, bucket['tokens'] + refill) bucket['last_update'] = now - if bucket['tokens'] >= 1.0: - bucket['tokens'] -= 1.0 + if bucket['tokens'] >= TOKEN_COST: + bucket['tokens'] -= TOKEN_COST return True return False - def cleanup_old_entries(self, max_age: float = 3600): - now = time.time() - to_remove = [ - ip for ip, bucket in self.buckets.items() - if now - bucket['last_update'] > max_age - ] - for ip in to_remove: - del self.buckets[ip] - - class ProxyConfig: - def __init__(self, config_path: str): with open(config_path, 'r') as f: self.config = json.load(f) - self.port = self.config.get('port', 8585) - self.routes = self._build_routing_table() - - rate_limit_config = self.config.get('rate_limit', {}) - self.rate_limit_enabled = rate_limit_config.get('enabled', True) - requests_per_second = rate_limit_config.get('requests_per_second', 10) - burst_size = rate_limit_config.get('burst_size', 20) - self.rate_limiter = RateLimiter(requests_per_second, burst_size) - - def _build_routing_table(self) -> Dict[str, dict]: - routes = {} - for route in self.config.get('reverse_proxy', []): - hostname = route['hostname'].lower() - routes[hostname] = route - return routes - - def get_route(self, hostname: str) -> Optional[dict]: - return self.routes.get(hostname.lower()) - + self.port = self.config.get('port', DEFAULT_PORT) + self.routes = {r['hostname'].lower(): r for r in self.config.get('reverse_proxy', [])} + rl = self.config.get('rate_limit', {}) + self.rate_limit_enabled = rl.get('enabled', True) + self.rate_limiter = RateLimiter(rl.get('requests_per_second', DEFAULT_REQUESTS_PER_SECOND), rl.get('burst_size', DEFAULT_BURST_SIZE)) class HTTPProxyProtocol(asyncio.Protocol): - - MAX_HEADER_SIZE = 16 * 1024 - def __init__(self, config: ProxyConfig): self.config = config self.transport = None - self.buffer = b'' self.upstream_transport = None - self.upstream_protocol = None + self.buffer = b'' self.headers_parsed = False + self.is_websocket = False self.route = None - self.peername = None def connection_made(self, transport): self.transport = transport self.peername = transport.get_extra_info('peername') - logging.info(f"Connection from {self.peername}") def data_received(self, data: bytes): - self.buffer += data - - if not self.headers_parsed: - if len(self.buffer) > self.MAX_HEADER_SIZE: - self._send_error(431, "Request Header Fields Too Large") - return - - - delimiter = b'\r\n\r\n' if b'\r\n\r\n' in self.buffer else b'\n\n' if b'\n\n' in self.buffer else None - - if delimiter: - self._parse_and_route(delimiter) - else: - + if self.is_websocket: if self.upstream_transport: self.upstream_transport.write(data) + return - def _parse_and_route(self, delimiter: bytes): + self.buffer += data + while not self.is_websocket and not self.headers_parsed and self.buffer: + delimiter = DELIMITER_CRLF if DELIMITER_CRLF in self.buffer else DELIMITER_LF if DELIMITER_LF in self.buffer else None + if not delimiter: + break + + if not self._process_headers(delimiter): + break + + if self.headers_parsed and not self.is_websocket and self.buffer: + if self.upstream_transport: + self.upstream_transport.write(self.buffer) + self.buffer = b'' + + def _process_headers(self, delimiter: bytes) -> bool: try: - header_end = self.buffer.index(delimiter) + len(delimiter) + idx = self.buffer.index(delimiter) + len(delimiter) + header_chunk = self.buffer[:idx] + self.buffer = self.buffer[idx:] - if header_end > self.MAX_HEADER_SIZE: - self._send_error(431, "Request Header Fields Too Large") - return - - header_data = self.buffer[:header_end] - lines = header_data.split(b'\r\n') if b'\r\n' in header_data else header_data.split(b'\n') - - if not lines or len(lines[0]) == 0: - self._send_error(400, "Bad Request") - return - - request_line = lines[0].decode('utf-8', errors='ignore') + lines = header_chunk.splitlines() + if not lines: return False headers = {} for line in lines[1:]: if b':' in line: - key, value = line.split(b':', 1) - headers[key.decode('utf-8', errors='ignore').strip().lower()] = \ - value.decode('utf-8', errors='ignore').strip() - - - client_ip = self.peername[0] if self.peername else 'unknown' - if 'x-forwarded-for' in headers: - client_ip = headers['x-forwarded-for'].split(',')[0].strip() + k, v = line.split(b':', 1) + headers[k.decode(errors='ignore').strip().lower()] = v.decode(errors='ignore').strip() + ip = headers.get('x-forwarded-for', '').split(',')[0].strip() or self.peername[0] if self.config.rate_limit_enabled: - if not self.config.rate_limiter.is_allowed(client_ip): - logging.warning(f"Rate limit exceeded for {client_ip}") - self._send_error(429, "Too Many Requests") - return + if not self.config.rate_limiter.is_allowed(ip): + logging.info(f"Rate Limiter BLOCKING: {ip}") + self._send_error(HTTP_TOO_MANY_REQUESTS, MSG_TOO_MANY_REQUESTS) + return False - host = headers.get('host', '').split(':')[0] - if not host: - self._send_error(400, "Bad Request: Missing Host header") - return - - self.route = self.config.get_route(host) + host = headers.get('host', '').split(':')[0].lower() + self.route = self.config.routes.get(host) if not self.route: - self._send_error(404, f"Not Found: No route for {host}") - return + self._send_error(HTTP_NOT_FOUND, MSG_NOT_FOUND) + return False + + if headers.get('upgrade', '').lower() == UPGRADE_WEBSOCKET: + self.is_websocket = True - modified_headers = self._rewrite_headers(headers, request_line) self.headers_parsed = True + mod_headers = self._rewrite(headers, lines[0].decode()) - - if self.upstream_transport and not self.upstream_transport.is_closing(): - self.upstream_transport.write(modified_headers) - - if len(self.buffer) > header_end: - self.upstream_transport.write(self.buffer[header_end:]) - self.buffer = b'' + if self.upstream_transport: + self.upstream_transport.write(mod_headers) else: - asyncio.create_task(self._connect_upstream(modified_headers, header_end)) - + asyncio.create_task(self._connect(mod_headers)) + return True except Exception as e: - logging.error(f"Error parsing request: {e}") - self._send_error(500, f"Internal Server Error") - - def _rewrite_headers(self, headers: dict, request_line: str) -> bytes: - - lines = [request_line.encode('utf-8')] - - for key, value in headers.items(): - if key == 'host' and self.route.get('rewrite_host'): - value = f"{self.route['upstream_host']}:{self.route['upstream_port']}" - lines.append(f"{key.title()}: {value}".encode('utf-8')) + logging.error(f"Parser error: {e}") + return False + def _rewrite(self, headers: dict, req_line: str) -> bytes: + out = [req_line.encode()] + for k, v in headers.items(): + if k == 'host' and self.route.get('rewrite_host'): + v = f"{self.route['upstream_host']}:{self.route['upstream_port']}" + out.append(f"{k.title()}: {v}".encode()) if self.route.get('use_auth'): - username = self.route.get('username', '') - password = self.route.get('password', '') - credentials = base64.b64encode(f"{username}:{password}".encode()).decode() - lines.append(f"Authorization: Basic {credentials}".encode('utf-8')) - - lines.append(b'') - lines.append(b'') - return b'\r\n'.join(lines) - - async def _connect_upstream(self, modified_headers: bytes, header_end: int): + creds = base64.b64encode(f"{self.route['username']}:{self.route['password']}".encode()).decode() + out.append(f"Authorization: {AUTH_BASIC} {creds}".encode()) + return b'\r\n'.join(out) + b'\r\n\r\n' + async def _connect(self, mod_headers: bytes): try: - upstream_host = self.route['upstream_host'] - upstream_port = self.route['upstream_port'] - use_ssl = self.route.get('use_ssl', False) + ctx = None + if self.route.get('use_ssl'): + ctx = ssl.create_default_context() + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE - loop = asyncio.get_event_loop() - ssl_context = ssl.create_default_context() if use_ssl else None - if ssl_context: - ssl_context.check_hostname = False - ssl_context.verify_mode = ssl.CERT_NONE + self.upstream_transport, _ = await asyncio.get_event_loop().create_connection( + lambda: UpstreamProtocol(self), self.route['upstream_host'], self.route['upstream_port'], ssl=ctx + ) + self.upstream_transport.write(mod_headers) + if self.buffer: + self.upstream_transport.write(self.buffer) + self.buffer = b'' + except Exception: + self._send_error(HTTP_BAD_GATEWAY, MSG_BAD_GATEWAY) - self.upstream_transport, self.upstream_protocol = \ - await loop.create_connection( - lambda: UpstreamProtocol(self), - upstream_host, - upstream_port, - ssl=ssl_context - ) - - self.upstream_transport.write(modified_headers) - if len(self.buffer) > header_end: - self.upstream_transport.write(self.buffer[header_end:]) - self.buffer = b'' - - except Exception as e: - logging.error(f"Error connecting to upstream: {e}") - self._send_error(502, f"Bad Gateway") - - def _send_error(self, code: int, message: str): - response = ( - f"HTTP/1.1 {code} {message}\r\n" - f"Content-Type: text/plain\r\n" - f"Content-Length: {len(message)}\r\n" - f"Connection: close\r\n\r\n{message}" - ).encode('utf-8') + def _send_error(self, code: int, msg: str): if self.transport: - self.transport.write(response) + self.transport.write(f"HTTP/1.1 {code} {msg}\r\nContent-Type: {CONTENT_TYPE_TEXT}\r\nConnection: {CONNECTION_CLOSE}\r\n\r\n{msg}".encode()) self.transport.close() def connection_lost(self, exc): - if self.upstream_transport: - self.upstream_transport.close() - + if self.upstream_transport: self.upstream_transport.close() class UpstreamProtocol(asyncio.Protocol): - - def __init__(self, proxy_protocol: HTTPProxyProtocol): - self.proxy_protocol = proxy_protocol - self.transport = None + def __init__(self, proxy: HTTPProxyProtocol): + self.proxy = proxy def connection_made(self, transport): self.transport = transport def data_received(self, data: bytes): - - if self.proxy_protocol.transport: - self.proxy_protocol.transport.write(data) - - - self.proxy_protocol.headers_parsed = False + if self.proxy.transport: + self.proxy.transport.write(data) + if not self.proxy.is_websocket and (DELIMITER_CRLF in data or DELIMITER_LF in data): + self.proxy.headers_parsed = False def connection_lost(self, exc): - if self.proxy_protocol.transport: - self.proxy_protocol.transport.close() - - -async def cleanup_task(config: ProxyConfig): - while True: - await asyncio.sleep(600) - config.rate_limiter.cleanup_old_entries() - - -async def main(config_path: str): - config = ProxyConfig(config_path) - loop = asyncio.get_event_loop() - server = await loop.create_server(lambda: HTTPProxyProtocol(config), '0.0.0.0', config.port) - logging.info(f"Reverse proxy (Keep-Alive enabled) listening on 0.0.0.0:{config.port}") - asyncio.create_task(cleanup_task(config)) - async with server: - await server.serve_forever() + if self.proxy.transport: self.proxy.transport.close() +async def main(path: str): + cfg = ProxyConfig(path) + server = await asyncio.get_event_loop().create_server(lambda: HTTPProxyProtocol(cfg), HOST_ALL, cfg.port) + logging.info(f"Proxy running on {cfg.port}") + async with server: await server.serve_forever() if __name__ == '__main__': - if len(sys.argv) >= 2: - try: - asyncio.run(main(sys.argv[1])) - except KeyboardInterrupt: - pass + if len(sys.argv) > 1: asyncio.run(main(sys.argv[1]))