import hashlib import json import sqlite3 import time from typing import Any, Dict, Optional class APICache: def __init__(self, db_path: str, ttl_seconds: int = 3600): self.db_path = db_path self.ttl_seconds = ttl_seconds self._initialize_cache() def _initialize_cache(self): conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute( """ CREATE TABLE IF NOT EXISTS api_cache ( cache_key TEXT PRIMARY KEY, response_data TEXT NOT NULL, created_at INTEGER NOT NULL, expires_at INTEGER NOT NULL, model TEXT, token_count INTEGER ) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS idx_expires_at ON api_cache(expires_at) """ ) conn.commit() conn.close() def _generate_cache_key( self, model: str, messages: list, temperature: float, max_tokens: int ) -> str: cache_data = { "model": model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, } serialized = json.dumps(cache_data, sort_keys=True) return hashlib.sha256(serialized.encode()).hexdigest() def get( self, model: str, messages: list, temperature: float, max_tokens: int ) -> Optional[Dict[str, Any]]: cache_key = self._generate_cache_key(model, messages, temperature, max_tokens) conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() current_time = int(time.time()) cursor.execute( """ SELECT response_data FROM api_cache WHERE cache_key = ? AND expires_at > ? """, (cache_key, current_time), ) row = cursor.fetchone() conn.close() if row: return json.loads(row[0]) return None def set( self, model: str, messages: list, temperature: float, max_tokens: int, response: Dict[str, Any], token_count: int = 0, ): cache_key = self._generate_cache_key(model, messages, temperature, max_tokens) current_time = int(time.time()) expires_at = current_time + self.ttl_seconds conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute( """ INSERT OR REPLACE INTO api_cache (cache_key, response_data, created_at, expires_at, model, token_count) VALUES (?, ?, ?, ?, ?, ?) """, ( cache_key, json.dumps(response), current_time, expires_at, model, token_count, ), ) conn.commit() conn.close() def clear_expired(self): current_time = int(time.time()) conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute("DELETE FROM api_cache WHERE expires_at <= ?", (current_time,)) deleted_count = cursor.rowcount conn.commit() conn.close() return deleted_count def clear_all(self): conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute("DELETE FROM api_cache") deleted_count = cursor.rowcount conn.commit() conn.close() return deleted_count def get_statistics(self) -> Dict[str, Any]: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute("SELECT COUNT(*) FROM api_cache") total_entries = cursor.fetchone()[0] current_time = int(time.time()) cursor.execute("SELECT COUNT(*) FROM api_cache WHERE expires_at > ?", (current_time,)) valid_entries = cursor.fetchone()[0] cursor.execute( "SELECT SUM(token_count) FROM api_cache WHERE expires_at > ?", (current_time,), ) total_tokens = cursor.fetchone()[0] or 0 conn.close() return { "total_entries": total_entries, "valid_entries": valid_entries, "expired_entries": total_entries - valid_entries, "total_cached_tokens": total_tokens, }