Proxy.
This commit is contained in:
parent
d5e24dc329
commit
37f708c064
@ -10,16 +10,32 @@ import sys
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
|
||||
DEFAULT_PORT = 8585
|
||||
DEFAULT_REQUESTS_PER_SECOND = 2
|
||||
DEFAULT_BURST_SIZE = 5
|
||||
TOKEN_COST = 1.0
|
||||
HTTP_TOO_MANY_REQUESTS = 429
|
||||
HTTP_NOT_FOUND = 404
|
||||
HTTP_BAD_GATEWAY = 502
|
||||
MSG_TOO_MANY_REQUESTS = "Too Many Requests"
|
||||
MSG_NOT_FOUND = "Not Found"
|
||||
MSG_BAD_GATEWAY = "Bad Gateway"
|
||||
DELIMITER_CRLF = b'\r\n\r\n'
|
||||
DELIMITER_LF = b'\n\n'
|
||||
HOST_ALL = '0.0.0.0'
|
||||
CONTENT_TYPE_TEXT = "text/plain"
|
||||
CONNECTION_CLOSE = "close"
|
||||
UPGRADE_WEBSOCKET = "websocket"
|
||||
AUTH_BASIC = "Basic"
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
|
||||
def __init__(self, rate: float, capacity: int):
|
||||
self.rate = rate
|
||||
self.capacity = capacity
|
||||
self.rate = float(rate)
|
||||
self.capacity = float(capacity)
|
||||
self.buckets: Dict[str, Dict[str, float]] = defaultdict(
|
||||
lambda: {'tokens': capacity, 'last_update': time.time()}
|
||||
lambda: {'tokens': self.capacity, 'last_update': time.time()}
|
||||
)
|
||||
|
||||
def is_allowed(self, client_ip: str) -> bool:
|
||||
@ -27,260 +43,164 @@ class RateLimiter:
|
||||
bucket = self.buckets[client_ip]
|
||||
|
||||
time_passed = now - bucket['last_update']
|
||||
bucket['tokens'] = min(
|
||||
self.capacity,
|
||||
bucket['tokens'] + time_passed * self.rate
|
||||
)
|
||||
refill = time_passed * self.rate
|
||||
|
||||
bucket['tokens'] = min(self.capacity, bucket['tokens'] + refill)
|
||||
bucket['last_update'] = now
|
||||
|
||||
if bucket['tokens'] >= 1.0:
|
||||
bucket['tokens'] -= 1.0
|
||||
if bucket['tokens'] >= TOKEN_COST:
|
||||
bucket['tokens'] -= TOKEN_COST
|
||||
return True
|
||||
return False
|
||||
|
||||
def cleanup_old_entries(self, max_age: float = 3600):
|
||||
now = time.time()
|
||||
to_remove = [
|
||||
ip for ip, bucket in self.buckets.items()
|
||||
if now - bucket['last_update'] > max_age
|
||||
]
|
||||
for ip in to_remove:
|
||||
del self.buckets[ip]
|
||||
|
||||
|
||||
class ProxyConfig:
|
||||
|
||||
def __init__(self, config_path: str):
|
||||
with open(config_path, 'r') as f:
|
||||
self.config = json.load(f)
|
||||
self.port = self.config.get('port', 8585)
|
||||
self.routes = self._build_routing_table()
|
||||
|
||||
rate_limit_config = self.config.get('rate_limit', {})
|
||||
self.rate_limit_enabled = rate_limit_config.get('enabled', True)
|
||||
requests_per_second = rate_limit_config.get('requests_per_second', 10)
|
||||
burst_size = rate_limit_config.get('burst_size', 20)
|
||||
self.rate_limiter = RateLimiter(requests_per_second, burst_size)
|
||||
|
||||
def _build_routing_table(self) -> Dict[str, dict]:
|
||||
routes = {}
|
||||
for route in self.config.get('reverse_proxy', []):
|
||||
hostname = route['hostname'].lower()
|
||||
routes[hostname] = route
|
||||
return routes
|
||||
|
||||
def get_route(self, hostname: str) -> Optional[dict]:
|
||||
return self.routes.get(hostname.lower())
|
||||
|
||||
self.port = self.config.get('port', DEFAULT_PORT)
|
||||
self.routes = {r['hostname'].lower(): r for r in self.config.get('reverse_proxy', [])}
|
||||
rl = self.config.get('rate_limit', {})
|
||||
self.rate_limit_enabled = rl.get('enabled', True)
|
||||
self.rate_limiter = RateLimiter(rl.get('requests_per_second', DEFAULT_REQUESTS_PER_SECOND), rl.get('burst_size', DEFAULT_BURST_SIZE))
|
||||
|
||||
class HTTPProxyProtocol(asyncio.Protocol):
|
||||
|
||||
MAX_HEADER_SIZE = 16 * 1024
|
||||
|
||||
def __init__(self, config: ProxyConfig):
|
||||
self.config = config
|
||||
self.transport = None
|
||||
self.buffer = b''
|
||||
self.upstream_transport = None
|
||||
self.upstream_protocol = None
|
||||
self.buffer = b''
|
||||
self.headers_parsed = False
|
||||
self.is_websocket = False
|
||||
self.route = None
|
||||
self.peername = None
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.transport = transport
|
||||
self.peername = transport.get_extra_info('peername')
|
||||
logging.info(f"Connection from {self.peername}")
|
||||
|
||||
def data_received(self, data: bytes):
|
||||
self.buffer += data
|
||||
|
||||
if not self.headers_parsed:
|
||||
if len(self.buffer) > self.MAX_HEADER_SIZE:
|
||||
self._send_error(431, "Request Header Fields Too Large")
|
||||
return
|
||||
|
||||
|
||||
delimiter = b'\r\n\r\n' if b'\r\n\r\n' in self.buffer else b'\n\n' if b'\n\n' in self.buffer else None
|
||||
|
||||
if delimiter:
|
||||
self._parse_and_route(delimiter)
|
||||
else:
|
||||
|
||||
if self.is_websocket:
|
||||
if self.upstream_transport:
|
||||
self.upstream_transport.write(data)
|
||||
return
|
||||
|
||||
def _parse_and_route(self, delimiter: bytes):
|
||||
self.buffer += data
|
||||
|
||||
while not self.is_websocket and not self.headers_parsed and self.buffer:
|
||||
delimiter = DELIMITER_CRLF if DELIMITER_CRLF in self.buffer else DELIMITER_LF if DELIMITER_LF in self.buffer else None
|
||||
if not delimiter:
|
||||
break
|
||||
|
||||
if not self._process_headers(delimiter):
|
||||
break
|
||||
|
||||
if self.headers_parsed and not self.is_websocket and self.buffer:
|
||||
if self.upstream_transport:
|
||||
self.upstream_transport.write(self.buffer)
|
||||
self.buffer = b''
|
||||
|
||||
def _process_headers(self, delimiter: bytes) -> bool:
|
||||
try:
|
||||
header_end = self.buffer.index(delimiter) + len(delimiter)
|
||||
idx = self.buffer.index(delimiter) + len(delimiter)
|
||||
header_chunk = self.buffer[:idx]
|
||||
self.buffer = self.buffer[idx:]
|
||||
|
||||
if header_end > self.MAX_HEADER_SIZE:
|
||||
self._send_error(431, "Request Header Fields Too Large")
|
||||
return
|
||||
|
||||
header_data = self.buffer[:header_end]
|
||||
lines = header_data.split(b'\r\n') if b'\r\n' in header_data else header_data.split(b'\n')
|
||||
|
||||
if not lines or len(lines[0]) == 0:
|
||||
self._send_error(400, "Bad Request")
|
||||
return
|
||||
|
||||
request_line = lines[0].decode('utf-8', errors='ignore')
|
||||
lines = header_chunk.splitlines()
|
||||
if not lines: return False
|
||||
|
||||
headers = {}
|
||||
for line in lines[1:]:
|
||||
if b':' in line:
|
||||
key, value = line.split(b':', 1)
|
||||
headers[key.decode('utf-8', errors='ignore').strip().lower()] = \
|
||||
value.decode('utf-8', errors='ignore').strip()
|
||||
|
||||
|
||||
client_ip = self.peername[0] if self.peername else 'unknown'
|
||||
if 'x-forwarded-for' in headers:
|
||||
client_ip = headers['x-forwarded-for'].split(',')[0].strip()
|
||||
k, v = line.split(b':', 1)
|
||||
headers[k.decode(errors='ignore').strip().lower()] = v.decode(errors='ignore').strip()
|
||||
|
||||
ip = headers.get('x-forwarded-for', '').split(',')[0].strip() or self.peername[0]
|
||||
|
||||
if self.config.rate_limit_enabled:
|
||||
if not self.config.rate_limiter.is_allowed(client_ip):
|
||||
logging.warning(f"Rate limit exceeded for {client_ip}")
|
||||
self._send_error(429, "Too Many Requests")
|
||||
return
|
||||
if not self.config.rate_limiter.is_allowed(ip):
|
||||
logging.info(f"Rate Limiter BLOCKING: {ip}")
|
||||
self._send_error(HTTP_TOO_MANY_REQUESTS, MSG_TOO_MANY_REQUESTS)
|
||||
return False
|
||||
|
||||
host = headers.get('host', '').split(':')[0]
|
||||
if not host:
|
||||
self._send_error(400, "Bad Request: Missing Host header")
|
||||
return
|
||||
|
||||
self.route = self.config.get_route(host)
|
||||
host = headers.get('host', '').split(':')[0].lower()
|
||||
self.route = self.config.routes.get(host)
|
||||
if not self.route:
|
||||
self._send_error(404, f"Not Found: No route for {host}")
|
||||
return
|
||||
self._send_error(HTTP_NOT_FOUND, MSG_NOT_FOUND)
|
||||
return False
|
||||
|
||||
if headers.get('upgrade', '').lower() == UPGRADE_WEBSOCKET:
|
||||
self.is_websocket = True
|
||||
|
||||
modified_headers = self._rewrite_headers(headers, request_line)
|
||||
self.headers_parsed = True
|
||||
mod_headers = self._rewrite(headers, lines[0].decode())
|
||||
|
||||
|
||||
if self.upstream_transport and not self.upstream_transport.is_closing():
|
||||
self.upstream_transport.write(modified_headers)
|
||||
|
||||
if len(self.buffer) > header_end:
|
||||
self.upstream_transport.write(self.buffer[header_end:])
|
||||
self.buffer = b''
|
||||
if self.upstream_transport:
|
||||
self.upstream_transport.write(mod_headers)
|
||||
else:
|
||||
asyncio.create_task(self._connect_upstream(modified_headers, header_end))
|
||||
|
||||
asyncio.create_task(self._connect(mod_headers))
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"Error parsing request: {e}")
|
||||
self._send_error(500, f"Internal Server Error")
|
||||
|
||||
def _rewrite_headers(self, headers: dict, request_line: str) -> bytes:
|
||||
|
||||
lines = [request_line.encode('utf-8')]
|
||||
|
||||
for key, value in headers.items():
|
||||
if key == 'host' and self.route.get('rewrite_host'):
|
||||
value = f"{self.route['upstream_host']}:{self.route['upstream_port']}"
|
||||
lines.append(f"{key.title()}: {value}".encode('utf-8'))
|
||||
logging.error(f"Parser error: {e}")
|
||||
return False
|
||||
|
||||
def _rewrite(self, headers: dict, req_line: str) -> bytes:
|
||||
out = [req_line.encode()]
|
||||
for k, v in headers.items():
|
||||
if k == 'host' and self.route.get('rewrite_host'):
|
||||
v = f"{self.route['upstream_host']}:{self.route['upstream_port']}"
|
||||
out.append(f"{k.title()}: {v}".encode())
|
||||
if self.route.get('use_auth'):
|
||||
username = self.route.get('username', '')
|
||||
password = self.route.get('password', '')
|
||||
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
|
||||
lines.append(f"Authorization: Basic {credentials}".encode('utf-8'))
|
||||
|
||||
lines.append(b'')
|
||||
lines.append(b'')
|
||||
return b'\r\n'.join(lines)
|
||||
|
||||
async def _connect_upstream(self, modified_headers: bytes, header_end: int):
|
||||
creds = base64.b64encode(f"{self.route['username']}:{self.route['password']}".encode()).decode()
|
||||
out.append(f"Authorization: {AUTH_BASIC} {creds}".encode())
|
||||
return b'\r\n'.join(out) + b'\r\n\r\n'
|
||||
|
||||
async def _connect(self, mod_headers: bytes):
|
||||
try:
|
||||
upstream_host = self.route['upstream_host']
|
||||
upstream_port = self.route['upstream_port']
|
||||
use_ssl = self.route.get('use_ssl', False)
|
||||
ctx = None
|
||||
if self.route.get('use_ssl'):
|
||||
ctx = ssl.create_default_context()
|
||||
ctx.check_hostname = False
|
||||
ctx.verify_mode = ssl.CERT_NONE
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
ssl_context = ssl.create_default_context() if use_ssl else None
|
||||
if ssl_context:
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
self.upstream_transport, _ = await asyncio.get_event_loop().create_connection(
|
||||
lambda: UpstreamProtocol(self), self.route['upstream_host'], self.route['upstream_port'], ssl=ctx
|
||||
)
|
||||
self.upstream_transport.write(mod_headers)
|
||||
if self.buffer:
|
||||
self.upstream_transport.write(self.buffer)
|
||||
self.buffer = b''
|
||||
except Exception:
|
||||
self._send_error(HTTP_BAD_GATEWAY, MSG_BAD_GATEWAY)
|
||||
|
||||
self.upstream_transport, self.upstream_protocol = \
|
||||
await loop.create_connection(
|
||||
lambda: UpstreamProtocol(self),
|
||||
upstream_host,
|
||||
upstream_port,
|
||||
ssl=ssl_context
|
||||
)
|
||||
|
||||
self.upstream_transport.write(modified_headers)
|
||||
if len(self.buffer) > header_end:
|
||||
self.upstream_transport.write(self.buffer[header_end:])
|
||||
self.buffer = b''
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error connecting to upstream: {e}")
|
||||
self._send_error(502, f"Bad Gateway")
|
||||
|
||||
def _send_error(self, code: int, message: str):
|
||||
response = (
|
||||
f"HTTP/1.1 {code} {message}\r\n"
|
||||
f"Content-Type: text/plain\r\n"
|
||||
f"Content-Length: {len(message)}\r\n"
|
||||
f"Connection: close\r\n\r\n{message}"
|
||||
).encode('utf-8')
|
||||
def _send_error(self, code: int, msg: str):
|
||||
if self.transport:
|
||||
self.transport.write(response)
|
||||
self.transport.write(f"HTTP/1.1 {code} {msg}\r\nContent-Type: {CONTENT_TYPE_TEXT}\r\nConnection: {CONNECTION_CLOSE}\r\n\r\n{msg}".encode())
|
||||
self.transport.close()
|
||||
|
||||
def connection_lost(self, exc):
|
||||
if self.upstream_transport:
|
||||
self.upstream_transport.close()
|
||||
|
||||
if self.upstream_transport: self.upstream_transport.close()
|
||||
|
||||
class UpstreamProtocol(asyncio.Protocol):
|
||||
|
||||
def __init__(self, proxy_protocol: HTTPProxyProtocol):
|
||||
self.proxy_protocol = proxy_protocol
|
||||
self.transport = None
|
||||
def __init__(self, proxy: HTTPProxyProtocol):
|
||||
self.proxy = proxy
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.transport = transport
|
||||
|
||||
def data_received(self, data: bytes):
|
||||
|
||||
if self.proxy_protocol.transport:
|
||||
self.proxy_protocol.transport.write(data)
|
||||
|
||||
|
||||
self.proxy_protocol.headers_parsed = False
|
||||
if self.proxy.transport:
|
||||
self.proxy.transport.write(data)
|
||||
if not self.proxy.is_websocket and (DELIMITER_CRLF in data or DELIMITER_LF in data):
|
||||
self.proxy.headers_parsed = False
|
||||
|
||||
def connection_lost(self, exc):
|
||||
if self.proxy_protocol.transport:
|
||||
self.proxy_protocol.transport.close()
|
||||
|
||||
|
||||
async def cleanup_task(config: ProxyConfig):
|
||||
while True:
|
||||
await asyncio.sleep(600)
|
||||
config.rate_limiter.cleanup_old_entries()
|
||||
|
||||
|
||||
async def main(config_path: str):
|
||||
config = ProxyConfig(config_path)
|
||||
loop = asyncio.get_event_loop()
|
||||
server = await loop.create_server(lambda: HTTPProxyProtocol(config), '0.0.0.0', config.port)
|
||||
logging.info(f"Reverse proxy (Keep-Alive enabled) listening on 0.0.0.0:{config.port}")
|
||||
asyncio.create_task(cleanup_task(config))
|
||||
async with server:
|
||||
await server.serve_forever()
|
||||
if self.proxy.transport: self.proxy.transport.close()
|
||||
|
||||
async def main(path: str):
|
||||
cfg = ProxyConfig(path)
|
||||
server = await asyncio.get_event_loop().create_server(lambda: HTTPProxyProtocol(cfg), HOST_ALL, cfg.port)
|
||||
logging.info(f"Proxy running on {cfg.port}")
|
||||
async with server: await server.serve_forever()
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) >= 2:
|
||||
try:
|
||||
asyncio.run(main(sys.argv[1]))
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
if len(sys.argv) > 1: asyncio.run(main(sys.argv[1]))
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user