perf: remove async locks from balancer, socket service, and cache

feat: add db connection injection for testing in app
build: update pytest configuration
This commit is contained in:
retoor 2025-12-24 18:03:21 +01:00
parent 572d9610f2
commit 401e558011
7 changed files with 110 additions and 191 deletions

View File

@ -13,6 +13,14 @@
## Version 1.13.0 - 2025-12-24
Improves performance in the balancer, socket service, and cache by removing async locks. Adds database connection injection for testing in the app and updates pytest configuration.
**Changes:** 5 files, 291 lines
**Languages:** Other (6 lines), Python (285 lines)
## Version 1.12.0 - 2025-12-21
The socket service enhances concurrency and performance through improved locking and asynchronous handling. The RPC view renames the scheduled list to tasks for clearer API terminology.

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "Snek"
version = "1.12.0"
version = "1.13.0"
readme = "README.md"
#license = { file = "LICENSE", content-type="text/markdown" }
description = "Snek Chat Application by Molodetz"

View File

@ -1,11 +1,10 @@
[tool:pytest]
[pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts =
--strict-markers
--strict-config
--disable-warnings
--tb=short
-v
@ -14,4 +13,3 @@ markers =
integration: Integration tests
slow: Slow running tests
asyncio_mode = auto
asyncio_default_fixture_loop_scope = function

View File

@ -148,11 +148,12 @@ class Application(BaseApplication):
created_by_uid=admin_user["uid"],
)
def __init__(self, *args, **kwargs):
def __init__(self, *args, db_connection=None, **kwargs):
middlewares = [
cors_middleware,
csp_middleware,
]
self._test_db = db_connection
self.template_path = pathlib.Path(__file__).parent.joinpath("templates")
self.static_path = pathlib.Path(__file__).parent.joinpath("static")
super().__init__(
@ -196,6 +197,16 @@ class Application(BaseApplication):
self.on_startup.append(self.prepare_database)
self.on_startup.append(self.create_default_forum)
@property
def db(self):
if self._test_db is not None:
return self._test_db
return self._db
@db.setter
def db(self, value):
self._db = value
@property
def uptime_seconds(self):
return (datetime.now() - self.time_start).total_seconds()

View File

@ -9,7 +9,6 @@ class LoadBalancer:
self.backend_ports = backend_ports
self.backend_processes = []
self.client_counts = [0] * len(backend_ports)
self.lock = asyncio.Lock()
async def start_backend_servers(self, port, workers):
for x in range(workers):
@ -29,10 +28,9 @@ class LoadBalancer:
)
async def handle_client(self, reader, writer):
async with self.lock:
min_clients = min(self.client_counts)
server_index = self.client_counts.index(min_clients)
self.client_counts[server_index] += 1
min_clients = min(self.client_counts)
server_index = self.client_counts.index(min_clients)
self.client_counts[server_index] += 1
backend = ("127.0.0.1", self.backend_ports[server_index])
try:
backend_reader, backend_writer = await asyncio.open_connection(*backend)
@ -57,8 +55,7 @@ class LoadBalancer:
print(f"Error: {e}")
finally:
writer.close()
async with self.lock:
self.client_counts[server_index] -= 1
self.client_counts[server_index] -= 1
async def monitor(self):
while True:

View File

