This commit is contained in:
retoor 2026-01-27 20:06:10 +01:00
parent e0fb68d659
commit d5e24dc329
2 changed files with 294 additions and 1 deletions

286
proxy_with_ratelimiter.py Normal file
View File

@ -0,0 +1,286 @@
#!/usr/bin/env python3
import asyncio
import base64
from collections import defaultdict
import json
import logging
import ssl
import sys
import time
from typing import Dict, Optional
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class RateLimiter:
def __init__(self, rate: float, capacity: int):
self.rate = rate
self.capacity = capacity
self.buckets: Dict[str, Dict[str, float]] = defaultdict(
lambda: {'tokens': capacity, 'last_update': time.time()}
)
def is_allowed(self, client_ip: str) -> bool:
now = time.time()
bucket = self.buckets[client_ip]
time_passed = now - bucket['last_update']
bucket['tokens'] = min(
self.capacity,
bucket['tokens'] + time_passed * self.rate
)
bucket['last_update'] = now
if bucket['tokens'] >= 1.0:
bucket['tokens'] -= 1.0
return True
return False
def cleanup_old_entries(self, max_age: float = 3600):
now = time.time()
to_remove = [
ip for ip, bucket in self.buckets.items()
if now - bucket['last_update'] > max_age
]
for ip in to_remove:
del self.buckets[ip]
class ProxyConfig:
def __init__(self, config_path: str):
with open(config_path, 'r') as f:
self.config = json.load(f)
self.port = self.config.get('port', 8585)
self.routes = self._build_routing_table()
rate_limit_config = self.config.get('rate_limit', {})
self.rate_limit_enabled = rate_limit_config.get('enabled', True)
requests_per_second = rate_limit_config.get('requests_per_second', 10)
burst_size = rate_limit_config.get('burst_size', 20)
self.rate_limiter = RateLimiter(requests_per_second, burst_size)
def _build_routing_table(self) -> Dict[str, dict]:
routes = {}
for route in self.config.get('reverse_proxy', []):
hostname = route['hostname'].lower()
routes[hostname] = route
return routes
def get_route(self, hostname: str) -> Optional[dict]:
return self.routes.get(hostname.lower())
class HTTPProxyProtocol(asyncio.Protocol):
MAX_HEADER_SIZE = 16 * 1024
def __init__(self, config: ProxyConfig):
self.config = config
self.transport = None
self.buffer = b''
self.upstream_transport = None
self.upstream_protocol = None
self.headers_parsed = False
self.route = None
self.peername = None
def connection_made(self, transport):
self.transport = transport
self.peername = transport.get_extra_info('peername')
logging.info(f"Connection from {self.peername}")
def data_received(self, data: bytes):
self.buffer += data
if not self.headers_parsed:
if len(self.buffer) > self.MAX_HEADER_SIZE:
self._send_error(431, "Request Header Fields Too Large")
return
delimiter = b'\r\n\r\n' if b'\r\n\r\n' in self.buffer else b'\n\n' if b'\n\n' in self.buffer else None
if delimiter:
self._parse_and_route(delimiter)
else:
if self.upstream_transport:
self.upstream_transport.write(data)
def _parse_and_route(self, delimiter: bytes):
try:
header_end = self.buffer.index(delimiter) + len(delimiter)
if header_end > self.MAX_HEADER_SIZE:
self._send_error(431, "Request Header Fields Too Large")
return
header_data = self.buffer[:header_end]
lines = header_data.split(b'\r\n') if b'\r\n' in header_data else header_data.split(b'\n')
if not lines or len(lines[0]) == 0:
self._send_error(400, "Bad Request")
return
request_line = lines[0].decode('utf-8', errors='ignore')
headers = {}
for line in lines[1:]:
if b':' in line:
key, value = line.split(b':', 1)
headers[key.decode('utf-8', errors='ignore').strip().lower()] = \
value.decode('utf-8', errors='ignore').strip()
client_ip = self.peername[0] if self.peername else 'unknown'
if 'x-forwarded-for' in headers:
client_ip = headers['x-forwarded-for'].split(',')[0].strip()
if self.config.rate_limit_enabled:
if not self.config.rate_limiter.is_allowed(client_ip):
logging.warning(f"Rate limit exceeded for {client_ip}")
self._send_error(429, "Too Many Requests")
return
host = headers.get('host', '').split(':')[0]
if not host:
self._send_error(400, "Bad Request: Missing Host header")
return
self.route = self.config.get_route(host)
if not self.route:
self._send_error(404, f"Not Found: No route for {host}")
return
modified_headers = self._rewrite_headers(headers, request_line)
self.headers_parsed = True
if self.upstream_transport and not self.upstream_transport.is_closing():
self.upstream_transport.write(modified_headers)
if len(self.buffer) > header_end:
self.upstream_transport.write(self.buffer[header_end:])
self.buffer = b''
else:
asyncio.create_task(self._connect_upstream(modified_headers, header_end))
except Exception as e:
logging.error(f"Error parsing request: {e}")
self._send_error(500, f"Internal Server Error")
def _rewrite_headers(self, headers: dict, request_line: str) -> bytes:
lines = [request_line.encode('utf-8')]
for key, value in headers.items():
if key == 'host' and self.route.get('rewrite_host'):
value = f"{self.route['upstream_host']}:{self.route['upstream_port']}"
lines.append(f"{key.title()}: {value}".encode('utf-8'))
if self.route.get('use_auth'):
username = self.route.get('username', '')
password = self.route.get('password', '')
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
lines.append(f"Authorization: Basic {credentials}".encode('utf-8'))
lines.append(b'')
lines.append(b'')
return b'\r\n'.join(lines)
async def _connect_upstream(self, modified_headers: bytes, header_end: int):
try:
upstream_host = self.route['upstream_host']
upstream_port = self.route['upstream_port']
use_ssl = self.route.get('use_ssl', False)
loop = asyncio.get_event_loop()
ssl_context = ssl.create_default_context() if use_ssl else None
if ssl_context:
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
self.upstream_transport, self.upstream_protocol = \
await loop.create_connection(
lambda: UpstreamProtocol(self),
upstream_host,
upstream_port,
ssl=ssl_context
)
self.upstream_transport.write(modified_headers)
if len(self.buffer) > header_end:
self.upstream_transport.write(self.buffer[header_end:])
self.buffer = b''
except Exception as e:
logging.error(f"Error connecting to upstream: {e}")
self._send_error(502, f"Bad Gateway")
def _send_error(self, code: int, message: str):
response = (
f"HTTP/1.1 {code} {message}\r\n"
f"Content-Type: text/plain\r\n"
f"Content-Length: {len(message)}\r\n"
f"Connection: close\r\n\r\n{message}"
).encode('utf-8')
if self.transport:
self.transport.write(response)
self.transport.close()
def connection_lost(self, exc):
if self.upstream_transport:
self.upstream_transport.close()
class UpstreamProtocol(asyncio.Protocol):
def __init__(self, proxy_protocol: HTTPProxyProtocol):
self.proxy_protocol = proxy_protocol
self.transport = None
def connection_made(self, transport):
self.transport = transport
def data_received(self, data: bytes):
if self.proxy_protocol.transport:
self.proxy_protocol.transport.write(data)
self.proxy_protocol.headers_parsed = False
def connection_lost(self, exc):
if self.proxy_protocol.transport:
self.proxy_protocol.transport.close()
async def cleanup_task(config: ProxyConfig):
while True:
await asyncio.sleep(600)
config.rate_limiter.cleanup_old_entries()
async def main(config_path: str):
config = ProxyConfig(config_path)
loop = asyncio.get_event_loop()
server = await loop.create_server(lambda: HTTPProxyProtocol(config), '0.0.0.0', config.port)
logging.info(f"Reverse proxy (Keep-Alive enabled) listening on 0.0.0.0:{config.port}")
asyncio.create_task(cleanup_task(config))
async with server:
await server.serve_forever()
if __name__ == '__main__':
if len(sys.argv) >= 2:
try:
asyncio.run(main(sys.argv[1]))
except KeyboardInterrupt:
pass

View File

@ -279,7 +279,14 @@ ANALYSIS GUIDELINES:
- Detect performance work by optimization patterns or caching additions
OUTPUT:
Provide only the commit message lines, no explanations or additional text."""
Provide only the commit message lines, no explanations or additional text.
BAD OUTPUT EXAMPLE (TOO GENERIC):
chore: update html, js files
GOOD OUTPUT EXAMPLE:
chore: created a new responsive navigation bar
"""
message = call_ai(prompt)
if not message: