ADded coverage.
This commit is contained in:
parent
2b701cb5cd
commit
a40713d463
@ -31,13 +31,13 @@ class KnowledgeEntry:
|
|||||||
class KnowledgeStore:
|
class KnowledgeStore:
|
||||||
def __init__(self, db_path: str):
|
def __init__(self, db_path: str):
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
|
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||||
self.semantic_index = SemanticIndex()
|
self.semantic_index = SemanticIndex()
|
||||||
self._initialize_store()
|
self._initialize_store()
|
||||||
self._load_index()
|
self._load_index()
|
||||||
|
|
||||||
def _initialize_store(self):
|
def _initialize_store(self):
|
||||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
cursor = self.conn.cursor()
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
cursor.execute('''
|
cursor.execute('''
|
||||||
CREATE TABLE IF NOT EXISTS knowledge_entries (
|
CREATE TABLE IF NOT EXISTS knowledge_entries (
|
||||||
@ -62,22 +62,17 @@ class KnowledgeStore:
|
|||||||
CREATE INDEX IF NOT EXISTS idx_created ON knowledge_entries(created_at DESC)
|
CREATE INDEX IF NOT EXISTS idx_created ON knowledge_entries(created_at DESC)
|
||||||
''')
|
''')
|
||||||
|
|
||||||
conn.commit()
|
self.conn.commit()
|
||||||
conn.close()
|
|
||||||
|
|
||||||
def _load_index(self):
|
def _load_index(self):
|
||||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
cursor = self.conn.cursor()
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
cursor.execute('SELECT entry_id, content FROM knowledge_entries')
|
cursor.execute('SELECT entry_id, content FROM knowledge_entries')
|
||||||
for row in cursor.fetchall():
|
for row in cursor.fetchall():
|
||||||
self.semantic_index.add_document(row[0], row[1])
|
self.semantic_index.add_document(row[0], row[1])
|
||||||
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
def add_entry(self, entry: KnowledgeEntry):
|
def add_entry(self, entry: KnowledgeEntry):
|
||||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
cursor = self.conn.cursor()
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
cursor.execute('''
|
cursor.execute('''
|
||||||
INSERT OR REPLACE INTO knowledge_entries
|
INSERT OR REPLACE INTO knowledge_entries
|
||||||
@ -94,14 +89,12 @@ class KnowledgeStore:
|
|||||||
entry.importance_score
|
entry.importance_score
|
||||||
))
|
))
|
||||||
|
|
||||||
conn.commit()
|
self.conn.commit()
|
||||||
conn.close()
|
|
||||||
|
|
||||||
self.semantic_index.add_document(entry.entry_id, entry.content)
|
self.semantic_index.add_document(entry.entry_id, entry.content)
|
||||||
|
|
||||||
def get_entry(self, entry_id: str) -> Optional[KnowledgeEntry]:
|
def get_entry(self, entry_id: str) -> Optional[KnowledgeEntry]:
|
||||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
cursor = self.conn.cursor()
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
cursor.execute('''
|
cursor.execute('''
|
||||||
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
|
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
|
||||||
@ -117,9 +110,7 @@ class KnowledgeStore:
|
|||||||
SET access_count = access_count + 1
|
SET access_count = access_count + 1
|
||||||
WHERE entry_id = ?
|
WHERE entry_id = ?
|
||||||
''', (entry_id,))
|
''', (entry_id,))
|
||||||
conn.commit()
|
self.conn.commit()
|
||||||
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
return KnowledgeEntry(
|
return KnowledgeEntry(
|
||||||
entry_id=row[0],
|
entry_id=row[0],
|
||||||
@ -132,15 +123,13 @@ class KnowledgeStore:
|
|||||||
importance_score=row[7]
|
importance_score=row[7]
|
||||||
)
|
)
|
||||||
|
|
||||||
conn.close()
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def search_entries(self, query: str, category: Optional[str] = None,
|
def search_entries(self, query: str, category: Optional[str] = None,
|
||||||
top_k: int = 5) -> List[KnowledgeEntry]:
|
top_k: int = 5) -> List[KnowledgeEntry]:
|
||||||
search_results = self.semantic_index.search(query, top_k * 2)
|
search_results = self.semantic_index.search(query, top_k * 2)
|
||||||
|
|
||||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
cursor = self.conn.cursor()
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
entries = []
|
entries = []
|
||||||
for entry_id, score in search_results:
|
for entry_id, score in search_results:
|
||||||
@ -174,12 +163,10 @@ class KnowledgeStore:
|
|||||||
if len(entries) >= top_k:
|
if len(entries) >= top_k:
|
||||||
break
|
break
|
||||||
|
|
||||||
conn.close()
|
|
||||||
return entries
|
return entries
|
||||||
|
|
||||||
def get_by_category(self, category: str, limit: int = 20) -> List[KnowledgeEntry]:
|
def get_by_category(self, category: str, limit: int = 20) -> List[KnowledgeEntry]:
|
||||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
cursor = self.conn.cursor()
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
cursor.execute('''
|
cursor.execute('''
|
||||||
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
|
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
|
||||||
@ -202,12 +189,10 @@ class KnowledgeStore:
|
|||||||
importance_score=row[7]
|
importance_score=row[7]
|
||||||
))
|
))
|
||||||
|
|
||||||
conn.close()
|
|
||||||
return entries
|
return entries
|
||||||
|
|
||||||
def update_importance(self, entry_id: str, importance_score: float):
|
def update_importance(self, entry_id: str, importance_score: float):
|
||||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
cursor = self.conn.cursor()
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
cursor.execute('''
|
cursor.execute('''
|
||||||
UPDATE knowledge_entries
|
UPDATE knowledge_entries
|
||||||
@ -215,18 +200,15 @@ class KnowledgeStore:
|
|||||||
WHERE entry_id = ?
|
WHERE entry_id = ?
|
||||||
''', (importance_score, time.time(), entry_id))
|
''', (importance_score, time.time(), entry_id))
|
||||||
|
|
||||||
conn.commit()
|
self.conn.commit()
|
||||||
conn.close()
|
|
||||||
|
|
||||||
def delete_entry(self, entry_id: str) -> bool:
|
def delete_entry(self, entry_id: str) -> bool:
|
||||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
cursor = self.conn.cursor()
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
cursor.execute('DELETE FROM knowledge_entries WHERE entry_id = ?', (entry_id,))
|
cursor.execute('DELETE FROM knowledge_entries WHERE entry_id = ?', (entry_id,))
|
||||||
deleted = cursor.rowcount > 0
|
deleted = cursor.rowcount > 0
|
||||||
|
|
||||||
conn.commit()
|
self.conn.commit()
|
||||||
conn.close()
|
|
||||||
|
|
||||||
if deleted:
|
if deleted:
|
||||||
self.semantic_index.remove_document(entry_id)
|
self.semantic_index.remove_document(entry_id)
|
||||||
@ -234,8 +216,7 @@ class KnowledgeStore:
|
|||||||
return deleted
|
return deleted
|
||||||
|
|
||||||
def get_statistics(self) -> Dict[str, Any]:
|
def get_statistics(self) -> Dict[str, Any]:
|
||||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
cursor = self.conn.cursor()
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
cursor.execute('SELECT COUNT(*) FROM knowledge_entries')
|
cursor.execute('SELECT COUNT(*) FROM knowledge_entries')
|
||||||
total_entries = cursor.fetchone()[0]
|
total_entries = cursor.fetchone()[0]
|
||||||
@ -254,8 +235,6 @@ class KnowledgeStore:
|
|||||||
cursor.execute('SELECT SUM(access_count) FROM knowledge_entries')
|
cursor.execute('SELECT SUM(access_count) FROM knowledge_entries')
|
||||||
total_accesses = cursor.fetchone()[0] or 0
|
total_accesses = cursor.fetchone()[0] or 0
|
||||||
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'total_entries': total_entries,
|
'total_entries': total_entries,
|
||||||
'total_categories': total_categories,
|
'total_categories': total_categories,
|
||||||
|
|||||||
89
tests/test_enhanced_assistant.py
Normal file
89
tests/test_enhanced_assistant.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from pr.core.enhanced_assistant import EnhancedAssistant
|
||||||
|
|
||||||
|
def test_enhanced_assistant_init():
|
||||||
|
mock_base = MagicMock()
|
||||||
|
assistant = EnhancedAssistant(mock_base)
|
||||||
|
assert assistant.base == mock_base
|
||||||
|
assert assistant.current_conversation_id is not None
|
||||||
|
|
||||||
|
def test_enhanced_call_api_with_cache():
|
||||||
|
mock_base = MagicMock()
|
||||||
|
mock_base.model = 'test-model'
|
||||||
|
mock_base.api_url = 'http://test'
|
||||||
|
mock_base.api_key = 'key'
|
||||||
|
mock_base.use_tools = False
|
||||||
|
mock_base.verbose = False
|
||||||
|
|
||||||
|
assistant = EnhancedAssistant(mock_base)
|
||||||
|
assistant.api_cache = MagicMock()
|
||||||
|
assistant.api_cache.get.return_value = {'cached': True}
|
||||||
|
|
||||||
|
result = assistant.enhanced_call_api([{'role': 'user', 'content': 'test'}])
|
||||||
|
assert result == {'cached': True}
|
||||||
|
assistant.api_cache.get.assert_called_once()
|
||||||
|
|
||||||
|
def test_enhanced_call_api_without_cache():
|
||||||
|
mock_base = MagicMock()
|
||||||
|
mock_base.model = 'test-model'
|
||||||
|
mock_base.api_url = 'http://test'
|
||||||
|
mock_base.api_key = 'key'
|
||||||
|
mock_base.use_tools = False
|
||||||
|
mock_base.verbose = False
|
||||||
|
|
||||||
|
assistant = EnhancedAssistant(mock_base)
|
||||||
|
assistant.api_cache = None
|
||||||
|
|
||||||
|
# It will try to call API and fail with network error, but that's expected
|
||||||
|
result = assistant.enhanced_call_api([{'role': 'user', 'content': 'test'}])
|
||||||
|
assert 'error' in result
|
||||||
|
|
||||||
|
def test_execute_workflow_not_found():
|
||||||
|
mock_base = MagicMock()
|
||||||
|
assistant = EnhancedAssistant(mock_base)
|
||||||
|
assistant.workflow_storage = MagicMock()
|
||||||
|
assistant.workflow_storage.load_workflow_by_name.return_value = None
|
||||||
|
|
||||||
|
result = assistant.execute_workflow('nonexistent')
|
||||||
|
assert 'error' in result
|
||||||
|
|
||||||
|
def test_create_agent():
|
||||||
|
mock_base = MagicMock()
|
||||||
|
assistant = EnhancedAssistant(mock_base)
|
||||||
|
assistant.agent_manager = MagicMock()
|
||||||
|
assistant.agent_manager.create_agent.return_value = 'agent_id'
|
||||||
|
|
||||||
|
result = assistant.create_agent('role')
|
||||||
|
assert result == 'agent_id'
|
||||||
|
|
||||||
|
def test_search_knowledge():
|
||||||
|
mock_base = MagicMock()
|
||||||
|
assistant = EnhancedAssistant(mock_base)
|
||||||
|
assistant.knowledge_store = MagicMock()
|
||||||
|
assistant.knowledge_store.search_entries.return_value = [{'result': True}]
|
||||||
|
|
||||||
|
result = assistant.search_knowledge('query')
|
||||||
|
assert result == [{'result': True}]
|
||||||
|
|
||||||
|
def test_get_cache_statistics():
|
||||||
|
mock_base = MagicMock()
|
||||||
|
assistant = EnhancedAssistant(mock_base)
|
||||||
|
assistant.api_cache = MagicMock()
|
||||||
|
assistant.api_cache.get_statistics.return_value = {'hits': 10}
|
||||||
|
assistant.tool_cache = MagicMock()
|
||||||
|
assistant.tool_cache.get_statistics.return_value = {'misses': 5}
|
||||||
|
|
||||||
|
stats = assistant.get_cache_statistics()
|
||||||
|
assert 'api_cache' in stats
|
||||||
|
assert 'tool_cache' in stats
|
||||||
|
|
||||||
|
def test_clear_caches():
|
||||||
|
mock_base = MagicMock()
|
||||||
|
assistant = EnhancedAssistant(mock_base)
|
||||||
|
assistant.api_cache = MagicMock()
|
||||||
|
assistant.tool_cache = MagicMock()
|
||||||
|
|
||||||
|
assistant.clear_caches()
|
||||||
|
assistant.api_cache.clear_all.assert_called_once()
|
||||||
|
assistant.tool_cache.clear_all.assert_called_once()
|
||||||
24
tests/test_logging.py
Normal file
24
tests/test_logging.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
import pytest
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
from pr.core.logging import setup_logging, get_logger
|
||||||
|
|
||||||
|
def test_setup_logging_basic():
|
||||||
|
logger = setup_logging(verbose=False)
|
||||||
|
assert logger.name == 'pr'
|
||||||
|
assert logger.level == 20 # INFO
|
||||||
|
|
||||||
|
def test_setup_logging_verbose():
|
||||||
|
logger = setup_logging(verbose=True)
|
||||||
|
assert logger.name == 'pr'
|
||||||
|
assert logger.level == 10 # DEBUG
|
||||||
|
# Should have console handler
|
||||||
|
assert len(logger.handlers) >= 2
|
||||||
|
|
||||||
|
def test_get_logger_default():
|
||||||
|
logger = get_logger()
|
||||||
|
assert logger.name == 'pr'
|
||||||
|
|
||||||
|
def test_get_logger_named():
|
||||||
|
logger = get_logger('test')
|
||||||
|
assert logger.name == 'pr.test'
|
||||||
116
tests/test_session.py
Normal file
116
tests/test_session.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
import pytest
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from pr.core.session import SessionManager
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_sessions_dir(tmp_path, monkeypatch):
|
||||||
|
from pr.core import session
|
||||||
|
original_dir = session.SESSIONS_DIR
|
||||||
|
monkeypatch.setattr(session, 'SESSIONS_DIR', str(tmp_path))
|
||||||
|
# Clean any existing files
|
||||||
|
import shutil
|
||||||
|
if os.path.exists(str(tmp_path)):
|
||||||
|
shutil.rmtree(str(tmp_path))
|
||||||
|
os.makedirs(str(tmp_path), exist_ok=True)
|
||||||
|
yield tmp_path
|
||||||
|
monkeypatch.setattr(session, 'SESSIONS_DIR', original_dir)
|
||||||
|
|
||||||
|
def test_session_manager_init(temp_sessions_dir):
|
||||||
|
manager = SessionManager()
|
||||||
|
assert os.path.exists(temp_sessions_dir)
|
||||||
|
|
||||||
|
def test_save_and_load_session(temp_sessions_dir):
|
||||||
|
manager = SessionManager()
|
||||||
|
name = "test_session"
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
metadata = {"test": True}
|
||||||
|
|
||||||
|
assert manager.save_session(name, messages, metadata)
|
||||||
|
|
||||||
|
loaded = manager.load_session(name)
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded['name'] == name
|
||||||
|
assert loaded['messages'] == messages
|
||||||
|
assert loaded['metadata'] == metadata
|
||||||
|
|
||||||
|
def test_load_nonexistent_session(temp_sessions_dir):
|
||||||
|
manager = SessionManager()
|
||||||
|
loaded = manager.load_session("nonexistent")
|
||||||
|
assert loaded is None
|
||||||
|
|
||||||
|
def test_list_sessions(temp_sessions_dir):
|
||||||
|
manager = SessionManager()
|
||||||
|
# Save a session
|
||||||
|
manager.save_session("session1", [{"role": "user", "content": "Hi"}])
|
||||||
|
manager.save_session("session2", [{"role": "user", "content": "Hello"}])
|
||||||
|
|
||||||
|
sessions = manager.list_sessions()
|
||||||
|
assert len(sessions) == 2
|
||||||
|
assert sessions[0]['name'] == "session2" # sorted by created_at desc
|
||||||
|
|
||||||
|
def test_delete_session(temp_sessions_dir):
|
||||||
|
manager = SessionManager()
|
||||||
|
name = "to_delete"
|
||||||
|
manager.save_session(name, [{"role": "user", "content": "Test"}])
|
||||||
|
|
||||||
|
assert manager.delete_session(name)
|
||||||
|
assert manager.load_session(name) is None
|
||||||
|
|
||||||
|
def test_delete_nonexistent_session(temp_sessions_dir):
|
||||||
|
manager = SessionManager()
|
||||||
|
assert not manager.delete_session("nonexistent")
|
||||||
|
|
||||||
|
def test_export_session_json(temp_sessions_dir, tmp_path):
|
||||||
|
manager = SessionManager()
|
||||||
|
name = "export_test"
|
||||||
|
messages = [{"role": "user", "content": "Export me"}]
|
||||||
|
manager.save_session(name, messages)
|
||||||
|
|
||||||
|
output_path = tmp_path / "exported.json"
|
||||||
|
assert manager.export_session(name, str(output_path), 'json')
|
||||||
|
assert output_path.exists()
|
||||||
|
|
||||||
|
with open(output_path) as f:
|
||||||
|
data = json.load(f)
|
||||||
|
assert data['name'] == name
|
||||||
|
|
||||||
|
def test_export_session_markdown(temp_sessions_dir, tmp_path):
|
||||||
|
manager = SessionManager()
|
||||||
|
name = "export_md"
|
||||||
|
messages = [{"role": "user", "content": "Markdown export"}]
|
||||||
|
manager.save_session(name, messages)
|
||||||
|
|
||||||
|
output_path = tmp_path / "exported.md"
|
||||||
|
assert manager.export_session(name, str(output_path), 'markdown')
|
||||||
|
assert output_path.exists()
|
||||||
|
|
||||||
|
content = output_path.read_text()
|
||||||
|
assert "# Session: export_md" in content
|
||||||
|
|
||||||
|
def test_export_session_txt(temp_sessions_dir, tmp_path):
|
||||||
|
manager = SessionManager()
|
||||||
|
name = "export_txt"
|
||||||
|
messages = [{"role": "user", "content": "Text export"}]
|
||||||
|
manager.save_session(name, messages)
|
||||||
|
|
||||||
|
output_path = tmp_path / "exported.txt"
|
||||||
|
assert manager.export_session(name, str(output_path), 'txt')
|
||||||
|
assert output_path.exists()
|
||||||
|
|
||||||
|
content = output_path.read_text()
|
||||||
|
assert "Session: export_txt" in content
|
||||||
|
|
||||||
|
def test_export_nonexistent_session(temp_sessions_dir, tmp_path):
|
||||||
|
manager = SessionManager()
|
||||||
|
output_path = tmp_path / "nonexistent.json"
|
||||||
|
assert not manager.export_session("nonexistent", str(output_path), 'json')
|
||||||
|
|
||||||
|
def test_export_unsupported_format(temp_sessions_dir, tmp_path):
|
||||||
|
manager = SessionManager()
|
||||||
|
name = "test"
|
||||||
|
manager.save_session(name, [{"role": "user", "content": "Test"}])
|
||||||
|
|
||||||
|
output_path = tmp_path / "test.unsupported"
|
||||||
|
assert not manager.export_session(name, str(output_path), 'unsupported')
|
||||||
86
tests/test_usage_tracker.py
Normal file
86
tests/test_usage_tracker.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
import pytest
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from pr.core.usage_tracker import UsageTracker
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_usage_file(tmp_path, monkeypatch):
|
||||||
|
from pr.core import usage_tracker
|
||||||
|
original_file = usage_tracker.USAGE_DB_FILE
|
||||||
|
temp_file = str(tmp_path / "usage.json")
|
||||||
|
monkeypatch.setattr(usage_tracker, 'USAGE_DB_FILE', temp_file)
|
||||||
|
yield temp_file
|
||||||
|
if os.path.exists(temp_file):
|
||||||
|
os.remove(temp_file)
|
||||||
|
monkeypatch.setattr(usage_tracker, 'USAGE_DB_FILE', original_file)
|
||||||
|
|
||||||
|
def test_usage_tracker_init():
|
||||||
|
tracker = UsageTracker()
|
||||||
|
summary = tracker.get_session_summary()
|
||||||
|
assert summary['requests'] == 0
|
||||||
|
assert summary['total_tokens'] == 0
|
||||||
|
assert summary['estimated_cost'] == 0.0
|
||||||
|
|
||||||
|
def test_track_request_known_model():
|
||||||
|
tracker = UsageTracker()
|
||||||
|
tracker.track_request('gpt-3.5-turbo', 100, 50)
|
||||||
|
|
||||||
|
summary = tracker.get_session_summary()
|
||||||
|
assert summary['requests'] == 1
|
||||||
|
assert summary['input_tokens'] == 100
|
||||||
|
assert summary['output_tokens'] == 50
|
||||||
|
assert summary['total_tokens'] == 150
|
||||||
|
assert 'gpt-3.5-turbo' in summary['models_used']
|
||||||
|
# Cost: (100/1000)*0.0005 + (50/1000)*0.0015 = 0.00005 + 0.000075 = 0.000125
|
||||||
|
assert abs(summary['estimated_cost'] - 0.000125) < 1e-6
|
||||||
|
|
||||||
|
def test_track_request_unknown_model():
|
||||||
|
tracker = UsageTracker()
|
||||||
|
tracker.track_request('unknown-model', 100, 50)
|
||||||
|
|
||||||
|
summary = tracker.get_session_summary()
|
||||||
|
assert summary['requests'] == 1
|
||||||
|
assert summary['estimated_cost'] == 0.0 # Unknown model, cost 0
|
||||||
|
|
||||||
|
def test_track_request_multiple():
|
||||||
|
tracker = UsageTracker()
|
||||||
|
tracker.track_request('gpt-3.5-turbo', 100, 50)
|
||||||
|
tracker.track_request('gpt-4', 200, 100)
|
||||||
|
|
||||||
|
summary = tracker.get_session_summary()
|
||||||
|
assert summary['requests'] == 2
|
||||||
|
assert summary['input_tokens'] == 300
|
||||||
|
assert summary['output_tokens'] == 150
|
||||||
|
assert summary['total_tokens'] == 450
|
||||||
|
assert len(summary['models_used']) == 2
|
||||||
|
|
||||||
|
def test_get_formatted_summary():
|
||||||
|
tracker = UsageTracker()
|
||||||
|
tracker.track_request('gpt-3.5-turbo', 100, 50)
|
||||||
|
|
||||||
|
formatted = tracker.get_formatted_summary()
|
||||||
|
assert "Total Requests: 1" in formatted
|
||||||
|
assert "Total Tokens: 150" in formatted
|
||||||
|
assert "Estimated Cost: $0.0001" in formatted
|
||||||
|
assert "gpt-3.5-turbo" in formatted
|
||||||
|
|
||||||
|
def test_get_total_usage_no_file(temp_usage_file):
|
||||||
|
total = UsageTracker.get_total_usage()
|
||||||
|
assert total['total_requests'] == 0
|
||||||
|
assert total['total_tokens'] == 0
|
||||||
|
assert total['total_cost'] == 0.0
|
||||||
|
|
||||||
|
def test_get_total_usage_with_data(temp_usage_file):
|
||||||
|
# Manually create history file
|
||||||
|
history = [
|
||||||
|
{'timestamp': '2023-01-01', 'model': 'gpt-3.5-turbo', 'input_tokens': 100, 'output_tokens': 50, 'total_tokens': 150, 'cost': 0.000125},
|
||||||
|
{'timestamp': '2023-01-02', 'model': 'gpt-4', 'input_tokens': 200, 'output_tokens': 100, 'total_tokens': 300, 'cost': 0.008}
|
||||||
|
]
|
||||||
|
with open(temp_usage_file, 'w') as f:
|
||||||
|
json.dump(history, f)
|
||||||
|
|
||||||
|
total = UsageTracker.get_total_usage()
|
||||||
|
assert total['total_requests'] == 2
|
||||||
|
assert total['total_tokens'] == 450
|
||||||
|
assert abs(total['total_cost'] - 0.008125) < 1e-6
|
||||||
122
tests/test_validation.py
Normal file
122
tests/test_validation.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
import pytest
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
from pr.core.validation import (
|
||||||
|
validate_file_path,
|
||||||
|
validate_directory_path,
|
||||||
|
validate_model_name,
|
||||||
|
validate_api_url,
|
||||||
|
validate_session_name,
|
||||||
|
validate_temperature,
|
||||||
|
validate_max_tokens,
|
||||||
|
)
|
||||||
|
from pr.core.exceptions import ValidationError
|
||||||
|
|
||||||
|
def test_validate_file_path_empty():
|
||||||
|
with pytest.raises(ValidationError, match="File path cannot be empty"):
|
||||||
|
validate_file_path("")
|
||||||
|
|
||||||
|
def test_validate_file_path_not_exist():
|
||||||
|
with pytest.raises(ValidationError, match="File does not exist"):
|
||||||
|
validate_file_path("/nonexistent/file.txt", must_exist=True)
|
||||||
|
|
||||||
|
def test_validate_file_path_is_dir():
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
with pytest.raises(ValidationError, match="Path is a directory"):
|
||||||
|
validate_file_path(tmpdir, must_exist=True)
|
||||||
|
|
||||||
|
def test_validate_file_path_valid():
|
||||||
|
with tempfile.NamedTemporaryFile() as tmpfile:
|
||||||
|
result = validate_file_path(tmpfile.name, must_exist=True)
|
||||||
|
assert os.path.isabs(result)
|
||||||
|
assert result == os.path.abspath(tmpfile.name)
|
||||||
|
|
||||||
|
def test_validate_directory_path_empty():
|
||||||
|
with pytest.raises(ValidationError, match="Directory path cannot be empty"):
|
||||||
|
validate_directory_path("")
|
||||||
|
|
||||||
|
def test_validate_directory_path_not_exist():
|
||||||
|
with pytest.raises(ValidationError, match="Directory does not exist"):
|
||||||
|
validate_directory_path("/nonexistent/dir", must_exist=True)
|
||||||
|
|
||||||
|
def test_validate_directory_path_not_dir():
|
||||||
|
with tempfile.NamedTemporaryFile() as tmpfile:
|
||||||
|
with pytest.raises(ValidationError, match="Path is not a directory"):
|
||||||
|
validate_directory_path(tmpfile.name, must_exist=True)
|
||||||
|
|
||||||
|
def test_validate_directory_path_create():
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
new_dir = os.path.join(tmpdir, "new_dir")
|
||||||
|
result = validate_directory_path(new_dir, must_exist=True, create=True)
|
||||||
|
assert os.path.isdir(new_dir)
|
||||||
|
assert result == os.path.abspath(new_dir)
|
||||||
|
|
||||||
|
def test_validate_directory_path_valid():
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
result = validate_directory_path(tmpdir, must_exist=True)
|
||||||
|
assert result == os.path.abspath(tmpdir)
|
||||||
|
|
||||||
|
def test_validate_model_name_empty():
|
||||||
|
with pytest.raises(ValidationError, match="Model name cannot be empty"):
|
||||||
|
validate_model_name("")
|
||||||
|
|
||||||
|
def test_validate_model_name_too_short():
|
||||||
|
with pytest.raises(ValidationError, match="Model name too short"):
|
||||||
|
validate_model_name("a")
|
||||||
|
|
||||||
|
def test_validate_model_name_valid():
|
||||||
|
result = validate_model_name("gpt-3.5-turbo")
|
||||||
|
assert result == "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
def test_validate_api_url_empty():
|
||||||
|
with pytest.raises(ValidationError, match="API URL cannot be empty"):
|
||||||
|
validate_api_url("")
|
||||||
|
|
||||||
|
def test_validate_api_url_invalid():
|
||||||
|
with pytest.raises(ValidationError, match="API URL must start with"):
|
||||||
|
validate_api_url("invalid-url")
|
||||||
|
|
||||||
|
def test_validate_api_url_valid():
|
||||||
|
result = validate_api_url("https://api.example.com")
|
||||||
|
assert result == "https://api.example.com"
|
||||||
|
|
||||||
|
def test_validate_session_name_empty():
|
||||||
|
with pytest.raises(ValidationError, match="Session name cannot be empty"):
|
||||||
|
validate_session_name("")
|
||||||
|
|
||||||
|
def test_validate_session_name_invalid_char():
|
||||||
|
with pytest.raises(ValidationError, match="contains invalid character"):
|
||||||
|
validate_session_name("test/session")
|
||||||
|
|
||||||
|
def test_validate_session_name_too_long():
|
||||||
|
long_name = "a" * 256
|
||||||
|
with pytest.raises(ValidationError, match="Session name too long"):
|
||||||
|
validate_session_name(long_name)
|
||||||
|
|
||||||
|
def test_validate_session_name_valid():
|
||||||
|
result = validate_session_name("valid_session_123")
|
||||||
|
assert result == "valid_session_123"
|
||||||
|
|
||||||
|
def test_validate_temperature_too_low():
|
||||||
|
with pytest.raises(ValidationError, match="Temperature must be between"):
|
||||||
|
validate_temperature(-0.1)
|
||||||
|
|
||||||
|
def test_validate_temperature_too_high():
|
||||||
|
with pytest.raises(ValidationError, match="Temperature must be between"):
|
||||||
|
validate_temperature(2.1)
|
||||||
|
|
||||||
|
def test_validate_temperature_valid():
|
||||||
|
result = validate_temperature(0.7)
|
||||||
|
assert result == 0.7
|
||||||
|
|
||||||
|
def test_validate_max_tokens_too_low():
|
||||||
|
with pytest.raises(ValidationError, match="Max tokens must be at least 1"):
|
||||||
|
validate_max_tokens(0)
|
||||||
|
|
||||||
|
def test_validate_max_tokens_too_high():
|
||||||
|
with pytest.raises(ValidationError, match="Max tokens too high"):
|
||||||
|
validate_max_tokens(100001)
|
||||||
|
|
||||||
|
def test_validate_max_tokens_valid():
|
||||||
|
result = validate_max_tokens(1000)
|
||||||
|
assert result == 1000
|
||||||
Loading…
Reference in New Issue
Block a user