# retoor <retoor@molodetz.nl>
import asyncio
import logging
from datetime import datetime
from snek.model.user import UserModel
from snek.system.service import BaseService
from snek.system.model import now
logger = logging.getLogger(__name__)
def safe_get(obj, key, default=None):
if obj is None:
return default
try:
if isinstance(obj, dict):
return obj.get(key, default)
if hasattr(obj, "fields") and hasattr(obj, "__getitem__"):
val = obj[key]
return val if val is not None else default
return getattr(obj, key, default)
except (KeyError, TypeError, AttributeError):
return default
def safe_str(obj):
if obj is None:
return ""
try:
return str(obj)
except Exception:
return ""
class SocketService(BaseService):
class Socket:
def __init__(self, ws, user: UserModel):
self.ws = ws
self.is_connected = True
self.user = user
self.user_uid = safe_get(user, "uid") if user else None
self.user_color = safe_get(user, "color") if user else None
self._lock = asyncio.Lock()
self.subscribed_channels = set()
async def send_json(self, data):
if data is None:
return False
async with self._lock:
if not self.is_connected:
return False
if not self.ws:
self.is_connected = False
return False
try:
await self.ws.send_json(data)
return True
except ConnectionResetError:
self.is_connected = False
logger.debug("Connection reset during send_json")
except Exception as ex:
self.is_connected = False
logger.debug(f"send_json failed: {safe_str(ex)}")
return False
async def close(self):
async with self._lock:
if not self.is_connected:
return True
self.is_connected = False
self.subscribed_channels.clear()
try:
if self.ws and not self.ws.closed:
await asyncio.wait_for(self.ws.close(), timeout=5.0)
except asyncio.TimeoutError:
logger.debug("Socket close timed out")
except Exception as ex:
logger.debug(f"Socket close failed: {safe_str(ex)}")
return True
def __init__(self, app):
super().__init__(app)
self.sockets = set()
self.users = {}
self.subscriptions = {}
self.last_update = str(datetime.now())
self._sockets_lock = asyncio.Lock()
self._user_locks = {}
self._channel_locks = {}
self._registry_lock = asyncio.Lock()
async def _get_user_lock(self, user_uid):
async with self._registry_lock:
if user_uid not in self._user_locks:
self._user_locks[user_uid] = asyncio.Lock()
return self._user_locks[user_uid]
async def _get_channel_lock(self, channel_uid):
async with self._registry_lock:
if channel_uid not in self._channel_locks:
self._channel_locks[channel_uid] = asyncio.Lock()
return self._channel_locks[channel_uid]
async def user_availability_service(self):
logger.info("User availability update service started.")
while True:
try:
async with self._sockets_lock:
sockets_copy = list(self.sockets)
users_to_update = []
seen_user_uids = set()
for s in sockets_copy:
try:
if not s or not s.is_connected:
continue
if not s.user or not s.user_uid:
continue
if s.user_uid in seen_user_uids:
continue
seen_user_uids.add(s.user_uid)
users_to_update.append(s.user_uid)
except Exception as ex:
logger.debug(f"Failed to check user availability: {safe_str(ex)}")
for user_uid in users_to_update:
try:
if self.app and hasattr(self.app, "services") and self.app.services:
user = await self.app.services.user.get(uid=user_uid)
if user:
user["last_ping"] = now()
await self.app.services.user.save(user)
except Exception as ex:
logger.debug(f"Failed to update user availability: {safe_str(ex)}")
logger.info(f"Updated user availability for {len(users_to_update)} online users.")
except Exception as ex:
logger.warning(f"User availability service error: {safe_str(ex)}")
try:
await asyncio.sleep(60)
except asyncio.CancelledError:
logger.info("User availability service cancelled")
break
async def add(self, ws, user_uid):
if not ws:
return None
if not user_uid:
return None
try:
if not self.app or not hasattr(self.app, "services") or not self.app.services:
logger.warning("Services not available for socket add")
return None
user = await self.app.services.user.get(uid=user_uid)
if not user:
logger.warning(f"User not found for socket add: {user_uid}")
return None
s = self.Socket(ws, user)
username = safe_get(user, "username", "unknown")
nick = safe_get(user, "nick") or username
color = s.user_color
async with self._sockets_lock:
self.sockets.add(s)
is_first_connection = False
user_lock = await self._get_user_lock(user_uid)
async with user_lock:
if user_uid not in self.users:
self.users[user_uid] = set()
is_first_connection = True
elif len(self.users[user_uid]) == 0:
is_first_connection = True
self.users[user_uid].add(s)
try:
fresh_user = await self.app.services.user.get(uid=user_uid)
if fresh_user:
fresh_user["last_ping"] = now()
await self.app.services.user.save(fresh_user)
except Exception as ex:
logger.debug(f"Failed to update last_ping: {safe_str(ex)}")
logger.info(f"Added socket for user {username}")
if is_first_connection:
await self._broadcast_presence("arrived", user_uid, nick, color)
return s
except Exception as ex:
logger.warning(f"Failed to add socket: {safe_str(ex)}")
return None
async def subscribe(self, ws, channel_uid, user_uid):
if not ws or not channel_uid or not user_uid:
return False
try:
user_lock = await self._get_user_lock(user_uid)
channel_lock = await self._get_channel_lock(channel_uid)
existing_socket = None
async with user_lock:
user_sockets = self.users.get(user_uid, set())
for sock in user_sockets:
if sock and sock.ws == ws:
existing_socket = sock
break
if not existing_socket:
return False
existing_socket.subscribed_channels.add(channel_uid)
async with channel_lock:
if channel_uid not in self.subscriptions:
self.subscriptions[channel_uid] = set()
self.subscriptions[channel_uid].add(user_uid)
return True
except Exception as ex:
logger.warning(f"Failed to subscribe: {safe_str(ex)}")
return False
async def send_to_user(self, user_uid, message):
if not user_uid or message is None:
return 0
count = 0
try:
user_lock = await self._get_user_lock(user_uid)
async with user_lock:
user_sockets = list(self.users.get(user_uid, []))
for s in user_sockets:
if not s:
continue
try:
if await s.send_json(message):
count += 1
except Exception as ex:
logger.debug(f"Failed to send to user socket: {safe_str(ex)}")
except Exception as ex:
logger.warning(f"send_to_user failed: {safe_str(ex)}")
return count
async def broadcast(self, channel_uid, message):
if not channel_uid or message is None:
logger.debug(f"broadcast: invalid params channel_uid={channel_uid}, message={message is not None}")
return False
logger.debug(f"broadcast: starting for channel={channel_uid}")
return await self._broadcast(channel_uid, message)
async def _broadcast(self, channel_uid, message):
if not channel_uid or message is None:
return False
sent = 0
user_uids_to_send = set()
try:
if self.services:
try:
async for user_uid in self.services.channel_member.get_user_uids(channel_uid):
if user_uid:
user_uids_to_send.add(user_uid)
logger.debug(f"_broadcast: found {len(user_uids_to_send)} users from db for channel={channel_uid}")
except Exception as ex:
logger.warning(f"Broadcast db query failed: {safe_str(ex)}")
if not user_uids_to_send:
channel_lock = await self._get_channel_lock(channel_uid)
async with channel_lock:
if channel_uid in self.subscriptions:
user_uids_to_send = set(self.subscriptions[channel_uid])
logger.debug(f"_broadcast: using {len(user_uids_to_send)} users from subscriptions for channel={channel_uid}")
send_tasks = []
for user_uid in user_uids_to_send:
send_tasks.append(self._send_to_user_safe(user_uid, message))
if send_tasks:
results = await asyncio.gather(*send_tasks, return_exceptions=True)
for result in results:
if isinstance(result, int):
sent += result
logger.info(f"_broadcast: completed channel={channel_uid}, total_users={len(user_uids_to_send)}, sent={sent}")
return True
except Exception as ex:
logger.warning(f"Broadcast failed: {safe_str(ex)}")
return False
async def _send_to_user_safe(self, user_uid, message):
try:
return await self.send_to_user(user_uid, message)
except Exception as ex:
logger.debug(f"Failed to send to user {user_uid}: {safe_str(ex)}")
return 0
async def delete(self, ws):
if not ws:
return
sockets_to_remove = []
async with self._sockets_lock:
sockets_to_remove = [sock for sock in self.sockets if sock and sock.ws == ws]
for s in sockets_to_remove:
self.sockets.discard(s)
departures_to_broadcast = []
channels_to_cleanup = set()
for s in sockets_to_remove:
user_uid = s.user_uid
if not user_uid:
continue
user_lock = await self._get_user_lock(user_uid)
is_last_connection = False
async with user_lock:
if user_uid in self.users:
self.users[user_uid].discard(s)
if len(self.users[user_uid]) == 0:
del self.users[user_uid]
is_last_connection = True
if is_last_connection:
user_nick = None
try:
if s.user:
user_nick = safe_get(s.user, "nick") or safe_get(s.user, "username")
except Exception:
pass
if user_nick:
departures_to_broadcast.append((user_uid, user_nick, s.user_color))
for channel_uid in list(s.subscribed_channels):
channels_to_cleanup.add((channel_uid, user_uid))
for channel_uid, user_uid in channels_to_cleanup:
try:
channel_lock = await self._get_channel_lock(channel_uid)
async with channel_lock:
if channel_uid in self.subscriptions:
self.subscriptions[channel_uid].discard(user_uid)
if len(self.subscriptions[channel_uid]) == 0:
del self.subscriptions[channel_uid]
except Exception as ex:
logger.debug(f"Failed to cleanup channel subscription: {safe_str(ex)}")
for s in sockets_to_remove:
try:
username = safe_get(s.user, "username", "unknown") if s.user else "unknown"
logger.info(f"Removed socket for user {username}")
await s.close()
except Exception as ex:
logger.warning(f"Socket close failed: {safe_str(ex)}")
for user_uid, user_nick, user_color in departures_to_broadcast:
try:
await self._broadcast_presence("departed", user_uid, user_nick, user_color)
except Exception as ex:
logger.debug(f"Failed to broadcast departure: {safe_str(ex)}")
async def _broadcast_presence(self, event_type, user_uid, user_nick, user_color):
if not user_uid or not user_nick:
return
if not event_type or event_type not in ("arrived", "departed"):
return
try:
message = {
"event": "user_presence",
"data": {
"type": event_type,
"user_uid": user_uid,
"user_nick": user_nick,
"user_color": user_color,
"timestamp": datetime.now().isoformat(),
},
}
async with self._sockets_lock:
sockets_copy = list(self.sockets)
send_tasks = []
for s in sockets_copy:
if not s or not s.is_connected:
continue
if s.user_uid == user_uid:
continue
send_tasks.append(s.send_json(message))
if send_tasks:
results = await asyncio.gather(*send_tasks, return_exceptions=True)
sent_count = sum(1 for r in results if r is True)
logger.info(f"Broadcast presence '{event_type}' for {user_nick} to {sent_count} users")
except Exception as ex:
logger.warning(f"Broadcast presence failed: {safe_str(ex)}")
async def get_connected_users(self):
try:
return list(self.users.keys())
except Exception:
return []
async def get_user_socket_count(self, user_uid):
if not user_uid:
return 0
try:
user_lock = await self._get_user_lock(user_uid)
async with user_lock:
return len(self.users.get(user_uid, []))
except Exception:
return 0
async def is_user_online(self, user_uid):
if not user_uid:
return False
try:
user_lock = await self._get_user_lock(user_uid)
async with user_lock:
user_sockets = self.users.get(user_uid, set())
return any(s.is_connected for s in user_sockets if s)
except Exception:
return False