diff --git a/src/snek/system/cache.py b/src/snek/system/cache.py index 8f8cdc3..9bbed53 100644 --- a/src/snek/system/cache.py +++ b/src/snek/system/cache.py @@ -1,52 +1,93 @@ import functools import json +import asyncio +from collections import OrderedDict # Use OrderedDict for O(1) LRU management +# Assuming snek.system.security exists and security.hash is an async function from snek.system import security - -cache = functools.cache - CACHE_MAX_ITEMS_DEFAULT = 5000 class Cache: def __init__(self, app, max_items=CACHE_MAX_ITEMS_DEFAULT): self.app = app - self.cache = {} + self.cache = OrderedDict() self.max_items = max_items self.stats = {} - self.enabled = False - self.lru = [] + self.enabled = True + self.lru = [] + self._lock = asyncio.Lock() self.version = ((42 + 420 + 1984 + 1990 + 10 + 6 + 71 + 3004 + 7245) ^ 1337) + 4 + async def get(self, args): - if not self.enabled: - return None - await self.update_stat(args, "get") - try: - self.lru.pop(self.lru.index(args)) - except: - # print("Cache miss!", args, flush=True) - return None - self.lru.insert(0, args) - while len(self.lru) > self.max_items: - self.cache.pop(self.lru[-1]) - self.lru.pop() - # print("Cache hit!", args, flush=True) - return self.cache[args] + async with self._lock: + if not self.enabled: + return None + + if args not in self.cache: + await self.update_stat(args, "get") + return None + + await self.update_stat(args, "get") + + value = self.cache.pop(args) + self.cache[args] = value + + return value + + async def set(self, args, result): + + async with self._lock: + if not self.enabled: + return + + is_new = args not in self.cache + + self.cache[args] = result + + self.cache.move_to_end(args) + + await self.update_stat(args, "set") + + if len(self.cache) > self.max_items: + evicted_key, _ = self.cache.popitem(last=False) + + if is_new: + self.version += 1 + + async def delete(self, args): + async with self._lock: + if not self.enabled: + return + + if args in self.cache: + await self.update_stat(args, "delete") + del self.cache[args] async def get_stats(self): - all_ = [] - for key in self.lru: - all_.append( - { - "key": key, - "set": self.stats[key]["set"], - "get": self.stats[key]["get"], - "delete": self.stats[key]["delete"], - "value": str(self.serialize(self.cache[key].record)), - } - ) - return all_ + async with self._lock: + all_ = [] + lru_keys = list(self.cache.keys()) + lru_keys.reverse() + self.lru = lru_keys + + for key in self.lru: + if key not in self.stats: + self.stats[key] = {"set": 0, "get": 0, "delete": 0} + + if key in self.cache: + value_record = self.cache[key].record if hasattr(self.cache.get(key), 'record') else self.cache[key] + all_.append( + { + "key": key, + "set": self.stats[key]["set"], + "get": self.stats[key]["get"], + "delete": self.stats[key]["delete"], + "value": str(self.serialize(value_record)), + } + ) + return all_ def serialize(self, obj): cpy = obj.copy() @@ -57,18 +98,18 @@ class Cache: return cpy async def update_stat(self, key, action): - if key not in self.stats: - self.stats[key] = {"set": 0, "get": 0, "delete": 0} - self.stats[key][action] = self.stats[key][action] + 1 - + async with self._lock: + if key not in self.stats: + self.stats[key] = {"set": 0, "get": 0, "delete": 0} + self.stats[key][action] = self.stats[key][action] + 1 + def json_default(self, value): - # if hasattr(value, "to_json"): - # return value.to_json() try: return json.dumps(value.__dict__, default=str) except: return str(value) + # Retained async due to the call to await security.hash() async def create_cache_key(self, args, kwargs): return await security.hash( json.dumps( @@ -78,37 +119,6 @@ class Cache: ) ) - async def set(self, args, result): - if not self.enabled: - return - is_new = args not in self.cache - self.cache[args] = result - await self.update_stat(args, "set") - try: - self.lru.pop(self.lru.index(args)) - except (ValueError, IndexError): - pass - self.lru.insert(0, args) - - while len(self.lru) > self.max_items: - self.cache.pop(self.lru[-1]) - self.lru.pop() - - if is_new: - self.version += 1 - # print(f"Cache store! {len(self.lru)} items. New version:", self.version, flush=True) - - async def delete(self, args): - if not self.enabled: - return - await self.update_stat(args, "delete") - if args in self.cache: - try: - self.lru.pop(self.lru.index(args)) - except IndexError: - pass - del self.cache[args] - def async_cache(self, func): @functools.wraps(func) async def wrapper(*args, **kwargs): @@ -119,21 +129,15 @@ class Cache: result = await func(*args, **kwargs) await self.set(cache_key, result) return result - return wrapper def async_delete_cache(self, func): @functools.wraps(func) async def wrapper(*args, **kwargs): cache_key = await self.create_cache_key(args, kwargs) - if cache_key in self.cache: - try: - self.lru.pop(self.lru.index(cache_key)) - except IndexError: - pass - del self.cache[cache_key] + # Use the fixed self.delete method + await self.delete(cache_key) return await func(*args, **kwargs) - return wrapper @@ -149,3 +153,4 @@ def async_cache(func): return result return wrapper +