You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
157 lines
5.3 KiB
157 lines
5.3 KiB
import unittest
|
|
import numpy as np
|
|
from typing import List
|
|
|
|
from deepsearcher.vector_db.base import (
|
|
RetrievalResult,
|
|
deduplicate_results,
|
|
CollectionInfo,
|
|
BaseVectorDB,
|
|
)
|
|
from deepsearcher.loader.splitter import Chunk
|
|
|
|
|
|
class TestRetrievalResult(unittest.TestCase):
|
|
"""Tests for the RetrievalResult class."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
self.embedding = np.array([0.1, 0.2, 0.3])
|
|
self.text = "Test text"
|
|
self.reference = "test.txt"
|
|
self.metadata = {"key": "value"}
|
|
self.score = 0.95
|
|
|
|
def test_init(self):
|
|
"""Test initialization of RetrievalResult."""
|
|
result = RetrievalResult(
|
|
embedding=self.embedding,
|
|
text=self.text,
|
|
reference=self.reference,
|
|
metadata=self.metadata,
|
|
score=self.score,
|
|
)
|
|
|
|
self.assertTrue(np.array_equal(result.embedding, self.embedding))
|
|
self.assertEqual(result.text, self.text)
|
|
self.assertEqual(result.reference, self.reference)
|
|
self.assertEqual(result.metadata, self.metadata)
|
|
self.assertEqual(result.score, self.score)
|
|
|
|
def test_init_default_score(self):
|
|
"""Test initialization of RetrievalResult with default score."""
|
|
result = RetrievalResult(
|
|
embedding=self.embedding,
|
|
text=self.text,
|
|
reference=self.reference,
|
|
metadata=self.metadata,
|
|
)
|
|
self.assertEqual(result.score, 0.0)
|
|
|
|
def test_repr(self):
|
|
"""Test string representation of RetrievalResult."""
|
|
result = RetrievalResult(
|
|
embedding=self.embedding,
|
|
text=self.text,
|
|
reference=self.reference,
|
|
metadata=self.metadata,
|
|
score=self.score,
|
|
)
|
|
expected = f"RetrievalResult(score={self.score}, embedding={self.embedding}, text={self.text}, reference={self.reference}), metadata={self.metadata}"
|
|
self.assertEqual(repr(result), expected)
|
|
|
|
|
|
class TestDeduplicateResults(unittest.TestCase):
|
|
"""Tests for the deduplicate_results function."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
self.embedding1 = np.array([0.1, 0.2, 0.3])
|
|
self.embedding2 = np.array([0.4, 0.5, 0.6])
|
|
self.text1 = "Text 1"
|
|
self.text2 = "Text 2"
|
|
self.reference = "test.txt"
|
|
self.metadata = {"key": "value"}
|
|
|
|
def test_no_duplicates(self):
|
|
"""Test deduplication with no duplicate results."""
|
|
results = [
|
|
RetrievalResult(self.embedding1, self.text1, self.reference, self.metadata),
|
|
RetrievalResult(self.embedding2, self.text2, self.reference, self.metadata),
|
|
]
|
|
deduplicated = deduplicate_results(results)
|
|
self.assertEqual(len(deduplicated), 2)
|
|
self.assertEqual(deduplicated, results)
|
|
|
|
def test_with_duplicates(self):
|
|
"""Test deduplication with duplicate results."""
|
|
results = [
|
|
RetrievalResult(self.embedding1, self.text1, self.reference, self.metadata),
|
|
RetrievalResult(self.embedding2, self.text2, self.reference, self.metadata),
|
|
RetrievalResult(self.embedding1, self.text1, self.reference, self.metadata),
|
|
]
|
|
deduplicated = deduplicate_results(results)
|
|
self.assertEqual(len(deduplicated), 2)
|
|
self.assertEqual(deduplicated[0].text, self.text1)
|
|
self.assertEqual(deduplicated[1].text, self.text2)
|
|
|
|
def test_empty_list(self):
|
|
"""Test deduplication with empty list."""
|
|
results = []
|
|
deduplicated = deduplicate_results(results)
|
|
self.assertEqual(len(deduplicated), 0)
|
|
|
|
|
|
class TestCollectionInfo(unittest.TestCase):
|
|
"""Tests for the CollectionInfo class."""
|
|
|
|
def test_init(self):
|
|
"""Test initialization of CollectionInfo."""
|
|
name = "test_collection"
|
|
description = "Test collection description"
|
|
collection_info = CollectionInfo(name, description)
|
|
|
|
self.assertEqual(collection_info.collection_name, name)
|
|
self.assertEqual(collection_info.description, description)
|
|
|
|
|
|
class MockVectorDB(BaseVectorDB):
|
|
"""Mock implementation of BaseVectorDB for testing."""
|
|
|
|
def init_collection(self, dim, collection, description, force_new_collection=False, *args, **kwargs):
|
|
pass
|
|
|
|
def insert_data(self, collection, chunks, *args, **kwargs):
|
|
pass
|
|
|
|
def search_data(self, collection, vector, *args, **kwargs) -> List[RetrievalResult]:
|
|
return []
|
|
|
|
def clear_db(self, *args, **kwargs):
|
|
pass
|
|
|
|
|
|
class TestBaseVectorDB(unittest.TestCase):
|
|
"""Tests for the BaseVectorDB class."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
self.db = MockVectorDB()
|
|
|
|
def test_init_default(self):
|
|
"""Test initialization with default collection name."""
|
|
self.assertEqual(self.db.default_collection, "deepsearcher")
|
|
|
|
def test_init_custom_collection(self):
|
|
"""Test initialization with custom collection name."""
|
|
custom_collection = "custom_collection"
|
|
db = MockVectorDB(default_collection=custom_collection)
|
|
self.assertEqual(db.default_collection, custom_collection)
|
|
|
|
def test_list_collections_default(self):
|
|
"""Test default list_collections implementation."""
|
|
self.assertIsNone(self.db.list_collections())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|