perf: remove async locks from balancer, socket service, and cache

feat: add db connection injection for testing in app
build: update pytest configuration
This commit is contained in:
retoor 2025-12-24 18:03:21 +01:00
parent 572d9610f2
commit 401e558011
7 changed files with 110 additions and 191 deletions

View File

@ -13,6 +13,14 @@
## Version 1.13.0 - 2025-12-24
Improves performance in the balancer, socket service, and cache by removing async locks. Adds database connection injection for testing in the app and updates pytest configuration.
**Changes:** 5 files, 291 lines
**Languages:** Other (6 lines), Python (285 lines)
## Version 1.12.0 - 2025-12-21 ## Version 1.12.0 - 2025-12-21
The socket service enhances concurrency and performance through improved locking and asynchronous handling. The RPC view renames the scheduled list to tasks for clearer API terminology. The socket service enhances concurrency and performance through improved locking and asynchronous handling. The RPC view renames the scheduled list to tasks for clearer API terminology.

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "Snek" name = "Snek"
version = "1.12.0" version = "1.13.0"
readme = "README.md" readme = "README.md"
#license = { file = "LICENSE", content-type="text/markdown" } #license = { file = "LICENSE", content-type="text/markdown" }
description = "Snek Chat Application by Molodetz" description = "Snek Chat Application by Molodetz"

View File

@ -1,11 +1,10 @@
[tool:pytest] [pytest]
testpaths = tests testpaths = tests
python_files = test_*.py python_files = test_*.py
python_classes = Test* python_classes = Test*
python_functions = test_* python_functions = test_*
addopts = addopts =
--strict-markers --strict-markers
--strict-config
--disable-warnings --disable-warnings
--tb=short --tb=short
-v -v
@ -14,4 +13,3 @@ markers =
integration: Integration tests integration: Integration tests
slow: Slow running tests slow: Slow running tests
asyncio_mode = auto asyncio_mode = auto
asyncio_default_fixture_loop_scope = function

View File

@ -148,11 +148,12 @@ class Application(BaseApplication):
created_by_uid=admin_user["uid"], created_by_uid=admin_user["uid"],
) )
def __init__(self, *args, **kwargs): def __init__(self, *args, db_connection=None, **kwargs):
middlewares = [ middlewares = [
cors_middleware, cors_middleware,
csp_middleware, csp_middleware,
] ]
self._test_db = db_connection
self.template_path = pathlib.Path(__file__).parent.joinpath("templates") self.template_path = pathlib.Path(__file__).parent.joinpath("templates")
self.static_path = pathlib.Path(__file__).parent.joinpath("static") self.static_path = pathlib.Path(__file__).parent.joinpath("static")
super().__init__( super().__init__(
@ -196,6 +197,16 @@ class Application(BaseApplication):
self.on_startup.append(self.prepare_database) self.on_startup.append(self.prepare_database)
self.on_startup.append(self.create_default_forum) self.on_startup.append(self.create_default_forum)
@property
def db(self):
if self._test_db is not None:
return self._test_db
return self._db
@db.setter
def db(self, value):
self._db = value
@property @property
def uptime_seconds(self): def uptime_seconds(self):
return (datetime.now() - self.time_start).total_seconds() return (datetime.now() - self.time_start).total_seconds()

View File

@ -9,7 +9,6 @@ class LoadBalancer:
self.backend_ports = backend_ports self.backend_ports = backend_ports
self.backend_processes = [] self.backend_processes = []
self.client_counts = [0] * len(backend_ports) self.client_counts = [0] * len(backend_ports)
self.lock = asyncio.Lock()
async def start_backend_servers(self, port, workers): async def start_backend_servers(self, port, workers):
for x in range(workers): for x in range(workers):
@ -29,7 +28,6 @@ class LoadBalancer:
) )
async def handle_client(self, reader, writer): async def handle_client(self, reader, writer):
async with self.lock:
min_clients = min(self.client_counts) min_clients = min(self.client_counts)
server_index = self.client_counts.index(min_clients) server_index = self.client_counts.index(min_clients)
self.client_counts[server_index] += 1 self.client_counts[server_index] += 1
@ -57,7 +55,6 @@ class LoadBalancer:
print(f"Error: {e}") print(f"Error: {e}")
finally: finally:
writer.close() writer.close()
async with self.lock:
self.client_counts[server_index] -= 1 self.client_counts[server_index] -= 1
async def monitor(self): async def monitor(self):

View File

@ -42,13 +42,11 @@ class SocketService(BaseService):
self.user = user self.user = user
self.user_uid = safe_get(user, "uid") if user else None self.user_uid = safe_get(user, "uid") if user else None
self.user_color = safe_get(user, "color") if user else None self.user_color = safe_get(user, "color") if user else None
self._lock = asyncio.Lock()
self.subscribed_channels = set() self.subscribed_channels = set()
async def send_json(self, data): async def send_json(self, data):
if data is None: if data is None:
return False return False
async with self._lock:
if not self.is_connected: if not self.is_connected:
return False return False
if not self.ws: if not self.ws:
@ -66,7 +64,6 @@ class SocketService(BaseService):
return False return False
async def close(self): async def close(self):
async with self._lock:
if not self.is_connected: if not self.is_connected:
return True return True
self.is_connected = False self.is_connected = False
@ -86,29 +83,12 @@ class SocketService(BaseService):
self.users = {} self.users = {}
self.subscriptions = {} self.subscriptions = {}
self.last_update = str(datetime.now()) self.last_update = str(datetime.now())
self._sockets_lock = asyncio.Lock()
self._user_locks = {}
self._channel_locks = {}
self._registry_lock = asyncio.Lock()
async def _get_user_lock(self, user_uid):
async with self._registry_lock:
if user_uid not in self._user_locks:
self._user_locks[user_uid] = asyncio.Lock()
return self._user_locks[user_uid]
async def _get_channel_lock(self, channel_uid):
async with self._registry_lock:
if channel_uid not in self._channel_locks:
self._channel_locks[channel_uid] = asyncio.Lock()
return self._channel_locks[channel_uid]
async def user_availability_service(self): async def user_availability_service(self):
logger.info("User availability update service started.") logger.info("User availability update service started.")
while True: while True:
try: try:
async with self._sockets_lock:
sockets_copy = list(self.sockets) sockets_copy = list(self.sockets)
users_to_update = [] users_to_update = []
seen_user_uids = set() seen_user_uids = set()
@ -159,11 +139,8 @@ class SocketService(BaseService):
username = safe_get(user, "username", "unknown") username = safe_get(user, "username", "unknown")
nick = safe_get(user, "nick") or username nick = safe_get(user, "nick") or username
color = s.user_color color = s.user_color
async with self._sockets_lock:
self.sockets.add(s) self.sockets.add(s)
is_first_connection = False is_first_connection = False
user_lock = await self._get_user_lock(user_uid)
async with user_lock:
if user_uid not in self.users: if user_uid not in self.users:
self.users[user_uid] = set() self.users[user_uid] = set()
is_first_connection = True is_first_connection = True
@ -189,10 +166,7 @@ class SocketService(BaseService):
if not ws or not channel_uid or not user_uid: if not ws or not channel_uid or not user_uid:
return False return False
try: try:
user_lock = await self._get_user_lock(user_uid)
channel_lock = await self._get_channel_lock(channel_uid)
existing_socket = None existing_socket = None
async with user_lock:
user_sockets = self.users.get(user_uid, set()) user_sockets = self.users.get(user_uid, set())
for sock in user_sockets: for sock in user_sockets:
if sock and sock.ws == ws: if sock and sock.ws == ws:
@ -201,7 +175,6 @@ class SocketService(BaseService):
if not existing_socket: if not existing_socket:
return False return False
existing_socket.subscribed_channels.add(channel_uid) existing_socket.subscribed_channels.add(channel_uid)
async with channel_lock:
if channel_uid not in self.subscriptions: if channel_uid not in self.subscriptions:
self.subscriptions[channel_uid] = set() self.subscriptions[channel_uid] = set()
self.subscriptions[channel_uid].add(user_uid) self.subscriptions[channel_uid].add(user_uid)
@ -215,8 +188,6 @@ class SocketService(BaseService):
return 0 return 0
count = 0 count = 0
try: try:
user_lock = await self._get_user_lock(user_uid)
async with user_lock:
user_sockets = list(self.users.get(user_uid, [])) user_sockets = list(self.users.get(user_uid, []))
for s in user_sockets: for s in user_sockets:
if not s: if not s:
@ -252,8 +223,6 @@ class SocketService(BaseService):
except Exception as ex: except Exception as ex:
logger.warning(f"Broadcast db query failed: {safe_str(ex)}") logger.warning(f"Broadcast db query failed: {safe_str(ex)}")
if not user_uids_to_send: if not user_uids_to_send:
channel_lock = await self._get_channel_lock(channel_uid)
async with channel_lock:
if channel_uid in self.subscriptions: if channel_uid in self.subscriptions:
user_uids_to_send = set(self.subscriptions[channel_uid]) user_uids_to_send = set(self.subscriptions[channel_uid])
logger.debug(f"_broadcast: using {len(user_uids_to_send)} users from subscriptions for channel={channel_uid}") logger.debug(f"_broadcast: using {len(user_uids_to_send)} users from subscriptions for channel={channel_uid}")
@ -281,8 +250,6 @@ class SocketService(BaseService):
async def delete(self, ws): async def delete(self, ws):
if not ws: if not ws:
return return
sockets_to_remove = []
async with self._sockets_lock:
sockets_to_remove = [sock for sock in self.sockets if sock and sock.ws == ws] sockets_to_remove = [sock for sock in self.sockets if sock and sock.ws == ws]
for s in sockets_to_remove: for s in sockets_to_remove:
self.sockets.discard(s) self.sockets.discard(s)
@ -292,9 +259,7 @@ class SocketService(BaseService):
user_uid = s.user_uid user_uid = s.user_uid
if not user_uid: if not user_uid:
continue continue
user_lock = await self._get_user_lock(user_uid)
is_last_connection = False is_last_connection = False
async with user_lock:
if user_uid in self.users: if user_uid in self.users:
self.users[user_uid].discard(s) self.users[user_uid].discard(s)
if len(self.users[user_uid]) == 0: if len(self.users[user_uid]) == 0:
@ -313,8 +278,6 @@ class SocketService(BaseService):
channels_to_cleanup.add((channel_uid, user_uid)) channels_to_cleanup.add((channel_uid, user_uid))
for channel_uid, user_uid in channels_to_cleanup: for channel_uid, user_uid in channels_to_cleanup:
try: try:
channel_lock = await self._get_channel_lock(channel_uid)
async with channel_lock:
if channel_uid in self.subscriptions: if channel_uid in self.subscriptions:
self.subscriptions[channel_uid].discard(user_uid) self.subscriptions[channel_uid].discard(user_uid)
if len(self.subscriptions[channel_uid]) == 0: if len(self.subscriptions[channel_uid]) == 0:
@ -350,7 +313,6 @@ class SocketService(BaseService):
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
}, },
} }
async with self._sockets_lock:
sockets_copy = list(self.sockets) sockets_copy = list(self.sockets)
send_tasks = [] send_tasks = []
for s in sockets_copy: for s in sockets_copy:
@ -376,8 +338,6 @@ class SocketService(BaseService):
if not user_uid: if not user_uid:
return 0 return 0
try: try:
user_lock = await self._get_user_lock(user_uid)
async with user_lock:
return len(self.users.get(user_uid, [])) return len(self.users.get(user_uid, []))
except Exception: except Exception:
return 0 return 0
@ -386,8 +346,6 @@ class SocketService(BaseService):
if not user_uid: if not user_uid:
return False return False
try: try:
user_lock = await self._get_user_lock(user_uid)
async with user_lock:
user_sockets = self.users.get(user_uid, set()) user_sockets = self.users.get(user_uid, set())
return any(s.is_connected for s in user_sockets if s) return any(s.is_connected for s in user_sockets if s)
except Exception: except Exception:

