perf: remove async locks from balancer, socket service, and cache

feat: add db connection injection for testing in app
build: update pytest configuration
This commit is contained in:
retoor 2025-12-24 18:03:21 +01:00
parent 572d9610f2
commit 401e558011
7 changed files with 110 additions and 191 deletions

View File

@ -13,6 +13,14 @@
## Version 1.13.0 - 2025-12-24
Improves performance in the balancer, socket service, and cache by removing async locks. Adds database connection injection for testing in the app and updates pytest configuration.
**Changes:** 5 files, 291 lines
**Languages:** Other (6 lines), Python (285 lines)
## Version 1.12.0 - 2025-12-21 ## Version 1.12.0 - 2025-12-21
The socket service enhances concurrency and performance through improved locking and asynchronous handling. The RPC view renames the scheduled list to tasks for clearer API terminology. The socket service enhances concurrency and performance through improved locking and asynchronous handling. The RPC view renames the scheduled list to tasks for clearer API terminology.

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "Snek" name = "Snek"
version = "1.12.0" version = "1.13.0"
readme = "README.md" readme = "README.md"
#license = { file = "LICENSE", content-type="text/markdown" } #license = { file = "LICENSE", content-type="text/markdown" }
description = "Snek Chat Application by Molodetz" description = "Snek Chat Application by Molodetz"

View File

@ -1,11 +1,10 @@
[tool:pytest] [pytest]
testpaths = tests testpaths = tests
python_files = test_*.py python_files = test_*.py
python_classes = Test* python_classes = Test*
python_functions = test_* python_functions = test_*
addopts = addopts =
--strict-markers --strict-markers
--strict-config
--disable-warnings --disable-warnings
--tb=short --tb=short
-v -v
@ -13,5 +12,4 @@ markers =
unit: Unit tests unit: Unit tests
integration: Integration tests integration: Integration tests
slow: Slow running tests slow: Slow running tests
asyncio_mode = auto asyncio_mode = auto
asyncio_default_fixture_loop_scope = function

View File

@ -148,11 +148,12 @@ class Application(BaseApplication):
created_by_uid=admin_user["uid"], created_by_uid=admin_user["uid"],
) )
def __init__(self, *args, **kwargs): def __init__(self, *args, db_connection=None, **kwargs):
middlewares = [ middlewares = [
cors_middleware, cors_middleware,
csp_middleware, csp_middleware,
] ]
self._test_db = db_connection
self.template_path = pathlib.Path(__file__).parent.joinpath("templates") self.template_path = pathlib.Path(__file__).parent.joinpath("templates")
self.static_path = pathlib.Path(__file__).parent.joinpath("static") self.static_path = pathlib.Path(__file__).parent.joinpath("static")
super().__init__( super().__init__(
@ -196,6 +197,16 @@ class Application(BaseApplication):
self.on_startup.append(self.prepare_database) self.on_startup.append(self.prepare_database)
self.on_startup.append(self.create_default_forum) self.on_startup.append(self.create_default_forum)
@property
def db(self):
if self._test_db is not None:
return self._test_db
return self._db
@db.setter
def db(self, value):
self._db = value
@property @property
def uptime_seconds(self): def uptime_seconds(self):
return (datetime.now() - self.time_start).total_seconds() return (datetime.now() - self.time_start).total_seconds()

View File

@ -9,7 +9,6 @@ class LoadBalancer:
self.backend_ports = backend_ports self.backend_ports = backend_ports
self.backend_processes = [] self.backend_processes = []
self.client_counts = [0] * len(backend_ports) self.client_counts = [0] * len(backend_ports)
self.lock = asyncio.Lock()
async def start_backend_servers(self, port, workers): async def start_backend_servers(self, port, workers):
for x in range(workers): for x in range(workers):
@ -29,10 +28,9 @@ class LoadBalancer:
) )
async def handle_client(self, reader, writer): async def handle_client(self, reader, writer):
async with self.lock: min_clients = min(self.client_counts)
min_clients = min(self.client_counts) server_index = self.client_counts.index(min_clients)
server_index = self.client_counts.index(min_clients) self.client_counts[server_index] += 1
self.client_counts[server_index] += 1
backend = ("127.0.0.1", self.backend_ports[server_index]) backend = ("127.0.0.1", self.backend_ports[server_index])
try: try:
backend_reader, backend_writer = await asyncio.open_connection(*backend) backend_reader, backend_writer = await asyncio.open_connection(*backend)
@ -57,8 +55,7 @@ class LoadBalancer:
print(f"Error: {e}") print(f"Error: {e}")
finally: finally:
writer.close() writer.close()
async with self.lock: self.client_counts[server_index] -= 1
self.client_counts[server_index] -= 1
async def monitor(self): async def monitor(self):
while True: while True:

