import unittest
from unittest.mock import patch, MagicMock
import tempfile
import os
from pr.core.assistant import Assistant, process_message
class TestAssistant(unittest.TestCase):
def setUp(self):
self.args = MagicMock()
self.args.verbose = False
self.args.debug = False
self.args.no_syntax = False
self.args.model = 'test-model'
self.args.api_url = 'test-url'
self.args.model_list_url = 'test-list-url'
@patch('sqlite3.connect')
@patch('os.environ.get')
@patch('pr.core.context.init_system_message')
@patch('pr.core.enhanced_assistant.EnhancedAssistant')
def test_init(self, mock_enhanced, mock_init_sys, mock_env, mock_sqlite):
mock_env.side_effect = lambda key, default: {'OPENROUTER_API_KEY': 'key', 'AI_MODEL': 'model', 'API_URL': 'url', 'MODEL_LIST_URL': 'list', 'USE_TOOLS': '1', 'STRICT_MODE': '0'}.get(key, default)
mock_conn = MagicMock()
mock_sqlite.return_value = mock_conn
mock_init_sys.return_value = {'role': 'system', 'content': 'sys'}
assistant = Assistant(self.args)
self.assertEqual(assistant.api_key, 'key')
self.assertEqual(assistant.model, 'test-model')
mock_sqlite.assert_called_once()
@patch('pr.core.assistant.call_api')
@patch('pr.core.assistant.render_markdown')
def test_process_response_no_tools(self, mock_render, mock_call):
assistant = MagicMock()
assistant.messages = MagicMock()
assistant.verbose = False
assistant.syntax_highlighting = True
mock_render.return_value = 'rendered'
response = {'choices': [{'message': {'content': 'content'}}]}
result = Assistant.process_response(assistant, response)
self.assertEqual(result, 'rendered')
assistant.messages.append.assert_called_with({'content': 'content'})
@patch('pr.core.assistant.call_api')
@patch('pr.core.assistant.render_markdown')
@patch('pr.core.assistant.get_tools_definition')
def test_process_response_with_tools(self, mock_tools_def, mock_render, mock_call):
assistant = MagicMock()
assistant.messages = MagicMock()
assistant.verbose = False
assistant.syntax_highlighting = True
assistant.use_tools = True
assistant.model = 'model'
assistant.api_url = 'url'
assistant.api_key = 'key'
mock_tools_def.return_value = []
mock_call.return_value = {'choices': [{'message': {'content': 'follow'}}]}
response = {'choices': [{'message': {'tool_calls': [{'id': '1', 'function': {'name': 'test', 'arguments': '{}'}}]}}]}
with patch.object(assistant, 'execute_tool_calls', return_value=[{'role': 'tool', 'content': 'result'}]):
result = Assistant.process_response(assistant, response)
mock_call.assert_called()
@patch('pr.core.assistant.call_api')
@patch('pr.core.assistant.get_tools_definition')
def test_process_message(self, mock_tools, mock_call):
assistant = MagicMock()
assistant.messages = MagicMock()
assistant.verbose = False
assistant.use_tools = True
assistant.model = 'model'
assistant.api_url = 'url'
assistant.api_key = 'key'
mock_tools.return_value = []
mock_call.return_value = {'choices': [{'message': {'content': 'response'}}]}
with patch('pr.core.assistant.render_markdown', return_value='rendered'):
with patch('builtins.print'):
process_message(assistant, 'test message')
assistant.messages.append.assert_called_with({'role': 'user', 'content': 'test message'})
if __name__ == '__main__':
unittest.main()