#!/usr/bin/env python3 import asyncio import base64 from collections import defaultdict import json import logging import ssl import sys import time from typing import Dict, Optional DEFAULT_PORT = 8585 DEFAULT_REQUESTS_PER_SECOND = 2 DEFAULT_BURST_SIZE = 5 TOKEN_COST = 1.0 HTTP_TOO_MANY_REQUESTS = 429 HTTP_NOT_FOUND = 404 HTTP_BAD_GATEWAY = 502 MSG_TOO_MANY_REQUESTS = "Too Many Requests" MSG_NOT_FOUND = "Not Found" MSG_BAD_GATEWAY = "Bad Gateway" DELIMITER_CRLF = b'\r\n\r\n' DELIMITER_LF = b'\n\n' HOST_ALL = '0.0.0.0' CONTENT_TYPE_TEXT = "text/plain" CONNECTION_CLOSE = "close" UPGRADE_WEBSOCKET = "websocket" AUTH_BASIC = "Basic" logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') class RateLimiter: def __init__(self, rate: float, capacity: int): self.rate = float(rate) self.capacity = float(capacity) self.buckets: Dict[str, Dict[str, float]] = defaultdict( lambda: {'tokens': self.capacity, 'last_update': time.time()} ) def is_allowed(self, client_ip: str) -> bool: now = time.time() bucket = self.buckets[client_ip] time_passed = now - bucket['last_update'] refill = time_passed * self.rate bucket['tokens'] = min(self.capacity, bucket['tokens'] + refill) bucket['last_update'] = now if bucket['tokens'] >= TOKEN_COST: bucket['tokens'] -= TOKEN_COST return True return False class ProxyConfig: def __init__(self, config_path: str): with open(config_path, 'r') as f: self.config = json.load(f) self.port = self.config.get('port', DEFAULT_PORT) self.routes = {r['hostname'].lower(): r for r in self.config.get('reverse_proxy', [])} rl = self.config.get('rate_limit', {}) self.rate_limit_enabled = rl.get('enabled', True) self.rate_limiter = RateLimiter(rl.get('requests_per_second', DEFAULT_REQUESTS_PER_SECOND), rl.get('burst_size', DEFAULT_BURST_SIZE)) class HTTPProxyProtocol(asyncio.Protocol): def __init__(self, config: ProxyConfig): self.config = config self.transport = None self.upstream_transport = None self.buffer = b'' self.headers_parsed = False self.is_websocket = False self.route = None def connection_made(self, transport): self.transport = transport self.peername = transport.get_extra_info('peername') def data_received(self, data: bytes): if self.is_websocket: if self.upstream_transport: self.upstream_transport.write(data) return self.buffer += data while not self.is_websocket and not self.headers_parsed and self.buffer: delimiter = DELIMITER_CRLF if DELIMITER_CRLF in self.buffer else DELIMITER_LF if DELIMITER_LF in self.buffer else None if not delimiter: break if not self._process_headers(delimiter): break if self.headers_parsed and not self.is_websocket and self.buffer: if self.upstream_transport: self.upstream_transport.write(self.buffer) self.buffer = b'' def _process_headers(self, delimiter: bytes) -> bool: try: idx = self.buffer.index(delimiter) + len(delimiter) header_chunk = self.buffer[:idx] self.buffer = self.buffer[idx:] lines = header_chunk.splitlines() if not lines: return False headers = {} for line in lines[1:]: if b':' in line: k, v = line.split(b':', 1) headers[k.decode(errors='ignore').strip().lower()] = v.decode(errors='ignore').strip() ip = headers.get('x-forwarded-for', '').split(',')[0].strip() or self.peername[0] if self.config.rate_limit_enabled: if not self.config.rate_limiter.is_allowed(ip): logging.info(f"Rate Limiter BLOCKING: {ip}") self._send_error(HTTP_TOO_MANY_REQUESTS, MSG_TOO_MANY_REQUESTS) return False host = headers.get('host', '').split(':')[0].lower() self.route = self.config.routes.get(host) if not self.route: self._send_error(HTTP_NOT_FOUND, MSG_NOT_FOUND) return False if headers.get('upgrade', '').lower() == UPGRADE_WEBSOCKET: self.is_websocket = True self.headers_parsed = True mod_headers = self._rewrite(headers, lines[0].decode()) if self.upstream_transport: self.upstream_transport.write(mod_headers) else: asyncio.create_task(self._connect(mod_headers)) return True except Exception as e: logging.error(f"Parser error: {e}") return False def _rewrite(self, headers: dict, req_line: str) -> bytes: out = [req_line.encode()] for k, v in headers.items(): if k == 'host' and self.route.get('rewrite_host'): v = f"{self.route['upstream_host']}:{self.route['upstream_port']}" out.append(f"{k.title()}: {v}".encode()) if self.route.get('use_auth'): creds = base64.b64encode(f"{self.route['username']}:{self.route['password']}".encode()).decode() out.append(f"Authorization: {AUTH_BASIC} {creds}".encode()) return b'\r\n'.join(out) + b'\r\n\r\n' async def _connect(self, mod_headers: bytes): try: ctx = None if self.route.get('use_ssl'): ctx = ssl.create_default_context() ctx.check_hostname = False ctx.verify_mode = ssl.CERT_NONE self.upstream_transport, _ = await asyncio.get_event_loop().create_connection( lambda: UpstreamProtocol(self), self.route['upstream_host'], self.route['upstream_port'], ssl=ctx ) self.upstream_transport.write(mod_headers) if self.buffer: self.upstream_transport.write(self.buffer) self.buffer = b'' except Exception: self._send_error(HTTP_BAD_GATEWAY, MSG_BAD_GATEWAY) def _send_error(self, code: int, msg: str): if self.transport: self.transport.write(f"HTTP/1.1 {code} {msg}\r\nContent-Type: {CONTENT_TYPE_TEXT}\r\nConnection: {CONNECTION_CLOSE}\r\n\r\n{msg}".encode()) self.transport.close() def connection_lost(self, exc): if self.upstream_transport: self.upstream_transport.close() class UpstreamProtocol(asyncio.Protocol): def __init__(self, proxy: HTTPProxyProtocol): self.proxy = proxy def connection_made(self, transport): self.transport = transport def data_received(self, data: bytes): if self.proxy.transport: self.proxy.transport.write(data) if not self.proxy.is_websocket and (DELIMITER_CRLF in data or DELIMITER_LF in data): self.proxy.headers_parsed = False def connection_lost(self, exc): if self.proxy.transport: self.proxy.transport.close() async def main(path: str): cfg = ProxyConfig(path) server = await asyncio.get_event_loop().create_server(lambda: HTTPProxyProtocol(cfg), HOST_ALL, cfg.port) logging.info(f"Proxy running on {cfg.port}") async with server: await server.serve_forever() if __name__ == '__main__': if len(sys.argv) > 1: asyncio.run(main(sys.argv[1]))