|
import asyncio
|
|
import json
|
|
import base64
|
|
import hashlib
|
|
import ssl
|
|
import logging
|
|
import time
|
|
import os
|
|
from typing import Optional, Dict, Any, Callable, Union
|
|
from urllib.parse import urlparse
|
|
from enum import Enum
|
|
from collections import defaultdict
|
|
from contextlib import asynccontextmanager
|
|
import uuid
|
|
|
|
|
|
class ConnectionState(Enum):
|
|
"""WebSocket connection states"""
|
|
DISCONNECTED = "disconnected"
|
|
CONNECTING = "connecting"
|
|
CONNECTED = "connected"
|
|
CLOSING = "closing"
|
|
CLOSED = "closed"
|
|
|
|
|
|
class WebSocketException(Exception):
|
|
"""Base exception for WebSocket errors"""
|
|
pass
|
|
|
|
|
|
class ConnectionError(WebSocketException):
|
|
"""Connection-related errors"""
|
|
pass
|
|
|
|
|
|
class ProtocolError(WebSocketException):
|
|
"""WebSocket protocol errors"""
|
|
pass
|
|
|
|
|
|
class WebSocketClient:
|
|
"""
|
|
Production-ready WebSocket client with automatic reconnection,
|
|
proper error handling, and message queuing.
|
|
"""
|
|
|
|
# WebSocket opcodes
|
|
OPCODE_CONTINUATION = 0x0
|
|
OPCODE_TEXT = 0x1
|
|
OPCODE_BINARY = 0x2
|
|
OPCODE_CLOSE = 0x8
|
|
OPCODE_PING = 0x9
|
|
OPCODE_PONG = 0xa
|
|
|
|
def __init__(
|
|
self,
|
|
url: str,
|
|
ssl_context: Optional[ssl.SSLContext] = None,
|
|
ping_interval: float = 30.0,
|
|
ping_timeout: float = 10.0,
|
|
reconnect_interval: float = 5.0,
|
|
max_reconnect_attempts: int = 10,
|
|
logger: Optional[logging.Logger] = None
|
|
):
|
|
self.url = urlparse(url)
|
|
self.ssl_context = ssl_context or ssl.create_default_context()
|
|
self.ping_interval = ping_interval
|
|
self.ping_timeout = ping_timeout
|
|
self.reconnect_interval = reconnect_interval
|
|
self.max_reconnect_attempts = max_reconnect_attempts
|
|
self.logger = logger or logging.getLogger(__name__)
|
|
|
|
# Connection state
|
|
self.state = ConnectionState.DISCONNECTED
|
|
self.reader: Optional[asyncio.StreamReader] = None
|
|
self.writer: Optional[asyncio.StreamWriter] = None
|
|
|
|
# Message handling
|
|
self._response_futures: Dict[str, asyncio.Future] = {}
|
|
self._event_handlers: Dict[str, list] = defaultdict(list)
|
|
self._receive_task: Optional[asyncio.Task] = None
|
|
self._ping_task: Optional[asyncio.Task] = None
|
|
|
|
# Reconnection
|
|
self._reconnect_attempts = 0
|
|
self._reconnect_lock = asyncio.Lock()
|
|
self._closing = False
|
|
|
|
# Message queue for buffering during reconnection
|
|
self._message_queue = asyncio.Queue(maxsize=1000)
|
|
|
|
async def connect(self) -> None:
|
|
"""Establish WebSocket connection with proper handshake"""
|
|
async with self._reconnect_lock:
|
|
if self.state == ConnectionState.CONNECTED:
|
|
return
|
|
|
|
self.state = ConnectionState.CONNECTING
|
|
self.logger.info(f"Connecting to {self.url.geturl()}")
|
|
|
|
try:
|
|
# Establish TCP connection
|
|
port = self.url.port or (443 if self.url.scheme == 'wss' else 80)
|
|
self.reader, self.writer = await asyncio.open_connection(
|
|
self.url.hostname,
|
|
port,
|
|
ssl=self.ssl_context if self.url.scheme == 'wss' else None
|
|
)
|
|
|
|
# Perform WebSocket handshake
|
|
await self._handshake()
|
|
|
|
# Start background tasks
|
|
self._receive_task = asyncio.create_task(self._receive_loop())
|
|
self._ping_task = asyncio.create_task(self._ping_loop())
|
|
|
|
self.state = ConnectionState.CONNECTED
|
|
self._reconnect_attempts = 0
|
|
self.logger.info("WebSocket connection established")
|
|
|
|
# Process any queued messages
|
|
await self._process_queued_messages()
|
|
|
|
except Exception as e:
|
|
self.state = ConnectionState.DISCONNECTED
|
|
self.logger.error(f"Connection failed: {e}")
|
|
raise ConnectionError(f"Failed to connect: {e}")
|
|
|
|
async def _handshake(self) -> None:
|
|
"""Perform WebSocket handshake"""
|
|
# Generate WebSocket key
|
|
key = base64.b64encode(os.urandom(16)).decode('utf-8')
|
|
|
|
# Build handshake request
|
|
path = self.url.path or '/'
|
|
if self.url.query:
|
|
path += '?' + self.url.query
|
|
|
|
handshake = (
|
|
f"GET {path} HTTP/1.1\r\n"
|
|
f"Host: {self.url.hostname}\r\n"
|
|
f"Upgrade: websocket\r\n"
|
|
f"Connection: Upgrade\r\n"
|
|
f"Sec-WebSocket-Key: {key}\r\n"
|
|
f"Sec-WebSocket-Version: 13\r\n"
|
|
f"User-Agent: Python-WebSocket-Client\r\n"
|
|
f"\r\n"
|
|
)
|
|
|
|
self.writer.write(handshake.encode())
|
|
await self.writer.drain()
|
|
|
|
# Read response
|
|
response = await self.reader.readuntil(b'\r\n\r\n')
|
|
|
|
# Validate response
|
|
if b'101 Switching Protocols' not in response:
|
|
raise ProtocolError(f"Invalid handshake response: {response.decode()}")
|
|
|
|
# Verify Sec-WebSocket-Accept
|
|
expected_accept = base64.b64encode(
|
|
hashlib.sha1((key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode()).digest()
|
|
).decode()
|
|
|
|
if f'Sec-WebSocket-Accept: {expected_accept}'.encode() not in response:
|
|
self.logger.warning("Sec-WebSocket-Accept header mismatch")
|
|
|
|
async def _receive_loop(self) -> None:
|
|
"""Main loop for receiving and processing messages"""
|
|
try:
|
|
while self.state == ConnectionState.CONNECTED and not self._closing:
|
|
try:
|
|
message = await self._receive_message()
|
|
|
|
if message['opcode'] == self.OPCODE_TEXT:
|
|
await self._handle_text_message(message['data'].decode('utf-8'))
|
|
elif message['opcode'] == self.OPCODE_BINARY:
|
|
await self._handle_binary_message(message['data'])
|
|
elif message['opcode'] == self.OPCODE_CLOSE:
|
|
await self._handle_close_frame(message['data'])
|
|
break
|
|
elif message['opcode'] == self.OPCODE_PING:
|
|
await self._send_frame(self.OPCODE_PONG, message['data'])
|
|
elif message['opcode'] == self.OPCODE_PONG:
|
|
pass # Pong received, connection is alive
|
|
|
|
except asyncio.TimeoutError:
|
|
self.logger.warning("Receive timeout")
|
|
except Exception as e:
|
|
self.logger.error(f"Error in receive loop: {e}")
|
|
break
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Fatal error in receive loop: {e}")
|
|
finally:
|
|
if not self._closing:
|
|
await self._handle_disconnect()
|
|
|
|
async def _receive_message(self) -> Dict[str, Any]:
|
|
"""Receive and parse a WebSocket frame"""
|
|
# Read frame header
|
|
header = await self.reader.readexactly(2)
|
|
|
|
fin = (header[0] & 0x80) != 0
|
|
opcode = header[0] & 0x0f
|
|
masked = (header[1] & 0x80) != 0
|
|
payload_length = header[1] & 0x7f
|
|
|
|
# Read extended payload length if needed
|
|
if payload_length == 126:
|
|
payload_length = int.from_bytes(await self.reader.readexactly(2), 'big')
|
|
elif payload_length == 127:
|
|
payload_length = int.from_bytes(await self.reader.readexactly(8), 'big')
|
|
|
|
# Read masking key if present (server shouldn't send masked frames)
|
|
mask_key = None
|
|
if masked:
|
|
mask_key = await self.reader.readexactly(4)
|
|
|
|
# Read payload
|
|
payload = await self.reader.readexactly(payload_length)
|
|
|
|
# Unmask payload if needed
|
|
if masked:
|
|
payload = bytes(b ^ mask_key[i % 4] for i, b in enumerate(payload))
|
|
|
|
return {
|
|
'fin': fin,
|
|
'opcode': opcode,
|
|
'data': payload
|
|
}
|
|
|
|
async def _send_frame(self, opcode: int, data: bytes) -> None:
|
|
"""Send a WebSocket frame"""
|
|
if self.state != ConnectionState.CONNECTED:
|
|
raise ConnectionError("Not connected")
|
|
|
|
# Build frame
|
|
frame = bytearray()
|
|
|
|
# FIN = 1, RSV = 0, Opcode
|
|
frame.append(0x80 | opcode)
|
|
|
|
# Mask bit = 1 (client must mask), payload length
|
|
length = len(data)
|
|
if length <= 125:
|
|
frame.append(0x80 | length)
|
|
elif length <= 65535:
|
|
frame.append(0x80 | 126)
|
|
frame.extend(length.to_bytes(2, 'big'))
|
|
else:
|
|
frame.append(0x80 | 127)
|
|
frame.extend(length.to_bytes(8, 'big'))
|
|
|
|
# Masking key
|
|
mask_key = os.urandom(4)
|
|
frame.extend(mask_key)
|
|
|
|
# Masked payload
|
|
masked_data = bytes(b ^ mask_key[i % 4] for i, b in enumerate(data))
|
|
frame.extend(masked_data)
|
|
|
|
# Send frame
|
|
self.writer.write(frame)
|
|
await self.writer.drain()
|
|
|
|
async def _handle_text_message(self, message: str) -> None:
|
|
"""Handle incoming text message"""
|
|
try:
|
|
data = json.loads(message)
|
|
|
|
# Check if this is a response to a request
|
|
if 'callId' in data and data['callId'] in self._response_futures:
|
|
future = self._response_futures.pop(data['callId'])
|
|
if not future.done():
|
|
future.set_result(data.get('data', data))
|
|
else:
|
|
# Emit as event
|
|
await self._emit_event('message', data)
|
|
|
|
except json.JSONDecodeError:
|
|
self.logger.error(f"Invalid JSON message: {message}")
|
|
await self._emit_event('message', message)
|
|
|
|
async def _handle_binary_message(self, data: bytes) -> None:
|
|
"""Handle incoming binary message"""
|
|
await self._emit_event('binary', data)
|
|
|
|
async def _handle_close_frame(self, data: bytes) -> None:
|
|
"""Handle close frame from server"""
|
|
code = int.from_bytes(data[:2], 'big') if len(data) >= 2 else 1000
|
|
reason = data[2:].decode('utf-8', errors='ignore') if len(data) > 2 else ''
|
|
|
|
self.logger.info(f"Received close frame: {code} {reason}")
|
|
|
|
# Send close frame back
|
|
await self._send_frame(self.OPCODE_CLOSE, data[:2])
|
|
|
|
self.state = ConnectionState.CLOSING
|
|
|
|
async def _handle_disconnect(self) -> None:
|
|
"""Handle unexpected disconnection"""
|
|
self.state = ConnectionState.DISCONNECTED
|
|
self.logger.warning("WebSocket disconnected")
|
|
|
|
# Cancel pending futures
|
|
for future in self._response_futures.values():
|
|
if not future.done():
|
|
future.set_exception(ConnectionError("Disconnected"))
|
|
self._response_futures.clear()
|
|
|
|
# Emit disconnect event
|
|
await self._emit_event('disconnect', None)
|
|
|
|
# Attempt reconnection if not closing
|
|
if not self._closing:
|
|
asyncio.create_task(self._reconnect())
|
|
|
|
async def _reconnect(self) -> None:
|
|
"""Attempt to reconnect with exponential backoff"""
|
|
while self._reconnect_attempts < self.max_reconnect_attempts and not self._closing:
|
|
self._reconnect_attempts += 1
|
|
wait_time = min(self.reconnect_interval * (2 ** (self._reconnect_attempts - 1)), 60)
|
|
|
|
self.logger.info(f"Reconnecting in {wait_time}s (attempt {self._reconnect_attempts}/{self.max_reconnect_attempts})")
|
|
await asyncio.sleep(wait_time)
|
|
|
|
try:
|
|
await self.connect()
|
|
await self._emit_event('reconnect', self._reconnect_attempts)
|
|
return
|
|
except Exception as e:
|
|
self.logger.error(f"Reconnection failed: {e}")
|
|
|
|
self.logger.error("Max reconnection attempts reached")
|
|
await self._emit_event('reconnect_failed', self._reconnect_attempts)
|
|
|
|
async def _ping_loop(self) -> None:
|
|
"""Send periodic ping frames to keep connection alive"""
|
|
try:
|
|
while self.state == ConnectionState.CONNECTED and not self._closing:
|
|
await asyncio.sleep(self.ping_interval)
|
|
|
|
try:
|
|
await self._send_frame(self.OPCODE_PING, b'')
|
|
except Exception as e:
|
|
self.logger.error(f"Ping failed: {e}")
|
|
break
|
|
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
async def _process_queued_messages(self) -> None:
|
|
"""Process messages that were queued during disconnection"""
|
|
while not self._message_queue.empty():
|
|
try:
|
|
message = self._message_queue.get_nowait()
|
|
await self.send_json(message)
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to send queued message: {e}")
|
|
|
|
async def send(self, message: Union[str, bytes]) -> None:
|
|
"""Send a message (text or binary)"""
|
|
if isinstance(message, str):
|
|
await self._send_frame(self.OPCODE_TEXT, message.encode('utf-8'))
|
|
else:
|
|
await self._send_frame(self.OPCODE_BINARY, message)
|
|
|
|
async def send_json(self, data: Dict[str, Any]) -> None:
|
|
"""Send JSON-encoded message"""
|
|
message = json.dumps(data)
|
|
await self.send(message)
|
|
|
|
async def request(self, method: str, *args, **kwargs) -> Any:
|
|
"""
|
|
Send a request and wait for response.
|
|
|
|
Args:
|
|
method: The method name to call
|
|
*args: Positional arguments for the method
|
|
**kwargs: Keyword arguments for the method
|
|
|
|
Returns:
|
|
The response data
|
|
"""
|
|
# Handle no_wait flag
|
|
no_wait = kwargs.pop('no_wait', False)
|
|
|
|
# Generate call ID
|
|
call_id = str(uuid.uuid4()) if not no_wait else None
|
|
|
|
# Prepare request
|
|
request = {
|
|
"method": method,
|
|
"args": args
|
|
}
|
|
|
|
if call_id:
|
|
request["callId"] = call_id
|
|
|
|
# Create future for response if waiting
|
|
if call_id:
|
|
future = asyncio.Future()
|
|
self._response_futures[call_id] = future
|
|
|
|
try:
|
|
# Send request
|
|
if self.state == ConnectionState.CONNECTED:
|
|
await self.send_json(request)
|
|
else:
|
|
# Queue message if not connected
|
|
if not self._message_queue.full():
|
|
await self._message_queue.put(request)
|
|
else:
|
|
raise ConnectionError("Message queue full")
|
|
|
|
# Wait for response if needed
|
|
if call_id:
|
|
return await asyncio.wait_for(future, timeout=30.0)
|
|
else:
|
|
return True
|
|
|
|
except asyncio.TimeoutError:
|
|
self._response_futures.pop(call_id, None)
|
|
raise TimeoutError(f"Request timeout for {method}")
|
|
except Exception:
|
|
if call_id:
|
|
self._response_futures.pop(call_id, None)
|
|
raise
|
|
|
|
def on(self, event: str, handler: Callable) -> None:
|
|
"""Register an event handler"""
|
|
self._event_handlers[event].append(handler)
|
|
|
|
def off(self, event: str, handler: Callable) -> None:
|
|
"""Unregister an event handler"""
|
|
if handler in self._event_handlers[event]:
|
|
self._event_handlers[event].remove(handler)
|
|
|
|
async def _emit_event(self, event: str, data: Any) -> None:
|
|
"""Emit an event to all registered handlers"""
|
|
for handler in self._event_handlers[event]:
|
|
try:
|
|
if asyncio.iscoroutinefunction(handler):
|
|
await handler(data)
|
|
else:
|
|
handler(data)
|
|
except Exception as e:
|
|
self.logger.error(f"Error in event handler for {event}: {e}")
|
|
|
|
async def close(self) -> None:
|
|
"""Close the WebSocket connection gracefully"""
|
|
if self.state == ConnectionState.DISCONNECTED:
|
|
return
|
|
|
|
self._closing = True
|
|
self.state = ConnectionState.CLOSING
|
|
|
|
try:
|
|
# Send close frame
|
|
close_data = (1000).to_bytes(2, 'big') + b'Normal closure'
|
|
await self._send_frame(self.OPCODE_CLOSE, close_data)
|
|
|
|
# Wait for close response or timeout
|
|
await asyncio.wait_for(
|
|
asyncio.shield(self._receive_task),
|
|
timeout=5.0
|
|
)
|
|
except (asyncio.TimeoutError, Exception):
|
|
pass
|
|
|
|
# Cancel tasks
|
|
if self._receive_task:
|
|
self._receive_task.cancel()
|
|
if self._ping_task:
|
|
self._ping_task.cancel()
|
|
|
|
# Close connection
|
|
if self.writer:
|
|
self.writer.close()
|
|
await self.writer.wait_closed()
|
|
|
|
self.state = ConnectionState.CLOSED
|
|
self.logger.info("WebSocket connection closed")
|
|
|
|
@asynccontextmanager
|
|
async def session(self):
|
|
"""Context manager for WebSocket connection"""
|
|
try:
|
|
await self.connect()
|
|
yield self
|
|
finally:
|
|
await self.close()
|
|
|
|
# Convenience methods for common operations
|
|
async def login(self, username: str, password: str) -> Any:
|
|
"""Login to the WebSocket service"""
|
|
return await self.request("login", username, password)
|
|
|
|
async def get_channels(self) -> Any:
|
|
"""Get available channels"""
|
|
return await self.request("get_channels")
|
|
|
|
async def get_messages(self, channel_uid: str) -> Any:
|
|
"""Get messages from a channel"""
|
|
return await self.request("get_messages", channel_uid)
|
|
|
|
async def send_message(self, channel_uid: str, message: str, **kwargs) -> Any:
|
|
"""Send a message to a channel"""
|
|
return await self.request("send_message", channel_uid, message, **kwargs)
|
|
|
|
|
|
# Example usage
|
|
async def main():
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
|
|
# Create WebSocket client
|
|
url = "wss://molodetz.online/rpc.ws"
|
|
client = WebSocketClient(
|
|
url,
|
|
ping_interval=30.0,
|
|
reconnect_interval=5.0,
|
|
max_reconnect_attempts=10
|
|
)
|
|
|
|
# Register event handlers
|
|
def on_message(data):
|
|
print(f"Received message: {data}")
|
|
|
|
def on_disconnect(data):
|
|
print("Disconnected from server")
|
|
|
|
def on_reconnect(attempts):
|
|
print(f"Reconnected after {attempts} attempts")
|
|
|
|
client.on('message', on_message)
|
|
client.on('disconnect', on_disconnect)
|
|
client.on('reconnect', on_reconnect)
|
|
|
|
# Use context manager for automatic cleanup
|
|
async with client.session():
|
|
# Login
|
|
result = await client.login("username", "password")
|
|
print(f"Login result: {result}")
|
|
|
|
# Get channels
|
|
channels = await client.get_channels()
|
|
print(f"Available channels: {len(channels)}")
|
|
|
|
# Send a message without waiting for response
|
|
if channels:
|
|
await client.send_message(
|
|
channels[0]['uid'],
|
|
"Hello from Python!",
|
|
no_wait=True
|
|
)
|
|
|
|
# Keep connection alive
|
|
await asyncio.sleep(60)
|
|
|
|
|
|
# Load testing example
|
|
async def load_test():
|
|
"""Example of running multiple WebSocket clients concurrently"""
|
|
async def run_client(client_id: int):
|
|
url = "wss://molodetz.online/rpc.ws"
|
|
client = WebSocketClient(url, logger=logging.getLogger(f"client-{client_id}"))
|
|
|
|
try:
|
|
async with client.session():
|
|
await client.login(f"user{client_id}", "password")
|
|
|
|
# Simulate activity
|
|
for i in range(10):
|
|
channels = await client.get_channels()
|
|
await asyncio.sleep(1)
|
|
|
|
except Exception as e:
|
|
print(f"Client {client_id} error: {e}")
|
|
|
|
# Run 10 clients concurrently
|
|
tasks = [run_client(i) for i in range(10)]
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|
|
|