207 lines
7.4 KiB
Python
Raw Normal View History

2026-01-27 20:06:10 +01:00
#!/usr/bin/env python3
import asyncio
import base64
from collections import defaultdict
import json
import logging
import ssl
import sys
import time
from typing import Dict, Optional
2026-01-27 22:10:48 +01:00
DEFAULT_PORT = 8585
DEFAULT_REQUESTS_PER_SECOND = 2
DEFAULT_BURST_SIZE = 5
TOKEN_COST = 1.0
HTTP_TOO_MANY_REQUESTS = 429
HTTP_NOT_FOUND = 404
HTTP_BAD_GATEWAY = 502
MSG_TOO_MANY_REQUESTS = "Too Many Requests"
MSG_NOT_FOUND = "Not Found"
MSG_BAD_GATEWAY = "Bad Gateway"
DELIMITER_CRLF = b'\r\n\r\n'
DELIMITER_LF = b'\n\n'
HOST_ALL = '0.0.0.0'
CONTENT_TYPE_TEXT = "text/plain"
CONNECTION_CLOSE = "close"
UPGRADE_WEBSOCKET = "websocket"
AUTH_BASIC = "Basic"
2026-01-27 20:06:10 +01:00
2026-01-27 22:10:48 +01:00
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
2026-01-27 20:06:10 +01:00
class RateLimiter:
def __init__(self, rate: float, capacity: int):
2026-01-27 22:10:48 +01:00
self.rate = float(rate)
self.capacity = float(capacity)
2026-01-27 20:06:10 +01:00
self.buckets: Dict[str, Dict[str, float]] = defaultdict(
2026-01-27 22:10:48 +01:00
lambda: {'tokens': self.capacity, 'last_update': time.time()}
2026-01-27 20:06:10 +01:00
)
def is_allowed(self, client_ip: str) -> bool:
now = time.time()
bucket = self.buckets[client_ip]
time_passed = now - bucket['last_update']
2026-01-27 22:10:48 +01:00
refill = time_passed * self.rate
bucket['tokens'] = min(self.capacity, bucket['tokens'] + refill)
2026-01-27 20:06:10 +01:00
bucket['last_update'] = now
2026-01-27 22:10:48 +01:00
if bucket['tokens'] >= TOKEN_COST:
bucket['tokens'] -= TOKEN_COST
2026-01-27 20:06:10 +01:00
return True
return False
class ProxyConfig:
def __init__(self, config_path: str):
with open(config_path, 'r') as f:
self.config = json.load(f)
2026-01-27 22:10:48 +01:00
self.port = self.config.get('port', DEFAULT_PORT)
self.routes = {r['hostname'].lower(): r for r in self.config.get('reverse_proxy', [])}
rl = self.config.get('rate_limit', {})
self.rate_limit_enabled = rl.get('enabled', True)
self.rate_limiter = RateLimiter(rl.get('requests_per_second', DEFAULT_REQUESTS_PER_SECOND), rl.get('burst_size', DEFAULT_BURST_SIZE))
2026-01-27 20:06:10 +01:00
class HTTPProxyProtocol(asyncio.Protocol):
def __init__(self, config: ProxyConfig):
self.config = config
self.transport = None
self.upstream_transport = None
2026-01-27 22:10:48 +01:00
self.buffer = b''
2026-01-27 20:06:10 +01:00
self.headers_parsed = False
2026-01-27 22:10:48 +01:00
self.is_websocket = False
2026-01-27 20:06:10 +01:00
self.route = None
def connection_made(self, transport):
self.transport = transport
self.peername = transport.get_extra_info('peername')
def data_received(self, data: bytes):
2026-01-27 22:10:48 +01:00
if self.is_websocket:
2026-01-27 20:06:10 +01:00
if self.upstream_transport:
self.upstream_transport.write(data)
2026-01-27 22:10:48 +01:00
return
2026-01-27 20:06:10 +01:00
2026-01-27 22:10:48 +01:00
self.buffer += data
2026-01-27 20:06:10 +01:00
2026-01-27 22:10:48 +01:00
while not self.is_websocket and not self.headers_parsed and self.buffer:
delimiter = DELIMITER_CRLF if DELIMITER_CRLF in self.buffer else DELIMITER_LF if DELIMITER_LF in self.buffer else None
if not delimiter:
break
2026-01-27 20:06:10 +01:00
2026-01-27 22:10:48 +01:00
if not self._process_headers(delimiter):
break
2026-01-27 20:06:10 +01:00
2026-01-27 22:10:48 +01:00
if self.headers_parsed and not self.is_websocket and self.buffer:
if self.upstream_transport:
self.upstream_transport.write(self.buffer)
self.buffer = b''
2026-01-27 20:06:10 +01:00
2026-01-27 22:10:48 +01:00
def _process_headers(self, delimiter: bytes) -> bool:
try:
idx = self.buffer.index(delimiter) + len(delimiter)
header_chunk = self.buffer[:idx]
self.buffer = self.buffer[idx:]
2026-01-27 20:06:10 +01:00
2026-01-27 22:10:48 +01:00
lines = header_chunk.splitlines()
if not lines: return False
2026-01-27 20:06:10 +01:00
headers = {}
for line in lines[1:]:
if b':' in line:
2026-01-27 22:10:48 +01:00
k, v = line.split(b':', 1)
headers[k.decode(errors='ignore').strip().lower()] = v.decode(errors='ignore').strip()
2026-01-27 20:06:10 +01:00
2026-01-27 22:10:48 +01:00
ip = headers.get('x-forwarded-for', '').split(',')[0].strip() or self.peername[0]
2026-01-27 20:06:10 +01:00
if self.config.rate_limit_enabled:
2026-01-27 22:10:48 +01:00
if not self.config.rate_limiter.is_allowed(ip):
logging.info(f"Rate Limiter BLOCKING: {ip}")
self._send_error(HTTP_TOO_MANY_REQUESTS, MSG_TOO_MANY_REQUESTS)
return False
2026-01-27 20:06:10 +01:00
2026-01-27 22:10:48 +01:00
host = headers.get('host', '').split(':')[0].lower()
self.route = self.config.routes.get(host)
2026-01-27 20:06:10 +01:00
if not self.route:
2026-01-27 22:10:48 +01:00
self._send_error(HTTP_NOT_FOUND, MSG_NOT_FOUND)
return False
2026-01-27 20:06:10 +01:00
2026-01-27 22:10:48 +01:00
if headers.get('upgrade', '').lower() == UPGRADE_WEBSOCKET:
self.is_websocket = True
2026-01-27 20:06:10 +01:00
2026-01-27 22:10:48 +01:00
self.headers_parsed = True
mod_headers = self._rewrite(headers, lines[0].decode())
2026-01-27 20:06:10 +01:00
2026-01-27 22:10:48 +01:00
if self.upstream_transport:
self.upstream_transport.write(mod_headers)
2026-01-27 20:06:10 +01:00
else:
2026-01-27 22:10:48 +01:00
asyncio.create_task(self._connect(mod_headers))
return True
2026-01-27 20:06:10 +01:00
except Exception as e:
2026-01-27 22:10:48 +01:00
logging.error(f"Parser error: {e}")
return False
def _rewrite(self, headers: dict, req_line: str) -> bytes:
out = [req_line.encode()]
for k, v in headers.items():
if k == 'host' and self.route.get('rewrite_host'):
v = f"{self.route['upstream_host']}:{self.route['upstream_port']}"
out.append(f"{k.title()}: {v}".encode())
2026-01-27 20:06:10 +01:00
if self.route.get('use_auth'):
2026-01-27 22:10:48 +01:00
creds = base64.b64encode(f"{self.route['username']}:{self.route['password']}".encode()).decode()
out.append(f"Authorization: {AUTH_BASIC} {creds}".encode())
return b'\r\n'.join(out) + b'\r\n\r\n'
2026-01-27 20:06:10 +01:00
2026-01-27 22:10:48 +01:00
async def _connect(self, mod_headers: bytes):
2026-01-27 20:06:10 +01:00
try:
2026-01-27 22:10:48 +01:00
ctx = None
if self.route.get('use_ssl'):
ctx = ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
self.upstream_transport, _ = await asyncio.get_event_loop().create_connection(
lambda: UpstreamProtocol(self), self.route['upstream_host'], self.route['upstream_port'], ssl=ctx
)
self.upstream_transport.write(mod_headers)
if self.buffer:
self.upstream_transport.write(self.buffer)
self.buffer = b''
except Exception:
self._send_error(HTTP_BAD_GATEWAY, MSG_BAD_GATEWAY)
2026-01-27 20:06:10 +01:00
2026-01-27 22:10:48 +01:00
def _send_error(self, code: int, msg: str):
2026-01-27 20:06:10 +01:00
if self.transport:
2026-01-27 22:10:48 +01:00
self.transport.write(f"HTTP/1.1 {code} {msg}\r\nContent-Type: {CONTENT_TYPE_TEXT}\r\nConnection: {CONNECTION_CLOSE}\r\n\r\n{msg}".encode())
2026-01-27 20:06:10 +01:00
self.transport.close()
def connection_lost(self, exc):
2026-01-27 22:10:48 +01:00
if self.upstream_transport: self.upstream_transport.close()
2026-01-27 20:06:10 +01:00
class UpstreamProtocol(asyncio.Protocol):
2026-01-27 22:10:48 +01:00
def __init__(self, proxy: HTTPProxyProtocol):
self.proxy = proxy
2026-01-27 20:06:10 +01:00
def connection_made(self, transport):
self.transport = transport
def data_received(self, data: bytes):
2026-01-27 22:10:48 +01:00
if self.proxy.transport:
self.proxy.transport.write(data)
if not self.proxy.is_websocket and (DELIMITER_CRLF in data or DELIMITER_LF in data):
self.proxy.headers_parsed = False
2026-01-27 20:06:10 +01:00
def connection_lost(self, exc):
2026-01-27 22:10:48 +01:00
if self.proxy.transport: self.proxy.transport.close()
2026-01-27 20:06:10 +01:00
2026-01-27 22:10:48 +01:00
async def main(path: str):
cfg = ProxyConfig(path)
server = await asyncio.get_event_loop().create_server(lambda: HTTPProxyProtocol(cfg), HOST_ALL, cfg.port)
logging.info(f"Proxy running on {cfg.port}")
async with server: await server.serve_forever()
2026-01-27 20:06:10 +01:00
if __name__ == '__main__':
2026-01-27 22:10:48 +01:00
if len(sys.argv) > 1: asyncio.run(main(sys.argv[1]))
2026-01-27 20:06:10 +01:00