ADded coverage.
This commit is contained in:
parent
2b701cb5cd
commit
a40713d463
@ -31,13 +31,13 @@ class KnowledgeEntry:
|
||||
class KnowledgeStore:
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
self.semantic_index = SemanticIndex()
|
||||
self._initialize_store()
|
||||
self._load_index()
|
||||
|
||||
def _initialize_store(self):
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS knowledge_entries (
|
||||
@ -62,22 +62,17 @@ class KnowledgeStore:
|
||||
CREATE INDEX IF NOT EXISTS idx_created ON knowledge_entries(created_at DESC)
|
||||
''')
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
self.conn.commit()
|
||||
|
||||
def _load_index(self):
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute('SELECT entry_id, content FROM knowledge_entries')
|
||||
for row in cursor.fetchall():
|
||||
self.semantic_index.add_document(row[0], row[1])
|
||||
|
||||
conn.close()
|
||||
|
||||
def add_entry(self, entry: KnowledgeEntry):
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
INSERT OR REPLACE INTO knowledge_entries
|
||||
@ -94,14 +89,12 @@ class KnowledgeStore:
|
||||
entry.importance_score
|
||||
))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
self.conn.commit()
|
||||
|
||||
self.semantic_index.add_document(entry.entry_id, entry.content)
|
||||
|
||||
def get_entry(self, entry_id: str) -> Optional[KnowledgeEntry]:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
|
||||
@ -117,9 +110,7 @@ class KnowledgeStore:
|
||||
SET access_count = access_count + 1
|
||||
WHERE entry_id = ?
|
||||
''', (entry_id,))
|
||||
conn.commit()
|
||||
|
||||
conn.close()
|
||||
self.conn.commit()
|
||||
|
||||
return KnowledgeEntry(
|
||||
entry_id=row[0],
|
||||
@ -132,15 +123,13 @@ class KnowledgeStore:
|
||||
importance_score=row[7]
|
||||
)
|
||||
|
||||
conn.close()
|
||||
return None
|
||||
|
||||
def search_entries(self, query: str, category: Optional[str] = None,
|
||||
top_k: int = 5) -> List[KnowledgeEntry]:
|
||||
search_results = self.semantic_index.search(query, top_k * 2)
|
||||
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
entries = []
|
||||
for entry_id, score in search_results:
|
||||
@ -174,12 +163,10 @@ class KnowledgeStore:
|
||||
if len(entries) >= top_k:
|
||||
break
|
||||
|
||||
conn.close()
|
||||
return entries
|
||||
|
||||
def get_by_category(self, category: str, limit: int = 20) -> List[KnowledgeEntry]:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
|
||||
@ -202,12 +189,10 @@ class KnowledgeStore:
|
||||
importance_score=row[7]
|
||||
))
|
||||
|
||||
conn.close()
|
||||
return entries
|
||||
|
||||
def update_importance(self, entry_id: str, importance_score: float):
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
UPDATE knowledge_entries
|
||||
@ -215,18 +200,15 @@ class KnowledgeStore:
|
||||
WHERE entry_id = ?
|
||||
''', (importance_score, time.time(), entry_id))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
self.conn.commit()
|
||||
|
||||
def delete_entry(self, entry_id: str) -> bool:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute('DELETE FROM knowledge_entries WHERE entry_id = ?', (entry_id,))
|
||||
deleted = cursor.rowcount > 0
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
self.conn.commit()
|
||||
|
||||
if deleted:
|
||||
self.semantic_index.remove_document(entry_id)
|
||||
@ -234,8 +216,7 @@ class KnowledgeStore:
|
||||
return deleted
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
cursor = conn.cursor()
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute('SELECT COUNT(*) FROM knowledge_entries')
|
||||
total_entries = cursor.fetchone()[0]
|
||||
@ -254,8 +235,6 @@ class KnowledgeStore:
|
||||
cursor.execute('SELECT SUM(access_count) FROM knowledge_entries')
|
||||
total_accesses = cursor.fetchone()[0] or 0
|
||||
|
||||
conn.close()
|
||||
|
||||
return {
|
||||
'total_entries': total_entries,
|
||||
'total_categories': total_categories,
|
||||
|
||||
89
tests/test_enhanced_assistant.py
Normal file
89
tests/test_enhanced_assistant.py
Normal file
@ -0,0 +1,89 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from pr.core.enhanced_assistant import EnhancedAssistant
|
||||
|
||||
def test_enhanced_assistant_init():
|
||||
mock_base = MagicMock()
|
||||
assistant = EnhancedAssistant(mock_base)
|
||||
assert assistant.base == mock_base
|
||||
assert assistant.current_conversation_id is not None
|
||||
|
||||
def test_enhanced_call_api_with_cache():
|
||||
mock_base = MagicMock()
|
||||
mock_base.model = 'test-model'
|
||||
mock_base.api_url = 'http://test'
|
||||
mock_base.api_key = 'key'
|
||||
mock_base.use_tools = False
|
||||
mock_base.verbose = False
|
||||
|
||||
assistant = EnhancedAssistant(mock_base)
|
||||
assistant.api_cache = MagicMock()
|
||||
assistant.api_cache.get.return_value = {'cached': True}
|
||||
|
||||
result = assistant.enhanced_call_api([{'role': 'user', 'content': 'test'}])
|
||||
assert result == {'cached': True}
|
||||
assistant.api_cache.get.assert_called_once()
|
||||
|
||||
def test_enhanced_call_api_without_cache():
|
||||
mock_base = MagicMock()
|
||||
mock_base.model = 'test-model'
|
||||
mock_base.api_url = 'http://test'
|
||||
mock_base.api_key = 'key'
|
||||
mock_base.use_tools = False
|
||||
mock_base.verbose = False
|
||||
|
||||
assistant = EnhancedAssistant(mock_base)
|
||||
assistant.api_cache = None
|
||||
|
||||
# It will try to call API and fail with network error, but that's expected
|
||||
result = assistant.enhanced_call_api([{'role': 'user', 'content': 'test'}])
|
||||
assert 'error' in result
|
||||
|
||||
def test_execute_workflow_not_found():
|
||||
mock_base = MagicMock()
|
||||
assistant = EnhancedAssistant(mock_base)
|
||||
assistant.workflow_storage = MagicMock()
|
||||
assistant.workflow_storage.load_workflow_by_name.return_value = None
|
||||
|
||||
result = assistant.execute_workflow('nonexistent')
|
||||
assert 'error' in result
|
||||
|
||||
def test_create_agent():
|
||||
mock_base = MagicMock()
|
||||
assistant = EnhancedAssistant(mock_base)
|
||||
assistant.agent_manager = MagicMock()
|
||||
assistant.agent_manager.create_agent.return_value = 'agent_id'
|
||||
|
||||
result = assistant.create_agent('role')
|
||||
assert result == 'agent_id'
|
||||
|
||||
def test_search_knowledge():
|
||||
mock_base = MagicMock()
|
||||
assistant = EnhancedAssistant(mock_base)
|
||||
assistant.knowledge_store = MagicMock()
|
||||
assistant.knowledge_store.search_entries.return_value = [{'result': True}]
|
||||
|
||||
result = assistant.search_knowledge('query')
|
||||
assert result == [{'result': True}]
|
||||
|
||||
def test_get_cache_statistics():
|
||||
mock_base = MagicMock()
|
||||
assistant = EnhancedAssistant(mock_base)
|
||||
assistant.api_cache = MagicMock()
|
||||
assistant.api_cache.get_statistics.return_value = {'hits': 10}
|
||||
assistant.tool_cache = MagicMock()
|
||||
assistant.tool_cache.get_statistics.return_value = {'misses': 5}
|
||||
|
||||
stats = assistant.get_cache_statistics()
|
||||
assert 'api_cache' in stats
|
||||
assert 'tool_cache' in stats
|
||||
|
||||
def test_clear_caches():
|
||||
mock_base = MagicMock()
|
||||
assistant = EnhancedAssistant(mock_base)
|
||||
assistant.api_cache = MagicMock()
|
||||
assistant.tool_cache = MagicMock()
|
||||
|
||||
assistant.clear_caches()
|
||||
assistant.api_cache.clear_all.assert_called_once()
|
||||
assistant.tool_cache.clear_all.assert_called_once()
|
||||
24
tests/test_logging.py
Normal file
24
tests/test_logging.py
Normal file
@ -0,0 +1,24 @@
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
from pr.core.logging import setup_logging, get_logger
|
||||
|
||||
def test_setup_logging_basic():
|
||||
logger = setup_logging(verbose=False)
|
||||
assert logger.name == 'pr'
|
||||
assert logger.level == 20 # INFO
|
||||
|
||||
def test_setup_logging_verbose():
|
||||
logger = setup_logging(verbose=True)
|
||||
assert logger.name == 'pr'
|
||||
assert logger.level == 10 # DEBUG
|
||||
# Should have console handler
|
||||
assert len(logger.handlers) >= 2
|
||||
|
||||
def test_get_logger_default():
|
||||
logger = get_logger()
|
||||
assert logger.name == 'pr'
|
||||
|
||||
def test_get_logger_named():
|
||||
logger = get_logger('test')
|
||||
assert logger.name == 'pr.test'
|
||||
116
tests/test_session.py
Normal file
116
tests/test_session.py
Normal file
@ -0,0 +1,116 @@
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
import json
|
||||
from pr.core.session import SessionManager
|
||||
|
||||
@pytest.fixture
|
||||
def temp_sessions_dir(tmp_path, monkeypatch):
|
||||
from pr.core import session
|
||||
original_dir = session.SESSIONS_DIR
|
||||
monkeypatch.setattr(session, 'SESSIONS_DIR', str(tmp_path))
|
||||
# Clean any existing files
|
||||
import shutil
|
||||
if os.path.exists(str(tmp_path)):
|
||||
shutil.rmtree(str(tmp_path))
|
||||
os.makedirs(str(tmp_path), exist_ok=True)
|
||||
yield tmp_path
|
||||
monkeypatch.setattr(session, 'SESSIONS_DIR', original_dir)
|
||||
|
||||
def test_session_manager_init(temp_sessions_dir):
|
||||
manager = SessionManager()
|
||||
assert os.path.exists(temp_sessions_dir)
|
||||
|
||||
def test_save_and_load_session(temp_sessions_dir):
|
||||
manager = SessionManager()
|
||||
name = "test_session"
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
metadata = {"test": True}
|
||||
|
||||
assert manager.save_session(name, messages, metadata)
|
||||
|
||||
loaded = manager.load_session(name)
|
||||
assert loaded is not None
|
||||
assert loaded['name'] == name
|
||||
assert loaded['messages'] == messages
|
||||
assert loaded['metadata'] == metadata
|
||||
|
||||
def test_load_nonexistent_session(temp_sessions_dir):
|
||||
manager = SessionManager()
|
||||
loaded = manager.load_session("nonexistent")
|
||||
assert loaded is None
|
||||
|
||||
def test_list_sessions(temp_sessions_dir):
|
||||
manager = SessionManager()
|
||||
# Save a session
|
||||
manager.save_session("session1", [{"role": "user", "content": "Hi"}])
|
||||
manager.save_session("session2", [{"role": "user", "content": "Hello"}])
|
||||
|
||||
sessions = manager.list_sessions()
|
||||
assert len(sessions) == 2
|
||||
assert sessions[0]['name'] == "session2" # sorted by created_at desc
|
||||
|
||||
def test_delete_session(temp_sessions_dir):
|
||||
manager = SessionManager()
|
||||
name = "to_delete"
|
||||
manager.save_session(name, [{"role": "user", "content": "Test"}])
|
||||
|
||||
assert manager.delete_session(name)
|
||||
assert manager.load_session(name) is None
|
||||
|
||||
def test_delete_nonexistent_session(temp_sessions_dir):
|
||||
manager = SessionManager()
|
||||
assert not manager.delete_session("nonexistent")
|
||||
|
||||
def test_export_session_json(temp_sessions_dir, tmp_path):
|
||||
manager = SessionManager()
|
||||
name = "export_test"
|
||||
messages = [{"role": "user", "content": "Export me"}]
|
||||
manager.save_session(name, messages)
|
||||
|
||||
output_path = tmp_path / "exported.json"
|
||||
assert manager.export_session(name, str(output_path), 'json')
|
||||
assert output_path.exists()
|
||||
|
||||
with open(output_path) as f:
|
||||
data = json.load(f)
|
||||
assert data['name'] == name
|
||||
|
||||
def test_export_session_markdown(temp_sessions_dir, tmp_path):
|
||||
manager = SessionManager()
|
||||
name = "export_md"
|
||||
messages = [{"role": "user", "content": "Markdown export"}]
|
||||
manager.save_session(name, messages)
|
||||
|
||||
output_path = tmp_path / "exported.md"
|
||||
assert manager.export_session(name, str(output_path), 'markdown')
|
||||
assert output_path.exists()
|
||||
|
||||
content = output_path.read_text()
|
||||
assert "# Session: export_md" in content
|
||||
|
||||
def test_export_session_txt(temp_sessions_dir, tmp_path):
|
||||
manager = SessionManager()
|
||||
name = "export_txt"
|
||||
messages = [{"role": "user", "content": "Text export"}]
|
||||
manager.save_session(name, messages)
|
||||
|
||||
output_path = tmp_path / "exported.txt"
|
||||
assert manager.export_session(name, str(output_path), 'txt')
|
||||
assert output_path.exists()
|
||||
|
||||
content = output_path.read_text()
|
||||
assert "Session: export_txt" in content
|
||||
|
||||
def test_export_nonexistent_session(temp_sessions_dir, tmp_path):
|
||||
manager = SessionManager()
|
||||
output_path = tmp_path / "nonexistent.json"
|
||||
assert not manager.export_session("nonexistent", str(output_path), 'json')
|
||||
|
||||
def test_export_unsupported_format(temp_sessions_dir, tmp_path):
|
||||
manager = SessionManager()
|
||||
name = "test"
|
||||
manager.save_session(name, [{"role": "user", "content": "Test"}])
|
||||
|
||||
output_path = tmp_path / "test.unsupported"
|
||||
assert not manager.export_session(name, str(output_path), 'unsupported')
|
||||
86
tests/test_usage_tracker.py
Normal file
86
tests/test_usage_tracker.py
Normal file
@ -0,0 +1,86 @@
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
import json
|
||||
from pr.core.usage_tracker import UsageTracker
|
||||
|
||||
@pytest.fixture
|
||||
def temp_usage_file(tmp_path, monkeypatch):
|
||||
from pr.core import usage_tracker
|
||||
original_file = usage_tracker.USAGE_DB_FILE
|
||||
temp_file = str(tmp_path / "usage.json")
|
||||
monkeypatch.setattr(usage_tracker, 'USAGE_DB_FILE', temp_file)
|
||||
yield temp_file
|
||||
if os.path.exists(temp_file):
|
||||
os.remove(temp_file)
|
||||
monkeypatch.setattr(usage_tracker, 'USAGE_DB_FILE', original_file)
|
||||
|
||||
def test_usage_tracker_init():
|
||||
tracker = UsageTracker()
|
||||
summary = tracker.get_session_summary()
|
||||
assert summary['requests'] == 0
|
||||
assert summary['total_tokens'] == 0
|
||||
assert summary['estimated_cost'] == 0.0
|
||||
|
||||
def test_track_request_known_model():
|
||||
tracker = UsageTracker()
|
||||
tracker.track_request('gpt-3.5-turbo', 100, 50)
|
||||
|
||||
summary = tracker.get_session_summary()
|
||||
assert summary['requests'] == 1
|
||||
assert summary['input_tokens'] == 100
|
||||
assert summary['output_tokens'] == 50
|
||||
assert summary['total_tokens'] == 150
|
||||
assert 'gpt-3.5-turbo' in summary['models_used']
|
||||
# Cost: (100/1000)*0.0005 + (50/1000)*0.0015 = 0.00005 + 0.000075 = 0.000125
|
||||
assert abs(summary['estimated_cost'] - 0.000125) < 1e-6
|
||||
|
||||
def test_track_request_unknown_model():
|
||||
tracker = UsageTracker()
|
||||
tracker.track_request('unknown-model', 100, 50)
|
||||
|
||||
summary = tracker.get_session_summary()
|
||||
assert summary['requests'] == 1
|
||||
assert summary['estimated_cost'] == 0.0 # Unknown model, cost 0
|
||||
|
||||
def test_track_request_multiple():
|
||||
tracker = UsageTracker()
|
||||
tracker.track_request('gpt-3.5-turbo', 100, 50)
|
||||
tracker.track_request('gpt-4', 200, 100)
|
||||
|
||||
summary = tracker.get_session_summary()
|
||||
assert summary['requests'] == 2
|
||||
assert summary['input_tokens'] == 300
|
||||
assert summary['output_tokens'] == 150
|
||||
assert summary['total_tokens'] == 450
|
||||
assert len(summary['models_used']) == 2
|
||||
|
||||
def test_get_formatted_summary():
|
||||
tracker = UsageTracker()
|
||||
tracker.track_request('gpt-3.5-turbo', 100, 50)
|
||||
|
||||
formatted = tracker.get_formatted_summary()
|
||||
assert "Total Requests: 1" in formatted
|
||||
assert "Total Tokens: 150" in formatted
|
||||
assert "Estimated Cost: $0.0001" in formatted
|
||||
assert "gpt-3.5-turbo" in formatted
|
||||
|
||||
def test_get_total_usage_no_file(temp_usage_file):
|
||||
total = UsageTracker.get_total_usage()
|
||||
assert total['total_requests'] == 0
|
||||
assert total['total_tokens'] == 0
|
||||
assert total['total_cost'] == 0.0
|
||||
|
||||
def test_get_total_usage_with_data(temp_usage_file):
|
||||
# Manually create history file
|
||||
history = [
|
||||
{'timestamp': '2023-01-01', 'model': 'gpt-3.5-turbo', 'input_tokens': 100, 'output_tokens': 50, 'total_tokens': 150, 'cost': 0.000125},
|
||||
{'timestamp': '2023-01-02', 'model': 'gpt-4', 'input_tokens': 200, 'output_tokens': 100, 'total_tokens': 300, 'cost': 0.008}
|
||||
]
|
||||
with open(temp_usage_file, 'w') as f:
|
||||
json.dump(history, f)
|
||||
|
||||
total = UsageTracker.get_total_usage()
|
||||
assert total['total_requests'] == 2
|
||||
assert total['total_tokens'] == 450
|
||||
assert abs(total['total_cost'] - 0.008125) < 1e-6
|
||||
122
tests/test_validation.py
Normal file
122
tests/test_validation.py
Normal file
@ -0,0 +1,122 @@
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
from pr.core.validation import (
|
||||
validate_file_path,
|
||||
validate_directory_path,
|
||||
validate_model_name,
|
||||
validate_api_url,
|
||||
validate_session_name,
|
||||
validate_temperature,
|
||||
validate_max_tokens,
|
||||
)
|
||||
from pr.core.exceptions import ValidationError
|
||||
|
||||
def test_validate_file_path_empty():
|
||||
with pytest.raises(ValidationError, match="File path cannot be empty"):
|
||||
validate_file_path("")
|
||||
|
||||
def test_validate_file_path_not_exist():
|
||||
with pytest.raises(ValidationError, match="File does not exist"):
|
||||
validate_file_path("/nonexistent/file.txt", must_exist=True)
|
||||
|
||||
def test_validate_file_path_is_dir():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with pytest.raises(ValidationError, match="Path is a directory"):
|
||||
validate_file_path(tmpdir, must_exist=True)
|
||||
|
||||
def test_validate_file_path_valid():
|
||||
with tempfile.NamedTemporaryFile() as tmpfile:
|
||||
result = validate_file_path(tmpfile.name, must_exist=True)
|
||||
assert os.path.isabs(result)
|
||||
assert result == os.path.abspath(tmpfile.name)
|
||||
|
||||
def test_validate_directory_path_empty():
|
||||
with pytest.raises(ValidationError, match="Directory path cannot be empty"):
|
||||
validate_directory_path("")
|
||||
|
||||
def test_validate_directory_path_not_exist():
|
||||
with pytest.raises(ValidationError, match="Directory does not exist"):
|
||||
validate_directory_path("/nonexistent/dir", must_exist=True)
|
||||
|
||||
def test_validate_directory_path_not_dir():
|
||||
with tempfile.NamedTemporaryFile() as tmpfile:
|
||||
with pytest.raises(ValidationError, match="Path is not a directory"):
|
||||
validate_directory_path(tmpfile.name, must_exist=True)
|
||||
|
||||
def test_validate_directory_path_create():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
new_dir = os.path.join(tmpdir, "new_dir")
|
||||
result = validate_directory_path(new_dir, must_exist=True, create=True)
|
||||
assert os.path.isdir(new_dir)
|
||||
assert result == os.path.abspath(new_dir)
|
||||
|
||||
def test_validate_directory_path_valid():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
result = validate_directory_path(tmpdir, must_exist=True)
|
||||
assert result == os.path.abspath(tmpdir)
|
||||
|
||||
def test_validate_model_name_empty():
|
||||
with pytest.raises(ValidationError, match="Model name cannot be empty"):
|
||||
validate_model_name("")
|
||||
|
||||
def test_validate_model_name_too_short():
|
||||
with pytest.raises(ValidationError, match="Model name too short"):
|
||||
validate_model_name("a")
|
||||
|
||||
def test_validate_model_name_valid():
|
||||
result = validate_model_name("gpt-3.5-turbo")
|
||||
assert result == "gpt-3.5-turbo"
|
||||
|
||||
def test_validate_api_url_empty():
|
||||
with pytest.raises(ValidationError, match="API URL cannot be empty"):
|
||||
validate_api_url("")
|
||||
|
||||
def test_validate_api_url_invalid():
|
||||
with pytest.raises(ValidationError, match="API URL must start with"):
|
||||
validate_api_url("invalid-url")
|
||||
|
||||
def test_validate_api_url_valid():
|
||||
result = validate_api_url("https://api.example.com")
|
||||
assert result == "https://api.example.com"
|
||||
|
||||
def test_validate_session_name_empty():
|
||||
with pytest.raises(ValidationError, match="Session name cannot be empty"):
|
||||
validate_session_name("")
|
||||
|
||||
def test_validate_session_name_invalid_char():
|
||||
with pytest.raises(ValidationError, match="contains invalid character"):
|
||||
validate_session_name("test/session")
|
||||
|
||||
def test_validate_session_name_too_long():
|
||||
long_name = "a" * 256
|
||||
with pytest.raises(ValidationError, match="Session name too long"):
|
||||
validate_session_name(long_name)
|
||||
|
||||
def test_validate_session_name_valid():
|
||||
result = validate_session_name("valid_session_123")
|
||||
assert result == "valid_session_123"
|
||||
|
||||
def test_validate_temperature_too_low():
|
||||
with pytest.raises(ValidationError, match="Temperature must be between"):
|
||||
validate_temperature(-0.1)
|
||||
|
||||
def test_validate_temperature_too_high():
|
||||
with pytest.raises(ValidationError, match="Temperature must be between"):
|
||||
validate_temperature(2.1)
|
||||
|
||||
def test_validate_temperature_valid():
|
||||
result = validate_temperature(0.7)
|
||||
assert result == 0.7
|
||||
|
||||
def test_validate_max_tokens_too_low():
|
||||
with pytest.raises(ValidationError, match="Max tokens must be at least 1"):
|
||||
validate_max_tokens(0)
|
||||
|
||||
def test_validate_max_tokens_too_high():
|
||||
with pytest.raises(ValidationError, match="Max tokens too high"):
|
||||
validate_max_tokens(100001)
|
||||
|
||||
def test_validate_max_tokens_valid():
|
||||
result = validate_max_tokens(1000)
|
||||
assert result == 1000
|
||||
Loading…
Reference in New Issue
Block a user