|
import hashlib
|
|
import json
|
|
import sqlite3
|
|
import time
|
|
from typing import Any, Dict, Optional
|
|
|
|
|
|
class APICache:
|
|
def __init__(self, db_path: str, ttl_seconds: int = 3600):
|
|
self.db_path = db_path
|
|
self.ttl_seconds = ttl_seconds
|
|
self._initialize_cache()
|
|
|
|
def _initialize_cache(self):
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS api_cache (
|
|
cache_key TEXT PRIMARY KEY,
|
|
response_data TEXT NOT NULL,
|
|
created_at INTEGER NOT NULL,
|
|
expires_at INTEGER NOT NULL,
|
|
model TEXT,
|
|
token_count INTEGER,
|
|
hit_count INTEGER DEFAULT 0
|
|
)
|
|
"""
|
|
)
|
|
cursor.execute(
|
|
"""
|
|
CREATE INDEX IF NOT EXISTS idx_expires_at ON api_cache(expires_at)
|
|
"""
|
|
)
|
|
|
|
# Check if hit_count column exists, add if not
|
|
cursor.execute("PRAGMA table_info(api_cache)")
|
|
columns = [row[1] for row in cursor.fetchall()]
|
|
if "hit_count" not in columns:
|
|
cursor.execute("ALTER TABLE api_cache ADD COLUMN hit_count INTEGER DEFAULT 0")
|
|
conn.commit()
|
|
|
|
conn.close()
|
|
|
|
def _generate_cache_key(
|
|
self, model: str, messages: list, temperature: float, max_tokens: int
|
|
) -> str:
|
|
cache_data = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
"max_tokens": max_tokens,
|
|
}
|
|
serialized = json.dumps(cache_data, sort_keys=True)
|
|
return hashlib.sha256(serialized.encode()).hexdigest()
|
|
|
|
def get(
|
|
self, model: str, messages: list, temperature: float, max_tokens: int
|
|
) -> Optional[Dict[str, Any]]:
|
|
cache_key = self._generate_cache_key(model, messages, temperature, max_tokens)
|
|
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
|
cursor = conn.cursor()
|
|
|
|
current_time = int(time.time())
|
|
cursor.execute(
|
|
"""
|
|
SELECT response_data FROM api_cache
|
|
WHERE cache_key = ? AND expires_at > ?
|
|
""",
|
|
(cache_key, current_time),
|
|
)
|
|
|
|
row = cursor.fetchone()
|
|
|
|
if row:
|
|
# Increment hit count
|
|
cursor.execute(
|
|
"""
|
|
UPDATE api_cache SET hit_count = hit_count + 1
|
|
WHERE cache_key = ?
|
|
""",
|
|
(cache_key,),
|
|
)
|
|
conn.commit()
|
|
conn.close()
|
|
return json.loads(row[0])
|
|
|
|
conn.close()
|
|
return None
|
|
|
|
def set(
|
|
self,
|
|
model: str,
|
|
messages: list,
|
|
temperature: float,
|
|
max_tokens: int,
|
|
response: Dict[str, Any],
|
|
token_count: int = 0,
|
|
):
|
|
cache_key = self._generate_cache_key(model, messages, temperature, max_tokens)
|
|
|
|
current_time = int(time.time())
|
|
expires_at = current_time + self.ttl_seconds
|
|
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute(
|
|
"""
|
|
INSERT OR REPLACE INTO api_cache
|
|
(cache_key, response_data, created_at, expires_at, model, token_count, hit_count)
|
|
VALUES (?, ?, ?, ?, ?, ?, 0)
|
|
""",
|
|
(
|
|
cache_key,
|
|
json.dumps(response),
|
|
current_time,
|
|
expires_at,
|
|
model,
|
|
token_count,
|
|
),
|
|
)
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
def clear_expired(self):
|
|
current_time = int(time.time())
|
|
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("DELETE FROM api_cache WHERE expires_at <= ?", (current_time,))
|
|
deleted_count = cursor.rowcount
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
return deleted_count
|
|
|
|
def clear_all(self):
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("DELETE FROM api_cache")
|
|
deleted_count = cursor.rowcount
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
return deleted_count
|
|
|
|
def get_statistics(self) -> Dict[str, Any]:
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("SELECT COUNT(*) FROM api_cache")
|
|
total_entries = cursor.fetchone()[0]
|
|
|
|
current_time = int(time.time())
|
|
cursor.execute("SELECT COUNT(*) FROM api_cache WHERE expires_at > ?", (current_time,))
|
|
valid_entries = cursor.fetchone()[0]
|
|
|
|
cursor.execute(
|
|
"SELECT SUM(token_count) FROM api_cache WHERE expires_at > ?",
|
|
(current_time,),
|
|
)
|
|
total_tokens = cursor.fetchone()[0] or 0
|
|
|
|
cursor.execute(
|
|
"SELECT SUM(hit_count) FROM api_cache WHERE expires_at > ?",
|
|
(current_time,),
|
|
)
|
|
total_hits = cursor.fetchone()[0] or 0
|
|
|
|
conn.close()
|
|
|
|
return {
|
|
"total_entries": total_entries,
|
|
"valid_entries": valid_entries,
|
|
"expired_entries": total_entries - valid_entries,
|
|
"total_cached_tokens": total_tokens,
|
|
"total_cache_hits": total_hits,
|
|
}
|