203 lines
5.7 KiB
Python
Raw Normal View History

2025-11-04 05:17:27 +01:00
import hashlib
import json
import sqlite3
import time
2025-11-04 08:09:12 +01:00
from typing import Any, Optional, Set
2025-11-04 05:17:27 +01:00
class ToolCache:
DETERMINISTIC_TOOLS: Set[str] = {
2025-11-04 08:09:12 +01:00
"read_file",
"list_directory",
"get_current_directory",
"db_get",
"db_query",
"index_directory",
"http_fetch",
"web_search",
"web_search_news",
"search_knowledge",
"get_knowledge_entry",
"get_knowledge_by_category",
"get_knowledge_statistics",
2025-11-04 05:17:27 +01:00
}
def __init__(self, db_path: str, ttl_seconds: int = 300):
self.db_path = db_path
self.ttl_seconds = ttl_seconds
self._initialize_cache()
def _initialize_cache(self):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
2025-11-04 08:09:12 +01:00
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
CREATE TABLE IF NOT EXISTS tool_cache (
cache_key TEXT PRIMARY KEY,
tool_name TEXT NOT NULL,
result_data TEXT NOT NULL,
created_at INTEGER NOT NULL,
expires_at INTEGER NOT NULL,
hit_count INTEGER DEFAULT 0
)
2025-11-04 08:09:12 +01:00
"""
)
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
CREATE INDEX IF NOT EXISTS idx_tool_expires ON tool_cache(expires_at)
2025-11-04 08:09:12 +01:00
"""
)
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
CREATE INDEX IF NOT EXISTS idx_tool_name ON tool_cache(tool_name)
2025-11-04 08:09:12 +01:00
"""
)
2025-11-04 05:17:27 +01:00
conn.commit()
conn.close()
def _generate_cache_key(self, tool_name: str, arguments: dict) -> str:
2025-11-04 08:09:12 +01:00
cache_data = {"tool": tool_name, "args": arguments}
2025-11-04 05:17:27 +01:00
serialized = json.dumps(cache_data, sort_keys=True)
return hashlib.sha256(serialized.encode()).hexdigest()
def is_cacheable(self, tool_name: str) -> bool:
return tool_name in self.DETERMINISTIC_TOOLS
def get(self, tool_name: str, arguments: dict) -> Optional[Any]:
if not self.is_cacheable(tool_name):
return None
cache_key = self._generate_cache_key(tool_name, arguments)
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
current_time = int(time.time())
2025-11-04 08:09:12 +01:00
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
SELECT result_data, hit_count FROM tool_cache
WHERE cache_key = ? AND expires_at > ?
2025-11-04 08:09:12 +01:00
""",
(cache_key, current_time),
)
2025-11-04 05:17:27 +01:00
row = cursor.fetchone()
if row:
2025-11-04 08:09:12 +01:00
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
UPDATE tool_cache SET hit_count = hit_count + 1
WHERE cache_key = ?
2025-11-04 08:09:12 +01:00
""",
(cache_key,),
)
2025-11-04 05:17:27 +01:00
conn.commit()
conn.close()
return json.loads(row[0])
conn.close()
return None
def set(self, tool_name: str, arguments: dict, result: Any):
if not self.is_cacheable(tool_name):
return
cache_key = self._generate_cache_key(tool_name, arguments)
current_time = int(time.time())
expires_at = current_time + self.ttl_seconds
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
2025-11-04 08:09:12 +01:00
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
INSERT OR REPLACE INTO tool_cache
(cache_key, tool_name, result_data, created_at, expires_at, hit_count)
VALUES (?, ?, ?, ?, ?, 0)
2025-11-04 08:09:12 +01:00
""",
(cache_key, tool_name, json.dumps(result), current_time, expires_at),
)
2025-11-04 05:17:27 +01:00
conn.commit()
conn.close()
def invalidate_tool(self, tool_name: str):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
2025-11-04 08:09:12 +01:00
cursor.execute("DELETE FROM tool_cache WHERE tool_name = ?", (tool_name,))
2025-11-04 05:17:27 +01:00
deleted_count = cursor.rowcount
conn.commit()
conn.close()
return deleted_count
def clear_expired(self):
current_time = int(time.time())
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
2025-11-04 08:09:12 +01:00
cursor.execute("DELETE FROM tool_cache WHERE expires_at <= ?", (current_time,))
2025-11-04 05:17:27 +01:00
deleted_count = cursor.rowcount
conn.commit()
conn.close()
return deleted_count
def clear_all(self):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
2025-11-04 08:09:12 +01:00
cursor.execute("DELETE FROM tool_cache")
2025-11-04 05:17:27 +01:00
deleted_count = cursor.rowcount
conn.commit()
conn.close()
return deleted_count
def get_statistics(self) -> dict:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
2025-11-04 08:09:12 +01:00
cursor.execute("SELECT COUNT(*) FROM tool_cache")
2025-11-04 05:17:27 +01:00
total_entries = cursor.fetchone()[0]
current_time = int(time.time())
2025-11-04 08:10:37 +01:00
cursor.execute("SELECT COUNT(*) FROM tool_cache WHERE expires_at > ?", (current_time,))
2025-11-04 05:17:27 +01:00
valid_entries = cursor.fetchone()[0]
2025-11-04 08:09:12 +01:00
cursor.execute(
"SELECT SUM(hit_count) FROM tool_cache WHERE expires_at > ?",
(current_time,),
)
2025-11-04 05:17:27 +01:00
total_hits = cursor.fetchone()[0] or 0
2025-11-04 08:09:12 +01:00
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
SELECT tool_name, COUNT(*), SUM(hit_count)
FROM tool_cache
WHERE expires_at > ?
GROUP BY tool_name
2025-11-04 08:09:12 +01:00
""",
(current_time,),
)
2025-11-04 05:17:27 +01:00
tool_stats = {}
for row in cursor.fetchall():
2025-11-04 08:09:12 +01:00
tool_stats[row[0]] = {"cached_entries": row[1], "total_hits": row[2] or 0}
2025-11-04 05:17:27 +01:00
conn.close()
return {
2025-11-04 08:09:12 +01:00
"total_entries": total_entries,
"valid_entries": valid_entries,
"expired_entries": total_entries - valid_entries,
"total_cache_hits": total_hits,
"by_tool": tool_stats,
2025-11-04 05:17:27 +01:00
}