You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
237 lines
8.0 KiB
237 lines
8.0 KiB
import unittest
|
|
from unittest.mock import patch, MagicMock
|
|
import numpy as np
|
|
import sys
|
|
|
|
from deepsearcher.vector_db import AzureSearch
|
|
from deepsearcher.vector_db.base import RetrievalResult
|
|
|
|
|
|
class TestAzureSearch(unittest.TestCase):
|
|
"""Tests for the Azure Search vector database implementation."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
# Create mock modules
|
|
self.mock_azure = MagicMock()
|
|
self.mock_search = MagicMock()
|
|
self.mock_indexes = MagicMock()
|
|
self.mock_models = MagicMock()
|
|
self.mock_credentials = MagicMock()
|
|
self.mock_exceptions = MagicMock()
|
|
|
|
# Setup nested structure
|
|
self.mock_azure.search = self.mock_search
|
|
self.mock_search.documents = self.mock_search
|
|
self.mock_search.documents.indexes = self.mock_indexes
|
|
self.mock_indexes.models = self.mock_models
|
|
self.mock_azure.core = self.mock_credentials
|
|
self.mock_azure.core.credentials = self.mock_credentials
|
|
self.mock_azure.core.exceptions = self.mock_exceptions
|
|
|
|
# Mock specific models needed for init_collection
|
|
self.mock_models.SearchableField = MagicMock()
|
|
self.mock_models.SimpleField = MagicMock()
|
|
self.mock_models.SearchField = MagicMock()
|
|
self.mock_models.SearchIndex = MagicMock()
|
|
|
|
# Create the module patcher
|
|
self.module_patcher = patch.dict('sys.modules', {
|
|
'azure': self.mock_azure,
|
|
'azure.core': self.mock_credentials,
|
|
'azure.core.credentials': self.mock_credentials,
|
|
'azure.core.exceptions': self.mock_exceptions,
|
|
'azure.search': self.mock_search,
|
|
'azure.search.documents': self.mock_search,
|
|
'azure.search.documents.indexes': self.mock_indexes,
|
|
'azure.search.documents.indexes.models': self.mock_models
|
|
})
|
|
|
|
# Start the patcher
|
|
self.module_patcher.start()
|
|
|
|
# Import after mocking
|
|
from deepsearcher.vector_db import AzureSearch
|
|
from deepsearcher.vector_db.base import RetrievalResult
|
|
|
|
self.AzureSearch = AzureSearch
|
|
self.RetrievalResult = RetrievalResult
|
|
|
|
def tearDown(self):
|
|
"""Clean up test fixtures."""
|
|
self.module_patcher.stop()
|
|
|
|
def test_init(self):
|
|
"""Test basic initialization."""
|
|
# Setup mock
|
|
mock_client = MagicMock()
|
|
self.mock_search.SearchClient.return_value = mock_client
|
|
|
|
azure_search = self.AzureSearch(
|
|
endpoint="https://test-search.search.windows.net",
|
|
index_name="test-index",
|
|
api_key="test-key",
|
|
vector_field="content_vector"
|
|
)
|
|
|
|
# Verify initialization
|
|
self.assertEqual(azure_search.index_name, "test-index")
|
|
self.assertEqual(azure_search.endpoint, "https://test-search.search.windows.net")
|
|
self.assertEqual(azure_search.api_key, "test-key")
|
|
self.assertEqual(azure_search.vector_field, "content_vector")
|
|
self.assertIsNotNone(azure_search.client)
|
|
|
|
def test_init_collection(self):
|
|
"""Test collection initialization."""
|
|
# Setup mock
|
|
mock_index_client = MagicMock()
|
|
self.mock_indexes.SearchIndexClient.return_value = mock_index_client
|
|
mock_index_client.create_index.return_value = None
|
|
|
|
azure_search = self.AzureSearch(
|
|
endpoint="https://test-search.search.windows.net",
|
|
index_name="test-index",
|
|
api_key="test-key",
|
|
vector_field="content_vector"
|
|
)
|
|
|
|
azure_search.init_collection()
|
|
self.assertTrue(mock_index_client.create_index.called)
|
|
|
|
def test_insert_data(self):
|
|
"""Test inserting data."""
|
|
# Setup mock
|
|
mock_client = MagicMock()
|
|
self.mock_search.SearchClient.return_value = mock_client
|
|
|
|
# Mock successful upload result
|
|
mock_result = [MagicMock(succeeded=True) for _ in range(2)]
|
|
mock_client.upload_documents.return_value = mock_result
|
|
|
|
azure_search = self.AzureSearch(
|
|
endpoint="https://test-search.search.windows.net",
|
|
index_name="test-index",
|
|
api_key="test-key",
|
|
vector_field="content_vector"
|
|
)
|
|
|
|
# Create test data
|
|
d = 1536 # Azure Search expects 1536 dimensions
|
|
rng = np.random.default_rng(seed=42)
|
|
|
|
test_docs = [
|
|
{
|
|
"text": "hello world",
|
|
"vector": rng.random(d).tolist(),
|
|
"id": "doc1"
|
|
},
|
|
{
|
|
"text": "hello azure search",
|
|
"vector": rng.random(d).tolist(),
|
|
"id": "doc2"
|
|
}
|
|
]
|
|
|
|
results = azure_search.insert_data(documents=test_docs)
|
|
self.assertEqual(len(results), 2)
|
|
self.assertTrue(all(results))
|
|
|
|
def test_search_data(self):
|
|
"""Test search functionality."""
|
|
# Setup mock
|
|
mock_client = MagicMock()
|
|
self.mock_search.SearchClient.return_value = mock_client
|
|
|
|
# Mock search results
|
|
d = 1536
|
|
rng = np.random.default_rng(seed=42)
|
|
|
|
mock_results = MagicMock()
|
|
mock_results.results = [
|
|
{
|
|
"content": "hello world",
|
|
"id": "doc1",
|
|
"@search.score": 0.95
|
|
},
|
|
{
|
|
"content": "hello azure search",
|
|
"id": "doc2",
|
|
"@search.score": 0.85
|
|
}
|
|
]
|
|
mock_client._client.documents.search_post.return_value = mock_results
|
|
|
|
azure_search = self.AzureSearch(
|
|
endpoint="https://test-search.search.windows.net",
|
|
index_name="test-index",
|
|
api_key="test-key",
|
|
vector_field="content_vector"
|
|
)
|
|
|
|
# Test search
|
|
query_vector = rng.random(d).tolist()
|
|
results = azure_search.search_data(
|
|
collection="test-index",
|
|
vector=query_vector,
|
|
top_k=2
|
|
)
|
|
|
|
self.assertIsInstance(results, list)
|
|
self.assertEqual(len(results), 2)
|
|
# Verify results are RetrievalResult objects
|
|
for result in results:
|
|
self.assertIsInstance(result, self.RetrievalResult)
|
|
|
|
def test_clear_db(self):
|
|
"""Test clearing database."""
|
|
# Setup mock
|
|
mock_client = MagicMock()
|
|
self.mock_search.SearchClient.return_value = mock_client
|
|
|
|
# Mock search results for documents to delete
|
|
mock_client.search.return_value = [
|
|
{"id": "doc1"},
|
|
{"id": "doc2"}
|
|
]
|
|
|
|
azure_search = self.AzureSearch(
|
|
endpoint="https://test-search.search.windows.net",
|
|
index_name="test-index",
|
|
api_key="test-key",
|
|
vector_field="content_vector"
|
|
)
|
|
|
|
deleted_count = azure_search.clear_db()
|
|
self.assertEqual(deleted_count, 2)
|
|
|
|
def test_list_collections(self):
|
|
"""Test listing collections."""
|
|
# Setup mock
|
|
mock_index_client = MagicMock()
|
|
self.mock_indexes.SearchIndexClient.return_value = mock_index_client
|
|
|
|
# Mock list_indexes response
|
|
mock_index1 = MagicMock()
|
|
mock_index1.name = "test-index-1"
|
|
mock_index1.fields = ["field1", "field2"]
|
|
|
|
mock_index2 = MagicMock()
|
|
mock_index2.name = "test-index-2"
|
|
mock_index2.fields = ["field1", "field2", "field3"]
|
|
|
|
mock_index_client.list_indexes.return_value = [mock_index1, mock_index2]
|
|
|
|
azure_search = self.AzureSearch(
|
|
endpoint="https://test-search.search.windows.net",
|
|
index_name="test-index",
|
|
api_key="test-key",
|
|
vector_field="content_vector"
|
|
)
|
|
|
|
collections = azure_search.list_collections()
|
|
self.assertIsInstance(collections, list)
|
|
self.assertEqual(len(collections), 2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|