128 lines
4.1 KiB
Python
128 lines
4.1 KiB
Python
|
|
import hashlib
|
||
|
|
import json
|
||
|
|
import sqlite3
|
||
|
|
import time
|
||
|
|
from typing import Optional, Dict, Any
|
||
|
|
|
||
|
|
class APICache:
|
||
|
|
def __init__(self, db_path: str, ttl_seconds: int = 3600):
|
||
|
|
self.db_path = db_path
|
||
|
|
self.ttl_seconds = ttl_seconds
|
||
|
|
self._initialize_cache()
|
||
|
|
|
||
|
|
def _initialize_cache(self):
|
||
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||
|
|
cursor = conn.cursor()
|
||
|
|
cursor.execute('''
|
||
|
|
CREATE TABLE IF NOT EXISTS api_cache (
|
||
|
|
cache_key TEXT PRIMARY KEY,
|
||
|
|
response_data TEXT NOT NULL,
|
||
|
|
created_at INTEGER NOT NULL,
|
||
|
|
expires_at INTEGER NOT NULL,
|
||
|
|
model TEXT,
|
||
|
|
token_count INTEGER
|
||
|
|
)
|
||
|
|
''')
|
||
|
|
cursor.execute('''
|
||
|
|
CREATE INDEX IF NOT EXISTS idx_expires_at ON api_cache(expires_at)
|
||
|
|
''')
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
def _generate_cache_key(self, model: str, messages: list, temperature: float, max_tokens: int) -> str:
|
||
|
|
cache_data = {
|
||
|
|
'model': model,
|
||
|
|
'messages': messages,
|
||
|
|
'temperature': temperature,
|
||
|
|
'max_tokens': max_tokens
|
||
|
|
}
|
||
|
|
serialized = json.dumps(cache_data, sort_keys=True)
|
||
|
|
return hashlib.sha256(serialized.encode()).hexdigest()
|
||
|
|
|
||
|
|
def get(self, model: str, messages: list, temperature: float, max_tokens: int) -> Optional[Dict[str, Any]]:
|
||
|
|
cache_key = self._generate_cache_key(model, messages, temperature, max_tokens)
|
||
|
|
|
||
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||
|
|
cursor = conn.cursor()
|
||
|
|
|
||
|
|
current_time = int(time.time())
|
||
|
|
cursor.execute('''
|
||
|
|
SELECT response_data FROM api_cache
|
||
|
|
WHERE cache_key = ? AND expires_at > ?
|
||
|
|
''', (cache_key, current_time))
|
||
|
|
|
||
|
|
row = cursor.fetchone()
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
if row:
|
||
|
|
return json.loads(row[0])
|
||
|
|
return None
|
||
|
|
|
||
|
|
def set(self, model: str, messages: list, temperature: float, max_tokens: int,
|
||
|
|
response: Dict[str, Any], token_count: int = 0):
|
||
|
|
cache_key = self._generate_cache_key(model, messages, temperature, max_tokens)
|
||
|
|
|
||
|
|
current_time = int(time.time())
|
||
|
|
expires_at = current_time + self.ttl_seconds
|
||
|
|
|
||
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||
|
|
cursor = conn.cursor()
|
||
|
|
|
||
|
|
cursor.execute('''
|
||
|
|
INSERT OR REPLACE INTO api_cache
|
||
|
|
(cache_key, response_data, created_at, expires_at, model, token_count)
|
||
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
||
|
|
''', (cache_key, json.dumps(response), current_time, expires_at, model, token_count))
|
||
|
|
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
def clear_expired(self):
|
||
|
|
current_time = int(time.time())
|
||
|
|
|
||
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||
|
|
cursor = conn.cursor()
|
||
|
|
|
||
|
|
cursor.execute('DELETE FROM api_cache WHERE expires_at <= ?', (current_time,))
|
||
|
|
deleted_count = cursor.rowcount
|
||
|
|
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
return deleted_count
|
||
|
|
|
||
|
|
def clear_all(self):
|
||
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||
|
|
cursor = conn.cursor()
|
||
|
|
|
||
|
|
cursor.execute('DELETE FROM api_cache')
|
||
|
|
deleted_count = cursor.rowcount
|
||
|
|
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
return deleted_count
|
||
|
|
|
||
|
|
def get_statistics(self) -> Dict[str, Any]:
|
||
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||
|
|
cursor = conn.cursor()
|
||
|
|
|
||
|
|
cursor.execute('SELECT COUNT(*) FROM api_cache')
|
||
|
|
total_entries = cursor.fetchone()[0]
|
||
|
|
|
||
|
|
current_time = int(time.time())
|
||
|
|
cursor.execute('SELECT COUNT(*) FROM api_cache WHERE expires_at > ?', (current_time,))
|
||
|
|
valid_entries = cursor.fetchone()[0]
|
||
|
|
|
||
|
|
cursor.execute('SELECT SUM(token_count) FROM api_cache WHERE expires_at > ?', (current_time,))
|
||
|
|
total_tokens = cursor.fetchone()[0] or 0
|
||
|
|
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
return {
|
||
|
|
'total_entries': total_entries,
|
||
|
|
'valid_entries': valid_entries,
|
||
|
|
'expired_entries': total_entries - valid_entries,
|
||
|
|
'total_cached_tokens': total_tokens
|
||
|
|
}
|