import hashlib
import json
import logging
import sqlite3
import time
from typing import Any, Dict, Optional
from rp.config import CACHE_PREFIX_MIN_LENGTH, PRICING_CACHED, PRICING_INPUT
logger = logging.getLogger("rp")
class PromptPrefixCache:
CACHE_TTL = 3600
def __init__(self, db_path: str):
self.db_path = db_path
self._init_cache()
self.stats = {
'hits': 0,
'misses': 0,
'tokens_saved': 0,
'cost_saved': 0.0
}
def _init_cache(self):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS prefix_cache (
prefix_hash TEXT PRIMARY KEY,
prefix_content TEXT NOT NULL,
token_count INTEGER NOT NULL,
created_at INTEGER NOT NULL,
expires_at INTEGER NOT NULL,
hit_count INTEGER DEFAULT 0,
last_used INTEGER
)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_prefix_expires ON prefix_cache(expires_at)
""")
conn.commit()
conn.close()
def _generate_prefix_key(self, system_prompt: str, tool_definitions: list) -> str:
cache_data = {
'system_prompt': system_prompt,
'tools': tool_definitions
}
serialized = json.dumps(cache_data, sort_keys=True)
return hashlib.sha256(serialized.encode()).hexdigest()
def get_cached_prefix(
self,
system_prompt: str,
tool_definitions: list
) -> Optional[Dict[str, Any]]:
prefix_key = self._generate_prefix_key(system_prompt, tool_definitions)
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
current_time = int(time.time())
cursor.execute("""
SELECT prefix_content, token_count
FROM prefix_cache
WHERE prefix_hash = ? AND expires_at > ?
""", (prefix_key, current_time))
row = cursor.fetchone()
if row:
cursor.execute("""
UPDATE prefix_cache
SET hit_count = hit_count + 1, last_used = ?
WHERE prefix_hash = ?
""", (current_time, prefix_key))
conn.commit()
conn.close()
self.stats['hits'] += 1
self.stats['tokens_saved'] += row[1]
self.stats['cost_saved'] += row[1] * (PRICING_INPUT - PRICING_CACHED)
return {
'content': row[0],
'token_count': row[1],
'cached': True
}
conn.close()
self.stats['misses'] += 1
return None
def cache_prefix(
self,
system_prompt: str,
tool_definitions: list,
token_count: int
):
if token_count < CACHE_PREFIX_MIN_LENGTH:
return
prefix_key = self._generate_prefix_key(system_prompt, tool_definitions)
prefix_content = json.dumps({
'system_prompt': system_prompt,
'tools': tool_definitions
})
current_time = int(time.time())
expires_at = current_time + self.CACHE_TTL
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute("""
INSERT OR REPLACE INTO prefix_cache
(prefix_hash, prefix_content, token_count, created_at, expires_at, hit_count, last_used)
VALUES (?, ?, ?, ?, ?, 0, ?)
""", (prefix_key, prefix_content, token_count, current_time, expires_at, current_time))
conn.commit()
conn.close()
def is_prefix_cached(self, system_prompt: str, tool_definitions: list) -> bool:
prefix_key = self._generate_prefix_key(system_prompt, tool_definitions)
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
current_time = int(time.time())
cursor.execute("""
SELECT 1 FROM prefix_cache
WHERE prefix_hash = ? AND expires_at > ?
""", (prefix_key, current_time))
result = cursor.fetchone() is not None
conn.close()
return result
def calculate_savings(self, cached_tokens: int, fresh_tokens: int) -> Dict[str, Any]:
cached_cost = cached_tokens * PRICING_CACHED
fresh_cost = fresh_tokens * PRICING_INPUT
savings = fresh_cost - cached_cost
return {
'cached_cost': cached_cost,
'fresh_cost': fresh_cost,
'savings': savings,
'savings_percent': (savings / fresh_cost * 100) if fresh_cost > 0 else 0,
'tokens_at_discount': cached_tokens
}
def get_statistics(self) -> Dict[str, Any]:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
current_time = int(time.time())
cursor.execute("SELECT COUNT(*) FROM prefix_cache WHERE expires_at > ?", (current_time,))
valid_entries = cursor.fetchone()[0]
cursor.execute("SELECT SUM(token_count) FROM prefix_cache WHERE expires_at > ?", (current_time,))
total_tokens = cursor.fetchone()[0] or 0
cursor.execute("SELECT SUM(hit_count) FROM prefix_cache WHERE expires_at > ?", (current_time,))
total_hits = cursor.fetchone()[0] or 0
conn.close()
return {
'cached_prefixes': valid_entries,
'total_cached_tokens': total_tokens,
'database_hits': total_hits,
'session_stats': self.stats,
'hit_rate': self.stats['hits'] / (self.stats['hits'] + self.stats['misses'])
if (self.stats['hits'] + self.stats['misses']) > 0 else 0
}
def clear_expired(self) -> int:
current_time = int(time.time())
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute("DELETE FROM prefix_cache WHERE expires_at <= ?", (current_time,))
deleted = cursor.rowcount
conn.commit()
conn.close()
return deleted
def clear_all(self) -> int:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute("DELETE FROM prefix_cache")
deleted = cursor.rowcount
conn.commit()
conn.close()
return deleted
def create_prefix_cache(db_path: str) -> PromptPrefixCache:
return PromptPrefixCache(db_path)