You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

239 lines
8.6 KiB

import unittest
import sys
from unittest.mock import patch, MagicMock
from deepsearcher.embedding import OllamaEmbedding
class TestOllamaEmbedding(unittest.TestCase):
"""Tests for the OllamaEmbedding class."""
def setUp(self):
"""Set up test fixtures."""
# Create mock module for ollama
mock_ollama_module = MagicMock()
# Create mock Client class
self.mock_ollama_client = MagicMock()
mock_ollama_module.Client = self.mock_ollama_client
# Add the mock module to sys.modules
self.module_patcher = patch.dict('sys.modules', {'ollama': mock_ollama_module})
self.module_patcher.start()
# Set up mock client instance
self.mock_client = MagicMock()
self.mock_ollama_client.return_value = self.mock_client
# Configure mock embed method
self.mock_client.embed.return_value = {"embeddings": [[0.1] * 1024]}
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
@patch.dict('os.environ', {}, clear=True)
def test_init_default(self):
"""Test initialization with default parameters."""
# Create instance to test
embedding = OllamaEmbedding(model="bge-m3")
# Check that Client was initialized correctly
self.mock_ollama_client.assert_called_once_with(host="http://localhost:11434/")
# Check instance attributes
self.assertEqual(embedding.model, "bge-m3")
self.assertEqual(embedding.dim, 1024)
self.assertEqual(embedding.batch_size, 32)
@patch.dict('os.environ', {}, clear=True)
def test_init_with_base_url(self):
"""Test initialization with custom base URL."""
# Reset mock
self.mock_ollama_client.reset_mock()
# Create embedding with custom base URL
embedding = OllamaEmbedding(base_url="http://custom-ollama-server:11434/")
# Check that Client was initialized with custom base URL
self.mock_ollama_client.assert_called_with(host="http://custom-ollama-server:11434/")
@patch.dict('os.environ', {}, clear=True)
def test_init_with_model_name(self):
"""Test initialization with model_name parameter."""
# Reset mock
self.mock_ollama_client.reset_mock()
# Create embedding with model_name
embedding = OllamaEmbedding(model_name="mxbai-embed-large")
# Check model attribute
self.assertEqual(embedding.model, "mxbai-embed-large")
# Check dimension is set correctly based on model
self.assertEqual(embedding.dim, 768)
@patch.dict('os.environ', {}, clear=True)
def test_init_with_dimension(self):
"""Test initialization with custom dimension."""
# Reset mock
self.mock_ollama_client.reset_mock()
# Create embedding with custom dimension
embedding = OllamaEmbedding(dimension=512)
# Check dimension attribute
self.assertEqual(embedding.dim, 512)
@patch.dict('os.environ', {}, clear=True)
def test_embed_query(self):
"""Test embedding a single query."""
# Create instance to test
embedding = OllamaEmbedding(model="bge-m3")
# Set up mock response
self.mock_client.embed.return_value = {"embeddings": [[0.1, 0.2, 0.3] * 341 + [0.4]]} # 1024 dimensions
# Call the method
result = embedding.embed_query("test query")
# Verify embed was called correctly
self.mock_client.embed.assert_called_once_with(model="bge-m3", input="test query")
# Check the result
self.assertEqual(len(result), 1024)
self.assertEqual(result, [0.1, 0.2, 0.3] * 341 + [0.4])
@patch.dict('os.environ', {}, clear=True)
def test_embed_documents_small_batch(self):
"""Test embedding documents with a small batch (less than batch size)."""
# Create instance to test
embedding = OllamaEmbedding(model="bge-m3")
# Set up mock response for multiple documents
mock_embeddings = [
[0.1, 0.2, 0.3] * 341 + [0.4], # 1024 dimensions
[0.4, 0.5, 0.6] * 341 + [0.7],
[0.7, 0.8, 0.9] * 341 + [0.1]
]
self.mock_client.embed.return_value = {"embeddings": mock_embeddings}
# Create test texts
texts = ["text 1", "text 2", "text 3"]
# Call the method
results = embedding.embed_documents(texts)
# Verify embed was called correctly
self.mock_client.embed.assert_called_once_with(model="bge-m3", input=texts)
# Check the results
self.assertEqual(len(results), 3)
for i, result in enumerate(results):
self.assertEqual(len(result), 1024)
self.assertEqual(result, mock_embeddings[i])
@patch.dict('os.environ', {}, clear=True)
def test_embed_documents_large_batch(self):
"""Test embedding documents with a large batch (more than batch size)."""
# Create instance to test
embedding = OllamaEmbedding(model="bge-m3")
# Set a smaller batch size for testing
embedding.batch_size = 2
# Set up mock responses for batches
batch1_embeddings = [
[0.1, 0.2, 0.3] * 341 + [0.4], # 1024 dimensions
[0.4, 0.5, 0.6] * 341 + [0.7]
]
batch2_embeddings = [
[0.7, 0.8, 0.9] * 341 + [0.1]
]
# Configure mock to return different responses for each call
self.mock_client.embed.side_effect = [
{"embeddings": batch1_embeddings},
{"embeddings": batch2_embeddings}
]
# Create test texts
texts = ["text 1", "text 2", "text 3"]
# Call the method
results = embedding.embed_documents(texts)
# Verify embed was called twice with the right batches
self.assertEqual(self.mock_client.embed.call_count, 2)
self.mock_client.embed.assert_any_call(model="bge-m3", input=["text 1", "text 2"])
self.mock_client.embed.assert_any_call(model="bge-m3", input=["text 3"])
# Check the results
self.assertEqual(len(results), 3)
self.assertEqual(results[0], batch1_embeddings[0])
self.assertEqual(results[1], batch1_embeddings[1])
self.assertEqual(results[2], batch2_embeddings[0])
@patch.dict('os.environ', {}, clear=True)
def test_embed_documents_no_batching(self):
"""Test embedding documents with batching disabled."""
# Create instance to test
embedding = OllamaEmbedding(model="bge-m3")
# Disable batching
embedding.batch_size = 0
# Mock the embed_query method
original_embed_query = embedding.embed_query
embed_query_calls = []
def mock_embed_query(text):
embed_query_calls.append(text)
return [0.1] * 1024 # Return a simple mock embedding
embedding.embed_query = mock_embed_query
# Create test texts
texts = ["text 1", "text 2", "text 3"]
# Call the method
results = embedding.embed_documents(texts)
# Check that embed_query was called for each text
self.assertEqual(len(embed_query_calls), 3)
self.assertEqual(embed_query_calls, texts)
# Check the results
self.assertEqual(len(results), 3)
for result in results:
self.assertEqual(len(result), 1024)
self.assertEqual(result, [0.1] * 1024)
# Restore original method
embedding.embed_query = original_embed_query
@patch.dict('os.environ', {}, clear=True)
def test_dimension_property(self):
"""Test the dimension property."""
# Create instance to test
embedding = OllamaEmbedding(model="bge-m3")
# Check dimension for bge-m3
self.assertEqual(embedding.dimension, 1024)
# Test with different models
self.mock_ollama_client.reset_mock()
embedding = OllamaEmbedding(model="mxbai-embed-large")
self.assertEqual(embedding.dimension, 768)
self.mock_ollama_client.reset_mock()
embedding = OllamaEmbedding(model="nomic-embed-text")
self.assertEqual(embedding.dimension, 768)
# Test with custom dimension
self.mock_ollama_client.reset_mock()
embedding = OllamaEmbedding(dimension=512)
self.assertEqual(embedding.dimension, 512)
if __name__ == "__main__":
unittest.main()