View File

@ -1,6 +1,5 @@
# retoor <retoor@molodetz.nl> # retoor <retoor@molodetz.nl>
import asyncio
import functools import functools
import json import json
from collections import OrderedDict from collections import OrderedDict
@ -10,93 +9,50 @@ CACHE_MAX_ITEMS_DEFAULT = 5000
class Cache: class Cache:
"""
An asynchronous, thread-safe, in-memory LRU (Least Recently Used) cache.
This implementation uses an OrderedDict for efficient O(1) time complexity
for its core get, set, and delete operations.
"""
def __init__(self, app, max_items=CACHE_MAX_ITEMS_DEFAULT): def __init__(self, app, max_items=CACHE_MAX_ITEMS_DEFAULT):
self.app = app self.app = app
# OrderedDict is the core of the LRU logic. It remembers the order
# in which items were inserted.
self.cache: OrderedDict = OrderedDict() self.cache: OrderedDict = OrderedDict()
self.max_items = max_items self.max_items = max_items
self.stats = {} self.stats = {}
self.enabled = True self.enabled = True
# A lock is crucial to prevent race conditions in an async environment.
self._lock = asyncio.Lock()
self.version = ((42 + 420 + 1984 + 1990 + 10 + 6 + 71 + 3004 + 7245) ^ 1337) + 4 self.version = ((42 + 420 + 1984 + 1990 + 10 + 6 + 71 + 3004 + 7245) ^ 1337) + 4
async def get(self, key): async def get(self, key):
"""
Retrieves an item from the cache. If found, it's marked as recently used.
Returns None if the item is not found or the cache is disabled.
"""
if not self.enabled: if not self.enabled:
return None return None
#async with self._lock:
if key not in self.cache: if key not in self.cache:
await self.update_stat(key, "get") await self.update_stat(key, "get")
return None return None
# Mark as recently used by moving it to the end of the OrderedDict.
# This is an O(1) operation.
self.cache.move_to_end(key) self.cache.move_to_end(key)
await self.update_stat(key, "get") await self.update_stat(key, "get")
return self.cache[key] return self.cache[key]
async def set(self, key, value): async def set(self, key, value):
"""
Adds or updates an item in the cache and marks it as recently used.
If the cache exceeds its maximum size, the least recently used item is evicted.
"""
if not self.enabled: if not self.enabled:
return return
# comment
#async with self._lock:
is_new = key not in self.cache is_new = key not in self.cache
# Add or update the item. If it exists, it's moved to the end.
self.cache[key] = value self.cache[key] = value
self.cache.move_to_end(key) self.cache.move_to_end(key)
await self.update_stat(key, "set") await self.update_stat(key, "set")
# Evict the least recently used item if the cache is full.
# This is an O(1) operation.
if len(self.cache) > self.max_items: if len(self.cache) > self.max_items:
# popitem(last=False) removes and returns the first (oldest) item. self.cache.popitem(last=False)
evicted_key, _ = self.cache.popitem(last=False)
# Optionally, you could log the evicted key here.
if is_new: if is_new:
self.version += 1 self.version += 1
async def delete(self, key): async def delete(self, key):
"""Removes an item from the cache if it exists."""
if not self.enabled: if not self.enabled:
return return
async with self._lock:
if key in self.cache: if key in self.cache:
await self.update_stat(key, "delete") await self.update_stat(key, "delete")
# Deleting from OrderedDict is an O(1) operation on average.
del self.cache[key] del self.cache[key]
async def get_stats(self): async def get_stats(self):
"""Returns statistics for all items currently in the cache."""
async with self._lock:
stats_list = [] stats_list = []
# Items are iterated from oldest to newest. We reverse to show
# most recently used items first.
for key in reversed(self.cache): for key in reversed(self.cache):
stat_data = self.stats.get(key, {"set": 0, "get": 0, "delete": 0}) stat_data = self.stats.get(key, {"set": 0, "get": 0, "delete": 0})
value = self.cache[key] value = self.cache[key]
value_record = value.record if hasattr(value, 'record') else value value_record = value.record if hasattr(value, 'record') else value
stats_list.append({ stats_list.append({
"key": key, "key": key,
"set": stat_data.get("set", 0), "set": stat_data.get("set", 0),
@ -107,16 +63,11 @@ class Cache:
return stats_list return stats_list
async def update_stat(self, key, action): async def update_stat(self, key, action):
"""Updates hit/miss/set counts for a given cache key."""
# This method is already called within a locked context,
# but the lock makes it safe if ever called directly.
async with self._lock:
if key not in self.stats: if key not in self.stats:
self.stats[key] = {"set": 0, "get": 0, "delete": 0} self.stats[key] = {"set": 0, "get": 0, "delete": 0}
self.stats[key][action] += 1 self.stats[key][action] += 1
def serialize(self, obj): def serialize(self, obj):
"""A synchronous helper to create a serializable representation of an object."""
if not isinstance(obj, dict): if not isinstance(obj, dict):
return obj return obj
cpy = obj.copy() cpy = obj.copy()
@ -131,8 +82,6 @@ class Cache:
return str(value) return str(value)
async def create_cache_key(self, args, kwargs): async def create_cache_key(self, args, kwargs):
"""Creates a consistent, hashable cache key from function arguments."""
# security.hash is async, so this method remains async.
return await security.hash( return await security.hash(
json.dumps( json.dumps(
{"args": args, "kwargs": kwargs}, {"args": args, "kwargs": kwargs},
@ -142,7 +91,6 @@ class Cache:
) )
def async_cache(self, func): def async_cache(self, func):
"""Decorator to cache the results of an async function."""
@functools.wraps(func) @functools.wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
cache_key = await self.create_cache_key(args, kwargs) cache_key = await self.create_cache_key(args, kwargs)
@ -155,7 +103,6 @@ class Cache:
return wrapper return wrapper
def async_delete_cache(self, func): def async_delete_cache(self, func):
"""Decorator to invalidate a cache entry before running an async function."""
@functools.wraps(func) @functools.wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
cache_key = await self.create_cache_key(args, kwargs) cache_key = await self.create_cache_key(args, kwargs)