160 lines
4.5 KiB
Python
Raw Normal View History

2025-11-04 05:17:27 +01:00
import hashlib
import json
import sqlite3
import time
2025-11-04 08:09:12 +01:00
from typing import Any, Dict, Optional
2025-11-04 05:17:27 +01:00
class APICache:
def __init__(self, db_path: str, ttl_seconds: int = 3600):
self.db_path = db_path
self.ttl_seconds = ttl_seconds
self._initialize_cache()
def _initialize_cache(self):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
2025-11-04 08:09:12 +01:00
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
CREATE TABLE IF NOT EXISTS api_cache (
cache_key TEXT PRIMARY KEY,
response_data TEXT NOT NULL,
created_at INTEGER NOT NULL,
expires_at INTEGER NOT NULL,
model TEXT,
token_count INTEGER
)
2025-11-04 08:09:12 +01:00
"""
)
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
CREATE INDEX IF NOT EXISTS idx_expires_at ON api_cache(expires_at)
2025-11-04 08:09:12 +01:00
"""
)
2025-11-04 05:17:27 +01:00
conn.commit()
conn.close()
2025-11-04 08:09:12 +01:00
def _generate_cache_key(
self, model: str, messages: list, temperature: float, max_tokens: int
) -> str:
2025-11-04 05:17:27 +01:00
cache_data = {
2025-11-04 08:09:12 +01:00
"model": model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
2025-11-04 05:17:27 +01:00
}
serialized = json.dumps(cache_data, sort_keys=True)
return hashlib.sha256(serialized.encode()).hexdigest()
2025-11-04 08:09:12 +01:00
def get(
self, model: str, messages: list, temperature: float, max_tokens: int
) -> Optional[Dict[str, Any]]:
2025-11-04 05:17:27 +01:00
cache_key = self._generate_cache_key(model, messages, temperature, max_tokens)
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
current_time = int(time.time())
2025-11-04 08:09:12 +01:00
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
SELECT response_data FROM api_cache
WHERE cache_key = ? AND expires_at > ?
2025-11-04 08:09:12 +01:00
""",
(cache_key, current_time),
)
2025-11-04 05:17:27 +01:00
row = cursor.fetchone()
conn.close()
if row:
return json.loads(row[0])
return None
2025-11-04 08:09:12 +01:00
def set(
self,
model: str,
messages: list,
temperature: float,
max_tokens: int,
response: Dict[str, Any],
token_count: int = 0,
):
2025-11-04 05:17:27 +01:00
cache_key = self._generate_cache_key(model, messages, temperature, max_tokens)
current_time = int(time.time())
expires_at = current_time + self.ttl_seconds
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
2025-11-04 08:09:12 +01:00
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
INSERT OR REPLACE INTO api_cache
(cache_key, response_data, created_at, expires_at, model, token_count)
VALUES (?, ?, ?, ?, ?, ?)
2025-11-04 08:09:12 +01:00
""",
(
cache_key,
json.dumps(response),
current_time,
expires_at,
model,
token_count,
),
)
2025-11-04 05:17:27 +01:00
conn.commit()
conn.close()
def clear_expired(self):
current_time = int(time.time())
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
2025-11-04 08:09:12 +01:00
cursor.execute("DELETE FROM api_cache WHERE expires_at <= ?", (current_time,))
2025-11-04 05:17:27 +01:00
deleted_count = cursor.rowcount
conn.commit()
conn.close()
return deleted_count
def clear_all(self):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
2025-11-04 08:09:12 +01:00
cursor.execute("DELETE FROM api_cache")
2025-11-04 05:17:27 +01:00
deleted_count = cursor.rowcount
conn.commit()
conn.close()
return deleted_count
def get_statistics(self) -> Dict[str, Any]:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
2025-11-04 08:09:12 +01:00
cursor.execute("SELECT COUNT(*) FROM api_cache")
2025-11-04 05:17:27 +01:00
total_entries = cursor.fetchone()[0]
current_time = int(time.time())
2025-11-04 08:10:37 +01:00
cursor.execute("SELECT COUNT(*) FROM api_cache WHERE expires_at > ?", (current_time,))
2025-11-04 05:17:27 +01:00
valid_entries = cursor.fetchone()[0]
2025-11-04 08:09:12 +01:00
cursor.execute(
"SELECT SUM(token_count) FROM api_cache WHERE expires_at > ?",
(current_time,),
)
2025-11-04 05:17:27 +01:00
total_tokens = cursor.fetchone()[0] or 0
conn.close()
return {
2025-11-04 08:09:12 +01:00
"total_entries": total_entries,
"valid_entries": valid_entries,
"expired_entries": total_entries - valid_entries,
"total_cached_tokens": total_tokens,
2025-11-04 05:17:27 +01:00
}