330 lines
9.8 KiB
Python
Raw Normal View History

2025-11-04 05:17:27 +01:00
import json
import sqlite3
import time
2025-11-04 08:09:12 +01:00
from typing import Any, Dict, List, Optional
2025-11-04 05:17:27 +01:00
class ConversationMemory:
def __init__(self, db_path: str):
self.db_path = db_path
self._initialize_memory()
def _initialize_memory(self):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
2025-11-04 08:09:12 +01:00
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
CREATE TABLE IF NOT EXISTS conversation_history (
conversation_id TEXT PRIMARY KEY,
session_id TEXT,
started_at REAL NOT NULL,
ended_at REAL,
message_count INTEGER DEFAULT 0,
summary TEXT,
topics TEXT,
metadata TEXT
)
2025-11-04 08:09:12 +01:00
"""
)
2025-11-04 05:17:27 +01:00
2025-11-04 08:09:12 +01:00
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
CREATE TABLE IF NOT EXISTS conversation_messages (
message_id TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
timestamp REAL NOT NULL,
tool_calls TEXT,
metadata TEXT,
FOREIGN KEY (conversation_id) REFERENCES conversation_history(conversation_id)
)
2025-11-04 08:09:12 +01:00
"""
)
2025-11-04 05:17:27 +01:00
2025-11-04 08:09:12 +01:00
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
CREATE INDEX IF NOT EXISTS idx_conv_session ON conversation_history(session_id)
2025-11-04 08:09:12 +01:00
"""
)
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
CREATE INDEX IF NOT EXISTS idx_conv_started ON conversation_history(started_at DESC)
2025-11-04 08:09:12 +01:00
"""
)
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
CREATE INDEX IF NOT EXISTS idx_msg_conversation ON conversation_messages(conversation_id)
2025-11-04 08:09:12 +01:00
"""
)
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
CREATE INDEX IF NOT EXISTS idx_msg_timestamp ON conversation_messages(timestamp)
2025-11-04 08:09:12 +01:00
"""
)
2025-11-04 05:17:27 +01:00
conn.commit()
conn.close()
2025-11-04 08:09:12 +01:00
def create_conversation(
self,
conversation_id: str,
session_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
):
2025-11-04 05:17:27 +01:00
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
2025-11-04 08:09:12 +01:00
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
INSERT INTO conversation_history
(conversation_id, session_id, started_at, metadata)
VALUES (?, ?, ?, ?)
2025-11-04 08:09:12 +01:00
""",
(
conversation_id,
session_id,
time.time(),
json.dumps(metadata) if metadata else None,
),
)
2025-11-04 05:17:27 +01:00
conn.commit()
conn.close()
2025-11-04 08:09:12 +01:00
def add_message(
self,
conversation_id: str,
message_id: str,
role: str,
content: str,
tool_calls: Optional[List[Dict[str, Any]]] = None,
metadata: Optional[Dict[str, Any]] = None,
):
2025-11-04 05:17:27 +01:00
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
2025-11-04 08:09:12 +01:00
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
INSERT INTO conversation_messages
(message_id, conversation_id, role, content, timestamp, tool_calls, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?)
2025-11-04 08:09:12 +01:00
""",
(
message_id,
conversation_id,
role,
content,
time.time(),
json.dumps(tool_calls) if tool_calls else None,
json.dumps(metadata) if metadata else None,
),
)
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
UPDATE conversation_history
SET message_count = message_count + 1
WHERE conversation_id = ?
2025-11-04 08:09:12 +01:00
""",
(conversation_id,),
)
2025-11-04 05:17:27 +01:00
conn.commit()
conn.close()
2025-11-04 08:09:12 +01:00
def get_conversation_messages(
self, conversation_id: str, limit: Optional[int] = None
) -> List[Dict[str, Any]]:
2025-11-04 05:17:27 +01:00
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
if limit:
2025-11-04 08:09:12 +01:00
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
SELECT message_id, role, content, timestamp, tool_calls, metadata
FROM conversation_messages
WHERE conversation_id = ?
ORDER BY timestamp DESC
LIMIT ?
2025-11-04 08:09:12 +01:00
""",
(conversation_id, limit),
)
2025-11-04 05:17:27 +01:00
else:
2025-11-04 08:09:12 +01:00
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
SELECT message_id, role, content, timestamp, tool_calls, metadata
FROM conversation_messages
WHERE conversation_id = ?
ORDER BY timestamp ASC
2025-11-04 08:09:12 +01:00
""",
(conversation_id,),
)
2025-11-04 05:17:27 +01:00
messages = []
for row in cursor.fetchall():
2025-11-04 08:09:12 +01:00
messages.append(
{
"message_id": row[0],
"role": row[1],
"content": row[2],
"timestamp": row[3],
"tool_calls": json.loads(row[4]) if row[4] else None,
"metadata": json.loads(row[5]) if row[5] else None,
}
)
2025-11-04 05:17:27 +01:00
conn.close()
return messages
2025-11-04 08:09:12 +01:00
def update_conversation_summary(
self, conversation_id: str, summary: str, topics: Optional[List[str]] = None
):
2025-11-04 05:17:27 +01:00
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
2025-11-04 08:09:12 +01:00
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
UPDATE conversation_history
SET summary = ?, topics = ?, ended_at = ?
WHERE conversation_id = ?
2025-11-04 08:09:12 +01:00
""",
(
summary,
json.dumps(topics) if topics else None,
time.time(),
conversation_id,
),
)
2025-11-04 05:17:27 +01:00
conn.commit()
conn.close()
def search_conversations(self, query: str, limit: int = 10) -> List[Dict[str, Any]]:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
2025-11-04 08:09:12 +01:00
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
SELECT DISTINCT h.conversation_id, h.session_id, h.started_at,
h.message_count, h.summary, h.topics
FROM conversation_history h
LEFT JOIN conversation_messages m ON h.conversation_id = m.conversation_id
WHERE h.summary LIKE ? OR h.topics LIKE ? OR m.content LIKE ?
ORDER BY h.started_at DESC
LIMIT ?
2025-11-04 08:09:12 +01:00
""",
(f"%{query}%", f"%{query}%", f"%{query}%", limit),
)
2025-11-04 05:17:27 +01:00
conversations = []
for row in cursor.fetchall():
2025-11-04 08:09:12 +01:00
conversations.append(
{
"conversation_id": row[0],
"session_id": row[1],
"started_at": row[2],
"message_count": row[3],
"summary": row[4],
"topics": json.loads(row[5]) if row[5] else [],
}
)
2025-11-04 05:17:27 +01:00
conn.close()
return conversations
2025-11-04 08:09:12 +01:00
def get_recent_conversations(
self, limit: int = 10, session_id: Optional[str] = None
) -> List[Dict[str, Any]]:
2025-11-04 05:17:27 +01:00
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
if session_id:
2025-11-04 08:09:12 +01:00
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
SELECT conversation_id, session_id, started_at, ended_at,
message_count, summary, topics
FROM conversation_history
WHERE session_id = ?
ORDER BY started_at DESC
LIMIT ?
2025-11-04 08:09:12 +01:00
""",
(session_id, limit),
)
2025-11-04 05:17:27 +01:00
else:
2025-11-04 08:09:12 +01:00
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
SELECT conversation_id, session_id, started_at, ended_at,
message_count, summary, topics
FROM conversation_history
ORDER BY started_at DESC
LIMIT ?
2025-11-04 08:09:12 +01:00
""",
(limit,),
)
2025-11-04 05:17:27 +01:00
conversations = []
for row in cursor.fetchall():
2025-11-04 08:09:12 +01:00
conversations.append(
{
"conversation_id": row[0],
"session_id": row[1],
"started_at": row[2],
"ended_at": row[3],
"message_count": row[4],
"summary": row[5],
"topics": json.loads(row[6]) if row[6] else [],
}
)
2025-11-04 05:17:27 +01:00
conn.close()
return conversations
def delete_conversation(self, conversation_id: str) -> bool:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
2025-11-04 08:09:12 +01:00
cursor.execute(
"DELETE FROM conversation_messages WHERE conversation_id = ?",
(conversation_id,),
)
cursor.execute(
"DELETE FROM conversation_history WHERE conversation_id = ?",
(conversation_id,),
)
2025-11-04 05:17:27 +01:00
deleted = cursor.rowcount > 0
conn.commit()
conn.close()
return deleted
def get_statistics(self) -> Dict[str, Any]:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
2025-11-04 08:09:12 +01:00
cursor.execute("SELECT COUNT(*) FROM conversation_history")
2025-11-04 05:17:27 +01:00
total_conversations = cursor.fetchone()[0]
2025-11-04 08:09:12 +01:00
cursor.execute("SELECT COUNT(*) FROM conversation_messages")
2025-11-04 05:17:27 +01:00
total_messages = cursor.fetchone()[0]
2025-11-04 08:09:12 +01:00
cursor.execute("SELECT SUM(message_count) FROM conversation_history")
cursor.fetchone()[0] or 0
2025-11-04 05:17:27 +01:00
2025-11-04 08:09:12 +01:00
cursor.execute(
"""
2025-11-04 05:17:27 +01:00
SELECT AVG(message_count) FROM conversation_history WHERE message_count > 0
2025-11-04 08:09:12 +01:00
"""
)
2025-11-04 05:17:27 +01:00
avg_messages = cursor.fetchone()[0] or 0
conn.close()
return {
2025-11-04 08:09:12 +01:00
"total_conversations": total_conversations,
"total_messages": total_messages,
"average_messages_per_conversation": round(avg_messages, 2),
2025-11-04 05:17:27 +01:00
}