373 lines
13 KiB
Python
Raw Normal View History

import sqlite3
import tempfile
import os
import time
from rp.memory.conversation_memory import ConversationMemory
class TestConversationMemory:
def setup_method(self):
"""Set up test database for each test."""
self.db_fd, self.db_path = tempfile.mkstemp()
self.memory = ConversationMemory(self.db_path)
def teardown_method(self):
"""Clean up test database after each test."""
self.memory = None
os.close(self.db_fd)
os.unlink(self.db_path)
def test_init(self):
"""Test ConversationMemory initialization."""
assert self.memory.db_path == self.db_path
# Verify tables were created
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = [row[0] for row in cursor.fetchall()]
assert "conversation_history" in tables
assert "conversation_messages" in tables
conn.close()
def test_create_conversation(self):
"""Test creating a new conversation."""
conversation_id = "test_conv_123"
session_id = "test_session_456"
self.memory.create_conversation(conversation_id, session_id)
# Verify conversation was created
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"SELECT conversation_id, session_id FROM conversation_history WHERE conversation_id = ?",
(conversation_id,),
)
row = cursor.fetchone()
assert row[0] == conversation_id
assert row[1] == session_id
conn.close()
def test_create_conversation_without_session(self):
"""Test creating a conversation without session ID."""
conversation_id = "test_conv_no_session"
self.memory.create_conversation(conversation_id)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"SELECT conversation_id, session_id FROM conversation_history WHERE conversation_id = ?",
(conversation_id,),
)
row = cursor.fetchone()
assert row[0] == conversation_id
assert row[1] is None
conn.close()
def test_create_conversation_with_metadata(self):
"""Test creating a conversation with metadata."""
conversation_id = "test_conv_metadata"
metadata = {"topic": "test", "priority": "high"}
self.memory.create_conversation(conversation_id, metadata=metadata)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"SELECT conversation_id, metadata FROM conversation_history WHERE conversation_id = ?",
(conversation_id,),
)
row = cursor.fetchone()
assert row[0] == conversation_id
assert row[1] is not None
conn.close()
def test_add_message(self):
"""Test adding a message to a conversation."""
conversation_id = "test_conv_msg"
message_id = "test_msg_123"
role = "user"
content = "Hello, world!"
# Create conversation first
self.memory.create_conversation(conversation_id)
# Add message
self.memory.add_message(conversation_id, message_id, role, content)
# Verify message was added
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"SELECT message_id, conversation_id, role, content FROM conversation_messages WHERE message_id = ?",
(message_id,),
)
row = cursor.fetchone()
assert row[0] == message_id
assert row[1] == conversation_id
assert row[2] == role
assert row[3] == content
# Verify message count was updated
cursor.execute(
"SELECT message_count FROM conversation_history WHERE conversation_id = ?",
(conversation_id,),
)
count_row = cursor.fetchone()
assert count_row[0] == 1
conn.close()
def test_add_message_with_tool_calls(self):
"""Test adding a message with tool calls."""
conversation_id = "test_conv_tools"
message_id = "test_msg_tools"
role = "assistant"
content = "I'll help you with that."
tool_calls = [{"function": "test_func", "args": {"param": "value"}}]
self.memory.create_conversation(conversation_id)
self.memory.add_message(conversation_id, message_id, role, content, tool_calls=tool_calls)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"SELECT tool_calls FROM conversation_messages WHERE message_id = ?", (message_id,)
)
row = cursor.fetchone()
assert row[0] is not None
conn.close()
def test_add_message_with_metadata(self):
"""Test adding a message with metadata."""
conversation_id = "test_conv_meta"
message_id = "test_msg_meta"
role = "user"
content = "Test message"
metadata = {"tokens": 5, "model": "gpt-4"}
self.memory.create_conversation(conversation_id)
self.memory.add_message(conversation_id, message_id, role, content, metadata=metadata)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"SELECT metadata FROM conversation_messages WHERE message_id = ?", (message_id,)
)
row = cursor.fetchone()
assert row[0] is not None
conn.close()
def test_get_conversation_messages(self):
"""Test retrieving conversation messages."""
conversation_id = "test_conv_get"
self.memory.create_conversation(conversation_id)
# Add multiple messages
messages = [
("msg1", "user", "Hello"),
("msg2", "assistant", "Hi there"),
("msg3", "user", "How are you?"),
]
for msg_id, role, content in messages:
self.memory.add_message(conversation_id, msg_id, role, content)
# Retrieve all messages
retrieved = self.memory.get_conversation_messages(conversation_id)
assert len(retrieved) == 3
assert retrieved[0]["message_id"] == "msg1"
assert retrieved[1]["message_id"] == "msg2"
assert retrieved[2]["message_id"] == "msg3"
def test_get_conversation_messages_limited(self):
"""Test retrieving limited number of conversation messages."""
conversation_id = "test_conv_limit"
self.memory.create_conversation(conversation_id)
# Add multiple messages
for i in range(5):
self.memory.add_message(conversation_id, f"msg{i}", "user", f"Message {i}")
# Retrieve limited messages
retrieved = self.memory.get_conversation_messages(conversation_id, limit=3)
assert len(retrieved) == 3
# Should return most recent messages first due to DESC order
assert retrieved[0]["message_id"] == "msg4"
assert retrieved[1]["message_id"] == "msg3"
assert retrieved[2]["message_id"] == "msg2"
def test_update_conversation_summary(self):
"""Test updating conversation summary."""
conversation_id = "test_conv_summary"
self.memory.create_conversation(conversation_id)
summary = "This is a test conversation summary"
topics = ["testing", "memory", "conversation"]
self.memory.update_conversation_summary(conversation_id, summary, topics)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"SELECT summary, topics, ended_at FROM conversation_history WHERE conversation_id = ?",
(conversation_id,),
)
row = cursor.fetchone()
assert row[0] == summary
assert row[1] is not None # topics should be stored
assert row[2] is not None # ended_at should be set
conn.close()
def test_search_conversations(self):
"""Test searching conversations by content."""
# Create conversations with different content
conv1 = "conv_search_1"
conv2 = "conv_search_2"
conv3 = "conv_search_3"
self.memory.create_conversation(conv1)
self.memory.create_conversation(conv2)
self.memory.create_conversation(conv3)
# Add messages with searchable content
self.memory.add_message(conv1, "msg1", "user", "Python programming tutorial")
self.memory.add_message(conv2, "msg2", "user", "JavaScript development guide")
self.memory.add_message(conv3, "msg3", "user", "Database design principles")
# Search for "programming"
results = self.memory.search_conversations("programming")
assert len(results) == 1
assert results[0]["conversation_id"] == conv1
# Search for "development"
results = self.memory.search_conversations("development")
assert len(results) == 1
assert results[0]["conversation_id"] == conv2
def test_get_recent_conversations(self):
"""Test getting recent conversations."""
# Create conversations at different times
conv1 = "conv_recent_1"
conv2 = "conv_recent_2"
conv3 = "conv_recent_3"
self.memory.create_conversation(conv1)
time.sleep(0.01) # Small delay to ensure different timestamps
self.memory.create_conversation(conv2)
time.sleep(0.01)
self.memory.create_conversation(conv3)
# Get recent conversations
recent = self.memory.get_recent_conversations(limit=2)
assert len(recent) == 2
# Should be ordered by started_at DESC
assert recent[0]["conversation_id"] == conv3
assert recent[1]["conversation_id"] == conv2
def test_get_recent_conversations_by_session(self):
"""Test getting recent conversations for a specific session."""
session1 = "session_1"
session2 = "session_2"
conv1 = "conv_session_1"
conv2 = "conv_session_2"
conv3 = "conv_session_3"
self.memory.create_conversation(conv1, session1)
self.memory.create_conversation(conv2, session2)
self.memory.create_conversation(conv3, session1)
# Get conversations for session1
session_convs = self.memory.get_recent_conversations(session_id=session1)
assert len(session_convs) == 2
conversation_ids = [c["conversation_id"] for c in session_convs]
assert conv1 in conversation_ids
assert conv3 in conversation_ids
assert conv2 not in conversation_ids
def test_delete_conversation(self):
"""Test deleting a conversation."""
conversation_id = "conv_delete"
self.memory.create_conversation(conversation_id)
# Add some messages
self.memory.add_message(conversation_id, "msg1", "user", "Test message")
self.memory.add_message(conversation_id, "msg2", "assistant", "Response")
# Delete conversation
result = self.memory.delete_conversation(conversation_id)
assert result is True
# Verify conversation and messages are gone
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"SELECT COUNT(*) FROM conversation_history WHERE conversation_id = ?",
(conversation_id,),
)
assert cursor.fetchone()[0] == 0
cursor.execute(
"SELECT COUNT(*) FROM conversation_messages WHERE conversation_id = ?",
(conversation_id,),
)
assert cursor.fetchone()[0] == 0
conn.close()
def test_delete_nonexistent_conversation(self):
"""Test deleting a non-existent conversation."""
result = self.memory.delete_conversation("nonexistent")
assert result is False
def test_get_statistics(self):
"""Test getting memory statistics."""
# Create some conversations and messages
for i in range(3):
conv_id = f"conv_stats_{i}"
self.memory.create_conversation(conv_id)
for j in range(2):
self.memory.add_message(conv_id, f"msg_{i}_{j}", "user", f"Message {j}")
stats = self.memory.get_statistics()
assert stats["total_conversations"] == 3
assert stats["total_messages"] == 6
assert stats["average_messages_per_conversation"] == 2.0
def test_thread_safety(self):
"""Test that the memory can handle concurrent access."""
import threading
import queue
results = queue.Queue()
def worker(worker_id):
try:
conv_id = f"conv_thread_{worker_id}"
self.memory.create_conversation(conv_id)
self.memory.add_message(conv_id, f"msg_{worker_id}", "user", f"Worker {worker_id}")
results.put(True)
except Exception as e:
results.put(e)
# Start multiple threads
threads = []
for i in range(5):
t = threading.Thread(target=worker, args=(i,))
threads.append(t)
t.start()
# Wait for all threads
for t in threads:
t.join()
# Check results
for _ in range(5):
result = results.get()
assert result is True
# Verify all conversations were created
recent = self.memory.get_recent_conversations(limit=10)
assert len(recent) >= 5