You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

237 lines
8.0 KiB

import unittest
from unittest.mock import patch, MagicMock
import numpy as np
import sys
from deepsearcher.vector_db import AzureSearch
from deepsearcher.vector_db.base import RetrievalResult
class TestAzureSearch(unittest.TestCase):
"""Tests for the Azure Search vector database implementation."""
def setUp(self):
"""Set up test fixtures."""
# Create mock modules
self.mock_azure = MagicMock()
self.mock_search = MagicMock()
self.mock_indexes = MagicMock()
self.mock_models = MagicMock()
self.mock_credentials = MagicMock()
self.mock_exceptions = MagicMock()
# Setup nested structure
self.mock_azure.search = self.mock_search
self.mock_search.documents = self.mock_search
self.mock_search.documents.indexes = self.mock_indexes
self.mock_indexes.models = self.mock_models
self.mock_azure.core = self.mock_credentials
self.mock_azure.core.credentials = self.mock_credentials
self.mock_azure.core.exceptions = self.mock_exceptions
# Mock specific models needed for init_collection
self.mock_models.SearchableField = MagicMock()
self.mock_models.SimpleField = MagicMock()
self.mock_models.SearchField = MagicMock()
self.mock_models.SearchIndex = MagicMock()
# Create the module patcher
self.module_patcher = patch.dict('sys.modules', {
'azure': self.mock_azure,
'azure.core': self.mock_credentials,
'azure.core.credentials': self.mock_credentials,
'azure.core.exceptions': self.mock_exceptions,
'azure.search': self.mock_search,
'azure.search.documents': self.mock_search,
'azure.search.documents.indexes': self.mock_indexes,
'azure.search.documents.indexes.models': self.mock_models
})
# Start the patcher
self.module_patcher.start()
# Import after mocking
from deepsearcher.vector_db import AzureSearch
from deepsearcher.vector_db.base import RetrievalResult
self.AzureSearch = AzureSearch
self.RetrievalResult = RetrievalResult
def tearDown(self):
"""Clean up test fixtures."""
self.module_patcher.stop()
def test_init(self):
"""Test basic initialization."""
# Setup mock
mock_client = MagicMock()
self.mock_search.SearchClient.return_value = mock_client
azure_search = self.AzureSearch(
endpoint="https://test-search.search.windows.net",
index_name="test-index",
api_key="test-key",
vector_field="content_vector"
)
# Verify initialization
self.assertEqual(azure_search.index_name, "test-index")
self.assertEqual(azure_search.endpoint, "https://test-search.search.windows.net")
self.assertEqual(azure_search.api_key, "test-key")
self.assertEqual(azure_search.vector_field, "content_vector")
self.assertIsNotNone(azure_search.client)
def test_init_collection(self):
"""Test collection initialization."""
# Setup mock
mock_index_client = MagicMock()
self.mock_indexes.SearchIndexClient.return_value = mock_index_client
mock_index_client.create_index.return_value = None
azure_search = self.AzureSearch(
endpoint="https://test-search.search.windows.net",
index_name="test-index",
api_key="test-key",
vector_field="content_vector"
)
azure_search.init_collection()
self.assertTrue(mock_index_client.create_index.called)
def test_insert_data(self):
"""Test inserting data."""
# Setup mock
mock_client = MagicMock()
self.mock_search.SearchClient.return_value = mock_client
# Mock successful upload result
mock_result = [MagicMock(succeeded=True) for _ in range(2)]
mock_client.upload_documents.return_value = mock_result
azure_search = self.AzureSearch(
endpoint="https://test-search.search.windows.net",
index_name="test-index",
api_key="test-key",
vector_field="content_vector"
)
# Create test data
d = 1536 # Azure Search expects 1536 dimensions
rng = np.random.default_rng(seed=42)
test_docs = [
{
"text": "hello world",
"vector": rng.random(d).tolist(),
"id": "doc1"
},
{
"text": "hello azure search",
"vector": rng.random(d).tolist(),
"id": "doc2"
}
]
results = azure_search.insert_data(documents=test_docs)
self.assertEqual(len(results), 2)
self.assertTrue(all(results))
def test_search_data(self):
"""Test search functionality."""
# Setup mock
mock_client = MagicMock()
self.mock_search.SearchClient.return_value = mock_client
# Mock search results
d = 1536
rng = np.random.default_rng(seed=42)
mock_results = MagicMock()
mock_results.results = [
{
"content": "hello world",
"id": "doc1",
"@search.score": 0.95
},
{
"content": "hello azure search",
"id": "doc2",
"@search.score": 0.85
}
]
mock_client._client.documents.search_post.return_value = mock_results
azure_search = self.AzureSearch(
endpoint="https://test-search.search.windows.net",
index_name="test-index",
api_key="test-key",
vector_field="content_vector"
)
# Test search
query_vector = rng.random(d).tolist()
results = azure_search.search_data(
collection="test-index",
vector=query_vector,
top_k=2
)
self.assertIsInstance(results, list)
self.assertEqual(len(results), 2)
# Verify results are RetrievalResult objects
for result in results:
self.assertIsInstance(result, self.RetrievalResult)
def test_clear_db(self):
"""Test clearing database."""
# Setup mock
mock_client = MagicMock()
self.mock_search.SearchClient.return_value = mock_client
# Mock search results for documents to delete
mock_client.search.return_value = [
{"id": "doc1"},
{"id": "doc2"}
]
azure_search = self.AzureSearch(
endpoint="https://test-search.search.windows.net",
index_name="test-index",
api_key="test-key",
vector_field="content_vector"
)
deleted_count = azure_search.clear_db()
self.assertEqual(deleted_count, 2)
def test_list_collections(self):
"""Test listing collections."""
# Setup mock
mock_index_client = MagicMock()
self.mock_indexes.SearchIndexClient.return_value = mock_index_client
# Mock list_indexes response
mock_index1 = MagicMock()
mock_index1.name = "test-index-1"
mock_index1.fields = ["field1", "field2"]
mock_index2 = MagicMock()
mock_index2.name = "test-index-2"
mock_index2.fields = ["field1", "field2", "field3"]
mock_index_client.list_indexes.return_value = [mock_index1, mock_index2]
azure_search = self.AzureSearch(
endpoint="https://test-search.search.windows.net",
index_name="test-index",
api_key="test-key",
vector_field="content_vector"
)
collections = azure_search.list_collections()
self.assertIsInstance(collections, list)
self.assertEqual(len(collections), 2)
if __name__ == "__main__":
unittest.main()