260 lines
9.1 KiB
Python
260 lines
9.1 KiB
Python
|
|
import json
|
||
|
|
import sqlite3
|
||
|
|
import time
|
||
|
|
from typing import List, Dict, Any, Optional
|
||
|
|
|
||
|
|
class ConversationMemory:
|
||
|
|
def __init__(self, db_path: str):
|
||
|
|
self.db_path = db_path
|
||
|
|
self._initialize_memory()
|
||
|
|
|
||
|
|
def _initialize_memory(self):
|
||
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||
|
|
cursor = conn.cursor()
|
||
|
|
|
||
|
|
cursor.execute('''
|
||
|
|
CREATE TABLE IF NOT EXISTS conversation_history (
|
||
|
|
conversation_id TEXT PRIMARY KEY,
|
||
|
|
session_id TEXT,
|
||
|
|
started_at REAL NOT NULL,
|
||
|
|
ended_at REAL,
|
||
|
|
message_count INTEGER DEFAULT 0,
|
||
|
|
summary TEXT,
|
||
|
|
topics TEXT,
|
||
|
|
metadata TEXT
|
||
|
|
)
|
||
|
|
''')
|
||
|
|
|
||
|
|
cursor.execute('''
|
||
|
|
CREATE TABLE IF NOT EXISTS conversation_messages (
|
||
|
|
message_id TEXT PRIMARY KEY,
|
||
|
|
conversation_id TEXT NOT NULL,
|
||
|
|
role TEXT NOT NULL,
|
||
|
|
content TEXT NOT NULL,
|
||
|
|
timestamp REAL NOT NULL,
|
||
|
|
tool_calls TEXT,
|
||
|
|
metadata TEXT,
|
||
|
|
FOREIGN KEY (conversation_id) REFERENCES conversation_history(conversation_id)
|
||
|
|
)
|
||
|
|
''')
|
||
|
|
|
||
|
|
cursor.execute('''
|
||
|
|
CREATE INDEX IF NOT EXISTS idx_conv_session ON conversation_history(session_id)
|
||
|
|
''')
|
||
|
|
cursor.execute('''
|
||
|
|
CREATE INDEX IF NOT EXISTS idx_conv_started ON conversation_history(started_at DESC)
|
||
|
|
''')
|
||
|
|
cursor.execute('''
|
||
|
|
CREATE INDEX IF NOT EXISTS idx_msg_conversation ON conversation_messages(conversation_id)
|
||
|
|
''')
|
||
|
|
cursor.execute('''
|
||
|
|
CREATE INDEX IF NOT EXISTS idx_msg_timestamp ON conversation_messages(timestamp)
|
||
|
|
''')
|
||
|
|
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
def create_conversation(self, conversation_id: str, session_id: Optional[str] = None,
|
||
|
|
metadata: Optional[Dict[str, Any]] = None):
|
||
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||
|
|
cursor = conn.cursor()
|
||
|
|
|
||
|
|
cursor.execute('''
|
||
|
|
INSERT INTO conversation_history
|
||
|
|
(conversation_id, session_id, started_at, metadata)
|
||
|
|
VALUES (?, ?, ?, ?)
|
||
|
|
''', (
|
||
|
|
conversation_id,
|
||
|
|
session_id,
|
||
|
|
time.time(),
|
||
|
|
json.dumps(metadata) if metadata else None
|
||
|
|
))
|
||
|
|
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
def add_message(self, conversation_id: str, message_id: str, role: str,
|
||
|
|
content: str, tool_calls: Optional[List[Dict[str, Any]]] = None,
|
||
|
|
metadata: Optional[Dict[str, Any]] = None):
|
||
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||
|
|
cursor = conn.cursor()
|
||
|
|
|
||
|
|
cursor.execute('''
|
||
|
|
INSERT INTO conversation_messages
|
||
|
|
(message_id, conversation_id, role, content, timestamp, tool_calls, metadata)
|
||
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||
|
|
''', (
|
||
|
|
message_id,
|
||
|
|
conversation_id,
|
||
|
|
role,
|
||
|
|
content,
|
||
|
|
time.time(),
|
||
|
|
json.dumps(tool_calls) if tool_calls else None,
|
||
|
|
json.dumps(metadata) if metadata else None
|
||
|
|
))
|
||
|
|
|
||
|
|
cursor.execute('''
|
||
|
|
UPDATE conversation_history
|
||
|
|
SET message_count = message_count + 1
|
||
|
|
WHERE conversation_id = ?
|
||
|
|
''', (conversation_id,))
|
||
|
|
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
def get_conversation_messages(self, conversation_id: str,
|
||
|
|
limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
||
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||
|
|
cursor = conn.cursor()
|
||
|
|
|
||
|
|
if limit:
|
||
|
|
cursor.execute('''
|
||
|
|
SELECT message_id, role, content, timestamp, tool_calls, metadata
|
||
|
|
FROM conversation_messages
|
||
|
|
WHERE conversation_id = ?
|
||
|
|
ORDER BY timestamp DESC
|
||
|
|
LIMIT ?
|
||
|
|
''', (conversation_id, limit))
|
||
|
|
else:
|
||
|
|
cursor.execute('''
|
||
|
|
SELECT message_id, role, content, timestamp, tool_calls, metadata
|
||
|
|
FROM conversation_messages
|
||
|
|
WHERE conversation_id = ?
|
||
|
|
ORDER BY timestamp ASC
|
||
|
|
''', (conversation_id,))
|
||
|
|
|
||
|
|
messages = []
|
||
|
|
for row in cursor.fetchall():
|
||
|
|
messages.append({
|
||
|
|
'message_id': row[0],
|
||
|
|
'role': row[1],
|
||
|
|
'content': row[2],
|
||
|
|
'timestamp': row[3],
|
||
|
|
'tool_calls': json.loads(row[4]) if row[4] else None,
|
||
|
|
'metadata': json.loads(row[5]) if row[5] else None
|
||
|
|
})
|
||
|
|
|
||
|
|
conn.close()
|
||
|
|
return messages
|
||
|
|
|
||
|
|
def update_conversation_summary(self, conversation_id: str, summary: str,
|
||
|
|
topics: Optional[List[str]] = None):
|
||
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||
|
|
cursor = conn.cursor()
|
||
|
|
|
||
|
|
cursor.execute('''
|
||
|
|
UPDATE conversation_history
|
||
|
|
SET summary = ?, topics = ?, ended_at = ?
|
||
|
|
WHERE conversation_id = ?
|
||
|
|
''', (summary, json.dumps(topics) if topics else None, time.time(), conversation_id))
|
||
|
|
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
def search_conversations(self, query: str, limit: int = 10) -> List[Dict[str, Any]]:
|
||
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||
|
|
cursor = conn.cursor()
|
||
|
|
|
||
|
|
cursor.execute('''
|
||
|
|
SELECT DISTINCT h.conversation_id, h.session_id, h.started_at,
|
||
|
|
h.message_count, h.summary, h.topics
|
||
|
|
FROM conversation_history h
|
||
|
|
LEFT JOIN conversation_messages m ON h.conversation_id = m.conversation_id
|
||
|
|
WHERE h.summary LIKE ? OR h.topics LIKE ? OR m.content LIKE ?
|
||
|
|
ORDER BY h.started_at DESC
|
||
|
|
LIMIT ?
|
||
|
|
''', (f'%{query}%', f'%{query}%', f'%{query}%', limit))
|
||
|
|
|
||
|
|
conversations = []
|
||
|
|
for row in cursor.fetchall():
|
||
|
|
conversations.append({
|
||
|
|
'conversation_id': row[0],
|
||
|
|
'session_id': row[1],
|
||
|
|
'started_at': row[2],
|
||
|
|
'message_count': row[3],
|
||
|
|
'summary': row[4],
|
||
|
|
'topics': json.loads(row[5]) if row[5] else []
|
||
|
|
})
|
||
|
|
|
||
|
|
conn.close()
|
||
|
|
return conversations
|
||
|
|
|
||
|
|
def get_recent_conversations(self, limit: int = 10,
|
||
|
|
session_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||
|
|
cursor = conn.cursor()
|
||
|
|
|
||
|
|
if session_id:
|
||
|
|
cursor.execute('''
|
||
|
|
SELECT conversation_id, session_id, started_at, ended_at,
|
||
|
|
message_count, summary, topics
|
||
|
|
FROM conversation_history
|
||
|
|
WHERE session_id = ?
|
||
|
|
ORDER BY started_at DESC
|
||
|
|
LIMIT ?
|
||
|
|
''', (session_id, limit))
|
||
|
|
else:
|
||
|
|
cursor.execute('''
|
||
|
|
SELECT conversation_id, session_id, started_at, ended_at,
|
||
|
|
message_count, summary, topics
|
||
|
|
FROM conversation_history
|
||
|
|
ORDER BY started_at DESC
|
||
|
|
LIMIT ?
|
||
|
|
''', (limit,))
|
||
|
|
|
||
|
|
conversations = []
|
||
|
|
for row in cursor.fetchall():
|
||
|
|
conversations.append({
|
||
|
|
'conversation_id': row[0],
|
||
|
|
'session_id': row[1],
|
||
|
|
'started_at': row[2],
|
||
|
|
'ended_at': row[3],
|
||
|
|
'message_count': row[4],
|
||
|
|
'summary': row[5],
|
||
|
|
'topics': json.loads(row[6]) if row[6] else []
|
||
|
|
})
|
||
|
|
|
||
|
|
conn.close()
|
||
|
|
return conversations
|
||
|
|
|
||
|
|
def delete_conversation(self, conversation_id: str) -> bool:
|
||
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||
|
|
cursor = conn.cursor()
|
||
|
|
|
||
|
|
cursor.execute('DELETE FROM conversation_messages WHERE conversation_id = ?',
|
||
|
|
(conversation_id,))
|
||
|
|
cursor.execute('DELETE FROM conversation_history WHERE conversation_id = ?',
|
||
|
|
(conversation_id,))
|
||
|
|
|
||
|
|
deleted = cursor.rowcount > 0
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
return deleted
|
||
|
|
|
||
|
|
def get_statistics(self) -> Dict[str, Any]:
|
||
|
|
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||
|
|
cursor = conn.cursor()
|
||
|
|
|
||
|
|
cursor.execute('SELECT COUNT(*) FROM conversation_history')
|
||
|
|
total_conversations = cursor.fetchone()[0]
|
||
|
|
|
||
|
|
cursor.execute('SELECT COUNT(*) FROM conversation_messages')
|
||
|
|
total_messages = cursor.fetchone()[0]
|
||
|
|
|
||
|
|
cursor.execute('SELECT SUM(message_count) FROM conversation_history')
|
||
|
|
total_message_count = cursor.fetchone()[0] or 0
|
||
|
|
|
||
|
|
cursor.execute('''
|
||
|
|
SELECT AVG(message_count) FROM conversation_history WHERE message_count > 0
|
||
|
|
''')
|
||
|
|
avg_messages = cursor.fetchone()[0] or 0
|
||
|
|
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
return {
|
||
|
|
'total_conversations': total_conversations,
|
||
|
|
'total_messages': total_messages,
|
||
|
|
'average_messages_per_conversation': round(avg_messages, 2)
|
||
|
|
}
|