You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
279 lines
10 KiB
279 lines
10 KiB
import uuid
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from deepsearcher.vector_db.base import BaseVectorDB, CollectionInfo, RetrievalResult
|
|
|
|
|
|
class AzureSearch(BaseVectorDB):
|
|
def __init__(self, endpoint, index_name, api_key, vector_field):
|
|
super().__init__(default_collection=index_name)
|
|
from azure.core.credentials import AzureKeyCredential
|
|
from azure.search.documents import SearchClient
|
|
|
|
self.client = SearchClient(
|
|
endpoint=endpoint,
|
|
index_name=index_name,
|
|
credential=AzureKeyCredential(api_key),
|
|
)
|
|
self.vector_field = vector_field
|
|
self.endpoint = endpoint
|
|
self.index_name = index_name
|
|
self.api_key = api_key
|
|
|
|
def init_collection(self):
|
|
"""Initialize Azure Search index with proper schema"""
|
|
from azure.core.credentials import AzureKeyCredential
|
|
from azure.core.exceptions import ResourceNotFoundError
|
|
from azure.search.documents.indexes import SearchIndexClient
|
|
from azure.search.documents.indexes.models import (
|
|
SearchableField,
|
|
SearchField,
|
|
SearchIndex,
|
|
SimpleField,
|
|
)
|
|
|
|
index_client = SearchIndexClient(
|
|
endpoint=self.endpoint, credential=AzureKeyCredential(self.api_key)
|
|
)
|
|
|
|
# Create the index (simplified for compatibility with older SDK versions)
|
|
fields = [
|
|
SimpleField(name="id", type="Edm.String", key=True),
|
|
SearchableField(name="content", type="Edm.String"),
|
|
SearchField(
|
|
name="content_vector",
|
|
type="Collection(Edm.Single)",
|
|
searchable=True,
|
|
vector_search_dimensions=1536,
|
|
),
|
|
]
|
|
|
|
# Create index with fields
|
|
index = SearchIndex(name=self.index_name, fields=fields)
|
|
|
|
try:
|
|
# Try to delete existing index
|
|
try:
|
|
index_client.delete_index(self.index_name)
|
|
except ResourceNotFoundError:
|
|
pass
|
|
|
|
# Create the index
|
|
index_client.create_index(index)
|
|
except Exception as e:
|
|
print(f"Error creating index: {str(e)}")
|
|
|
|
def insert_data(self, documents: List[dict]):
|
|
"""Batch insert documents with vector embeddings"""
|
|
from azure.core.credentials import AzureKeyCredential
|
|
from azure.search.documents import SearchClient
|
|
|
|
search_client = SearchClient(
|
|
endpoint=self.endpoint,
|
|
index_name=self.index_name,
|
|
credential=AzureKeyCredential(self.api_key),
|
|
)
|
|
|
|
actions = [
|
|
{
|
|
"@search.action": "upload" if doc.get("id") else "merge",
|
|
"id": doc.get("id", str(uuid.uuid4())),
|
|
"content": doc["text"],
|
|
"content_vector": doc["vector"],
|
|
}
|
|
for doc in documents
|
|
]
|
|
|
|
result = search_client.upload_documents(actions)
|
|
return [x.succeeded for x in result]
|
|
|
|
def search_data(
|
|
self, collection: Optional[str], vector: List[float], top_k: int = 50
|
|
) -> List[RetrievalResult]:
|
|
"""Azure Cognitive Search implementation with compatibility for older SDK versions"""
|
|
from azure.core.credentials import AzureKeyCredential
|
|
from azure.search.documents import SearchClient
|
|
|
|
search_client = SearchClient(
|
|
endpoint=self.endpoint,
|
|
index_name=collection or self.index_name,
|
|
credential=AzureKeyCredential(self.api_key),
|
|
)
|
|
|
|
# Validate that vector is not empty
|
|
if not vector or len(vector) == 0:
|
|
print("Error: Empty vector provided for search. Vector must have 1536 dimensions.")
|
|
return []
|
|
|
|
# Debug vector and field info
|
|
print(f"Vector length for search: {len(vector)}")
|
|
print(f"Vector field name: {self.vector_field}")
|
|
|
|
# Ensure vector has the right dimensions
|
|
if len(vector) != 1536:
|
|
print(f"Warning: Vector length {len(vector)} does not match expected 1536 dimensions")
|
|
return []
|
|
|
|
# Execute search with direct parameters - simpler approach
|
|
try:
|
|
print(f"Executing search with top_k={top_k}")
|
|
|
|
# Directly use the search_by_vector method for compatibility
|
|
body = {
|
|
"search": "*",
|
|
"select": "id,content",
|
|
"top": top_k,
|
|
"vectorQueries": [
|
|
{
|
|
"vector": vector,
|
|
"fields": self.vector_field,
|
|
"k": top_k,
|
|
"kind": "vector",
|
|
}
|
|
],
|
|
}
|
|
|
|
# Print the search request body for debugging
|
|
print(f"Search request body: {body}")
|
|
|
|
# Use the REST API directly
|
|
result = search_client._client.documents.search_post(
|
|
search_request=body, headers={"api-key": self.api_key}
|
|
)
|
|
|
|
# Format results
|
|
search_results = []
|
|
if hasattr(result, "results"):
|
|
for doc in result.results:
|
|
try:
|
|
doc_dict = doc.as_dict() if hasattr(doc, "as_dict") else doc
|
|
content = doc_dict.get("content", "")
|
|
doc_id = doc_dict.get("id", "")
|
|
score = doc_dict.get("@search.score", 0.0)
|
|
|
|
result = RetrievalResult(
|
|
embedding=[], # We don't get the vectors back
|
|
text=content,
|
|
reference=doc_id,
|
|
metadata={"source": doc_id},
|
|
score=score,
|
|
)
|
|
search_results.append(result)
|
|
except Exception as e:
|
|
print(f"Error processing result: {str(e)}")
|
|
|
|
return search_results
|
|
except Exception as e:
|
|
print(f"Search error: {str(e)}")
|
|
|
|
# Try another approach if the first one fails
|
|
try:
|
|
print("Trying alternative search method...")
|
|
results = search_client.search(search_text="*", select=["id", "content"], top=top_k)
|
|
|
|
# Process results
|
|
alt_results = []
|
|
for doc in results:
|
|
try:
|
|
# Handle different result formats
|
|
if isinstance(doc, dict):
|
|
content = doc.get("content", "")
|
|
doc_id = doc.get("id", "")
|
|
score = doc.get("@search.score", 0.0)
|
|
else:
|
|
content = getattr(doc, "content", "")
|
|
doc_id = getattr(doc, "id", "")
|
|
score = getattr(doc, "@search.score", 0.0)
|
|
|
|
result = RetrievalResult(
|
|
embedding=[],
|
|
text=content,
|
|
reference=doc_id,
|
|
metadata={"source": doc_id},
|
|
score=score,
|
|
)
|
|
alt_results.append(result)
|
|
except Exception as e:
|
|
print(f"Error processing result: {str(e)}")
|
|
|
|
return alt_results
|
|
except Exception as e:
|
|
print(f"Alternative search failed: {str(e)}")
|
|
return []
|
|
|
|
def clear_db(self):
|
|
"""Delete all documents in the index"""
|
|
from azure.core.credentials import AzureKeyCredential
|
|
from azure.search.documents import SearchClient
|
|
|
|
search_client = SearchClient(
|
|
endpoint=self.endpoint,
|
|
index_name=self.index_name,
|
|
credential=AzureKeyCredential(self.api_key),
|
|
)
|
|
|
|
docs = search_client.search(search_text="*", include_total_count=True, select=["id"])
|
|
ids = [doc["id"] for doc in docs]
|
|
|
|
if ids:
|
|
search_client.delete_documents([{"id": id} for id in ids])
|
|
|
|
return len(ids)
|
|
|
|
def get_all_collections(self) -> List[str]:
|
|
"""List all search indices in Azure Cognitive Search"""
|
|
from azure.core.credentials import AzureKeyCredential
|
|
from azure.search.documents.indexes import SearchIndexClient
|
|
|
|
try:
|
|
index_client = SearchIndexClient(
|
|
endpoint=self.endpoint, credential=AzureKeyCredential(self.api_key)
|
|
)
|
|
return [index.name for index in index_client.list_indexes()]
|
|
except Exception as e:
|
|
print(f"Failed to list indices: {str(e)}")
|
|
return []
|
|
|
|
def get_collection_info(self, name: str) -> Dict[str, Any]:
|
|
"""Retrieve index metadata"""
|
|
from azure.core.credentials import AzureKeyCredential
|
|
from azure.search.documents.indexes import SearchIndexClient
|
|
|
|
index_client = SearchIndexClient(
|
|
endpoint=self.endpoint, credential=AzureKeyCredential(self.api_key)
|
|
)
|
|
return index_client.get_index(name).__dict__
|
|
|
|
def collection_exists(self, name: str) -> bool:
|
|
"""Check index existence"""
|
|
from azure.core.exceptions import ResourceNotFoundError
|
|
|
|
try:
|
|
self.get_collection_info(name)
|
|
return True
|
|
except ResourceNotFoundError:
|
|
return False
|
|
|
|
def list_collections(self, *args, **kwargs) -> List[CollectionInfo]:
|
|
"""List all Azure Search indices with metadata"""
|
|
from azure.core.credentials import AzureKeyCredential
|
|
from azure.search.documents.indexes import SearchIndexClient
|
|
|
|
try:
|
|
index_client = SearchIndexClient(
|
|
endpoint=self.endpoint, credential=AzureKeyCredential(self.api_key)
|
|
)
|
|
|
|
collections = []
|
|
for index in index_client.list_indexes():
|
|
collections.append(
|
|
CollectionInfo(
|
|
collection_name=index.name,
|
|
description=f"Azure Search Index with {len(index.fields) if hasattr(index, 'fields') else 0} fields",
|
|
)
|
|
)
|
|
return collections
|
|
|
|
except Exception as e:
|
|
print(f"Collection listing failed: {str(e)}")
|
|
return []
|
|
|