You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

170 lines
6.2 KiB

import unittest
from unittest.mock import patch, MagicMock
import os
import logging
# Disable logging for tests
logging.disable(logging.CRITICAL)
from deepsearcher.llm import AzureOpenAI
from deepsearcher.llm.base import ChatResponse
class TestAzureOpenAI(unittest.TestCase):
"""Tests for the Azure OpenAI LLM provider."""
def setUp(self):
"""Set up test fixtures."""
# Create mock module and components
self.mock_openai = MagicMock()
self.mock_client = MagicMock()
self.mock_chat = MagicMock()
self.mock_completions = MagicMock()
# Set up the mock module structure
self.mock_openai.AzureOpenAI = MagicMock(return_value=self.mock_client)
self.mock_client.chat = self.mock_chat
self.mock_chat.completions = self.mock_completions
# Set up mock response
self.mock_response = MagicMock()
self.mock_choice = MagicMock()
self.mock_message = MagicMock()
self.mock_usage = MagicMock()
self.mock_message.content = "Test response"
self.mock_choice.message = self.mock_message
self.mock_usage.total_tokens = 100
self.mock_response.choices = [self.mock_choice]
self.mock_response.usage = self.mock_usage
self.mock_completions.create.return_value = self.mock_response
# Create the module patcher
self.module_patcher = patch.dict('sys.modules', {'openai': self.mock_openai})
self.module_patcher.start()
# Test parameters
self.test_model = "gpt-4"
self.test_endpoint = "https://test.openai.azure.com"
self.test_api_key = "test_api_key"
self.test_api_version = "2024-02-15"
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
def test_init_with_parameters(self):
"""Test initialization with explicit parameters."""
# Clear environment variables temporarily
with patch.dict('os.environ', {}, clear=True):
llm = AzureOpenAI(
model=self.test_model,
azure_endpoint=self.test_endpoint,
api_key=self.test_api_key,
api_version=self.test_api_version
)
# Check that Azure OpenAI client was initialized correctly
self.mock_openai.AzureOpenAI.assert_called_once_with(
azure_endpoint=self.test_endpoint,
api_key=self.test_api_key,
api_version=self.test_api_version
)
# Check model attribute
self.assertEqual(llm.model, self.test_model)
def test_init_with_env_variables(self):
"""Test initialization with environment variables."""
env_endpoint = "https://env.openai.azure.com"
env_api_key = "env_api_key"
with patch.dict(os.environ, {
"AZURE_OPENAI_ENDPOINT": env_endpoint,
"AZURE_OPENAI_KEY": env_api_key
}):
llm = AzureOpenAI(model=self.test_model)
self.mock_openai.AzureOpenAI.assert_called_with(
azure_endpoint=env_endpoint,
api_key=env_api_key,
api_version=None
)
def test_chat_single_message(self):
"""Test chat with a single message."""
# Create Azure OpenAI instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = AzureOpenAI(
model=self.test_model,
azure_endpoint=self.test_endpoint,
api_key=self.test_api_key,
api_version=self.test_api_version
)
messages = [{"role": "user", "content": "Hello"}]
response = llm.chat(messages)
# Check that completions.create was called correctly
self.mock_completions.create.assert_called_once()
call_args = self.mock_completions.create.call_args
self.assertEqual(call_args[1]["model"], self.test_model)
self.assertEqual(call_args[1]["messages"], messages)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_multiple_messages(self):
"""Test chat with multiple messages."""
# Create Azure OpenAI instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = AzureOpenAI(
model=self.test_model,
azure_endpoint=self.test_endpoint,
api_key=self.test_api_key,
api_version=self.test_api_version
)
messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"}
]
response = llm.chat(messages)
# Check that completions.create was called correctly
self.mock_completions.create.assert_called_once()
call_args = self.mock_completions.create.call_args
self.assertEqual(call_args[1]["model"], self.test_model)
self.assertEqual(call_args[1]["messages"], messages)
# Check response
self.assertIsInstance(response, ChatResponse)
self.assertEqual(response.content, "Test response")
self.assertEqual(response.total_tokens, 100)
def test_chat_with_error(self):
"""Test chat when an error occurs."""
# Create Azure OpenAI instance with mocked environment
with patch.dict('os.environ', {}, clear=True):
llm = AzureOpenAI(
model=self.test_model,
azure_endpoint=self.test_endpoint,
api_key=self.test_api_key,
api_version=self.test_api_version
)
# Mock an error response
self.mock_completions.create.side_effect = Exception("Azure OpenAI API Error")
messages = [{"role": "user", "content": "Hello"}]
with self.assertRaises(Exception) as context:
llm.chat(messages)
self.assertEqual(str(context.exception), "Azure OpenAI API Error")
if __name__ == "__main__":
unittest.main()