You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
196 lines
7.5 KiB
196 lines
7.5 KiB
import unittest
|
|
from unittest.mock import patch, MagicMock
|
|
import os
|
|
import logging
|
|
|
|
# Disable logging for tests
|
|
logging.disable(logging.CRITICAL)
|
|
|
|
from deepsearcher.llm import Bedrock
|
|
from deepsearcher.llm.base import ChatResponse
|
|
|
|
|
|
class TestBedrock(unittest.TestCase):
|
|
"""Tests for the AWS Bedrock LLM provider."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
# Create mock module and components
|
|
self.mock_boto3 = MagicMock()
|
|
self.mock_client = MagicMock()
|
|
|
|
# Set up the mock module structure
|
|
self.mock_boto3.client = MagicMock(return_value=self.mock_client)
|
|
|
|
# Set up mock response
|
|
self.mock_response = {
|
|
"output": {
|
|
"message": {
|
|
"content": [{"text": "Test response\nwith newline"}]
|
|
}
|
|
},
|
|
"usage": {
|
|
"totalTokens": 100
|
|
}
|
|
}
|
|
self.mock_client.converse.return_value = self.mock_response
|
|
|
|
# Create the module patcher
|
|
self.module_patcher = patch.dict('sys.modules', {'boto3': self.mock_boto3})
|
|
self.module_patcher.start()
|
|
|
|
def tearDown(self):
|
|
"""Clean up test fixtures."""
|
|
self.module_patcher.stop()
|
|
|
|
def test_init_default(self):
|
|
"""Test initialization with default parameters."""
|
|
# Clear environment variables temporarily
|
|
with patch.dict('os.environ', {}, clear=True):
|
|
llm = Bedrock()
|
|
# Check that client was initialized correctly
|
|
self.mock_boto3.client.assert_called_once_with(
|
|
"bedrock-runtime",
|
|
region_name="us-west-2",
|
|
aws_access_key_id=None,
|
|
aws_secret_access_key=None,
|
|
aws_session_token=None
|
|
)
|
|
|
|
# Check default attributes
|
|
self.assertEqual(llm.model, "us.deepseek.r1-v1:0")
|
|
self.assertEqual(llm.max_tokens, 20000)
|
|
|
|
def test_init_with_aws_credentials_from_env(self):
|
|
"""Test initialization with AWS credentials from environment variables."""
|
|
credentials = {
|
|
"AWS_ACCESS_KEY_ID": "test_access_key",
|
|
"AWS_SECRET_ACCESS_KEY": "test_secret_key",
|
|
"AWS_SESSION_TOKEN": "test_session_token"
|
|
}
|
|
with patch.dict(os.environ, credentials):
|
|
llm = Bedrock()
|
|
self.mock_boto3.client.assert_called_with(
|
|
"bedrock-runtime",
|
|
region_name="us-west-2",
|
|
aws_access_key_id="test_access_key",
|
|
aws_secret_access_key="test_secret_key",
|
|
aws_session_token="test_session_token"
|
|
)
|
|
|
|
def test_init_with_aws_credentials_parameters(self):
|
|
"""Test initialization with AWS credentials as parameters."""
|
|
with patch.dict('os.environ', {}, clear=True):
|
|
llm = Bedrock(
|
|
aws_access_key_id="param_access_key",
|
|
aws_secret_access_key="param_secret_key",
|
|
aws_session_token="param_session_token"
|
|
)
|
|
self.mock_boto3.client.assert_called_with(
|
|
"bedrock-runtime",
|
|
region_name="us-west-2",
|
|
aws_access_key_id="param_access_key",
|
|
aws_secret_access_key="param_secret_key",
|
|
aws_session_token="param_session_token"
|
|
)
|
|
|
|
def test_init_with_custom_model_and_tokens(self):
|
|
"""Test initialization with custom model and max tokens."""
|
|
with patch.dict('os.environ', {}, clear=True):
|
|
llm = Bedrock(model="custom.model", max_tokens=1000)
|
|
self.assertEqual(llm.model, "custom.model")
|
|
self.assertEqual(llm.max_tokens, 1000)
|
|
|
|
def test_init_with_custom_region(self):
|
|
"""Test initialization with custom region."""
|
|
with patch.dict('os.environ', {}, clear=True):
|
|
llm = Bedrock(region_name="us-east-1")
|
|
self.mock_boto3.client.assert_called_with(
|
|
"bedrock-runtime",
|
|
region_name="us-east-1",
|
|
aws_access_key_id=None,
|
|
aws_secret_access_key=None,
|
|
aws_session_token=None
|
|
)
|
|
|
|
def test_chat_single_message(self):
|
|
"""Test chat with a single message."""
|
|
# Create Bedrock instance with mocked environment
|
|
with patch.dict('os.environ', {}, clear=True):
|
|
llm = Bedrock()
|
|
messages = [{"role": "user", "content": "Hello"}]
|
|
response = llm.chat(messages)
|
|
|
|
# Check that converse was called correctly
|
|
self.mock_client.converse.assert_called_once()
|
|
call_args = self.mock_client.converse.call_args
|
|
self.assertEqual(call_args[1]["modelId"], "us.deepseek.r1-v1:0")
|
|
self.assertEqual(call_args[1]["messages"], [
|
|
{"role": "user", "content": [{"text": "Hello"}]}
|
|
])
|
|
self.assertEqual(call_args[1]["inferenceConfig"], {"maxTokens": 20000})
|
|
|
|
# Check response
|
|
self.assertIsInstance(response, ChatResponse)
|
|
self.assertEqual(response.content, "Test responsewith newline")
|
|
self.assertEqual(response.total_tokens, 100)
|
|
|
|
def test_chat_multiple_messages(self):
|
|
"""Test chat with multiple messages."""
|
|
# Create Bedrock instance with mocked environment
|
|
with patch.dict('os.environ', {}, clear=True):
|
|
llm = Bedrock()
|
|
messages = [
|
|
{"role": "system", "content": "You are a helpful assistant"},
|
|
{"role": "user", "content": "Hello"},
|
|
{"role": "assistant", "content": "Hi there!"},
|
|
{"role": "user", "content": "How are you?"}
|
|
]
|
|
response = llm.chat(messages)
|
|
|
|
# Check that converse was called correctly
|
|
self.mock_client.converse.assert_called_once()
|
|
call_args = self.mock_client.converse.call_args
|
|
|
|
expected_messages = [
|
|
{"role": "system", "content": [{"text": "You are a helpful assistant"}]},
|
|
{"role": "user", "content": [{"text": "Hello"}]},
|
|
{"role": "assistant", "content": [{"text": "Hi there!"}]},
|
|
{"role": "user", "content": [{"text": "How are you?"}]}
|
|
]
|
|
self.assertEqual(call_args[1]["messages"], expected_messages)
|
|
|
|
def test_chat_with_error(self):
|
|
"""Test chat when an error occurs."""
|
|
# Create Bedrock instance with mocked environment
|
|
with patch.dict('os.environ', {}, clear=True):
|
|
llm = Bedrock()
|
|
# Mock an error response
|
|
self.mock_client.converse.side_effect = Exception("AWS Bedrock Error")
|
|
|
|
messages = [{"role": "user", "content": "Hello"}]
|
|
with self.assertRaises(Exception) as context:
|
|
llm.chat(messages)
|
|
|
|
self.assertEqual(str(context.exception), "AWS Bedrock Error")
|
|
|
|
def test_chat_with_preformatted_messages(self):
|
|
"""Test chat with messages that are already in the correct format."""
|
|
# Create Bedrock instance with mocked environment
|
|
with patch.dict('os.environ', {}, clear=True):
|
|
llm = Bedrock()
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [{"text": "Hello"}]
|
|
}
|
|
]
|
|
response = llm.chat(messages)
|
|
|
|
# Check that the message format was preserved
|
|
call_args = self.mock_client.converse.call_args
|
|
self.assertEqual(call_args[1]["messages"], messages)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|