import hashlib import json import sqlite3 import time from typing import Any, Optional, Set class ToolCache: DETERMINISTIC_TOOLS: Set[str] = { "read_file", "list_directory", "get_current_directory", "db_get", "db_query", "index_directory", "http_fetch", "web_search", "web_search_news", "search_knowledge", "get_knowledge_entry", "get_knowledge_by_category", "get_knowledge_statistics", } def __init__(self, db_path: str, ttl_seconds: int = 300): self.db_path = db_path self.ttl_seconds = ttl_seconds self._initialize_cache() def _initialize_cache(self): conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute( """ CREATE TABLE IF NOT EXISTS tool_cache ( cache_key TEXT PRIMARY KEY, tool_name TEXT NOT NULL, result_data TEXT NOT NULL, created_at INTEGER NOT NULL, expires_at INTEGER NOT NULL, hit_count INTEGER DEFAULT 0 ) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_tool_expires ON tool_cache(expires_at) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_tool_name ON tool_cache(tool_name) """ ) conn.commit() conn.close() def _generate_cache_key(self, tool_name: str, arguments: dict) -> str: cache_data = {"tool": tool_name, "args": arguments} serialized = json.dumps(cache_data, sort_keys=True) return hashlib.sha256(serialized.encode()).hexdigest() def is_cacheable(self, tool_name: str) -> bool: return tool_name in self.DETERMINISTIC_TOOLS def get(self, tool_name: str, arguments: dict) -> Optional[Any]: if not self.is_cacheable(tool_name): return None cache_key = self._generate_cache_key(tool_name, arguments) conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() current_time = int(time.time()) cursor.execute( """ SELECT result_data, hit_count FROM tool_cache WHERE cache_key = ? AND expires_at > ? """, (cache_key, current_time), ) row = cursor.fetchone() if row: cursor.execute( """ UPDATE tool_cache SET hit_count = hit_count + 1 WHERE cache_key = ? """, (cache_key,), ) conn.commit() conn.close() return json.loads(row[0]) conn.close() return None def set(self, tool_name: str, arguments: dict, result: Any): if not self.is_cacheable(tool_name): return cache_key = self._generate_cache_key(tool_name, arguments) current_time = int(time.time()) expires_at = current_time + self.ttl_seconds conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute( """ INSERT OR REPLACE INTO tool_cache (cache_key, tool_name, result_data, created_at, expires_at, hit_count) VALUES (?, ?, ?, ?, ?, 0) """, (cache_key, tool_name, json.dumps(result), current_time, expires_at), ) conn.commit() conn.close() def invalidate_tool(self, tool_name: str): conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute("DELETE FROM tool_cache WHERE tool_name = ?", (tool_name,)) deleted_count = cursor.rowcount conn.commit() conn.close() return deleted_count def clear_expired(self): current_time = int(time.time()) conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute("DELETE FROM tool_cache WHERE expires_at <= ?", (current_time,)) deleted_count = cursor.rowcount conn.commit() conn.close() return deleted_count def clear_all(self): conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute("DELETE FROM tool_cache") deleted_count = cursor.rowcount conn.commit() conn.close() return deleted_count def get_statistics(self) -> dict: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute("SELECT COUNT(*) FROM tool_cache") total_entries = cursor.fetchone()[0] current_time = int(time.time()) cursor.execute("SELECT COUNT(*) FROM tool_cache WHERE expires_at > ?", (current_time,)) valid_entries = cursor.fetchone()[0] cursor.execute( "SELECT SUM(hit_count) FROM tool_cache WHERE expires_at > ?", (current_time,), ) total_hits = cursor.fetchone()[0] or 0 cursor.execute( """ SELECT tool_name, COUNT(*), SUM(hit_count) FROM tool_cache WHERE expires_at > ? GROUP BY tool_name """, (current_time,), ) tool_stats = {} for row in cursor.fetchall(): tool_stats[row[0]] = {"cached_entries": row[1], "total_hits": row[2] or 0} conn.close() return { "total_entries": total_entries, "valid_entries": valid_entries, "expired_entries": total_entries - valid_entries, "total_cache_hits": total_hits, "by_tool": tool_stats, }