import unittest from unittest.mock import patch, MagicMock import tempfile import os from pr.core.assistant import Assistant, process_message class TestAssistant(unittest.TestCase): def setUp(self): self.args = MagicMock() self.args.verbose = False self.args.debug = False self.args.no_syntax = False self.args.model = 'test-model' self.args.api_url = 'test-url' self.args.model_list_url = 'test-list-url' @patch('sqlite3.connect') @patch('os.environ.get') @patch('pr.core.context.init_system_message') @patch('pr.core.enhanced_assistant.EnhancedAssistant') def test_init(self, mock_enhanced, mock_init_sys, mock_env, mock_sqlite): mock_env.side_effect = lambda key, default: {'OPENROUTER_API_KEY': 'key', 'AI_MODEL': 'model', 'API_URL': 'url', 'MODEL_LIST_URL': 'list', 'USE_TOOLS': '1', 'STRICT_MODE': '0'}.get(key, default) mock_conn = MagicMock() mock_sqlite.return_value = mock_conn mock_init_sys.return_value = {'role': 'system', 'content': 'sys'} assistant = Assistant(self.args) self.assertEqual(assistant.api_key, 'key') self.assertEqual(assistant.model, 'test-model') mock_sqlite.assert_called_once() @patch('pr.core.assistant.call_api') @patch('pr.core.assistant.render_markdown') def test_process_response_no_tools(self, mock_render, mock_call): assistant = MagicMock() assistant.messages = MagicMock() assistant.verbose = False assistant.syntax_highlighting = True mock_render.return_value = 'rendered' response = {'choices': [{'message': {'content': 'content'}}]} result = Assistant.process_response(assistant, response) self.assertEqual(result, 'rendered') assistant.messages.append.assert_called_with({'content': 'content'}) @patch('pr.core.assistant.call_api') @patch('pr.core.assistant.render_markdown') @patch('pr.core.assistant.get_tools_definition') def test_process_response_with_tools(self, mock_tools_def, mock_render, mock_call): assistant = MagicMock() assistant.messages = MagicMock() assistant.verbose = False assistant.syntax_highlighting = True assistant.use_tools = True assistant.model = 'model' assistant.api_url = 'url' assistant.api_key = 'key' mock_tools_def.return_value = [] mock_call.return_value = {'choices': [{'message': {'content': 'follow'}}]} response = {'choices': [{'message': {'tool_calls': [{'id': '1', 'function': {'name': 'test', 'arguments': '{}'}}]}}]} with patch.object(assistant, 'execute_tool_calls', return_value=[{'role': 'tool', 'content': 'result'}]): result = Assistant.process_response(assistant, response) mock_call.assert_called() @patch('pr.core.assistant.call_api') @patch('pr.core.assistant.get_tools_definition') def test_process_message(self, mock_tools, mock_call): assistant = MagicMock() assistant.messages = MagicMock() assistant.verbose = False assistant.use_tools = True assistant.model = 'model' assistant.api_url = 'url' assistant.api_key = 'key' mock_tools.return_value = [] mock_call.return_value = {'choices': [{'message': {'content': 'response'}}]} with patch('pr.core.assistant.render_markdown', return_value='rendered'): with patch('builtins.print'): process_message(assistant, 'test message') assistant.messages.append.assert_called_with({'role': 'user', 'content': 'test message'}) if __name__ == '__main__': unittest.main()