import os import sqlite3 import tempfile import time from rp.memory.knowledge_store import KnowledgeEntry, KnowledgeStore class TestKnowledgeStore: def setup_method(self): """Set up test database for each test.""" self.db_fd, self.db_path = tempfile.mkstemp() self.store = KnowledgeStore(self.db_path) def teardown_method(self): """Clean up test database after each test.""" self.store = None os.close(self.db_fd) os.unlink(self.db_path) def test_init(self): """Test KnowledgeStore initialization.""" assert self.store.db_path == self.db_path # Verify tables were created conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") tables = [row[0] for row in cursor.fetchall()] assert "knowledge_entries" in tables conn.close() def test_add_entry(self): """Test adding a knowledge entry.""" entry = KnowledgeEntry( entry_id="test_1", category="test", content="This is a test entry", metadata={"source": "test"}, created_at=time.time(), updated_at=time.time(), ) self.store.add_entry(entry) # Verify entry was added conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute( "SELECT entry_id, category, content FROM knowledge_entries WHERE entry_id = ?", ("test_1",), ) row = cursor.fetchone() assert row[0] == "test_1" assert row[1] == "test" assert row[2] == "This is a test entry" conn.close() def test_get_entry(self): """Test retrieving a knowledge entry.""" entry = KnowledgeEntry( entry_id="test_get", category="test", content="Content to retrieve", metadata={}, created_at=time.time(), updated_at=time.time(), ) self.store.add_entry(entry) retrieved = self.store.get_entry("test_get") assert retrieved is not None assert retrieved.entry_id == "test_get" assert retrieved.content == "Content to retrieve" assert retrieved.access_count == 1 # Should be incremented def test_get_entry_not_found(self): """Test retrieving a non-existent entry.""" retrieved = self.store.get_entry("nonexistent") assert retrieved is None def test_search_entries_semantic(self): """Test semantic search.""" entries = [ KnowledgeEntry( "entry1", "personal", "John is a software engineer", {}, time.time(), time.time() ), KnowledgeEntry( "entry2", "personal", "Mary works as a designer", {}, time.time(), time.time() ), KnowledgeEntry( "entry3", "tech", "Python is a programming language", {}, time.time(), time.time() ), ] for entry in entries: self.store.add_entry(entry) results = self.store.search_entries("software engineer", top_k=2) assert len(results) >= 1 # Should find the most relevant entry found_ids = [r.entry_id for r in results] assert "entry1" in found_ids def test_search_entries_fts_exact(self): """Test full-text search with exact matches.""" entries = [ KnowledgeEntry( "exact1", "test", "Python programming language", {}, time.time(), time.time() ), KnowledgeEntry( "exact2", "test", "Java programming language", {}, time.time(), time.time() ), KnowledgeEntry( "exact3", "test", "Database design principles", {}, time.time(), time.time() ), ] for entry in entries: self.store.add_entry(entry) results = self.store.search_entries("programming language", top_k=3) assert len(results) >= 2 found_ids = [r.entry_id for r in results] assert "exact1" in found_ids assert "exact2" in found_ids def test_search_entries_by_category(self): """Test searching entries by category.""" entries = [ KnowledgeEntry("cat1", "personal", "John's info", {}, time.time(), time.time()), KnowledgeEntry("cat2", "tech", "Python info", {}, time.time(), time.time()), KnowledgeEntry("cat3", "personal", "Jane's info", {}, time.time(), time.time()), ] for entry in entries: self.store.add_entry(entry) results = self.store.search_entries("info", category="personal", top_k=5) assert len(results) == 2 found_ids = [r.entry_id for r in results] assert "cat1" in found_ids assert "cat3" in found_ids assert "cat2" not in found_ids def test_get_by_category(self): """Test getting entries by category.""" entries = [ KnowledgeEntry("get1", "personal", "Entry 1", {}, time.time(), time.time()), KnowledgeEntry("get2", "tech", "Entry 2", {}, time.time(), time.time()), KnowledgeEntry("get3", "personal", "Entry 3", {}, time.time() + 1, time.time() + 1), ] for entry in entries: self.store.add_entry(entry) personal_entries = self.store.get_by_category("personal") assert len(personal_entries) == 2 # Should be ordered by importance_score DESC, created_at DESC assert personal_entries[0].entry_id == "get3" # More recent assert personal_entries[1].entry_id == "get1" def test_update_importance(self): """Test updating entry importance.""" entry = KnowledgeEntry( "importance_test", "test", "Test content", {}, time.time(), time.time() ) self.store.add_entry(entry) self.store.update_importance("importance_test", 0.8) retrieved = self.store.get_entry("importance_test") assert retrieved.importance_score == 0.8 def test_delete_entry(self): """Test deleting an entry.""" entry = KnowledgeEntry("delete_test", "test", "To be deleted", {}, time.time(), time.time()) self.store.add_entry(entry) result = self.store.delete_entry("delete_test") assert result is True # Verify it's gone retrieved = self.store.get_entry("delete_test") assert retrieved is None def test_delete_entry_not_found(self): """Test deleting a non-existent entry.""" result = self.store.delete_entry("nonexistent") assert result is False def test_get_statistics(self): """Test getting store statistics.""" entries = [ KnowledgeEntry("stat1", "personal", "Personal info", {}, time.time(), time.time()), KnowledgeEntry("stat2", "tech", "Tech info", {}, time.time(), time.time()), KnowledgeEntry("stat3", "personal", "More personal info", {}, time.time(), time.time()), ] for entry in entries: self.store.add_entry(entry) stats = self.store.get_statistics() assert stats["total_entries"] == 3 assert stats["total_categories"] == 2 assert stats["category_distribution"]["personal"] == 2 assert stats["category_distribution"]["tech"] == 1 assert stats["vocabulary_size"] > 0 def test_fts_search_exact_phrase(self): """Test FTS exact phrase matching.""" entries = [ KnowledgeEntry( "fts1", "test", "Python is great for programming", {}, time.time(), time.time() ), KnowledgeEntry( "fts2", "test", "Java is also good for programming", {}, time.time(), time.time() ), KnowledgeEntry( "fts3", "test", "Database management is important", {}, time.time(), time.time() ), ] for entry in entries: self.store.add_entry(entry) fts_results = self.store._fts_search("programming") assert len(fts_results) == 2 entry_ids = [entry_id for entry_id, score in fts_results] assert "fts1" in entry_ids assert "fts2" in entry_ids def test_fts_search_partial_match(self): """Test FTS partial word matching.""" entries = [ KnowledgeEntry("partial1", "test", "Python programming", {}, time.time(), time.time()), KnowledgeEntry("partial2", "test", "Java programming", {}, time.time(), time.time()), KnowledgeEntry("partial3", "test", "Database design", {}, time.time(), time.time()), ] for entry in entries: self.store.add_entry(entry) fts_results = self.store._fts_search("program") assert len(fts_results) >= 2 def test_combined_search_scoring(self): """Test that combined semantic + FTS search produces proper scoring.""" entries = [ KnowledgeEntry( "combined1", "test", "Python programming language", {}, time.time(), time.time() ), KnowledgeEntry( "combined2", "test", "Java programming language", {}, time.time(), time.time() ), KnowledgeEntry( "combined3", "test", "Database management system", {}, time.time(), time.time() ), ] for entry in entries: self.store.add_entry(entry) results = self.store.search_entries("programming language") assert len(results) >= 2 # Check that results have search scores (at least one should have a positive score) has_positive_score = False for result in results: assert "search_score" in result.metadata if result.metadata["search_score"] > 0: has_positive_score = True assert has_positive_score def test_empty_search(self): """Test searching with no matches.""" results = self.store.search_entries("nonexistent topic") assert len(results) == 0 def test_search_empty_query(self): """Test searching with empty query.""" results = self.store.search_entries("") assert len(results) == 0 def test_thread_safety(self): """Test that the store can handle concurrent access.""" import queue import threading results = queue.Queue() def worker(worker_id): try: entry = KnowledgeEntry( f"thread_{worker_id}", "test", f"Content from worker {worker_id}", {}, time.time(), time.time(), ) self.store.add_entry(entry) retrieved = self.store.get_entry(f"thread_{worker_id}") assert retrieved is not None results.put(True) except Exception as e: results.put(e) # Start multiple threads threads = [] for i in range(5): t = threading.Thread(target=worker, args=(i,)) threads.append(t) t.start() # Wait for all threads for t in threads: t.join() # Check results for _ in range(5): result = results.get() assert result is True def test_entry_to_dict(self): """Test KnowledgeEntry to_dict method.""" entry = KnowledgeEntry( entry_id="dict_test", category="test", content="Test content", metadata={"key": "value"}, created_at=1234567890.0, updated_at=1234567891.0, access_count=5, importance_score=0.8, ) entry_dict = entry.to_dict() assert entry_dict["entry_id"] == "dict_test" assert entry_dict["category"] == "test" assert entry_dict["content"] == "Test content" assert entry_dict["metadata"]["key"] == "value" assert entry_dict["access_count"] == 5 assert entry_dict["importance_score"] == 0.8