Update, new snek_rpc connection.
This commit is contained in:
parent
5377b3c647
commit
6378c397a0
592
snek_rpc.py
Normal file
592
snek_rpc.py
Normal file
@ -0,0 +1,592 @@
|
||||
import asyncio
|
||||
import json
|
||||
import base64
|
||||
import hashlib
|
||||
import ssl
|
||||
import logging
|
||||
import time
|
||||
import os
|
||||
from typing import Optional, Dict, Any, Callable, Union
|
||||
from urllib.parse import urlparse
|
||||
from enum import Enum
|
||||
from collections import defaultdict
|
||||
from contextlib import asynccontextmanager
|
||||
import uuid
|
||||
|
||||
|
||||
class ConnectionState(Enum):
|
||||
"""WebSocket connection states"""
|
||||
DISCONNECTED = "disconnected"
|
||||
CONNECTING = "connecting"
|
||||
CONNECTED = "connected"
|
||||
CLOSING = "closing"
|
||||
CLOSED = "closed"
|
||||
|
||||
|
||||
class WebSocketException(Exception):
|
||||
"""Base exception for WebSocket errors"""
|
||||
pass
|
||||
|
||||
|
||||
class ConnectionError(WebSocketException):
|
||||
"""Connection-related errors"""
|
||||
pass
|
||||
|
||||
|
||||
class ProtocolError(WebSocketException):
|
||||
"""WebSocket protocol errors"""
|
||||
pass
|
||||
|
||||
|
||||
class WebSocketClient:
|
||||
"""
|
||||
Production-ready WebSocket client with automatic reconnection,
|
||||
proper error handling, and message queuing.
|
||||
"""
|
||||
|
||||
# WebSocket opcodes
|
||||
OPCODE_CONTINUATION = 0x0
|
||||
OPCODE_TEXT = 0x1
|
||||
OPCODE_BINARY = 0x2
|
||||
OPCODE_CLOSE = 0x8
|
||||
OPCODE_PING = 0x9
|
||||
OPCODE_PONG = 0xa
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
ping_interval: float = 30.0,
|
||||
ping_timeout: float = 10.0,
|
||||
reconnect_interval: float = 5.0,
|
||||
max_reconnect_attempts: int = 10,
|
||||
logger: Optional[logging.Logger] = None
|
||||
):
|
||||
self.url = urlparse(url)
|
||||
self.ssl_context = ssl_context or ssl.create_default_context()
|
||||
self.ping_interval = ping_interval
|
||||
self.ping_timeout = ping_timeout
|
||||
self.reconnect_interval = reconnect_interval
|
||||
self.max_reconnect_attempts = max_reconnect_attempts
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
|
||||
# Connection state
|
||||
self.state = ConnectionState.DISCONNECTED
|
||||
self.reader: Optional[asyncio.StreamReader] = None
|
||||
self.writer: Optional[asyncio.StreamWriter] = None
|
||||
|
||||
# Message handling
|
||||
self._response_futures: Dict[str, asyncio.Future] = {}
|
||||
self._event_handlers: Dict[str, list] = defaultdict(list)
|
||||
self._receive_task: Optional[asyncio.Task] = None
|
||||
self._ping_task: Optional[asyncio.Task] = None
|
||||
|
||||
# Reconnection
|
||||
self._reconnect_attempts = 0
|
||||
self._reconnect_lock = asyncio.Lock()
|
||||
self._closing = False
|
||||
|
||||
# Message queue for buffering during reconnection
|
||||
self._message_queue = asyncio.Queue(maxsize=1000)
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish WebSocket connection with proper handshake"""
|
||||
async with self._reconnect_lock:
|
||||
if self.state == ConnectionState.CONNECTED:
|
||||
return
|
||||
|
||||
self.state = ConnectionState.CONNECTING
|
||||
self.logger.info(f"Connecting to {self.url.geturl()}")
|
||||
|
||||
try:
|
||||
# Establish TCP connection
|
||||
port = self.url.port or (443 if self.url.scheme == 'wss' else 80)
|
||||
self.reader, self.writer = await asyncio.open_connection(
|
||||
self.url.hostname,
|
||||
port,
|
||||
ssl=self.ssl_context if self.url.scheme == 'wss' else None
|
||||
)
|
||||
|
||||
# Perform WebSocket handshake
|
||||
await self._handshake()
|
||||
|
||||
# Start background tasks
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
self._ping_task = asyncio.create_task(self._ping_loop())
|
||||
|
||||
self.state = ConnectionState.CONNECTED
|
||||
self._reconnect_attempts = 0
|
||||
self.logger.info("WebSocket connection established")
|
||||
|
||||
# Process any queued messages
|
||||
await self._process_queued_messages()
|
||||
|
||||
except Exception as e:
|
||||
self.state = ConnectionState.DISCONNECTED
|
||||
self.logger.error(f"Connection failed: {e}")
|
||||
raise ConnectionError(f"Failed to connect: {e}")
|
||||
|
||||
async def _handshake(self) -> None:
|
||||
"""Perform WebSocket handshake"""
|
||||
# Generate WebSocket key
|
||||
key = base64.b64encode(os.urandom(16)).decode('utf-8')
|
||||
|
||||
# Build handshake request
|
||||
path = self.url.path or '/'
|
||||
if self.url.query:
|
||||
path += '?' + self.url.query
|
||||
|
||||
handshake = (
|
||||
f"GET {path} HTTP/1.1\r\n"
|
||||
f"Host: {self.url.hostname}\r\n"
|
||||
f"Upgrade: websocket\r\n"
|
||||
f"Connection: Upgrade\r\n"
|
||||
f"Sec-WebSocket-Key: {key}\r\n"
|
||||
f"Sec-WebSocket-Version: 13\r\n"
|
||||
f"User-Agent: Python-WebSocket-Client\r\n"
|
||||
f"\r\n"
|
||||
)
|
||||
|
||||
self.writer.write(handshake.encode())
|
||||
await self.writer.drain()
|
||||
|
||||
# Read response
|
||||
response = await self.reader.readuntil(b'\r\n\r\n')
|
||||
|
||||
# Validate response
|
||||
if b'101 Switching Protocols' not in response:
|
||||
raise ProtocolError(f"Invalid handshake response: {response.decode()}")
|
||||
|
||||
# Verify Sec-WebSocket-Accept
|
||||
expected_accept = base64.b64encode(
|
||||
hashlib.sha1((key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode()).digest()
|
||||
).decode()
|
||||
|
||||
if f'Sec-WebSocket-Accept: {expected_accept}'.encode() not in response:
|
||||
self.logger.warning("Sec-WebSocket-Accept header mismatch")
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Main loop for receiving and processing messages"""
|
||||
try:
|
||||
while self.state == ConnectionState.CONNECTED and not self._closing:
|
||||
try:
|
||||
message = await self._receive_message()
|
||||
|
||||
if message['opcode'] == self.OPCODE_TEXT:
|
||||
await self._handle_text_message(message['data'].decode('utf-8'))
|
||||
elif message['opcode'] == self.OPCODE_BINARY:
|
||||
await self._handle_binary_message(message['data'])
|
||||
elif message['opcode'] == self.OPCODE_CLOSE:
|
||||
await self._handle_close_frame(message['data'])
|
||||
break
|
||||
elif message['opcode'] == self.OPCODE_PING:
|
||||
await self._send_frame(self.OPCODE_PONG, message['data'])
|
||||
elif message['opcode'] == self.OPCODE_PONG:
|
||||
pass # Pong received, connection is alive
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.logger.warning("Receive timeout")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in receive loop: {e}")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Fatal error in receive loop: {e}")
|
||||
finally:
|
||||
if not self._closing:
|
||||
await self._handle_disconnect()
|
||||
|
||||
async def _receive_message(self) -> Dict[str, Any]:
|
||||
"""Receive and parse a WebSocket frame"""
|
||||
# Read frame header
|
||||
header = await self.reader.readexactly(2)
|
||||
|
||||
fin = (header[0] & 0x80) != 0
|
||||
opcode = header[0] & 0x0f
|
||||
masked = (header[1] & 0x80) != 0
|
||||
payload_length = header[1] & 0x7f
|
||||
|
||||
# Read extended payload length if needed
|
||||
if payload_length == 126:
|
||||
payload_length = int.from_bytes(await self.reader.readexactly(2), 'big')
|
||||
elif payload_length == 127:
|
||||
payload_length = int.from_bytes(await self.reader.readexactly(8), 'big')
|
||||
|
||||
# Read masking key if present (server shouldn't send masked frames)
|
||||
mask_key = None
|
||||
if masked:
|
||||
mask_key = await self.reader.readexactly(4)
|
||||
|
||||
# Read payload
|
||||
payload = await self.reader.readexactly(payload_length)
|
||||
|
||||
# Unmask payload if needed
|
||||
if masked:
|
||||
payload = bytes(b ^ mask_key[i % 4] for i, b in enumerate(payload))
|
||||
|
||||
return {
|
||||
'fin': fin,
|
||||
'opcode': opcode,
|
||||
'data': payload
|
||||
}
|
||||
|
||||
async def _send_frame(self, opcode: int, data: bytes) -> None:
|
||||
"""Send a WebSocket frame"""
|
||||
if self.state != ConnectionState.CONNECTED:
|
||||
raise ConnectionError("Not connected")
|
||||
|
||||
# Build frame
|
||||
frame = bytearray()
|
||||
|
||||
# FIN = 1, RSV = 0, Opcode
|
||||
frame.append(0x80 | opcode)
|
||||
|
||||
# Mask bit = 1 (client must mask), payload length
|
||||
length = len(data)
|
||||
if length <= 125:
|
||||
frame.append(0x80 | length)
|
||||
elif length <= 65535:
|
||||
frame.append(0x80 | 126)
|
||||
frame.extend(length.to_bytes(2, 'big'))
|
||||
else:
|
||||
frame.append(0x80 | 127)
|
||||
frame.extend(length.to_bytes(8, 'big'))
|
||||
|
||||
# Masking key
|
||||
mask_key = os.urandom(4)
|
||||
frame.extend(mask_key)
|
||||
|
||||
# Masked payload
|
||||
masked_data = bytes(b ^ mask_key[i % 4] for i, b in enumerate(data))
|
||||
frame.extend(masked_data)
|
||||
|
||||
# Send frame
|
||||
self.writer.write(frame)
|
||||
await self.writer.drain()
|
||||
|
||||
async def _handle_text_message(self, message: str) -> None:
|
||||
"""Handle incoming text message"""
|
||||
try:
|
||||
data = json.loads(message)
|
||||
|
||||
# Check if this is a response to a request
|
||||
if 'callId' in data and data['callId'] in self._response_futures:
|
||||
future = self._response_futures.pop(data['callId'])
|
||||
if not future.done():
|
||||
future.set_result(data.get('data', data))
|
||||
else:
|
||||
# Emit as event
|
||||
await self._emit_event('message', data)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
self.logger.error(f"Invalid JSON message: {message}")
|
||||
await self._emit_event('message', message)
|
||||
|
||||
async def _handle_binary_message(self, data: bytes) -> None:
|
||||
"""Handle incoming binary message"""
|
||||
await self._emit_event('binary', data)
|
||||
|
||||
async def _handle_close_frame(self, data: bytes) -> None:
|
||||
"""Handle close frame from server"""
|
||||
code = int.from_bytes(data[:2], 'big') if len(data) >= 2 else 1000
|
||||
reason = data[2:].decode('utf-8', errors='ignore') if len(data) > 2 else ''
|
||||
|
||||
self.logger.info(f"Received close frame: {code} {reason}")
|
||||
|
||||
# Send close frame back
|
||||
await self._send_frame(self.OPCODE_CLOSE, data[:2])
|
||||
|
||||
self.state = ConnectionState.CLOSING
|
||||
|
||||
async def _handle_disconnect(self) -> None:
|
||||
"""Handle unexpected disconnection"""
|
||||
self.state = ConnectionState.DISCONNECTED
|
||||
self.logger.warning("WebSocket disconnected")
|
||||
|
||||
# Cancel pending futures
|
||||
for future in self._response_futures.values():
|
||||
if not future.done():
|
||||
future.set_exception(ConnectionError("Disconnected"))
|
||||
self._response_futures.clear()
|
||||
|
||||
# Emit disconnect event
|
||||
await self._emit_event('disconnect', None)
|
||||
|
||||
# Attempt reconnection if not closing
|
||||
if not self._closing:
|
||||
asyncio.create_task(self._reconnect())
|
||||
|
||||
async def _reconnect(self) -> None:
|
||||
"""Attempt to reconnect with exponential backoff"""
|
||||
while self._reconnect_attempts < self.max_reconnect_attempts and not self._closing:
|
||||
self._reconnect_attempts += 1
|
||||
wait_time = min(self.reconnect_interval * (2 ** (self._reconnect_attempts - 1)), 60)
|
||||
|
||||
self.logger.info(f"Reconnecting in {wait_time}s (attempt {self._reconnect_attempts}/{self.max_reconnect_attempts})")
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
try:
|
||||
await self.connect()
|
||||
await self._emit_event('reconnect', self._reconnect_attempts)
|
||||
return
|
||||
except Exception as e:
|
||||
self.logger.error(f"Reconnection failed: {e}")
|
||||
|
||||
self.logger.error("Max reconnection attempts reached")
|
||||
await self._emit_event('reconnect_failed', self._reconnect_attempts)
|
||||
|
||||
async def _ping_loop(self) -> None:
|
||||
"""Send periodic ping frames to keep connection alive"""
|
||||
try:
|
||||
while self.state == ConnectionState.CONNECTED and not self._closing:
|
||||
await asyncio.sleep(self.ping_interval)
|
||||
|
||||
try:
|
||||
await self._send_frame(self.OPCODE_PING, b'')
|
||||
except Exception as e:
|
||||
self.logger.error(f"Ping failed: {e}")
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _process_queued_messages(self) -> None:
|
||||
"""Process messages that were queued during disconnection"""
|
||||
while not self._message_queue.empty():
|
||||
try:
|
||||
message = self._message_queue.get_nowait()
|
||||
await self.send_json(message)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to send queued message: {e}")
|
||||
|
||||
async def send(self, message: Union[str, bytes]) -> None:
|
||||
"""Send a message (text or binary)"""
|
||||
if isinstance(message, str):
|
||||
await self._send_frame(self.OPCODE_TEXT, message.encode('utf-8'))
|
||||
else:
|
||||
await self._send_frame(self.OPCODE_BINARY, message)
|
||||
|
||||
async def send_json(self, data: Dict[str, Any]) -> None:
|
||||
"""Send JSON-encoded message"""
|
||||
message = json.dumps(data)
|
||||
await self.send(message)
|
||||
|
||||
async def request(self, method: str, *args, **kwargs) -> Any:
|
||||
"""
|
||||
Send a request and wait for response.
|
||||
|
||||
Args:
|
||||
method: The method name to call
|
||||
*args: Positional arguments for the method
|
||||
**kwargs: Keyword arguments for the method
|
||||
|
||||
Returns:
|
||||
The response data
|
||||
"""
|
||||
# Handle no_wait flag
|
||||
no_wait = kwargs.pop('no_wait', False)
|
||||
|
||||
# Generate call ID
|
||||
call_id = str(uuid.uuid4()) if not no_wait else None
|
||||
|
||||
# Prepare request
|
||||
request = {
|
||||
"method": method,
|
||||
"args": args
|
||||
}
|
||||
|
||||
if call_id:
|
||||
request["callId"] = call_id
|
||||
|
||||
# Create future for response if waiting
|
||||
if call_id:
|
||||
future = asyncio.Future()
|
||||
self._response_futures[call_id] = future
|
||||
|
||||
try:
|
||||
# Send request
|
||||
if self.state == ConnectionState.CONNECTED:
|
||||
await self.send_json(request)
|
||||
else:
|
||||
# Queue message if not connected
|
||||
if not self._message_queue.full():
|
||||
await self._message_queue.put(request)
|
||||
else:
|
||||
raise ConnectionError("Message queue full")
|
||||
|
||||
# Wait for response if needed
|
||||
if call_id:
|
||||
return await asyncio.wait_for(future, timeout=30.0)
|
||||
else:
|
||||
return True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self._response_futures.pop(call_id, None)
|
||||
raise TimeoutError(f"Request timeout for {method}")
|
||||
except Exception:
|
||||
if call_id:
|
||||
self._response_futures.pop(call_id, None)
|
||||
raise
|
||||
|
||||
def on(self, event: str, handler: Callable) -> None:
|
||||
"""Register an event handler"""
|
||||
self._event_handlers[event].append(handler)
|
||||
|
||||
def off(self, event: str, handler: Callable) -> None:
|
||||
"""Unregister an event handler"""
|
||||
if handler in self._event_handlers[event]:
|
||||
self._event_handlers[event].remove(handler)
|
||||
|
||||
async def _emit_event(self, event: str, data: Any) -> None:
|
||||
"""Emit an event to all registered handlers"""
|
||||
for handler in self._event_handlers[event]:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(handler):
|
||||
await handler(data)
|
||||
else:
|
||||
handler(data)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in event handler for {event}: {e}")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebSocket connection gracefully"""
|
||||
if self.state == ConnectionState.DISCONNECTED:
|
||||
return
|
||||
|
||||
self._closing = True
|
||||
self.state = ConnectionState.CLOSING
|
||||
|
||||
try:
|
||||
# Send close frame
|
||||
close_data = (1000).to_bytes(2, 'big') + b'Normal closure'
|
||||
await self._send_frame(self.OPCODE_CLOSE, close_data)
|
||||
|
||||
# Wait for close response or timeout
|
||||
await asyncio.wait_for(
|
||||
asyncio.shield(self._receive_task),
|
||||
timeout=5.0
|
||||
)
|
||||
except (asyncio.TimeoutError, Exception):
|
||||
pass
|
||||
|
||||
# Cancel tasks
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
if self._ping_task:
|
||||
self._ping_task.cancel()
|
||||
|
||||
# Close connection
|
||||
if self.writer:
|
||||
self.writer.close()
|
||||
await self.writer.wait_closed()
|
||||
|
||||
self.state = ConnectionState.CLOSED
|
||||
self.logger.info("WebSocket connection closed")
|
||||
|
||||
@asynccontextmanager
|
||||
async def session(self):
|
||||
"""Context manager for WebSocket connection"""
|
||||
try:
|
||||
await self.connect()
|
||||
yield self
|
||||
finally:
|
||||
await self.close()
|
||||
|
||||
# Convenience methods for common operations
|
||||
async def login(self, username: str, password: str) -> Any:
|
||||
"""Login to the WebSocket service"""
|
||||
return await self.request("login", username, password)
|
||||
|
||||
async def get_channels(self) -> Any:
|
||||
"""Get available channels"""
|
||||
return await self.request("get_channels")
|
||||
|
||||
async def get_messages(self, channel_uid: str) -> Any:
|
||||
"""Get messages from a channel"""
|
||||
return await self.request("get_messages", channel_uid)
|
||||
|
||||
async def send_message(self, channel_uid: str, message: str, **kwargs) -> Any:
|
||||
"""Send a message to a channel"""
|
||||
return await self.request("send_message", channel_uid, message, **kwargs)
|
||||
|
||||
|
||||
# Example usage
|
||||
async def main():
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
# Create WebSocket client
|
||||
url = "wss://molodetz.online/rpc.ws"
|
||||
client = WebSocketClient(
|
||||
url,
|
||||
ping_interval=30.0,
|
||||
reconnect_interval=5.0,
|
||||
max_reconnect_attempts=10
|
||||
)
|
||||
|
||||
# Register event handlers
|
||||
def on_message(data):
|
||||
print(f"Received message: {data}")
|
||||
|
||||
def on_disconnect(data):
|
||||
print("Disconnected from server")
|
||||
|
||||
def on_reconnect(attempts):
|
||||
print(f"Reconnected after {attempts} attempts")
|
||||
|
||||
client.on('message', on_message)
|
||||
client.on('disconnect', on_disconnect)
|
||||
client.on('reconnect', on_reconnect)
|
||||
|
||||
# Use context manager for automatic cleanup
|
||||
async with client.session():
|
||||
# Login
|
||||
result = await client.login("username", "password")
|
||||
print(f"Login result: {result}")
|
||||
|
||||
# Get channels
|
||||
channels = await client.get_channels()
|
||||
print(f"Available channels: {len(channels)}")
|
||||
|
||||
# Send a message without waiting for response
|
||||
if channels:
|
||||
await client.send_message(
|
||||
channels[0]['uid'],
|
||||
"Hello from Python!",
|
||||
no_wait=True
|
||||
)
|
||||
|
||||
# Keep connection alive
|
||||
await asyncio.sleep(60)
|
||||
|
||||
|
||||
# Load testing example
|
||||
async def load_test():
|
||||
"""Example of running multiple WebSocket clients concurrently"""
|
||||
async def run_client(client_id: int):
|
||||
url = "wss://molodetz.online/rpc.ws"
|
||||
client = WebSocketClient(url, logger=logging.getLogger(f"client-{client_id}"))
|
||||
|
||||
try:
|
||||
async with client.session():
|
||||
await client.login(f"user{client_id}", "password")
|
||||
|
||||
# Simulate activity
|
||||
for i in range(10):
|
||||
channels = await client.get_channels()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Client {client_id} error: {e}")
|
||||
|
||||
# Run 10 clients concurrently
|
||||
tasks = [run_client(i) for i in range(10)]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
Loading…
Reference in New Issue
Block a user