You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

105 lines
3.6 KiB

2 weeks ago
import unittest
from typing import List
from unittest.mock import patch, MagicMock
from deepsearcher.embedding.base import BaseEmbedding
from deepsearcher.loader.splitter import Chunk
class ConcreteEmbedding(BaseEmbedding):
"""A concrete implementation of BaseEmbedding for testing."""
def __init__(self, dimension=768):
self._dimension = dimension
def embed_query(self, text: str) -> List[float]:
"""Simple implementation that returns a vector of the given dimension."""
return [0.1] * self._dimension
@property
def dimension(self) -> int:
return self._dimension
class TestBaseEmbedding(unittest.TestCase):
"""Tests for the BaseEmbedding base class."""
@patch.dict('os.environ', {}, clear=True)
def test_embed_query(self):
"""Test the embed_query method."""
embedding = ConcreteEmbedding()
result = embedding.embed_query("test text")
self.assertEqual(len(result), 768)
self.assertEqual(result, [0.1] * 768)
@patch.dict('os.environ', {}, clear=True)
def test_embed_documents(self):
"""Test the embed_documents method."""
embedding = ConcreteEmbedding()
texts = ["text 1", "text 2", "text 3"]
results = embedding.embed_documents(texts)
# Check we got the right number of embeddings
self.assertEqual(len(results), 3)
# Check each embedding
for result in results:
self.assertEqual(len(result), 768)
self.assertEqual(result, [0.1] * 768)
@patch('deepsearcher.embedding.base.tqdm')
@patch.dict('os.environ', {}, clear=True)
def test_embed_chunks(self, mock_tqdm):
"""Test the embed_chunks method."""
embedding = ConcreteEmbedding()
# Set up mock tqdm to just return the iterable
mock_tqdm.return_value = lambda x, desc: x
# Create test chunks
chunks = [
Chunk(text="text 1", reference="ref1"),
Chunk(text="text 2", reference="ref2"),
Chunk(text="text 3", reference="ref3")
]
# Create a spy on embed_documents
original_embed_documents = embedding.embed_documents
embed_documents_calls = []
def mock_embed_documents(texts):
embed_documents_calls.append(texts)
return original_embed_documents(texts)
embedding.embed_documents = mock_embed_documents
# Mock tqdm to return the batch_texts directly
mock_tqdm.side_effect = lambda x, **kwargs: x
# Call the method
result_chunks = embedding.embed_chunks(chunks, batch_size=2)
# Verify embed_documents was called correctly
self.assertEqual(len(embed_documents_calls), 2) # Should be called twice with batch_size=2
self.assertEqual(embed_documents_calls[0], ["text 1", "text 2"])
self.assertEqual(embed_documents_calls[1], ["text 3"])
# Verify chunks were updated with embeddings
self.assertEqual(len(result_chunks), 3)
for chunk in result_chunks:
self.assertEqual(len(chunk.embedding), 768)
self.assertEqual(chunk.embedding, [0.1] * 768)
@patch.dict('os.environ', {}, clear=True)
def test_dimension_property(self):
"""Test the dimension property."""
embedding = ConcreteEmbedding()
self.assertEqual(embedding.dimension, 768)
# Test with different dimension
embedding = ConcreteEmbedding(dimension=1024)
self.assertEqual(embedding.dimension, 1024)
if __name__ == "__main__":
unittest.main()