From 572d9610f20e3974ec306236b2d940a143990353 Mon Sep 17 00:00:00 2001 From: retoor Date: Sun, 21 Dec 2025 00:25:08 +0100 Subject: [PATCH] refactor: enhance socket service with granular locks and async optimizations refactor: rename scheduled list to tasks in rpc view --- CHANGELOG.md | 8 ++ pyproject.toml | 2 +- src/snek/service/socket.py | 198 ++++++++++++++++++++++++------------- src/snek/view/rpc.py | 45 +++++---- 4 files changed, 162 insertions(+), 91 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e635a39..e4aba67 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,14 @@ + +## Version 1.12.0 - 2025-12-21 + +The socket service enhances concurrency and performance through improved locking and asynchronous handling. The RPC view renames the scheduled list to tasks for clearer API terminology. + +**Changes:** 2 files, 243 lines +**Languages:** Python (243 lines) + ## Version 1.11.0 - 2025-12-19 Adds the ability to configure whether user registration is open or closed via system settings. Administrators can now toggle registration openness to control access to the registration form. diff --git a/pyproject.toml b/pyproject.toml index 458c82e..2ff7225 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "Snek" -version = "1.11.0" +version = "1.12.0" readme = "README.md" #license = { file = "LICENSE", content-type="text/markdown" } description = "Snek Chat Application by Molodetz" diff --git a/src/snek/service/socket.py b/src/snek/service/socket.py index 1cf4b9c..782bad6 100644 --- a/src/snek/service/socket.py +++ b/src/snek/service/socket.py @@ -46,14 +46,14 @@ class SocketService(BaseService): self.subscribed_channels = set() async def send_json(self, data): - if not self.is_connected: - return False - if not self.ws: - self.is_connected = False - return False if data is None: return False async with self._lock: + if not self.is_connected: + return False + if not self.ws: + self.is_connected = False + return False try: await self.ws.send_json(data) return True @@ -66,17 +66,18 @@ class SocketService(BaseService): return False async def close(self): - if not self.is_connected: - return True async with self._lock: + if not self.is_connected: + return True + self.is_connected = False + self.subscribed_channels.clear() try: - if self.ws: - await self.ws.close() + if self.ws and not self.ws.closed: + await asyncio.wait_for(self.ws.close(), timeout=5.0) + except asyncio.TimeoutError: + logger.debug("Socket close timed out") except Exception as ex: logger.debug(f"Socket close failed: {safe_str(ex)}") - finally: - self.is_connected = False - self.subscribed_channels.clear() return True def __init__(self, app): @@ -85,29 +86,54 @@ class SocketService(BaseService): self.users = {} self.subscriptions = {} self.last_update = str(datetime.now()) - self._lock = asyncio.Lock() + self._sockets_lock = asyncio.Lock() + self._user_locks = {} + self._channel_locks = {} + self._registry_lock = asyncio.Lock() + + async def _get_user_lock(self, user_uid): + async with self._registry_lock: + if user_uid not in self._user_locks: + self._user_locks[user_uid] = asyncio.Lock() + return self._user_locks[user_uid] + + async def _get_channel_lock(self, channel_uid): + async with self._registry_lock: + if channel_uid not in self._channel_locks: + self._channel_locks[channel_uid] = asyncio.Lock() + return self._channel_locks[channel_uid] + async def user_availability_service(self): logger.info("User availability update service started.") while True: try: - users_updated = [] - sockets_copy = list(self.sockets) + async with self._sockets_lock: + sockets_copy = list(self.sockets) + users_to_update = [] + seen_user_uids = set() for s in sockets_copy: try: if not s or not s.is_connected: continue - if not s.user: + if not s.user or not s.user_uid: continue - if s.user in users_updated: + if s.user_uid in seen_user_uids: continue - s.user["last_ping"] = now() + seen_user_uids.add(s.user_uid) + users_to_update.append(s.user_uid) + except Exception as ex: + logger.debug(f"Failed to check user availability: {safe_str(ex)}") + for user_uid in users_to_update: + try: if self.app and hasattr(self.app, "services") and self.app.services: - await self.app.services.user.save(s.user) - users_updated.append(s.user) + user = await self.app.services.user.get(uid=user_uid) + if user: + user["last_ping"] = now() + await self.app.services.user.save(user) except Exception as ex: logger.debug(f"Failed to update user availability: {safe_str(ex)}") - logger.info(f"Updated user availability for {len(users_updated)} online users.") + logger.info(f"Updated user availability for {len(users_to_update)} online users.") except Exception as ex: logger.warning(f"User availability service error: {safe_str(ex)}") try: @@ -130,26 +156,29 @@ class SocketService(BaseService): logger.warning(f"User not found for socket add: {user_uid}") return None s = self.Socket(ws, user) - async with self._lock: + username = safe_get(user, "username", "unknown") + nick = safe_get(user, "nick") or username + color = s.user_color + async with self._sockets_lock: self.sockets.add(s) - try: - s.user["last_ping"] = now() - await self.app.services.user.save(s.user) - except Exception as ex: - logger.debug(f"Failed to update last_ping: {safe_str(ex)}") - username = safe_get(s.user, "username", "unknown") - logger.info(f"Added socket for user {username}") is_first_connection = False - async with self._lock: + user_lock = await self._get_user_lock(user_uid) + async with user_lock: if user_uid not in self.users: self.users[user_uid] = set() is_first_connection = True elif len(self.users[user_uid]) == 0: is_first_connection = True self.users[user_uid].add(s) + try: + fresh_user = await self.app.services.user.get(uid=user_uid) + if fresh_user: + fresh_user["last_ping"] = now() + await self.app.services.user.save(fresh_user) + except Exception as ex: + logger.debug(f"Failed to update last_ping: {safe_str(ex)}") + logger.info(f"Added socket for user {username}") if is_first_connection: - nick = safe_get(s.user, "nick") or safe_get(s.user, "username", "") - color = s.user_color await self._broadcast_presence("arrived", user_uid, nick, color) return s except Exception as ex: @@ -160,15 +189,19 @@ class SocketService(BaseService): if not ws or not channel_uid or not user_uid: return False try: - async with self._lock: - existing_socket = None - for sock in self.sockets: - if sock and sock.ws == ws and sock.user_uid == user_uid: + user_lock = await self._get_user_lock(user_uid) + channel_lock = await self._get_channel_lock(channel_uid) + existing_socket = None + async with user_lock: + user_sockets = self.users.get(user_uid, set()) + for sock in user_sockets: + if sock and sock.ws == ws: existing_socket = sock break if not existing_socket: return False existing_socket.subscribed_channels.add(channel_uid) + async with channel_lock: if channel_uid not in self.subscriptions: self.subscriptions[channel_uid] = set() self.subscriptions[channel_uid].add(user_uid) @@ -182,7 +215,8 @@ class SocketService(BaseService): return 0 count = 0 try: - async with self._lock: + user_lock = await self._get_user_lock(user_uid) + async with user_lock: user_sockets = list(self.users.get(user_uid, [])) for s in user_sockets: if not s: @@ -218,51 +252,75 @@ class SocketService(BaseService): except Exception as ex: logger.warning(f"Broadcast db query failed: {safe_str(ex)}") if not user_uids_to_send: - async with self._lock: + channel_lock = await self._get_channel_lock(channel_uid) + async with channel_lock: if channel_uid in self.subscriptions: user_uids_to_send = set(self.subscriptions[channel_uid]) logger.debug(f"_broadcast: using {len(user_uids_to_send)} users from subscriptions for channel={channel_uid}") + send_tasks = [] for user_uid in user_uids_to_send: - try: - count = await self.send_to_user(user_uid, message) - sent += count - if count > 0: - logger.debug(f"_broadcast: sent to user={user_uid}, sockets={count}") - except Exception as ex: - logger.debug(f"Failed to send to user {user_uid}: {safe_str(ex)}") + send_tasks.append(self._send_to_user_safe(user_uid, message)) + if send_tasks: + results = await asyncio.gather(*send_tasks, return_exceptions=True) + for result in results: + if isinstance(result, int): + sent += result logger.info(f"_broadcast: completed channel={channel_uid}, total_users={len(user_uids_to_send)}, sent={sent}") return True except Exception as ex: logger.warning(f"Broadcast failed: {safe_str(ex)}") return False + async def _send_to_user_safe(self, user_uid, message): + try: + return await self.send_to_user(user_uid, message) + except Exception as ex: + logger.debug(f"Failed to send to user {user_uid}: {safe_str(ex)}") + return 0 + async def delete(self, ws): if not ws: return sockets_to_remove = [] - departures_to_broadcast = [] - async with self._lock: + async with self._sockets_lock: sockets_to_remove = [sock for sock in self.sockets if sock and sock.ws == ws] for s in sockets_to_remove: self.sockets.discard(s) - user_uid = s.user_uid - if user_uid and user_uid in self.users: + departures_to_broadcast = [] + channels_to_cleanup = set() + for s in sockets_to_remove: + user_uid = s.user_uid + if not user_uid: + continue + user_lock = await self._get_user_lock(user_uid) + is_last_connection = False + async with user_lock: + if user_uid in self.users: self.users[user_uid].discard(s) if len(self.users[user_uid]) == 0: del self.users[user_uid] - user_nick = None - try: - if s.user: - user_nick = safe_get(s.user, "nick") or safe_get(s.user, "username") - except Exception: - pass - if user_nick: - departures_to_broadcast.append((user_uid, user_nick, s.user_color)) - for channel_uid in list(s.subscribed_channels): + is_last_connection = True + if is_last_connection: + user_nick = None + try: + if s.user: + user_nick = safe_get(s.user, "nick") or safe_get(s.user, "username") + except Exception: + pass + if user_nick: + departures_to_broadcast.append((user_uid, user_nick, s.user_color)) + for channel_uid in list(s.subscribed_channels): + channels_to_cleanup.add((channel_uid, user_uid)) + for channel_uid, user_uid in channels_to_cleanup: + try: + channel_lock = await self._get_channel_lock(channel_uid) + async with channel_lock: if channel_uid in self.subscriptions: self.subscriptions[channel_uid].discard(user_uid) if len(self.subscriptions[channel_uid]) == 0: del self.subscriptions[channel_uid] + except Exception as ex: + logger.debug(f"Failed to cleanup channel subscription: {safe_str(ex)}") for s in sockets_to_remove: try: username = safe_get(s.user, "username", "unknown") if s.user else "unknown" @@ -292,27 +350,25 @@ class SocketService(BaseService): "timestamp": datetime.now().isoformat(), }, } - sent_count = 0 - async with self._lock: + async with self._sockets_lock: sockets_copy = list(self.sockets) + send_tasks = [] for s in sockets_copy: if not s or not s.is_connected: continue if s.user_uid == user_uid: continue - try: - if await s.send_json(message): - sent_count += 1 - except Exception as ex: - logger.debug(f"Failed to send presence to socket: {safe_str(ex)}") - logger.info(f"Broadcast presence '{event_type}' for {user_nick} to {sent_count} users") + send_tasks.append(s.send_json(message)) + if send_tasks: + results = await asyncio.gather(*send_tasks, return_exceptions=True) + sent_count = sum(1 for r in results if r is True) + logger.info(f"Broadcast presence '{event_type}' for {user_nick} to {sent_count} users") except Exception as ex: logger.warning(f"Broadcast presence failed: {safe_str(ex)}") async def get_connected_users(self): try: - async with self._lock: - return list(self.users.keys()) + return list(self.users.keys()) except Exception: return [] @@ -320,7 +376,8 @@ class SocketService(BaseService): if not user_uid: return 0 try: - async with self._lock: + user_lock = await self._get_user_lock(user_uid) + async with user_lock: return len(self.users.get(user_uid, [])) except Exception: return 0 @@ -329,7 +386,8 @@ class SocketService(BaseService): if not user_uid: return False try: - async with self._lock: + user_lock = await self._get_user_lock(user_uid) + async with user_lock: user_sockets = self.users.get(user_uid, set()) return any(s.is_connected for s in user_sockets if s) except Exception: diff --git a/src/snek/view/rpc.py b/src/snek/view/rpc.py index a105bdc..af63702 100644 --- a/src/snek/view/rpc.py +++ b/src/snek/view/rpc.py @@ -44,7 +44,7 @@ class RPCView(BaseView): self.services = safe_get(self.app, "services") if self.app else None self.ws = ws self.user_session = {} - self._scheduled = [] + self._scheduled_tasks = [] self._finalize_task = None self._is_closed = False @@ -904,20 +904,13 @@ class RPCView(BaseView): return try: seconds = max(0, min(float(seconds) if seconds else 0, 300)) - self._scheduled.append(call) await asyncio.sleep(seconds) - if self.services and self.user_uid: + if self.services and self.user_uid and not self._is_closed: await self.services.socket.send_to_user(self.user_uid, call) except asyncio.CancelledError: pass except Exception as ex: logger.warning(f"Schedule failed: {safe_str(ex)}") - finally: - try: - if call in self._scheduled: - self._scheduled.remove(call) - except (ValueError, Exception): - pass async def ping(self, call_id=None, *args): try: @@ -964,13 +957,12 @@ class RPCView(BaseView): async def get(self): ws = None rpc = None - scheduled = [] + scheduled_tasks = [] async def schedule(uid, seconds, call): if not uid or not call: return try: - scheduled.append(call) seconds = max(0, min(float(seconds) if seconds else 0, 300)) await asyncio.sleep(seconds) if self.services: @@ -979,12 +971,6 @@ class RPCView(BaseView): pass except Exception as ex: logger.debug(f"Schedule failed: {safe_str(ex)}") - finally: - try: - if call in scheduled: - scheduled.remove(call) - except (ValueError, Exception): - pass try: ws = web.WebSocketResponse(heartbeat=30.0, autoping=True) @@ -1018,17 +1004,19 @@ class RPCView(BaseView): try: app = getattr(self.request, "app", None) uptime_seconds = getattr(app, "uptime_seconds", 999) if app else 999 - if not scheduled and uptime_seconds < 15 and user_uid: + if not scheduled_tasks and uptime_seconds < 15 and user_uid: jitter = random.uniform(0, 3) - asyncio.create_task(schedule( + task1 = asyncio.create_task(schedule( user_uid, jitter, {"event": "refresh", "data": {"message": "Finishing deployment"}}, )) + scheduled_tasks.append(task1) uptime = getattr(app, "uptime", "") if app else "" - asyncio.create_task(schedule( + task2 = asyncio.create_task(schedule( user_uid, 15 + jitter, {"event": "deployed", "data": {"uptime": uptime}}, )) + scheduled_tasks.append(task2) except Exception as ex: logger.debug(f"Deployment schedule failed: {safe_str(ex)}") @@ -1072,6 +1060,23 @@ class RPCView(BaseView): finally: if rpc: rpc._is_closed = True + if rpc._finalize_task: + try: + rpc._finalize_task.cancel() + except Exception: + pass + for task in rpc._scheduled_tasks: + try: + if task and not task.done(): + task.cancel() + except Exception: + pass + for task in scheduled_tasks: + try: + if task and not task.done(): + task.cancel() + except Exception: + pass try: if self.services and ws: await self.services.socket.delete(ws)