From 6378c397a0d0e4cd7fe688df4bd4c8c957e32d17 Mon Sep 17 00:00:00 2001 From: retoor Date: Sun, 20 Jul 2025 20:18:24 +0200 Subject: [PATCH] Update, new snek_rpc connection. --- snek_rpc.py | 592 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 592 insertions(+) create mode 100644 snek_rpc.py diff --git a/snek_rpc.py b/snek_rpc.py new file mode 100644 index 0000000..ff58414 --- /dev/null +++ b/snek_rpc.py @@ -0,0 +1,592 @@ +import asyncio +import json +import base64 +import hashlib +import ssl +import logging +import time +import os +from typing import Optional, Dict, Any, Callable, Union +from urllib.parse import urlparse +from enum import Enum +from collections import defaultdict +from contextlib import asynccontextmanager +import uuid + + +class ConnectionState(Enum): + """WebSocket connection states""" + DISCONNECTED = "disconnected" + CONNECTING = "connecting" + CONNECTED = "connected" + CLOSING = "closing" + CLOSED = "closed" + + +class WebSocketException(Exception): + """Base exception for WebSocket errors""" + pass + + +class ConnectionError(WebSocketException): + """Connection-related errors""" + pass + + +class ProtocolError(WebSocketException): + """WebSocket protocol errors""" + pass + + +class WebSocketClient: + """ + Production-ready WebSocket client with automatic reconnection, + proper error handling, and message queuing. + """ + + # WebSocket opcodes + OPCODE_CONTINUATION = 0x0 + OPCODE_TEXT = 0x1 + OPCODE_BINARY = 0x2 + OPCODE_CLOSE = 0x8 + OPCODE_PING = 0x9 + OPCODE_PONG = 0xa + + def __init__( + self, + url: str, + ssl_context: Optional[ssl.SSLContext] = None, + ping_interval: float = 30.0, + ping_timeout: float = 10.0, + reconnect_interval: float = 5.0, + max_reconnect_attempts: int = 10, + logger: Optional[logging.Logger] = None + ): + self.url = urlparse(url) + self.ssl_context = ssl_context or ssl.create_default_context() + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout + self.reconnect_interval = reconnect_interval + self.max_reconnect_attempts = max_reconnect_attempts + self.logger = logger or logging.getLogger(__name__) + + # Connection state + self.state = ConnectionState.DISCONNECTED + self.reader: Optional[asyncio.StreamReader] = None + self.writer: Optional[asyncio.StreamWriter] = None + + # Message handling + self._response_futures: Dict[str, asyncio.Future] = {} + self._event_handlers: Dict[str, list] = defaultdict(list) + self._receive_task: Optional[asyncio.Task] = None + self._ping_task: Optional[asyncio.Task] = None + + # Reconnection + self._reconnect_attempts = 0 + self._reconnect_lock = asyncio.Lock() + self._closing = False + + # Message queue for buffering during reconnection + self._message_queue = asyncio.Queue(maxsize=1000) + + async def connect(self) -> None: + """Establish WebSocket connection with proper handshake""" + async with self._reconnect_lock: + if self.state == ConnectionState.CONNECTED: + return + + self.state = ConnectionState.CONNECTING + self.logger.info(f"Connecting to {self.url.geturl()}") + + try: + # Establish TCP connection + port = self.url.port or (443 if self.url.scheme == 'wss' else 80) + self.reader, self.writer = await asyncio.open_connection( + self.url.hostname, + port, + ssl=self.ssl_context if self.url.scheme == 'wss' else None + ) + + # Perform WebSocket handshake + await self._handshake() + + # Start background tasks + self._receive_task = asyncio.create_task(self._receive_loop()) + self._ping_task = asyncio.create_task(self._ping_loop()) + + self.state = ConnectionState.CONNECTED + self._reconnect_attempts = 0 + self.logger.info("WebSocket connection established") + + # Process any queued messages + await self._process_queued_messages() + + except Exception as e: + self.state = ConnectionState.DISCONNECTED + self.logger.error(f"Connection failed: {e}") + raise ConnectionError(f"Failed to connect: {e}") + + async def _handshake(self) -> None: + """Perform WebSocket handshake""" + # Generate WebSocket key + key = base64.b64encode(os.urandom(16)).decode('utf-8') + + # Build handshake request + path = self.url.path or '/' + if self.url.query: + path += '?' + self.url.query + + handshake = ( + f"GET {path} HTTP/1.1\r\n" + f"Host: {self.url.hostname}\r\n" + f"Upgrade: websocket\r\n" + f"Connection: Upgrade\r\n" + f"Sec-WebSocket-Key: {key}\r\n" + f"Sec-WebSocket-Version: 13\r\n" + f"User-Agent: Python-WebSocket-Client\r\n" + f"\r\n" + ) + + self.writer.write(handshake.encode()) + await self.writer.drain() + + # Read response + response = await self.reader.readuntil(b'\r\n\r\n') + + # Validate response + if b'101 Switching Protocols' not in response: + raise ProtocolError(f"Invalid handshake response: {response.decode()}") + + # Verify Sec-WebSocket-Accept + expected_accept = base64.b64encode( + hashlib.sha1((key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode()).digest() + ).decode() + + if f'Sec-WebSocket-Accept: {expected_accept}'.encode() not in response: + self.logger.warning("Sec-WebSocket-Accept header mismatch") + + async def _receive_loop(self) -> None: + """Main loop for receiving and processing messages""" + try: + while self.state == ConnectionState.CONNECTED and not self._closing: + try: + message = await self._receive_message() + + if message['opcode'] == self.OPCODE_TEXT: + await self._handle_text_message(message['data'].decode('utf-8')) + elif message['opcode'] == self.OPCODE_BINARY: + await self._handle_binary_message(message['data']) + elif message['opcode'] == self.OPCODE_CLOSE: + await self._handle_close_frame(message['data']) + break + elif message['opcode'] == self.OPCODE_PING: + await self._send_frame(self.OPCODE_PONG, message['data']) + elif message['opcode'] == self.OPCODE_PONG: + pass # Pong received, connection is alive + + except asyncio.TimeoutError: + self.logger.warning("Receive timeout") + except Exception as e: + self.logger.error(f"Error in receive loop: {e}") + break + + except Exception as e: + self.logger.error(f"Fatal error in receive loop: {e}") + finally: + if not self._closing: + await self._handle_disconnect() + + async def _receive_message(self) -> Dict[str, Any]: + """Receive and parse a WebSocket frame""" + # Read frame header + header = await self.reader.readexactly(2) + + fin = (header[0] & 0x80) != 0 + opcode = header[0] & 0x0f + masked = (header[1] & 0x80) != 0 + payload_length = header[1] & 0x7f + + # Read extended payload length if needed + if payload_length == 126: + payload_length = int.from_bytes(await self.reader.readexactly(2), 'big') + elif payload_length == 127: + payload_length = int.from_bytes(await self.reader.readexactly(8), 'big') + + # Read masking key if present (server shouldn't send masked frames) + mask_key = None + if masked: + mask_key = await self.reader.readexactly(4) + + # Read payload + payload = await self.reader.readexactly(payload_length) + + # Unmask payload if needed + if masked: + payload = bytes(b ^ mask_key[i % 4] for i, b in enumerate(payload)) + + return { + 'fin': fin, + 'opcode': opcode, + 'data': payload + } + + async def _send_frame(self, opcode: int, data: bytes) -> None: + """Send a WebSocket frame""" + if self.state != ConnectionState.CONNECTED: + raise ConnectionError("Not connected") + + # Build frame + frame = bytearray() + + # FIN = 1, RSV = 0, Opcode + frame.append(0x80 | opcode) + + # Mask bit = 1 (client must mask), payload length + length = len(data) + if length <= 125: + frame.append(0x80 | length) + elif length <= 65535: + frame.append(0x80 | 126) + frame.extend(length.to_bytes(2, 'big')) + else: + frame.append(0x80 | 127) + frame.extend(length.to_bytes(8, 'big')) + + # Masking key + mask_key = os.urandom(4) + frame.extend(mask_key) + + # Masked payload + masked_data = bytes(b ^ mask_key[i % 4] for i, b in enumerate(data)) + frame.extend(masked_data) + + # Send frame + self.writer.write(frame) + await self.writer.drain() + + async def _handle_text_message(self, message: str) -> None: + """Handle incoming text message""" + try: + data = json.loads(message) + + # Check if this is a response to a request + if 'callId' in data and data['callId'] in self._response_futures: + future = self._response_futures.pop(data['callId']) + if not future.done(): + future.set_result(data.get('data', data)) + else: + # Emit as event + await self._emit_event('message', data) + + except json.JSONDecodeError: + self.logger.error(f"Invalid JSON message: {message}") + await self._emit_event('message', message) + + async def _handle_binary_message(self, data: bytes) -> None: + """Handle incoming binary message""" + await self._emit_event('binary', data) + + async def _handle_close_frame(self, data: bytes) -> None: + """Handle close frame from server""" + code = int.from_bytes(data[:2], 'big') if len(data) >= 2 else 1000 + reason = data[2:].decode('utf-8', errors='ignore') if len(data) > 2 else '' + + self.logger.info(f"Received close frame: {code} {reason}") + + # Send close frame back + await self._send_frame(self.OPCODE_CLOSE, data[:2]) + + self.state = ConnectionState.CLOSING + + async def _handle_disconnect(self) -> None: + """Handle unexpected disconnection""" + self.state = ConnectionState.DISCONNECTED + self.logger.warning("WebSocket disconnected") + + # Cancel pending futures + for future in self._response_futures.values(): + if not future.done(): + future.set_exception(ConnectionError("Disconnected")) + self._response_futures.clear() + + # Emit disconnect event + await self._emit_event('disconnect', None) + + # Attempt reconnection if not closing + if not self._closing: + asyncio.create_task(self._reconnect()) + + async def _reconnect(self) -> None: + """Attempt to reconnect with exponential backoff""" + while self._reconnect_attempts < self.max_reconnect_attempts and not self._closing: + self._reconnect_attempts += 1 + wait_time = min(self.reconnect_interval * (2 ** (self._reconnect_attempts - 1)), 60) + + self.logger.info(f"Reconnecting in {wait_time}s (attempt {self._reconnect_attempts}/{self.max_reconnect_attempts})") + await asyncio.sleep(wait_time) + + try: + await self.connect() + await self._emit_event('reconnect', self._reconnect_attempts) + return + except Exception as e: + self.logger.error(f"Reconnection failed: {e}") + + self.logger.error("Max reconnection attempts reached") + await self._emit_event('reconnect_failed', self._reconnect_attempts) + + async def _ping_loop(self) -> None: + """Send periodic ping frames to keep connection alive""" + try: + while self.state == ConnectionState.CONNECTED and not self._closing: + await asyncio.sleep(self.ping_interval) + + try: + await self._send_frame(self.OPCODE_PING, b'') + except Exception as e: + self.logger.error(f"Ping failed: {e}") + break + + except asyncio.CancelledError: + pass + + async def _process_queued_messages(self) -> None: + """Process messages that were queued during disconnection""" + while not self._message_queue.empty(): + try: + message = self._message_queue.get_nowait() + await self.send_json(message) + except Exception as e: + self.logger.error(f"Failed to send queued message: {e}") + + async def send(self, message: Union[str, bytes]) -> None: + """Send a message (text or binary)""" + if isinstance(message, str): + await self._send_frame(self.OPCODE_TEXT, message.encode('utf-8')) + else: + await self._send_frame(self.OPCODE_BINARY, message) + + async def send_json(self, data: Dict[str, Any]) -> None: + """Send JSON-encoded message""" + message = json.dumps(data) + await self.send(message) + + async def request(self, method: str, *args, **kwargs) -> Any: + """ + Send a request and wait for response. + + Args: + method: The method name to call + *args: Positional arguments for the method + **kwargs: Keyword arguments for the method + + Returns: + The response data + """ + # Handle no_wait flag + no_wait = kwargs.pop('no_wait', False) + + # Generate call ID + call_id = str(uuid.uuid4()) if not no_wait else None + + # Prepare request + request = { + "method": method, + "args": args + } + + if call_id: + request["callId"] = call_id + + # Create future for response if waiting + if call_id: + future = asyncio.Future() + self._response_futures[call_id] = future + + try: + # Send request + if self.state == ConnectionState.CONNECTED: + await self.send_json(request) + else: + # Queue message if not connected + if not self._message_queue.full(): + await self._message_queue.put(request) + else: + raise ConnectionError("Message queue full") + + # Wait for response if needed + if call_id: + return await asyncio.wait_for(future, timeout=30.0) + else: + return True + + except asyncio.TimeoutError: + self._response_futures.pop(call_id, None) + raise TimeoutError(f"Request timeout for {method}") + except Exception: + if call_id: + self._response_futures.pop(call_id, None) + raise + + def on(self, event: str, handler: Callable) -> None: + """Register an event handler""" + self._event_handlers[event].append(handler) + + def off(self, event: str, handler: Callable) -> None: + """Unregister an event handler""" + if handler in self._event_handlers[event]: + self._event_handlers[event].remove(handler) + + async def _emit_event(self, event: str, data: Any) -> None: + """Emit an event to all registered handlers""" + for handler in self._event_handlers[event]: + try: + if asyncio.iscoroutinefunction(handler): + await handler(data) + else: + handler(data) + except Exception as e: + self.logger.error(f"Error in event handler for {event}: {e}") + + async def close(self) -> None: + """Close the WebSocket connection gracefully""" + if self.state == ConnectionState.DISCONNECTED: + return + + self._closing = True + self.state = ConnectionState.CLOSING + + try: + # Send close frame + close_data = (1000).to_bytes(2, 'big') + b'Normal closure' + await self._send_frame(self.OPCODE_CLOSE, close_data) + + # Wait for close response or timeout + await asyncio.wait_for( + asyncio.shield(self._receive_task), + timeout=5.0 + ) + except (asyncio.TimeoutError, Exception): + pass + + # Cancel tasks + if self._receive_task: + self._receive_task.cancel() + if self._ping_task: + self._ping_task.cancel() + + # Close connection + if self.writer: + self.writer.close() + await self.writer.wait_closed() + + self.state = ConnectionState.CLOSED + self.logger.info("WebSocket connection closed") + + @asynccontextmanager + async def session(self): + """Context manager for WebSocket connection""" + try: + await self.connect() + yield self + finally: + await self.close() + + # Convenience methods for common operations + async def login(self, username: str, password: str) -> Any: + """Login to the WebSocket service""" + return await self.request("login", username, password) + + async def get_channels(self) -> Any: + """Get available channels""" + return await self.request("get_channels") + + async def get_messages(self, channel_uid: str) -> Any: + """Get messages from a channel""" + return await self.request("get_messages", channel_uid) + + async def send_message(self, channel_uid: str, message: str, **kwargs) -> Any: + """Send a message to a channel""" + return await self.request("send_message", channel_uid, message, **kwargs) + + +# Example usage +async def main(): + # Configure logging + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + + # Create WebSocket client + url = "wss://molodetz.online/rpc.ws" + client = WebSocketClient( + url, + ping_interval=30.0, + reconnect_interval=5.0, + max_reconnect_attempts=10 + ) + + # Register event handlers + def on_message(data): + print(f"Received message: {data}") + + def on_disconnect(data): + print("Disconnected from server") + + def on_reconnect(attempts): + print(f"Reconnected after {attempts} attempts") + + client.on('message', on_message) + client.on('disconnect', on_disconnect) + client.on('reconnect', on_reconnect) + + # Use context manager for automatic cleanup + async with client.session(): + # Login + result = await client.login("username", "password") + print(f"Login result: {result}") + + # Get channels + channels = await client.get_channels() + print(f"Available channels: {len(channels)}") + + # Send a message without waiting for response + if channels: + await client.send_message( + channels[0]['uid'], + "Hello from Python!", + no_wait=True + ) + + # Keep connection alive + await asyncio.sleep(60) + + +# Load testing example +async def load_test(): + """Example of running multiple WebSocket clients concurrently""" + async def run_client(client_id: int): + url = "wss://molodetz.online/rpc.ws" + client = WebSocketClient(url, logger=logging.getLogger(f"client-{client_id}")) + + try: + async with client.session(): + await client.login(f"user{client_id}", "password") + + # Simulate activity + for i in range(10): + channels = await client.get_channels() + await asyncio.sleep(1) + + except Exception as e: + print(f"Client {client_id} error: {e}") + + # Run 10 clients concurrently + tasks = [run_client(i) for i in range(10)] + await asyncio.gather(*tasks, return_exceptions=True) + + +if __name__ == "__main__": + asyncio.run(main()) +