You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

284 lines
11 KiB

import unittest
from unittest.mock import MagicMock, patch, ANY
import os
class TestWatsonXEmbedding(unittest.TestCase):
"""Test cases for WatsonXEmbedding class."""
def setUp(self):
"""Set up test fixtures."""
# Mock the ibm_watsonx_ai imports
self.mock_credentials = MagicMock()
self.mock_embeddings = MagicMock()
# Create a mock client
self.mock_client = MagicMock()
# Set up mock response for embed_query
self.mock_client.embed_query.return_value = {
'results': [
{'embedding': [0.1] * 768}
]
}
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings')
@patch('deepsearcher.embedding.watsonx_embedding.Credentials')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_init_with_env_vars(self, mock_credentials_class, mock_embeddings_class):
"""Test initialization with environment variables."""
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding
mock_credentials_instance = MagicMock()
mock_embeddings_instance = MagicMock()
mock_credentials_class.return_value = mock_credentials_instance
mock_embeddings_class.return_value = mock_embeddings_instance
embedding = WatsonXEmbedding()
# Check that Credentials was called with correct parameters
mock_credentials_class.assert_called_once_with(
url='https://test.watsonx.com',
api_key='test-api-key'
)
# Check that Embeddings was called with correct parameters
mock_embeddings_class.assert_called_once_with(
model_id='ibm/slate-125m-english-rtrvr-v2',
credentials=mock_credentials_instance,
project_id='test-project-id'
)
# Check default model and dimension
self.assertEqual(embedding.model, 'ibm/slate-125m-english-rtrvr-v2')
self.assertEqual(embedding.dimension, 768)
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings')
@patch('deepsearcher.embedding.watsonx_embedding.Credentials')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com'
})
def test_init_with_space_id(self, mock_credentials_class, mock_embeddings_class):
"""Test initialization with space_id instead of project_id."""
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding
mock_credentials_instance = MagicMock()
mock_embeddings_instance = MagicMock()
mock_credentials_class.return_value = mock_credentials_instance
mock_embeddings_class.return_value = mock_embeddings_instance
embedding = WatsonXEmbedding(space_id='test-space-id')
# Check that Embeddings was called with space_id
mock_embeddings_class.assert_called_once_with(
model_id='ibm/slate-125m-english-rtrvr-v2',
credentials=mock_credentials_instance,
space_id='test-space-id'
)
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings')
@patch('deepsearcher.embedding.watsonx_embedding.Credentials')
def test_init_missing_api_key(self, mock_credentials_class, mock_embeddings_class):
"""Test initialization with missing API key."""
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding
with patch.dict(os.environ, {}, clear=True):
with self.assertRaises(ValueError) as context:
WatsonXEmbedding()
self.assertIn("WATSONX_APIKEY", str(context.exception))
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings')
@patch('deepsearcher.embedding.watsonx_embedding.Credentials')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key'
})
def test_init_missing_url(self, mock_credentials_class, mock_embeddings_class):
"""Test initialization with missing URL."""
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding
with self.assertRaises(ValueError) as context:
WatsonXEmbedding()
self.assertIn("WATSONX_URL", str(context.exception))
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings')
@patch('deepsearcher.embedding.watsonx_embedding.Credentials')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com'
})
def test_init_missing_project_and_space_id(self, mock_credentials_class, mock_embeddings_class):
"""Test initialization with missing both project_id and space_id."""
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding
with self.assertRaises(ValueError) as context:
WatsonXEmbedding()
self.assertIn("WATSONX_PROJECT_ID", str(context.exception))
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings')
@patch('deepsearcher.embedding.watsonx_embedding.Credentials')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_embed_query(self, mock_credentials_class, mock_embeddings_class):
"""Test embedding a single query."""
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding
mock_credentials_instance = MagicMock()
mock_embeddings_instance = MagicMock()
# WatsonX embed_query returns the embedding vector directly, not wrapped in a dict
mock_embeddings_instance.embed_query.return_value = [0.1] * 768
mock_credentials_class.return_value = mock_credentials_instance
mock_embeddings_class.return_value = mock_embeddings_instance
# Create the embedder
embedding = WatsonXEmbedding()
# Create a test query
query = "This is a test query"
# Call the method
result = embedding.embed_query(query)
# Verify that embed_query was called correctly
mock_embeddings_instance.embed_query.assert_called_once_with(text=query)
# Check the result
self.assertEqual(result, [0.1] * 768)
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings')
@patch('deepsearcher.embedding.watsonx_embedding.Credentials')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_embed_documents(self, mock_credentials_class, mock_embeddings_class):
"""Test embedding multiple documents."""
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding
mock_credentials_instance = MagicMock()
mock_embeddings_instance = MagicMock()
# WatsonX embed_documents returns a list of embedding vectors directly
mock_embeddings_instance.embed_documents.return_value = [
[0.1] * 768,
[0.2] * 768,
[0.3] * 768
]
mock_credentials_class.return_value = mock_credentials_instance
mock_embeddings_class.return_value = mock_embeddings_instance
# Create the embedder
embedding = WatsonXEmbedding()
# Create test documents
documents = ["Document 1", "Document 2", "Document 3"]
# Call the method
results = embedding.embed_documents(documents)
# Verify that embed_documents was called correctly
mock_embeddings_instance.embed_documents.assert_called_once_with(texts=documents)
# Check the results
self.assertEqual(len(results), 3)
self.assertEqual(results[0], [0.1] * 768)
self.assertEqual(results[1], [0.2] * 768)
self.assertEqual(results[2], [0.3] * 768)
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings')
@patch('deepsearcher.embedding.watsonx_embedding.Credentials')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_dimension_property(self, mock_credentials_class, mock_embeddings_class):
"""Test the dimension property for different models."""
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding
mock_credentials_instance = MagicMock()
mock_embeddings_instance = MagicMock()
mock_credentials_class.return_value = mock_credentials_instance
mock_embeddings_class.return_value = mock_embeddings_instance
# Test default model
embedding = WatsonXEmbedding()
self.assertEqual(embedding.dimension, 768)
# Test different model
embedding = WatsonXEmbedding(model='ibm/slate-30m-english-rtrvr')
self.assertEqual(embedding.dimension, 384)
# Test unknown model (should default to 768)
embedding = WatsonXEmbedding(model='unknown-model')
self.assertEqual(embedding.dimension, 768)
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings')
@patch('deepsearcher.embedding.watsonx_embedding.Credentials')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_embed_query_error_handling(self, mock_credentials_class, mock_embeddings_class):
"""Test error handling in embed_query."""
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding
mock_credentials_instance = MagicMock()
mock_embeddings_instance = MagicMock()
mock_embeddings_instance.embed_query.side_effect = Exception("API Error")
mock_credentials_class.return_value = mock_credentials_instance
mock_embeddings_class.return_value = mock_embeddings_instance
# Create the embedder
embedding = WatsonXEmbedding()
# Test that the exception is properly wrapped
with self.assertRaises(RuntimeError) as context:
embedding.embed_query("test")
self.assertIn("Error embedding query with WatsonX", str(context.exception))
@patch('deepsearcher.embedding.watsonx_embedding.Embeddings')
@patch('deepsearcher.embedding.watsonx_embedding.Credentials')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_embed_documents_error_handling(self, mock_credentials_class, mock_embeddings_class):
"""Test error handling in embed_documents."""
from deepsearcher.embedding.watsonx_embedding import WatsonXEmbedding
mock_credentials_instance = MagicMock()
mock_embeddings_instance = MagicMock()
mock_embeddings_instance.embed_documents.side_effect = Exception("API Error")
mock_credentials_class.return_value = mock_credentials_instance
mock_embeddings_class.return_value = mock_embeddings_instance
# Create the embedder
embedding = WatsonXEmbedding()
# Test that the exception is properly wrapped
with self.assertRaises(RuntimeError) as context:
embedding.embed_documents(["test"])
self.assertIn("Error embedding documents with WatsonX", str(context.exception))
if __name__ == '__main__':
unittest.main()