You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
213 lines
8.1 KiB
213 lines
8.1 KiB
import unittest
|
|
import sys
|
|
import logging
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
# Disable logging for tests
|
|
logging.disable(logging.CRITICAL)
|
|
|
|
from deepsearcher.embedding import SentenceTransformerEmbedding
|
|
|
|
|
|
class TestSentenceTransformerEmbedding(unittest.TestCase):
|
|
"""Tests for the SentenceTransformerEmbedding class."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
# Create mock module for sentence_transformers
|
|
mock_st_module = MagicMock()
|
|
|
|
# Create mock SentenceTransformer class
|
|
self.mock_sentence_transformer = MagicMock()
|
|
mock_st_module.SentenceTransformer = self.mock_sentence_transformer
|
|
|
|
# Add the mock module to sys.modules
|
|
self.module_patcher = patch.dict('sys.modules', {'sentence_transformers': mock_st_module})
|
|
self.module_patcher.start()
|
|
|
|
# Set up mock instance
|
|
self.mock_model = MagicMock()
|
|
self.mock_sentence_transformer.return_value = self.mock_model
|
|
|
|
# Configure mock encode method
|
|
mock_embedding = [[0.1, 0.2, 0.3] * 341 + [0.4]] # 1024 dimensions
|
|
self.mock_model.encode.return_value = MagicMock()
|
|
self.mock_model.encode.return_value.tolist.return_value = mock_embedding
|
|
|
|
def tearDown(self):
|
|
"""Clean up test fixtures."""
|
|
self.module_patcher.stop()
|
|
|
|
@patch.dict('os.environ', {}, clear=True)
|
|
def test_init(self):
|
|
"""Test initialization."""
|
|
# Create instance to test
|
|
embedding = SentenceTransformerEmbedding(model="BAAI/bge-m3")
|
|
|
|
# Check that SentenceTransformer was called with the right model
|
|
self.mock_sentence_transformer.assert_called_once_with("BAAI/bge-m3")
|
|
|
|
# Check that model and client were set correctly
|
|
self.assertEqual(embedding.model, "BAAI/bge-m3")
|
|
self.assertEqual(embedding.client, self.mock_model)
|
|
|
|
# Check batch size default
|
|
self.assertEqual(embedding.batch_size, 32)
|
|
|
|
# Test with model_name parameter
|
|
self.mock_sentence_transformer.reset_mock()
|
|
embedding = SentenceTransformerEmbedding(model_name="BAAI/bge-large-zh-v1.5")
|
|
self.mock_sentence_transformer.assert_called_once_with("BAAI/bge-large-zh-v1.5")
|
|
self.assertEqual(embedding.model, "BAAI/bge-large-zh-v1.5")
|
|
|
|
# Test with custom batch size
|
|
self.mock_sentence_transformer.reset_mock()
|
|
embedding = SentenceTransformerEmbedding(batch_size=64)
|
|
self.assertEqual(embedding.batch_size, 64)
|
|
|
|
@patch.dict('os.environ', {}, clear=True)
|
|
def test_embed_query(self):
|
|
"""Test embedding a single query."""
|
|
# Create instance to test
|
|
embedding = SentenceTransformerEmbedding(model="BAAI/bge-m3")
|
|
|
|
# Mock the encode response for a single query
|
|
single_embedding = [0.1, 0.2, 0.3] * 341 + [0.4] # 1024 dimensions
|
|
self.mock_model.encode.return_value = MagicMock()
|
|
self.mock_model.encode.return_value.tolist.return_value = [single_embedding]
|
|
|
|
# Call the method
|
|
result = embedding.embed_query("test query")
|
|
|
|
# Verify encode was called correctly
|
|
self.mock_model.encode.assert_called_once_with("test query")
|
|
|
|
# Check the result
|
|
self.assertEqual(len(result), 1024)
|
|
self.assertEqual(result, single_embedding)
|
|
|
|
@patch.dict('os.environ', {}, clear=True)
|
|
def test_embed_documents_small_batch(self):
|
|
"""Test embedding documents with a small batch (less than batch size)."""
|
|
# Create instance to test
|
|
embedding = SentenceTransformerEmbedding(model="BAAI/bge-m3")
|
|
|
|
# Mock the encode response for documents
|
|
batch_embeddings = [
|
|
[0.1, 0.2, 0.3] * 341 + [0.4], # 1024 dimensions
|
|
[0.4, 0.5, 0.6] * 341 + [0.7],
|
|
[0.7, 0.8, 0.9] * 341 + [0.1]
|
|
]
|
|
self.mock_model.encode.return_value = MagicMock()
|
|
self.mock_model.encode.return_value.tolist.return_value = batch_embeddings
|
|
|
|
# Create test texts
|
|
texts = ["text 1", "text 2", "text 3"]
|
|
|
|
# Call the method
|
|
results = embedding.embed_documents(texts)
|
|
|
|
# Verify encode was called correctly
|
|
self.mock_model.encode.assert_called_once_with(texts)
|
|
|
|
# Check the results
|
|
self.assertEqual(len(results), 3)
|
|
for i, result in enumerate(results):
|
|
self.assertEqual(len(result), 1024)
|
|
self.assertEqual(result, batch_embeddings[i])
|
|
|
|
@patch.dict('os.environ', {}, clear=True)
|
|
def test_embed_documents_large_batch(self):
|
|
"""Test embedding documents with a large batch (more than batch size)."""
|
|
# Create instance to test with small batch size
|
|
embedding = SentenceTransformerEmbedding(model="BAAI/bge-m3", batch_size=2)
|
|
|
|
# Mock the encode response for the first batch
|
|
batch1_embeddings = [
|
|
[0.1, 0.2, 0.3] * 341 + [0.4], # 1024 dimensions
|
|
[0.4, 0.5, 0.6] * 341 + [0.7]
|
|
]
|
|
# Mock the encode response for the second batch
|
|
batch2_embeddings = [
|
|
[0.7, 0.8, 0.9] * 341 + [0.1]
|
|
]
|
|
|
|
# Set up the mock to return different values on each call
|
|
self.mock_model.encode.side_effect = [
|
|
MagicMock(tolist=lambda: batch1_embeddings),
|
|
MagicMock(tolist=lambda: batch2_embeddings)
|
|
]
|
|
|
|
# Create test texts
|
|
texts = ["text 1", "text 2", "text 3"]
|
|
|
|
# Call the method
|
|
results = embedding.embed_documents(texts)
|
|
|
|
# Verify encode was called twice with the right batches
|
|
self.assertEqual(self.mock_model.encode.call_count, 2)
|
|
self.mock_model.encode.assert_any_call(["text 1", "text 2"])
|
|
self.mock_model.encode.assert_any_call(["text 3"])
|
|
|
|
# Check the results
|
|
self.assertEqual(len(results), 3)
|
|
self.assertEqual(results[0], batch1_embeddings[0])
|
|
self.assertEqual(results[1], batch1_embeddings[1])
|
|
self.assertEqual(results[2], batch2_embeddings[0])
|
|
|
|
@patch.dict('os.environ', {}, clear=True)
|
|
def test_embed_documents_no_batching(self):
|
|
"""Test embedding documents with batching disabled."""
|
|
# Create instance to test with batching disabled
|
|
embedding = SentenceTransformerEmbedding(model="BAAI/bge-m3", batch_size=0)
|
|
|
|
# Mock the embed_query method
|
|
original_embed_query = embedding.embed_query
|
|
embed_query_calls = []
|
|
|
|
def mock_embed_query(text):
|
|
embed_query_calls.append(text)
|
|
return [0.1] * 1024 # Return a simple mock embedding
|
|
|
|
embedding.embed_query = mock_embed_query
|
|
|
|
# Create test texts
|
|
texts = ["text 1", "text 2", "text 3"]
|
|
|
|
# Call the method
|
|
results = embedding.embed_documents(texts)
|
|
|
|
# Check that embed_query was called for each text
|
|
self.assertEqual(len(embed_query_calls), 3)
|
|
self.assertEqual(embed_query_calls, texts)
|
|
|
|
# Check the results
|
|
self.assertEqual(len(results), 3)
|
|
for result in results:
|
|
self.assertEqual(len(result), 1024)
|
|
self.assertEqual(result, [0.1] * 1024)
|
|
|
|
# Restore original method
|
|
embedding.embed_query = original_embed_query
|
|
|
|
@patch.dict('os.environ', {}, clear=True)
|
|
def test_dimension_property(self):
|
|
"""Test the dimension property."""
|
|
# Create instance to test
|
|
embedding = SentenceTransformerEmbedding(model="BAAI/bge-m3")
|
|
|
|
# Check dimension for BAAI/bge-m3
|
|
self.assertEqual(embedding.dimension, 1024)
|
|
|
|
# Test with different models
|
|
self.mock_sentence_transformer.reset_mock()
|
|
embedding = SentenceTransformerEmbedding(model="BAAI/bge-large-zh-v1.5")
|
|
self.assertEqual(embedding.dimension, 1024)
|
|
|
|
self.mock_sentence_transformer.reset_mock()
|
|
embedding = SentenceTransformerEmbedding(model="BAAI/bge-large-en-v1.5")
|
|
self.assertEqual(embedding.dimension, 1024)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|