import json import sqlite3 import time from typing import List, Dict, Any, Optional class ConversationMemory: def __init__(self, db_path: str): self.db_path = db_path self._initialize_memory() def _initialize_memory(self): conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute(''' CREATE TABLE IF NOT EXISTS conversation_history ( conversation_id TEXT PRIMARY KEY, session_id TEXT, started_at REAL NOT NULL, ended_at REAL, message_count INTEGER DEFAULT 0, summary TEXT, topics TEXT, metadata TEXT ) ''') cursor.execute(''' CREATE TABLE IF NOT EXISTS conversation_messages ( message_id TEXT PRIMARY KEY, conversation_id TEXT NOT NULL, role TEXT NOT NULL, content TEXT NOT NULL, timestamp REAL NOT NULL, tool_calls TEXT, metadata TEXT, FOREIGN KEY (conversation_id) REFERENCES conversation_history(conversation_id) ) ''') cursor.execute(''' CREATE INDEX IF NOT EXISTS idx_conv_session ON conversation_history(session_id) ''') cursor.execute(''' CREATE INDEX IF NOT EXISTS idx_conv_started ON conversation_history(started_at DESC) ''') cursor.execute(''' CREATE INDEX IF NOT EXISTS idx_msg_conversation ON conversation_messages(conversation_id) ''') cursor.execute(''' CREATE INDEX IF NOT EXISTS idx_msg_timestamp ON conversation_messages(timestamp) ''') conn.commit() conn.close() def create_conversation(self, conversation_id: str, session_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None): conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute(''' INSERT INTO conversation_history (conversation_id, session_id, started_at, metadata) VALUES (?, ?, ?, ?) ''', ( conversation_id, session_id, time.time(), json.dumps(metadata) if metadata else None )) conn.commit() conn.close() def add_message(self, conversation_id: str, message_id: str, role: str, content: str, tool_calls: Optional[List[Dict[str, Any]]] = None, metadata: Optional[Dict[str, Any]] = None): conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute(''' INSERT INTO conversation_messages (message_id, conversation_id, role, content, timestamp, tool_calls, metadata) VALUES (?, ?, ?, ?, ?, ?, ?) ''', ( message_id, conversation_id, role, content, time.time(), json.dumps(tool_calls) if tool_calls else None, json.dumps(metadata) if metadata else None )) cursor.execute(''' UPDATE conversation_history SET message_count = message_count + 1 WHERE conversation_id = ? ''', (conversation_id,)) conn.commit() conn.close() def get_conversation_messages(self, conversation_id: str, limit: Optional[int] = None) -> List[Dict[str, Any]]: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() if limit: cursor.execute(''' SELECT message_id, role, content, timestamp, tool_calls, metadata FROM conversation_messages WHERE conversation_id = ? ORDER BY timestamp DESC LIMIT ? ''', (conversation_id, limit)) else: cursor.execute(''' SELECT message_id, role, content, timestamp, tool_calls, metadata FROM conversation_messages WHERE conversation_id = ? ORDER BY timestamp ASC ''', (conversation_id,)) messages = [] for row in cursor.fetchall(): messages.append({ 'message_id': row[0], 'role': row[1], 'content': row[2], 'timestamp': row[3], 'tool_calls': json.loads(row[4]) if row[4] else None, 'metadata': json.loads(row[5]) if row[5] else None }) conn.close() return messages def update_conversation_summary(self, conversation_id: str, summary: str, topics: Optional[List[str]] = None): conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute(''' UPDATE conversation_history SET summary = ?, topics = ?, ended_at = ? WHERE conversation_id = ? ''', (summary, json.dumps(topics) if topics else None, time.time(), conversation_id)) conn.commit() conn.close() def search_conversations(self, query: str, limit: int = 10) -> List[Dict[str, Any]]: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute(''' SELECT DISTINCT h.conversation_id, h.session_id, h.started_at, h.message_count, h.summary, h.topics FROM conversation_history h LEFT JOIN conversation_messages m ON h.conversation_id = m.conversation_id WHERE h.summary LIKE ? OR h.topics LIKE ? OR m.content LIKE ? ORDER BY h.started_at DESC LIMIT ? ''', (f'%{query}%', f'%{query}%', f'%{query}%', limit)) conversations = [] for row in cursor.fetchall(): conversations.append({ 'conversation_id': row[0], 'session_id': row[1], 'started_at': row[2], 'message_count': row[3], 'summary': row[4], 'topics': json.loads(row[5]) if row[5] else [] }) conn.close() return conversations def get_recent_conversations(self, limit: int = 10, session_id: Optional[str] = None) -> List[Dict[str, Any]]: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() if session_id: cursor.execute(''' SELECT conversation_id, session_id, started_at, ended_at, message_count, summary, topics FROM conversation_history WHERE session_id = ? ORDER BY started_at DESC LIMIT ? ''', (session_id, limit)) else: cursor.execute(''' SELECT conversation_id, session_id, started_at, ended_at, message_count, summary, topics FROM conversation_history ORDER BY started_at DESC LIMIT ? ''', (limit,)) conversations = [] for row in cursor.fetchall(): conversations.append({ 'conversation_id': row[0], 'session_id': row[1], 'started_at': row[2], 'ended_at': row[3], 'message_count': row[4], 'summary': row[5], 'topics': json.loads(row[6]) if row[6] else [] }) conn.close() return conversations def delete_conversation(self, conversation_id: str) -> bool: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute('DELETE FROM conversation_messages WHERE conversation_id = ?', (conversation_id,)) cursor.execute('DELETE FROM conversation_history WHERE conversation_id = ?', (conversation_id,)) deleted = cursor.rowcount > 0 conn.commit() conn.close() return deleted def get_statistics(self) -> Dict[str, Any]: conn = sqlite3.connect(self.db_path, check_same_thread=False) cursor = conn.cursor() cursor.execute('SELECT COUNT(*) FROM conversation_history') total_conversations = cursor.fetchone()[0] cursor.execute('SELECT COUNT(*) FROM conversation_messages') total_messages = cursor.fetchone()[0] cursor.execute('SELECT SUM(message_count) FROM conversation_history') total_message_count = cursor.fetchone()[0] or 0 cursor.execute(''' SELECT AVG(message_count) FROM conversation_history WHERE message_count > 0 ''') avg_messages = cursor.fetchone()[0] or 0 conn.close() return { 'total_conversations': total_conversations, 'total_messages': total_messages, 'average_messages_per_conversation': round(avg_messages, 2) }