|
#!/usr/bin/env python3
|
|
|
|
import asyncio
|
|
import base64
|
|
from collections import defaultdict
|
|
import json
|
|
import logging
|
|
import ssl
|
|
import sys
|
|
import time
|
|
from typing import Dict, Optional
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
|
|
class RateLimiter:
|
|
|
|
def __init__(self, rate: float, capacity: int):
|
|
self.rate = rate
|
|
self.capacity = capacity
|
|
self.buckets: Dict[str, Dict[str, float]] = defaultdict(
|
|
lambda: {'tokens': capacity, 'last_update': time.time()}
|
|
)
|
|
|
|
def is_allowed(self, client_ip: str) -> bool:
|
|
now = time.time()
|
|
bucket = self.buckets[client_ip]
|
|
|
|
time_passed = now - bucket['last_update']
|
|
bucket['tokens'] = min(
|
|
self.capacity,
|
|
bucket['tokens'] + time_passed * self.rate
|
|
)
|
|
bucket['last_update'] = now
|
|
|
|
if bucket['tokens'] >= 1.0:
|
|
bucket['tokens'] -= 1.0
|
|
return True
|
|
return False
|
|
|
|
def cleanup_old_entries(self, max_age: float = 3600):
|
|
now = time.time()
|
|
to_remove = [
|
|
ip for ip, bucket in self.buckets.items()
|
|
if now - bucket['last_update'] > max_age
|
|
]
|
|
for ip in to_remove:
|
|
del self.buckets[ip]
|
|
|
|
|
|
class ProxyConfig:
|
|
|
|
def __init__(self, config_path: str):
|
|
with open(config_path, 'r') as f:
|
|
self.config = json.load(f)
|
|
self.port = self.config.get('port', 8585)
|
|
self.routes = self._build_routing_table()
|
|
|
|
rate_limit_config = self.config.get('rate_limit', {})
|
|
self.rate_limit_enabled = rate_limit_config.get('enabled', True)
|
|
requests_per_second = rate_limit_config.get('requests_per_second', 10)
|
|
burst_size = rate_limit_config.get('burst_size', 20)
|
|
self.rate_limiter = RateLimiter(requests_per_second, burst_size)
|
|
|
|
def _build_routing_table(self) -> Dict[str, dict]:
|
|
routes = {}
|
|
for route in self.config.get('reverse_proxy', []):
|
|
hostname = route['hostname'].lower()
|
|
routes[hostname] = route
|
|
return routes
|
|
|
|
def get_route(self, hostname: str) -> Optional[dict]:
|
|
return self.routes.get(hostname.lower())
|
|
|
|
|
|
class HTTPProxyProtocol(asyncio.Protocol):
|
|
|
|
MAX_HEADER_SIZE = 16 * 1024
|
|
|
|
def __init__(self, config: ProxyConfig):
|
|
self.config = config
|
|
self.transport = None
|
|
self.buffer = b''
|
|
self.upstream_transport = None
|
|
self.upstream_protocol = None
|
|
self.headers_parsed = False
|
|
self.route = None
|
|
self.peername = None
|
|
|
|
def connection_made(self, transport):
|
|
self.transport = transport
|
|
self.peername = transport.get_extra_info('peername')
|
|
logging.info(f"Connection from {self.peername}")
|
|
|
|
def data_received(self, data: bytes):
|
|
self.buffer += data
|
|
|
|
if not self.headers_parsed:
|
|
if len(self.buffer) > self.MAX_HEADER_SIZE:
|
|
self._send_error(431, "Request Header Fields Too Large")
|
|
return
|
|
|
|
|
|
delimiter = b'\r\n\r\n' if b'\r\n\r\n' in self.buffer else b'\n\n' if b'\n\n' in self.buffer else None
|
|
|
|
if delimiter:
|
|
self._parse_and_route(delimiter)
|
|
else:
|
|
|
|
if self.upstream_transport:
|
|
self.upstream_transport.write(data)
|
|
|
|
def _parse_and_route(self, delimiter: bytes):
|
|
|
|
try:
|
|
header_end = self.buffer.index(delimiter) + len(delimiter)
|
|
|
|
if header_end > self.MAX_HEADER_SIZE:
|
|
self._send_error(431, "Request Header Fields Too Large")
|
|
return
|
|
|
|
header_data = self.buffer[:header_end]
|
|
lines = header_data.split(b'\r\n') if b'\r\n' in header_data else header_data.split(b'\n')
|
|
|
|
if not lines or len(lines[0]) == 0:
|
|
self._send_error(400, "Bad Request")
|
|
return
|
|
|
|
request_line = lines[0].decode('utf-8', errors='ignore')
|
|
|
|
headers = {}
|
|
for line in lines[1:]:
|
|
if b':' in line:
|
|
key, value = line.split(b':', 1)
|
|
headers[key.decode('utf-8', errors='ignore').strip().lower()] = \
|
|
value.decode('utf-8', errors='ignore').strip()
|
|
|
|
|
|
client_ip = self.peername[0] if self.peername else 'unknown'
|
|
if 'x-forwarded-for' in headers:
|
|
client_ip = headers['x-forwarded-for'].split(',')[0].strip()
|
|
|
|
|
|
if self.config.rate_limit_enabled:
|
|
if not self.config.rate_limiter.is_allowed(client_ip):
|
|
logging.warning(f"Rate limit exceeded for {client_ip}")
|
|
self._send_error(429, "Too Many Requests")
|
|
return
|
|
|
|
host = headers.get('host', '').split(':')[0]
|
|
if not host:
|
|
self._send_error(400, "Bad Request: Missing Host header")
|
|
return
|
|
|
|
self.route = self.config.get_route(host)
|
|
if not self.route:
|
|
self._send_error(404, f"Not Found: No route for {host}")
|
|
return
|
|
|
|
modified_headers = self._rewrite_headers(headers, request_line)
|
|
self.headers_parsed = True
|
|
|
|
|
|
if self.upstream_transport and not self.upstream_transport.is_closing():
|
|
self.upstream_transport.write(modified_headers)
|
|
|
|
if len(self.buffer) > header_end:
|
|
self.upstream_transport.write(self.buffer[header_end:])
|
|
self.buffer = b''
|
|
else:
|
|
asyncio.create_task(self._connect_upstream(modified_headers, header_end))
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error parsing request: {e}")
|
|
self._send_error(500, f"Internal Server Error")
|
|
|
|
def _rewrite_headers(self, headers: dict, request_line: str) -> bytes:
|
|
|
|
lines = [request_line.encode('utf-8')]
|
|
|
|
for key, value in headers.items():
|
|
if key == 'host' and self.route.get('rewrite_host'):
|
|
value = f"{self.route['upstream_host']}:{self.route['upstream_port']}"
|
|
lines.append(f"{key.title()}: {value}".encode('utf-8'))
|
|
|
|
if self.route.get('use_auth'):
|
|
username = self.route.get('username', '')
|
|
password = self.route.get('password', '')
|
|
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
|
|
lines.append(f"Authorization: Basic {credentials}".encode('utf-8'))
|
|
|
|
lines.append(b'')
|
|
lines.append(b'')
|
|
return b'\r\n'.join(lines)
|
|
|
|
async def _connect_upstream(self, modified_headers: bytes, header_end: int):
|
|
|
|
try:
|
|
upstream_host = self.route['upstream_host']
|
|
upstream_port = self.route['upstream_port']
|
|
use_ssl = self.route.get('use_ssl', False)
|
|
|
|
loop = asyncio.get_event_loop()
|
|
ssl_context = ssl.create_default_context() if use_ssl else None
|
|
if ssl_context:
|
|
ssl_context.check_hostname = False
|
|
ssl_context.verify_mode = ssl.CERT_NONE
|
|
|
|
self.upstream_transport, self.upstream_protocol = \
|
|
await loop.create_connection(
|
|
lambda: UpstreamProtocol(self),
|
|
upstream_host,
|
|
upstream_port,
|
|
ssl=ssl_context
|
|
)
|
|
|
|
self.upstream_transport.write(modified_headers)
|
|
if len(self.buffer) > header_end:
|
|
self.upstream_transport.write(self.buffer[header_end:])
|
|
self.buffer = b''
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error connecting to upstream: {e}")
|
|
self._send_error(502, f"Bad Gateway")
|
|
|
|
def _send_error(self, code: int, message: str):
|
|
response = (
|
|
f"HTTP/1.1 {code} {message}\r\n"
|
|
f"Content-Type: text/plain\r\n"
|
|
f"Content-Length: {len(message)}\r\n"
|
|
f"Connection: close\r\n\r\n{message}"
|
|
).encode('utf-8')
|
|
if self.transport:
|
|
self.transport.write(response)
|
|
self.transport.close()
|
|
|
|
def connection_lost(self, exc):
|
|
if self.upstream_transport:
|
|
self.upstream_transport.close()
|
|
|
|
|
|
class UpstreamProtocol(asyncio.Protocol):
|
|
|
|
def __init__(self, proxy_protocol: HTTPProxyProtocol):
|
|
self.proxy_protocol = proxy_protocol
|
|
self.transport = None
|
|
|
|
def connection_made(self, transport):
|
|
self.transport = transport
|
|
|
|
def data_received(self, data: bytes):
|
|
|
|
if self.proxy_protocol.transport:
|
|
self.proxy_protocol.transport.write(data)
|
|
|
|
|
|
self.proxy_protocol.headers_parsed = False
|
|
|
|
def connection_lost(self, exc):
|
|
if self.proxy_protocol.transport:
|
|
self.proxy_protocol.transport.close()
|
|
|
|
|
|
async def cleanup_task(config: ProxyConfig):
|
|
while True:
|
|
await asyncio.sleep(600)
|
|
config.rate_limiter.cleanup_old_entries()
|
|
|
|
|
|
async def main(config_path: str):
|
|
config = ProxyConfig(config_path)
|
|
loop = asyncio.get_event_loop()
|
|
server = await loop.create_server(lambda: HTTPProxyProtocol(config), '0.0.0.0', config.port)
|
|
logging.info(f"Reverse proxy (Keep-Alive enabled) listening on 0.0.0.0:{config.port}")
|
|
asyncio.create_task(cleanup_task(config))
|
|
async with server:
|
|
await server.serve_forever()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if len(sys.argv) >= 2:
|
|
try:
|
|
asyncio.run(main(sys.argv[1]))
|
|
except KeyboardInterrupt:
|
|
pass
|
|
|