From a40713d46319220c82ce8c02fad99763a829a073 Mon Sep 17 00:00:00 2001 From: retoor Date: Tue, 4 Nov 2025 08:01:20 +0100 Subject: [PATCH] ADded coverage. --- pr/memory/knowledge_store.py | 51 ++++--------- tests/test_enhanced_assistant.py | 89 ++++++++++++++++++++++ tests/test_logging.py | 24 ++++++ tests/test_session.py | 116 +++++++++++++++++++++++++++++ tests/test_usage_tracker.py | 86 ++++++++++++++++++++++ tests/test_validation.py | 122 +++++++++++++++++++++++++++++++ 6 files changed, 452 insertions(+), 36 deletions(-) create mode 100644 tests/test_enhanced_assistant.py create mode 100644 tests/test_logging.py create mode 100644 tests/test_session.py create mode 100644 tests/test_usage_tracker.py create mode 100644 tests/test_validation.py diff --git a/pr/memory/knowledge_store.py b/pr/memory/knowledge_store.py index 3fa6732..9625a70 100644 --- a/pr/memory/knowledge_store.py +++ b/pr/memory/knowledge_store.py @@ -31,13 +31,13 @@ class KnowledgeEntry: class KnowledgeStore: def __init__(self, db_path: str): self.db_path = db_path + self.conn = sqlite3.connect(self.db_path, check_same_thread=False) self.semantic_index = SemanticIndex() self._initialize_store() self._load_index() def _initialize_store(self): - conn = sqlite3.connect(self.db_path, check_same_thread=False) - cursor = conn.cursor() + cursor = self.conn.cursor() cursor.execute(''' CREATE TABLE IF NOT EXISTS knowledge_entries ( @@ -62,22 +62,17 @@ class KnowledgeStore: CREATE INDEX IF NOT EXISTS idx_created ON knowledge_entries(created_at DESC) ''') - conn.commit() - conn.close() + self.conn.commit() def _load_index(self): - conn = sqlite3.connect(self.db_path, check_same_thread=False) - cursor = conn.cursor() + cursor = self.conn.cursor() cursor.execute('SELECT entry_id, content FROM knowledge_entries') for row in cursor.fetchall(): self.semantic_index.add_document(row[0], row[1]) - conn.close() - def add_entry(self, entry: KnowledgeEntry): - conn = sqlite3.connect(self.db_path, check_same_thread=False) - cursor = conn.cursor() + cursor = self.conn.cursor() cursor.execute(''' INSERT OR REPLACE INTO knowledge_entries @@ -94,14 +89,12 @@ class KnowledgeStore: entry.importance_score )) - conn.commit() - conn.close() + self.conn.commit() self.semantic_index.add_document(entry.entry_id, entry.content) def get_entry(self, entry_id: str) -> Optional[KnowledgeEntry]: - conn = sqlite3.connect(self.db_path, check_same_thread=False) - cursor = conn.cursor() + cursor = self.conn.cursor() cursor.execute(''' SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score @@ -117,9 +110,7 @@ class KnowledgeStore: SET access_count = access_count + 1 WHERE entry_id = ? ''', (entry_id,)) - conn.commit() - - conn.close() + self.conn.commit() return KnowledgeEntry( entry_id=row[0], @@ -132,15 +123,13 @@ class KnowledgeStore: importance_score=row[7] ) - conn.close() return None def search_entries(self, query: str, category: Optional[str] = None, top_k: int = 5) -> List[KnowledgeEntry]: search_results = self.semantic_index.search(query, top_k * 2) - conn = sqlite3.connect(self.db_path, check_same_thread=False) - cursor = conn.cursor() + cursor = self.conn.cursor() entries = [] for entry_id, score in search_results: @@ -174,12 +163,10 @@ class KnowledgeStore: if len(entries) >= top_k: break - conn.close() return entries def get_by_category(self, category: str, limit: int = 20) -> List[KnowledgeEntry]: - conn = sqlite3.connect(self.db_path, check_same_thread=False) - cursor = conn.cursor() + cursor = self.conn.cursor() cursor.execute(''' SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score @@ -202,12 +189,10 @@ class KnowledgeStore: importance_score=row[7] )) - conn.close() return entries def update_importance(self, entry_id: str, importance_score: float): - conn = sqlite3.connect(self.db_path, check_same_thread=False) - cursor = conn.cursor() + cursor = self.conn.cursor() cursor.execute(''' UPDATE knowledge_entries @@ -215,18 +200,15 @@ class KnowledgeStore: WHERE entry_id = ? ''', (importance_score, time.time(), entry_id)) - conn.commit() - conn.close() + self.conn.commit() def delete_entry(self, entry_id: str) -> bool: - conn = sqlite3.connect(self.db_path, check_same_thread=False) - cursor = conn.cursor() + cursor = self.conn.cursor() cursor.execute('DELETE FROM knowledge_entries WHERE entry_id = ?', (entry_id,)) deleted = cursor.rowcount > 0 - conn.commit() - conn.close() + self.conn.commit() if deleted: self.semantic_index.remove_document(entry_id) @@ -234,8 +216,7 @@ class KnowledgeStore: return deleted def get_statistics(self) -> Dict[str, Any]: - conn = sqlite3.connect(self.db_path, check_same_thread=False) - cursor = conn.cursor() + cursor = self.conn.cursor() cursor.execute('SELECT COUNT(*) FROM knowledge_entries') total_entries = cursor.fetchone()[0] @@ -254,8 +235,6 @@ class KnowledgeStore: cursor.execute('SELECT SUM(access_count) FROM knowledge_entries') total_accesses = cursor.fetchone()[0] or 0 - conn.close() - return { 'total_entries': total_entries, 'total_categories': total_categories, diff --git a/tests/test_enhanced_assistant.py b/tests/test_enhanced_assistant.py new file mode 100644 index 0000000..d700f7e --- /dev/null +++ b/tests/test_enhanced_assistant.py @@ -0,0 +1,89 @@ +import pytest +from unittest.mock import MagicMock +from pr.core.enhanced_assistant import EnhancedAssistant + +def test_enhanced_assistant_init(): + mock_base = MagicMock() + assistant = EnhancedAssistant(mock_base) + assert assistant.base == mock_base + assert assistant.current_conversation_id is not None + +def test_enhanced_call_api_with_cache(): + mock_base = MagicMock() + mock_base.model = 'test-model' + mock_base.api_url = 'http://test' + mock_base.api_key = 'key' + mock_base.use_tools = False + mock_base.verbose = False + + assistant = EnhancedAssistant(mock_base) + assistant.api_cache = MagicMock() + assistant.api_cache.get.return_value = {'cached': True} + + result = assistant.enhanced_call_api([{'role': 'user', 'content': 'test'}]) + assert result == {'cached': True} + assistant.api_cache.get.assert_called_once() + +def test_enhanced_call_api_without_cache(): + mock_base = MagicMock() + mock_base.model = 'test-model' + mock_base.api_url = 'http://test' + mock_base.api_key = 'key' + mock_base.use_tools = False + mock_base.verbose = False + + assistant = EnhancedAssistant(mock_base) + assistant.api_cache = None + + # It will try to call API and fail with network error, but that's expected + result = assistant.enhanced_call_api([{'role': 'user', 'content': 'test'}]) + assert 'error' in result + +def test_execute_workflow_not_found(): + mock_base = MagicMock() + assistant = EnhancedAssistant(mock_base) + assistant.workflow_storage = MagicMock() + assistant.workflow_storage.load_workflow_by_name.return_value = None + + result = assistant.execute_workflow('nonexistent') + assert 'error' in result + +def test_create_agent(): + mock_base = MagicMock() + assistant = EnhancedAssistant(mock_base) + assistant.agent_manager = MagicMock() + assistant.agent_manager.create_agent.return_value = 'agent_id' + + result = assistant.create_agent('role') + assert result == 'agent_id' + +def test_search_knowledge(): + mock_base = MagicMock() + assistant = EnhancedAssistant(mock_base) + assistant.knowledge_store = MagicMock() + assistant.knowledge_store.search_entries.return_value = [{'result': True}] + + result = assistant.search_knowledge('query') + assert result == [{'result': True}] + +def test_get_cache_statistics(): + mock_base = MagicMock() + assistant = EnhancedAssistant(mock_base) + assistant.api_cache = MagicMock() + assistant.api_cache.get_statistics.return_value = {'hits': 10} + assistant.tool_cache = MagicMock() + assistant.tool_cache.get_statistics.return_value = {'misses': 5} + + stats = assistant.get_cache_statistics() + assert 'api_cache' in stats + assert 'tool_cache' in stats + +def test_clear_caches(): + mock_base = MagicMock() + assistant = EnhancedAssistant(mock_base) + assistant.api_cache = MagicMock() + assistant.tool_cache = MagicMock() + + assistant.clear_caches() + assistant.api_cache.clear_all.assert_called_once() + assistant.tool_cache.clear_all.assert_called_once() diff --git a/tests/test_logging.py b/tests/test_logging.py new file mode 100644 index 0000000..9f19267 --- /dev/null +++ b/tests/test_logging.py @@ -0,0 +1,24 @@ +import pytest +import tempfile +import os +from pr.core.logging import setup_logging, get_logger + +def test_setup_logging_basic(): + logger = setup_logging(verbose=False) + assert logger.name == 'pr' + assert logger.level == 20 # INFO + +def test_setup_logging_verbose(): + logger = setup_logging(verbose=True) + assert logger.name == 'pr' + assert logger.level == 10 # DEBUG + # Should have console handler + assert len(logger.handlers) >= 2 + +def test_get_logger_default(): + logger = get_logger() + assert logger.name == 'pr' + +def test_get_logger_named(): + logger = get_logger('test') + assert logger.name == 'pr.test' diff --git a/tests/test_session.py b/tests/test_session.py new file mode 100644 index 0000000..833d692 --- /dev/null +++ b/tests/test_session.py @@ -0,0 +1,116 @@ +import pytest +import tempfile +import os +import json +from pr.core.session import SessionManager + +@pytest.fixture +def temp_sessions_dir(tmp_path, monkeypatch): + from pr.core import session + original_dir = session.SESSIONS_DIR + monkeypatch.setattr(session, 'SESSIONS_DIR', str(tmp_path)) + # Clean any existing files + import shutil + if os.path.exists(str(tmp_path)): + shutil.rmtree(str(tmp_path)) + os.makedirs(str(tmp_path), exist_ok=True) + yield tmp_path + monkeypatch.setattr(session, 'SESSIONS_DIR', original_dir) + +def test_session_manager_init(temp_sessions_dir): + manager = SessionManager() + assert os.path.exists(temp_sessions_dir) + +def test_save_and_load_session(temp_sessions_dir): + manager = SessionManager() + name = "test_session" + messages = [{"role": "user", "content": "Hello"}] + metadata = {"test": True} + + assert manager.save_session(name, messages, metadata) + + loaded = manager.load_session(name) + assert loaded is not None + assert loaded['name'] == name + assert loaded['messages'] == messages + assert loaded['metadata'] == metadata + +def test_load_nonexistent_session(temp_sessions_dir): + manager = SessionManager() + loaded = manager.load_session("nonexistent") + assert loaded is None + +def test_list_sessions(temp_sessions_dir): + manager = SessionManager() + # Save a session + manager.save_session("session1", [{"role": "user", "content": "Hi"}]) + manager.save_session("session2", [{"role": "user", "content": "Hello"}]) + + sessions = manager.list_sessions() + assert len(sessions) == 2 + assert sessions[0]['name'] == "session2" # sorted by created_at desc + +def test_delete_session(temp_sessions_dir): + manager = SessionManager() + name = "to_delete" + manager.save_session(name, [{"role": "user", "content": "Test"}]) + + assert manager.delete_session(name) + assert manager.load_session(name) is None + +def test_delete_nonexistent_session(temp_sessions_dir): + manager = SessionManager() + assert not manager.delete_session("nonexistent") + +def test_export_session_json(temp_sessions_dir, tmp_path): + manager = SessionManager() + name = "export_test" + messages = [{"role": "user", "content": "Export me"}] + manager.save_session(name, messages) + + output_path = tmp_path / "exported.json" + assert manager.export_session(name, str(output_path), 'json') + assert output_path.exists() + + with open(output_path) as f: + data = json.load(f) + assert data['name'] == name + +def test_export_session_markdown(temp_sessions_dir, tmp_path): + manager = SessionManager() + name = "export_md" + messages = [{"role": "user", "content": "Markdown export"}] + manager.save_session(name, messages) + + output_path = tmp_path / "exported.md" + assert manager.export_session(name, str(output_path), 'markdown') + assert output_path.exists() + + content = output_path.read_text() + assert "# Session: export_md" in content + +def test_export_session_txt(temp_sessions_dir, tmp_path): + manager = SessionManager() + name = "export_txt" + messages = [{"role": "user", "content": "Text export"}] + manager.save_session(name, messages) + + output_path = tmp_path / "exported.txt" + assert manager.export_session(name, str(output_path), 'txt') + assert output_path.exists() + + content = output_path.read_text() + assert "Session: export_txt" in content + +def test_export_nonexistent_session(temp_sessions_dir, tmp_path): + manager = SessionManager() + output_path = tmp_path / "nonexistent.json" + assert not manager.export_session("nonexistent", str(output_path), 'json') + +def test_export_unsupported_format(temp_sessions_dir, tmp_path): + manager = SessionManager() + name = "test" + manager.save_session(name, [{"role": "user", "content": "Test"}]) + + output_path = tmp_path / "test.unsupported" + assert not manager.export_session(name, str(output_path), 'unsupported') diff --git a/tests/test_usage_tracker.py b/tests/test_usage_tracker.py new file mode 100644 index 0000000..5f6591d --- /dev/null +++ b/tests/test_usage_tracker.py @@ -0,0 +1,86 @@ +import pytest +import tempfile +import os +import json +from pr.core.usage_tracker import UsageTracker + +@pytest.fixture +def temp_usage_file(tmp_path, monkeypatch): + from pr.core import usage_tracker + original_file = usage_tracker.USAGE_DB_FILE + temp_file = str(tmp_path / "usage.json") + monkeypatch.setattr(usage_tracker, 'USAGE_DB_FILE', temp_file) + yield temp_file + if os.path.exists(temp_file): + os.remove(temp_file) + monkeypatch.setattr(usage_tracker, 'USAGE_DB_FILE', original_file) + +def test_usage_tracker_init(): + tracker = UsageTracker() + summary = tracker.get_session_summary() + assert summary['requests'] == 0 + assert summary['total_tokens'] == 0 + assert summary['estimated_cost'] == 0.0 + +def test_track_request_known_model(): + tracker = UsageTracker() + tracker.track_request('gpt-3.5-turbo', 100, 50) + + summary = tracker.get_session_summary() + assert summary['requests'] == 1 + assert summary['input_tokens'] == 100 + assert summary['output_tokens'] == 50 + assert summary['total_tokens'] == 150 + assert 'gpt-3.5-turbo' in summary['models_used'] + # Cost: (100/1000)*0.0005 + (50/1000)*0.0015 = 0.00005 + 0.000075 = 0.000125 + assert abs(summary['estimated_cost'] - 0.000125) < 1e-6 + +def test_track_request_unknown_model(): + tracker = UsageTracker() + tracker.track_request('unknown-model', 100, 50) + + summary = tracker.get_session_summary() + assert summary['requests'] == 1 + assert summary['estimated_cost'] == 0.0 # Unknown model, cost 0 + +def test_track_request_multiple(): + tracker = UsageTracker() + tracker.track_request('gpt-3.5-turbo', 100, 50) + tracker.track_request('gpt-4', 200, 100) + + summary = tracker.get_session_summary() + assert summary['requests'] == 2 + assert summary['input_tokens'] == 300 + assert summary['output_tokens'] == 150 + assert summary['total_tokens'] == 450 + assert len(summary['models_used']) == 2 + +def test_get_formatted_summary(): + tracker = UsageTracker() + tracker.track_request('gpt-3.5-turbo', 100, 50) + + formatted = tracker.get_formatted_summary() + assert "Total Requests: 1" in formatted + assert "Total Tokens: 150" in formatted + assert "Estimated Cost: $0.0001" in formatted + assert "gpt-3.5-turbo" in formatted + +def test_get_total_usage_no_file(temp_usage_file): + total = UsageTracker.get_total_usage() + assert total['total_requests'] == 0 + assert total['total_tokens'] == 0 + assert total['total_cost'] == 0.0 + +def test_get_total_usage_with_data(temp_usage_file): + # Manually create history file + history = [ + {'timestamp': '2023-01-01', 'model': 'gpt-3.5-turbo', 'input_tokens': 100, 'output_tokens': 50, 'total_tokens': 150, 'cost': 0.000125}, + {'timestamp': '2023-01-02', 'model': 'gpt-4', 'input_tokens': 200, 'output_tokens': 100, 'total_tokens': 300, 'cost': 0.008} + ] + with open(temp_usage_file, 'w') as f: + json.dump(history, f) + + total = UsageTracker.get_total_usage() + assert total['total_requests'] == 2 + assert total['total_tokens'] == 450 + assert abs(total['total_cost'] - 0.008125) < 1e-6 diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 0000000..7e1a2b4 --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,122 @@ +import pytest +import tempfile +import os +from pr.core.validation import ( + validate_file_path, + validate_directory_path, + validate_model_name, + validate_api_url, + validate_session_name, + validate_temperature, + validate_max_tokens, +) +from pr.core.exceptions import ValidationError + +def test_validate_file_path_empty(): + with pytest.raises(ValidationError, match="File path cannot be empty"): + validate_file_path("") + +def test_validate_file_path_not_exist(): + with pytest.raises(ValidationError, match="File does not exist"): + validate_file_path("/nonexistent/file.txt", must_exist=True) + +def test_validate_file_path_is_dir(): + with tempfile.TemporaryDirectory() as tmpdir: + with pytest.raises(ValidationError, match="Path is a directory"): + validate_file_path(tmpdir, must_exist=True) + +def test_validate_file_path_valid(): + with tempfile.NamedTemporaryFile() as tmpfile: + result = validate_file_path(tmpfile.name, must_exist=True) + assert os.path.isabs(result) + assert result == os.path.abspath(tmpfile.name) + +def test_validate_directory_path_empty(): + with pytest.raises(ValidationError, match="Directory path cannot be empty"): + validate_directory_path("") + +def test_validate_directory_path_not_exist(): + with pytest.raises(ValidationError, match="Directory does not exist"): + validate_directory_path("/nonexistent/dir", must_exist=True) + +def test_validate_directory_path_not_dir(): + with tempfile.NamedTemporaryFile() as tmpfile: + with pytest.raises(ValidationError, match="Path is not a directory"): + validate_directory_path(tmpfile.name, must_exist=True) + +def test_validate_directory_path_create(): + with tempfile.TemporaryDirectory() as tmpdir: + new_dir = os.path.join(tmpdir, "new_dir") + result = validate_directory_path(new_dir, must_exist=True, create=True) + assert os.path.isdir(new_dir) + assert result == os.path.abspath(new_dir) + +def test_validate_directory_path_valid(): + with tempfile.TemporaryDirectory() as tmpdir: + result = validate_directory_path(tmpdir, must_exist=True) + assert result == os.path.abspath(tmpdir) + +def test_validate_model_name_empty(): + with pytest.raises(ValidationError, match="Model name cannot be empty"): + validate_model_name("") + +def test_validate_model_name_too_short(): + with pytest.raises(ValidationError, match="Model name too short"): + validate_model_name("a") + +def test_validate_model_name_valid(): + result = validate_model_name("gpt-3.5-turbo") + assert result == "gpt-3.5-turbo" + +def test_validate_api_url_empty(): + with pytest.raises(ValidationError, match="API URL cannot be empty"): + validate_api_url("") + +def test_validate_api_url_invalid(): + with pytest.raises(ValidationError, match="API URL must start with"): + validate_api_url("invalid-url") + +def test_validate_api_url_valid(): + result = validate_api_url("https://api.example.com") + assert result == "https://api.example.com" + +def test_validate_session_name_empty(): + with pytest.raises(ValidationError, match="Session name cannot be empty"): + validate_session_name("") + +def test_validate_session_name_invalid_char(): + with pytest.raises(ValidationError, match="contains invalid character"): + validate_session_name("test/session") + +def test_validate_session_name_too_long(): + long_name = "a" * 256 + with pytest.raises(ValidationError, match="Session name too long"): + validate_session_name(long_name) + +def test_validate_session_name_valid(): + result = validate_session_name("valid_session_123") + assert result == "valid_session_123" + +def test_validate_temperature_too_low(): + with pytest.raises(ValidationError, match="Temperature must be between"): + validate_temperature(-0.1) + +def test_validate_temperature_too_high(): + with pytest.raises(ValidationError, match="Temperature must be between"): + validate_temperature(2.1) + +def test_validate_temperature_valid(): + result = validate_temperature(0.7) + assert result == 0.7 + +def test_validate_max_tokens_too_low(): + with pytest.raises(ValidationError, match="Max tokens must be at least 1"): + validate_max_tokens(0) + +def test_validate_max_tokens_too_high(): + with pytest.raises(ValidationError, match="Max tokens too high"): + validate_max_tokens(100001) + +def test_validate_max_tokens_valid(): + result = validate_max_tokens(1000) + assert result == 1000