#!/usr/bin/env python3 import asyncio import base64 from collections import defaultdict import json import logging import ssl import sys import time from typing import Dict, Optional logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') class RateLimiter: def __init__(self, rate: float, capacity: int): self.rate = rate self.capacity = capacity self.buckets: Dict[str, Dict[str, float]] = defaultdict( lambda: {'tokens': capacity, 'last_update': time.time()} ) def is_allowed(self, client_ip: str) -> bool: now = time.time() bucket = self.buckets[client_ip] time_passed = now - bucket['last_update'] bucket['tokens'] = min( self.capacity, bucket['tokens'] + time_passed * self.rate ) bucket['last_update'] = now if bucket['tokens'] >= 1.0: bucket['tokens'] -= 1.0 return True return False def cleanup_old_entries(self, max_age: float = 3600): now = time.time() to_remove = [ ip for ip, bucket in self.buckets.items() if now - bucket['last_update'] > max_age ] for ip in to_remove: del self.buckets[ip] class ProxyConfig: def __init__(self, config_path: str): with open(config_path, 'r') as f: self.config = json.load(f) self.port = self.config.get('port', 8585) self.routes = self._build_routing_table() rate_limit_config = self.config.get('rate_limit', {}) self.rate_limit_enabled = rate_limit_config.get('enabled', True) requests_per_second = rate_limit_config.get('requests_per_second', 10) burst_size = rate_limit_config.get('burst_size', 20) self.rate_limiter = RateLimiter(requests_per_second, burst_size) def _build_routing_table(self) -> Dict[str, dict]: routes = {} for route in self.config.get('reverse_proxy', []): hostname = route['hostname'].lower() routes[hostname] = route return routes def get_route(self, hostname: str) -> Optional[dict]: return self.routes.get(hostname.lower()) class HTTPProxyProtocol(asyncio.Protocol): MAX_HEADER_SIZE = 16 * 1024 def __init__(self, config: ProxyConfig): self.config = config self.transport = None self.buffer = b'' self.upstream_transport = None self.upstream_protocol = None self.headers_parsed = False self.route = None self.peername = None def connection_made(self, transport): self.transport = transport self.peername = transport.get_extra_info('peername') logging.info(f"Connection from {self.peername}") def data_received(self, data: bytes): self.buffer += data if not self.headers_parsed: if len(self.buffer) > self.MAX_HEADER_SIZE: self._send_error(431, "Request Header Fields Too Large") return delimiter = b'\r\n\r\n' if b'\r\n\r\n' in self.buffer else b'\n\n' if b'\n\n' in self.buffer else None if delimiter: self._parse_and_route(delimiter) else: if self.upstream_transport: self.upstream_transport.write(data) def _parse_and_route(self, delimiter: bytes): try: header_end = self.buffer.index(delimiter) + len(delimiter) if header_end > self.MAX_HEADER_SIZE: self._send_error(431, "Request Header Fields Too Large") return header_data = self.buffer[:header_end] lines = header_data.split(b'\r\n') if b'\r\n' in header_data else header_data.split(b'\n') if not lines or len(lines[0]) == 0: self._send_error(400, "Bad Request") return request_line = lines[0].decode('utf-8', errors='ignore') headers = {} for line in lines[1:]: if b':' in line: key, value = line.split(b':', 1) headers[key.decode('utf-8', errors='ignore').strip().lower()] = \ value.decode('utf-8', errors='ignore').strip() client_ip = self.peername[0] if self.peername else 'unknown' if 'x-forwarded-for' in headers: client_ip = headers['x-forwarded-for'].split(',')[0].strip() if self.config.rate_limit_enabled: if not self.config.rate_limiter.is_allowed(client_ip): logging.warning(f"Rate limit exceeded for {client_ip}") self._send_error(429, "Too Many Requests") return host = headers.get('host', '').split(':')[0] if not host: self._send_error(400, "Bad Request: Missing Host header") return self.route = self.config.get_route(host) if not self.route: self._send_error(404, f"Not Found: No route for {host}") return modified_headers = self._rewrite_headers(headers, request_line) self.headers_parsed = True if self.upstream_transport and not self.upstream_transport.is_closing(): self.upstream_transport.write(modified_headers) if len(self.buffer) > header_end: self.upstream_transport.write(self.buffer[header_end:]) self.buffer = b'' else: asyncio.create_task(self._connect_upstream(modified_headers, header_end)) except Exception as e: logging.error(f"Error parsing request: {e}") self._send_error(500, f"Internal Server Error") def _rewrite_headers(self, headers: dict, request_line: str) -> bytes: lines = [request_line.encode('utf-8')] for key, value in headers.items(): if key == 'host' and self.route.get('rewrite_host'): value = f"{self.route['upstream_host']}:{self.route['upstream_port']}" lines.append(f"{key.title()}: {value}".encode('utf-8')) if self.route.get('use_auth'): username = self.route.get('username', '') password = self.route.get('password', '') credentials = base64.b64encode(f"{username}:{password}".encode()).decode() lines.append(f"Authorization: Basic {credentials}".encode('utf-8')) lines.append(b'') lines.append(b'') return b'\r\n'.join(lines) async def _connect_upstream(self, modified_headers: bytes, header_end: int): try: upstream_host = self.route['upstream_host'] upstream_port = self.route['upstream_port'] use_ssl = self.route.get('use_ssl', False) loop = asyncio.get_event_loop() ssl_context = ssl.create_default_context() if use_ssl else None if ssl_context: ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE self.upstream_transport, self.upstream_protocol = \ await loop.create_connection( lambda: UpstreamProtocol(self), upstream_host, upstream_port, ssl=ssl_context ) self.upstream_transport.write(modified_headers) if len(self.buffer) > header_end: self.upstream_transport.write(self.buffer[header_end:]) self.buffer = b'' except Exception as e: logging.error(f"Error connecting to upstream: {e}") self._send_error(502, f"Bad Gateway") def _send_error(self, code: int, message: str): response = ( f"HTTP/1.1 {code} {message}\r\n" f"Content-Type: text/plain\r\n" f"Content-Length: {len(message)}\r\n" f"Connection: close\r\n\r\n{message}" ).encode('utf-8') if self.transport: self.transport.write(response) self.transport.close() def connection_lost(self, exc): if self.upstream_transport: self.upstream_transport.close() class UpstreamProtocol(asyncio.Protocol): def __init__(self, proxy_protocol: HTTPProxyProtocol): self.proxy_protocol = proxy_protocol self.transport = None def connection_made(self, transport): self.transport = transport def data_received(self, data: bytes): if self.proxy_protocol.transport: self.proxy_protocol.transport.write(data) self.proxy_protocol.headers_parsed = False def connection_lost(self, exc): if self.proxy_protocol.transport: self.proxy_protocol.transport.close() async def cleanup_task(config: ProxyConfig): while True: await asyncio.sleep(600) config.rate_limiter.cleanup_old_entries() async def main(config_path: str): config = ProxyConfig(config_path) loop = asyncio.get_event_loop() server = await loop.create_server(lambda: HTTPProxyProtocol(config), '0.0.0.0', config.port) logging.info(f"Reverse proxy (Keep-Alive enabled) listening on 0.0.0.0:{config.port}") asyncio.create_task(cleanup_task(config)) async with server: await server.serve_forever() if __name__ == '__main__': if len(sys.argv) >= 2: try: asyncio.run(main(sys.argv[1])) except KeyboardInterrupt: pass