From 401e5580115443bee6e6e12a71a030c051a1fed8 Mon Sep 17 00:00:00 2001 From: retoor Date: Wed, 24 Dec 2025 18:03:21 +0100 Subject: [PATCH] perf: remove async locks from balancer, socket service, and cache feat: add db connection injection for testing in app build: update pytest configuration --- CHANGELOG.md | 8 ++ pyproject.toml | 2 +- pytest.ini | 6 +- src/snek/app.py | 13 ++- src/snek/balancer.py | 11 +-- src/snek/service/socket.py | 168 ++++++++++++++----------------------- src/snek/system/cache.py | 93 +++++--------------- 7 files changed, 110 insertions(+), 191 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e4aba67..45a742d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,14 @@ + +## Version 1.13.0 - 2025-12-24 + +Improves performance in the balancer, socket service, and cache by removing async locks. Adds database connection injection for testing in the app and updates pytest configuration. + +**Changes:** 5 files, 291 lines +**Languages:** Other (6 lines), Python (285 lines) + ## Version 1.12.0 - 2025-12-21 The socket service enhances concurrency and performance through improved locking and asynchronous handling. The RPC view renames the scheduled list to tasks for clearer API terminology. diff --git a/pyproject.toml b/pyproject.toml index 2ff7225..ba01daa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "Snek" -version = "1.12.0" +version = "1.13.0" readme = "README.md" #license = { file = "LICENSE", content-type="text/markdown" } description = "Snek Chat Application by Molodetz" diff --git a/pytest.ini b/pytest.ini index dc3bca0..716765e 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,11 +1,10 @@ -[tool:pytest] +[pytest] testpaths = tests python_files = test_*.py python_classes = Test* python_functions = test_* addopts = --strict-markers - --strict-config --disable-warnings --tb=short -v @@ -13,5 +12,4 @@ markers = unit: Unit tests integration: Integration tests slow: Slow running tests -asyncio_mode = auto -asyncio_default_fixture_loop_scope = function \ No newline at end of file +asyncio_mode = auto \ No newline at end of file diff --git a/src/snek/app.py b/src/snek/app.py index 9b93842..b98525f 100644 --- a/src/snek/app.py +++ b/src/snek/app.py @@ -148,11 +148,12 @@ class Application(BaseApplication): created_by_uid=admin_user["uid"], ) - def __init__(self, *args, **kwargs): + def __init__(self, *args, db_connection=None, **kwargs): middlewares = [ cors_middleware, csp_middleware, ] + self._test_db = db_connection self.template_path = pathlib.Path(__file__).parent.joinpath("templates") self.static_path = pathlib.Path(__file__).parent.joinpath("static") super().__init__( @@ -196,6 +197,16 @@ class Application(BaseApplication): self.on_startup.append(self.prepare_database) self.on_startup.append(self.create_default_forum) + @property + def db(self): + if self._test_db is not None: + return self._test_db + return self._db + + @db.setter + def db(self, value): + self._db = value + @property def uptime_seconds(self): return (datetime.now() - self.time_start).total_seconds() diff --git a/src/snek/balancer.py b/src/snek/balancer.py index 758cd60..874737f 100644 --- a/src/snek/balancer.py +++ b/src/snek/balancer.py @@ -9,7 +9,6 @@ class LoadBalancer: self.backend_ports = backend_ports self.backend_processes = [] self.client_counts = [0] * len(backend_ports) - self.lock = asyncio.Lock() async def start_backend_servers(self, port, workers): for x in range(workers): @@ -29,10 +28,9 @@ class LoadBalancer: ) async def handle_client(self, reader, writer): - async with self.lock: - min_clients = min(self.client_counts) - server_index = self.client_counts.index(min_clients) - self.client_counts[server_index] += 1 + min_clients = min(self.client_counts) + server_index = self.client_counts.index(min_clients) + self.client_counts[server_index] += 1 backend = ("127.0.0.1", self.backend_ports[server_index]) try: backend_reader, backend_writer = await asyncio.open_connection(*backend) @@ -57,8 +55,7 @@ class LoadBalancer: print(f"Error: {e}") finally: writer.close() - async with self.lock: - self.client_counts[server_index] -= 1 + self.client_counts[server_index] -= 1 async def monitor(self): while True: diff --git a/src/snek/service/socket.py b/src/snek/service/socket.py index 782bad6..47c4212 100644 --- a/src/snek/service/socket.py +++ b/src/snek/service/socket.py @@ -42,42 +42,39 @@ class SocketService(BaseService): self.user = user self.user_uid = safe_get(user, "uid") if user else None self.user_color = safe_get(user, "color") if user else None - self._lock = asyncio.Lock() self.subscribed_channels = set() async def send_json(self, data): if data is None: return False - async with self._lock: - if not self.is_connected: - return False - if not self.ws: - self.is_connected = False - return False - try: - await self.ws.send_json(data) - return True - except ConnectionResetError: - self.is_connected = False - logger.debug("Connection reset during send_json") - except Exception as ex: - self.is_connected = False - logger.debug(f"send_json failed: {safe_str(ex)}") + if not self.is_connected: return False + if not self.ws: + self.is_connected = False + return False + try: + await self.ws.send_json(data) + return True + except ConnectionResetError: + self.is_connected = False + logger.debug("Connection reset during send_json") + except Exception as ex: + self.is_connected = False + logger.debug(f"send_json failed: {safe_str(ex)}") + return False async def close(self): - async with self._lock: - if not self.is_connected: - return True - self.is_connected = False - self.subscribed_channels.clear() - try: - if self.ws and not self.ws.closed: - await asyncio.wait_for(self.ws.close(), timeout=5.0) - except asyncio.TimeoutError: - logger.debug("Socket close timed out") - except Exception as ex: - logger.debug(f"Socket close failed: {safe_str(ex)}") + if not self.is_connected: + return True + self.is_connected = False + self.subscribed_channels.clear() + try: + if self.ws and not self.ws.closed: + await asyncio.wait_for(self.ws.close(), timeout=5.0) + except asyncio.TimeoutError: + logger.debug("Socket close timed out") + except Exception as ex: + logger.debug(f"Socket close failed: {safe_str(ex)}") return True def __init__(self, app): @@ -86,30 +83,13 @@ class SocketService(BaseService): self.users = {} self.subscriptions = {} self.last_update = str(datetime.now()) - self._sockets_lock = asyncio.Lock() - self._user_locks = {} - self._channel_locks = {} - self._registry_lock = asyncio.Lock() - - async def _get_user_lock(self, user_uid): - async with self._registry_lock: - if user_uid not in self._user_locks: - self._user_locks[user_uid] = asyncio.Lock() - return self._user_locks[user_uid] - - async def _get_channel_lock(self, channel_uid): - async with self._registry_lock: - if channel_uid not in self._channel_locks: - self._channel_locks[channel_uid] = asyncio.Lock() - return self._channel_locks[channel_uid] async def user_availability_service(self): logger.info("User availability update service started.") while True: try: - async with self._sockets_lock: - sockets_copy = list(self.sockets) + sockets_copy = list(self.sockets) users_to_update = [] seen_user_uids = set() for s in sockets_copy: @@ -159,17 +139,14 @@ class SocketService(BaseService): username = safe_get(user, "username", "unknown") nick = safe_get(user, "nick") or username color = s.user_color - async with self._sockets_lock: - self.sockets.add(s) + self.sockets.add(s) is_first_connection = False - user_lock = await self._get_user_lock(user_uid) - async with user_lock: - if user_uid not in self.users: - self.users[user_uid] = set() - is_first_connection = True - elif len(self.users[user_uid]) == 0: - is_first_connection = True - self.users[user_uid].add(s) + if user_uid not in self.users: + self.users[user_uid] = set() + is_first_connection = True + elif len(self.users[user_uid]) == 0: + is_first_connection = True + self.users[user_uid].add(s) try: fresh_user = await self.app.services.user.get(uid=user_uid) if fresh_user: @@ -189,22 +166,18 @@ class SocketService(BaseService): if not ws or not channel_uid or not user_uid: return False try: - user_lock = await self._get_user_lock(user_uid) - channel_lock = await self._get_channel_lock(channel_uid) existing_socket = None - async with user_lock: - user_sockets = self.users.get(user_uid, set()) - for sock in user_sockets: - if sock and sock.ws == ws: - existing_socket = sock - break - if not existing_socket: - return False - existing_socket.subscribed_channels.add(channel_uid) - async with channel_lock: - if channel_uid not in self.subscriptions: - self.subscriptions[channel_uid] = set() - self.subscriptions[channel_uid].add(user_uid) + user_sockets = self.users.get(user_uid, set()) + for sock in user_sockets: + if sock and sock.ws == ws: + existing_socket = sock + break + if not existing_socket: + return False + existing_socket.subscribed_channels.add(channel_uid) + if channel_uid not in self.subscriptions: + self.subscriptions[channel_uid] = set() + self.subscriptions[channel_uid].add(user_uid) return True except Exception as ex: logger.warning(f"Failed to subscribe: {safe_str(ex)}") @@ -215,9 +188,7 @@ class SocketService(BaseService): return 0 count = 0 try: - user_lock = await self._get_user_lock(user_uid) - async with user_lock: - user_sockets = list(self.users.get(user_uid, [])) + user_sockets = list(self.users.get(user_uid, [])) for s in user_sockets: if not s: continue @@ -252,10 +223,8 @@ class SocketService(BaseService): except Exception as ex: logger.warning(f"Broadcast db query failed: {safe_str(ex)}") if not user_uids_to_send: - channel_lock = await self._get_channel_lock(channel_uid) - async with channel_lock: - if channel_uid in self.subscriptions: - user_uids_to_send = set(self.subscriptions[channel_uid]) + if channel_uid in self.subscriptions: + user_uids_to_send = set(self.subscriptions[channel_uid]) logger.debug(f"_broadcast: using {len(user_uids_to_send)} users from subscriptions for channel={channel_uid}") send_tasks = [] for user_uid in user_uids_to_send: @@ -281,25 +250,21 @@ class SocketService(BaseService): async def delete(self, ws): if not ws: return - sockets_to_remove = [] - async with self._sockets_lock: - sockets_to_remove = [sock for sock in self.sockets if sock and sock.ws == ws] - for s in sockets_to_remove: - self.sockets.discard(s) + sockets_to_remove = [sock for sock in self.sockets if sock and sock.ws == ws] + for s in sockets_to_remove: + self.sockets.discard(s) departures_to_broadcast = [] channels_to_cleanup = set() for s in sockets_to_remove: user_uid = s.user_uid if not user_uid: continue - user_lock = await self._get_user_lock(user_uid) is_last_connection = False - async with user_lock: - if user_uid in self.users: - self.users[user_uid].discard(s) - if len(self.users[user_uid]) == 0: - del self.users[user_uid] - is_last_connection = True + if user_uid in self.users: + self.users[user_uid].discard(s) + if len(self.users[user_uid]) == 0: + del self.users[user_uid] + is_last_connection = True if is_last_connection: user_nick = None try: @@ -313,12 +278,10 @@ class SocketService(BaseService): channels_to_cleanup.add((channel_uid, user_uid)) for channel_uid, user_uid in channels_to_cleanup: try: - channel_lock = await self._get_channel_lock(channel_uid) - async with channel_lock: - if channel_uid in self.subscriptions: - self.subscriptions[channel_uid].discard(user_uid) - if len(self.subscriptions[channel_uid]) == 0: - del self.subscriptions[channel_uid] + if channel_uid in self.subscriptions: + self.subscriptions[channel_uid].discard(user_uid) + if len(self.subscriptions[channel_uid]) == 0: + del self.subscriptions[channel_uid] except Exception as ex: logger.debug(f"Failed to cleanup channel subscription: {safe_str(ex)}") for s in sockets_to_remove: @@ -350,8 +313,7 @@ class SocketService(BaseService): "timestamp": datetime.now().isoformat(), }, } - async with self._sockets_lock: - sockets_copy = list(self.sockets) + sockets_copy = list(self.sockets) send_tasks = [] for s in sockets_copy: if not s or not s.is_connected: @@ -376,9 +338,7 @@ class SocketService(BaseService): if not user_uid: return 0 try: - user_lock = await self._get_user_lock(user_uid) - async with user_lock: - return len(self.users.get(user_uid, [])) + return len(self.users.get(user_uid, [])) except Exception: return 0 @@ -386,9 +346,7 @@ class SocketService(BaseService): if not user_uid: return False try: - user_lock = await self._get_user_lock(user_uid) - async with user_lock: - user_sockets = self.users.get(user_uid, set()) - return any(s.is_connected for s in user_sockets if s) + user_sockets = self.users.get(user_uid, set()) + return any(s.is_connected for s in user_sockets if s) except Exception: return False diff --git a/src/snek/system/cache.py b/src/snek/system/cache.py index 522351f..a2c0b7b 100644 --- a/src/snek/system/cache.py +++ b/src/snek/system/cache.py @@ -1,6 +1,5 @@ # retoor -import asyncio import functools import json from collections import OrderedDict @@ -10,113 +9,65 @@ CACHE_MAX_ITEMS_DEFAULT = 5000 class Cache: - """ - An asynchronous, thread-safe, in-memory LRU (Least Recently Used) cache. - - This implementation uses an OrderedDict for efficient O(1) time complexity - for its core get, set, and delete operations. - """ def __init__(self, app, max_items=CACHE_MAX_ITEMS_DEFAULT): self.app = app - # OrderedDict is the core of the LRU logic. It remembers the order - # in which items were inserted. self.cache: OrderedDict = OrderedDict() self.max_items = max_items self.stats = {} self.enabled = True - # A lock is crucial to prevent race conditions in an async environment. - self._lock = asyncio.Lock() self.version = ((42 + 420 + 1984 + 1990 + 10 + 6 + 71 + 3004 + 7245) ^ 1337) + 4 async def get(self, key): - """ - Retrieves an item from the cache. If found, it's marked as recently used. - Returns None if the item is not found or the cache is disabled. - """ if not self.enabled: return None - - #async with self._lock: if key not in self.cache: await self.update_stat(key, "get") return None - - # Mark as recently used by moving it to the end of the OrderedDict. - # This is an O(1) operation. self.cache.move_to_end(key) await self.update_stat(key, "get") return self.cache[key] async def set(self, key, value): - """ - Adds or updates an item in the cache and marks it as recently used. - If the cache exceeds its maximum size, the least recently used item is evicted. - """ if not self.enabled: return - # comment - #async with self._lock: is_new = key not in self.cache - - # Add or update the item. If it exists, it's moved to the end. self.cache[key] = value self.cache.move_to_end(key) - await self.update_stat(key, "set") - - # Evict the least recently used item if the cache is full. - # This is an O(1) operation. if len(self.cache) > self.max_items: - # popitem(last=False) removes and returns the first (oldest) item. - evicted_key, _ = self.cache.popitem(last=False) - # Optionally, you could log the evicted key here. - + self.cache.popitem(last=False) if is_new: self.version += 1 async def delete(self, key): - """Removes an item from the cache if it exists.""" if not self.enabled: return - - async with self._lock: - if key in self.cache: - await self.update_stat(key, "delete") - # Deleting from OrderedDict is an O(1) operation on average. - del self.cache[key] + if key in self.cache: + await self.update_stat(key, "delete") + del self.cache[key] async def get_stats(self): - """Returns statistics for all items currently in the cache.""" - async with self._lock: - stats_list = [] - # Items are iterated from oldest to newest. We reverse to show - # most recently used items first. - for key in reversed(self.cache): - stat_data = self.stats.get(key, {"set": 0, "get": 0, "delete": 0}) - value = self.cache[key] - value_record = value.record if hasattr(value, 'record') else value - - stats_list.append({ - "key": key, - "set": stat_data.get("set", 0), - "get": stat_data.get("get", 0), - "delete": stat_data.get("delete", 0), - "value": str(self.serialize(value_record)), - }) - return stats_list + stats_list = [] + for key in reversed(self.cache): + stat_data = self.stats.get(key, {"set": 0, "get": 0, "delete": 0}) + value = self.cache[key] + value_record = value.record if hasattr(value, 'record') else value + stats_list.append({ + "key": key, + "set": stat_data.get("set", 0), + "get": stat_data.get("get", 0), + "delete": stat_data.get("delete", 0), + "value": str(self.serialize(value_record)), + }) + return stats_list async def update_stat(self, key, action): - """Updates hit/miss/set counts for a given cache key.""" - # This method is already called within a locked context, - # but the lock makes it safe if ever called directly. - async with self._lock: - if key not in self.stats: - self.stats[key] = {"set": 0, "get": 0, "delete": 0} - self.stats[key][action] += 1 + if key not in self.stats: + self.stats[key] = {"set": 0, "get": 0, "delete": 0} + self.stats[key][action] += 1 def serialize(self, obj): - """A synchronous helper to create a serializable representation of an object.""" if not isinstance(obj, dict): return obj cpy = obj.copy() @@ -131,8 +82,6 @@ class Cache: return str(value) async def create_cache_key(self, args, kwargs): - """Creates a consistent, hashable cache key from function arguments.""" - # security.hash is async, so this method remains async. return await security.hash( json.dumps( {"args": args, "kwargs": kwargs}, @@ -142,7 +91,6 @@ class Cache: ) def async_cache(self, func): - """Decorator to cache the results of an async function.""" @functools.wraps(func) async def wrapper(*args, **kwargs): cache_key = await self.create_cache_key(args, kwargs) @@ -155,7 +103,6 @@ class Cache: return wrapper def async_delete_cache(self, func): - """Decorator to invalidate a cache entry before running an async function.""" @functools.wraps(func) async def wrapper(*args, **kwargs): cache_key = await self.create_cache_key(args, kwargs)