refactor: enhance socket service with granular locks and async optimizations

refactor: rename scheduled list to tasks in rpc view
This commit is contained in:
retoor 2025-12-21 00:25:08 +01:00
parent 1d3444d665
commit 572d9610f2
4 changed files with 162 additions and 91 deletions

View File

@ -12,6 +12,14 @@
## Version 1.12.0 - 2025-12-21
The socket service enhances concurrency and performance through improved locking and asynchronous handling. The RPC view renames the scheduled list to tasks for clearer API terminology.
**Changes:** 2 files, 243 lines
**Languages:** Python (243 lines)
## Version 1.11.0 - 2025-12-19 ## Version 1.11.0 - 2025-12-19
Adds the ability to configure whether user registration is open or closed via system settings. Administrators can now toggle registration openness to control access to the registration form. Adds the ability to configure whether user registration is open or closed via system settings. Administrators can now toggle registration openness to control access to the registration form.

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "Snek" name = "Snek"
version = "1.11.0" version = "1.12.0"
readme = "README.md" readme = "README.md"
#license = { file = "LICENSE", content-type="text/markdown" } #license = { file = "LICENSE", content-type="text/markdown" }
description = "Snek Chat Application by Molodetz" description = "Snek Chat Application by Molodetz"

View File

@ -46,14 +46,14 @@ class SocketService(BaseService):
self.subscribed_channels = set() self.subscribed_channels = set()
async def send_json(self, data): async def send_json(self, data):
if data is None:
return False
async with self._lock:
if not self.is_connected: if not self.is_connected:
return False return False
if not self.ws: if not self.ws:
self.is_connected = False self.is_connected = False
return False return False
if data is None:
return False
async with self._lock:
try: try:
await self.ws.send_json(data) await self.ws.send_json(data)
return True return True
@ -66,17 +66,18 @@ class SocketService(BaseService):
return False return False
async def close(self): async def close(self):
async with self._lock:
if not self.is_connected: if not self.is_connected:
return True return True
async with self._lock:
try:
if self.ws:
await self.ws.close()
except Exception as ex:
logger.debug(f"Socket close failed: {safe_str(ex)}")
finally:
self.is_connected = False self.is_connected = False
self.subscribed_channels.clear() self.subscribed_channels.clear()
try:
if self.ws and not self.ws.closed:
await asyncio.wait_for(self.ws.close(), timeout=5.0)
except asyncio.TimeoutError:
logger.debug("Socket close timed out")
except Exception as ex:
logger.debug(f"Socket close failed: {safe_str(ex)}")
return True return True
def __init__(self, app): def __init__(self, app):
@ -85,29 +86,54 @@ class SocketService(BaseService):
self.users = {} self.users = {}
self.subscriptions = {} self.subscriptions = {}
self.last_update = str(datetime.now()) self.last_update = str(datetime.now())
self._lock = asyncio.Lock() self._sockets_lock = asyncio.Lock()
self._user_locks = {}
self._channel_locks = {}
self._registry_lock = asyncio.Lock()
async def _get_user_lock(self, user_uid):
async with self._registry_lock:
if user_uid not in self._user_locks:
self._user_locks[user_uid] = asyncio.Lock()
return self._user_locks[user_uid]
async def _get_channel_lock(self, channel_uid):
async with self._registry_lock:
if channel_uid not in self._channel_locks:
self._channel_locks[channel_uid] = asyncio.Lock()
return self._channel_locks[channel_uid]
async def user_availability_service(self): async def user_availability_service(self):
logger.info("User availability update service started.") logger.info("User availability update service started.")
while True: while True:
try: try:
users_updated = [] async with self._sockets_lock:
sockets_copy = list(self.sockets) sockets_copy = list(self.sockets)
users_to_update = []
seen_user_uids = set()
for s in sockets_copy: for s in sockets_copy:
try: try:
if not s or not s.is_connected: if not s or not s.is_connected:
continue continue
if not s.user: if not s.user or not s.user_uid:
continue continue
if s.user in users_updated: if s.user_uid in seen_user_uids:
continue continue
s.user["last_ping"] = now() seen_user_uids.add(s.user_uid)
users_to_update.append(s.user_uid)
except Exception as ex:
logger.debug(f"Failed to check user availability: {safe_str(ex)}")
for user_uid in users_to_update:
try:
if self.app and hasattr(self.app, "services") and self.app.services: if self.app and hasattr(self.app, "services") and self.app.services:
await self.app.services.user.save(s.user) user = await self.app.services.user.get(uid=user_uid)
users_updated.append(s.user) if user:
user["last_ping"] = now()
await self.app.services.user.save(user)
except Exception as ex: except Exception as ex:
logger.debug(f"Failed to update user availability: {safe_str(ex)}") logger.debug(f"Failed to update user availability: {safe_str(ex)}")
logger.info(f"Updated user availability for {len(users_updated)} online users.") logger.info(f"Updated user availability for {len(users_to_update)} online users.")
except Exception as ex: except Exception as ex:
logger.warning(f"User availability service error: {safe_str(ex)}") logger.warning(f"User availability service error: {safe_str(ex)}")
try: try:
@ -130,26 +156,29 @@ class SocketService(BaseService):
logger.warning(f"User not found for socket add: {user_uid}") logger.warning(f"User not found for socket add: {user_uid}")
return None return None
s = self.Socket(ws, user) s = self.Socket(ws, user)
async with self._lock: username = safe_get(user, "username", "unknown")
nick = safe_get(user, "nick") or username
color = s.user_color
async with self._sockets_lock:
self.sockets.add(s) self.sockets.add(s)
try:
s.user["last_ping"] = now()
await self.app.services.user.save(s.user)
except Exception as ex:
logger.debug(f"Failed to update last_ping: {safe_str(ex)}")
username = safe_get(s.user, "username", "unknown")
logger.info(f"Added socket for user {username}")
is_first_connection = False is_first_connection = False
async with self._lock: user_lock = await self._get_user_lock(user_uid)
async with user_lock:
if user_uid not in self.users: if user_uid not in self.users:
self.users[user_uid] = set() self.users[user_uid] = set()
is_first_connection = True is_first_connection = True
elif len(self.users[user_uid]) == 0: elif len(self.users[user_uid]) == 0:
is_first_connection = True is_first_connection = True
self.users[user_uid].add(s) self.users[user_uid].add(s)
try:
fresh_user = await self.app.services.user.get(uid=user_uid)
if fresh_user:
fresh_user["last_ping"] = now()
await self.app.services.user.save(fresh_user)
except Exception as ex:
logger.debug(f"Failed to update last_ping: {safe_str(ex)}")
logger.info(f"Added socket for user {username}")
if is_first_connection: if is_first_connection:
nick = safe_get(s.user, "nick") or safe_get(s.user, "username", "")
color = s.user_color
await self._broadcast_presence("arrived", user_uid, nick, color) await self._broadcast_presence("arrived", user_uid, nick, color)
return s return s
except Exception as ex: except Exception as ex:
@ -160,15 +189,19 @@ class SocketService(BaseService):
if not ws or not channel_uid or not user_uid: if not ws or not channel_uid or not user_uid:
return False return False
try: try:
async with self._lock: user_lock = await self._get_user_lock(user_uid)
channel_lock = await self._get_channel_lock(channel_uid)
existing_socket = None existing_socket = None
for sock in self.sockets: async with user_lock:
if sock and sock.ws == ws and sock.user_uid == user_uid: user_sockets = self.users.get(user_uid, set())
for sock in user_sockets:
if sock and sock.ws == ws:
existing_socket = sock existing_socket = sock
break break
if not existing_socket: if not existing_socket:
return False return False
existing_socket.subscribed_channels.add(channel_uid) existing_socket.subscribed_channels.add(channel_uid)
async with channel_lock:
if channel_uid not in self.subscriptions: if channel_uid not in self.subscriptions:
self.subscriptions[channel_uid] = set() self.subscriptions[channel_uid] = set()
self.subscriptions[channel_uid].add(user_uid) self.subscriptions[channel_uid].add(user_uid)
@ -182,7 +215,8 @@ class SocketService(BaseService):
return 0 return 0
count = 0 count = 0
try: try:
async with self._lock: user_lock = await self._get_user_lock(user_uid)
async with user_lock:
user_sockets = list(self.users.get(user_uid, [])) user_sockets = list(self.users.get(user_uid, []))
for s in user_sockets: for s in user_sockets:
if not s: if not s:
@ -218,38 +252,55 @@ class SocketService(BaseService):
except Exception as ex: except Exception as ex:
logger.warning(f"Broadcast db query failed: {safe_str(ex)}") logger.warning(f"Broadcast db query failed: {safe_str(ex)}")
if not user_uids_to_send: if not user_uids_to_send:
async with self._lock: channel_lock = await self._get_channel_lock(channel_uid)
async with channel_lock:
if channel_uid in self.subscriptions: if channel_uid in self.subscriptions:
user_uids_to_send = set(self.subscriptions[channel_uid]) user_uids_to_send = set(self.subscriptions[channel_uid])
logger.debug(f"_broadcast: using {len(user_uids_to_send)} users from subscriptions for channel={channel_uid}") logger.debug(f"_broadcast: using {len(user_uids_to_send)} users from subscriptions for channel={channel_uid}")
send_tasks = []
for user_uid in user_uids_to_send: for user_uid in user_uids_to_send:
try: send_tasks.append(self._send_to_user_safe(user_uid, message))
count = await self.send_to_user(user_uid, message) if send_tasks:
sent += count results = await asyncio.gather(*send_tasks, return_exceptions=True)
if count > 0: for result in results:
logger.debug(f"_broadcast: sent to user={user_uid}, sockets={count}") if isinstance(result, int):
except Exception as ex: sent += result
logger.debug(f"Failed to send to user {user_uid}: {safe_str(ex)}")
logger.info(f"_broadcast: completed channel={channel_uid}, total_users={len(user_uids_to_send)}, sent={sent}") logger.info(f"_broadcast: completed channel={channel_uid}, total_users={len(user_uids_to_send)}, sent={sent}")
return True return True
except Exception as ex: except Exception as ex:
logger.warning(f"Broadcast failed: {safe_str(ex)}") logger.warning(f"Broadcast failed: {safe_str(ex)}")
return False return False
async def _send_to_user_safe(self, user_uid, message):
try:
return await self.send_to_user(user_uid, message)
except Exception as ex:
logger.debug(f"Failed to send to user {user_uid}: {safe_str(ex)}")
return 0
async def delete(self, ws): async def delete(self, ws):
if not ws: if not ws:
return return
sockets_to_remove = [] sockets_to_remove = []
departures_to_broadcast = [] async with self._sockets_lock:
async with self._lock:
sockets_to_remove = [sock for sock in self.sockets if sock and sock.ws == ws] sockets_to_remove = [sock for sock in self.sockets if sock and sock.ws == ws]
for s in sockets_to_remove: for s in sockets_to_remove:
self.sockets.discard(s) self.sockets.discard(s)
departures_to_broadcast = []
channels_to_cleanup = set()
for s in sockets_to_remove:
user_uid = s.user_uid user_uid = s.user_uid
if user_uid and user_uid in self.users: if not user_uid:
continue
user_lock = await self._get_user_lock(user_uid)
is_last_connection = False
async with user_lock:
if user_uid in self.users:
self.users[user_uid].discard(s) self.users[user_uid].discard(s)
if len(self.users[user_uid]) == 0: if len(self.users[user_uid]) == 0:
del self.users[user_uid] del self.users[user_uid]
is_last_connection = True
if is_last_connection:
user_nick = None user_nick = None
try: try:
if s.user: if s.user:
@ -259,10 +310,17 @@ class SocketService(BaseService):
if user_nick: if user_nick:
departures_to_broadcast.append((user_uid, user_nick, s.user_color)) departures_to_broadcast.append((user_uid, user_nick, s.user_color))
for channel_uid in list(s.subscribed_channels): for channel_uid in list(s.subscribed_channels):
channels_to_cleanup.add((channel_uid, user_uid))
for channel_uid, user_uid in channels_to_cleanup:
try:
channel_lock = await self._get_channel_lock(channel_uid)
async with channel_lock:
if channel_uid in self.subscriptions: if channel_uid in self.subscriptions:
self.subscriptions[channel_uid].discard(user_uid) self.subscriptions[channel_uid].discard(user_uid)
if len(self.subscriptions[channel_uid]) == 0: if len(self.subscriptions[channel_uid]) == 0:
del self.subscriptions[channel_uid] del self.subscriptions[channel_uid]
except Exception as ex:
logger.debug(f"Failed to cleanup channel subscription: {safe_str(ex)}")
for s in sockets_to_remove: for s in sockets_to_remove:
try: try:
username = safe_get(s.user, "username", "unknown") if s.user else "unknown" username = safe_get(s.user, "username", "unknown") if s.user else "unknown"
@ -292,26 +350,24 @@ class SocketService(BaseService):
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
}, },
} }
sent_count = 0 async with self._sockets_lock:
async with self._lock:
sockets_copy = list(self.sockets) sockets_copy = list(self.sockets)
send_tasks = []
for s in sockets_copy: for s in sockets_copy:
if not s or not s.is_connected: if not s or not s.is_connected:
continue continue
if s.user_uid == user_uid: if s.user_uid == user_uid:
continue continue
try: send_tasks.append(s.send_json(message))
if await s.send_json(message): if send_tasks:
sent_count += 1 results = await asyncio.gather(*send_tasks, return_exceptions=True)
except Exception as ex: sent_count = sum(1 for r in results if r is True)
logger.debug(f"Failed to send presence to socket: {safe_str(ex)}")
logger.info(f"Broadcast presence '{event_type}' for {user_nick} to {sent_count} users") logger.info(f"Broadcast presence '{event_type}' for {user_nick} to {sent_count} users")
except Exception as ex: except Exception as ex:
logger.warning(f"Broadcast presence failed: {safe_str(ex)}") logger.warning(f"Broadcast presence failed: {safe_str(ex)}")
async def get_connected_users(self): async def get_connected_users(self):
try: try:
async with self._lock:
return list(self.users.keys()) return list(self.users.keys())
except Exception: except Exception:
return [] return []
@ -320,7 +376,8 @@ class SocketService(BaseService):
if not user_uid: if not user_uid:
return 0 return 0
try: try:
async with self._lock: user_lock = await self._get_user_lock(user_uid)
async with user_lock:
return len(self.users.get(user_uid, [])) return len(self.users.get(user_uid, []))
except Exception: except Exception:
return 0 return 0
@ -329,7 +386,8 @@ class SocketService(BaseService):
if not user_uid: if not user_uid:
return False return False
try: try:
async with self._lock: user_lock = await self._get_user_lock(user_uid)
async with user_lock:
user_sockets = self.users.get(user_uid, set()) user_sockets = self.users.get(user_uid, set())
return any(s.is_connected for s in user_sockets if s) return any(s.is_connected for s in user_sockets if s)
except Exception: except Exception:

