From b710008dbea0079761d1d6ee9d73fc5c557c9edc Mon Sep 17 00:00:00 2001 From: retoor Date: Thu, 18 Dec 2025 22:04:01 +0100 Subject: [PATCH] refactor: enhance socket service robustness with locks and error handling fix: add null checks and safe access utilities to prevent crashes refactor: restructure socket methods for better concurrency and logging --- CHANGELOG.md | 8 + pyproject.toml | 2 +- src/snek/service/socket.py | 368 ++++++--- src/snek/static/event-handler.js | 181 +++- src/snek/static/socket.js | 423 ++++++++-- src/snek/view/rpc.py | 1313 +++++++++++++++++++----------- 6 files changed, 1618 insertions(+), 677 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 88a4c3b..1fa2872 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,14 @@ + +## Version 1.8.0 - 2025-12-18 + +The socket service now handles errors more robustly and prevents crashes through improved safety checks. Socket methods support better concurrency and provide enhanced logging for developers. + +**Changes:** 4 files, 2279 lines +**Languages:** JavaScript (592 lines), Python (1687 lines) + ## Version 1.7.0 - 2025-12-17 Fixes socket cleanup in the websocket handler to prevent resource leaks and improve connection stability. diff --git a/pyproject.toml b/pyproject.toml index 75440eb..2db0d0b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "Snek" -version = "1.7.0" +version = "1.8.0" readme = "README.md" #license = { file = "LICENSE", content-type="text/markdown" } description = "Snek Chat Application by Molodetz" diff --git a/src/snek/service/socket.py b/src/snek/service/socket.py index f1a3eaf..251511f 100644 --- a/src/snek/service/socket.py +++ b/src/snek/service/socket.py @@ -1,12 +1,34 @@ +# retoor + import asyncio import logging from datetime import datetime - from snek.model.user import UserModel from snek.system.service import BaseService +from snek.system.model import now logger = logging.getLogger(__name__) -from snek.system.model import now + + +def safe_get(obj, key, default=None): + if obj is None: + return default + try: + if isinstance(obj, dict): + return obj.get(key, default) + return getattr(obj, key, default) + except Exception: + return default + + +def safe_str(obj): + if obj is None: + return "" + try: + return str(obj) + except Exception: + return "" + class SocketService(BaseService): @@ -15,28 +37,43 @@ class SocketService(BaseService): self.ws = ws self.is_connected = True self.user = user - self.user_uid = user["uid"] if user else None - self.user_color = user["color"] if user else None + self.user_uid = safe_get(user, "uid") if user else None + self.user_color = safe_get(user, "color") if user else None + self._lock = asyncio.Lock() + self.subscribed_channels = set() async def send_json(self, data): if not self.is_connected: return False - try: - await self.ws.send_json(data) - except Exception: + if not self.ws: self.is_connected = False - return self.is_connected + return False + if data is None: + return False + async with self._lock: + try: + await self.ws.send_json(data) + return True + except ConnectionResetError: + self.is_connected = False + logger.debug("Connection reset during send_json") + except Exception as ex: + self.is_connected = False + logger.debug(f"send_json failed: {safe_str(ex)}") + return False async def close(self): if not self.is_connected: return True - - try: - await self.ws.close() - except Exception: - pass - self.is_connected = False - + async with self._lock: + try: + if self.ws: + await self.ws.close() + except Exception as ex: + logger.debug(f"Socket close failed: {safe_str(ex)}") + finally: + self.is_connected = False + self.subscribed_channels.clear() return True def __init__(self, app): @@ -45,124 +82,245 @@ class SocketService(BaseService): self.users = {} self.subscriptions = {} self.last_update = str(datetime.now()) + self._lock = asyncio.Lock() async def user_availability_service(self): logger.info("User availability update service started.") - logger.debug("Entering the main loop.") while True: - logger.info("Updating user availability...") - logger.debug("Initializing users_updated list.") - users_updated = [] - logger.debug("Iterating over sockets.") - for s in self.sockets: - logger.debug(f"Checking connection status for socket: {s}.") - if not s.is_connected: - logger.debug("Socket is not connected, continuing to next socket.") - continue - logger.debug(f"Checking if user {s.user} is already updated.") - if s.user not in users_updated: - logger.debug(f"Updating last_ping for user: {s.user}.") - s.user["last_ping"] = now() - logger.debug(f"Saving user {s.user} to the database.") - await self.app.services.user.save(s.user) - logger.debug(f"Adding user {s.user} to users_updated list.") - users_updated.append(s.user) - logger.info( - f"Updated user availability for {len(users_updated)} online users." - ) - logger.debug("Sleeping for 60 seconds before the next update.") - await asyncio.sleep(60) + try: + users_updated = [] + sockets_copy = list(self.sockets) + for s in sockets_copy: + try: + if not s or not s.is_connected: + continue + if not s.user: + continue + if s.user in users_updated: + continue + s.user["last_ping"] = now() + if self.app and hasattr(self.app, "services") and self.app.services: + await self.app.services.user.save(s.user) + users_updated.append(s.user) + except Exception as ex: + logger.debug(f"Failed to update user availability: {safe_str(ex)}") + logger.info(f"Updated user availability for {len(users_updated)} online users.") + except Exception as ex: + logger.warning(f"User availability service error: {safe_str(ex)}") + try: + await asyncio.sleep(60) + except asyncio.CancelledError: + logger.info("User availability service cancelled") + break async def add(self, ws, user_uid): + if not ws: + return None if not user_uid: return None - user = await self.app.services.user.get(uid=user_uid) - if not user: + try: + if not self.app or not hasattr(self.app, "services") or not self.app.services: + logger.warning("Services not available for socket add") + return None + user = await self.app.services.user.get(uid=user_uid) + if not user: + logger.warning(f"User not found for socket add: {user_uid}") + return None + s = self.Socket(ws, user) + async with self._lock: + self.sockets.add(s) + try: + s.user["last_ping"] = now() + await self.app.services.user.save(s.user) + except Exception as ex: + logger.debug(f"Failed to update last_ping: {safe_str(ex)}") + username = safe_get(s.user, "username", "unknown") + logger.info(f"Added socket for user {username}") + is_first_connection = False + async with self._lock: + if user_uid not in self.users: + self.users[user_uid] = set() + is_first_connection = True + elif len(self.users[user_uid]) == 0: + is_first_connection = True + self.users[user_uid].add(s) + if is_first_connection: + nick = safe_get(s.user, "nick") or safe_get(s.user, "username", "") + color = s.user_color + await self._broadcast_presence("arrived", user_uid, nick, color) + return s + except Exception as ex: + logger.warning(f"Failed to add socket: {safe_str(ex)}") return None - s = self.Socket(ws, user) - self.sockets.add(s) - s.user["last_ping"] = now() - await self.app.services.user.save(s.user) - logger.info(f"Added socket for user {s.user['username']}") - - is_first_connection = False - if not self.users.get(user_uid): - self.users[user_uid] = set() - is_first_connection = True - elif len(self.users[user_uid]) == 0: - is_first_connection = True - self.users[user_uid].add(s) - - if is_first_connection: - await self._broadcast_presence("arrived", user_uid, s.user["nick"] or s.user["username"], s.user_color) - - return s async def subscribe(self, ws, channel_uid, user_uid): - if channel_uid not in self.subscriptions: - self.subscriptions[channel_uid] = set() - s = self.Socket(ws, await self.app.services.user.get(uid=user_uid)) - self.subscriptions[channel_uid].add(s) + if not ws or not channel_uid or not user_uid: + return False + try: + async with self._lock: + existing_socket = None + for sock in self.sockets: + if sock and sock.ws == ws and sock.user_uid == user_uid: + existing_socket = sock + break + if not existing_socket: + return False + existing_socket.subscribed_channels.add(channel_uid) + if channel_uid not in self.subscriptions: + self.subscriptions[channel_uid] = set() + self.subscriptions[channel_uid].add(user_uid) + return True + except Exception as ex: + logger.warning(f"Failed to subscribe: {safe_str(ex)}") + return False async def send_to_user(self, user_uid, message): + if not user_uid or message is None: + return 0 count = 0 - for s in list(self.users.get(user_uid, [])): - if await s.send_json(message): - count += 1 + try: + async with self._lock: + user_sockets = list(self.users.get(user_uid, [])) + for s in user_sockets: + if not s: + continue + try: + if await s.send_json(message): + count += 1 + except Exception as ex: + logger.debug(f"Failed to send to user socket: {safe_str(ex)}") + except Exception as ex: + logger.warning(f"send_to_user failed: {safe_str(ex)}") return count async def broadcast(self, channel_uid, message): - await self._broadcast(channel_uid, message) + if not channel_uid or message is None: + return False + return await self._broadcast(channel_uid, message) async def _broadcast(self, channel_uid, message): + if not channel_uid or message is None: + return False sent = 0 + user_uids_to_send = set() try: - async for user_uid in self.services.channel_member.get_user_uids( - channel_uid - ): - sent += await self.send_to_user(user_uid, message) + if self.services: + try: + async for user_uid in self.services.channel_member.get_user_uids(channel_uid): + if user_uid: + user_uids_to_send.add(user_uid) + except Exception as ex: + logger.warning(f"Broadcast db query failed: {safe_str(ex)}") + if not user_uids_to_send: + async with self._lock: + if channel_uid in self.subscriptions: + user_uids_to_send = set(self.subscriptions[channel_uid]) + for user_uid in user_uids_to_send: + try: + sent += await self.send_to_user(user_uid, message) + except Exception as ex: + logger.debug(f"Failed to send to user {user_uid}: {safe_str(ex)}") + logger.debug(f"Broadcasted a message to {sent} users.") + return True except Exception as ex: - print(ex, flush=True) - logger.info(f"Broadcasted a message to {sent} users.") - return True + logger.warning(f"Broadcast failed: {safe_str(ex)}") + return False async def delete(self, ws): - for s in [sock for sock in self.sockets if sock.ws == ws]: - await s.close() - user_uid = s.user_uid - user_nick = (s.user["nick"] or s.user["username"]) if s.user else None - user_color = s.user_color - logger.info(f"Removed socket for user {s.user['username'] if s.user else 'unknown'}") - self.sockets.discard(s) - - if user_uid: - if user_uid in self.users: + if not ws: + return + sockets_to_remove = [] + departures_to_broadcast = [] + async with self._lock: + sockets_to_remove = [sock for sock in self.sockets if sock and sock.ws == ws] + for s in sockets_to_remove: + self.sockets.discard(s) + user_uid = s.user_uid + if user_uid and user_uid in self.users: self.users[user_uid].discard(s) if len(self.users[user_uid]) == 0: - await self._broadcast_presence("departed", user_uid, user_nick, user_color) + del self.users[user_uid] + user_nick = None + try: + if s.user: + user_nick = safe_get(s.user, "nick") or safe_get(s.user, "username") + except Exception: + pass + if user_nick: + departures_to_broadcast.append((user_uid, user_nick, s.user_color)) + for channel_uid in list(s.subscribed_channels): + if channel_uid in self.subscriptions: + self.subscriptions[channel_uid].discard(user_uid) + if len(self.subscriptions[channel_uid]) == 0: + del self.subscriptions[channel_uid] + for s in sockets_to_remove: + try: + username = safe_get(s.user, "username", "unknown") if s.user else "unknown" + logger.info(f"Removed socket for user {username}") + await s.close() + except Exception as ex: + logger.warning(f"Socket close failed: {safe_str(ex)}") + for user_uid, user_nick, user_color in departures_to_broadcast: + try: + await self._broadcast_presence("departed", user_uid, user_nick, user_color) + except Exception as ex: + logger.debug(f"Failed to broadcast departure: {safe_str(ex)}") async def _broadcast_presence(self, event_type, user_uid, user_nick, user_color): if not user_uid or not user_nick: return - message = { - "event": "user_presence", - "data": { - "type": event_type, - "user_uid": user_uid, - "user_nick": user_nick, - "user_color": user_color, - "timestamp": datetime.now().isoformat() + if not event_type or event_type not in ("arrived", "departed"): + return + try: + message = { + "event": "user_presence", + "data": { + "type": event_type, + "user_uid": user_uid, + "user_nick": user_nick, + "user_color": user_color, + "timestamp": datetime.now().isoformat(), + }, } - } - sent_count = 0 - for s in list(self.sockets): - if not s.is_connected: - continue - if s.user_uid == user_uid: - continue - try: - if await s.send_json(message): - sent_count += 1 - except Exception: - pass - logger.info(f"Broadcast presence '{event_type}' for {user_nick} to {sent_count} users") + sent_count = 0 + async with self._lock: + sockets_copy = list(self.sockets) + for s in sockets_copy: + if not s or not s.is_connected: + continue + if s.user_uid == user_uid: + continue + try: + if await s.send_json(message): + sent_count += 1 + except Exception as ex: + logger.debug(f"Failed to send presence to socket: {safe_str(ex)}") + logger.info(f"Broadcast presence '{event_type}' for {user_nick} to {sent_count} users") + except Exception as ex: + logger.warning(f"Broadcast presence failed: {safe_str(ex)}") + + async def get_connected_users(self): + try: + async with self._lock: + return list(self.users.keys()) + except Exception: + return [] + + async def get_user_socket_count(self, user_uid): + if not user_uid: + return 0 + try: + async with self._lock: + return len(self.users.get(user_uid, [])) + except Exception: + return 0 + + async def is_user_online(self, user_uid): + if not user_uid: + return False + try: + async with self._lock: + user_sockets = self.users.get(user_uid, set()) + return any(s.is_connected for s in user_sockets if s) + except Exception: + return False diff --git a/src/snek/static/event-handler.js b/src/snek/static/event-handler.js index 7299982..4b359a7 100644 --- a/src/snek/static/event-handler.js +++ b/src/snek/static/event-handler.js @@ -1,33 +1,178 @@ +// retoor + export class EventHandler { constructor() { this.subscribers = {}; + this._maxListeners = 100; + this._warnOnMaxListeners = true; } - addEventListener(type, handler, { once = false } = {}) { - if (!this.subscribers[type]) this.subscribers[type] = []; - if (once) { - const originalHandler = handler; - handler = (...args) => { - originalHandler(...args); - this.removeEventListener(type, handler); - }; + addEventListener(type, handler, options = {}) { + if (!type || typeof type !== "string") { + console.warn("EventHandler: Invalid event type"); + return false; + } + if (!handler || typeof handler !== "function") { + console.warn("EventHandler: Invalid handler"); + return false; + } + try { + const once = options && options.once === true; + if (!this.subscribers[type]) { + this.subscribers[type] = []; + } + if (this._warnOnMaxListeners && this.subscribers[type].length >= this._maxListeners) { + console.warn(`EventHandler: Max listeners (${this._maxListeners}) reached for event "${type}"`); + } + let wrappedHandler = handler; + if (once) { + const originalHandler = handler; + wrappedHandler = (...args) => { + try { + originalHandler(...args); + } catch (e) { + console.error(`EventHandler: Error in once handler for "${type}":`, e); + } finally { + this.removeEventListener(type, wrappedHandler); + } + }; + wrappedHandler._original = originalHandler; + } + this.subscribers[type].push(wrappedHandler); + return true; + } catch (e) { + console.error("EventHandler: addEventListener error:", e); + return false; } - this.subscribers[type].push(handler); } emit(type, ...data) { - if (this.subscribers[type]) - this.subscribers[type].forEach((handler) => handler(...data)); + if (!type || typeof type !== "string") { + return false; + } + try { + const handlers = this.subscribers[type]; + if (!handlers || !Array.isArray(handlers) || handlers.length === 0) { + return false; + } + const handlersCopy = [...handlers]; + for (const handler of handlersCopy) { + if (typeof handler !== "function") { + continue; + } + try { + handler(...data); + } catch (e) { + console.error(`EventHandler: Error in handler for "${type}":`, e); + } + } + return true; + } catch (e) { + console.error("EventHandler: emit error:", e); + return false; + } } removeEventListener(type, handler) { - if (!this.subscribers[type]) return; - this.subscribers[type] = this.subscribers[type].filter( - (h) => h !== handler - ); - - if (this.subscribers[type].length === 0) { - delete this.subscribers[type]; + if (!type || typeof type !== "string") { + return false; + } + try { + if (!this.subscribers[type]) { + return false; + } + if (!handler) { + delete this.subscribers[type]; + return true; + } + const originalLength = this.subscribers[type].length; + this.subscribers[type] = this.subscribers[type].filter((h) => { + if (h === handler) return false; + if (h._original && h._original === handler) return false; + return true; + }); + if (this.subscribers[type].length === 0) { + delete this.subscribers[type]; + } + return this.subscribers[type] ? this.subscribers[type].length < originalLength : true; + } catch (e) { + console.error("EventHandler: removeEventListener error:", e); + return false; } } + + removeAllEventListeners(type) { + try { + if (type) { + if (this.subscribers[type]) { + delete this.subscribers[type]; + return true; + } + return false; + } + this.subscribers = {}; + return true; + } catch (e) { + console.error("EventHandler: removeAllEventListeners error:", e); + return false; + } + } + + hasEventListener(type, handler) { + if (!type || typeof type !== "string") { + return false; + } + try { + if (!this.subscribers[type]) { + return false; + } + if (!handler) { + return this.subscribers[type].length > 0; + } + return this.subscribers[type].some((h) => { + if (h === handler) return true; + if (h._original && h._original === handler) return true; + return false; + }); + } catch (e) { + return false; + } + } + + getEventListenerCount(type) { + try { + if (type) { + return this.subscribers[type] ? this.subscribers[type].length : 0; + } + let count = 0; + for (const key in this.subscribers) { + if (Object.prototype.hasOwnProperty.call(this.subscribers, key)) { + count += this.subscribers[key].length; + } + } + return count; + } catch (e) { + return 0; + } + } + + getEventTypes() { + try { + return Object.keys(this.subscribers); + } catch (e) { + return []; + } + } + + once(type, handler) { + return this.addEventListener(type, handler, { once: true }); + } + + off(type, handler) { + return this.removeEventListener(type, handler); + } + + on(type, handler) { + return this.addEventListener(type, handler); + } } diff --git a/src/snek/static/socket.js b/src/snek/static/socket.js index ce90201..363dabd 100644 --- a/src/snek/static/socket.js +++ b/src/snek/static/socket.js @@ -1,110 +1,270 @@ +// retoor + import { EventHandler } from "./event-handler.js"; +function createPromiseWithResolvers() { + let resolve, reject; + const promise = new Promise((res, rej) => { + resolve = res; + reject = rej; + }); + return { promise, resolve, reject, resolved: false }; +} + export class Socket extends EventHandler { - /** - * @type {URL} - */ - url; - /** - * @type {WebSocket|null} - */ + url = null; ws = null; - - /** - * @type {null|PromiseWithResolvers&{resolved?:boolean}} - */ connection = null; - shouldReconnect = true; - _debug = false; + _reconnectAttempts = 0; + _maxReconnectAttempts = 50; + _reconnectDelay = 4000; + _pendingCalls = new Map(); + _callTimeout = 30000; + _isDestroyed = false; get isConnected() { - return this.ws && this.ws.readyState === WebSocket.OPEN; + try { + return this.ws !== null && this.ws.readyState === WebSocket.OPEN; + } catch (e) { + return false; + } } get isConnecting() { - return this.ws && this.ws.readyState === WebSocket.CONNECTING; + try { + return this.ws !== null && this.ws.readyState === WebSocket.CONNECTING; + } catch (e) { + return false; + } } constructor() { super(); - - this.url = new URL("/rpc.ws", window.location.origin); - this.url.protocol = this.url.protocol.replace("http", "ws"); - - this.connect(); + try { + this.url = new URL("/rpc.ws", window.location.origin); + this.url.protocol = this.url.protocol.replace("http", "ws"); + this.connect(); + } catch (e) { + console.error("Socket initialization failed:", e); + } } connect() { - if (this.ws) { - return this.connection.promise; + if (this._isDestroyed) { + return Promise.reject(new Error("Socket destroyed")); } - - if (!this.connection || this.connection.resolved) { - this.connection = Promise.withResolvers(); + if (this.ws && (this.isConnected || this.isConnecting)) { + return this.connection ? this.connection.promise : Promise.resolve(this); } - - this.ws = new WebSocket(this.url); - this.ws.addEventListener("open", () => { - this.connection.resolved = true; - this.connection.resolve(this); - this.emit("connected"); - }); - - this.ws.addEventListener("close", () => { - console.log("Connection closed"); - this.disconnect(); - }); - this.ws.addEventListener("error", (e) => { - console.error("Connection error", e); - this.disconnect(); - }); - this.ws.addEventListener("message", (e) => { - if (e.data instanceof Blob || e.data instanceof ArrayBuffer) { - console.error("Binary data not supported"); - } else { - try { - this.onData(JSON.parse(e.data)); - } catch (e) { - console.error("Failed to parse message", e); - } + try { + this._cleanup(); + if (!this.connection || this.connection.resolved) { + this.connection = createPromiseWithResolvers(); } - }); + if (!this.url) { + this.connection.reject(new Error("URL not initialized")); + return this.connection.promise; + } + this.ws = new WebSocket(this.url); + this.ws.addEventListener("open", () => { + try { + this._reconnectAttempts = 0; + if (this.connection && !this.connection.resolved) { + this.connection.resolved = true; + this.connection.resolve(this); + } + this.emit("connected"); + } catch (e) { + console.error("Open handler error:", e); + } + }); + this.ws.addEventListener("close", (event) => { + try { + const reason = event.reason || "Connection closed"; + console.log("Connection closed:", reason); + this._handleDisconnect(); + } catch (e) { + console.error("Close handler error:", e); + } + }); + this.ws.addEventListener("error", (e) => { + try { + console.error("Connection error:", e); + this._handleDisconnect(); + } catch (ex) { + console.error("Error handler error:", ex); + } + }); + this.ws.addEventListener("message", (e) => { + this._handleMessage(e); + }); + return this.connection.promise; + } catch (e) { + console.error("Connect failed:", e); + return Promise.reject(e); + } + } + + _handleMessage(e) { + if (!e || !e.data) { + return; + } + try { + if (e.data instanceof Blob || e.data instanceof ArrayBuffer) { + console.warn("Binary data not supported"); + return; + } + let data; + try { + data = JSON.parse(e.data); + } catch (parseError) { + console.error("Failed to parse message:", parseError); + return; + } + if (data) { + this.onData(data); + } + } catch (error) { + console.error("Message handling error:", error); + } } onData(data) { - if (data.success !== undefined && !data.success) { - console.error(data); + if (!data || typeof data !== "object") { + return; } - if (data.callId) { - this.emit(data.callId, data.data); + try { + if (data.success !== undefined && !data.success) { + console.error("RPC error:", data); + } + if (data.callId) { + try { + const response = { + data: data.data, + success: data.success !== false, + error: data.success === false ? (data.data && data.data.error ? data.data.error : "Unknown error") : null + }; + this.emit(data.callId, response); + if (this._pendingCalls.has(data.callId)) { + clearTimeout(this._pendingCalls.get(data.callId)); + this._pendingCalls.delete(data.callId); + } + } catch (e) { + console.error("CallId emit error:", e); + } + } + if (data.channel_uid) { + try { + this.emit(data.channel_uid, data.data); + if (!data.event) { + this.emit("channel-message", data); + } + } catch (e) { + console.error("Channel emit error:", e); + } + } + try { + this.emit("data", data.data); + } catch (e) { + console.error("Data emit error:", e); + } + if (data.event) { + try { + if (this._debug) { + console.info([data.event, data.data]); + } + this.emit(data.event, data.data); + } catch (e) { + console.error("Event emit error:", e); + } + } + } catch (error) { + console.error("onData error:", error); } - if (data.channel_uid) { - this.emit(data.channel_uid, data.data); - if (!data["event"]) this.emit("channel-message", data); + } + + _cleanup() { + try { + if (this.ws) { + try { + this.ws.close(); + } catch (e) { + // ignore + } + this.ws = null; + } + } catch (e) { + console.error("Cleanup error:", e); } - this.emit("data", data.data); - if (data["event"]) { - console.info([data.event,data.data]) - this.emit(data.event, data.data); + } + + _handleDisconnect() { + this._cleanup(); + this._rejectPendingCalls("Connection lost"); + if (this.connection && !this.connection.resolved) { + this.connection.resolved = true; + this.connection.reject(new Error("Connection failed")); + } + if (this.shouldReconnect && !this._isDestroyed) { + this._reconnectAttempts++; + if (this._reconnectAttempts <= this._maxReconnectAttempts) { + const delay = Math.min( + this._reconnectDelay * Math.pow(1.5, Math.min(this._reconnectAttempts - 1, 10)), + 30000 + ); + setTimeout(() => { + if (!this._isDestroyed && this.shouldReconnect) { + console.log(`Reconnecting (attempt ${this._reconnectAttempts}/${this._maxReconnectAttempts})`); + this.emit("reconnecting"); + this.connect(); + } + }, delay); + } else { + console.error("Max reconnection attempts reached"); + this.emit("reconnect_failed"); + } + } + } + + _rejectPendingCalls(reason) { + try { + for (const [callId, timeoutId] of this._pendingCalls) { + try { + clearTimeout(timeoutId); + this.emit(callId, { data: null, error: reason, success: false }); + } catch (e) { + // ignore + } + } + this._pendingCalls.clear(); + } catch (e) { + console.error("Failed to reject pending calls:", e); } } disconnect() { - this.ws?.close(); - this.ws = null; + this.shouldReconnect = false; + this._cleanup(); + this._rejectPendingCalls("Disconnected"); + } - if (this.shouldReconnect) - setTimeout(() => { - console.log("Reconnecting"); - this.emit("reconnecting"); - return this.connect(); - }, 4000); + destroy() { + this._isDestroyed = true; + this.disconnect(); + this.removeAllEventListeners(); } _camelToSnake(str) { - return str.replace(/([a-z])([A-Z])/g, "$1_$2").toLowerCase(); + if (!str || typeof str !== "string") { + return ""; + } + try { + return str.replace(/([a-z])([A-Z])/g, "$1_$2").toLowerCase(); + } catch (e) { + return str; + } } get client() { @@ -113,39 +273,128 @@ export class Socket extends EventHandler { {}, { get(_, prop) { + if (!prop || typeof prop !== "string") { + return () => Promise.reject(new Error("Invalid method name")); + } return (...args) => { - const functionName = me._camelToSnake(prop); - if(me._debug){ - const call = {} - call[functionName] = args - console.debug(call) + try { + const functionName = me._camelToSnake(prop); + if (me._debug) { + const call = {}; + call[functionName] = args; + console.debug(call); } return me.call(functionName, ...args); + } catch (e) { + console.error("Client call error:", e); + return Promise.reject(e); + } }; }, - }, + } ); } generateCallId() { - return self.crypto.randomUUID(); + try { + if (typeof crypto !== "undefined" && crypto.randomUUID) { + return crypto.randomUUID(); + } + return "xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx".replace(/[xy]/g, (c) => { + const r = (Math.random() * 16) | 0; + const v = c === "x" ? r : (r & 0x3) | 0x8; + return v.toString(16); + }); + } catch (e) { + return Date.now().toString(36) + Math.random().toString(36).substring(2); + } } async sendJson(data) { - await this.connect().then((api) => { - api.ws.send(JSON.stringify(data)); - }); + if (this._isDestroyed) { + throw new Error("Socket destroyed"); + } + if (!data) { + throw new Error("No data to send"); + } + try { + await this.connect(); + if (!this.ws || this.ws.readyState !== WebSocket.OPEN) { + throw new Error("WebSocket not open"); + } + const jsonStr = JSON.stringify(data); + this.ws.send(jsonStr); + } catch (e) { + console.error("sendJson error:", e); + throw e; + } } async call(method, ...args) { - const call = { - callId: this.generateCallId(), + if (this._isDestroyed) { + return Promise.reject(new Error("Socket destroyed")); + } + if (!method || typeof method !== "string") { + return Promise.reject(new Error("Invalid method name")); + } + const callId = this.generateCallId(); + const callData = { + callId, method, - args, + args: args || [], }; - return new Promise((resolve) => { - this.addEventListener(call.callId, (data) => resolve(data), { once: true}); - this.sendJson(call); + return new Promise((resolve, reject) => { + const timeoutId = setTimeout(() => { + try { + this._pendingCalls.delete(callId); + this.removeEventListener(callId, handler); + reject(new Error(`RPC call timeout: ${method}`)); + } catch (e) { + reject(e); + } + }, this._callTimeout); + this._pendingCalls.set(callId, timeoutId); + const handler = (response) => { + try { + clearTimeout(timeoutId); + this._pendingCalls.delete(callId); + if (response && !response.success && response.error) { + reject(new Error(response.error)); + } else { + resolve(response ? response.data : null); + } + } catch (e) { + reject(e); + } + }; + try { + this.addEventListener(callId, handler, { once: true }); + this.sendJson(callData).catch((e) => { + clearTimeout(timeoutId); + this._pendingCalls.delete(callId); + this.removeEventListener(callId, handler); + reject(e); + }); + } catch (e) { + clearTimeout(timeoutId); + this._pendingCalls.delete(callId); + reject(e); + } }); } + + async callWithRetry(method, args = [], maxRetries = 3) { + let lastError; + for (let i = 0; i < maxRetries; i++) { + try { + return await this.call(method, ...args); + } catch (e) { + lastError = e; + if (i < maxRetries - 1) { + await new Promise((r) => setTimeout(r, 1000 * (i + 1))); + } + } + } + throw lastError; + } } diff --git a/src/snek/view/rpc.py b/src/snek/view/rpc.py index ec80ef9..ce52abd 100644 --- a/src/snek/view/rpc.py +++ b/src/snek/view/rpc.py @@ -1,669 +1,1050 @@ -# Written by retoor@molodetz.nl - -# This source code implements a WebSocket-based RPC (Remote Procedure Call) view that uses asynchronous methods to facilitate real-time communication and services for an authenticated user session in a web application. The class handles WebSocket events, user authentication, and various RPC interactions such as login, message retrieval, and more. - -# External imports are used from the aiohttp library for the WebSocket response handling and the snek.system view for the BaseView class. - -# MIT License: Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions. +# retoor import asyncio import json import logging -import traceback import random +import traceback +import time from aiohttp import web - from snek.system.model import now -from snek.system.profiler import Profiler from snek.system.view import BaseView -import time + logger = logging.getLogger(__name__) +def safe_get(obj, key, default=None): + if obj is None: + return default + try: + if isinstance(obj, dict): + return obj.get(key, default) + return getattr(obj, key, default) + except Exception: + return default + + +def safe_str(obj): + if obj is None: + return "" + try: + return str(obj) + except Exception: + return "" + + class RPCView(BaseView): class RPCApi: def __init__(self, view, ws): self.view = view - self.app = self.view.app - self.services = self.app.services + self.app = safe_get(view, "app") + self.services = safe_get(self.app, "services") if self.app else None self.ws = ws self.user_session = {} self._scheduled = [] self._finalize_task = None + self._is_closed = False async def _session_ensure(self): - uid = await self.view.session_get("uid") - if uid not in self.user_session: - self.user_session[uid] = { - "said_hello": False, - } + try: + uid = await self.view.session_get("uid") if self.view else None + if uid and uid not in self.user_session: + self.user_session[uid] = {"said_hello": False} + except Exception as ex: + logger.warning(f"Failed to ensure session: {safe_str(ex)}") - async def session_get(self, key, default): - await self._session_ensure() - return self.user_session[self.user_uid].get(key, default) + async def session_get(self, key, default=None): + try: + await self._session_ensure() + uid = self.user_uid + if not uid or uid not in self.user_session: + return default + return self.user_session[uid].get(key, default) + except Exception as ex: + logger.warning(f"Failed to get session key {key}: {safe_str(ex)}") + return default async def session_set(self, key, value): - await self._session_ensure() - self.user_session[self.user_uid][key] = value - return True + try: + await self._session_ensure() + uid = self.user_uid + if not uid: + return False + if uid not in self.user_session: + self.user_session[uid] = {} + self.user_session[uid][key] = value + return True + except Exception as ex: + logger.warning(f"Failed to set session key {key}: {safe_str(ex)}") + return False async def db_insert(self, table_name, record): self._require_login() - + self._require_services() + if not table_name or not isinstance(table_name, str): + raise ValueError("Invalid table name") + if not record or not isinstance(record, dict): + raise ValueError("Invalid record") return await self.services.db.insert(self.user_uid, table_name, record) async def db_update(self, table_name, record): self._require_login() + self._require_services() + if not table_name or not isinstance(table_name, str): + raise ValueError("Invalid table name") + if not record or not isinstance(record, dict): + raise ValueError("Invalid record") return await self.services.db.update(self.user_uid, table_name, record) async def set_typing(self, channel_uid, color=None): self._require_login() - user = await self.services.user.get(self.user_uid) - if not color: - color = user["color"] - return await self.services.socket.broadcast( - channel_uid, - { - "channel_uid": channel_uid, - "event": "set_typing", - "data": { - "event": "set_typing", - "user_uid": user["uid"], - "username": user["username"], - "nick": user["nick"], + self._require_services() + if not channel_uid or not isinstance(channel_uid, str): + return False + try: + user = await self.services.user.get(uid=self.user_uid) + if not user: + return False + if not color: + color = safe_get(user, "color", "#000000") + return await self.services.socket.broadcast( + channel_uid, + { "channel_uid": channel_uid, - "color": color, + "event": "set_typing", + "data": { + "event": "set_typing", + "user_uid": safe_get(user, "uid", ""), + "username": safe_get(user, "username", ""), + "nick": safe_get(user, "nick", ""), + "channel_uid": channel_uid, + "color": color, + }, }, - }, - ) + ) + except Exception as ex: + logger.warning(f"Failed to set typing: {safe_str(ex)}") + return False async def db_delete(self, table_name, record): self._require_login() + self._require_services() + if not table_name or not isinstance(table_name, str): + raise ValueError("Invalid table name") + if not record or not isinstance(record, dict): + raise ValueError("Invalid record") return await self.services.db.delete(self.user_uid, table_name, record) async def db_get(self, table_name, record): self._require_login() + self._require_services() + if not table_name or not isinstance(table_name, str): + raise ValueError("Invalid table name") return await self.services.db.get(self.user_uid, table_name, record) async def db_find(self, table_name, record): self._require_login() + self._require_services() + if not table_name or not isinstance(table_name, str): + raise ValueError("Invalid table name") return await self.services.db.find(self.user_uid, table_name, record) async def db_upsert(self, table_name, record, keys): self._require_login() + self._require_services() + if not table_name or not isinstance(table_name, str): + raise ValueError("Invalid table name") + if not record or not isinstance(record, dict): + raise ValueError("Invalid record") return await self.services.db.upsert( self.user_uid, table_name, record, keys ) async def db_query(self, sql, args): self._require_login() + self._require_services() + if not sql or not isinstance(sql, str): + raise ValueError("Invalid SQL query") return await self.services.db.query(self.user_uid, sql, args) @property def user_uid(self): - return self.view.session.get("uid") + try: + if not self.view or not hasattr(self.view, "session"): + return None + session = self.view.session + if session is None: + return None + if hasattr(session, "get"): + return session.get("uid") + return None + except Exception: + return None @property def request(self): - return self.view.request + return safe_get(self.view, "request") def _require_login(self): if not self.is_logged_in: - raise Exception("Not logged in") + raise PermissionError("Not logged in") + + def _require_services(self): + if not self.services: + raise RuntimeError("Services not available") @property def is_logged_in(self): - return self.view.session.get("logged_in", False) + try: + if not self.view or not hasattr(self.view, "session"): + return False + session = self.view.session + if session is None: + return False + return bool(session.get("logged_in", False)) + except Exception: + return False async def mark_as_read(self, channel_uid): self._require_login() - await self.services.channel_member.mark_as_read(channel_uid, self.user_uid) - return True + self._require_services() + if not channel_uid or not isinstance(channel_uid, str): + return False + try: + await self.services.channel_member.mark_as_read(channel_uid, self.user_uid) + return True + except Exception as ex: + logger.warning(f"Failed to mark as read: {safe_str(ex)}") + return False async def login(self, username, password): - success = await self.services.user.validate_login(username, password) - if not success: - raise Exception("Invalid username or password") - user = await self.services.user.get(username=username) - self.view.session["uid"] = user["uid"] - self.view.session["logged_in"] = True - self.view.session["username"] = user["username"] - self.view.session["user_nick"] = user["nick"] - record = user.record - del record["password"] - del record["deleted_at"] - await self.services.socket.add( - self.ws, self.view.request.session.get("uid") - ) - async for subscription in self.services.channel_member.find( - user_uid=self.view.request.session.get("uid"), - deleted_at=None, - is_banned=False, - ): - await self.services.socket.subscribe( - self.ws, - subscription["channel_uid"], - self.view.request.session.get("uid"), - ) - return record + self._require_services() + if not username or not isinstance(username, str): + raise ValueError("Invalid username") + if not password or not isinstance(password, str): + raise ValueError("Invalid password") + try: + success = await self.services.user.validate_login(username, password) + if not success: + raise PermissionError("Invalid username or password") + user = await self.services.user.get(username=username) + if not user: + raise PermissionError("User not found") + if not self.view or not hasattr(self.view, "session") or self.view.session is None: + raise RuntimeError("Session not available") + self.view.session["uid"] = safe_get(user, "uid") + self.view.session["logged_in"] = True + self.view.session["username"] = safe_get(user, "username", "") + self.view.session["user_nick"] = safe_get(user, "nick", "") + record = dict(user.record) if hasattr(user, "record") else {} + record.pop("password", None) + record.pop("deleted_at", None) + user_uid = None + if self.view and hasattr(self.view, "session") and self.view.session: + user_uid = self.view.session.get("uid") if hasattr(self.view.session, "get") else None + if user_uid: + await self.services.socket.add(self.ws, user_uid) + try: + async for subscription in self.services.channel_member.find( + user_uid=user_uid, + deleted_at=None, + is_banned=False, + ): + if subscription and safe_get(subscription, "channel_uid"): + await self.services.socket.subscribe( + self.ws, + subscription["channel_uid"], + user_uid, + ) + except Exception as ex: + logger.warning(f"Failed to subscribe to channels: {safe_str(ex)}") + return record + except (PermissionError, ValueError): + raise + except Exception as ex: + logger.error(f"Login error: {safe_str(ex)}") + raise PermissionError("Login failed") async def search_user(self, query): self._require_login() - return [user["username"] for user in await self.services.user.search(query)] + self._require_services() + if not query or not isinstance(query, str): + return [] + try: + results = await self.services.user.search(query) + return [safe_get(user, "username", "") for user in (results or []) if user] + except Exception as ex: + logger.warning(f"Failed to search user: {safe_str(ex)}") + return [] - async def get_user(self, user_uid): + async def get_user(self, user_uid=None): self._require_login() - if not user_uid: - user_uid = self.user_uid - user = await self.services.user.get(uid=user_uid) - record = user.record - del record["password"] - del record["deleted_at"] - if user_uid != user["uid"]: - del record["email"] - return record + self._require_services() + try: + if not user_uid: + user_uid = self.user_uid + if not user_uid: + return None + user = await self.services.user.get(uid=user_uid) + if not user: + return None + record = dict(user.record) if hasattr(user, "record") else {} + record.pop("password", None) + record.pop("deleted_at", None) + if user_uid != self.user_uid: + record.pop("email", None) + return record + except Exception as ex: + logger.warning(f"Failed to get user: {safe_str(ex)}") + return None async def get_messages(self, channel_uid, offset=0, timestamp=None): self._require_login() - messages = [] - for message in await self.services.channel_message.offset( - channel_uid, offset or 0, timestamp or None - ): - extended_dict = await self.services.channel_message.to_extended_dict( - message + self._require_services() + if not channel_uid or not isinstance(channel_uid, str): + return [] + try: + offset = int(offset) if offset else 0 + if offset < 0: + offset = 0 + messages = [] + result = await self.services.channel_message.offset( + channel_uid, offset, timestamp or None ) - messages.append(extended_dict) - return messages + for message in (result or []): + if message: + try: + extended_dict = await self.services.channel_message.to_extended_dict(message) + if extended_dict: + messages.append(extended_dict) + except Exception as ex: + logger.warning(f"Failed to extend message: {safe_str(ex)}") + return messages + except Exception as ex: + logger.warning(f"Failed to get messages: {safe_str(ex)}") + return [] async def get_first_unread_message_uid(self, channel_uid): self._require_login() - channel_member = await self.services.channel_member.get( - channel_uid=channel_uid, user_uid=self.user_uid - ) - if not channel_member: + self._require_services() + if not channel_uid or not isinstance(channel_uid, str): return None - - last_read_at = channel_member["last_read_at"] - if not last_read_at: + try: + channel_member = await self.services.channel_member.get( + channel_uid=channel_uid, user_uid=self.user_uid + ) + if not channel_member: + return None + last_read_at = safe_get(channel_member, "last_read_at") + if not last_read_at: + return None + async for message in self.services.channel_message.query( + "SELECT uid FROM channel_message WHERE channel_uid=:channel_uid AND created_at > :last_read_at AND deleted_at IS NULL ORDER BY created_at ASC LIMIT 1", + {"channel_uid": channel_uid, "last_read_at": last_read_at} + ): + return safe_get(message, "uid") + return None + except Exception as ex: + logger.warning(f"Failed to get first unread message: {safe_str(ex)}") return None - - async for message in self.services.channel_message.query( - "SELECT uid FROM channel_message WHERE channel_uid=:channel_uid AND created_at > :last_read_at AND deleted_at IS NULL ORDER BY created_at ASC LIMIT 1", - {"channel_uid": channel_uid, "last_read_at": last_read_at} - ): - return message["uid"] - - return None async def get_channels(self): self._require_login() + self._require_services() channels = [] - async for subscription in self.services.channel_member.find( - user_uid=self.user_uid, is_banned=False - ): - channel = await self.services.channel.get( - uid=subscription["channel_uid"] - ) - last_message = await channel.get_last_message() - color = None - if last_message: - last_message_user = await last_message.get_user() - color = last_message_user["color"] - channels.append( - { - "name": subscription["label"], - "uid": subscription["channel_uid"], - "tag": channel["tag"], - "new_count": subscription["new_count"], - "is_moderator": subscription["is_moderator"], - "is_read_only": subscription["is_read_only"], - "new_count": subscription["new_count"], - "color": color, - } - ) - return channels - - async def clear_channel(self, channel_uid): + try: + async for subscription in self.services.channel_member.find( + user_uid=self.user_uid, is_banned=False + ): + if not subscription: + continue + try: + channel_uid = safe_get(subscription, "channel_uid") + if not channel_uid: + continue + channel = await self.services.channel.get(uid=channel_uid) + if not channel: + continue + color = None + try: + last_message = await channel.get_last_message() + if last_message: + last_message_user = await last_message.get_user() + color = safe_get(last_message_user, "color") if last_message_user else None + except Exception: + pass + channels.append({ + "name": safe_get(subscription, "label", ""), + "uid": channel_uid, + "tag": safe_get(channel, "tag", ""), + "new_count": safe_get(subscription, "new_count", 0), + "is_moderator": safe_get(subscription, "is_moderator", False), + "is_read_only": safe_get(subscription, "is_read_only", False), + "color": color, + }) + except Exception as ex: + logger.warning(f"Failed to process channel subscription: {safe_str(ex)}") + return channels + except Exception as ex: + logger.warning(f"Failed to get channels: {safe_str(ex)}") + return [] + + async def write_container(self, channel_uid, content, timeout=3): self._require_login() - user = await self.services.user.get(uid=self.user_uid) - if not user["is_admin"]: - raise Exception("Not allowed") - return await self.services.channel_message.clear(channel_uid) + self._require_services() + if not channel_uid or not isinstance(channel_uid, str): + raise ValueError("Invalid channel_uid") + try: + channel_member = await self.services.channel_member.get( + channel_uid=channel_uid, user_uid=self.user_uid + ) + if not channel_member: + raise PermissionError("Not allowed") + container_name = await self.services.container.get_container_name(channel_uid) + if not container_name: + raise RuntimeError("Container not found") + timeout = max(1, min(int(timeout) if timeout else 3, 30)) - async def write_container(self, channel_uid, content,timeout=3): - self._require_login() - channel_member = await self.services.channel_member.get( - channel_uid=channel_uid, user_uid=self.user_uid - ) - if not channel_member: - raise Exception("Not allowed") - - container_name = await self.services.container.get_container_name(channel_uid) + class SessionCall: + def __init__(self, app, channel_uid, container_name): + self.app = app + self.channel_uid = channel_uid + self.container_name = container_name + self.time_last_output = time.time() + self.output = b"" - class SessionCall: + async def stdout_event_handler(self, data): + try: + self.time_last_output = time.time() + if data: + self.output += data if isinstance(data, bytes) else data.encode("utf-8", "ignore") + return True + except Exception: + return False - def __init__(self, app,channel_uid_uid, container_name): - self.app = app - self.channel_uid = channel_uid - self.container_name = container_name - self.time_last_output = time.time() - self.output = b'' - - async def stdout_event_handler(self, data): - self.time_last_output = time.time() - self.output += data - return True + async def communicate(self, content, timeout=3): + try: + if self.app and hasattr(self.app, "services") and self.app.services: + await self.app.services.container.add_event_listener( + self.container_name, "stdout", self.stdout_event_handler + ) + await self.app.services.container.write_stdin(self.channel_uid, content) + start_time = time.time() + while time.time() - self.time_last_output < timeout: + if time.time() - start_time > timeout * 2: + break + await asyncio.sleep(0.1) + await self.app.services.container.remove_event_listener( + self.container_name, "stdout", self.stdout_event_handler + ) + return self.output + except Exception as ex: + logger.warning(f"Container communication error: {safe_str(ex)}") + return self.output - async def communicate(self,content, timeout=3): - await self.app.services.container.add_event_listener(self.container_name, "stdout", self.stdout_event_handler) - await self.app.services.container.write_stdin(self.channel_uid, content) - - while time.time() - self.time_last_output < timeout: - await asyncio.sleep(0.1) - await self.app.services.container.remove_event_listener(self.container_name, "stdout", self.stdout_event_handler) - return self.output - - sc = SessionCall(self, channel_uid,container_name) - return (await sc.communicate(content)).decode("utf-8","ignore") + sc = SessionCall(self.app, channel_uid, container_name) + result = await sc.communicate(content, timeout) + return result.decode("utf-8", "ignore") if result else "" + except (PermissionError, ValueError): + raise + except Exception as ex: + logger.warning(f"Failed to write container: {safe_str(ex)}") + return "" async def get_container(self, channel_uid): self._require_login() - channel_member = await self.services.channel_member.get( - channel_uid=channel_uid, user_uid=self.user_uid - ) - if not channel_member: - raise Exception("Not allowed") - container = await self.services.container.get(channel_uid) - result = None - if container: - result = { + self._require_services() + if not channel_uid or not isinstance(channel_uid, str): + return None + try: + channel_member = await self.services.channel_member.get( + channel_uid=channel_uid, user_uid=self.user_uid + ) + if not channel_member: + raise PermissionError("Not allowed") + container = await self.services.container.get(channel_uid) + if not container: + return None + deploy = safe_get(container, "deploy", {}) + resources = safe_get(deploy, "resources", {}) + limits = safe_get(resources, "limits", {}) + return { "name": await self.services.container.get_container_name(channel_uid), - "cpus": container["deploy"]["resources"]["limits"]["cpus"], - "memory": container["deploy"]["resources"]["limits"]["memory"], + "cpus": safe_get(limits, "cpus", "0.5"), + "memory": safe_get(limits, "memory", "256M"), "image": "ubuntu:latest", "volumes": [], - "status": container["status"] + "status": safe_get(container, "status", "unknown"), } - return result + except PermissionError: + raise + except Exception as ex: + logger.warning(f"Failed to get container: {safe_str(ex)}") + return None - - async def _finalize_message_task(self,message_uid): + async def _finalize_message_task(self, message_uid): + if not message_uid: + return False try: for _ in range(7): await asyncio.sleep(1) if not self._finalize_task: - return - except asyncio.exceptions.CancelledError: - print("Finalization cancelled.") - return - - self._finalize_task = None - message = await self.services.channel_message.get(message_uid, is_final=False) - if not message: + return False + except asyncio.CancelledError: + logger.debug("Finalization cancelled") + return False + except Exception as ex: + logger.warning(f"Finalization wait error: {safe_str(ex)}") + return False + self._finalize_task = None + try: + if not self.services: + return False + message = await self.services.channel_message.get(message_uid, is_final=False) + if not message: + return False + if safe_get(message, "user_uid") != self.user_uid: + return False + if safe_get(message, "is_final", False): + return True + await self.services.chat.finalize(safe_get(message, "uid")) + return True + except Exception as ex: + logger.warning(f"Failed to finalize message: {safe_str(ex)}") return False - if message["user_uid"] != self.user_uid: - raise Exception("Not allowed") - - if message["is_final"]: - return True - - if not message["is_final"]: - await self.services.chat.finalize(message["uid"]) - - - - return True - async def _queue_finalize_message(self, message_uid): - if self._finalize_task: - self._finalize_task.cancel() - self._finalize_task = asyncio.create_task( - self._finalize_message_task(message_uid) - ) + if not message_uid: + return + try: + if self._finalize_task: + self._finalize_task.cancel() + try: + await self._finalize_task + except asyncio.CancelledError: + pass + except Exception: + pass + self._finalize_task = asyncio.create_task( + self._finalize_message_task(message_uid) + ) + except Exception as ex: + logger.warning(f"Failed to queue finalize: {safe_str(ex)}") async def send_message(self, channel_uid, message, is_final=True): self._require_login() - - message = message.strip() - - if not is_final: - if_final = False - check_message = await self.services.channel_message.get(channel_uid=channel_uid, user_uid=self.user_uid,is_final=False,deleted_at=None) - if check_message: - await self._queue_finalize_message(check_message["uid"]) - return await self.update_message_text(check_message["uid"], message) - - if not message: - return - - message = await self.services.chat.send( - self.user_uid, channel_uid, message, is_final - ) - if not message["is_final"]: - await self._queue_finalize_message(message["uid"]) - elif self._finalize_task: - self._finalize_task.cancel() - - return message["uid"] - + self._require_services() + if not channel_uid or not isinstance(channel_uid, str): + return None + if message is None: + return None + try: + message_text = safe_str(message).strip() + if not is_final: + check_message = await self.services.channel_message.get( + channel_uid=channel_uid, user_uid=self.user_uid, is_final=False, deleted_at=None + ) + if check_message: + await self._queue_finalize_message(safe_get(check_message, "uid")) + return await self.update_message_text(safe_get(check_message, "uid"), message_text) + if not message_text: + return None + result = await self.services.chat.send( + self.user_uid, channel_uid, message_text, is_final + ) + if not result: + return None + if not safe_get(result, "is_final", True): + await self._queue_finalize_message(safe_get(result, "uid")) + elif self._finalize_task: + self._finalize_task.cancel() + return safe_get(result, "uid") + except Exception as ex: + logger.warning(f"Failed to send message: {safe_str(ex)}") + return None async def start_container(self, channel_uid): self._require_login() + self._require_services() + if not channel_uid or not isinstance(channel_uid, str): + raise ValueError("Invalid channel_uid") channel_member = await self.services.channel_member.get( channel_uid=channel_uid, user_uid=self.user_uid ) if not channel_member: - raise Exception("Not allowed") + raise PermissionError("Not allowed") return await self.services.container.start(channel_uid) async def stop_container(self, channel_uid): self._require_login() + self._require_services() + if not channel_uid or not isinstance(channel_uid, str): + raise ValueError("Invalid channel_uid") channel_member = await self.services.channel_member.get( channel_uid=channel_uid, user_uid=self.user_uid ) if not channel_member: - raise Exception("Not allowed") + raise PermissionError("Not allowed") return await self.services.container.stop(channel_uid) async def get_container_status(self, channel_uid): self._require_login() - channel_member = await self.services.channel_member.get( - channel_uid=channel_uid, user_uid=self.user_uid - ) - if not channel_member: - raise Exception("Not allowed") - return await self.services.container.get_status(channel_uid) + self._require_services() + if not channel_uid or not isinstance(channel_uid, str): + return None + try: + channel_member = await self.services.channel_member.get( + channel_uid=channel_uid, user_uid=self.user_uid + ) + if not channel_member: + raise PermissionError("Not allowed") + return await self.services.container.get_status(channel_uid) + except PermissionError: + raise + except Exception as ex: + logger.warning(f"Failed to get container status: {safe_str(ex)}") + return None async def finalize_message(self, message_uid): self._require_login() - message = await self.services.channel_message.get(message_uid) - if not message: + self._require_services() + if not message_uid or not isinstance(message_uid, str): + return False + try: + message = await self.services.channel_message.get(message_uid) + if not message: + return False + if safe_get(message, "user_uid") != self.user_uid: + raise PermissionError("Not allowed") + if not safe_get(message, "is_final", False): + await self.services.chat.finalize(safe_get(message, "uid")) + return True + except PermissionError: + raise + except Exception as ex: + logger.warning(f"Failed to finalize message: {safe_str(ex)}") return False - if message["user_uid"] != self.user_uid: - raise Exception("Not allowed") - - if not message["is_final"]: - await self.services.chat.finalize(message["uid"]) - - return True - async def update_message_text(self, message_uid, text): - - #async with self.app.no_save(): self._require_login() - - message = await self.services.channel_message.get(message_uid) - if not message: - return { - "error": "Message not found", - "success": False, - } - if message["is_final"]: - return { - "error": "Message is final", - "success": False, - } - if message["deleted_at"]: - return { - "error": "Message is deleted", - "success": False, - } - if message["user_uid"] != self.user_uid: - raise Exception("Not allowed") - - if message.get_seconds_since_last_update() > 8: - return { - "error": "Message too old", - "seconds_since_last_update": message.get_seconds_since_last_update(), - "success": False, - } - - - if not text: - message["deleted_at"] = now() - else: - message["deleted_at"] = None - - message["message"] = text - - await self.services.channel_message.save(message) - data = message.record - data["text"] = message["message"] - data["message_uid"] = message_uid - - await self.services.socket.broadcast( - message["channel_uid"], - { - "channel_uid": message["channel_uid"], - "event": "update_message_text", - "data": message.record, - }, - ) - - return {"success": True} - - - + self._require_services() + if not message_uid or not isinstance(message_uid, str): + return {"error": "Invalid message_uid", "success": False} + try: + message = await self.services.channel_message.get(message_uid) + if not message: + return {"error": "Message not found", "success": False} + if safe_get(message, "is_final", False): + return {"error": "Message is final", "success": False} + if safe_get(message, "deleted_at"): + return {"error": "Message is deleted", "success": False} + if safe_get(message, "user_uid") != self.user_uid: + raise PermissionError("Not allowed") + seconds_since_update = 0 + try: + if hasattr(message, "get_seconds_since_last_update"): + seconds_since_update = message.get_seconds_since_last_update() + except Exception: + pass + if seconds_since_update > 8: + return { + "error": "Message too old", + "seconds_since_last_update": seconds_since_update, + "success": False, + } + text_str = safe_str(text) if text else "" + if not text_str: + message["deleted_at"] = now() + else: + message["deleted_at"] = None + message["message"] = text_str + await self.services.channel_message.save(message) + record = dict(message.record) if hasattr(message, "record") else {} + record["text"] = safe_get(message, "message", "") + record["message_uid"] = message_uid + channel_uid = safe_get(message, "channel_uid") + if channel_uid: + try: + await self.services.socket.broadcast( + channel_uid, + { + "channel_uid": channel_uid, + "event": "update_message_text", + "data": record, + }, + ) + except Exception as ex: + logger.warning(f"Failed to broadcast message update: {safe_str(ex)}") + return {"success": True} + except PermissionError: + raise + except Exception as ex: + logger.warning(f"Failed to update message text: {safe_str(ex)}") + return {"error": "Update failed", "success": False} async def clear_channel(self, channel_uid): self._require_login() + self._require_services() + if not channel_uid or not isinstance(channel_uid, str): + raise ValueError("Invalid channel_uid") user = await self.services.user.get(uid=self.user_uid) - if not user["is_admin"]: - raise Exception("Not allowed") + if not user or not safe_get(user, "is_admin", False): + raise PermissionError("Not allowed") channel = await self.services.channel.get(uid=channel_uid) if not channel: - raise Exception("Channel not found") - channel['history_start'] = now() + raise ValueError("Channel not found") + channel["history_start"] = now() await self.services.channel.save(channel) return await self.services.channel_message.clear(channel_uid) - async def echo(self, *args): - self._require_login() return args async def query(self, *args): self._require_login() - query = args[0] - lowercase = query.lower() - if ( - any( - keyword in lowercase - for keyword in [ - "drop", - "alter", - "update", - "delete", - "replace", - "insert", - "truncate", - ] - ) - and "select" not in lowercase - ): - raise Exception("Not allowed") - records = [ - dict(record) async for record in self.services.channel.query(args[0]) - ] - for record in records: - try: - del record["email"] - except KeyError: - pass - try: - del record["password"] - except KeyError: - pass - try: - del record["message"] - except: - pass - try: - del record["html"] - except: - pass - return [ - dict(record) async for record in self.services.channel.query(args[0]) - ] + self._require_services() + if not args or len(args) == 0: + return [] + query_str = args[0] + if not query_str or not isinstance(query_str, str): + return [] + lowercase = query_str.lower() + forbidden_keywords = ["drop", "alter", "update", "delete", "replace", "insert", "truncate"] + if any(keyword in lowercase for keyword in forbidden_keywords) and "select" not in lowercase: + raise PermissionError("Not allowed") + try: + results = [] + async for record in self.services.channel.query(query_str): + if record: + record_dict = dict(record) + record_dict.pop("email", None) + record_dict.pop("password", None) + record_dict.pop("message", None) + record_dict.pop("html", None) + results.append(record_dict) + return results + except Exception as ex: + logger.warning(f"Query failed: {safe_str(ex)}") + return [] async def __call__(self, data): + call_id = None try: + if self._is_closed: + return + if not data or not isinstance(data, dict): + logger.warning("Invalid RPC data received") + return call_id = data.get("callId") method_name = data.get("method") + if not method_name or not isinstance(method_name, str): + await self._send_json({"callId": call_id, "success": False, "data": "Invalid method"}) + return if method_name.startswith("_"): - raise Exception("Not allowed") - args = data.get("args") or [] - if hasattr(super(), method_name) or not hasattr(self, method_name): - return await self._send_json( - {"callId": call_id, "data": "Not allowed"} - ) - method = getattr(self, method_name.replace(".", "_"), None) - if not method: - raise Exception("Method not found") + await self._send_json({"callId": call_id, "success": False, "data": "Not allowed"}) + return + args = data.get("args") + if args is None: + args = [] + elif not isinstance(args, (list, tuple)): + args = [args] + safe_method_name = method_name.replace(".", "_") + if not hasattr(self, safe_method_name): + await self._send_json({"callId": call_id, "success": False, "data": "Method not found"}) + return + method = getattr(self, safe_method_name, None) + if not method or not callable(method): + await self._send_json({"callId": call_id, "success": False, "data": "Method not callable"}) + return success = True + result = None try: result = await method(*args) - except Exception as ex: - result = {"exception": str(ex), "traceback": traceback.format_exc()} + except (PermissionError, ValueError) as ex: + result = {"error": safe_str(ex)} success = False - logger.exception(ex) + except Exception as ex: + result = {"error": "Internal error"} + success = False + logger.exception(f"RPC method {method_name} failed: {safe_str(ex)}") if result != "noresponse": - await self._send_json( - {"callId": call_id, "success": success, "data": result} - ) + await self._send_json({"callId": call_id, "success": success, "data": result}) except Exception as ex: - print(str(ex), flush=True) - logger.exception(ex) - await self._send_json( - {"callId": call_id, "success": False, "data": str(ex)} - ) + logger.exception(f"RPC call failed: {safe_str(ex)}") + try: + await self._send_json({"callId": call_id, "success": False, "data": "Internal error"}) + except Exception: + pass async def _send_json(self, obj): + if self._is_closed: + return False + if not self.ws: + return False try: - await self.ws.send_str(json.dumps(obj, default=str)) + json_str = json.dumps(obj, default=str) + await self.ws.send_str(json_str) + return True + except ConnectionResetError: + self._is_closed = True + logger.debug("Connection reset during send") except Exception as ex: - print("THIS IS THE DeAL>",str(ex), flush=True) - await self.services.socket.delete(self.ws) - + self._is_closed = True + logger.warning(f"Failed to send JSON: {safe_str(ex)}") + try: + if self.services and hasattr(self.services, "socket"): + await self.services.socket.delete(self.ws) + except Exception: + pass + return False + async def get_online_users(self, channel_uid): self._require_login() + self._require_services() + if not channel_uid or not isinstance(channel_uid, str): + return [] + try: + results = [] + async for record in self.services.channel.get_recent_users(channel_uid): + if record: + results.append(record) + return sorted(results, key=lambda x: safe_get(x, "nick", "")) + except Exception as ex: + logger.warning(f"Failed to get online users: {safe_str(ex)}") + return [] - results = [ - record - async for record in self.services.channel.get_recent_users(channel_uid) - ] - results = sorted(results, key=lambda x: x["nick"]) - return results - - async def echo(self, obj): - await self.ws.send_json(obj) + async def echo_raw(self, obj): + if not self.ws: + return "noresponse" + try: + await self.ws.send_json(obj if obj else {}) + except Exception as ex: + logger.warning(f"Failed to echo: {safe_str(ex)}") return "noresponse" async def get_recent_users(self, channel_uid): self._require_login() - - return [ - { - "uid": record["uid"], - "username": record["username"], - "nick": record["nick"], - "last_ping": record["last_ping"], - } - async for record in self.services.channel.get_recent_users(channel_uid) - ] + self._require_services() + if not channel_uid or not isinstance(channel_uid, str): + return [] + try: + results = [] + async for record in self.services.channel.get_recent_users(channel_uid): + if record: + results.append({ + "uid": safe_get(record, "uid", ""), + "username": safe_get(record, "username", ""), + "nick": safe_get(record, "nick", ""), + "last_ping": safe_get(record, "last_ping"), + }) + return results + except Exception as ex: + logger.warning(f"Failed to get recent users: {safe_str(ex)}") + return [] async def get_users(self, channel_uid): self._require_login() - - return [ - { - "uid": record["uid"], - "username": record["username"], - "nick": record["nick"], - "last_ping": record["last_ping"], - } - async for record in self.services.channel.get_users(channel_uid) - ] + self._require_services() + if not channel_uid or not isinstance(channel_uid, str): + return [] + try: + results = [] + async for record in self.services.channel.get_users(channel_uid): + if record: + results.append({ + "uid": safe_get(record, "uid", ""), + "username": safe_get(record, "username", ""), + "nick": safe_get(record, "nick", ""), + "last_ping": safe_get(record, "last_ping"), + }) + return results + except Exception as ex: + logger.warning(f"Failed to get users: {safe_str(ex)}") + return [] async def _schedule(self, seconds, call): - self._scheduled.append(call) - await asyncio.sleep(seconds) - await self.services.socket.send_to_user(self.user_uid, call) - self._scheduled.remove(call) + if not call: + return + try: + seconds = max(0, min(float(seconds) if seconds else 0, 300)) + self._scheduled.append(call) + await asyncio.sleep(seconds) + if self.services and self.user_uid: + await self.services.socket.send_to_user(self.user_uid, call) + except asyncio.CancelledError: + pass + except Exception as ex: + logger.warning(f"Schedule failed: {safe_str(ex)}") + finally: + try: + if call in self._scheduled: + self._scheduled.remove(call) + except (ValueError, Exception): + pass - async def ping(self, callId, *args): - if self.user_uid: - user = await self.services.user.get(uid=self.user_uid) - user["last_ping"] = now() - await self.services.user.save(user) - return {"pong": args} + async def ping(self, call_id=None, *args): + try: + if self.user_uid and self.services: + user = await self.services.user.get(uid=self.user_uid) + if user: + user["last_ping"] = now() + await self.services.user.save(user) + return {"pong": args if args else []} + except Exception as ex: + logger.warning(f"Ping failed: {safe_str(ex)}") + return {"pong": []} async def stars_render(self, channel_uid, message): - - for user in await self.get_online_users(channel_uid): - try: - await self.services.socket.send_to_user( - user["uid"], - { - "event": "stars_render", - "data": {"channel_uid": channel_uid, "message": message}, - }, - ) - except Exception as ex: - print(ex) + self._require_services() + if not channel_uid or not isinstance(channel_uid, str): + return False + try: + users = await self.get_online_users(channel_uid) + for user in (users or []): + if not user: + continue + user_uid = safe_get(user, "uid") + if not user_uid: + continue + try: + await self.services.socket.send_to_user( + user_uid, + { + "event": "stars_render", + "data": { + "channel_uid": channel_uid, + "message": safe_str(message) if message else "", + }, + }, + ) + except Exception as ex: + logger.debug(f"Failed to send stars_render to user: {safe_str(ex)}") + return True + except Exception as ex: + logger.warning(f"stars_render failed: {safe_str(ex)}") + return False async def get(self): + ws = None + rpc = None scheduled = [] async def schedule(uid, seconds, call): - scheduled.append(call) - await asyncio.sleep(seconds) - await self.services.socket.send_to_user(uid, call) - scheduled.remove(call) + if not uid or not call: + return + try: + scheduled.append(call) + seconds = max(0, min(float(seconds) if seconds else 0, 300)) + await asyncio.sleep(seconds) + if self.services: + await self.services.socket.send_to_user(uid, call) + except asyncio.CancelledError: + pass + except Exception as ex: + logger.debug(f"Schedule failed: {safe_str(ex)}") + finally: + try: + if call in scheduled: + scheduled.remove(call) + except (ValueError, Exception): + pass - ws = web.WebSocketResponse() - await ws.prepare(self.request) - if self.request.session.get("logged_in"): - await self.services.socket.add(ws, self.request.session.get("uid")) - async for subscription in self.services.channel_member.find( - user_uid=self.request.session.get("uid"), - deleted_at=None, - is_banned=False, - ): - await self.services.socket.subscribe( - ws, subscription["channel_uid"], self.request.session.get("uid") - ) - if not scheduled and self.request.app.uptime_seconds < 5: - await schedule( - self.request.session.get("uid"), - 0, - {"event": "refresh", "data": {"message": "Finishing deployment"}}, - ) - await schedule( - self.request.session.get("uid"), - 15, - {"event": "deployed", "data": {"uptime": self.request.app.uptime}}, - ) + try: + ws = web.WebSocketResponse(heartbeat=30.0, autoping=True) + await ws.prepare(self.request) + except Exception as ex: + logger.warning(f"Failed to prepare WebSocket: {safe_str(ex)}") + return web.Response(status=500, text="WebSocket initialization failed") + + user_uid = None + try: + session = getattr(self.request, "session", None) + if session and session.get("logged_in"): + user_uid = session.get("uid") + if user_uid and self.services: + await self.services.socket.add(ws, user_uid) + try: + async for subscription in self.services.channel_member.find( + user_uid=user_uid, + deleted_at=None, + is_banned=False, + ): + if subscription: + channel_uid = safe_get(subscription, "channel_uid") + if channel_uid: + await self.services.socket.subscribe(ws, channel_uid, user_uid) + except Exception as ex: + logger.warning(f"Failed to subscribe to channels: {safe_str(ex)}") + except Exception as ex: + logger.warning(f"Session initialization failed: {safe_str(ex)}") + + try: + app = getattr(self.request, "app", None) + uptime_seconds = getattr(app, "uptime_seconds", 999) if app else 999 + if not scheduled and uptime_seconds < 240 and user_uid: + jitter = random.uniform(0, 3) + asyncio.create_task(schedule( + user_uid, jitter, + {"event": "refresh", "data": {"message": "Finishing deployment"}}, + )) + uptime = getattr(app, "uptime", "") if app else "" + asyncio.create_task(schedule( + user_uid, 15 + jitter, + {"event": "deployed", "data": {"uptime": uptime}}, + )) + except Exception as ex: + logger.debug(f"Deployment schedule failed: {safe_str(ex)}") rpc = RPCView.RPCApi(self, ws) try: async for msg in ws: - if msg.type == web.WSMsgType.TEXT: - try: - await rpc(msg.json()) - except Exception as ex: - print("XXXXXXXXXX Deleting socket", ex, flush=True) - logger.exception(ex) - await self.services.socket.delete(ws) + if msg is None: + continue + try: + if msg.type == web.WSMsgType.TEXT: + try: + data = msg.json() + if data: + await rpc(data) + except json.JSONDecodeError as ex: + logger.warning(f"Invalid JSON received: {safe_str(ex)}") + except Exception as ex: + logger.warning(f"RPC call failed: {safe_str(ex)}") + elif msg.type == web.WSMsgType.BINARY: + logger.debug("Binary message received, ignoring") + elif msg.type == web.WSMsgType.ERROR: + logger.debug(f"WebSocket error: {safe_str(ws.exception())}") break - elif msg.type == web.WSMsgType.ERROR: - pass - elif msg.type == web.WSMsgType.CLOSE: - pass + elif msg.type == web.WSMsgType.CLOSE: + logger.debug("WebSocket close message received") + break + elif msg.type == web.WSMsgType.CLOSING: + logger.debug("WebSocket closing") + break + elif msg.type == web.WSMsgType.CLOSED: + logger.debug("WebSocket closed") + break + except Exception as ex: + logger.warning(f"Message processing error: {safe_str(ex)}") + except asyncio.CancelledError: + logger.debug("WebSocket handler cancelled") + except ConnectionResetError: + logger.debug("Connection reset by peer") + except Exception as ex: + logger.warning(f"WebSocket loop error: {safe_str(ex)}") finally: - await self.services.socket.delete(ws) + if rpc: + rpc._is_closed = True + try: + if self.services and ws: + await self.services.socket.delete(ws) + except Exception as ex: + logger.debug(f"Socket cleanup failed: {safe_str(ex)}") + try: + if ws and not ws.closed: + await ws.close() + except Exception: + pass return ws