180 lines
5.4 KiB
Python
180 lines
5.4 KiB
Python
|
|
import hashlib
|
||
|
|
import json
|
||
|
|
import sqlite3
|
||
|
|
import time
|
||
|
|
from typing import Optional, Any, Set
|
||
|
|
|
||
|
|
class ToolCache:
|
||
|
|
DETERMINISTIC_TOOLS: Set[str] = {
|
||
|
|
'read_file',
|
||
|
|
'list_directory',
|
||
|
|
'get_current_directory',
|
||
|
|
'db_get',
|
||
|
|
'db_query',
|
||
|
|
'index_directory'
|
||
|
|
}
|
||
|
|
|
||
|
|
def __init__(self, db_path: str, ttl_seconds: int = 300):
|
||
|
|
self.db_path = db_path
|
||
|
|
self.ttl_seconds = ttl_seconds
|
||
|
|
self._initialize_cache()
|
||
|
|
|
||
|
|
def _initialize_cache(self):
|
||
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||
|
|
cursor = conn.cursor()
|
||
|
|
cursor.execute('''
|
||
|
|
CREATE TABLE IF NOT EXISTS tool_cache (
|
||
|
|
cache_key TEXT PRIMARY KEY,
|
||
|
|
tool_name TEXT NOT NULL,
|
||
|
|
result_data TEXT NOT NULL,
|
||
|
|
created_at INTEGER NOT NULL,
|
||
|
|
expires_at INTEGER NOT NULL,
|
||
|
|
hit_count INTEGER DEFAULT 0
|
||
|
|
)
|
||
|
|
''')
|
||
|
|
cursor.execute('''
|
||
|
|
CREATE INDEX IF NOT EXISTS idx_tool_expires ON tool_cache(expires_at)
|
||
|
|
''')
|
||
|
|
cursor.execute('''
|
||
|
|
CREATE INDEX IF NOT EXISTS idx_tool_name ON tool_cache(tool_name)
|
||
|
|
''')
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
def _generate_cache_key(self, tool_name: str, arguments: dict) -> str:
|
||
|
|
cache_data = {
|
||
|
|
'tool': tool_name,
|
||
|
|
'args': arguments
|
||
|
|
}
|
||
|
|
serialized = json.dumps(cache_data, sort_keys=True)
|
||
|
|
return hashlib.sha256(serialized.encode()).hexdigest()
|
||
|
|
|
||
|
|
def is_cacheable(self, tool_name: str) -> bool:
|
||
|
|
return tool_name in self.DETERMINISTIC_TOOLS
|
||
|
|
|
||
|
|
def get(self, tool_name: str, arguments: dict) -> Optional[Any]:
|
||
|
|
if not self.is_cacheable(tool_name):
|
||
|
|
return None
|
||
|
|
|
||
|
|
cache_key = self._generate_cache_key(tool_name, arguments)
|
||
|
|
|
||
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||
|
|
cursor = conn.cursor()
|
||
|
|
|
||
|
|
current_time = int(time.time())
|
||
|
|
cursor.execute('''
|
||
|
|
SELECT result_data, hit_count FROM tool_cache
|
||
|
|
WHERE cache_key = ? AND expires_at > ?
|
||
|
|
''', (cache_key, current_time))
|
||
|
|
|
||
|
|
row = cursor.fetchone()
|
||
|
|
|
||
|
|
if row:
|
||
|
|
cursor.execute('''
|
||
|
|
UPDATE tool_cache SET hit_count = hit_count + 1
|
||
|
|
WHERE cache_key = ?
|
||
|
|
''', (cache_key,))
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
return json.loads(row[0])
|
||
|
|
|
||
|
|
conn.close()
|
||
|
|
return None
|
||
|
|
|
||
|
|
def set(self, tool_name: str, arguments: dict, result: Any):
|
||
|
|
if not self.is_cacheable(tool_name):
|
||
|
|
return
|
||
|
|
|
||
|
|
cache_key = self._generate_cache_key(tool_name, arguments)
|
||
|
|
|
||
|
|
current_time = int(time.time())
|
||
|
|
expires_at = current_time + self.ttl_seconds
|
||
|
|
|
||
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||
|
|
cursor = conn.cursor()
|
||
|
|
|
||
|
|
cursor.execute('''
|
||
|
|
INSERT OR REPLACE INTO tool_cache
|
||
|
|
(cache_key, tool_name, result_data, created_at, expires_at, hit_count)
|
||
|
|
VALUES (?, ?, ?, ?, ?, 0)
|
||
|
|
''', (cache_key, tool_name, json.dumps(result), current_time, expires_at))
|
||
|
|
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
def invalidate_tool(self, tool_name: str):
|
||
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||
|
|
cursor = conn.cursor()
|
||
|
|
|
||
|
|
cursor.execute('DELETE FROM tool_cache WHERE tool_name = ?', (tool_name,))
|
||
|
|
deleted_count = cursor.rowcount
|
||
|
|
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
return deleted_count
|
||
|
|
|
||
|
|
def clear_expired(self):
|
||
|
|
current_time = int(time.time())
|
||
|
|
|
||
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||
|
|
cursor = conn.cursor()
|
||
|
|
|
||
|
|
cursor.execute('DELETE FROM tool_cache WHERE expires_at <= ?', (current_time,))
|
||
|
|
deleted_count = cursor.rowcount
|
||
|
|
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
return deleted_count
|
||
|
|
|
||
|
|
def clear_all(self):
|
||
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||
|
|
cursor = conn.cursor()
|
||
|
|
|
||
|
|
cursor.execute('DELETE FROM tool_cache')
|
||
|
|
deleted_count = cursor.rowcount
|
||
|
|
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
return deleted_count
|
||
|
|
|
||
|
|
def get_statistics(self) -> dict:
|
||
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||
|
|
cursor = conn.cursor()
|
||
|
|
|
||
|
|
cursor.execute('SELECT COUNT(*) FROM tool_cache')
|
||
|
|
total_entries = cursor.fetchone()[0]
|
||
|
|
|
||
|
|
current_time = int(time.time())
|
||
|
|
cursor.execute('SELECT COUNT(*) FROM tool_cache WHERE expires_at > ?', (current_time,))
|
||
|
|
valid_entries = cursor.fetchone()[0]
|
||
|
|
|
||
|
|
cursor.execute('SELECT SUM(hit_count) FROM tool_cache WHERE expires_at > ?', (current_time,))
|
||
|
|
total_hits = cursor.fetchone()[0] or 0
|
||
|
|
|
||
|
|
cursor.execute('''
|
||
|
|
SELECT tool_name, COUNT(*), SUM(hit_count)
|
||
|
|
FROM tool_cache
|
||
|
|
WHERE expires_at > ?
|
||
|
|
GROUP BY tool_name
|
||
|
|
''', (current_time,))
|
||
|
|
|
||
|
|
tool_stats = {}
|
||
|
|
for row in cursor.fetchall():
|
||
|
|
tool_stats[row[0]] = {
|
||
|
|
'cached_entries': row[1],
|
||
|
|
'total_hits': row[2] or 0
|
||
|
|
}
|
||
|
|
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
return {
|
||
|
|
'total_entries': total_entries,
|
||
|
|
'valid_entries': valid_entries,
|
||
|
|
'expired_entries': total_entries - valid_entries,
|
||
|
|
'total_cache_hits': total_hits,
|
||
|
|
'by_tool': tool_stats
|
||
|
|
}
|