View File

@ -42,42 +42,39 @@ class SocketService(BaseService):
self.user = user self.user = user
self.user_uid = safe_get(user, "uid") if user else None self.user_uid = safe_get(user, "uid") if user else None
self.user_color = safe_get(user, "color") if user else None self.user_color = safe_get(user, "color") if user else None
self._lock = asyncio.Lock()
self.subscribed_channels = set() self.subscribed_channels = set()
async def send_json(self, data): async def send_json(self, data):
if data is None: if data is None:
return False return False
async with self._lock: if not self.is_connected:
if not self.is_connected:
return False
if not self.ws:
self.is_connected = False
return False
try:
await self.ws.send_json(data)
return True
except ConnectionResetError:
self.is_connected = False
logger.debug("Connection reset during send_json")
except Exception as ex:
self.is_connected = False
logger.debug(f"send_json failed: {safe_str(ex)}")
return False return False
if not self.ws:
self.is_connected = False
return False
try:
await self.ws.send_json(data)
return True
except ConnectionResetError:
self.is_connected = False
logger.debug("Connection reset during send_json")
except Exception as ex:
self.is_connected = False
logger.debug(f"send_json failed: {safe_str(ex)}")
return False
async def close(self): async def close(self):
async with self._lock: if not self.is_connected:
if not self.is_connected: return True
return True self.is_connected = False
self.is_connected = False self.subscribed_channels.clear()
self.subscribed_channels.clear() try:
try: if self.ws and not self.ws.closed:
if self.ws and not self.ws.closed: await asyncio.wait_for(self.ws.close(), timeout=5.0)
await asyncio.wait_for(self.ws.close(), timeout=5.0) except asyncio.TimeoutError:
except asyncio.TimeoutError: logger.debug("Socket close timed out")
logger.debug("Socket close timed out") except Exception as ex:
except Exception as ex: logger.debug(f"Socket close failed: {safe_str(ex)}")
logger.debug(f"Socket close failed: {safe_str(ex)}")
return True return True
def __init__(self, app): def __init__(self, app):
@ -86,30 +83,13 @@ class SocketService(BaseService):
self.users = {} self.users = {}
self.subscriptions = {} self.subscriptions = {}
self.last_update = str(datetime.now()) self.last_update = str(datetime.now())
self._sockets_lock = asyncio.Lock()
self._user_locks = {}
self._channel_locks = {}
self._registry_lock = asyncio.Lock()
async def _get_user_lock(self, user_uid):
async with self._registry_lock:
if user_uid not in self._user_locks:
self._user_locks[user_uid] = asyncio.Lock()
return self._user_locks[user_uid]
async def _get_channel_lock(self, channel_uid):
async with self._registry_lock:
if channel_uid not in self._channel_locks:
self._channel_locks[channel_uid] = asyncio.Lock()
return self._channel_locks[channel_uid]
async def user_availability_service(self): async def user_availability_service(self):
logger.info("User availability update service started.") logger.info("User availability update service started.")
while True: while True:
try: try:
async with self._sockets_lock: sockets_copy = list(self.sockets)
sockets_copy = list(self.sockets)
users_to_update = [] users_to_update = []
seen_user_uids = set() seen_user_uids = set()
for s in sockets_copy: for s in sockets_copy:
@ -159,17 +139,14 @@ class SocketService(BaseService):
username = safe_get(user, "username", "unknown") username = safe_get(user, "username", "unknown")
nick = safe_get(user, "nick") or username nick = safe_get(user, "nick") or username
color = s.user_color color = s.user_color
async with self._sockets_lock: self.sockets.add(s)
self.sockets.add(s)
is_first_connection = False is_first_connection = False
user_lock = await self._get_user_lock(user_uid) if user_uid not in self.users:
async with user_lock: self.users[user_uid] = set()
if user_uid not in self.users: is_first_connection = True
self.users[user_uid] = set() elif len(self.users[user_uid]) == 0:
is_first_connection = True is_first_connection = True
elif len(self.users[user_uid]) == 0: self.users[user_uid].add(s)
is_first_connection = True
self.users[user_uid].add(s)
try: try:
fresh_user = await self.app.services.user.get(uid=user_uid) fresh_user = await self.app.services.user.get(uid=user_uid)
if fresh_user: if fresh_user:
@ -189,22 +166,18 @@ class SocketService(BaseService):
if not ws or not channel_uid or not user_uid: if not ws or not channel_uid or not user_uid:
return False return False
try: try:
user_lock = await self._get_user_lock(user_uid)
channel_lock = await self._get_channel_lock(channel_uid)
existing_socket = None existing_socket = None
async with user_lock: user_sockets = self.users.get(user_uid, set())
user_sockets = self.users.get(user_uid, set()) for sock in user_sockets:
for sock in user_sockets: if sock and sock.ws == ws:
if sock and sock.ws == ws: existing_socket = sock
existing_socket = sock break
break if not existing_socket:
if not existing_socket: return False
return False existing_socket.subscribed_channels.add(channel_uid)
existing_socket.subscribed_channels.add(channel_uid) if channel_uid not in self.subscriptions:
async with channel_lock: self.subscriptions[channel_uid] = set()
if channel_uid not in self.subscriptions: self.subscriptions[channel_uid].add(user_uid)
self.subscriptions[channel_uid] = set()
self.subscriptions[channel_uid].add(user_uid)
return True return True
except Exception as ex: except Exception as ex:
logger.warning(f"Failed to subscribe: {safe_str(ex)}") logger.warning(f"Failed to subscribe: {safe_str(ex)}")
@ -215,9 +188,7 @@ class SocketService(BaseService):
return 0 return 0
count = 0 count = 0
try: try:
user_lock = await self._get_user_lock(user_uid) user_sockets = list(self.users.get(user_uid, []))
async with user_lock:
user_sockets = list(self.users.get(user_uid, []))
for s in user_sockets: for s in user_sockets:
if not s: if not s:
continue continue
@ -252,10 +223,8 @@ class SocketService(BaseService):
except Exception as ex: except Exception as ex:
logger.warning(f"Broadcast db query failed: {safe_str(ex)}") logger.warning(f"Broadcast db query failed: {safe_str(ex)}")
if not user_uids_to_send: if not user_uids_to_send:
channel_lock = await self._get_channel_lock(channel_uid) if channel_uid in self.subscriptions:
async with channel_lock: user_uids_to_send = set(self.subscriptions[channel_uid])
if channel_uid in self.subscriptions:
user_uids_to_send = set(self.subscriptions[channel_uid])
logger.debug(f"_broadcast: using {len(user_uids_to_send)} users from subscriptions for channel={channel_uid}") logger.debug(f"_broadcast: using {len(user_uids_to_send)} users from subscriptions for channel={channel_uid}")
send_tasks = [] send_tasks = []
for user_uid in user_uids_to_send: for user_uid in user_uids_to_send:
@ -281,25 +250,21 @@ class SocketService(BaseService):
async def delete(self, ws): async def delete(self, ws):
if not ws: if not ws:
return return
sockets_to_remove = [] sockets_to_remove = [sock for sock in self.sockets if sock and sock.ws == ws]
async with self._sockets_lock: for s in sockets_to_remove:
sockets_to_remove = [sock for sock in self.sockets if sock and sock.ws == ws] self.sockets.discard(s)
for s in sockets_to_remove:
self.sockets.discard(s)
departures_to_broadcast = [] departures_to_broadcast = []
channels_to_cleanup = set() channels_to_cleanup = set()
for s in sockets_to_remove: for s in sockets_to_remove:
user_uid = s.user_uid user_uid = s.user_uid
if not user_uid: if not user_uid:
continue continue
user_lock = await self._get_user_lock(user_uid)
is_last_connection = False is_last_connection = False
async with user_lock: if user_uid in self.users:
if user_uid in self.users: self.users[user_uid].discard(s)
self.users[user_uid].discard(s) if len(self.users[user_uid]) == 0:
if len(self.users[user_uid]) == 0: del self.users[user_uid]
del self.users[user_uid] is_last_connection = True
is_last_connection = True
if is_last_connection: if is_last_connection:
user_nick = None user_nick = None
try: try:
@ -313,12 +278,10 @@ class SocketService(BaseService):
channels_to_cleanup.add((channel_uid, user_uid)) channels_to_cleanup.add((channel_uid, user_uid))
for channel_uid, user_uid in channels_to_cleanup: for channel_uid, user_uid in channels_to_cleanup:
try: try:
channel_lock = await self._get_channel_lock(channel_uid) if channel_uid in self.subscriptions:
async with channel_lock: self.subscriptions[channel_uid].discard(user_uid)
if channel_uid in self.subscriptions: if len(self.subscriptions[channel_uid]) == 0:
self.subscriptions[channel_uid].discard(user_uid) del self.subscriptions[channel_uid]
if len(self.subscriptions[channel_uid]) == 0:
del self.subscriptions[channel_uid]
except Exception as ex: except Exception as ex:
logger.debug(f"Failed to cleanup channel subscription: {safe_str(ex)}") logger.debug(f"Failed to cleanup channel subscription: {safe_str(ex)}")
for s in sockets_to_remove: for s in sockets_to_remove:
@ -350,8 +313,7 @@ class SocketService(BaseService):
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
}, },
} }
async with self._sockets_lock: sockets_copy = list(self.sockets)
sockets_copy = list(self.sockets)
send_tasks = [] send_tasks = []
for s in sockets_copy: for s in sockets_copy:
if not s or not s.is_connected: if not s or not s.is_connected:
@ -376,9 +338,7 @@ class SocketService(BaseService):
if not user_uid: if not user_uid:
return 0 return 0
try: try:
user_lock = await self._get_user_lock(user_uid) return len(self.users.get(user_uid, []))
async with user_lock:
return len(self.users.get(user_uid, []))
except Exception: except Exception:
return 0 return 0
@ -386,9 +346,7 @@ class SocketService(BaseService):
if not user_uid: if not user_uid:
return False return False
try: try:
user_lock = await self._get_user_lock(user_uid) user_sockets = self.users.get(user_uid, set())
async with user_lock: return any(s.is_connected for s in user_sockets if s)
user_sockets = self.users.get(user_uid, set())
return any(s.is_connected for s in user_sockets if s)
except Exception: except Exception:
return False return False

