import os import sqlite3 import tempfile import time from rp.memory.conversation_memory import ConversationMemory class TestConversationMemory: def setup_method(self): """Set up test database for each test.""" self.db_fd, self.db_path = tempfile.mkstemp() self.memory = ConversationMemory(self.db_path) def teardown_method(self): """Clean up test database after each test.""" self.memory = None os.close(self.db_fd) os.unlink(self.db_path) def test_init(self): """Test ConversationMemory initialization.""" assert self.memory.db_path == self.db_path # Verify tables were created conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") tables = [row[0] for row in cursor.fetchall()] assert "conversation_history" in tables assert "conversation_messages" in tables conn.close() def test_create_conversation(self): """Test creating a new conversation.""" conversation_id = "test_conv_123" session_id = "test_session_456" self.memory.create_conversation(conversation_id, session_id) # Verify conversation was created conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( "SELECT conversation_id, session_id FROM conversation_history WHERE conversation_id = ?", (conversation_id,), ) row = cursor.fetchone() assert row[0] == conversation_id assert row[1] == session_id conn.close() def test_create_conversation_without_session(self): """Test creating a conversation without session ID.""" conversation_id = "test_conv_no_session" self.memory.create_conversation(conversation_id) conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( "SELECT conversation_id, session_id FROM conversation_history WHERE conversation_id = ?", (conversation_id,), ) row = cursor.fetchone() assert row[0] == conversation_id assert row[1] is None conn.close() def test_create_conversation_with_metadata(self): """Test creating a conversation with metadata.""" conversation_id = "test_conv_metadata" metadata = {"topic": "test", "priority": "high"} self.memory.create_conversation(conversation_id, metadata=metadata) conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( "SELECT conversation_id, metadata FROM conversation_history WHERE conversation_id = ?", (conversation_id,), ) row = cursor.fetchone() assert row[0] == conversation_id assert row[1] is not None conn.close() def test_add_message(self): """Test adding a message to a conversation.""" conversation_id = "test_conv_msg" message_id = "test_msg_123" role = "user" content = "Hello, world!" # Create conversation first self.memory.create_conversation(conversation_id) # Add message self.memory.add_message(conversation_id, message_id, role, content) # Verify message was added conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( "SELECT message_id, conversation_id, role, content FROM conversation_messages WHERE message_id = ?", (message_id,), ) row = cursor.fetchone() assert row[0] == message_id assert row[1] == conversation_id assert row[2] == role assert row[3] == content # Verify message count was updated cursor.execute( "SELECT message_count FROM conversation_history WHERE conversation_id = ?", (conversation_id,), ) count_row = cursor.fetchone() assert count_row[0] == 1 conn.close() def test_add_message_with_tool_calls(self): """Test adding a message with tool calls.""" conversation_id = "test_conv_tools" message_id = "test_msg_tools" role = "assistant" content = "I'll help you with that." tool_calls = [{"function": "test_func", "args": {"param": "value"}}] self.memory.create_conversation(conversation_id) self.memory.add_message(conversation_id, message_id, role, content, tool_calls=tool_calls) conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( "SELECT tool_calls FROM conversation_messages WHERE message_id = ?", (message_id,) ) row = cursor.fetchone() assert row[0] is not None conn.close() def test_add_message_with_metadata(self): """Test adding a message with metadata.""" conversation_id = "test_conv_meta" message_id = "test_msg_meta" role = "user" content = "Test message" metadata = {"tokens": 5, "model": "gpt-4"} self.memory.create_conversation(conversation_id) self.memory.add_message(conversation_id, message_id, role, content, metadata=metadata) conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( "SELECT metadata FROM conversation_messages WHERE message_id = ?", (message_id,) ) row = cursor.fetchone() assert row[0] is not None conn.close() def test_get_conversation_messages(self): """Test retrieving conversation messages.""" conversation_id = "test_conv_get" self.memory.create_conversation(conversation_id) # Add multiple messages messages = [ ("msg1", "user", "Hello"), ("msg2", "assistant", "Hi there"), ("msg3", "user", "How are you?"), ] for msg_id, role, content in messages: self.memory.add_message(conversation_id, msg_id, role, content) # Retrieve all messages retrieved = self.memory.get_conversation_messages(conversation_id) assert len(retrieved) == 3 assert retrieved[0]["message_id"] == "msg1" assert retrieved[1]["message_id"] == "msg2" assert retrieved[2]["message_id"] == "msg3" def test_get_conversation_messages_limited(self): """Test retrieving limited number of conversation messages.""" conversation_id = "test_conv_limit" self.memory.create_conversation(conversation_id) # Add multiple messages for i in range(5): self.memory.add_message(conversation_id, f"msg{i}", "user", f"Message {i}") # Retrieve limited messages retrieved = self.memory.get_conversation_messages(conversation_id, limit=3) assert len(retrieved) == 3 # Should return most recent messages first due to DESC order assert retrieved[0]["message_id"] == "msg4" assert retrieved[1]["message_id"] == "msg3" assert retrieved[2]["message_id"] == "msg2" def test_update_conversation_summary(self): """Test updating conversation summary.""" conversation_id = "test_conv_summary" self.memory.create_conversation(conversation_id) summary = "This is a test conversation summary" topics = ["testing", "memory", "conversation"] self.memory.update_conversation_summary(conversation_id, summary, topics) conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( "SELECT summary, topics, ended_at FROM conversation_history WHERE conversation_id = ?", (conversation_id,), ) row = cursor.fetchone() assert row[0] == summary assert row[1] is not None # topics should be stored assert row[2] is not None # ended_at should be set conn.close() def test_search_conversations(self): """Test searching conversations by content.""" # Create conversations with different content conv1 = "conv_search_1" conv2 = "conv_search_2" conv3 = "conv_search_3" self.memory.create_conversation(conv1) self.memory.create_conversation(conv2) self.memory.create_conversation(conv3) # Add messages with searchable content self.memory.add_message(conv1, "msg1", "user", "Python programming tutorial") self.memory.add_message(conv2, "msg2", "user", "JavaScript development guide") self.memory.add_message(conv3, "msg3", "user", "Database design principles") # Search for "programming" results = self.memory.search_conversations("programming") assert len(results) == 1 assert results[0]["conversation_id"] == conv1 # Search for "development" results = self.memory.search_conversations("development") assert len(results) == 1 assert results[0]["conversation_id"] == conv2 def test_get_recent_conversations(self): """Test getting recent conversations.""" # Create conversations at different times conv1 = "conv_recent_1" conv2 = "conv_recent_2" conv3 = "conv_recent_3" self.memory.create_conversation(conv1) time.sleep(0.01) # Small delay to ensure different timestamps self.memory.create_conversation(conv2) time.sleep(0.01) self.memory.create_conversation(conv3) # Get recent conversations recent = self.memory.get_recent_conversations(limit=2) assert len(recent) == 2 # Should be ordered by started_at DESC assert recent[0]["conversation_id"] == conv3 assert recent[1]["conversation_id"] == conv2 def test_get_recent_conversations_by_session(self): """Test getting recent conversations for a specific session.""" session1 = "session_1" session2 = "session_2" conv1 = "conv_session_1" conv2 = "conv_session_2" conv3 = "conv_session_3" self.memory.create_conversation(conv1, session1) self.memory.create_conversation(conv2, session2) self.memory.create_conversation(conv3, session1) # Get conversations for session1 session_convs = self.memory.get_recent_conversations(session_id=session1) assert len(session_convs) == 2 conversation_ids = [c["conversation_id"] for c in session_convs] assert conv1 in conversation_ids assert conv3 in conversation_ids assert conv2 not in conversation_ids def test_delete_conversation(self): """Test deleting a conversation.""" conversation_id = "conv_delete" self.memory.create_conversation(conversation_id) # Add some messages self.memory.add_message(conversation_id, "msg1", "user", "Test message") self.memory.add_message(conversation_id, "msg2", "assistant", "Response") # Delete conversation result = self.memory.delete_conversation(conversation_id) assert result is True # Verify conversation and messages are gone conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( "SELECT COUNT(*) FROM conversation_history WHERE conversation_id = ?", (conversation_id,), ) assert cursor.fetchone()[0] == 0 cursor.execute( "SELECT COUNT(*) FROM conversation_messages WHERE conversation_id = ?", (conversation_id,), ) assert cursor.fetchone()[0] == 0 conn.close() def test_delete_nonexistent_conversation(self): """Test deleting a non-existent conversation.""" result = self.memory.delete_conversation("nonexistent") assert result is False def test_get_statistics(self): """Test getting memory statistics.""" # Create some conversations and messages for i in range(3): conv_id = f"conv_stats_{i}" self.memory.create_conversation(conv_id) for j in range(2): self.memory.add_message(conv_id, f"msg_{i}_{j}", "user", f"Message {j}") stats = self.memory.get_statistics() assert stats["total_conversations"] == 3 assert stats["total_messages"] == 6 assert stats["average_messages_per_conversation"] == 2.0 def test_thread_safety(self): """Test that the memory can handle concurrent access.""" import queue import threading results = queue.Queue() def worker(worker_id): try: conv_id = f"conv_thread_{worker_id}" self.memory.create_conversation(conv_id) self.memory.add_message(conv_id, f"msg_{worker_id}", "user", f"Worker {worker_id}") results.put(True) except Exception as e: results.put(e) # Start multiple threads threads = [] for i in range(5): t = threading.Thread(target=worker, args=(i,)) threads.append(t) t.start() # Wait for all threads for t in threads: t.join() # Check results for _ in range(5): result = results.get() assert result is True # Verify all conversations were created recent = self.memory.get_recent_conversations(limit=10) assert len(recent) >= 5