View File

@ -44,7 +44,7 @@ class RPCView(BaseView):
self.services = safe_get(self.app, "services") if self.app else None self.services = safe_get(self.app, "services") if self.app else None
self.ws = ws self.ws = ws
self.user_session = {} self.user_session = {}
self._scheduled = [] self._scheduled_tasks = []
self._finalize_task = None self._finalize_task = None
self._is_closed = False self._is_closed = False
@ -904,20 +904,13 @@ class RPCView(BaseView):
return return
try: try:
seconds = max(0, min(float(seconds) if seconds else 0, 300)) seconds = max(0, min(float(seconds) if seconds else 0, 300))
self._scheduled.append(call)
await asyncio.sleep(seconds) await asyncio.sleep(seconds)
if self.services and self.user_uid: if self.services and self.user_uid and not self._is_closed:
await self.services.socket.send_to_user(self.user_uid, call) await self.services.socket.send_to_user(self.user_uid, call)
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
except Exception as ex: except Exception as ex:
logger.warning(f"Schedule failed: {safe_str(ex)}") logger.warning(f"Schedule failed: {safe_str(ex)}")
finally:
try:
if call in self._scheduled:
self._scheduled.remove(call)
except (ValueError, Exception):
pass
async def ping(self, call_id=None, *args): async def ping(self, call_id=None, *args):
try: try:
@ -964,13 +957,12 @@ class RPCView(BaseView):
async def get(self): async def get(self):
ws = None ws = None
rpc = None rpc = None
scheduled = [] scheduled_tasks = []
async def schedule(uid, seconds, call): async def schedule(uid, seconds, call):
if not uid or not call: if not uid or not call:
return return
try: try:
scheduled.append(call)
seconds = max(0, min(float(seconds) if seconds else 0, 300)) seconds = max(0, min(float(seconds) if seconds else 0, 300))
await asyncio.sleep(seconds) await asyncio.sleep(seconds)
if self.services: if self.services:
@ -979,12 +971,6 @@ class RPCView(BaseView):
pass pass
except Exception as ex: except Exception as ex:
logger.debug(f"Schedule failed: {safe_str(ex)}") logger.debug(f"Schedule failed: {safe_str(ex)}")
finally:
try:
if call in scheduled:
scheduled.remove(call)
except (ValueError, Exception):
pass
try: try:
ws = web.WebSocketResponse(heartbeat=30.0, autoping=True) ws = web.WebSocketResponse(heartbeat=30.0, autoping=True)
@ -1018,17 +1004,19 @@ class RPCView(BaseView):
try: try:
app = getattr(self.request, "app", None) app = getattr(self.request, "app", None)
uptime_seconds = getattr(app, "uptime_seconds", 999) if app else 999 uptime_seconds = getattr(app, "uptime_seconds", 999) if app else 999
if not scheduled and uptime_seconds < 15 and user_uid: if not scheduled_tasks and uptime_seconds < 15 and user_uid:
jitter = random.uniform(0, 3) jitter = random.uniform(0, 3)
asyncio.create_task(schedule( task1 = asyncio.create_task(schedule(
user_uid, jitter, user_uid, jitter,
{"event": "refresh", "data": {"message": "Finishing deployment"}}, {"event": "refresh", "data": {"message": "Finishing deployment"}},
)) ))
scheduled_tasks.append(task1)
uptime = getattr(app, "uptime", "") if app else "" uptime = getattr(app, "uptime", "") if app else ""
asyncio.create_task(schedule( task2 = asyncio.create_task(schedule(
user_uid, 15 + jitter, user_uid, 15 + jitter,
{"event": "deployed", "data": {"uptime": uptime}}, {"event": "deployed", "data": {"uptime": uptime}},
)) ))
scheduled_tasks.append(task2)
except Exception as ex: except Exception as ex:
logger.debug(f"Deployment schedule failed: {safe_str(ex)}") logger.debug(f"Deployment schedule failed: {safe_str(ex)}")
@ -1072,6 +1060,23 @@ class RPCView(BaseView):
finally: finally:
if rpc: if rpc:
rpc._is_closed = True rpc._is_closed = True
if rpc._finalize_task:
try:
rpc._finalize_task.cancel()
except Exception:
pass
for task in rpc._scheduled_tasks:
try:
if task and not task.done():
task.cancel()
except Exception:
pass
for task in scheduled_tasks:
try:
if task and not task.done():
task.cancel()
except Exception:
pass
try: try:
if self.services and ws: if self.services and ws:
await self.services.socket.delete(ws) await self.services.socket.delete(ws)