import hashlib import json import sqlite3 import time from typing import Optional, Any, Set class ToolCache: DETERMINISTIC_TOOLS: Set[str] = { 'read_file', 'list_directory', 'get_current_directory', 'db_get', 'db_query', 'index_directory' } def __init__(self, db_path: str, ttl_seconds: int = 300): self.db_path = db_path self.ttl_seconds = ttl_seconds self._initialize_cache() def _initialize_cache(self): conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute(''' CREATE TABLE IF NOT EXISTS tool_cache ( cache_key TEXT PRIMARY KEY, tool_name TEXT NOT NULL, result_data TEXT NOT NULL, created_at INTEGER NOT NULL, expires_at INTEGER NOT NULL, hit_count INTEGER DEFAULT 0 ) ''') cursor.execute(''' CREATE INDEX IF NOT EXISTS idx_tool_expires ON tool_cache(expires_at) ''') cursor.execute(''' CREATE INDEX IF NOT EXISTS idx_tool_name ON tool_cache(tool_name) ''') conn.commit() conn.close() def _generate_cache_key(self, tool_name: str, arguments: dict) -> str: cache_data = { 'tool': tool_name, 'args': arguments } serialized = json.dumps(cache_data, sort_keys=True) return hashlib.sha256(serialized.encode()).hexdigest() def is_cacheable(self, tool_name: str) -> bool: return tool_name in self.DETERMINISTIC_TOOLS def get(self, tool_name: str, arguments: dict) -> Optional[Any]: if not self.is_cacheable(tool_name): return None cache_key = self._generate_cache_key(tool_name, arguments) conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() current_time = int(time.time()) cursor.execute(''' SELECT result_data, hit_count FROM tool_cache WHERE cache_key = ? AND expires_at > ? ''', (cache_key, current_time)) row = cursor.fetchone() if row: cursor.execute(''' UPDATE tool_cache SET hit_count = hit_count + 1 WHERE cache_key = ? ''', (cache_key,)) conn.commit() conn.close() return json.loads(row[0]) conn.close() return None def set(self, tool_name: str, arguments: dict, result: Any): if not self.is_cacheable(tool_name): return cache_key = self._generate_cache_key(tool_name, arguments) current_time = int(time.time()) expires_at = current_time + self.ttl_seconds conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute(''' INSERT OR REPLACE INTO tool_cache (cache_key, tool_name, result_data, created_at, expires_at, hit_count) VALUES (?, ?, ?, ?, ?, 0) ''', (cache_key, tool_name, json.dumps(result), current_time, expires_at)) conn.commit() conn.close() def invalidate_tool(self, tool_name: str): conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute('DELETE FROM tool_cache WHERE tool_name = ?', (tool_name,)) deleted_count = cursor.rowcount conn.commit() conn.close() return deleted_count def clear_expired(self): current_time = int(time.time()) conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute('DELETE FROM tool_cache WHERE expires_at <= ?', (current_time,)) deleted_count = cursor.rowcount conn.commit() conn.close() return deleted_count def clear_all(self): conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute('DELETE FROM tool_cache') deleted_count = cursor.rowcount conn.commit() conn.close() return deleted_count def get_statistics(self) -> dict: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute('SELECT COUNT(*) FROM tool_cache') total_entries = cursor.fetchone()[0] current_time = int(time.time()) cursor.execute('SELECT COUNT(*) FROM tool_cache WHERE expires_at > ?', (current_time,)) valid_entries = cursor.fetchone()[0] cursor.execute('SELECT SUM(hit_count) FROM tool_cache WHERE expires_at > ?', (current_time,)) total_hits = cursor.fetchone()[0] or 0 cursor.execute(''' SELECT tool_name, COUNT(*), SUM(hit_count) FROM tool_cache WHERE expires_at > ? GROUP BY tool_name ''', (current_time,)) tool_stats = {} for row in cursor.fetchall(): tool_stats[row[0]] = { 'cached_entries': row[1], 'total_hits': row[2] or 0 } conn.close() return { 'total_entries': total_entries, 'valid_entries': valid_entries, 'expired_entries': total_entries - valid_entries, 'total_cache_hits': total_hits, 'by_tool': tool_stats }