View File

@ -1,6 +1,5 @@
# retoor <retoor@molodetz.nl> # retoor <retoor@molodetz.nl>
import asyncio
import functools import functools
import json import json
from collections import OrderedDict from collections import OrderedDict
@ -10,113 +9,65 @@ CACHE_MAX_ITEMS_DEFAULT = 5000
class Cache: class Cache:
"""
An asynchronous, thread-safe, in-memory LRU (Least Recently Used) cache.
This implementation uses an OrderedDict for efficient O(1) time complexity
for its core get, set, and delete operations.
"""
def __init__(self, app, max_items=CACHE_MAX_ITEMS_DEFAULT): def __init__(self, app, max_items=CACHE_MAX_ITEMS_DEFAULT):
self.app = app self.app = app
# OrderedDict is the core of the LRU logic. It remembers the order
# in which items were inserted.
self.cache: OrderedDict = OrderedDict() self.cache: OrderedDict = OrderedDict()
self.max_items = max_items self.max_items = max_items
self.stats = {} self.stats = {}
self.enabled = True self.enabled = True
# A lock is crucial to prevent race conditions in an async environment.
self._lock = asyncio.Lock()
self.version = ((42 + 420 + 1984 + 1990 + 10 + 6 + 71 + 3004 + 7245) ^ 1337) + 4 self.version = ((42 + 420 + 1984 + 1990 + 10 + 6 + 71 + 3004 + 7245) ^ 1337) + 4
async def get(self, key): async def get(self, key):
"""
Retrieves an item from the cache. If found, it's marked as recently used.
Returns None if the item is not found or the cache is disabled.
"""
if not self.enabled: if not self.enabled:
return None return None
#async with self._lock:
if key not in self.cache: if key not in self.cache:
await self.update_stat(key, "get") await self.update_stat(key, "get")
return None return None
# Mark as recently used by moving it to the end of the OrderedDict.
# This is an O(1) operation.
self.cache.move_to_end(key) self.cache.move_to_end(key)
await self.update_stat(key, "get") await self.update_stat(key, "get")
return self.cache[key] return self.cache[key]
async def set(self, key, value): async def set(self, key, value):
"""
Adds or updates an item in the cache and marks it as recently used.
If the cache exceeds its maximum size, the least recently used item is evicted.
"""
if not self.enabled: if not self.enabled:
return return
# comment
#async with self._lock:
is_new = key not in self.cache is_new = key not in self.cache
# Add or update the item. If it exists, it's moved to the end.
self.cache[key] = value self.cache[key] = value
self.cache.move_to_end(key) self.cache.move_to_end(key)
await self.update_stat(key, "set") await self.update_stat(key, "set")
# Evict the least recently used item if the cache is full.
# This is an O(1) operation.
if len(self.cache) > self.max_items: if len(self.cache) > self.max_items:
# popitem(last=False) removes and returns the first (oldest) item. self.cache.popitem(last=False)
evicted_key, _ = self.cache.popitem(last=False)
# Optionally, you could log the evicted key here.
if is_new: if is_new:
self.version += 1 self.version += 1
async def delete(self, key): async def delete(self, key):
"""Removes an item from the cache if it exists."""
if not self.enabled: if not self.enabled:
return return
if key in self.cache:
async with self._lock: await self.update_stat(key, "delete")
if key in self.cache: del self.cache[key]
await self.update_stat(key, "delete")
# Deleting from OrderedDict is an O(1) operation on average.
del self.cache[key]
async def get_stats(self): async def get_stats(self):
"""Returns statistics for all items currently in the cache.""" stats_list = []
async with self._lock: for key in reversed(self.cache):
stats_list = [] stat_data = self.stats.get(key, {"set": 0, "get": 0, "delete": 0})
# Items are iterated from oldest to newest. We reverse to show value = self.cache[key]
# most recently used items first. value_record = value.record if hasattr(value, 'record') else value
for key in reversed(self.cache): stats_list.append({
stat_data = self.stats.get(key, {"set": 0, "get": 0, "delete": 0}) "key": key,
value = self.cache[key] "set": stat_data.get("set", 0),
value_record = value.record if hasattr(value, 'record') else value "get": stat_data.get("get", 0),
"delete": stat_data.get("delete", 0),
stats_list.append({ "value": str(self.serialize(value_record)),
"key": key, })
"set": stat_data.get("set", 0), return stats_list
"get": stat_data.get("get", 0),
"delete": stat_data.get("delete", 0),
"value": str(self.serialize(value_record)),
})
return stats_list
async def update_stat(self, key, action): async def update_stat(self, key, action):
"""Updates hit/miss/set counts for a given cache key.""" if key not in self.stats:
# This method is already called within a locked context, self.stats[key] = {"set": 0, "get": 0, "delete": 0}
# but the lock makes it safe if ever called directly. self.stats[key][action] += 1
async with self._lock:
if key not in self.stats:
self.stats[key] = {"set": 0, "get": 0, "delete": 0}
self.stats[key][action] += 1
def serialize(self, obj): def serialize(self, obj):
"""A synchronous helper to create a serializable representation of an object."""
if not isinstance(obj, dict): if not isinstance(obj, dict):
return obj return obj
cpy = obj.copy() cpy = obj.copy()
@ -131,8 +82,6 @@ class Cache:
return str(value) return str(value)
async def create_cache_key(self, args, kwargs): async def create_cache_key(self, args, kwargs):
"""Creates a consistent, hashable cache key from function arguments."""
# security.hash is async, so this method remains async.
return await security.hash( return await security.hash(
json.dumps( json.dumps(
{"args": args, "kwargs": kwargs}, {"args": args, "kwargs": kwargs},
@ -142,7 +91,6 @@ class Cache:
) )
def async_cache(self, func): def async_cache(self, func):
"""Decorator to cache the results of an async function."""
@functools.wraps(func) @functools.wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
cache_key = await self.create_cache_key(args, kwargs) cache_key = await self.create_cache_key(args, kwargs)
@ -155,7 +103,6 @@ class Cache:
return wrapper return wrapper
def async_delete_cache(self, func): def async_delete_cache(self, func):
"""Decorator to invalidate a cache entry before running an async function."""
@functools.wraps(func) @functools.wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
cache_key = await self.create_cache_key(args, kwargs) cache_key = await self.create_cache_key(args, kwargs)