@ -42,42 +42,39 @@ class SocketService(BaseService):
self.user = user
self.user_uid = safe_get(user, "uid") if user else None
self.user_color = safe_get(user, "color") if user else None
self._lock = asyncio.Lock()
self.subscribed_channels = set()
async def send_json(self, data):
if data is None:
return False
async with self._lock:
if not self.is_connected:
return False
if not self.ws:
self.is_connected = False
return False
try:
await self.ws.send_json(data)
return True
except ConnectionResetError:
self.is_connected = False
logger.debug("Connection reset during send_json")
except Exception as ex:
self.is_connected = False
logger.debug(f"send_json failed: {safe_str(ex)}")
if not self.is_connected:
return False
if not self.ws:
self.is_connected = False
return False
try:
await self.ws.send_json(data)
return True
except ConnectionResetError:
self.is_connected = False
logger.debug("Connection reset during send_json")
except Exception as ex:
self.is_connected = False
logger.debug(f"send_json failed: {safe_str(ex)}")
return False
async def close(self):
async with self._lock:
if not self.is_connected:
return True
self.is_connected = False
self.subscribed_channels.clear()
try:
if self.ws and not self.ws.closed:
await asyncio.wait_for(self.ws.close(), timeout=5.0)
except asyncio.TimeoutError:
logger.debug("Socket close timed out")
except Exception as ex:
logger.debug(f"Socket close failed: {safe_str(ex)}")
if not self.is_connected:
return True
self.is_connected = False
self.subscribed_channels.clear()
try:
if self.ws and not self.ws.closed:
await asyncio.wait_for(self.ws.close(), timeout=5.0)
except asyncio.TimeoutError:
logger.debug("Socket close timed out")
except Exception as ex:
logger.debug(f"Socket close failed: {safe_str(ex)}")
return True
def __init__(self, app):
@ -86,30 +83,13 @@ class SocketService(BaseService):
self.users = {}
self.subscriptions = {}
self.last_update = str(datetime.now())
self._sockets_lock = asyncio.Lock()
self._user_locks = {}
self._channel_locks = {}
self._registry_lock = asyncio.Lock()
async def _get_user_lock(self, user_uid):
async with self._registry_lock:
if user_uid not in self._user_locks:
self._user_locks[user_uid] = asyncio.Lock()
return self._user_locks[user_uid]
async def _get_channel_lock(self, channel_uid):
async with self._registry_lock:
if channel_uid not in self._channel_locks:
self._channel_locks[channel_uid] = asyncio.Lock()
return self._channel_locks[channel_uid]
async def user_availability_service(self):
logger.info("User availability update service started.")
while True:
try:
async with self._sockets_lock:
sockets_copy = list(self.sockets)
sockets_copy = list(self.sockets)
users_to_update = []
seen_user_uids = set()
for s in sockets_copy:
@ -159,17 +139,14 @@ class SocketService(BaseService):
username = safe_get(user, "username", "unknown")
nick = safe_get(user, "nick") or username
color = s.user_color
async with self._sockets_lock:
self.sockets.add(s)
self.sockets.add(s)
is_first_connection = False
user_lock = await self._get_user_lock(user_uid)
async with user_lock:
if user_uid not in self.users:
self.users[user_uid] = set()
is_first_connection = True
elif len(self.users[user_uid]) == 0:
is_first_connection = True
self.users[user_uid].add(s)
if user_uid not in self.users:
self.users[user_uid] = set()
is_first_connection = True
elif len(self.users[user_uid]) == 0:
is_first_connection = True
self.users[user_uid].add(s)
try:
fresh_user = await self.app.services.user.get(uid=user_uid)
if fresh_user:
@ -189,22 +166,18 @@ class SocketService(BaseService):
if not ws or not channel_uid or not user_uid:
return False
try:
user_lock = await self._get_user_lock(user_uid)
channel_lock = await self._get_channel_lock(channel_uid)
existing_socket = None
async with user_lock:
user_sockets = self.users.get(user_uid, set())
for sock in user_sockets:
if sock and sock.ws == ws:
existing_socket = sock
break
if not existing_socket:
return False
existing_socket.subscribed_channels.add(channel_uid)
async with channel_lock:
if channel_uid not in self.subscriptions:
self.subscriptions[channel_uid] = set()
self.subscriptions[channel_uid].add(user_uid)
user_sockets = self.users.get(user_uid, set())
for sock in user_sockets:
if sock and sock.ws == ws:
existing_socket = sock
break
if not existing_socket:
return False
existing_socket.subscribed_channels.add(channel_uid)
if channel_uid not in self.subscriptions:
self.subscriptions[channel_uid] = set()
self.subscriptions[channel_uid].add(user_uid)
return True
except Exception as ex:
logger.warning(f"Failed to subscribe: {safe_str(ex)}")
@ -215,9 +188,7 @@ class SocketService(BaseService):
return 0
count = 0
try:
user_lock = await self._get_user_lock(user_uid)
async with user_lock:
user_sockets = list(self.users.get(user_uid, []))
user_sockets = list(self.users.get(user_uid, []))
for s in user_sockets:
if not s:
continue
@ -252,10 +223,8 @@ class SocketService(BaseService):
except Exception as ex:
logger.warning(f"Broadcast db query failed: {safe_str(ex)}")
if not user_uids_to_send:
channel_lock = await self._get_channel_lock(channel_uid)
async with channel_lock:
if channel_uid in self.subscriptions:
user_uids_to_send = set(self.subscriptions[channel_uid])
if channel_uid in self.subscriptions:
user_uids_to_send = set(self.subscriptions[channel_uid])
logger.debug(f"_broadcast: using {len(user_uids_to_send)} users from subscriptions for channel={channel_uid}")
send_tasks = []
for user_uid in user_uids_to_send:
@ -281,25 +250,21 @@ class SocketService(BaseService):
async def delete(self, ws):
if not ws:
return
sockets_to_remove = []
async with self._sockets_lock:
sockets_to_remove = [sock for sock in self.sockets if sock and sock.ws == ws]
for s in sockets_to_remove:
self.sockets.discard(s)
sockets_to_remove = [sock for sock in self.sockets if sock and sock.ws == ws]
for s in sockets_to_remove:
self.sockets.discard(s)
departures_to_broadcast = []
channels_to_cleanup = set()
for s in sockets_to_remove:
user_uid = s.user_uid
if not user_uid:
continue
user_lock = await self._get_user_lock(user_uid)
is_last_connection = False
async with user_lock:
if user_uid in self.users:
self.users[user_uid].discard(s)
if len(self.users[user_uid]) == 0:
del self.users[user_uid]
is_last_connection = True
if user_uid in self.users:
self.users[user_uid].discard(s)
if len(self.users[user_uid]) == 0:
del self.users[user_uid]
is_last_connection = True
if is_last_connection:
user_nick = None
try:
@ -313,12 +278,10 @@ class SocketService(BaseService):
channels_to_cleanup.add((channel_uid, user_uid))
for channel_uid, user_uid in channels_to_cleanup:
try:
channel_lock = await self._get_channel_lock(channel_uid)
async with channel_lock:
if channel_uid in self.subscriptions:
self.subscriptions[channel_uid].discard(user_uid)
if len(self.subscriptions[channel_uid]) == 0:
del self.subscriptions[channel_uid]
if channel_uid in self.subscriptions:
self.subscriptions[channel_uid].discard(user_uid)
if len(self.subscriptions[channel_uid]) == 0:
del self.subscriptions[channel_uid]
except Exception as ex:
logger.debug(f"Failed to cleanup channel subscription: {safe_str(ex)}")
for s in sockets_to_remove:
@ -350,8 +313,7 @@ class SocketService(BaseService):
"timestamp": datetime.now().isoformat(),
},
}
async with self._sockets_lock:
sockets_copy = list(self.sockets)
sockets_copy = list(self.sockets)
send_tasks = []
for s in sockets_copy:
if not s or not s.is_connected:
@ -376,9 +338,7 @@ class SocketService(BaseService):
if not user_uid:
return 0
try:
user_lock = await self._get_user_lock(user_uid)
async with user_lock:
return len(self.users.get(user_uid, []))
return len(self.users.get(user_uid, []))
except Exception:
return 0
@ -386,9 +346,7 @@ class SocketService(BaseService):
if not user_uid:
return False
try:
user_lock = await self._get_user_lock(user_uid)
async with user_lock:
user_sockets = self.users.get(user_uid, set())
return any(s.is_connected for s in user_sockets if s)
user_sockets = self.users.get(user_uid, set())
return any(s.is_connected for s in user_sockets if s)
except Exception:
return False

