|
import unittest
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from rp.core.assistant import Assistant, process_message
|
|
|
|
|
|
class TestAssistant(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
self.args = MagicMock()
|
|
self.args.verbose = False
|
|
self.args.debug = False
|
|
self.args.no_syntax = False
|
|
self.args.model = "test-model"
|
|
self.args.api_url = "test-url"
|
|
self.args.model_list_url = "test-list-url"
|
|
|
|
@patch("sqlite3.connect")
|
|
@patch("os.environ.get")
|
|
@patch("rp.core.context.init_system_message")
|
|
@patch("rp.core.enhanced_assistant.EnhancedAssistant")
|
|
def test_init(self, mock_enhanced, mock_init_sys, mock_env, mock_sqlite):
|
|
mock_env.side_effect = lambda key, default: {
|
|
"OPENROUTER_API_KEY": "key",
|
|
"AI_MODEL": "model",
|
|
"API_URL": "url",
|
|
"MODEL_LIST_URL": "list",
|
|
"USE_TOOLS": "1",
|
|
"STRICT_MODE": "0",
|
|
}.get(key, default)
|
|
mock_conn = MagicMock()
|
|
mock_sqlite.return_value = mock_conn
|
|
mock_init_sys.return_value = {"role": "system", "content": "sys"}
|
|
|
|
assistant = Assistant(self.args)
|
|
|
|
self.assertEqual(assistant.api_key, "key")
|
|
self.assertEqual(assistant.model, "test-model")
|
|
mock_sqlite.assert_called_once()
|
|
|
|
@patch("rp.core.assistant.call_api")
|
|
@patch("rp.core.assistant.render_markdown")
|
|
def test_process_response_no_tools(self, mock_render, mock_call):
|
|
assistant = MagicMock()
|
|
assistant.verbose = False
|
|
assistant.syntax_highlighting = True
|
|
mock_render.return_value = "rendered"
|
|
|
|
response = {"choices": [{"message": {"content": "content"}}]}
|
|
|
|
result = Assistant.process_response(assistant, response)
|
|
|
|
self.assertEqual(result, "rendered")
|
|
assistant.messages.append.assert_called_with({"content": "content"})
|
|
|
|
@patch("rp.core.assistant.call_api")
|
|
@patch("rp.core.assistant.render_markdown")
|
|
@patch("rp.core.assistant.get_tools_definition")
|
|
def test_process_response_with_tools(self, mock_tools_def, mock_render, mock_call):
|
|
assistant = MagicMock()
|
|
assistant.verbose = False
|
|
assistant.syntax_highlighting = True
|
|
assistant.use_tools = True
|
|
assistant.model = "model"
|
|
assistant.api_url = "url"
|
|
assistant.api_key = "key"
|
|
mock_tools_def.return_value = []
|
|
mock_call.return_value = {"choices": [{"message": {"content": "follow"}}]}
|
|
|
|
response = {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"tool_calls": [{"id": "1", "function": {"name": "test", "arguments": "{}"}}]
|
|
}
|
|
}
|
|
]
|
|
}
|
|
|
|
with patch.object(
|
|
assistant,
|
|
"execute_tool_calls",
|
|
return_value=[{"role": "tool", "content": "result"}],
|
|
):
|
|
Assistant.process_response(assistant, response)
|
|
|
|
mock_call.assert_called()
|
|
|
|
@patch("rp.core.assistant.call_api")
|
|
@patch("rp.core.assistant.get_tools_definition")
|
|
@patch("time.time")
|
|
@patch("uuid.uuid4")
|
|
def test_process_message(self, mock_uuid, mock_time, mock_tools, mock_call):
|
|
assistant = MagicMock()
|
|
assistant.verbose = False
|
|
assistant.use_tools = True
|
|
assistant.model = "model"
|
|
assistant.api_url = "url"
|
|
assistant.api_key = "key"
|
|
# Mock fact_extractor and its categorize_content method
|
|
assistant.fact_extractor = MagicMock()
|
|
assistant.fact_extractor.categorize_content.return_value = ["user_message"]
|
|
# Mock knowledge_store and its add_entry method
|
|
assistant.knowledge_store = MagicMock()
|
|
assistant.knowledge_store.add_entry.return_value = None
|
|
mock_tools.return_value = []
|
|
mock_call.return_value = {"choices": [{"message": {"content": "response"}}]}
|
|
mock_time.return_value = 1234567890.123456
|
|
mock_uuid.return_value = MagicMock()
|
|
mock_uuid.return_value.__str__ = MagicMock(return_value="mock_uuid_value")
|
|
|
|
with patch("rp.core.assistant.render_markdown", return_value="rendered"):
|
|
with patch("builtins.print"):
|
|
process_message(assistant, "test message")
|
|
|
|
from rp.memory import KnowledgeEntry
|
|
|
|
# Mock time.time() and uuid.uuid4() to return consistent values
|
|
expected_entry = KnowledgeEntry(
|
|
entry_id="mock_uuid_value"[:16],
|
|
category="user_message",
|
|
content="test message",
|
|
metadata={
|
|
"type": "user_message",
|
|
"confidence": 1.0,
|
|
"source": "user_input",
|
|
},
|
|
created_at=1234567890.123456,
|
|
updated_at=1234567890.123456,
|
|
)
|
|
str(expected_entry)
|
|
|
|
assistant.knowledge_store.add_entry.assert_called_once_with(expected_entry)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|