|
#!/usr/bin/env python3
|
|
|
|
import asyncio
|
|
import base64
|
|
from collections import defaultdict
|
|
import json
|
|
import logging
|
|
import ssl
|
|
import sys
|
|
import time
|
|
from typing import Dict, Optional
|
|
|
|
DEFAULT_PORT = 8585
|
|
DEFAULT_REQUESTS_PER_SECOND = 2
|
|
DEFAULT_BURST_SIZE = 5
|
|
TOKEN_COST = 1.0
|
|
HTTP_TOO_MANY_REQUESTS = 429
|
|
HTTP_NOT_FOUND = 404
|
|
HTTP_BAD_GATEWAY = 502
|
|
MSG_TOO_MANY_REQUESTS = "Too Many Requests"
|
|
MSG_NOT_FOUND = "Not Found"
|
|
MSG_BAD_GATEWAY = "Bad Gateway"
|
|
DELIMITER_CRLF = b'\r\n\r\n'
|
|
DELIMITER_LF = b'\n\n'
|
|
HOST_ALL = '0.0.0.0'
|
|
CONTENT_TYPE_TEXT = "text/plain"
|
|
CONNECTION_CLOSE = "close"
|
|
UPGRADE_WEBSOCKET = "websocket"
|
|
AUTH_BASIC = "Basic"
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
class RateLimiter:
|
|
def __init__(self, rate: float, capacity: int):
|
|
self.rate = float(rate)
|
|
self.capacity = float(capacity)
|
|
self.buckets: Dict[str, Dict[str, float]] = defaultdict(
|
|
lambda: {'tokens': self.capacity, 'last_update': time.time()}
|
|
)
|
|
|
|
def is_allowed(self, client_ip: str) -> bool:
|
|
now = time.time()
|
|
bucket = self.buckets[client_ip]
|
|
|
|
time_passed = now - bucket['last_update']
|
|
refill = time_passed * self.rate
|
|
|
|
bucket['tokens'] = min(self.capacity, bucket['tokens'] + refill)
|
|
bucket['last_update'] = now
|
|
|
|
if bucket['tokens'] >= TOKEN_COST:
|
|
bucket['tokens'] -= TOKEN_COST
|
|
return True
|
|
return False
|
|
|
|
class ProxyConfig:
|
|
def __init__(self, config_path: str):
|
|
with open(config_path, 'r') as f:
|
|
self.config = json.load(f)
|
|
self.port = self.config.get('port', DEFAULT_PORT)
|
|
self.routes = {r['hostname'].lower(): r for r in self.config.get('reverse_proxy', [])}
|
|
rl = self.config.get('rate_limit', {})
|
|
self.rate_limit_enabled = rl.get('enabled', True)
|
|
self.rate_limiter = RateLimiter(rl.get('requests_per_second', DEFAULT_REQUESTS_PER_SECOND), rl.get('burst_size', DEFAULT_BURST_SIZE))
|
|
|
|
class HTTPProxyProtocol(asyncio.Protocol):
|
|
def __init__(self, config: ProxyConfig):
|
|
self.config = config
|
|
self.transport = None
|
|
self.upstream_transport = None
|
|
self.buffer = b''
|
|
self.headers_parsed = False
|
|
self.is_websocket = False
|
|
self.route = None
|
|
|
|
def connection_made(self, transport):
|
|
self.transport = transport
|
|
self.peername = transport.get_extra_info('peername')
|
|
|
|
def data_received(self, data: bytes):
|
|
if self.is_websocket:
|
|
if self.upstream_transport:
|
|
self.upstream_transport.write(data)
|
|
return
|
|
|
|
self.buffer += data
|
|
|
|
while not self.is_websocket and not self.headers_parsed and self.buffer:
|
|
delimiter = DELIMITER_CRLF if DELIMITER_CRLF in self.buffer else DELIMITER_LF if DELIMITER_LF in self.buffer else None
|
|
if not delimiter:
|
|
break
|
|
|
|
if not self._process_headers(delimiter):
|
|
break
|
|
|
|
if self.headers_parsed and not self.is_websocket and self.buffer:
|
|
if self.upstream_transport:
|
|
self.upstream_transport.write(self.buffer)
|
|
self.buffer = b''
|
|
|
|
def _process_headers(self, delimiter: bytes) -> bool:
|
|
try:
|
|
idx = self.buffer.index(delimiter) + len(delimiter)
|
|
header_chunk = self.buffer[:idx]
|
|
self.buffer = self.buffer[idx:]
|
|
|
|
lines = header_chunk.splitlines()
|
|
if not lines: return False
|
|
|
|
headers = {}
|
|
for line in lines[1:]:
|
|
if b':' in line:
|
|
k, v = line.split(b':', 1)
|
|
headers[k.decode(errors='ignore').strip().lower()] = v.decode(errors='ignore').strip()
|
|
|
|
ip = headers.get('x-forwarded-for', '').split(',')[0].strip() or self.peername[0]
|
|
|
|
if self.config.rate_limit_enabled:
|
|
if not self.config.rate_limiter.is_allowed(ip):
|
|
logging.info(f"Rate Limiter BLOCKING: {ip}")
|
|
self._send_error(HTTP_TOO_MANY_REQUESTS, MSG_TOO_MANY_REQUESTS)
|
|
return False
|
|
|
|
host = headers.get('host', '').split(':')[0].lower()
|
|
self.route = self.config.routes.get(host)
|
|
if not self.route:
|
|
self._send_error(HTTP_NOT_FOUND, MSG_NOT_FOUND)
|
|
return False
|
|
|
|
if headers.get('upgrade', '').lower() == UPGRADE_WEBSOCKET:
|
|
self.is_websocket = True
|
|
|
|
self.headers_parsed = True
|
|
mod_headers = self._rewrite(headers, lines[0].decode())
|
|
|
|
if self.upstream_transport:
|
|
self.upstream_transport.write(mod_headers)
|
|
else:
|
|
asyncio.create_task(self._connect(mod_headers))
|
|
return True
|
|
except Exception as e:
|
|
logging.error(f"Parser error: {e}")
|
|
return False
|
|
|
|
def _rewrite(self, headers: dict, req_line: str) -> bytes:
|
|
out = [req_line.encode()]
|
|
for k, v in headers.items():
|
|
if k == 'host' and self.route.get('rewrite_host'):
|
|
v = f"{self.route['upstream_host']}:{self.route['upstream_port']}"
|
|
out.append(f"{k.title()}: {v}".encode())
|
|
if self.route.get('use_auth'):
|
|
creds = base64.b64encode(f"{self.route['username']}:{self.route['password']}".encode()).decode()
|
|
out.append(f"Authorization: {AUTH_BASIC} {creds}".encode())
|
|
return b'\r\n'.join(out) + b'\r\n\r\n'
|
|
|
|
async def _connect(self, mod_headers: bytes):
|
|
try:
|
|
ctx = None
|
|
if self.route.get('use_ssl'):
|
|
ctx = ssl.create_default_context()
|
|
ctx.check_hostname = False
|
|
ctx.verify_mode = ssl.CERT_NONE
|
|
|
|
self.upstream_transport, _ = await asyncio.get_event_loop().create_connection(
|
|
lambda: UpstreamProtocol(self), self.route['upstream_host'], self.route['upstream_port'], ssl=ctx
|
|
)
|
|
self.upstream_transport.write(mod_headers)
|
|
if self.buffer:
|
|
self.upstream_transport.write(self.buffer)
|
|
self.buffer = b''
|
|
except Exception:
|
|
self._send_error(HTTP_BAD_GATEWAY, MSG_BAD_GATEWAY)
|
|
|
|
def _send_error(self, code: int, msg: str):
|
|
if self.transport:
|
|
self.transport.write(f"HTTP/1.1 {code} {msg}\r\nContent-Type: {CONTENT_TYPE_TEXT}\r\nConnection: {CONNECTION_CLOSE}\r\n\r\n{msg}".encode())
|
|
self.transport.close()
|
|
|
|
def connection_lost(self, exc):
|
|
if self.upstream_transport: self.upstream_transport.close()
|
|
|
|
class UpstreamProtocol(asyncio.Protocol):
|
|
def __init__(self, proxy: HTTPProxyProtocol):
|
|
self.proxy = proxy
|
|
|
|
def connection_made(self, transport):
|
|
self.transport = transport
|
|
|
|
def data_received(self, data: bytes):
|
|
if self.proxy.transport:
|
|
self.proxy.transport.write(data)
|
|
if not self.proxy.is_websocket and (DELIMITER_CRLF in data or DELIMITER_LF in data):
|
|
self.proxy.headers_parsed = False
|
|
|
|
def connection_lost(self, exc):
|
|
if self.proxy.transport: self.proxy.transport.close()
|
|
|
|
async def main(path: str):
|
|
cfg = ProxyConfig(path)
|
|
server = await asyncio.get_event_loop().create_server(lambda: HTTPProxyProtocol(cfg), HOST_ALL, cfg.port)
|
|
logging.info(f"Proxy running on {cfg.port}")
|
|
async with server: await server.serve_forever()
|
|
|
|
if __name__ == '__main__':
|
|
if len(sys.argv) > 1: asyncio.run(main(sys.argv[1]))
|
|
|