import hashlib import json import sqlite3 import time from typing import Optional, Dict, Any class APICache: def __init__(self, db_path: str, ttl_seconds: int = 3600): self.db_path = db_path self.ttl_seconds = ttl_seconds self._initialize_cache() def _initialize_cache(self): conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute(''' CREATE TABLE IF NOT EXISTS api_cache ( cache_key TEXT PRIMARY KEY, response_data TEXT NOT NULL, created_at INTEGER NOT NULL, expires_at INTEGER NOT NULL, model TEXT, token_count INTEGER ) ''') cursor.execute(''' CREATE INDEX IF NOT EXISTS idx_expires_at ON api_cache(expires_at) ''') conn.commit() conn.close() def _generate_cache_key(self, model: str, messages: list, temperature: float, max_tokens: int) -> str: cache_data = { 'model': model, 'messages': messages, 'temperature': temperature, 'max_tokens': max_tokens } serialized = json.dumps(cache_data, sort_keys=True) return hashlib.sha256(serialized.encode()).hexdigest() def get(self, model: str, messages: list, temperature: float, max_tokens: int) -> Optional[Dict[str, Any]]: cache_key = self._generate_cache_key(model, messages, temperature, max_tokens) conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() current_time = int(time.time()) cursor.execute(''' SELECT response_data FROM api_cache WHERE cache_key = ? AND expires_at > ? ''', (cache_key, current_time)) row = cursor.fetchone() conn.close() if row: return json.loads(row[0]) return None def set(self, model: str, messages: list, temperature: float, max_tokens: int, response: Dict[str, Any], token_count: int = 0): cache_key = self._generate_cache_key(model, messages, temperature, max_tokens) current_time = int(time.time()) expires_at = current_time + self.ttl_seconds conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute(''' INSERT OR REPLACE INTO api_cache (cache_key, response_data, created_at, expires_at, model, token_count) VALUES (?, ?, ?, ?, ?, ?) ''', (cache_key, json.dumps(response), current_time, expires_at, model, token_count)) conn.commit() conn.close() def clear_expired(self): current_time = int(time.time()) conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute('DELETE FROM api_cache WHERE expires_at <= ?', (current_time,)) deleted_count = cursor.rowcount conn.commit() conn.close() return deleted_count def clear_all(self): conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute('DELETE FROM api_cache') deleted_count = cursor.rowcount conn.commit() conn.close() return deleted_count def get_statistics(self) -> Dict[str, Any]: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute('SELECT COUNT(*) FROM api_cache') total_entries = cursor.fetchone()[0] current_time = int(time.time()) cursor.execute('SELECT COUNT(*) FROM api_cache WHERE expires_at > ?', (current_time,)) valid_entries = cursor.fetchone()[0] cursor.execute('SELECT SUM(token_count) FROM api_cache WHERE expires_at > ?', (current_time,)) total_tokens = cursor.fetchone()[0] or 0 conn.close() return { 'total_entries': total_entries, 'valid_entries': valid_entries, 'expired_entries': total_entries - valid_entries, 'total_cached_tokens': total_tokens }