|
import sqlite3
|
|
import tempfile
|
|
import os
|
|
import time
|
|
|
|
from pr.memory.knowledge_store import KnowledgeStore, KnowledgeEntry
|
|
|
|
|
|
class TestKnowledgeStore:
|
|
def setup_method(self):
|
|
"""Set up test database for each test."""
|
|
self.db_fd, self.db_path = tempfile.mkstemp()
|
|
self.store = KnowledgeStore(self.db_path)
|
|
|
|
def teardown_method(self):
|
|
"""Clean up test database after each test."""
|
|
self.store = None
|
|
os.close(self.db_fd)
|
|
os.unlink(self.db_path)
|
|
|
|
def test_init(self):
|
|
"""Test KnowledgeStore initialization."""
|
|
assert self.store.db_path == self.db_path
|
|
# Verify tables were created
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
|
tables = [row[0] for row in cursor.fetchall()]
|
|
assert "knowledge_entries" in tables
|
|
conn.close()
|
|
|
|
def test_add_entry(self):
|
|
"""Test adding a knowledge entry."""
|
|
entry = KnowledgeEntry(
|
|
entry_id="test_1",
|
|
category="test",
|
|
content="This is a test entry",
|
|
metadata={"source": "test"},
|
|
created_at=time.time(),
|
|
updated_at=time.time(),
|
|
)
|
|
|
|
self.store.add_entry(entry)
|
|
|
|
# Verify entry was added
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"SELECT entry_id, category, content FROM knowledge_entries WHERE entry_id = ?",
|
|
("test_1",),
|
|
)
|
|
row = cursor.fetchone()
|
|
assert row[0] == "test_1"
|
|
assert row[1] == "test"
|
|
assert row[2] == "This is a test entry"
|
|
conn.close()
|
|
|
|
def test_get_entry(self):
|
|
"""Test retrieving a knowledge entry."""
|
|
entry = KnowledgeEntry(
|
|
entry_id="test_get",
|
|
category="test",
|
|
content="Content to retrieve",
|
|
metadata={},
|
|
created_at=time.time(),
|
|
updated_at=time.time(),
|
|
)
|
|
|
|
self.store.add_entry(entry)
|
|
retrieved = self.store.get_entry("test_get")
|
|
|
|
assert retrieved is not None
|
|
assert retrieved.entry_id == "test_get"
|
|
assert retrieved.content == "Content to retrieve"
|
|
assert retrieved.access_count == 1 # Should be incremented
|
|
|
|
def test_get_entry_not_found(self):
|
|
"""Test retrieving a non-existent entry."""
|
|
retrieved = self.store.get_entry("nonexistent")
|
|
assert retrieved is None
|
|
|
|
def test_search_entries_semantic(self):
|
|
"""Test semantic search."""
|
|
entries = [
|
|
KnowledgeEntry(
|
|
"entry1", "personal", "John is a software engineer", {}, time.time(), time.time()
|
|
),
|
|
KnowledgeEntry(
|
|
"entry2", "personal", "Mary works as a designer", {}, time.time(), time.time()
|
|
),
|
|
KnowledgeEntry(
|
|
"entry3", "tech", "Python is a programming language", {}, time.time(), time.time()
|
|
),
|
|
]
|
|
|
|
for entry in entries:
|
|
self.store.add_entry(entry)
|
|
|
|
results = self.store.search_entries("software engineer", top_k=2)
|
|
assert len(results) >= 1
|
|
# Should find the most relevant entry
|
|
found_ids = [r.entry_id for r in results]
|
|
assert "entry1" in found_ids
|
|
|
|
def test_search_entries_fts_exact(self):
|
|
"""Test full-text search with exact matches."""
|
|
entries = [
|
|
KnowledgeEntry(
|
|
"exact1", "test", "Python programming language", {}, time.time(), time.time()
|
|
),
|
|
KnowledgeEntry(
|
|
"exact2", "test", "Java programming language", {}, time.time(), time.time()
|
|
),
|
|
KnowledgeEntry(
|
|
"exact3", "test", "Database design principles", {}, time.time(), time.time()
|
|
),
|
|
]
|
|
|
|
for entry in entries:
|
|
self.store.add_entry(entry)
|
|
|
|
results = self.store.search_entries("programming language", top_k=3)
|
|
assert len(results) >= 2
|
|
found_ids = [r.entry_id for r in results]
|
|
assert "exact1" in found_ids
|
|
assert "exact2" in found_ids
|
|
|
|
def test_search_entries_by_category(self):
|
|
"""Test searching entries by category."""
|
|
entries = [
|
|
KnowledgeEntry("cat1", "personal", "John's info", {}, time.time(), time.time()),
|
|
KnowledgeEntry("cat2", "tech", "Python info", {}, time.time(), time.time()),
|
|
KnowledgeEntry("cat3", "personal", "Jane's info", {}, time.time(), time.time()),
|
|
]
|
|
|
|
for entry in entries:
|
|
self.store.add_entry(entry)
|
|
|
|
results = self.store.search_entries("info", category="personal", top_k=5)
|
|
assert len(results) == 2
|
|
found_ids = [r.entry_id for r in results]
|
|
assert "cat1" in found_ids
|
|
assert "cat3" in found_ids
|
|
assert "cat2" not in found_ids
|
|
|
|
def test_get_by_category(self):
|
|
"""Test getting entries by category."""
|
|
entries = [
|
|
KnowledgeEntry("get1", "personal", "Entry 1", {}, time.time(), time.time()),
|
|
KnowledgeEntry("get2", "tech", "Entry 2", {}, time.time(), time.time()),
|
|
KnowledgeEntry("get3", "personal", "Entry 3", {}, time.time() + 1, time.time() + 1),
|
|
]
|
|
|
|
for entry in entries:
|
|
self.store.add_entry(entry)
|
|
|
|
personal_entries = self.store.get_by_category("personal")
|
|
assert len(personal_entries) == 2
|
|
# Should be ordered by importance_score DESC, created_at DESC
|
|
assert personal_entries[0].entry_id == "get3" # More recent
|
|
assert personal_entries[1].entry_id == "get1"
|
|
|
|
def test_update_importance(self):
|
|
"""Test updating entry importance."""
|
|
entry = KnowledgeEntry(
|
|
"importance_test", "test", "Test content", {}, time.time(), time.time()
|
|
)
|
|
self.store.add_entry(entry)
|
|
|
|
self.store.update_importance("importance_test", 0.8)
|
|
|
|
retrieved = self.store.get_entry("importance_test")
|
|
assert retrieved.importance_score == 0.8
|
|
|
|
def test_delete_entry(self):
|
|
"""Test deleting an entry."""
|
|
entry = KnowledgeEntry("delete_test", "test", "To be deleted", {}, time.time(), time.time())
|
|
self.store.add_entry(entry)
|
|
|
|
result = self.store.delete_entry("delete_test")
|
|
assert result is True
|
|
|
|
# Verify it's gone
|
|
retrieved = self.store.get_entry("delete_test")
|
|
assert retrieved is None
|
|
|
|
def test_delete_entry_not_found(self):
|
|
"""Test deleting a non-existent entry."""
|
|
result = self.store.delete_entry("nonexistent")
|
|
assert result is False
|
|
|
|
def test_get_statistics(self):
|
|
"""Test getting store statistics."""
|
|
entries = [
|
|
KnowledgeEntry("stat1", "personal", "Personal info", {}, time.time(), time.time()),
|
|
KnowledgeEntry("stat2", "tech", "Tech info", {}, time.time(), time.time()),
|
|
KnowledgeEntry("stat3", "personal", "More personal info", {}, time.time(), time.time()),
|
|
]
|
|
|
|
for entry in entries:
|
|
self.store.add_entry(entry)
|
|
|
|
stats = self.store.get_statistics()
|
|
assert stats["total_entries"] == 3
|
|
assert stats["total_categories"] == 2
|
|
assert stats["category_distribution"]["personal"] == 2
|
|
assert stats["category_distribution"]["tech"] == 1
|
|
assert stats["vocabulary_size"] > 0
|
|
|
|
def test_fts_search_exact_phrase(self):
|
|
"""Test FTS exact phrase matching."""
|
|
entries = [
|
|
KnowledgeEntry(
|
|
"fts1", "test", "Python is great for programming", {}, time.time(), time.time()
|
|
),
|
|
KnowledgeEntry(
|
|
"fts2", "test", "Java is also good for programming", {}, time.time(), time.time()
|
|
),
|
|
KnowledgeEntry(
|
|
"fts3", "test", "Database management is important", {}, time.time(), time.time()
|
|
),
|
|
]
|
|
|
|
for entry in entries:
|
|
self.store.add_entry(entry)
|
|
|
|
fts_results = self.store._fts_search("programming")
|
|
assert len(fts_results) == 2
|
|
entry_ids = [entry_id for entry_id, score in fts_results]
|
|
assert "fts1" in entry_ids
|
|
assert "fts2" in entry_ids
|
|
|
|
def test_fts_search_partial_match(self):
|
|
"""Test FTS partial word matching."""
|
|
entries = [
|
|
KnowledgeEntry("partial1", "test", "Python programming", {}, time.time(), time.time()),
|
|
KnowledgeEntry("partial2", "test", "Java programming", {}, time.time(), time.time()),
|
|
KnowledgeEntry("partial3", "test", "Database design", {}, time.time(), time.time()),
|
|
]
|
|
|
|
for entry in entries:
|
|
self.store.add_entry(entry)
|
|
|
|
fts_results = self.store._fts_search("program")
|
|
assert len(fts_results) >= 2
|
|
|
|
def test_combined_search_scoring(self):
|
|
"""Test that combined semantic + FTS search produces proper scoring."""
|
|
entries = [
|
|
KnowledgeEntry(
|
|
"combined1", "test", "Python programming language", {}, time.time(), time.time()
|
|
),
|
|
KnowledgeEntry(
|
|
"combined2", "test", "Java programming language", {}, time.time(), time.time()
|
|
),
|
|
KnowledgeEntry(
|
|
"combined3", "test", "Database management system", {}, time.time(), time.time()
|
|
),
|
|
]
|
|
|
|
for entry in entries:
|
|
self.store.add_entry(entry)
|
|
|
|
results = self.store.search_entries("programming language")
|
|
assert len(results) >= 2
|
|
|
|
# Check that results have search scores (at least one should have a positive score)
|
|
has_positive_score = False
|
|
for result in results:
|
|
assert "search_score" in result.metadata
|
|
if result.metadata["search_score"] > 0:
|
|
has_positive_score = True
|
|
assert has_positive_score
|
|
|
|
def test_empty_search(self):
|
|
"""Test searching with no matches."""
|
|
results = self.store.search_entries("nonexistent topic")
|
|
assert len(results) == 0
|
|
|
|
def test_search_empty_query(self):
|
|
"""Test searching with empty query."""
|
|
results = self.store.search_entries("")
|
|
assert len(results) == 0
|
|
|
|
def test_thread_safety(self):
|
|
"""Test that the store can handle concurrent access."""
|
|
import threading
|
|
import queue
|
|
|
|
results = queue.Queue()
|
|
|
|
def worker(worker_id):
|
|
try:
|
|
entry = KnowledgeEntry(
|
|
f"thread_{worker_id}",
|
|
"test",
|
|
f"Content from worker {worker_id}",
|
|
{},
|
|
time.time(),
|
|
time.time(),
|
|
)
|
|
self.store.add_entry(entry)
|
|
retrieved = self.store.get_entry(f"thread_{worker_id}")
|
|
assert retrieved is not None
|
|
results.put(True)
|
|
except Exception as e:
|
|
results.put(e)
|
|
|
|
# Start multiple threads
|
|
threads = []
|
|
for i in range(5):
|
|
t = threading.Thread(target=worker, args=(i,))
|
|
threads.append(t)
|
|
t.start()
|
|
|
|
# Wait for all threads
|
|
for t in threads:
|
|
t.join()
|
|
|
|
# Check results
|
|
for _ in range(5):
|
|
result = results.get()
|
|
assert result is True
|
|
|
|
def test_entry_to_dict(self):
|
|
"""Test KnowledgeEntry to_dict method."""
|
|
entry = KnowledgeEntry(
|
|
entry_id="dict_test",
|
|
category="test",
|
|
content="Test content",
|
|
metadata={"key": "value"},
|
|
created_at=1234567890.0,
|
|
updated_at=1234567891.0,
|
|
access_count=5,
|
|
importance_score=0.8,
|
|
)
|
|
|
|
entry_dict = entry.to_dict()
|
|
assert entry_dict["entry_id"] == "dict_test"
|
|
assert entry_dict["category"] == "test"
|
|
assert entry_dict["content"] == "Test content"
|
|
assert entry_dict["metadata"]["key"] == "value"
|
|
assert entry_dict["access_count"] == 5
|
|
assert entry_dict["importance_score"] == 0.8
|