ADded coverage.

This commit is contained in:
retoor 2025-11-04 08:01:20 +01:00
parent 2b701cb5cd
commit a40713d463
6 changed files with 452 additions and 36 deletions

View File

@ -31,13 +31,13 @@ class KnowledgeEntry:
class KnowledgeStore:
def __init__(self, db_path: str):
self.db_path = db_path
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
self.semantic_index = SemanticIndex()
self._initialize_store()
self._load_index()
def _initialize_store(self):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor = self.conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS knowledge_entries (
@ -62,22 +62,17 @@ class KnowledgeStore:
CREATE INDEX IF NOT EXISTS idx_created ON knowledge_entries(created_at DESC)
''')
conn.commit()
conn.close()
self.conn.commit()
def _load_index(self):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor = self.conn.cursor()
cursor.execute('SELECT entry_id, content FROM knowledge_entries')
for row in cursor.fetchall():
self.semantic_index.add_document(row[0], row[1])
conn.close()
def add_entry(self, entry: KnowledgeEntry):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor = self.conn.cursor()
cursor.execute('''
INSERT OR REPLACE INTO knowledge_entries
@ -94,14 +89,12 @@ class KnowledgeStore:
entry.importance_score
))
conn.commit()
conn.close()
self.conn.commit()
self.semantic_index.add_document(entry.entry_id, entry.content)
def get_entry(self, entry_id: str) -> Optional[KnowledgeEntry]:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor = self.conn.cursor()
cursor.execute('''
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
@ -117,9 +110,7 @@ class KnowledgeStore:
SET access_count = access_count + 1
WHERE entry_id = ?
''', (entry_id,))
conn.commit()
conn.close()
self.conn.commit()
return KnowledgeEntry(
entry_id=row[0],
@ -132,15 +123,13 @@ class KnowledgeStore:
importance_score=row[7]
)
conn.close()
return None
def search_entries(self, query: str, category: Optional[str] = None,
top_k: int = 5) -> List[KnowledgeEntry]:
search_results = self.semantic_index.search(query, top_k * 2)
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor = self.conn.cursor()
entries = []
for entry_id, score in search_results:
@ -174,12 +163,10 @@ class KnowledgeStore:
if len(entries) >= top_k:
break
conn.close()
return entries
def get_by_category(self, category: str, limit: int = 20) -> List[KnowledgeEntry]:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor = self.conn.cursor()
cursor.execute('''
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
@ -202,12 +189,10 @@ class KnowledgeStore:
importance_score=row[7]
))
conn.close()
return entries
def update_importance(self, entry_id: str, importance_score: float):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor = self.conn.cursor()
cursor.execute('''
UPDATE knowledge_entries
@ -215,18 +200,15 @@ class KnowledgeStore:
WHERE entry_id = ?
''', (importance_score, time.time(), entry_id))
conn.commit()
conn.close()
self.conn.commit()
def delete_entry(self, entry_id: str) -> bool:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor = self.conn.cursor()
cursor.execute('DELETE FROM knowledge_entries WHERE entry_id = ?', (entry_id,))
deleted = cursor.rowcount > 0
conn.commit()
conn.close()
self.conn.commit()
if deleted:
self.semantic_index.remove_document(entry_id)
@ -234,8 +216,7 @@ class KnowledgeStore:
return deleted
def get_statistics(self) -> Dict[str, Any]:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor = self.conn.cursor()
cursor.execute('SELECT COUNT(*) FROM knowledge_entries')
total_entries = cursor.fetchone()[0]
@ -254,8 +235,6 @@ class KnowledgeStore:
cursor.execute('SELECT SUM(access_count) FROM knowledge_entries')
total_accesses = cursor.fetchone()[0] or 0
conn.close()
return {
'total_entries': total_entries,
'total_categories': total_categories,

View File

@ -0,0 +1,89 @@
import pytest
from unittest.mock import MagicMock
from pr.core.enhanced_assistant import EnhancedAssistant
def test_enhanced_assistant_init():
mock_base = MagicMock()
assistant = EnhancedAssistant(mock_base)
assert assistant.base == mock_base
assert assistant.current_conversation_id is not None
def test_enhanced_call_api_with_cache():
mock_base = MagicMock()
mock_base.model = 'test-model'
mock_base.api_url = 'http://test'
mock_base.api_key = 'key'
mock_base.use_tools = False
mock_base.verbose = False
assistant = EnhancedAssistant(mock_base)
assistant.api_cache = MagicMock()
assistant.api_cache.get.return_value = {'cached': True}
result = assistant.enhanced_call_api([{'role': 'user', 'content': 'test'}])
assert result == {'cached': True}
assistant.api_cache.get.assert_called_once()
def test_enhanced_call_api_without_cache():
mock_base = MagicMock()
mock_base.model = 'test-model'
mock_base.api_url = 'http://test'
mock_base.api_key = 'key'
mock_base.use_tools = False
mock_base.verbose = False
assistant = EnhancedAssistant(mock_base)
assistant.api_cache = None
# It will try to call API and fail with network error, but that's expected
result = assistant.enhanced_call_api([{'role': 'user', 'content': 'test'}])
assert 'error' in result
def test_execute_workflow_not_found():
mock_base = MagicMock()
assistant = EnhancedAssistant(mock_base)
assistant.workflow_storage = MagicMock()
assistant.workflow_storage.load_workflow_by_name.return_value = None
result = assistant.execute_workflow('nonexistent')
assert 'error' in result
def test_create_agent():
mock_base = MagicMock()
assistant = EnhancedAssistant(mock_base)
assistant.agent_manager = MagicMock()
assistant.agent_manager.create_agent.return_value = 'agent_id'
result = assistant.create_agent('role')
assert result == 'agent_id'
def test_search_knowledge():
mock_base = MagicMock()
assistant = EnhancedAssistant(mock_base)
assistant.knowledge_store = MagicMock()
assistant.knowledge_store.search_entries.return_value = [{'result': True}]
result = assistant.search_knowledge('query')
assert result == [{'result': True}]
def test_get_cache_statistics():
mock_base = MagicMock()
assistant = EnhancedAssistant(mock_base)
assistant.api_cache = MagicMock()
assistant.api_cache.get_statistics.return_value = {'hits': 10}
assistant.tool_cache = MagicMock()
assistant.tool_cache.get_statistics.return_value = {'misses': 5}
stats = assistant.get_cache_statistics()
assert 'api_cache' in stats
assert 'tool_cache' in stats
def test_clear_caches():
mock_base = MagicMock()
assistant = EnhancedAssistant(mock_base)
assistant.api_cache = MagicMock()
assistant.tool_cache = MagicMock()
assistant.clear_caches()
assistant.api_cache.clear_all.assert_called_once()
assistant.tool_cache.clear_all.assert_called_once()

24
tests/test_logging.py Normal file
View File

@ -0,0 +1,24 @@
import pytest
import tempfile
import os
from pr.core.logging import setup_logging, get_logger
def test_setup_logging_basic():
logger = setup_logging(verbose=False)
assert logger.name == 'pr'
assert logger.level == 20 # INFO
def test_setup_logging_verbose():
logger = setup_logging(verbose=True)
assert logger.name == 'pr'
assert logger.level == 10 # DEBUG
# Should have console handler
assert len(logger.handlers) >= 2
def test_get_logger_default():
logger = get_logger()
assert logger.name == 'pr'
def test_get_logger_named():
logger = get_logger('test')
assert logger.name == 'pr.test'

116
tests/test_session.py Normal file
View File

@ -0,0 +1,116 @@
import pytest
import tempfile
import os
import json
from pr.core.session import SessionManager
@pytest.fixture
def temp_sessions_dir(tmp_path, monkeypatch):
from pr.core import session
original_dir = session.SESSIONS_DIR
monkeypatch.setattr(session, 'SESSIONS_DIR', str(tmp_path))
# Clean any existing files
import shutil
if os.path.exists(str(tmp_path)):
shutil.rmtree(str(tmp_path))
os.makedirs(str(tmp_path), exist_ok=True)
yield tmp_path
monkeypatch.setattr(session, 'SESSIONS_DIR', original_dir)
def test_session_manager_init(temp_sessions_dir):
manager = SessionManager()
assert os.path.exists(temp_sessions_dir)
def test_save_and_load_session(temp_sessions_dir):
manager = SessionManager()
name = "test_session"
messages = [{"role": "user", "content": "Hello"}]
metadata = {"test": True}
assert manager.save_session(name, messages, metadata)
loaded = manager.load_session(name)
assert loaded is not None
assert loaded['name'] == name
assert loaded['messages'] == messages
assert loaded['metadata'] == metadata
def test_load_nonexistent_session(temp_sessions_dir):
manager = SessionManager()
loaded = manager.load_session("nonexistent")
assert loaded is None
def test_list_sessions(temp_sessions_dir):
manager = SessionManager()
# Save a session
manager.save_session("session1", [{"role": "user", "content": "Hi"}])
manager.save_session("session2", [{"role": "user", "content": "Hello"}])
sessions = manager.list_sessions()
assert len(sessions) == 2
assert sessions[0]['name'] == "session2" # sorted by created_at desc
def test_delete_session(temp_sessions_dir):
manager = SessionManager()
name = "to_delete"
manager.save_session(name, [{"role": "user", "content": "Test"}])
assert manager.delete_session(name)
assert manager.load_session(name) is None
def test_delete_nonexistent_session(temp_sessions_dir):
manager = SessionManager()
assert not manager.delete_session("nonexistent")
def test_export_session_json(temp_sessions_dir, tmp_path):
manager = SessionManager()
name = "export_test"
messages = [{"role": "user", "content": "Export me"}]
manager.save_session(name, messages)
output_path = tmp_path / "exported.json"
assert manager.export_session(name, str(output_path), 'json')
assert output_path.exists()
with open(output_path) as f:
data = json.load(f)
assert data['name'] == name
def test_export_session_markdown(temp_sessions_dir, tmp_path):
manager = SessionManager()
name = "export_md"
messages = [{"role": "user", "content": "Markdown export"}]
manager.save_session(name, messages)
output_path = tmp_path / "exported.md"
assert manager.export_session(name, str(output_path), 'markdown')
assert output_path.exists()
content = output_path.read_text()
assert "# Session: export_md" in content
def test_export_session_txt(temp_sessions_dir, tmp_path):
manager = SessionManager()
name = "export_txt"
messages = [{"role": "user", "content": "Text export"}]
manager.save_session(name, messages)
output_path = tmp_path / "exported.txt"
assert manager.export_session(name, str(output_path), 'txt')
assert output_path.exists()
content = output_path.read_text()
assert "Session: export_txt" in content
def test_export_nonexistent_session(temp_sessions_dir, tmp_path):
manager = SessionManager()
output_path = tmp_path / "nonexistent.json"
assert not manager.export_session("nonexistent", str(output_path), 'json')
def test_export_unsupported_format(temp_sessions_dir, tmp_path):
manager = SessionManager()
name = "test"
manager.save_session(name, [{"role": "user", "content": "Test"}])
output_path = tmp_path / "test.unsupported"
assert not manager.export_session(name, str(output_path), 'unsupported')

View File

@ -0,0 +1,86 @@
import pytest
import tempfile
import os
import json
from pr.core.usage_tracker import UsageTracker
@pytest.fixture
def temp_usage_file(tmp_path, monkeypatch):
from pr.core import usage_tracker
original_file = usage_tracker.USAGE_DB_FILE
temp_file = str(tmp_path / "usage.json")
monkeypatch.setattr(usage_tracker, 'USAGE_DB_FILE', temp_file)
yield temp_file
if os.path.exists(temp_file):
os.remove(temp_file)
monkeypatch.setattr(usage_tracker, 'USAGE_DB_FILE', original_file)
def test_usage_tracker_init():
tracker = UsageTracker()
summary = tracker.get_session_summary()
assert summary['requests'] == 0
assert summary['total_tokens'] == 0
assert summary['estimated_cost'] == 0.0
def test_track_request_known_model():
tracker = UsageTracker()
tracker.track_request('gpt-3.5-turbo', 100, 50)
summary = tracker.get_session_summary()
assert summary['requests'] == 1
assert summary['input_tokens'] == 100
assert summary['output_tokens'] == 50
assert summary['total_tokens'] == 150
assert 'gpt-3.5-turbo' in summary['models_used']
# Cost: (100/1000)*0.0005 + (50/1000)*0.0015 = 0.00005 + 0.000075 = 0.000125
assert abs(summary['estimated_cost'] - 0.000125) < 1e-6
def test_track_request_unknown_model():
tracker = UsageTracker()
tracker.track_request('unknown-model', 100, 50)
summary = tracker.get_session_summary()
assert summary['requests'] == 1
assert summary['estimated_cost'] == 0.0 # Unknown model, cost 0
def test_track_request_multiple():
tracker = UsageTracker()
tracker.track_request('gpt-3.5-turbo', 100, 50)
tracker.track_request('gpt-4', 200, 100)
summary = tracker.get_session_summary()
assert summary['requests'] == 2
assert summary['input_tokens'] == 300
assert summary['output_tokens'] == 150
assert summary['total_tokens'] == 450
assert len(summary['models_used']) == 2
def test_get_formatted_summary():
tracker = UsageTracker()
tracker.track_request('gpt-3.5-turbo', 100, 50)
formatted = tracker.get_formatted_summary()
assert "Total Requests: 1" in formatted
assert "Total Tokens: 150" in formatted
assert "Estimated Cost: $0.0001" in formatted
assert "gpt-3.5-turbo" in formatted
def test_get_total_usage_no_file(temp_usage_file):
total = UsageTracker.get_total_usage()
assert total['total_requests'] == 0
assert total['total_tokens'] == 0
assert total['total_cost'] == 0.0
def test_get_total_usage_with_data(temp_usage_file):
# Manually create history file
history = [
{'timestamp': '2023-01-01', 'model': 'gpt-3.5-turbo', 'input_tokens': 100, 'output_tokens': 50, 'total_tokens': 150, 'cost': 0.000125},
{'timestamp': '2023-01-02', 'model': 'gpt-4', 'input_tokens': 200, 'output_tokens': 100, 'total_tokens': 300, 'cost': 0.008}
]
with open(temp_usage_file, 'w') as f:
json.dump(history, f)
total = UsageTracker.get_total_usage()
assert total['total_requests'] == 2
assert total['total_tokens'] == 450
assert abs(total['total_cost'] - 0.008125) < 1e-6

122
tests/test_validation.py Normal file
View File

@ -0,0 +1,122 @@
import pytest
import tempfile
import os
from pr.core.validation import (
validate_file_path,
validate_directory_path,
validate_model_name,
validate_api_url,
validate_session_name,
validate_temperature,
validate_max_tokens,
)
from pr.core.exceptions import ValidationError
def test_validate_file_path_empty():
with pytest.raises(ValidationError, match="File path cannot be empty"):
validate_file_path("")
def test_validate_file_path_not_exist():
with pytest.raises(ValidationError, match="File does not exist"):
validate_file_path("/nonexistent/file.txt", must_exist=True)
def test_validate_file_path_is_dir():
with tempfile.TemporaryDirectory() as tmpdir:
with pytest.raises(ValidationError, match="Path is a directory"):
validate_file_path(tmpdir, must_exist=True)
def test_validate_file_path_valid():
with tempfile.NamedTemporaryFile() as tmpfile:
result = validate_file_path(tmpfile.name, must_exist=True)
assert os.path.isabs(result)
assert result == os.path.abspath(tmpfile.name)
def test_validate_directory_path_empty():
with pytest.raises(ValidationError, match="Directory path cannot be empty"):
validate_directory_path("")
def test_validate_directory_path_not_exist():
with pytest.raises(ValidationError, match="Directory does not exist"):
validate_directory_path("/nonexistent/dir", must_exist=True)
def test_validate_directory_path_not_dir():
with tempfile.NamedTemporaryFile() as tmpfile:
with pytest.raises(ValidationError, match="Path is not a directory"):
validate_directory_path(tmpfile.name, must_exist=True)
def test_validate_directory_path_create():
with tempfile.TemporaryDirectory() as tmpdir:
new_dir = os.path.join(tmpdir, "new_dir")
result = validate_directory_path(new_dir, must_exist=True, create=True)
assert os.path.isdir(new_dir)
assert result == os.path.abspath(new_dir)
def test_validate_directory_path_valid():
with tempfile.TemporaryDirectory() as tmpdir:
result = validate_directory_path(tmpdir, must_exist=True)
assert result == os.path.abspath(tmpdir)
def test_validate_model_name_empty():
with pytest.raises(ValidationError, match="Model name cannot be empty"):
validate_model_name("")
def test_validate_model_name_too_short():
with pytest.raises(ValidationError, match="Model name too short"):
validate_model_name("a")
def test_validate_model_name_valid():
result = validate_model_name("gpt-3.5-turbo")
assert result == "gpt-3.5-turbo"
def test_validate_api_url_empty():
with pytest.raises(ValidationError, match="API URL cannot be empty"):
validate_api_url("")
def test_validate_api_url_invalid():
with pytest.raises(ValidationError, match="API URL must start with"):
validate_api_url("invalid-url")
def test_validate_api_url_valid():
result = validate_api_url("https://api.example.com")
assert result == "https://api.example.com"
def test_validate_session_name_empty():
with pytest.raises(ValidationError, match="Session name cannot be empty"):
validate_session_name("")
def test_validate_session_name_invalid_char():
with pytest.raises(ValidationError, match="contains invalid character"):
validate_session_name("test/session")
def test_validate_session_name_too_long():
long_name = "a" * 256
with pytest.raises(ValidationError, match="Session name too long"):
validate_session_name(long_name)
def test_validate_session_name_valid():
result = validate_session_name("valid_session_123")
assert result == "valid_session_123"
def test_validate_temperature_too_low():
with pytest.raises(ValidationError, match="Temperature must be between"):
validate_temperature(-0.1)
def test_validate_temperature_too_high():
with pytest.raises(ValidationError, match="Temperature must be between"):
validate_temperature(2.1)
def test_validate_temperature_valid():
result = validate_temperature(0.7)
assert result == 0.7
def test_validate_max_tokens_too_low():
with pytest.raises(ValidationError, match="Max tokens must be at least 1"):
validate_max_tokens(0)
def test_validate_max_tokens_too_high():
with pytest.raises(ValidationError, match="Max tokens too high"):
validate_max_tokens(100001)
def test_validate_max_tokens_valid():
result = validate_max_tokens(1000)
assert result == 1000