345 lines
12 KiB
Python
Raw Normal View History

import os
import sqlite3
import tempfile
import time
from rp.memory.knowledge_store import KnowledgeEntry, KnowledgeStore
class TestKnowledgeStore:
def setup_method(self):
"""Set up test database for each test."""
self.db_fd, self.db_path = tempfile.mkstemp()
self.store = KnowledgeStore(self.db_path)
def teardown_method(self):
"""Clean up test database after each test."""
self.store = None
os.close(self.db_fd)
os.unlink(self.db_path)
def test_init(self):
"""Test KnowledgeStore initialization."""
assert self.store.db_path == self.db_path
# Verify tables were created
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = [row[0] for row in cursor.fetchall()]
assert "knowledge_entries" in tables
conn.close()
def test_add_entry(self):
"""Test adding a knowledge entry."""
entry = KnowledgeEntry(
entry_id="test_1",
category="test",
content="This is a test entry",
metadata={"source": "test"},
created_at=time.time(),
updated_at=time.time(),
)
self.store.add_entry(entry)
# Verify entry was added
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute(
"SELECT entry_id, category, content FROM knowledge_entries WHERE entry_id = ?",
("test_1",),
)
row = cursor.fetchone()
assert row[0] == "test_1"
assert row[1] == "test"
assert row[2] == "This is a test entry"
conn.close()
def test_get_entry(self):
"""Test retrieving a knowledge entry."""
entry = KnowledgeEntry(
entry_id="test_get",
category="test",
content="Content to retrieve",
metadata={},
created_at=time.time(),
updated_at=time.time(),
)
self.store.add_entry(entry)
retrieved = self.store.get_entry("test_get")
assert retrieved is not None
assert retrieved.entry_id == "test_get"
assert retrieved.content == "Content to retrieve"
assert retrieved.access_count == 1 # Should be incremented
def test_get_entry_not_found(self):
"""Test retrieving a non-existent entry."""
retrieved = self.store.get_entry("nonexistent")
assert retrieved is None
def test_search_entries_semantic(self):
"""Test semantic search."""
entries = [
KnowledgeEntry(
"entry1", "personal", "John is a software engineer", {}, time.time(), time.time()
),
KnowledgeEntry(
"entry2", "personal", "Mary works as a designer", {}, time.time(), time.time()
),
KnowledgeEntry(
"entry3", "tech", "Python is a programming language", {}, time.time(), time.time()
),
]
for entry in entries:
self.store.add_entry(entry)
results = self.store.search_entries("software engineer", top_k=2)
assert len(results) >= 1
# Should find the most relevant entry
found_ids = [r.entry_id for r in results]
assert "entry1" in found_ids
def test_search_entries_fts_exact(self):
"""Test full-text search with exact matches."""
entries = [
KnowledgeEntry(
"exact1", "test", "Python programming language", {}, time.time(), time.time()
),
KnowledgeEntry(
"exact2", "test", "Java programming language", {}, time.time(), time.time()
),
KnowledgeEntry(
"exact3", "test", "Database design principles", {}, time.time(), time.time()
),
]
for entry in entries:
self.store.add_entry(entry)
results = self.store.search_entries("programming language", top_k=3)
assert len(results) >= 2
found_ids = [r.entry_id for r in results]
assert "exact1" in found_ids
assert "exact2" in found_ids
def test_search_entries_by_category(self):
"""Test searching entries by category."""
entries = [
KnowledgeEntry("cat1", "personal", "John's info", {}, time.time(), time.time()),
KnowledgeEntry("cat2", "tech", "Python info", {}, time.time(), time.time()),
KnowledgeEntry("cat3", "personal", "Jane's info", {}, time.time(), time.time()),
]
for entry in entries:
self.store.add_entry(entry)
results = self.store.search_entries("info", category="personal", top_k=5)
assert len(results) == 2
found_ids = [r.entry_id for r in results]
assert "cat1" in found_ids
assert "cat3" in found_ids
assert "cat2" not in found_ids
def test_get_by_category(self):
"""Test getting entries by category."""
entries = [
KnowledgeEntry("get1", "personal", "Entry 1", {}, time.time(), time.time()),
KnowledgeEntry("get2", "tech", "Entry 2", {}, time.time(), time.time()),
KnowledgeEntry("get3", "personal", "Entry 3", {}, time.time() + 1, time.time() + 1),
]
for entry in entries:
self.store.add_entry(entry)
personal_entries = self.store.get_by_category("personal")
assert len(personal_entries) == 2
# Should be ordered by importance_score DESC, created_at DESC
assert personal_entries[0].entry_id == "get3" # More recent
assert personal_entries[1].entry_id == "get1"
def test_update_importance(self):
"""Test updating entry importance."""
entry = KnowledgeEntry(
"importance_test", "test", "Test content", {}, time.time(), time.time()
)
self.store.add_entry(entry)
self.store.update_importance("importance_test", 0.8)
retrieved = self.store.get_entry("importance_test")
assert retrieved.importance_score == 0.8
def test_delete_entry(self):
"""Test deleting an entry."""
entry = KnowledgeEntry("delete_test", "test", "To be deleted", {}, time.time(), time.time())
self.store.add_entry(entry)
result = self.store.delete_entry("delete_test")
assert result is True
# Verify it's gone
retrieved = self.store.get_entry("delete_test")
assert retrieved is None
def test_delete_entry_not_found(self):
"""Test deleting a non-existent entry."""
result = self.store.delete_entry("nonexistent")
assert result is False
def test_get_statistics(self):
"""Test getting store statistics."""
entries = [
KnowledgeEntry("stat1", "personal", "Personal info", {}, time.time(), time.time()),
KnowledgeEntry("stat2", "tech", "Tech info", {}, time.time(), time.time()),
KnowledgeEntry("stat3", "personal", "More personal info", {}, time.time(), time.time()),
]
for entry in entries:
self.store.add_entry(entry)
stats = self.store.get_statistics()
assert stats["total_entries"] == 3
assert stats["total_categories"] == 2
assert stats["category_distribution"]["personal"] == 2
assert stats["category_distribution"]["tech"] == 1
assert stats["vocabulary_size"] > 0
def test_fts_search_exact_phrase(self):
"""Test FTS exact phrase matching."""
entries = [
KnowledgeEntry(
"fts1", "test", "Python is great for programming", {}, time.time(), time.time()
),
KnowledgeEntry(
"fts2", "test", "Java is also good for programming", {}, time.time(), time.time()
),
KnowledgeEntry(
"fts3", "test", "Database management is important", {}, time.time(), time.time()
),
]
for entry in entries:
self.store.add_entry(entry)
fts_results = self.store._fts_search("programming")
assert len(fts_results) == 2
entry_ids = [entry_id for entry_id, score in fts_results]
assert "fts1" in entry_ids
assert "fts2" in entry_ids
def test_fts_search_partial_match(self):
"""Test FTS partial word matching."""
entries = [
KnowledgeEntry("partial1", "test", "Python programming", {}, time.time(), time.time()),
KnowledgeEntry("partial2", "test", "Java programming", {}, time.time(), time.time()),
KnowledgeEntry("partial3", "test", "Database design", {}, time.time(), time.time()),
]
for entry in entries:
self.store.add_entry(entry)
fts_results = self.store._fts_search("program")
assert len(fts_results) >= 2
def test_combined_search_scoring(self):
"""Test that combined semantic + FTS search produces proper scoring."""
entries = [
KnowledgeEntry(
"combined1", "test", "Python programming language", {}, time.time(), time.time()
),
KnowledgeEntry(
"combined2", "test", "Java programming language", {}, time.time(), time.time()
),
KnowledgeEntry(
"combined3", "test", "Database management system", {}, time.time(), time.time()
),
]
for entry in entries:
self.store.add_entry(entry)
results = self.store.search_entries("programming language")
assert len(results) >= 2
# Check that results have search scores (at least one should have a positive score)
has_positive_score = False
for result in results:
assert "search_score" in result.metadata
if result.metadata["search_score"] > 0:
has_positive_score = True
assert has_positive_score
def test_empty_search(self):
"""Test searching with no matches."""
results = self.store.search_entries("nonexistent topic")
assert len(results) == 0
def test_search_empty_query(self):
"""Test searching with empty query."""
results = self.store.search_entries("")
assert len(results) == 0
def test_thread_safety(self):
"""Test that the store can handle concurrent access."""
import queue
import threading
results = queue.Queue()
def worker(worker_id):
try:
entry = KnowledgeEntry(
f"thread_{worker_id}",
"test",
f"Content from worker {worker_id}",
{},
time.time(),
time.time(),
)
self.store.add_entry(entry)
retrieved = self.store.get_entry(f"thread_{worker_id}")
assert retrieved is not None
results.put(True)
except Exception as e:
results.put(e)
# Start multiple threads
threads = []
for i in range(5):
t = threading.Thread(target=worker, args=(i,))
threads.append(t)
t.start()
# Wait for all threads
for t in threads:
t.join()
# Check results
for _ in range(5):
result = results.get()
assert result is True
def test_entry_to_dict(self):
"""Test KnowledgeEntry to_dict method."""
entry = KnowledgeEntry(
entry_id="dict_test",
category="test",
content="Test content",
metadata={"key": "value"},
created_at=1234567890.0,
updated_at=1234567891.0,
access_count=5,
importance_score=0.8,
)
entry_dict = entry.to_dict()
assert entry_dict["entry_id"] == "dict_test"
assert entry_dict["category"] == "test"
assert entry_dict["content"] == "Test content"
assert entry_dict["metadata"]["key"] == "value"
assert entry_dict["access_count"] == 5
assert entry_dict["importance_score"] == 0.8