View File

@ -1,6 +1,5 @@
# retoor <retoor@molodetz.nl>
import asyncio
import functools
import json
from collections import OrderedDict
@ -10,113 +9,65 @@ CACHE_MAX_ITEMS_DEFAULT = 5000
class Cache:
"""
An asynchronous, thread-safe, in-memory LRU (Least Recently Used) cache.
This implementation uses an OrderedDict for efficient O(1) time complexity
for its core get, set, and delete operations.
"""
def __init__(self, app, max_items=CACHE_MAX_ITEMS_DEFAULT):
self.app = app
# OrderedDict is the core of the LRU logic. It remembers the order
# in which items were inserted.
self.cache: OrderedDict = OrderedDict()
self.max_items = max_items
self.stats = {}
self.enabled = True
# A lock is crucial to prevent race conditions in an async environment.
self._lock = asyncio.Lock()
self.version = ((42 + 420 + 1984 + 1990 + 10 + 6 + 71 + 3004 + 7245) ^ 1337) + 4
async def get(self, key):
"""
Retrieves an item from the cache. If found, it's marked as recently used.
Returns None if the item is not found or the cache is disabled.
"""
if not self.enabled:
return None
#async with self._lock:
if key not in self.cache:
await self.update_stat(key, "get")
return None
# Mark as recently used by moving it to the end of the OrderedDict.
# This is an O(1) operation.
self.cache.move_to_end(key)
await self.update_stat(key, "get")
return self.cache[key]
async def set(self, key, value):
"""
Adds or updates an item in the cache and marks it as recently used.
If the cache exceeds its maximum size, the least recently used item is evicted.
"""
if not self.enabled:
return
# comment
#async with self._lock:
is_new = key not in self.cache
# Add or update the item. If it exists, it's moved to the end.
self.cache[key] = value
self.cache.move_to_end(key)
await self.update_stat(key, "set")
# Evict the least recently used item if the cache is full.
# This is an O(1) operation.
if len(self.cache) > self.max_items:
# popitem(last=False) removes and returns the first (oldest) item.
evicted_key, _ = self.cache.popitem(last=False)
# Optionally, you could log the evicted key here.
self.cache.popitem(last=False)
if is_new:
self.version += 1
async def delete(self, key):
"""Removes an item from the cache if it exists."""
if not self.enabled:
return
async with self._lock:
if key in self.cache:
await self.update_stat(key, "delete")
# Deleting from OrderedDict is an O(1) operation on average.
del self.cache[key]
if key in self.cache:
await self.update_stat(key, "delete")
del self.cache[key]
async def get_stats(self):
"""Returns statistics for all items currently in the cache."""
async with self._lock:
stats_list = []
# Items are iterated from oldest to newest. We reverse to show
# most recently used items first.
for key in reversed(self.cache):
stat_data = self.stats.get(key, {"set": 0, "get": 0, "delete": 0})
value = self.cache[key]
value_record = value.record if hasattr(value, 'record') else value
stats_list.append({
"key": key,
"set": stat_data.get("set", 0),
"get": stat_data.get("get", 0),
"delete": stat_data.get("delete", 0),
"value": str(self.serialize(value_record)),
})
return stats_list
stats_list = []
for key in reversed(self.cache):
stat_data = self.stats.get(key, {"set": 0, "get": 0, "delete": 0})
value = self.cache[key]
value_record = value.record if hasattr(value, 'record') else value
stats_list.append({
"key": key,
"set": stat_data.get("set", 0),
"get": stat_data.get("get", 0),
"delete": stat_data.get("delete", 0),
"value": str(self.serialize(value_record)),
})
return stats_list
async def update_stat(self, key, action):
"""Updates hit/miss/set counts for a given cache key."""
# This method is already called within a locked context,
# but the lock makes it safe if ever called directly.
async with self._lock:
if key not in self.stats:
self.stats[key] = {"set": 0, "get": 0, "delete": 0}
self.stats[key][action] += 1
if key not in self.stats:
self.stats[key] = {"set": 0, "get": 0, "delete": 0}
self.stats[key][action] += 1
def serialize(self, obj):
"""A synchronous helper to create a serializable representation of an object."""
if not isinstance(obj, dict):
return obj
cpy = obj.copy()
@ -131,8 +82,6 @@ class Cache:
return str(value)
async def create_cache_key(self, args, kwargs):
"""Creates a consistent, hashable cache key from function arguments."""
# security.hash is async, so this method remains async.
return await security.hash(
json.dumps(
{"args": args, "kwargs": kwargs},
@ -142,7 +91,6 @@ class Cache:
)
def async_cache(self, func):
"""Decorator to cache the results of an async function."""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
cache_key = await self.create_cache_key(args, kwargs)
@ -155,7 +103,6 @@ class Cache:
return wrapper
def async_delete_cache(self, func):
"""Decorator to invalidate a cache entry before running an async function."""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
cache_key = await self.create_cache_key(args, kwargs)