2025-11-06 15:15:06 +01:00
|
|
|
import sqlite3
|
|
|
|
|
import tempfile
|
|
|
|
|
import os
|
|
|
|
|
import time
|
|
|
|
|
|
2025-11-07 18:50:28 +01:00
|
|
|
from rp.memory.conversation_memory import ConversationMemory
|
2025-11-06 15:15:06 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestConversationMemory:
|
|
|
|
|
def setup_method(self):
|
|
|
|
|
"""Set up test database for each test."""
|
|
|
|
|
self.db_fd, self.db_path = tempfile.mkstemp()
|
|
|
|
|
self.memory = ConversationMemory(self.db_path)
|
|
|
|
|
|
|
|
|
|
def teardown_method(self):
|
|
|
|
|
"""Clean up test database after each test."""
|
|
|
|
|
self.memory = None
|
|
|
|
|
os.close(self.db_fd)
|
|
|
|
|
os.unlink(self.db_path)
|
|
|
|
|
|
|
|
|
|
def test_init(self):
|
|
|
|
|
"""Test ConversationMemory initialization."""
|
|
|
|
|
assert self.memory.db_path == self.db_path
|
|
|
|
|
# Verify tables were created
|
|
|
|
|
conn = sqlite3.connect(self.db_path)
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
|
|
|
|
tables = [row[0] for row in cursor.fetchall()]
|
|
|
|
|
assert "conversation_history" in tables
|
|
|
|
|
assert "conversation_messages" in tables
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
def test_create_conversation(self):
|
|
|
|
|
"""Test creating a new conversation."""
|
|
|
|
|
conversation_id = "test_conv_123"
|
|
|
|
|
session_id = "test_session_456"
|
|
|
|
|
|
|
|
|
|
self.memory.create_conversation(conversation_id, session_id)
|
|
|
|
|
|
|
|
|
|
# Verify conversation was created
|
|
|
|
|
conn = sqlite3.connect(self.db_path)
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
cursor.execute(
|
|
|
|
|
"SELECT conversation_id, session_id FROM conversation_history WHERE conversation_id = ?",
|
|
|
|
|
(conversation_id,),
|
|
|
|
|
)
|
|
|
|
|
row = cursor.fetchone()
|
|
|
|
|
assert row[0] == conversation_id
|
|
|
|
|
assert row[1] == session_id
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
def test_create_conversation_without_session(self):
|
|
|
|
|
"""Test creating a conversation without session ID."""
|
|
|
|
|
conversation_id = "test_conv_no_session"
|
|
|
|
|
|
|
|
|
|
self.memory.create_conversation(conversation_id)
|
|
|
|
|
|
|
|
|
|
conn = sqlite3.connect(self.db_path)
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
cursor.execute(
|
|
|
|
|
"SELECT conversation_id, session_id FROM conversation_history WHERE conversation_id = ?",
|
|
|
|
|
(conversation_id,),
|
|
|
|
|
)
|
|
|
|
|
row = cursor.fetchone()
|
|
|
|
|
assert row[0] == conversation_id
|
|
|
|
|
assert row[1] is None
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
def test_create_conversation_with_metadata(self):
|
|
|
|
|
"""Test creating a conversation with metadata."""
|
|
|
|
|
conversation_id = "test_conv_metadata"
|
|
|
|
|
metadata = {"topic": "test", "priority": "high"}
|
|
|
|
|
|
|
|
|
|
self.memory.create_conversation(conversation_id, metadata=metadata)
|
|
|
|
|
|
|
|
|
|
conn = sqlite3.connect(self.db_path)
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
cursor.execute(
|
|
|
|
|
"SELECT conversation_id, metadata FROM conversation_history WHERE conversation_id = ?",
|
|
|
|
|
(conversation_id,),
|
|
|
|
|
)
|
|
|
|
|
row = cursor.fetchone()
|
|
|
|
|
assert row[0] == conversation_id
|
|
|
|
|
assert row[1] is not None
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
def test_add_message(self):
|
|
|
|
|
"""Test adding a message to a conversation."""
|
|
|
|
|
conversation_id = "test_conv_msg"
|
|
|
|
|
message_id = "test_msg_123"
|
|
|
|
|
role = "user"
|
|
|
|
|
content = "Hello, world!"
|
|
|
|
|
|
|
|
|
|
# Create conversation first
|
|
|
|
|
self.memory.create_conversation(conversation_id)
|
|
|
|
|
|
|
|
|
|
# Add message
|
|
|
|
|
self.memory.add_message(conversation_id, message_id, role, content)
|
|
|
|
|
|
|
|
|
|
# Verify message was added
|
|
|
|
|
conn = sqlite3.connect(self.db_path)
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
cursor.execute(
|
|
|
|
|
"SELECT message_id, conversation_id, role, content FROM conversation_messages WHERE message_id = ?",
|
|
|
|
|
(message_id,),
|
|
|
|
|
)
|
|
|
|
|
row = cursor.fetchone()
|
|
|
|
|
assert row[0] == message_id
|
|
|
|
|
assert row[1] == conversation_id
|
|
|
|
|
assert row[2] == role
|
|
|
|
|
assert row[3] == content
|
|
|
|
|
|
|
|
|
|
# Verify message count was updated
|
|
|
|
|
cursor.execute(
|
|
|
|
|
"SELECT message_count FROM conversation_history WHERE conversation_id = ?",
|
|
|
|
|
(conversation_id,),
|
|
|
|
|
)
|
|
|
|
|
count_row = cursor.fetchone()
|
|
|
|
|
assert count_row[0] == 1
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
def test_add_message_with_tool_calls(self):
|
|
|
|
|
"""Test adding a message with tool calls."""
|
|
|
|
|
conversation_id = "test_conv_tools"
|
|
|
|
|
message_id = "test_msg_tools"
|
|
|
|
|
role = "assistant"
|
|
|
|
|
content = "I'll help you with that."
|
|
|
|
|
tool_calls = [{"function": "test_func", "args": {"param": "value"}}]
|
|
|
|
|
|
|
|
|
|
self.memory.create_conversation(conversation_id)
|
|
|
|
|
self.memory.add_message(conversation_id, message_id, role, content, tool_calls=tool_calls)
|
|
|
|
|
|
|
|
|
|
conn = sqlite3.connect(self.db_path)
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
cursor.execute(
|
|
|
|
|
"SELECT tool_calls FROM conversation_messages WHERE message_id = ?", (message_id,)
|
|
|
|
|
)
|
|
|
|
|
row = cursor.fetchone()
|
|
|
|
|
assert row[0] is not None
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
def test_add_message_with_metadata(self):
|
|
|
|
|
"""Test adding a message with metadata."""
|
|
|
|
|
conversation_id = "test_conv_meta"
|
|
|
|
|
message_id = "test_msg_meta"
|
|
|
|
|
role = "user"
|
|
|
|
|
content = "Test message"
|
|
|
|
|
metadata = {"tokens": 5, "model": "gpt-4"}
|
|
|
|
|
|
|
|
|
|
self.memory.create_conversation(conversation_id)
|
|
|
|
|
self.memory.add_message(conversation_id, message_id, role, content, metadata=metadata)
|
|
|
|
|
|
|
|
|
|
conn = sqlite3.connect(self.db_path)
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
cursor.execute(
|
|
|
|
|
"SELECT metadata FROM conversation_messages WHERE message_id = ?", (message_id,)
|
|
|
|
|
)
|
|
|
|
|
row = cursor.fetchone()
|
|
|
|
|
assert row[0] is not None
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
def test_get_conversation_messages(self):
|
|
|
|
|
"""Test retrieving conversation messages."""
|
|
|
|
|
conversation_id = "test_conv_get"
|
|
|
|
|
self.memory.create_conversation(conversation_id)
|
|
|
|
|
|
|
|
|
|
# Add multiple messages
|
|
|
|
|
messages = [
|
|
|
|
|
("msg1", "user", "Hello"),
|
|
|
|
|
("msg2", "assistant", "Hi there"),
|
|
|
|
|
("msg3", "user", "How are you?"),
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
for msg_id, role, content in messages:
|
|
|
|
|
self.memory.add_message(conversation_id, msg_id, role, content)
|
|
|
|
|
|
|
|
|
|
# Retrieve all messages
|
|
|
|
|
retrieved = self.memory.get_conversation_messages(conversation_id)
|
|
|
|
|
assert len(retrieved) == 3
|
|
|
|
|
assert retrieved[0]["message_id"] == "msg1"
|
|
|
|
|
assert retrieved[1]["message_id"] == "msg2"
|
|
|
|
|
assert retrieved[2]["message_id"] == "msg3"
|
|
|
|
|
|
|
|
|
|
def test_get_conversation_messages_limited(self):
|
|
|
|
|
"""Test retrieving limited number of conversation messages."""
|
|
|
|
|
conversation_id = "test_conv_limit"
|
|
|
|
|
self.memory.create_conversation(conversation_id)
|
|
|
|
|
|
|
|
|
|
# Add multiple messages
|
|
|
|
|
for i in range(5):
|
|
|
|
|
self.memory.add_message(conversation_id, f"msg{i}", "user", f"Message {i}")
|
|
|
|
|
|
|
|
|
|
# Retrieve limited messages
|
|
|
|
|
retrieved = self.memory.get_conversation_messages(conversation_id, limit=3)
|
|
|
|
|
assert len(retrieved) == 3
|
|
|
|
|
# Should return most recent messages first due to DESC order
|
|
|
|
|
assert retrieved[0]["message_id"] == "msg4"
|
|
|
|
|
assert retrieved[1]["message_id"] == "msg3"
|
|
|
|
|
assert retrieved[2]["message_id"] == "msg2"
|
|
|
|
|
|
|
|
|
|
def test_update_conversation_summary(self):
|
|
|
|
|
"""Test updating conversation summary."""
|
|
|
|
|
conversation_id = "test_conv_summary"
|
|
|
|
|
self.memory.create_conversation(conversation_id)
|
|
|
|
|
|
|
|
|
|
summary = "This is a test conversation summary"
|
|
|
|
|
topics = ["testing", "memory", "conversation"]
|
|
|
|
|
|
|
|
|
|
self.memory.update_conversation_summary(conversation_id, summary, topics)
|
|
|
|
|
|
|
|
|
|
conn = sqlite3.connect(self.db_path)
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
cursor.execute(
|
|
|
|
|
"SELECT summary, topics, ended_at FROM conversation_history WHERE conversation_id = ?",
|
|
|
|
|
(conversation_id,),
|
|
|
|
|
)
|
|
|
|
|
row = cursor.fetchone()
|
|
|
|
|
assert row[0] == summary
|
|
|
|
|
assert row[1] is not None # topics should be stored
|
|
|
|
|
assert row[2] is not None # ended_at should be set
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
def test_search_conversations(self):
|
|
|
|
|
"""Test searching conversations by content."""
|
|
|
|
|
# Create conversations with different content
|
|
|
|
|
conv1 = "conv_search_1"
|
|
|
|
|
conv2 = "conv_search_2"
|
|
|
|
|
conv3 = "conv_search_3"
|
|
|
|
|
|
|
|
|
|
self.memory.create_conversation(conv1)
|
|
|
|
|
self.memory.create_conversation(conv2)
|
|
|
|
|
self.memory.create_conversation(conv3)
|
|
|
|
|
|
|
|
|
|
# Add messages with searchable content
|
|
|
|
|
self.memory.add_message(conv1, "msg1", "user", "Python programming tutorial")
|
|
|
|
|
self.memory.add_message(conv2, "msg2", "user", "JavaScript development guide")
|
|
|
|
|
self.memory.add_message(conv3, "msg3", "user", "Database design principles")
|
|
|
|
|
|
|
|
|
|
# Search for "programming"
|
|
|
|
|
results = self.memory.search_conversations("programming")
|
|
|
|
|
assert len(results) == 1
|
|
|
|
|
assert results[0]["conversation_id"] == conv1
|
|
|
|
|
|
|
|
|
|
# Search for "development"
|
|
|
|
|
results = self.memory.search_conversations("development")
|
|
|
|
|
assert len(results) == 1
|
|
|
|
|
assert results[0]["conversation_id"] == conv2
|
|
|
|
|
|
|
|
|
|
def test_get_recent_conversations(self):
|
|
|
|
|
"""Test getting recent conversations."""
|
|
|
|
|
# Create conversations at different times
|
|
|
|
|
conv1 = "conv_recent_1"
|
|
|
|
|
conv2 = "conv_recent_2"
|
|
|
|
|
conv3 = "conv_recent_3"
|
|
|
|
|
|
|
|
|
|
self.memory.create_conversation(conv1)
|
|
|
|
|
time.sleep(0.01) # Small delay to ensure different timestamps
|
|
|
|
|
self.memory.create_conversation(conv2)
|
|
|
|
|
time.sleep(0.01)
|
|
|
|
|
self.memory.create_conversation(conv3)
|
|
|
|
|
|
|
|
|
|
# Get recent conversations
|
|
|
|
|
recent = self.memory.get_recent_conversations(limit=2)
|
|
|
|
|
assert len(recent) == 2
|
|
|
|
|
# Should be ordered by started_at DESC
|
|
|
|
|
assert recent[0]["conversation_id"] == conv3
|
|
|
|
|
assert recent[1]["conversation_id"] == conv2
|
|
|
|
|
|
|
|
|
|
def test_get_recent_conversations_by_session(self):
|
|
|
|
|
"""Test getting recent conversations for a specific session."""
|
|
|
|
|
session1 = "session_1"
|
|
|
|
|
session2 = "session_2"
|
|
|
|
|
|
|
|
|
|
conv1 = "conv_session_1"
|
|
|
|
|
conv2 = "conv_session_2"
|
|
|
|
|
conv3 = "conv_session_3"
|
|
|
|
|
|
|
|
|
|
self.memory.create_conversation(conv1, session1)
|
|
|
|
|
self.memory.create_conversation(conv2, session2)
|
|
|
|
|
self.memory.create_conversation(conv3, session1)
|
|
|
|
|
|
|
|
|
|
# Get conversations for session1
|
|
|
|
|
session_convs = self.memory.get_recent_conversations(session_id=session1)
|
|
|
|
|
assert len(session_convs) == 2
|
|
|
|
|
conversation_ids = [c["conversation_id"] for c in session_convs]
|
|
|
|
|
assert conv1 in conversation_ids
|
|
|
|
|
assert conv3 in conversation_ids
|
|
|
|
|
assert conv2 not in conversation_ids
|
|
|
|
|
|
|
|
|
|
def test_delete_conversation(self):
|
|
|
|
|
"""Test deleting a conversation."""
|
|
|
|
|
conversation_id = "conv_delete"
|
|
|
|
|
self.memory.create_conversation(conversation_id)
|
|
|
|
|
|
|
|
|
|
# Add some messages
|
|
|
|
|
self.memory.add_message(conversation_id, "msg1", "user", "Test message")
|
|
|
|
|
self.memory.add_message(conversation_id, "msg2", "assistant", "Response")
|
|
|
|
|
|
|
|
|
|
# Delete conversation
|
|
|
|
|
result = self.memory.delete_conversation(conversation_id)
|
|
|
|
|
assert result is True
|
|
|
|
|
|
|
|
|
|
# Verify conversation and messages are gone
|
|
|
|
|
conn = sqlite3.connect(self.db_path)
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
cursor.execute(
|
|
|
|
|
"SELECT COUNT(*) FROM conversation_history WHERE conversation_id = ?",
|
|
|
|
|
(conversation_id,),
|
|
|
|
|
)
|
|
|
|
|
assert cursor.fetchone()[0] == 0
|
|
|
|
|
|
|
|
|
|
cursor.execute(
|
|
|
|
|
"SELECT COUNT(*) FROM conversation_messages WHERE conversation_id = ?",
|
|
|
|
|
(conversation_id,),
|
|
|
|
|
)
|
|
|
|
|
assert cursor.fetchone()[0] == 0
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
def test_delete_nonexistent_conversation(self):
|
|
|
|
|
"""Test deleting a non-existent conversation."""
|
|
|
|
|
result = self.memory.delete_conversation("nonexistent")
|
|
|
|
|
assert result is False
|
|
|
|
|
|
|
|
|
|
def test_get_statistics(self):
|
|
|
|
|
"""Test getting memory statistics."""
|
|
|
|
|
# Create some conversations and messages
|
|
|
|
|
for i in range(3):
|
|
|
|
|
conv_id = f"conv_stats_{i}"
|
|
|
|
|
self.memory.create_conversation(conv_id)
|
|
|
|
|
for j in range(2):
|
|
|
|
|
self.memory.add_message(conv_id, f"msg_{i}_{j}", "user", f"Message {j}")
|
|
|
|
|
|
|
|
|
|
stats = self.memory.get_statistics()
|
|
|
|
|
assert stats["total_conversations"] == 3
|
|
|
|
|
assert stats["total_messages"] == 6
|
|
|
|
|
assert stats["average_messages_per_conversation"] == 2.0
|
|
|
|
|
|
|
|
|
|
def test_thread_safety(self):
|
|
|
|
|
"""Test that the memory can handle concurrent access."""
|
|
|
|
|
import threading
|
|
|
|
|
import queue
|
|
|
|
|
|
|
|
|
|
results = queue.Queue()
|
|
|
|
|
|
|
|
|
|
def worker(worker_id):
|
|
|
|
|
try:
|
|
|
|
|
conv_id = f"conv_thread_{worker_id}"
|
|
|
|
|
self.memory.create_conversation(conv_id)
|
|
|
|
|
self.memory.add_message(conv_id, f"msg_{worker_id}", "user", f"Worker {worker_id}")
|
|
|
|
|
results.put(True)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
results.put(e)
|
|
|
|
|
|
|
|
|
|
# Start multiple threads
|
|
|
|
|
threads = []
|
|
|
|
|
for i in range(5):
|
|
|
|
|
t = threading.Thread(target=worker, args=(i,))
|
|
|
|
|
threads.append(t)
|
|
|
|
|
t.start()
|
|
|
|
|
|
|
|
|
|
# Wait for all threads
|
|
|
|
|
for t in threads:
|
|
|
|
|
t.join()
|
|
|
|
|
|
|
|
|
|
# Check results
|
|
|
|
|
for _ in range(5):
|
|
|
|
|
result = results.get()
|
|
|
|
|
assert result is True
|
|
|
|
|
|
|
|
|
|
# Verify all conversations were created
|
|
|
|
|
recent = self.memory.get_recent_conversations(limit=10)
|
|
|
|
|
assert len(recent) >= 5
|