import asyncio import json import base64 import hashlib import ssl import logging import time import os from typing import Optional, Dict, Any, Callable, Union from urllib.parse import urlparse from enum import Enum from collections import defaultdict from contextlib import asynccontextmanager import uuid class ConnectionState(Enum): """WebSocket connection states""" DISCONNECTED = "disconnected" CONNECTING = "connecting" CONNECTED = "connected" CLOSING = "closing" CLOSED = "closed" class WebSocketException(Exception): """Base exception for WebSocket errors""" pass class ConnectionError(WebSocketException): """Connection-related errors""" pass class ProtocolError(WebSocketException): """WebSocket protocol errors""" pass class WebSocketClient: """ Production-ready WebSocket client with automatic reconnection, proper error handling, and message queuing. """ # WebSocket opcodes OPCODE_CONTINUATION = 0x0 OPCODE_TEXT = 0x1 OPCODE_BINARY = 0x2 OPCODE_CLOSE = 0x8 OPCODE_PING = 0x9 OPCODE_PONG = 0xa def __init__( self, url: str, ssl_context: Optional[ssl.SSLContext] = None, ping_interval: float = 30.0, ping_timeout: float = 10.0, reconnect_interval: float = 5.0, max_reconnect_attempts: int = 10, logger: Optional[logging.Logger] = None ): self.url = urlparse(url) self.ssl_context = ssl_context or ssl.create_default_context() self.ping_interval = ping_interval self.ping_timeout = ping_timeout self.reconnect_interval = reconnect_interval self.max_reconnect_attempts = max_reconnect_attempts self.logger = logger or logging.getLogger(__name__) # Connection state self.state = ConnectionState.DISCONNECTED self.reader: Optional[asyncio.StreamReader] = None self.writer: Optional[asyncio.StreamWriter] = None # Message handling self._response_futures: Dict[str, asyncio.Future] = {} self._event_handlers: Dict[str, list] = defaultdict(list) self._receive_task: Optional[asyncio.Task] = None self._ping_task: Optional[asyncio.Task] = None # Reconnection self._reconnect_attempts = 0 self._reconnect_lock = asyncio.Lock() self._closing = False # Message queue for buffering during reconnection self._message_queue = asyncio.Queue(maxsize=1000) async def connect(self) -> None: """Establish WebSocket connection with proper handshake""" async with self._reconnect_lock: if self.state == ConnectionState.CONNECTED: return self.state = ConnectionState.CONNECTING self.logger.info(f"Connecting to {self.url.geturl()}") try: # Establish TCP connection port = self.url.port or (443 if self.url.scheme == 'wss' else 80) self.reader, self.writer = await asyncio.open_connection( self.url.hostname, port, ssl=self.ssl_context if self.url.scheme == 'wss' else None ) # Perform WebSocket handshake await self._handshake() # Start background tasks self._receive_task = asyncio.create_task(self._receive_loop()) self._ping_task = asyncio.create_task(self._ping_loop()) self.state = ConnectionState.CONNECTED self._reconnect_attempts = 0 self.logger.info("WebSocket connection established") # Process any queued messages await self._process_queued_messages() except Exception as e: self.state = ConnectionState.DISCONNECTED self.logger.error(f"Connection failed: {e}") raise ConnectionError(f"Failed to connect: {e}") async def _handshake(self) -> None: """Perform WebSocket handshake""" # Generate WebSocket key key = base64.b64encode(os.urandom(16)).decode('utf-8') # Build handshake request path = self.url.path or '/' if self.url.query: path += '?' + self.url.query handshake = ( f"GET {path} HTTP/1.1\r\n" f"Host: {self.url.hostname}\r\n" f"Upgrade: websocket\r\n" f"Connection: Upgrade\r\n" f"Sec-WebSocket-Key: {key}\r\n" f"Sec-WebSocket-Version: 13\r\n" f"User-Agent: Python-WebSocket-Client\r\n" f"\r\n" ) self.writer.write(handshake.encode()) await self.writer.drain() # Read response response = await self.reader.readuntil(b'\r\n\r\n') # Validate response if b'101 Switching Protocols' not in response: raise ProtocolError(f"Invalid handshake response: {response.decode()}") # Verify Sec-WebSocket-Accept expected_accept = base64.b64encode( hashlib.sha1((key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode()).digest() ).decode() if f'Sec-WebSocket-Accept: {expected_accept}'.encode() not in response: self.logger.warning("Sec-WebSocket-Accept header mismatch") async def _receive_loop(self) -> None: """Main loop for receiving and processing messages""" try: while self.state == ConnectionState.CONNECTED and not self._closing: try: message = await self._receive_message() if message['opcode'] == self.OPCODE_TEXT: await self._handle_text_message(message['data'].decode('utf-8')) elif message['opcode'] == self.OPCODE_BINARY: await self._handle_binary_message(message['data']) elif message['opcode'] == self.OPCODE_CLOSE: await self._handle_close_frame(message['data']) break elif message['opcode'] == self.OPCODE_PING: await self._send_frame(self.OPCODE_PONG, message['data']) elif message['opcode'] == self.OPCODE_PONG: pass # Pong received, connection is alive except asyncio.TimeoutError: self.logger.warning("Receive timeout") except Exception as e: self.logger.error(f"Error in receive loop: {e}") break except Exception as e: self.logger.error(f"Fatal error in receive loop: {e}") finally: if not self._closing: await self._handle_disconnect() async def _receive_message(self) -> Dict[str, Any]: """Receive and parse a WebSocket frame""" # Read frame header header = await self.reader.readexactly(2) fin = (header[0] & 0x80) != 0 opcode = header[0] & 0x0f masked = (header[1] & 0x80) != 0 payload_length = header[1] & 0x7f # Read extended payload length if needed if payload_length == 126: payload_length = int.from_bytes(await self.reader.readexactly(2), 'big') elif payload_length == 127: payload_length = int.from_bytes(await self.reader.readexactly(8), 'big') # Read masking key if present (server shouldn't send masked frames) mask_key = None if masked: mask_key = await self.reader.readexactly(4) # Read payload payload = await self.reader.readexactly(payload_length) # Unmask payload if needed if masked: payload = bytes(b ^ mask_key[i % 4] for i, b in enumerate(payload)) return { 'fin': fin, 'opcode': opcode, 'data': payload } async def _send_frame(self, opcode: int, data: bytes) -> None: """Send a WebSocket frame""" if self.state != ConnectionState.CONNECTED: raise ConnectionError("Not connected") # Build frame frame = bytearray() # FIN = 1, RSV = 0, Opcode frame.append(0x80 | opcode) # Mask bit = 1 (client must mask), payload length length = len(data) if length <= 125: frame.append(0x80 | length) elif length <= 65535: frame.append(0x80 | 126) frame.extend(length.to_bytes(2, 'big')) else: frame.append(0x80 | 127) frame.extend(length.to_bytes(8, 'big')) # Masking key mask_key = os.urandom(4) frame.extend(mask_key) # Masked payload masked_data = bytes(b ^ mask_key[i % 4] for i, b in enumerate(data)) frame.extend(masked_data) # Send frame self.writer.write(frame) await self.writer.drain() async def _handle_text_message(self, message: str) -> None: """Handle incoming text message""" try: data = json.loads(message) # Check if this is a response to a request if 'callId' in data and data['callId'] in self._response_futures: future = self._response_futures.pop(data['callId']) if not future.done(): future.set_result(data.get('data', data)) else: # Emit as event await self._emit_event('message', data) except json.JSONDecodeError: self.logger.error(f"Invalid JSON message: {message}") await self._emit_event('message', message) async def _handle_binary_message(self, data: bytes) -> None: """Handle incoming binary message""" await self._emit_event('binary', data) async def _handle_close_frame(self, data: bytes) -> None: """Handle close frame from server""" code = int.from_bytes(data[:2], 'big') if len(data) >= 2 else 1000 reason = data[2:].decode('utf-8', errors='ignore') if len(data) > 2 else '' self.logger.info(f"Received close frame: {code} {reason}") # Send close frame back await self._send_frame(self.OPCODE_CLOSE, data[:2]) self.state = ConnectionState.CLOSING async def _handle_disconnect(self) -> None: """Handle unexpected disconnection""" self.state = ConnectionState.DISCONNECTED self.logger.warning("WebSocket disconnected") # Cancel pending futures for future in self._response_futures.values(): if not future.done(): future.set_exception(ConnectionError("Disconnected")) self._response_futures.clear() # Emit disconnect event await self._emit_event('disconnect', None) # Attempt reconnection if not closing if not self._closing: asyncio.create_task(self._reconnect()) async def _reconnect(self) -> None: """Attempt to reconnect with exponential backoff""" while self._reconnect_attempts < self.max_reconnect_attempts and not self._closing: self._reconnect_attempts += 1 wait_time = min(self.reconnect_interval * (2 ** (self._reconnect_attempts - 1)), 60) self.logger.info(f"Reconnecting in {wait_time}s (attempt {self._reconnect_attempts}/{self.max_reconnect_attempts})") await asyncio.sleep(wait_time) try: await self.connect() await self._emit_event('reconnect', self._reconnect_attempts) return except Exception as e: self.logger.error(f"Reconnection failed: {e}") self.logger.error("Max reconnection attempts reached") await self._emit_event('reconnect_failed', self._reconnect_attempts) async def _ping_loop(self) -> None: """Send periodic ping frames to keep connection alive""" try: while self.state == ConnectionState.CONNECTED and not self._closing: await asyncio.sleep(self.ping_interval) try: await self._send_frame(self.OPCODE_PING, b'') except Exception as e: self.logger.error(f"Ping failed: {e}") break except asyncio.CancelledError: pass async def _process_queued_messages(self) -> None: """Process messages that were queued during disconnection""" while not self._message_queue.empty(): try: message = self._message_queue.get_nowait() await self.send_json(message) except Exception as e: self.logger.error(f"Failed to send queued message: {e}") async def send(self, message: Union[str, bytes]) -> None: """Send a message (text or binary)""" if isinstance(message, str): await self._send_frame(self.OPCODE_TEXT, message.encode('utf-8')) else: await self._send_frame(self.OPCODE_BINARY, message) async def send_json(self, data: Dict[str, Any]) -> None: """Send JSON-encoded message""" message = json.dumps(data) await self.send(message) async def request(self, method: str, *args, **kwargs) -> Any: """ Send a request and wait for response. Args: method: The method name to call *args: Positional arguments for the method **kwargs: Keyword arguments for the method Returns: The response data """ # Handle no_wait flag no_wait = kwargs.pop('no_wait', False) # Generate call ID call_id = str(uuid.uuid4()) if not no_wait else None # Prepare request request = { "method": method, "args": args } if call_id: request["callId"] = call_id # Create future for response if waiting if call_id: future = asyncio.Future() self._response_futures[call_id] = future try: # Send request if self.state == ConnectionState.CONNECTED: await self.send_json(request) else: # Queue message if not connected if not self._message_queue.full(): await self._message_queue.put(request) else: raise ConnectionError("Message queue full") # Wait for response if needed if call_id: return await asyncio.wait_for(future, timeout=30.0) else: return True except asyncio.TimeoutError: self._response_futures.pop(call_id, None) raise TimeoutError(f"Request timeout for {method}") except Exception: if call_id: self._response_futures.pop(call_id, None) raise def on(self, event: str, handler: Callable) -> None: """Register an event handler""" self._event_handlers[event].append(handler) def off(self, event: str, handler: Callable) -> None: """Unregister an event handler""" if handler in self._event_handlers[event]: self._event_handlers[event].remove(handler) async def _emit_event(self, event: str, data: Any) -> None: """Emit an event to all registered handlers""" for handler in self._event_handlers[event]: try: if asyncio.iscoroutinefunction(handler): await handler(data) else: handler(data) except Exception as e: self.logger.error(f"Error in event handler for {event}: {e}") async def close(self) -> None: """Close the WebSocket connection gracefully""" if self.state == ConnectionState.DISCONNECTED: return self._closing = True self.state = ConnectionState.CLOSING try: # Send close frame close_data = (1000).to_bytes(2, 'big') + b'Normal closure' await self._send_frame(self.OPCODE_CLOSE, close_data) # Wait for close response or timeout await asyncio.wait_for( asyncio.shield(self._receive_task), timeout=5.0 ) except (asyncio.TimeoutError, Exception): pass # Cancel tasks if self._receive_task: self._receive_task.cancel() if self._ping_task: self._ping_task.cancel() # Close connection if self.writer: self.writer.close() await self.writer.wait_closed() self.state = ConnectionState.CLOSED self.logger.info("WebSocket connection closed") @asynccontextmanager async def session(self): """Context manager for WebSocket connection""" try: await self.connect() yield self finally: await self.close() # Convenience methods for common operations async def login(self, username: str, password: str) -> Any: """Login to the WebSocket service""" return await self.request("login", username, password) async def get_channels(self) -> Any: """Get available channels""" return await self.request("get_channels") async def get_messages(self, channel_uid: str) -> Any: """Get messages from a channel""" return await self.request("get_messages", channel_uid) async def send_message(self, channel_uid: str, message: str, **kwargs) -> Any: """Send a message to a channel""" return await self.request("send_message", channel_uid, message, **kwargs) # Example usage async def main(): # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) # Create WebSocket client url = "wss://molodetz.online/rpc.ws" client = WebSocketClient( url, ping_interval=30.0, reconnect_interval=5.0, max_reconnect_attempts=10 ) # Register event handlers def on_message(data): print(f"Received message: {data}") def on_disconnect(data): print("Disconnected from server") def on_reconnect(attempts): print(f"Reconnected after {attempts} attempts") client.on('message', on_message) client.on('disconnect', on_disconnect) client.on('reconnect', on_reconnect) # Use context manager for automatic cleanup async with client.session(): # Login result = await client.login("username", "password") print(f"Login result: {result}") # Get channels channels = await client.get_channels() print(f"Available channels: {len(channels)}") # Send a message without waiting for response if channels: await client.send_message( channels[0]['uid'], "Hello from Python!", no_wait=True ) # Keep connection alive await asyncio.sleep(60) # Load testing example async def load_test(): """Example of running multiple WebSocket clients concurrently""" async def run_client(client_id: int): url = "wss://molodetz.online/rpc.ws" client = WebSocketClient(url, logger=logging.getLogger(f"client-{client_id}")) try: async with client.session(): await client.login(f"user{client_id}", "password") # Simulate activity for i in range(10): channels = await client.get_channels() await asyncio.sleep(1) except Exception as e: print(f"Client {client_id} error: {e}") # Run 10 clients concurrently tasks = [run_client(i) for i in range(10)] await asyncio.gather(*tasks, return_exceptions=True) if __name__ == "__main__": asyncio.run(main())