From d5e24dc3295aee52decc3fae7d78644fe86a9163 Mon Sep 17 00:00:00 2001 From: retoor Date: Tue, 27 Jan 2026 20:06:10 +0100 Subject: [PATCH] Update. --- proxy_with_ratelimiter.py | 286 ++++++++++++++++++++++++++++++++++++++ rgithook.py | 9 +- 2 files changed, 294 insertions(+), 1 deletion(-) create mode 100644 proxy_with_ratelimiter.py diff --git a/proxy_with_ratelimiter.py b/proxy_with_ratelimiter.py new file mode 100644 index 0000000..2344500 --- /dev/null +++ b/proxy_with_ratelimiter.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 + +import asyncio +import base64 +from collections import defaultdict +import json +import logging +import ssl +import sys +import time +from typing import Dict, Optional + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + + +class RateLimiter: + + def __init__(self, rate: float, capacity: int): + self.rate = rate + self.capacity = capacity + self.buckets: Dict[str, Dict[str, float]] = defaultdict( + lambda: {'tokens': capacity, 'last_update': time.time()} + ) + + def is_allowed(self, client_ip: str) -> bool: + now = time.time() + bucket = self.buckets[client_ip] + + time_passed = now - bucket['last_update'] + bucket['tokens'] = min( + self.capacity, + bucket['tokens'] + time_passed * self.rate + ) + bucket['last_update'] = now + + if bucket['tokens'] >= 1.0: + bucket['tokens'] -= 1.0 + return True + return False + + def cleanup_old_entries(self, max_age: float = 3600): + now = time.time() + to_remove = [ + ip for ip, bucket in self.buckets.items() + if now - bucket['last_update'] > max_age + ] + for ip in to_remove: + del self.buckets[ip] + + +class ProxyConfig: + + def __init__(self, config_path: str): + with open(config_path, 'r') as f: + self.config = json.load(f) + self.port = self.config.get('port', 8585) + self.routes = self._build_routing_table() + + rate_limit_config = self.config.get('rate_limit', {}) + self.rate_limit_enabled = rate_limit_config.get('enabled', True) + requests_per_second = rate_limit_config.get('requests_per_second', 10) + burst_size = rate_limit_config.get('burst_size', 20) + self.rate_limiter = RateLimiter(requests_per_second, burst_size) + + def _build_routing_table(self) -> Dict[str, dict]: + routes = {} + for route in self.config.get('reverse_proxy', []): + hostname = route['hostname'].lower() + routes[hostname] = route + return routes + + def get_route(self, hostname: str) -> Optional[dict]: + return self.routes.get(hostname.lower()) + + +class HTTPProxyProtocol(asyncio.Protocol): + + MAX_HEADER_SIZE = 16 * 1024 + + def __init__(self, config: ProxyConfig): + self.config = config + self.transport = None + self.buffer = b'' + self.upstream_transport = None + self.upstream_protocol = None + self.headers_parsed = False + self.route = None + self.peername = None + + def connection_made(self, transport): + self.transport = transport + self.peername = transport.get_extra_info('peername') + logging.info(f"Connection from {self.peername}") + + def data_received(self, data: bytes): + self.buffer += data + + if not self.headers_parsed: + if len(self.buffer) > self.MAX_HEADER_SIZE: + self._send_error(431, "Request Header Fields Too Large") + return + + + delimiter = b'\r\n\r\n' if b'\r\n\r\n' in self.buffer else b'\n\n' if b'\n\n' in self.buffer else None + + if delimiter: + self._parse_and_route(delimiter) + else: + + if self.upstream_transport: + self.upstream_transport.write(data) + + def _parse_and_route(self, delimiter: bytes): + + try: + header_end = self.buffer.index(delimiter) + len(delimiter) + + if header_end > self.MAX_HEADER_SIZE: + self._send_error(431, "Request Header Fields Too Large") + return + + header_data = self.buffer[:header_end] + lines = header_data.split(b'\r\n') if b'\r\n' in header_data else header_data.split(b'\n') + + if not lines or len(lines[0]) == 0: + self._send_error(400, "Bad Request") + return + + request_line = lines[0].decode('utf-8', errors='ignore') + + headers = {} + for line in lines[1:]: + if b':' in line: + key, value = line.split(b':', 1) + headers[key.decode('utf-8', errors='ignore').strip().lower()] = \ + value.decode('utf-8', errors='ignore').strip() + + + client_ip = self.peername[0] if self.peername else 'unknown' + if 'x-forwarded-for' in headers: + client_ip = headers['x-forwarded-for'].split(',')[0].strip() + + + if self.config.rate_limit_enabled: + if not self.config.rate_limiter.is_allowed(client_ip): + logging.warning(f"Rate limit exceeded for {client_ip}") + self._send_error(429, "Too Many Requests") + return + + host = headers.get('host', '').split(':')[0] + if not host: + self._send_error(400, "Bad Request: Missing Host header") + return + + self.route = self.config.get_route(host) + if not self.route: + self._send_error(404, f"Not Found: No route for {host}") + return + + modified_headers = self._rewrite_headers(headers, request_line) + self.headers_parsed = True + + + if self.upstream_transport and not self.upstream_transport.is_closing(): + self.upstream_transport.write(modified_headers) + + if len(self.buffer) > header_end: + self.upstream_transport.write(self.buffer[header_end:]) + self.buffer = b'' + else: + asyncio.create_task(self._connect_upstream(modified_headers, header_end)) + + except Exception as e: + logging.error(f"Error parsing request: {e}") + self._send_error(500, f"Internal Server Error") + + def _rewrite_headers(self, headers: dict, request_line: str) -> bytes: + + lines = [request_line.encode('utf-8')] + + for key, value in headers.items(): + if key == 'host' and self.route.get('rewrite_host'): + value = f"{self.route['upstream_host']}:{self.route['upstream_port']}" + lines.append(f"{key.title()}: {value}".encode('utf-8')) + + if self.route.get('use_auth'): + username = self.route.get('username', '') + password = self.route.get('password', '') + credentials = base64.b64encode(f"{username}:{password}".encode()).decode() + lines.append(f"Authorization: Basic {credentials}".encode('utf-8')) + + lines.append(b'') + lines.append(b'') + return b'\r\n'.join(lines) + + async def _connect_upstream(self, modified_headers: bytes, header_end: int): + + try: + upstream_host = self.route['upstream_host'] + upstream_port = self.route['upstream_port'] + use_ssl = self.route.get('use_ssl', False) + + loop = asyncio.get_event_loop() + ssl_context = ssl.create_default_context() if use_ssl else None + if ssl_context: + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + + self.upstream_transport, self.upstream_protocol = \ + await loop.create_connection( + lambda: UpstreamProtocol(self), + upstream_host, + upstream_port, + ssl=ssl_context + ) + + self.upstream_transport.write(modified_headers) + if len(self.buffer) > header_end: + self.upstream_transport.write(self.buffer[header_end:]) + self.buffer = b'' + + except Exception as e: + logging.error(f"Error connecting to upstream: {e}") + self._send_error(502, f"Bad Gateway") + + def _send_error(self, code: int, message: str): + response = ( + f"HTTP/1.1 {code} {message}\r\n" + f"Content-Type: text/plain\r\n" + f"Content-Length: {len(message)}\r\n" + f"Connection: close\r\n\r\n{message}" + ).encode('utf-8') + if self.transport: + self.transport.write(response) + self.transport.close() + + def connection_lost(self, exc): + if self.upstream_transport: + self.upstream_transport.close() + + +class UpstreamProtocol(asyncio.Protocol): + + def __init__(self, proxy_protocol: HTTPProxyProtocol): + self.proxy_protocol = proxy_protocol + self.transport = None + + def connection_made(self, transport): + self.transport = transport + + def data_received(self, data: bytes): + + if self.proxy_protocol.transport: + self.proxy_protocol.transport.write(data) + + + self.proxy_protocol.headers_parsed = False + + def connection_lost(self, exc): + if self.proxy_protocol.transport: + self.proxy_protocol.transport.close() + + +async def cleanup_task(config: ProxyConfig): + while True: + await asyncio.sleep(600) + config.rate_limiter.cleanup_old_entries() + + +async def main(config_path: str): + config = ProxyConfig(config_path) + loop = asyncio.get_event_loop() + server = await loop.create_server(lambda: HTTPProxyProtocol(config), '0.0.0.0', config.port) + logging.info(f"Reverse proxy (Keep-Alive enabled) listening on 0.0.0.0:{config.port}") + asyncio.create_task(cleanup_task(config)) + async with server: + await server.serve_forever() + + +if __name__ == '__main__': + if len(sys.argv) >= 2: + try: + asyncio.run(main(sys.argv[1])) + except KeyboardInterrupt: + pass + diff --git a/rgithook.py b/rgithook.py index da8b003..198f883 100755 --- a/rgithook.py +++ b/rgithook.py @@ -279,7 +279,14 @@ ANALYSIS GUIDELINES: - Detect performance work by optimization patterns or caching additions OUTPUT: -Provide only the commit message lines, no explanations or additional text.""" +Provide only the commit message lines, no explanations or additional text. + +BAD OUTPUT EXAMPLE (TOO GENERIC): +chore: update html, js files + +GOOD OUTPUT EXAMPLE: +chore: created a new responsive navigation bar +""" message = call_ai(prompt) if not message: