You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
148 lines
5.1 KiB
148 lines
5.1 KiB
import unittest
|
|
import json
|
|
import os
|
|
from unittest.mock import patch, MagicMock
|
|
import logging
|
|
|
|
# Disable logging for tests
|
|
logging.disable(logging.CRITICAL)
|
|
|
|
from deepsearcher.embedding import BedrockEmbedding
|
|
from deepsearcher.embedding.bedrock_embedding import (
|
|
MODEL_ID_TITAN_TEXT_V2,
|
|
MODEL_ID_TITAN_TEXT_G1,
|
|
MODEL_ID_COHERE_ENGLISH_V3,
|
|
)
|
|
|
|
|
|
class TestBedrockEmbedding(unittest.TestCase):
|
|
"""Tests for the BedrockEmbedding class."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
# Create mock module and components
|
|
self.mock_boto3 = MagicMock()
|
|
self.mock_client = MagicMock()
|
|
self.mock_boto3.client = MagicMock(return_value=self.mock_client)
|
|
|
|
# Create the module patcher
|
|
self.module_patcher = patch.dict('sys.modules', {'boto3': self.mock_boto3})
|
|
self.module_patcher.start()
|
|
|
|
# Configure mock response
|
|
self.mock_response = {
|
|
"body": MagicMock(),
|
|
"ResponseMetadata": {"HTTPStatusCode": 200}
|
|
}
|
|
self.mock_response["body"].read.return_value = json.dumps({"embedding": [0.1] * 1024})
|
|
self.mock_client.invoke_model.return_value = self.mock_response
|
|
|
|
def tearDown(self):
|
|
"""Clean up test fixtures."""
|
|
self.module_patcher.stop()
|
|
|
|
@patch.dict('os.environ', {}, clear=True)
|
|
def test_init_default(self):
|
|
"""Test initialization with default parameters."""
|
|
# Create instance to test
|
|
embedding = BedrockEmbedding()
|
|
|
|
# Check that boto3 client was created correctly
|
|
self.mock_boto3.client.assert_called_once_with(
|
|
"bedrock-runtime",
|
|
region_name="us-east-1",
|
|
aws_access_key_id=None,
|
|
aws_secret_access_key=None
|
|
)
|
|
|
|
# Check default model
|
|
self.assertEqual(embedding.model, MODEL_ID_TITAN_TEXT_V2)
|
|
|
|
# Ensure no coroutine warnings
|
|
self.mock_client.invoke_model.return_value = self.mock_response
|
|
|
|
@patch.dict('os.environ', {
|
|
'AWS_ACCESS_KEY_ID': 'test_key',
|
|
'AWS_SECRET_ACCESS_KEY': 'test_secret'
|
|
}, clear=True)
|
|
def test_init_with_credentials(self):
|
|
"""Test initialization with AWS credentials."""
|
|
embedding = BedrockEmbedding()
|
|
self.mock_boto3.client.assert_called_with(
|
|
"bedrock-runtime",
|
|
region_name="us-east-1",
|
|
aws_access_key_id="test_key",
|
|
aws_secret_access_key="test_secret"
|
|
)
|
|
|
|
@patch.dict('os.environ', {}, clear=True)
|
|
def test_init_with_different_models(self):
|
|
"""Test initialization with different models."""
|
|
# Test Titan Text G1
|
|
embedding = BedrockEmbedding(model=MODEL_ID_TITAN_TEXT_G1)
|
|
self.assertEqual(embedding.model, MODEL_ID_TITAN_TEXT_G1)
|
|
|
|
# Test Cohere English V3
|
|
embedding = BedrockEmbedding(model=MODEL_ID_COHERE_ENGLISH_V3)
|
|
self.assertEqual(embedding.model, MODEL_ID_COHERE_ENGLISH_V3)
|
|
|
|
@patch.dict('os.environ', {}, clear=True)
|
|
def test_embed_query(self):
|
|
"""Test embedding a single query."""
|
|
# Create instance to test
|
|
embedding = BedrockEmbedding()
|
|
|
|
query = "test query"
|
|
result = embedding.embed_query(query)
|
|
|
|
# Check that invoke_model was called correctly
|
|
self.mock_client.invoke_model.assert_called_once_with(
|
|
modelId=MODEL_ID_TITAN_TEXT_V2,
|
|
body=json.dumps({"inputText": query})
|
|
)
|
|
|
|
# Check result
|
|
self.assertEqual(len(result), 1024)
|
|
self.assertEqual(result, [0.1] * 1024)
|
|
|
|
@patch.dict('os.environ', {}, clear=True)
|
|
def test_embed_documents(self):
|
|
"""Test embedding multiple documents."""
|
|
# Create instance to test
|
|
embedding = BedrockEmbedding()
|
|
|
|
texts = ["text 1", "text 2", "text 3"]
|
|
results = embedding.embed_documents(texts)
|
|
|
|
# Check that invoke_model was called for each text
|
|
self.assertEqual(self.mock_client.invoke_model.call_count, 3)
|
|
for text in texts:
|
|
self.mock_client.invoke_model.assert_any_call(
|
|
modelId=MODEL_ID_TITAN_TEXT_V2,
|
|
body=json.dumps({"inputText": text})
|
|
)
|
|
|
|
# Check results
|
|
self.assertEqual(len(results), 3)
|
|
for result in results:
|
|
self.assertEqual(len(result), 1024)
|
|
self.assertEqual(result, [0.1] * 1024)
|
|
|
|
@patch.dict('os.environ', {}, clear=True)
|
|
def test_dimension_property(self):
|
|
"""Test the dimension property for different models."""
|
|
# Create instance to test with Titan Text V2
|
|
embedding = BedrockEmbedding()
|
|
self.assertEqual(embedding.dimension, 1024)
|
|
|
|
# Test Titan Text G1
|
|
embedding = BedrockEmbedding(model=MODEL_ID_TITAN_TEXT_G1)
|
|
self.assertEqual(embedding.dimension, 1536)
|
|
|
|
# Test Cohere English V3
|
|
embedding = BedrockEmbedding(model=MODEL_ID_COHERE_ENGLISH_V3)
|
|
self.assertEqual(embedding.dimension, 1024)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|