593 lines
20 KiB
Python
Raw Normal View History

2025-07-20 20:18:24 +02:00
import asyncio
import json
import base64
import hashlib
import ssl
import logging
import time
import os
from typing import Optional, Dict, Any, Callable, Union
from urllib.parse import urlparse
from enum import Enum
from collections import defaultdict
from contextlib import asynccontextmanager
import uuid
class ConnectionState(Enum):
"""WebSocket connection states"""
DISCONNECTED = "disconnected"
CONNECTING = "connecting"
CONNECTED = "connected"
CLOSING = "closing"
CLOSED = "closed"
class WebSocketException(Exception):
"""Base exception for WebSocket errors"""
pass
class ConnectionError(WebSocketException):
"""Connection-related errors"""
pass
class ProtocolError(WebSocketException):
"""WebSocket protocol errors"""
pass
class WebSocketClient:
"""
Production-ready WebSocket client with automatic reconnection,
proper error handling, and message queuing.
"""
# WebSocket opcodes
OPCODE_CONTINUATION = 0x0
OPCODE_TEXT = 0x1
OPCODE_BINARY = 0x2
OPCODE_CLOSE = 0x8
OPCODE_PING = 0x9
OPCODE_PONG = 0xa
def __init__(
self,
url: str,
ssl_context: Optional[ssl.SSLContext] = None,
ping_interval: float = 30.0,
ping_timeout: float = 10.0,
reconnect_interval: float = 5.0,
max_reconnect_attempts: int = 10,
logger: Optional[logging.Logger] = None
):
self.url = urlparse(url)
self.ssl_context = ssl_context or ssl.create_default_context()
self.ping_interval = ping_interval
self.ping_timeout = ping_timeout
self.reconnect_interval = reconnect_interval
self.max_reconnect_attempts = max_reconnect_attempts
self.logger = logger or logging.getLogger(__name__)
# Connection state
self.state = ConnectionState.DISCONNECTED
self.reader: Optional[asyncio.StreamReader] = None
self.writer: Optional[asyncio.StreamWriter] = None
# Message handling
self._response_futures: Dict[str, asyncio.Future] = {}
self._event_handlers: Dict[str, list] = defaultdict(list)
self._receive_task: Optional[asyncio.Task] = None
self._ping_task: Optional[asyncio.Task] = None
# Reconnection
self._reconnect_attempts = 0
self._reconnect_lock = asyncio.Lock()
self._closing = False
# Message queue for buffering during reconnection
self._message_queue = asyncio.Queue(maxsize=1000)
async def connect(self) -> None:
"""Establish WebSocket connection with proper handshake"""
async with self._reconnect_lock:
if self.state == ConnectionState.CONNECTED:
return
self.state = ConnectionState.CONNECTING
self.logger.info(f"Connecting to {self.url.geturl()}")
try:
# Establish TCP connection
port = self.url.port or (443 if self.url.scheme == 'wss' else 80)
self.reader, self.writer = await asyncio.open_connection(
self.url.hostname,
port,
ssl=self.ssl_context if self.url.scheme == 'wss' else None
)
# Perform WebSocket handshake
await self._handshake()
# Start background tasks
self._receive_task = asyncio.create_task(self._receive_loop())
self._ping_task = asyncio.create_task(self._ping_loop())
self.state = ConnectionState.CONNECTED
self._reconnect_attempts = 0
self.logger.info("WebSocket connection established")
# Process any queued messages
await self._process_queued_messages()
except Exception as e:
self.state = ConnectionState.DISCONNECTED
self.logger.error(f"Connection failed: {e}")
raise ConnectionError(f"Failed to connect: {e}")
async def _handshake(self) -> None:
"""Perform WebSocket handshake"""
# Generate WebSocket key
key = base64.b64encode(os.urandom(16)).decode('utf-8')
# Build handshake request
path = self.url.path or '/'
if self.url.query:
path += '?' + self.url.query
handshake = (
f"GET {path} HTTP/1.1\r\n"
f"Host: {self.url.hostname}\r\n"
f"Upgrade: websocket\r\n"
f"Connection: Upgrade\r\n"
f"Sec-WebSocket-Key: {key}\r\n"
f"Sec-WebSocket-Version: 13\r\n"
f"User-Agent: Python-WebSocket-Client\r\n"
f"\r\n"
)
self.writer.write(handshake.encode())
await self.writer.drain()
# Read response
response = await self.reader.readuntil(b'\r\n\r\n')
# Validate response
if b'101 Switching Protocols' not in response:
raise ProtocolError(f"Invalid handshake response: {response.decode()}")
# Verify Sec-WebSocket-Accept
expected_accept = base64.b64encode(
hashlib.sha1((key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode()).digest()
).decode()
if f'Sec-WebSocket-Accept: {expected_accept}'.encode() not in response:
self.logger.warning("Sec-WebSocket-Accept header mismatch")
async def _receive_loop(self) -> None:
"""Main loop for receiving and processing messages"""
try:
while self.state == ConnectionState.CONNECTED and not self._closing:
try:
message = await self._receive_message()
if message['opcode'] == self.OPCODE_TEXT:
await self._handle_text_message(message['data'].decode('utf-8'))
elif message['opcode'] == self.OPCODE_BINARY:
await self._handle_binary_message(message['data'])
elif message['opcode'] == self.OPCODE_CLOSE:
await self._handle_close_frame(message['data'])
break
elif message['opcode'] == self.OPCODE_PING:
await self._send_frame(self.OPCODE_PONG, message['data'])
elif message['opcode'] == self.OPCODE_PONG:
pass # Pong received, connection is alive
except asyncio.TimeoutError:
self.logger.warning("Receive timeout")
except Exception as e:
self.logger.error(f"Error in receive loop: {e}")
break
except Exception as e:
self.logger.error(f"Fatal error in receive loop: {e}")
finally:
if not self._closing:
await self._handle_disconnect()
async def _receive_message(self) -> Dict[str, Any]:
"""Receive and parse a WebSocket frame"""
# Read frame header
header = await self.reader.readexactly(2)
fin = (header[0] & 0x80) != 0
opcode = header[0] & 0x0f
masked = (header[1] & 0x80) != 0
payload_length = header[1] & 0x7f
# Read extended payload length if needed
if payload_length == 126:
payload_length = int.from_bytes(await self.reader.readexactly(2), 'big')
elif payload_length == 127:
payload_length = int.from_bytes(await self.reader.readexactly(8), 'big')
# Read masking key if present (server shouldn't send masked frames)
mask_key = None
if masked:
mask_key = await self.reader.readexactly(4)
# Read payload
payload = await self.reader.readexactly(payload_length)
# Unmask payload if needed
if masked:
payload = bytes(b ^ mask_key[i % 4] for i, b in enumerate(payload))
return {
'fin': fin,
'opcode': opcode,
'data': payload
}
async def _send_frame(self, opcode: int, data: bytes) -> None:
"""Send a WebSocket frame"""
if self.state != ConnectionState.CONNECTED:
raise ConnectionError("Not connected")
# Build frame
frame = bytearray()
# FIN = 1, RSV = 0, Opcode
frame.append(0x80 | opcode)
# Mask bit = 1 (client must mask), payload length
length = len(data)
if length <= 125:
frame.append(0x80 | length)
elif length <= 65535:
frame.append(0x80 | 126)
frame.extend(length.to_bytes(2, 'big'))
else:
frame.append(0x80 | 127)
frame.extend(length.to_bytes(8, 'big'))
# Masking key
mask_key = os.urandom(4)
frame.extend(mask_key)
# Masked payload
masked_data = bytes(b ^ mask_key[i % 4] for i, b in enumerate(data))
frame.extend(masked_data)
# Send frame
self.writer.write(frame)
await self.writer.drain()
async def _handle_text_message(self, message: str) -> None:
"""Handle incoming text message"""
try:
data = json.loads(message)
# Check if this is a response to a request
if 'callId' in data and data['callId'] in self._response_futures:
future = self._response_futures.pop(data['callId'])
if not future.done():
future.set_result(data.get('data', data))
else:
# Emit as event
await self._emit_event('message', data)
except json.JSONDecodeError:
self.logger.error(f"Invalid JSON message: {message}")
await self._emit_event('message', message)
async def _handle_binary_message(self, data: bytes) -> None:
"""Handle incoming binary message"""
await self._emit_event('binary', data)
async def _handle_close_frame(self, data: bytes) -> None:
"""Handle close frame from server"""
code = int.from_bytes(data[:2], 'big') if len(data) >= 2 else 1000
reason = data[2:].decode('utf-8', errors='ignore') if len(data) > 2 else ''
self.logger.info(f"Received close frame: {code} {reason}")
# Send close frame back
await self._send_frame(self.OPCODE_CLOSE, data[:2])
self.state = ConnectionState.CLOSING
async def _handle_disconnect(self) -> None:
"""Handle unexpected disconnection"""
self.state = ConnectionState.DISCONNECTED
self.logger.warning("WebSocket disconnected")
# Cancel pending futures
for future in self._response_futures.values():
if not future.done():
future.set_exception(ConnectionError("Disconnected"))
self._response_futures.clear()
# Emit disconnect event
await self._emit_event('disconnect', None)
# Attempt reconnection if not closing
if not self._closing:
asyncio.create_task(self._reconnect())
async def _reconnect(self) -> None:
"""Attempt to reconnect with exponential backoff"""
while self._reconnect_attempts < self.max_reconnect_attempts and not self._closing:
self._reconnect_attempts += 1
wait_time = min(self.reconnect_interval * (2 ** (self._reconnect_attempts - 1)), 60)
self.logger.info(f"Reconnecting in {wait_time}s (attempt {self._reconnect_attempts}/{self.max_reconnect_attempts})")
await asyncio.sleep(wait_time)
try:
await self.connect()
await self._emit_event('reconnect', self._reconnect_attempts)
return
except Exception as e:
self.logger.error(f"Reconnection failed: {e}")
self.logger.error("Max reconnection attempts reached")
await self._emit_event('reconnect_failed', self._reconnect_attempts)
async def _ping_loop(self) -> None:
"""Send periodic ping frames to keep connection alive"""
try:
while self.state == ConnectionState.CONNECTED and not self._closing:
await asyncio.sleep(self.ping_interval)
try:
await self._send_frame(self.OPCODE_PING, b'')
except Exception as e:
self.logger.error(f"Ping failed: {e}")
break
except asyncio.CancelledError:
pass
async def _process_queued_messages(self) -> None:
"""Process messages that were queued during disconnection"""
while not self._message_queue.empty():
try:
message = self._message_queue.get_nowait()
await self.send_json(message)
except Exception as e:
self.logger.error(f"Failed to send queued message: {e}")
async def send(self, message: Union[str, bytes]) -> None:
"""Send a message (text or binary)"""
if isinstance(message, str):
await self._send_frame(self.OPCODE_TEXT, message.encode('utf-8'))
else:
await self._send_frame(self.OPCODE_BINARY, message)
async def send_json(self, data: Dict[str, Any]) -> None:
"""Send JSON-encoded message"""
message = json.dumps(data)
await self.send(message)
async def request(self, method: str, *args, **kwargs) -> Any:
"""
Send a request and wait for response.
Args:
method: The method name to call
*args: Positional arguments for the method
**kwargs: Keyword arguments for the method
Returns:
The response data
"""
# Handle no_wait flag
no_wait = kwargs.pop('no_wait', False)
# Generate call ID
call_id = str(uuid.uuid4()) if not no_wait else None
# Prepare request
request = {
"method": method,
"args": args
}
if call_id:
request["callId"] = call_id
# Create future for response if waiting
if call_id:
future = asyncio.Future()
self._response_futures[call_id] = future
try:
# Send request
if self.state == ConnectionState.CONNECTED:
await self.send_json(request)
else:
# Queue message if not connected
if not self._message_queue.full():
await self._message_queue.put(request)
else:
raise ConnectionError("Message queue full")
# Wait for response if needed
if call_id:
return await asyncio.wait_for(future, timeout=30.0)
else:
return True
except asyncio.TimeoutError:
self._response_futures.pop(call_id, None)
raise TimeoutError(f"Request timeout for {method}")
except Exception:
if call_id:
self._response_futures.pop(call_id, None)
raise
def on(self, event: str, handler: Callable) -> None:
"""Register an event handler"""
self._event_handlers[event].append(handler)
def off(self, event: str, handler: Callable) -> None:
"""Unregister an event handler"""
if handler in self._event_handlers[event]:
self._event_handlers[event].remove(handler)
async def _emit_event(self, event: str, data: Any) -> None:
"""Emit an event to all registered handlers"""
for handler in self._event_handlers[event]:
try:
if asyncio.iscoroutinefunction(handler):
await handler(data)
else:
handler(data)
except Exception as e:
self.logger.error(f"Error in event handler for {event}: {e}")
async def close(self) -> None:
"""Close the WebSocket connection gracefully"""
if self.state == ConnectionState.DISCONNECTED:
return
self._closing = True
self.state = ConnectionState.CLOSING
try:
# Send close frame
close_data = (1000).to_bytes(2, 'big') + b'Normal closure'
await self._send_frame(self.OPCODE_CLOSE, close_data)
# Wait for close response or timeout
await asyncio.wait_for(
asyncio.shield(self._receive_task),
timeout=5.0
)
except (asyncio.TimeoutError, Exception):
pass
# Cancel tasks
if self._receive_task:
self._receive_task.cancel()
if self._ping_task:
self._ping_task.cancel()
# Close connection
if self.writer:
self.writer.close()
await self.writer.wait_closed()
self.state = ConnectionState.CLOSED
self.logger.info("WebSocket connection closed")
@asynccontextmanager
async def session(self):
"""Context manager for WebSocket connection"""
try:
await self.connect()
yield self
finally:
await self.close()
# Convenience methods for common operations
async def login(self, username: str, password: str) -> Any:
"""Login to the WebSocket service"""
return await self.request("login", username, password)
async def get_channels(self) -> Any:
"""Get available channels"""
return await self.request("get_channels")
async def get_messages(self, channel_uid: str) -> Any:
"""Get messages from a channel"""
return await self.request("get_messages", channel_uid)
async def send_message(self, channel_uid: str, message: str, **kwargs) -> Any:
"""Send a message to a channel"""
return await self.request("send_message", channel_uid, message, **kwargs)
# Example usage
async def main():
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# Create WebSocket client
url = "wss://molodetz.online/rpc.ws"
client = WebSocketClient(
url,
ping_interval=30.0,
reconnect_interval=5.0,
max_reconnect_attempts=10
)
# Register event handlers
def on_message(data):
print(f"Received message: {data}")
def on_disconnect(data):
print("Disconnected from server")
def on_reconnect(attempts):
print(f"Reconnected after {attempts} attempts")
client.on('message', on_message)
client.on('disconnect', on_disconnect)
client.on('reconnect', on_reconnect)
# Use context manager for automatic cleanup
async with client.session():
# Login
result = await client.login("username", "password")
print(f"Login result: {result}")
# Get channels
channels = await client.get_channels()
print(f"Available channels: {len(channels)}")
# Send a message without waiting for response
if channels:
await client.send_message(
channels[0]['uid'],
"Hello from Python!",
no_wait=True
)
# Keep connection alive
await asyncio.sleep(60)
# Load testing example
async def load_test():
"""Example of running multiple WebSocket clients concurrently"""
async def run_client(client_id: int):
url = "wss://molodetz.online/rpc.ws"
client = WebSocketClient(url, logger=logging.getLogger(f"client-{client_id}"))
try:
async with client.session():
await client.login(f"user{client_id}", "password")
# Simulate activity
for i in range(10):
channels = await client.get_channels()
await asyncio.sleep(1)
except Exception as e:
print(f"Client {client_id} error: {e}")
# Run 10 clients concurrently
tasks = [run_client(i) for i in range(10)]
await asyncio.gather(*tasks, return_exceptions=True)
if __name__ == "__main__":
asyncio.run(main())