|
import hashlib
|
|
import json
|
|
import logging
|
|
import sqlite3
|
|
import time
|
|
from typing import Any, Dict, Optional
|
|
|
|
from rp.config import CACHE_PREFIX_MIN_LENGTH, PRICING_CACHED, PRICING_INPUT
|
|
|
|
logger = logging.getLogger("rp")
|
|
|
|
|
|
class PromptPrefixCache:
|
|
CACHE_TTL = 3600
|
|
|
|
def __init__(self, db_path: str):
|
|
self.db_path = db_path
|
|
self._init_cache()
|
|
self.stats = {
|
|
'hits': 0,
|
|
'misses': 0,
|
|
'tokens_saved': 0,
|
|
'cost_saved': 0.0
|
|
}
|
|
|
|
def _init_cache(self):
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
|
cursor = conn.cursor()
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS prefix_cache (
|
|
prefix_hash TEXT PRIMARY KEY,
|
|
prefix_content TEXT NOT NULL,
|
|
token_count INTEGER NOT NULL,
|
|
created_at INTEGER NOT NULL,
|
|
expires_at INTEGER NOT NULL,
|
|
hit_count INTEGER DEFAULT 0,
|
|
last_used INTEGER
|
|
)
|
|
""")
|
|
cursor.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_prefix_expires ON prefix_cache(expires_at)
|
|
""")
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
def _generate_prefix_key(self, system_prompt: str, tool_definitions: list) -> str:
|
|
cache_data = {
|
|
'system_prompt': system_prompt,
|
|
'tools': tool_definitions
|
|
}
|
|
serialized = json.dumps(cache_data, sort_keys=True)
|
|
return hashlib.sha256(serialized.encode()).hexdigest()
|
|
|
|
def get_cached_prefix(
|
|
self,
|
|
system_prompt: str,
|
|
tool_definitions: list
|
|
) -> Optional[Dict[str, Any]]:
|
|
prefix_key = self._generate_prefix_key(system_prompt, tool_definitions)
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
|
cursor = conn.cursor()
|
|
current_time = int(time.time())
|
|
cursor.execute("""
|
|
SELECT prefix_content, token_count
|
|
FROM prefix_cache
|
|
WHERE prefix_hash = ? AND expires_at > ?
|
|
""", (prefix_key, current_time))
|
|
row = cursor.fetchone()
|
|
if row:
|
|
cursor.execute("""
|
|
UPDATE prefix_cache
|
|
SET hit_count = hit_count + 1, last_used = ?
|
|
WHERE prefix_hash = ?
|
|
""", (current_time, prefix_key))
|
|
conn.commit()
|
|
conn.close()
|
|
self.stats['hits'] += 1
|
|
self.stats['tokens_saved'] += row[1]
|
|
self.stats['cost_saved'] += row[1] * (PRICING_INPUT - PRICING_CACHED)
|
|
return {
|
|
'content': row[0],
|
|
'token_count': row[1],
|
|
'cached': True
|
|
}
|
|
conn.close()
|
|
self.stats['misses'] += 1
|
|
return None
|
|
|
|
def cache_prefix(
|
|
self,
|
|
system_prompt: str,
|
|
tool_definitions: list,
|
|
token_count: int
|
|
):
|
|
if token_count < CACHE_PREFIX_MIN_LENGTH:
|
|
return
|
|
prefix_key = self._generate_prefix_key(system_prompt, tool_definitions)
|
|
prefix_content = json.dumps({
|
|
'system_prompt': system_prompt,
|
|
'tools': tool_definitions
|
|
})
|
|
current_time = int(time.time())
|
|
expires_at = current_time + self.CACHE_TTL
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
|
cursor = conn.cursor()
|
|
cursor.execute("""
|
|
INSERT OR REPLACE INTO prefix_cache
|
|
(prefix_hash, prefix_content, token_count, created_at, expires_at, hit_count, last_used)
|
|
VALUES (?, ?, ?, ?, ?, 0, ?)
|
|
""", (prefix_key, prefix_content, token_count, current_time, expires_at, current_time))
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
def is_prefix_cached(self, system_prompt: str, tool_definitions: list) -> bool:
|
|
prefix_key = self._generate_prefix_key(system_prompt, tool_definitions)
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
|
cursor = conn.cursor()
|
|
current_time = int(time.time())
|
|
cursor.execute("""
|
|
SELECT 1 FROM prefix_cache
|
|
WHERE prefix_hash = ? AND expires_at > ?
|
|
""", (prefix_key, current_time))
|
|
result = cursor.fetchone() is not None
|
|
conn.close()
|
|
return result
|
|
|
|
def calculate_savings(self, cached_tokens: int, fresh_tokens: int) -> Dict[str, Any]:
|
|
cached_cost = cached_tokens * PRICING_CACHED
|
|
fresh_cost = fresh_tokens * PRICING_INPUT
|
|
savings = fresh_cost - cached_cost
|
|
return {
|
|
'cached_cost': cached_cost,
|
|
'fresh_cost': fresh_cost,
|
|
'savings': savings,
|
|
'savings_percent': (savings / fresh_cost * 100) if fresh_cost > 0 else 0,
|
|
'tokens_at_discount': cached_tokens
|
|
}
|
|
|
|
def get_statistics(self) -> Dict[str, Any]:
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
|
cursor = conn.cursor()
|
|
current_time = int(time.time())
|
|
cursor.execute("SELECT COUNT(*) FROM prefix_cache WHERE expires_at > ?", (current_time,))
|
|
valid_entries = cursor.fetchone()[0]
|
|
cursor.execute("SELECT SUM(token_count) FROM prefix_cache WHERE expires_at > ?", (current_time,))
|
|
total_tokens = cursor.fetchone()[0] or 0
|
|
cursor.execute("SELECT SUM(hit_count) FROM prefix_cache WHERE expires_at > ?", (current_time,))
|
|
total_hits = cursor.fetchone()[0] or 0
|
|
conn.close()
|
|
return {
|
|
'cached_prefixes': valid_entries,
|
|
'total_cached_tokens': total_tokens,
|
|
'database_hits': total_hits,
|
|
'session_stats': self.stats,
|
|
'hit_rate': self.stats['hits'] / (self.stats['hits'] + self.stats['misses'])
|
|
if (self.stats['hits'] + self.stats['misses']) > 0 else 0
|
|
}
|
|
|
|
def clear_expired(self) -> int:
|
|
current_time = int(time.time())
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
|
cursor = conn.cursor()
|
|
cursor.execute("DELETE FROM prefix_cache WHERE expires_at <= ?", (current_time,))
|
|
deleted = cursor.rowcount
|
|
conn.commit()
|
|
conn.close()
|
|
return deleted
|
|
|
|
def clear_all(self) -> int:
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
|
cursor = conn.cursor()
|
|
cursor.execute("DELETE FROM prefix_cache")
|
|
deleted = cursor.rowcount
|
|
conn.commit()
|
|
conn.close()
|
|
return deleted
|
|
|
|
|
|
def create_prefix_cache(db_path: str) -> PromptPrefixCache:
|
|
return PromptPrefixCache(db_path)
|