|
import unittest
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from rp.core.assistant import Assistant, process_message
|
|
|
|
|
|
class TestAssistant(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
self.args = MagicMock()
|
|
self.args.verbose = False
|
|
self.args.debug = False
|
|
self.args.no_syntax = False
|
|
self.args.model = "test-model"
|
|
self.args.api_url = "test-url"
|
|
self.args.model_list_url = "test-list-url"
|
|
|
|
@patch("sqlite3.connect")
|
|
@patch("os.environ.get")
|
|
@patch("rp.core.context.init_system_message")
|
|
@patch("rp.core.enhanced_assistant.EnhancedAssistant")
|
|
def test_init(self, mock_enhanced, mock_init_sys, mock_env, mock_sqlite):
|
|
mock_env.side_effect = lambda key, default: {
|
|
"OPENROUTER_API_KEY": "key",
|
|
"AI_MODEL": "model",
|
|
"API_URL": "url",
|
|
"MODEL_LIST_URL": "list",
|
|
"USE_TOOLS": "1",
|
|
"STRICT_MODE": "0",
|
|
}.get(key, default)
|
|
mock_conn = MagicMock()
|
|
mock_sqlite.return_value = mock_conn
|
|
mock_init_sys.return_value = {"role": "system", "content": "sys"}
|
|
|
|
assistant = Assistant(self.args)
|
|
|
|
self.assertEqual(assistant.api_key, "key")
|
|
self.assertEqual(assistant.model, "test-model")
|
|
mock_sqlite.assert_called_once()
|
|
|
|
@patch("rp.core.assistant.call_api")
|
|
@patch("rp.core.assistant.render_markdown")
|
|
def test_process_response_no_tools(self, mock_render, mock_call):
|
|
assistant = MagicMock()
|
|
assistant.verbose = False
|
|
assistant.syntax_highlighting = True
|
|
mock_render.return_value = "rendered"
|
|
|
|
response = {"choices": [{"message": {"content": "content"}}]}
|
|
|
|
result = Assistant.process_response(assistant, response)
|
|
|
|
self.assertEqual(result, "rendered")
|
|
assistant.messages.append.assert_called_with({"content": "content"})
|
|
|
|
@patch("rp.core.assistant.call_api")
|
|
@patch("rp.core.assistant.render_markdown")
|
|
@patch("rp.core.assistant.get_tools_definition")
|
|
def test_process_response_with_tools(self, mock_tools_def, mock_render, mock_call):
|
|
assistant = MagicMock()
|
|
assistant.verbose = False
|
|
assistant.syntax_highlighting = True
|
|
assistant.use_tools = True
|
|
assistant.model = "model"
|
|
assistant.api_url = "url"
|
|
assistant.api_key = "key"
|
|
mock_tools_def.return_value = []
|
|
mock_call.return_value = {"choices": [{"message": {"content": "follow"}}]}
|
|
|
|
response = {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"tool_calls": [{"id": "1", "function": {"name": "test", "arguments": "{}"}}]
|
|
}
|
|
}
|
|
]
|
|
}
|
|
|
|
with patch.object(
|
|
assistant,
|
|
"execute_tool_calls",
|
|
return_value=[{"role": "tool", "content": "result"}],
|
|
):
|
|
Assistant.process_response(assistant, response)
|
|
|
|
mock_call.assert_called()
|
|
|
|
@patch("rp.core.assistant.call_api")
|
|
@patch("rp.core.assistant.get_tools_definition")
|
|
def test_process_message(self, mock_tools, mock_call):
|
|
assistant = MagicMock()
|
|
assistant.verbose = False
|
|
assistant.use_tools = True
|
|
assistant.model = "model"
|
|
assistant.api_url = "url"
|
|
assistant.api_key = "key"
|
|
mock_tools.return_value = []
|
|
mock_call.return_value = {"choices": [{"message": {"content": "response"}}]}
|
|
|
|
with patch("rp.core.assistant.render_markdown", return_value="rendered"):
|
|
with patch("builtins.print"):
|
|
process_message(assistant, "test message")
|
|
|
|
assistant.messages.append.assert_called_with({"role": "user", "content": "test message"})
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|