refactor: enhance socket service with granular locks and async optimizations
refactor: rename scheduled list to tasks in rpc view
This commit is contained in:
parent
1d3444d665
commit
572d9610f2
@ -12,6 +12,14 @@
|
||||
|
||||
|
||||
|
||||
|
||||
## Version 1.12.0 - 2025-12-21
|
||||
|
||||
The socket service enhances concurrency and performance through improved locking and asynchronous handling. The RPC view renames the scheduled list to tasks for clearer API terminology.
|
||||
|
||||
**Changes:** 2 files, 243 lines
|
||||
**Languages:** Python (243 lines)
|
||||
|
||||
## Version 1.11.0 - 2025-12-19
|
||||
|
||||
Adds the ability to configure whether user registration is open or closed via system settings. Administrators can now toggle registration openness to control access to the registration form.
|
||||
|
||||
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "Snek"
|
||||
version = "1.11.0"
|
||||
version = "1.12.0"
|
||||
readme = "README.md"
|
||||
#license = { file = "LICENSE", content-type="text/markdown" }
|
||||
description = "Snek Chat Application by Molodetz"
|
||||
|
||||
@ -46,14 +46,14 @@ class SocketService(BaseService):
|
||||
self.subscribed_channels = set()
|
||||
|
||||
async def send_json(self, data):
|
||||
if data is None:
|
||||
return False
|
||||
async with self._lock:
|
||||
if not self.is_connected:
|
||||
return False
|
||||
if not self.ws:
|
||||
self.is_connected = False
|
||||
return False
|
||||
if data is None:
|
||||
return False
|
||||
async with self._lock:
|
||||
try:
|
||||
await self.ws.send_json(data)
|
||||
return True
|
||||
@ -66,17 +66,18 @@ class SocketService(BaseService):
|
||||
return False
|
||||
|
||||
async def close(self):
|
||||
async with self._lock:
|
||||
if not self.is_connected:
|
||||
return True
|
||||
async with self._lock:
|
||||
try:
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
except Exception as ex:
|
||||
logger.debug(f"Socket close failed: {safe_str(ex)}")
|
||||
finally:
|
||||
self.is_connected = False
|
||||
self.subscribed_channels.clear()
|
||||
try:
|
||||
if self.ws and not self.ws.closed:
|
||||
await asyncio.wait_for(self.ws.close(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug("Socket close timed out")
|
||||
except Exception as ex:
|
||||
logger.debug(f"Socket close failed: {safe_str(ex)}")
|
||||
return True
|
||||
|
||||
def __init__(self, app):
|
||||
@ -85,29 +86,54 @@ class SocketService(BaseService):
|
||||
self.users = {}
|
||||
self.subscriptions = {}
|
||||
self.last_update = str(datetime.now())
|
||||
self._lock = asyncio.Lock()
|
||||
self._sockets_lock = asyncio.Lock()
|
||||
self._user_locks = {}
|
||||
self._channel_locks = {}
|
||||
self._registry_lock = asyncio.Lock()
|
||||
|
||||
async def _get_user_lock(self, user_uid):
|
||||
async with self._registry_lock:
|
||||
if user_uid not in self._user_locks:
|
||||
self._user_locks[user_uid] = asyncio.Lock()
|
||||
return self._user_locks[user_uid]
|
||||
|
||||
async def _get_channel_lock(self, channel_uid):
|
||||
async with self._registry_lock:
|
||||
if channel_uid not in self._channel_locks:
|
||||
self._channel_locks[channel_uid] = asyncio.Lock()
|
||||
return self._channel_locks[channel_uid]
|
||||
|
||||
|
||||
async def user_availability_service(self):
|
||||
logger.info("User availability update service started.")
|
||||
while True:
|
||||
try:
|
||||
users_updated = []
|
||||
async with self._sockets_lock:
|
||||
sockets_copy = list(self.sockets)
|
||||
users_to_update = []
|
||||
seen_user_uids = set()
|
||||
for s in sockets_copy:
|
||||
try:
|
||||
if not s or not s.is_connected:
|
||||
continue
|
||||
if not s.user:
|
||||
if not s.user or not s.user_uid:
|
||||
continue
|
||||
if s.user in users_updated:
|
||||
if s.user_uid in seen_user_uids:
|
||||
continue
|
||||
s.user["last_ping"] = now()
|
||||
seen_user_uids.add(s.user_uid)
|
||||
users_to_update.append(s.user_uid)
|
||||
except Exception as ex:
|
||||
logger.debug(f"Failed to check user availability: {safe_str(ex)}")
|
||||
for user_uid in users_to_update:
|
||||
try:
|
||||
if self.app and hasattr(self.app, "services") and self.app.services:
|
||||
await self.app.services.user.save(s.user)
|
||||
users_updated.append(s.user)
|
||||
user = await self.app.services.user.get(uid=user_uid)
|
||||
if user:
|
||||
user["last_ping"] = now()
|
||||
await self.app.services.user.save(user)
|
||||
except Exception as ex:
|
||||
logger.debug(f"Failed to update user availability: {safe_str(ex)}")
|
||||
logger.info(f"Updated user availability for {len(users_updated)} online users.")
|
||||
logger.info(f"Updated user availability for {len(users_to_update)} online users.")
|
||||
except Exception as ex:
|
||||
logger.warning(f"User availability service error: {safe_str(ex)}")
|
||||
try:
|
||||
@ -130,26 +156,29 @@ class SocketService(BaseService):
|
||||
logger.warning(f"User not found for socket add: {user_uid}")
|
||||
return None
|
||||
s = self.Socket(ws, user)
|
||||
async with self._lock:
|
||||
username = safe_get(user, "username", "unknown")
|
||||
nick = safe_get(user, "nick") or username
|
||||
color = s.user_color
|
||||
async with self._sockets_lock:
|
||||
self.sockets.add(s)
|
||||
try:
|
||||
s.user["last_ping"] = now()
|
||||
await self.app.services.user.save(s.user)
|
||||
except Exception as ex:
|
||||
logger.debug(f"Failed to update last_ping: {safe_str(ex)}")
|
||||
username = safe_get(s.user, "username", "unknown")
|
||||
logger.info(f"Added socket for user {username}")
|
||||
is_first_connection = False
|
||||
async with self._lock:
|
||||
user_lock = await self._get_user_lock(user_uid)
|
||||
async with user_lock:
|
||||
if user_uid not in self.users:
|
||||
self.users[user_uid] = set()
|
||||
is_first_connection = True
|
||||
elif len(self.users[user_uid]) == 0:
|
||||
is_first_connection = True
|
||||
self.users[user_uid].add(s)
|
||||
try:
|
||||
fresh_user = await self.app.services.user.get(uid=user_uid)
|
||||
if fresh_user:
|
||||
fresh_user["last_ping"] = now()
|
||||
await self.app.services.user.save(fresh_user)
|
||||
except Exception as ex:
|
||||
logger.debug(f"Failed to update last_ping: {safe_str(ex)}")
|
||||
logger.info(f"Added socket for user {username}")
|
||||
if is_first_connection:
|
||||
nick = safe_get(s.user, "nick") or safe_get(s.user, "username", "")
|
||||
color = s.user_color
|
||||
await self._broadcast_presence("arrived", user_uid, nick, color)
|
||||
return s
|
||||
except Exception as ex:
|
||||
@ -160,15 +189,19 @@ class SocketService(BaseService):
|
||||
if not ws or not channel_uid or not user_uid:
|
||||
return False
|
||||
try:
|
||||
async with self._lock:
|
||||
user_lock = await self._get_user_lock(user_uid)
|
||||
channel_lock = await self._get_channel_lock(channel_uid)
|
||||
existing_socket = None
|
||||
for sock in self.sockets:
|
||||
if sock and sock.ws == ws and sock.user_uid == user_uid:
|
||||
async with user_lock:
|
||||
user_sockets = self.users.get(user_uid, set())
|
||||
for sock in user_sockets:
|
||||
if sock and sock.ws == ws:
|
||||
existing_socket = sock
|
||||
break
|
||||
if not existing_socket:
|
||||
return False
|
||||
existing_socket.subscribed_channels.add(channel_uid)
|
||||
async with channel_lock:
|
||||
if channel_uid not in self.subscriptions:
|
||||
self.subscriptions[channel_uid] = set()
|
||||
self.subscriptions[channel_uid].add(user_uid)
|
||||
@ -182,7 +215,8 @@ class SocketService(BaseService):
|
||||
return 0
|
||||
count = 0
|
||||
try:
|
||||
async with self._lock:
|
||||
user_lock = await self._get_user_lock(user_uid)
|
||||
async with user_lock:
|
||||
user_sockets = list(self.users.get(user_uid, []))
|
||||
for s in user_sockets:
|
||||
if not s:
|
||||
@ -218,38 +252,55 @@ class SocketService(BaseService):
|
||||
except Exception as ex:
|
||||
logger.warning(f"Broadcast db query failed: {safe_str(ex)}")
|
||||
if not user_uids_to_send:
|
||||
async with self._lock:
|
||||
channel_lock = await self._get_channel_lock(channel_uid)
|
||||
async with channel_lock:
|
||||
if channel_uid in self.subscriptions:
|
||||
user_uids_to_send = set(self.subscriptions[channel_uid])
|
||||
logger.debug(f"_broadcast: using {len(user_uids_to_send)} users from subscriptions for channel={channel_uid}")
|
||||
send_tasks = []
|
||||
for user_uid in user_uids_to_send:
|
||||
try:
|
||||
count = await self.send_to_user(user_uid, message)
|
||||
sent += count
|
||||
if count > 0:
|
||||
logger.debug(f"_broadcast: sent to user={user_uid}, sockets={count}")
|
||||
except Exception as ex:
|
||||
logger.debug(f"Failed to send to user {user_uid}: {safe_str(ex)}")
|
||||
send_tasks.append(self._send_to_user_safe(user_uid, message))
|
||||
if send_tasks:
|
||||
results = await asyncio.gather(*send_tasks, return_exceptions=True)
|
||||
for result in results:
|
||||
if isinstance(result, int):
|
||||
sent += result
|
||||
logger.info(f"_broadcast: completed channel={channel_uid}, total_users={len(user_uids_to_send)}, sent={sent}")
|
||||
return True
|
||||
except Exception as ex:
|
||||
logger.warning(f"Broadcast failed: {safe_str(ex)}")
|
||||
return False
|
||||
|
||||
async def _send_to_user_safe(self, user_uid, message):
|
||||
try:
|
||||
return await self.send_to_user(user_uid, message)
|
||||
except Exception as ex:
|
||||
logger.debug(f"Failed to send to user {user_uid}: {safe_str(ex)}")
|
||||
return 0
|
||||
|
||||
async def delete(self, ws):
|
||||
if not ws:
|
||||
return
|
||||
sockets_to_remove = []
|
||||
departures_to_broadcast = []
|
||||
async with self._lock:
|
||||
async with self._sockets_lock:
|
||||
sockets_to_remove = [sock for sock in self.sockets if sock and sock.ws == ws]
|
||||
for s in sockets_to_remove:
|
||||
self.sockets.discard(s)
|
||||
departures_to_broadcast = []
|
||||
channels_to_cleanup = set()
|
||||
for s in sockets_to_remove:
|
||||
user_uid = s.user_uid
|
||||
if user_uid and user_uid in self.users:
|
||||
if not user_uid:
|
||||
continue
|
||||
user_lock = await self._get_user_lock(user_uid)
|
||||
is_last_connection = False
|
||||
async with user_lock:
|
||||
if user_uid in self.users:
|
||||
self.users[user_uid].discard(s)
|
||||
if len(self.users[user_uid]) == 0:
|
||||
del self.users[user_uid]
|
||||
is_last_connection = True
|
||||
if is_last_connection:
|
||||
user_nick = None
|
||||
try:
|
||||
if s.user:
|
||||
@ -259,10 +310,17 @@ class SocketService(BaseService):
|
||||
if user_nick:
|
||||
departures_to_broadcast.append((user_uid, user_nick, s.user_color))
|
||||
for channel_uid in list(s.subscribed_channels):
|
||||
channels_to_cleanup.add((channel_uid, user_uid))
|
||||
for channel_uid, user_uid in channels_to_cleanup:
|
||||
try:
|
||||
channel_lock = await self._get_channel_lock(channel_uid)
|
||||
async with channel_lock:
|
||||
if channel_uid in self.subscriptions:
|
||||
self.subscriptions[channel_uid].discard(user_uid)
|
||||
if len(self.subscriptions[channel_uid]) == 0:
|
||||
del self.subscriptions[channel_uid]
|
||||
except Exception as ex:
|
||||
logger.debug(f"Failed to cleanup channel subscription: {safe_str(ex)}")
|
||||
for s in sockets_to_remove:
|
||||
try:
|
||||
username = safe_get(s.user, "username", "unknown") if s.user else "unknown"
|
||||
@ -292,26 +350,24 @@ class SocketService(BaseService):
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
},
|
||||
}
|
||||
sent_count = 0
|
||||
async with self._lock:
|
||||
async with self._sockets_lock:
|
||||
sockets_copy = list(self.sockets)
|
||||
send_tasks = []
|
||||
for s in sockets_copy:
|
||||
if not s or not s.is_connected:
|
||||
continue
|
||||
if s.user_uid == user_uid:
|
||||
continue
|
||||
try:
|
||||
if await s.send_json(message):
|
||||
sent_count += 1
|
||||
except Exception as ex:
|
||||
logger.debug(f"Failed to send presence to socket: {safe_str(ex)}")
|
||||
send_tasks.append(s.send_json(message))
|
||||
if send_tasks:
|
||||
results = await asyncio.gather(*send_tasks, return_exceptions=True)
|
||||
sent_count = sum(1 for r in results if r is True)
|
||||
logger.info(f"Broadcast presence '{event_type}' for {user_nick} to {sent_count} users")
|
||||
except Exception as ex:
|
||||
logger.warning(f"Broadcast presence failed: {safe_str(ex)}")
|
||||
|
||||
async def get_connected_users(self):
|
||||
try:
|
||||
async with self._lock:
|
||||
return list(self.users.keys())
|
||||
except Exception:
|
||||
return []
|
||||
@ -320,7 +376,8 @@ class SocketService(BaseService):
|
||||
if not user_uid:
|
||||
return 0
|
||||
try:
|
||||
async with self._lock:
|
||||
user_lock = await self._get_user_lock(user_uid)
|
||||
async with user_lock:
|
||||
return len(self.users.get(user_uid, []))
|
||||
except Exception:
|
||||
return 0
|
||||
@ -329,7 +386,8 @@ class SocketService(BaseService):
|
||||
if not user_uid:
|
||||
return False
|
||||
try:
|
||||
async with self._lock:
|
||||
user_lock = await self._get_user_lock(user_uid)
|
||||
async with user_lock:
|
||||
user_sockets = self.users.get(user_uid, set())
|
||||
return any(s.is_connected for s in user_sockets if s)
|
||||
except Exception:
|
||||
|
||||
@ -44,7 +44,7 @@ class RPCView(BaseView):
|
||||
self.services = safe_get(self.app, "services") if self.app else None
|
||||
self.ws = ws
|
||||
self.user_session = {}
|
||||
self._scheduled = []
|
||||
self._scheduled_tasks = []
|
||||
self._finalize_task = None
|
||||
self._is_closed = False
|
||||
|
||||
@ -904,20 +904,13 @@ class RPCView(BaseView):
|
||||
return
|
||||
try:
|
||||
seconds = max(0, min(float(seconds) if seconds else 0, 300))
|
||||
self._scheduled.append(call)
|
||||
await asyncio.sleep(seconds)
|
||||
if self.services and self.user_uid:
|
||||
if self.services and self.user_uid and not self._is_closed:
|
||||
await self.services.socket.send_to_user(self.user_uid, call)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as ex:
|
||||
logger.warning(f"Schedule failed: {safe_str(ex)}")
|
||||
finally:
|
||||
try:
|
||||
if call in self._scheduled:
|
||||
self._scheduled.remove(call)
|
||||
except (ValueError, Exception):
|
||||
pass
|
||||
|
||||
async def ping(self, call_id=None, *args):
|
||||
try:
|
||||
@ -964,13 +957,12 @@ class RPCView(BaseView):
|
||||
async def get(self):
|
||||
ws = None
|
||||
rpc = None
|
||||
scheduled = []
|
||||
scheduled_tasks = []
|
||||
|
||||
async def schedule(uid, seconds, call):
|
||||
if not uid or not call:
|
||||
return
|
||||
try:
|
||||
scheduled.append(call)
|
||||
seconds = max(0, min(float(seconds) if seconds else 0, 300))
|
||||
await asyncio.sleep(seconds)
|
||||
if self.services:
|
||||
@ -979,12 +971,6 @@ class RPCView(BaseView):
|
||||
pass
|
||||
except Exception as ex:
|
||||
logger.debug(f"Schedule failed: {safe_str(ex)}")
|
||||
finally:
|
||||
try:
|
||||
if call in scheduled:
|
||||
scheduled.remove(call)
|
||||
except (ValueError, Exception):
|
||||
pass
|
||||
|
||||
try:
|
||||
ws = web.WebSocketResponse(heartbeat=30.0, autoping=True)
|
||||
@ -1018,17 +1004,19 @@ class RPCView(BaseView):
|
||||
try:
|
||||
app = getattr(self.request, "app", None)
|
||||
uptime_seconds = getattr(app, "uptime_seconds", 999) if app else 999
|
||||
if not scheduled and uptime_seconds < 15 and user_uid:
|
||||
if not scheduled_tasks and uptime_seconds < 15 and user_uid:
|
||||
jitter = random.uniform(0, 3)
|
||||
asyncio.create_task(schedule(
|
||||
task1 = asyncio.create_task(schedule(
|
||||
user_uid, jitter,
|
||||
{"event": "refresh", "data": {"message": "Finishing deployment"}},
|
||||
))
|
||||
scheduled_tasks.append(task1)
|
||||
uptime = getattr(app, "uptime", "") if app else ""
|
||||
asyncio.create_task(schedule(
|
||||
task2 = asyncio.create_task(schedule(
|
||||
user_uid, 15 + jitter,
|
||||
{"event": "deployed", "data": {"uptime": uptime}},
|
||||
))
|
||||
scheduled_tasks.append(task2)
|
||||
except Exception as ex:
|
||||
logger.debug(f"Deployment schedule failed: {safe_str(ex)}")
|
||||
|
||||
@ -1072,6 +1060,23 @@ class RPCView(BaseView):
|
||||
finally:
|
||||
if rpc:
|
||||
rpc._is_closed = True
|
||||
if rpc._finalize_task:
|
||||
try:
|
||||
rpc._finalize_task.cancel()
|
||||
except Exception:
|
||||
pass
|
||||
for task in rpc._scheduled_tasks:
|
||||
try:
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
except Exception:
|
||||
pass
|
||||
for task in scheduled_tasks:
|
||||
try:
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self.services and ws:
|
||||
await self.services.socket.delete(ws)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user