perf: remove async locks from balancer, socket service, and cache
feat: add db connection injection for testing in app build: update pytest configuration
This commit is contained in:
parent
572d9610f2
commit
401e558011
@ -13,6 +13,14 @@
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Version 1.13.0 - 2025-12-24
|
||||||
|
|
||||||
|
Improves performance in the balancer, socket service, and cache by removing async locks. Adds database connection injection for testing in the app and updates pytest configuration.
|
||||||
|
|
||||||
|
**Changes:** 5 files, 291 lines
|
||||||
|
**Languages:** Other (6 lines), Python (285 lines)
|
||||||
|
|
||||||
## Version 1.12.0 - 2025-12-21
|
## Version 1.12.0 - 2025-12-21
|
||||||
|
|
||||||
The socket service enhances concurrency and performance through improved locking and asynchronous handling. The RPC view renames the scheduled list to tasks for clearer API terminology.
|
The socket service enhances concurrency and performance through improved locking and asynchronous handling. The RPC view renames the scheduled list to tasks for clearer API terminology.
|
||||||
|
|||||||
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "Snek"
|
name = "Snek"
|
||||||
version = "1.12.0"
|
version = "1.13.0"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
#license = { file = "LICENSE", content-type="text/markdown" }
|
#license = { file = "LICENSE", content-type="text/markdown" }
|
||||||
description = "Snek Chat Application by Molodetz"
|
description = "Snek Chat Application by Molodetz"
|
||||||
|
|||||||
@ -1,11 +1,10 @@
|
|||||||
[tool:pytest]
|
[pytest]
|
||||||
testpaths = tests
|
testpaths = tests
|
||||||
python_files = test_*.py
|
python_files = test_*.py
|
||||||
python_classes = Test*
|
python_classes = Test*
|
||||||
python_functions = test_*
|
python_functions = test_*
|
||||||
addopts =
|
addopts =
|
||||||
--strict-markers
|
--strict-markers
|
||||||
--strict-config
|
|
||||||
--disable-warnings
|
--disable-warnings
|
||||||
--tb=short
|
--tb=short
|
||||||
-v
|
-v
|
||||||
@ -14,4 +13,3 @@ markers =
|
|||||||
integration: Integration tests
|
integration: Integration tests
|
||||||
slow: Slow running tests
|
slow: Slow running tests
|
||||||
asyncio_mode = auto
|
asyncio_mode = auto
|
||||||
asyncio_default_fixture_loop_scope = function
|
|
||||||
@ -148,11 +148,12 @@ class Application(BaseApplication):
|
|||||||
created_by_uid=admin_user["uid"],
|
created_by_uid=admin_user["uid"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, db_connection=None, **kwargs):
|
||||||
middlewares = [
|
middlewares = [
|
||||||
cors_middleware,
|
cors_middleware,
|
||||||
csp_middleware,
|
csp_middleware,
|
||||||
]
|
]
|
||||||
|
self._test_db = db_connection
|
||||||
self.template_path = pathlib.Path(__file__).parent.joinpath("templates")
|
self.template_path = pathlib.Path(__file__).parent.joinpath("templates")
|
||||||
self.static_path = pathlib.Path(__file__).parent.joinpath("static")
|
self.static_path = pathlib.Path(__file__).parent.joinpath("static")
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -196,6 +197,16 @@ class Application(BaseApplication):
|
|||||||
self.on_startup.append(self.prepare_database)
|
self.on_startup.append(self.prepare_database)
|
||||||
self.on_startup.append(self.create_default_forum)
|
self.on_startup.append(self.create_default_forum)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def db(self):
|
||||||
|
if self._test_db is not None:
|
||||||
|
return self._test_db
|
||||||
|
return self._db
|
||||||
|
|
||||||
|
@db.setter
|
||||||
|
def db(self, value):
|
||||||
|
self._db = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def uptime_seconds(self):
|
def uptime_seconds(self):
|
||||||
return (datetime.now() - self.time_start).total_seconds()
|
return (datetime.now() - self.time_start).total_seconds()
|
||||||
|
|||||||
@ -9,7 +9,6 @@ class LoadBalancer:
|
|||||||
self.backend_ports = backend_ports
|
self.backend_ports = backend_ports
|
||||||
self.backend_processes = []
|
self.backend_processes = []
|
||||||
self.client_counts = [0] * len(backend_ports)
|
self.client_counts = [0] * len(backend_ports)
|
||||||
self.lock = asyncio.Lock()
|
|
||||||
|
|
||||||
async def start_backend_servers(self, port, workers):
|
async def start_backend_servers(self, port, workers):
|
||||||
for x in range(workers):
|
for x in range(workers):
|
||||||
@ -29,10 +28,9 @@ class LoadBalancer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def handle_client(self, reader, writer):
|
async def handle_client(self, reader, writer):
|
||||||
async with self.lock:
|
min_clients = min(self.client_counts)
|
||||||
min_clients = min(self.client_counts)
|
server_index = self.client_counts.index(min_clients)
|
||||||
server_index = self.client_counts.index(min_clients)
|
self.client_counts[server_index] += 1
|
||||||
self.client_counts[server_index] += 1
|
|
||||||
backend = ("127.0.0.1", self.backend_ports[server_index])
|
backend = ("127.0.0.1", self.backend_ports[server_index])
|
||||||
try:
|
try:
|
||||||
backend_reader, backend_writer = await asyncio.open_connection(*backend)
|
backend_reader, backend_writer = await asyncio.open_connection(*backend)
|
||||||
@ -57,8 +55,7 @@ class LoadBalancer:
|
|||||||
print(f"Error: {e}")
|
print(f"Error: {e}")
|
||||||
finally:
|
finally:
|
||||||
writer.close()
|
writer.close()
|
||||||
async with self.lock:
|
self.client_counts[server_index] -= 1
|
||||||
self.client_counts[server_index] -= 1
|
|
||||||
|
|
||||||
async def monitor(self):
|
async def monitor(self):
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
@ -42,42 +42,39 @@ class SocketService(BaseService):
|
|||||||
self.user = user
|
self.user = user
|
||||||
self.user_uid = safe_get(user, "uid") if user else None
|
self.user_uid = safe_get(user, "uid") if user else None
|
||||||
self.user_color = safe_get(user, "color") if user else None
|
self.user_color = safe_get(user, "color") if user else None
|
||||||
self._lock = asyncio.Lock()
|
|
||||||
self.subscribed_channels = set()
|
self.subscribed_channels = set()
|
||||||
|
|
||||||
async def send_json(self, data):
|
async def send_json(self, data):
|
||||||
if data is None:
|
if data is None:
|
||||||
return False
|
return False
|
||||||
async with self._lock:
|
if not self.is_connected:
|
||||||
if not self.is_connected:
|
|
||||||
return False
|
|
||||||
if not self.ws:
|
|
||||||
self.is_connected = False
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
await self.ws.send_json(data)
|
|
||||||
return True
|
|
||||||
except ConnectionResetError:
|
|
||||||
self.is_connected = False
|
|
||||||
logger.debug("Connection reset during send_json")
|
|
||||||
except Exception as ex:
|
|
||||||
self.is_connected = False
|
|
||||||
logger.debug(f"send_json failed: {safe_str(ex)}")
|
|
||||||
return False
|
return False
|
||||||
|
if not self.ws:
|
||||||
|
self.is_connected = False
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
await self.ws.send_json(data)
|
||||||
|
return True
|
||||||
|
except ConnectionResetError:
|
||||||
|
self.is_connected = False
|
||||||
|
logger.debug("Connection reset during send_json")
|
||||||
|
except Exception as ex:
|
||||||
|
self.is_connected = False
|
||||||
|
logger.debug(f"send_json failed: {safe_str(ex)}")
|
||||||
|
return False
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
async with self._lock:
|
if not self.is_connected:
|
||||||
if not self.is_connected:
|
return True
|
||||||
return True
|
self.is_connected = False
|
||||||
self.is_connected = False
|
self.subscribed_channels.clear()
|
||||||
self.subscribed_channels.clear()
|
try:
|
||||||
try:
|
if self.ws and not self.ws.closed:
|
||||||
if self.ws and not self.ws.closed:
|
await asyncio.wait_for(self.ws.close(), timeout=5.0)
|
||||||
await asyncio.wait_for(self.ws.close(), timeout=5.0)
|
except asyncio.TimeoutError:
|
||||||
except asyncio.TimeoutError:
|
logger.debug("Socket close timed out")
|
||||||
logger.debug("Socket close timed out")
|
except Exception as ex:
|
||||||
except Exception as ex:
|
logger.debug(f"Socket close failed: {safe_str(ex)}")
|
||||||
logger.debug(f"Socket close failed: {safe_str(ex)}")
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def __init__(self, app):
|
def __init__(self, app):
|
||||||
@ -86,30 +83,13 @@ class SocketService(BaseService):
|
|||||||
self.users = {}
|
self.users = {}
|
||||||
self.subscriptions = {}
|
self.subscriptions = {}
|
||||||
self.last_update = str(datetime.now())
|
self.last_update = str(datetime.now())
|
||||||
self._sockets_lock = asyncio.Lock()
|
|
||||||
self._user_locks = {}
|
|
||||||
self._channel_locks = {}
|
|
||||||
self._registry_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
async def _get_user_lock(self, user_uid):
|
|
||||||
async with self._registry_lock:
|
|
||||||
if user_uid not in self._user_locks:
|
|
||||||
self._user_locks[user_uid] = asyncio.Lock()
|
|
||||||
return self._user_locks[user_uid]
|
|
||||||
|
|
||||||
async def _get_channel_lock(self, channel_uid):
|
|
||||||
async with self._registry_lock:
|
|
||||||
if channel_uid not in self._channel_locks:
|
|
||||||
self._channel_locks[channel_uid] = asyncio.Lock()
|
|
||||||
return self._channel_locks[channel_uid]
|
|
||||||
|
|
||||||
|
|
||||||
async def user_availability_service(self):
|
async def user_availability_service(self):
|
||||||
logger.info("User availability update service started.")
|
logger.info("User availability update service started.")
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
async with self._sockets_lock:
|
sockets_copy = list(self.sockets)
|
||||||
sockets_copy = list(self.sockets)
|
|
||||||
users_to_update = []
|
users_to_update = []
|
||||||
seen_user_uids = set()
|
seen_user_uids = set()
|
||||||
for s in sockets_copy:
|
for s in sockets_copy:
|
||||||
@ -159,17 +139,14 @@ class SocketService(BaseService):
|
|||||||
username = safe_get(user, "username", "unknown")
|
username = safe_get(user, "username", "unknown")
|
||||||
nick = safe_get(user, "nick") or username
|
nick = safe_get(user, "nick") or username
|
||||||
color = s.user_color
|
color = s.user_color
|
||||||
async with self._sockets_lock:
|
self.sockets.add(s)
|
||||||
self.sockets.add(s)
|
|
||||||
is_first_connection = False
|
is_first_connection = False
|
||||||
user_lock = await self._get_user_lock(user_uid)
|
if user_uid not in self.users:
|
||||||
async with user_lock:
|
self.users[user_uid] = set()
|
||||||
if user_uid not in self.users:
|
is_first_connection = True
|
||||||
self.users[user_uid] = set()
|
elif len(self.users[user_uid]) == 0:
|
||||||
is_first_connection = True
|
is_first_connection = True
|
||||||
elif len(self.users[user_uid]) == 0:
|
self.users[user_uid].add(s)
|
||||||
is_first_connection = True
|
|
||||||
self.users[user_uid].add(s)
|
|
||||||
try:
|
try:
|
||||||
fresh_user = await self.app.services.user.get(uid=user_uid)
|
fresh_user = await self.app.services.user.get(uid=user_uid)
|
||||||
if fresh_user:
|
if fresh_user:
|
||||||
@ -189,22 +166,18 @@ class SocketService(BaseService):
|
|||||||
if not ws or not channel_uid or not user_uid:
|
if not ws or not channel_uid or not user_uid:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
user_lock = await self._get_user_lock(user_uid)
|
|
||||||
channel_lock = await self._get_channel_lock(channel_uid)
|
|
||||||
existing_socket = None
|
existing_socket = None
|
||||||
async with user_lock:
|
user_sockets = self.users.get(user_uid, set())
|
||||||
user_sockets = self.users.get(user_uid, set())
|
for sock in user_sockets:
|
||||||
for sock in user_sockets:
|
if sock and sock.ws == ws:
|
||||||
if sock and sock.ws == ws:
|
existing_socket = sock
|
||||||
existing_socket = sock
|
break
|
||||||
break
|
if not existing_socket:
|
||||||
if not existing_socket:
|
return False
|
||||||
return False
|
existing_socket.subscribed_channels.add(channel_uid)
|
||||||
existing_socket.subscribed_channels.add(channel_uid)
|
if channel_uid not in self.subscriptions:
|
||||||
async with channel_lock:
|
self.subscriptions[channel_uid] = set()
|
||||||
if channel_uid not in self.subscriptions:
|
self.subscriptions[channel_uid].add(user_uid)
|
||||||
self.subscriptions[channel_uid] = set()
|
|
||||||
self.subscriptions[channel_uid].add(user_uid)
|
|
||||||
return True
|
return True
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.warning(f"Failed to subscribe: {safe_str(ex)}")
|
logger.warning(f"Failed to subscribe: {safe_str(ex)}")
|
||||||
@ -215,9 +188,7 @@ class SocketService(BaseService):
|
|||||||
return 0
|
return 0
|
||||||
count = 0
|
count = 0
|
||||||
try:
|
try:
|
||||||
user_lock = await self._get_user_lock(user_uid)
|
user_sockets = list(self.users.get(user_uid, []))
|
||||||
async with user_lock:
|
|
||||||
user_sockets = list(self.users.get(user_uid, []))
|
|
||||||
for s in user_sockets:
|
for s in user_sockets:
|
||||||
if not s:
|
if not s:
|
||||||
continue
|
continue
|
||||||
@ -252,10 +223,8 @@ class SocketService(BaseService):
|
|||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.warning(f"Broadcast db query failed: {safe_str(ex)}")
|
logger.warning(f"Broadcast db query failed: {safe_str(ex)}")
|
||||||
if not user_uids_to_send:
|
if not user_uids_to_send:
|
||||||
channel_lock = await self._get_channel_lock(channel_uid)
|
if channel_uid in self.subscriptions:
|
||||||
async with channel_lock:
|
user_uids_to_send = set(self.subscriptions[channel_uid])
|
||||||
if channel_uid in self.subscriptions:
|
|
||||||
user_uids_to_send = set(self.subscriptions[channel_uid])
|
|
||||||
logger.debug(f"_broadcast: using {len(user_uids_to_send)} users from subscriptions for channel={channel_uid}")
|
logger.debug(f"_broadcast: using {len(user_uids_to_send)} users from subscriptions for channel={channel_uid}")
|
||||||
send_tasks = []
|
send_tasks = []
|
||||||
for user_uid in user_uids_to_send:
|
for user_uid in user_uids_to_send:
|
||||||
@ -281,25 +250,21 @@ class SocketService(BaseService):
|
|||||||
async def delete(self, ws):
|
async def delete(self, ws):
|
||||||
if not ws:
|
if not ws:
|
||||||
return
|
return
|
||||||
sockets_to_remove = []
|
sockets_to_remove = [sock for sock in self.sockets if sock and sock.ws == ws]
|
||||||
async with self._sockets_lock:
|
for s in sockets_to_remove:
|
||||||
sockets_to_remove = [sock for sock in self.sockets if sock and sock.ws == ws]
|
self.sockets.discard(s)
|
||||||
for s in sockets_to_remove:
|
|
||||||
self.sockets.discard(s)
|
|
||||||
departures_to_broadcast = []
|
departures_to_broadcast = []
|
||||||
channels_to_cleanup = set()
|
channels_to_cleanup = set()
|
||||||
for s in sockets_to_remove:
|
for s in sockets_to_remove:
|
||||||
user_uid = s.user_uid
|
user_uid = s.user_uid
|
||||||
if not user_uid:
|
if not user_uid:
|
||||||
continue
|
continue
|
||||||
user_lock = await self._get_user_lock(user_uid)
|
|
||||||
is_last_connection = False
|
is_last_connection = False
|
||||||
async with user_lock:
|
if user_uid in self.users:
|
||||||
if user_uid in self.users:
|
self.users[user_uid].discard(s)
|
||||||
self.users[user_uid].discard(s)
|
if len(self.users[user_uid]) == 0:
|
||||||
if len(self.users[user_uid]) == 0:
|
del self.users[user_uid]
|
||||||
del self.users[user_uid]
|
is_last_connection = True
|
||||||
is_last_connection = True
|
|
||||||
if is_last_connection:
|
if is_last_connection:
|
||||||
user_nick = None
|
user_nick = None
|
||||||
try:
|
try:
|
||||||
@ -313,12 +278,10 @@ class SocketService(BaseService):
|
|||||||
channels_to_cleanup.add((channel_uid, user_uid))
|
channels_to_cleanup.add((channel_uid, user_uid))
|
||||||
for channel_uid, user_uid in channels_to_cleanup:
|
for channel_uid, user_uid in channels_to_cleanup:
|
||||||
try:
|
try:
|
||||||
channel_lock = await self._get_channel_lock(channel_uid)
|
if channel_uid in self.subscriptions:
|
||||||
async with channel_lock:
|
self.subscriptions[channel_uid].discard(user_uid)
|
||||||
if channel_uid in self.subscriptions:
|
if len(self.subscriptions[channel_uid]) == 0:
|
||||||
self.subscriptions[channel_uid].discard(user_uid)
|
del self.subscriptions[channel_uid]
|
||||||
if len(self.subscriptions[channel_uid]) == 0:
|
|
||||||
del self.subscriptions[channel_uid]
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.debug(f"Failed to cleanup channel subscription: {safe_str(ex)}")
|
logger.debug(f"Failed to cleanup channel subscription: {safe_str(ex)}")
|
||||||
for s in sockets_to_remove:
|
for s in sockets_to_remove:
|
||||||
@ -350,8 +313,7 @@ class SocketService(BaseService):
|
|||||||
"timestamp": datetime.now().isoformat(),
|
"timestamp": datetime.now().isoformat(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
async with self._sockets_lock:
|
sockets_copy = list(self.sockets)
|
||||||
sockets_copy = list(self.sockets)
|
|
||||||
send_tasks = []
|
send_tasks = []
|
||||||
for s in sockets_copy:
|
for s in sockets_copy:
|
||||||
if not s or not s.is_connected:
|
if not s or not s.is_connected:
|
||||||
@ -376,9 +338,7 @@ class SocketService(BaseService):
|
|||||||
if not user_uid:
|
if not user_uid:
|
||||||
return 0
|
return 0
|
||||||
try:
|
try:
|
||||||
user_lock = await self._get_user_lock(user_uid)
|
return len(self.users.get(user_uid, []))
|
||||||
async with user_lock:
|
|
||||||
return len(self.users.get(user_uid, []))
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@ -386,9 +346,7 @@ class SocketService(BaseService):
|
|||||||
if not user_uid:
|
if not user_uid:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
user_lock = await self._get_user_lock(user_uid)
|
user_sockets = self.users.get(user_uid, set())
|
||||||
async with user_lock:
|
return any(s.is_connected for s in user_sockets if s)
|
||||||
user_sockets = self.users.get(user_uid, set())
|
|
||||||
return any(s.is_connected for s in user_sockets if s)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
# retoor <retoor@molodetz.nl>
|
# retoor <retoor@molodetz.nl>
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
@ -10,113 +9,65 @@ CACHE_MAX_ITEMS_DEFAULT = 5000
|
|||||||
|
|
||||||
|
|
||||||
class Cache:
|
class Cache:
|
||||||
"""
|
|
||||||
An asynchronous, thread-safe, in-memory LRU (Least Recently Used) cache.
|
|
||||||
|
|
||||||
This implementation uses an OrderedDict for efficient O(1) time complexity
|
|
||||||
for its core get, set, and delete operations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, app, max_items=CACHE_MAX_ITEMS_DEFAULT):
|
def __init__(self, app, max_items=CACHE_MAX_ITEMS_DEFAULT):
|
||||||
self.app = app
|
self.app = app
|
||||||
# OrderedDict is the core of the LRU logic. It remembers the order
|
|
||||||
# in which items were inserted.
|
|
||||||
self.cache: OrderedDict = OrderedDict()
|
self.cache: OrderedDict = OrderedDict()
|
||||||
self.max_items = max_items
|
self.max_items = max_items
|
||||||
self.stats = {}
|
self.stats = {}
|
||||||
self.enabled = True
|
self.enabled = True
|
||||||
# A lock is crucial to prevent race conditions in an async environment.
|
|
||||||
self._lock = asyncio.Lock()
|
|
||||||
self.version = ((42 + 420 + 1984 + 1990 + 10 + 6 + 71 + 3004 + 7245) ^ 1337) + 4
|
self.version = ((42 + 420 + 1984 + 1990 + 10 + 6 + 71 + 3004 + 7245) ^ 1337) + 4
|
||||||
|
|
||||||
async def get(self, key):
|
async def get(self, key):
|
||||||
"""
|
|
||||||
Retrieves an item from the cache. If found, it's marked as recently used.
|
|
||||||
Returns None if the item is not found or the cache is disabled.
|
|
||||||
"""
|
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
#async with self._lock:
|
|
||||||
if key not in self.cache:
|
if key not in self.cache:
|
||||||
await self.update_stat(key, "get")
|
await self.update_stat(key, "get")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Mark as recently used by moving it to the end of the OrderedDict.
|
|
||||||
# This is an O(1) operation.
|
|
||||||
self.cache.move_to_end(key)
|
self.cache.move_to_end(key)
|
||||||
await self.update_stat(key, "get")
|
await self.update_stat(key, "get")
|
||||||
return self.cache[key]
|
return self.cache[key]
|
||||||
|
|
||||||
async def set(self, key, value):
|
async def set(self, key, value):
|
||||||
"""
|
|
||||||
Adds or updates an item in the cache and marks it as recently used.
|
|
||||||
If the cache exceeds its maximum size, the least recently used item is evicted.
|
|
||||||
"""
|
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
return
|
return
|
||||||
# comment
|
|
||||||
#async with self._lock:
|
|
||||||
is_new = key not in self.cache
|
is_new = key not in self.cache
|
||||||
|
|
||||||
# Add or update the item. If it exists, it's moved to the end.
|
|
||||||
self.cache[key] = value
|
self.cache[key] = value
|
||||||
self.cache.move_to_end(key)
|
self.cache.move_to_end(key)
|
||||||
|
|
||||||
await self.update_stat(key, "set")
|
await self.update_stat(key, "set")
|
||||||
|
|
||||||
# Evict the least recently used item if the cache is full.
|
|
||||||
# This is an O(1) operation.
|
|
||||||
if len(self.cache) > self.max_items:
|
if len(self.cache) > self.max_items:
|
||||||
# popitem(last=False) removes and returns the first (oldest) item.
|
self.cache.popitem(last=False)
|
||||||
evicted_key, _ = self.cache.popitem(last=False)
|
|
||||||
# Optionally, you could log the evicted key here.
|
|
||||||
|
|
||||||
if is_new:
|
if is_new:
|
||||||
self.version += 1
|
self.version += 1
|
||||||
|
|
||||||
async def delete(self, key):
|
async def delete(self, key):
|
||||||
"""Removes an item from the cache if it exists."""
|
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
return
|
return
|
||||||
|
if key in self.cache:
|
||||||
async with self._lock:
|
await self.update_stat(key, "delete")
|
||||||
if key in self.cache:
|
del self.cache[key]
|
||||||
await self.update_stat(key, "delete")
|
|
||||||
# Deleting from OrderedDict is an O(1) operation on average.
|
|
||||||
del self.cache[key]
|
|
||||||
|
|
||||||
async def get_stats(self):
|
async def get_stats(self):
|
||||||
"""Returns statistics for all items currently in the cache."""
|
stats_list = []
|
||||||
async with self._lock:
|
for key in reversed(self.cache):
|
||||||
stats_list = []
|
stat_data = self.stats.get(key, {"set": 0, "get": 0, "delete": 0})
|
||||||
# Items are iterated from oldest to newest. We reverse to show
|
value = self.cache[key]
|
||||||
# most recently used items first.
|
value_record = value.record if hasattr(value, 'record') else value
|
||||||
for key in reversed(self.cache):
|
stats_list.append({
|
||||||
stat_data = self.stats.get(key, {"set": 0, "get": 0, "delete": 0})
|
"key": key,
|
||||||
value = self.cache[key]
|
"set": stat_data.get("set", 0),
|
||||||
value_record = value.record if hasattr(value, 'record') else value
|
"get": stat_data.get("get", 0),
|
||||||
|
"delete": stat_data.get("delete", 0),
|
||||||
stats_list.append({
|
"value": str(self.serialize(value_record)),
|
||||||
"key": key,
|
})
|
||||||
"set": stat_data.get("set", 0),
|
return stats_list
|
||||||
"get": stat_data.get("get", 0),
|
|
||||||
"delete": stat_data.get("delete", 0),
|
|
||||||
"value": str(self.serialize(value_record)),
|
|
||||||
})
|
|
||||||
return stats_list
|
|
||||||
|
|
||||||
async def update_stat(self, key, action):
|
async def update_stat(self, key, action):
|
||||||
"""Updates hit/miss/set counts for a given cache key."""
|
if key not in self.stats:
|
||||||
# This method is already called within a locked context,
|
self.stats[key] = {"set": 0, "get": 0, "delete": 0}
|
||||||
# but the lock makes it safe if ever called directly.
|
self.stats[key][action] += 1
|
||||||
async with self._lock:
|
|
||||||
if key not in self.stats:
|
|
||||||
self.stats[key] = {"set": 0, "get": 0, "delete": 0}
|
|
||||||
self.stats[key][action] += 1
|
|
||||||
|
|
||||||
def serialize(self, obj):
|
def serialize(self, obj):
|
||||||
"""A synchronous helper to create a serializable representation of an object."""
|
|
||||||
if not isinstance(obj, dict):
|
if not isinstance(obj, dict):
|
||||||
return obj
|
return obj
|
||||||
cpy = obj.copy()
|
cpy = obj.copy()
|
||||||
@ -131,8 +82,6 @@ class Cache:
|
|||||||
return str(value)
|
return str(value)
|
||||||
|
|
||||||
async def create_cache_key(self, args, kwargs):
|
async def create_cache_key(self, args, kwargs):
|
||||||
"""Creates a consistent, hashable cache key from function arguments."""
|
|
||||||
# security.hash is async, so this method remains async.
|
|
||||||
return await security.hash(
|
return await security.hash(
|
||||||
json.dumps(
|
json.dumps(
|
||||||
{"args": args, "kwargs": kwargs},
|
{"args": args, "kwargs": kwargs},
|
||||||
@ -142,7 +91,6 @@ class Cache:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def async_cache(self, func):
|
def async_cache(self, func):
|
||||||
"""Decorator to cache the results of an async function."""
|
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def wrapper(*args, **kwargs):
|
async def wrapper(*args, **kwargs):
|
||||||
cache_key = await self.create_cache_key(args, kwargs)
|
cache_key = await self.create_cache_key(args, kwargs)
|
||||||
@ -155,7 +103,6 @@ class Cache:
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
def async_delete_cache(self, func):
|
def async_delete_cache(self, func):
|
||||||
"""Decorator to invalidate a cache entry before running an async function."""
|
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def wrapper(*args, **kwargs):
|
async def wrapper(*args, **kwargs):
|
||||||
cache_key = await self.create_cache_key(args, kwargs)
|
cache_key = await self.create_cache_key(args, kwargs)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user