You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

157 lines
5.3 KiB

import unittest
import numpy as np
from typing import List
from deepsearcher.vector_db.base import (
RetrievalResult,
deduplicate_results,
CollectionInfo,
BaseVectorDB,
)
from deepsearcher.loader.splitter import Chunk
class TestRetrievalResult(unittest.TestCase):
"""Tests for the RetrievalResult class."""
def setUp(self):
"""Set up test fixtures."""
self.embedding = np.array([0.1, 0.2, 0.3])
self.text = "Test text"
self.reference = "test.txt"
self.metadata = {"key": "value"}
self.score = 0.95
def test_init(self):
"""Test initialization of RetrievalResult."""
result = RetrievalResult(
embedding=self.embedding,
text=self.text,
reference=self.reference,
metadata=self.metadata,
score=self.score,
)
self.assertTrue(np.array_equal(result.embedding, self.embedding))
self.assertEqual(result.text, self.text)
self.assertEqual(result.reference, self.reference)
self.assertEqual(result.metadata, self.metadata)
self.assertEqual(result.score, self.score)
def test_init_default_score(self):
"""Test initialization of RetrievalResult with default score."""
result = RetrievalResult(
embedding=self.embedding,
text=self.text,
reference=self.reference,
metadata=self.metadata,
)
self.assertEqual(result.score, 0.0)
def test_repr(self):
"""Test string representation of RetrievalResult."""
result = RetrievalResult(
embedding=self.embedding,
text=self.text,
reference=self.reference,
metadata=self.metadata,
score=self.score,
)
expected = f"RetrievalResult(score={self.score}, embedding={self.embedding}, text={self.text}, reference={self.reference}), metadata={self.metadata}"
self.assertEqual(repr(result), expected)
class TestDeduplicateResults(unittest.TestCase):
"""Tests for the deduplicate_results function."""
def setUp(self):
"""Set up test fixtures."""
self.embedding1 = np.array([0.1, 0.2, 0.3])
self.embedding2 = np.array([0.4, 0.5, 0.6])
self.text1 = "Text 1"
self.text2 = "Text 2"
self.reference = "test.txt"
self.metadata = {"key": "value"}
def test_no_duplicates(self):
"""Test deduplication with no duplicate results."""
results = [
RetrievalResult(self.embedding1, self.text1, self.reference, self.metadata),
RetrievalResult(self.embedding2, self.text2, self.reference, self.metadata),
]
deduplicated = deduplicate_results(results)
self.assertEqual(len(deduplicated), 2)
self.assertEqual(deduplicated, results)
def test_with_duplicates(self):
"""Test deduplication with duplicate results."""
results = [
RetrievalResult(self.embedding1, self.text1, self.reference, self.metadata),
RetrievalResult(self.embedding2, self.text2, self.reference, self.metadata),
RetrievalResult(self.embedding1, self.text1, self.reference, self.metadata),
]
deduplicated = deduplicate_results(results)
self.assertEqual(len(deduplicated), 2)
self.assertEqual(deduplicated[0].text, self.text1)
self.assertEqual(deduplicated[1].text, self.text2)
def test_empty_list(self):
"""Test deduplication with empty list."""
results = []
deduplicated = deduplicate_results(results)
self.assertEqual(len(deduplicated), 0)
class TestCollectionInfo(unittest.TestCase):
"""Tests for the CollectionInfo class."""
def test_init(self):
"""Test initialization of CollectionInfo."""
name = "test_collection"
description = "Test collection description"
collection_info = CollectionInfo(name, description)
self.assertEqual(collection_info.collection_name, name)
self.assertEqual(collection_info.description, description)
class MockVectorDB(BaseVectorDB):
"""Mock implementation of BaseVectorDB for testing."""
def init_collection(self, dim, collection, description, force_new_collection=False, *args, **kwargs):
pass
def insert_data(self, collection, chunks, *args, **kwargs):
pass
def search_data(self, collection, vector, *args, **kwargs) -> List[RetrievalResult]:
return []
def clear_db(self, *args, **kwargs):
pass
class TestBaseVectorDB(unittest.TestCase):
"""Tests for the BaseVectorDB class."""
def setUp(self):
"""Set up test fixtures."""
self.db = MockVectorDB()
def test_init_default(self):
"""Test initialization with default collection name."""
self.assertEqual(self.db.default_collection, "deepsearcher")
def test_init_custom_collection(self):
"""Test initialization with custom collection name."""
custom_collection = "custom_collection"
db = MockVectorDB(default_collection=custom_collection)
self.assertEqual(db.default_collection, custom_collection)
def test_list_collections_default(self):
"""Test default list_collections implementation."""
self.assertIsNone(self.db.list_collections())
if __name__ == "__main__":
unittest.main()