9 changed files with 46 additions and 545 deletions
@ -1,6 +1,5 @@ |
|||
from .azure_search import AzureSearch |
|||
from .milvus import Milvus, RetrievalResult |
|||
from .oracle import OracleDB |
|||
from .qdrant import Qdrant |
|||
|
|||
__all__ = ["Milvus", "RetrievalResult", "OracleDB", "Qdrant", "AzureSearch"] |
|||
__all__ = ["Milvus", "RetrievalResult", "OracleDB", "Qdrant"] |
|||
|
@ -1,279 +0,0 @@ |
|||
import uuid |
|||
from typing import Any, Dict, List, Optional |
|||
|
|||
from deepsearcher.vector_db.base import BaseVectorDB, CollectionInfo, RetrievalResult |
|||
|
|||
|
|||
class AzureSearch(BaseVectorDB): |
|||
def __init__(self, endpoint, index_name, api_key, vector_field): |
|||
super().__init__(default_collection=index_name) |
|||
from azure.core.credentials import AzureKeyCredential |
|||
from azure.search.documents import SearchClient |
|||
|
|||
self.client = SearchClient( |
|||
endpoint=endpoint, |
|||
index_name=index_name, |
|||
credential=AzureKeyCredential(api_key), |
|||
) |
|||
self.vector_field = vector_field |
|||
self.endpoint = endpoint |
|||
self.index_name = index_name |
|||
self.api_key = api_key |
|||
|
|||
def init_collection(self): |
|||
"""Initialize Azure Search index with proper schema""" |
|||
from azure.core.credentials import AzureKeyCredential |
|||
from azure.core.exceptions import ResourceNotFoundError |
|||
from azure.search.documents.indexes import SearchIndexClient |
|||
from azure.search.documents.indexes.models import ( |
|||
SearchableField, |
|||
SearchField, |
|||
SearchIndex, |
|||
SimpleField, |
|||
) |
|||
|
|||
index_client = SearchIndexClient( |
|||
endpoint=self.endpoint, credential=AzureKeyCredential(self.api_key) |
|||
) |
|||
|
|||
# Create the index (simplified for compatibility with older SDK versions) |
|||
fields = [ |
|||
SimpleField(name="id", type="Edm.String", key=True), |
|||
SearchableField(name="content", type="Edm.String"), |
|||
SearchField( |
|||
name="content_vector", |
|||
type="Collection(Edm.Single)", |
|||
searchable=True, |
|||
vector_search_dimensions=1536, |
|||
), |
|||
] |
|||
|
|||
# Create index with fields |
|||
index = SearchIndex(name=self.index_name, fields=fields) |
|||
|
|||
try: |
|||
# Try to delete existing index |
|||
try: |
|||
index_client.delete_index(self.index_name) |
|||
except ResourceNotFoundError: |
|||
pass |
|||
|
|||
# Create the index |
|||
index_client.create_index(index) |
|||
except Exception as e: |
|||
print(f"Error creating index: {str(e)}") |
|||
|
|||
def insert_data(self, documents: List[dict]): |
|||
"""Batch insert documents with vector embeddings""" |
|||
from azure.core.credentials import AzureKeyCredential |
|||
from azure.search.documents import SearchClient |
|||
|
|||
search_client = SearchClient( |
|||
endpoint=self.endpoint, |
|||
index_name=self.index_name, |
|||
credential=AzureKeyCredential(self.api_key), |
|||
) |
|||
|
|||
actions = [ |
|||
{ |
|||
"@search.action": "upload" if doc.get("id") else "merge", |
|||
"id": doc.get("id", str(uuid.uuid4())), |
|||
"content": doc["text"], |
|||
"content_vector": doc["vector"], |
|||
} |
|||
for doc in documents |
|||
] |
|||
|
|||
result = search_client.upload_documents(actions) |
|||
return [x.succeeded for x in result] |
|||
|
|||
def search_data( |
|||
self, collection: Optional[str], vector: List[float], top_k: int = 50 |
|||
) -> List[RetrievalResult]: |
|||
"""Azure Cognitive Search implementation with compatibility for older SDK versions""" |
|||
from azure.core.credentials import AzureKeyCredential |
|||
from azure.search.documents import SearchClient |
|||
|
|||
search_client = SearchClient( |
|||
endpoint=self.endpoint, |
|||
index_name=collection or self.index_name, |
|||
credential=AzureKeyCredential(self.api_key), |
|||
) |
|||
|
|||
# Validate that vector is not empty |
|||
if not vector or len(vector) == 0: |
|||
print("Error: Empty vector provided for search. Vector must have 1536 dimensions.") |
|||
return [] |
|||
|
|||
# Debug vector and field info |
|||
print(f"Vector length for search: {len(vector)}") |
|||
print(f"Vector field name: {self.vector_field}") |
|||
|
|||
# Ensure vector has the right dimensions |
|||
if len(vector) != 1536: |
|||
print(f"Warning: Vector length {len(vector)} does not match expected 1536 dimensions") |
|||
return [] |
|||
|
|||
# Execute search with direct parameters - simpler approach |
|||
try: |
|||
print(f"Executing search with top_k={top_k}") |
|||
|
|||
# Directly use the search_by_vector method for compatibility |
|||
body = { |
|||
"search": "*", |
|||
"select": "id,content", |
|||
"top": top_k, |
|||
"vectorQueries": [ |
|||
{ |
|||
"vector": vector, |
|||
"fields": self.vector_field, |
|||
"k": top_k, |
|||
"kind": "vector", |
|||
} |
|||
], |
|||
} |
|||
|
|||
# Print the search request body for debugging |
|||
print(f"Search request body: {body}") |
|||
|
|||
# Use the REST API directly |
|||
result = search_client._client.documents.search_post( |
|||
search_request=body, headers={"api-key": self.api_key} |
|||
) |
|||
|
|||
# Format results |
|||
search_results = [] |
|||
if hasattr(result, "results"): |
|||
for doc in result.results: |
|||
try: |
|||
doc_dict = doc.as_dict() if hasattr(doc, "as_dict") else doc |
|||
content = doc_dict.get("content", "") |
|||
doc_id = doc_dict.get("id", "") |
|||
score = doc_dict.get("@search.score", 0.0) |
|||
|
|||
result = RetrievalResult( |
|||
embedding=[], # We don't get the vectors back |
|||
text=content, |
|||
reference=doc_id, |
|||
metadata={"source": doc_id}, |
|||
score=score, |
|||
) |
|||
search_results.append(result) |
|||
except Exception as e: |
|||
print(f"Error processing result: {str(e)}") |
|||
|
|||
return search_results |
|||
except Exception as e: |
|||
print(f"Search error: {str(e)}") |
|||
|
|||
# Try another approach if the first one fails |
|||
try: |
|||
print("Trying alternative search method...") |
|||
results = search_client.search(search_text="*", select=["id", "content"], top=top_k) |
|||
|
|||
# Process results |
|||
alt_results = [] |
|||
for doc in results: |
|||
try: |
|||
# Handle different result formats |
|||
if isinstance(doc, dict): |
|||
content = doc.get("content", "") |
|||
doc_id = doc.get("id", "") |
|||
score = doc.get("@search.score", 0.0) |
|||
else: |
|||
content = getattr(doc, "content", "") |
|||
doc_id = getattr(doc, "id", "") |
|||
score = getattr(doc, "@search.score", 0.0) |
|||
|
|||
result = RetrievalResult( |
|||
embedding=[], |
|||
text=content, |
|||
reference=doc_id, |
|||
metadata={"source": doc_id}, |
|||
score=score, |
|||
) |
|||
alt_results.append(result) |
|||
except Exception as e: |
|||
print(f"Error processing result: {str(e)}") |
|||
|
|||
return alt_results |
|||
except Exception as e: |
|||
print(f"Alternative search failed: {str(e)}") |
|||
return [] |
|||
|
|||
def clear_db(self): |
|||
"""Delete all documents in the index""" |
|||
from azure.core.credentials import AzureKeyCredential |
|||
from azure.search.documents import SearchClient |
|||
|
|||
search_client = SearchClient( |
|||
endpoint=self.endpoint, |
|||
index_name=self.index_name, |
|||
credential=AzureKeyCredential(self.api_key), |
|||
) |
|||
|
|||
docs = search_client.search(search_text="*", include_total_count=True, select=["id"]) |
|||
ids = [doc["id"] for doc in docs] |
|||
|
|||
if ids: |
|||
search_client.delete_documents([{"id": id} for id in ids]) |
|||
|
|||
return len(ids) |
|||
|
|||
def get_all_collections(self) -> List[str]: |
|||
"""List all search indices in Azure Cognitive Search""" |
|||
from azure.core.credentials import AzureKeyCredential |
|||
from azure.search.documents.indexes import SearchIndexClient |
|||
|
|||
try: |
|||
index_client = SearchIndexClient( |
|||
endpoint=self.endpoint, credential=AzureKeyCredential(self.api_key) |
|||
) |
|||
return [index.name for index in index_client.list_indexes()] |
|||
except Exception as e: |
|||
print(f"Failed to list indices: {str(e)}") |
|||
return [] |
|||
|
|||
def get_collection_info(self, name: str) -> Dict[str, Any]: |
|||
"""Retrieve index metadata""" |
|||
from azure.core.credentials import AzureKeyCredential |
|||
from azure.search.documents.indexes import SearchIndexClient |
|||
|
|||
index_client = SearchIndexClient( |
|||
endpoint=self.endpoint, credential=AzureKeyCredential(self.api_key) |
|||
) |
|||
return index_client.get_index(name).__dict__ |
|||
|
|||
def collection_exists(self, name: str) -> bool: |
|||
"""Check index existence""" |
|||
from azure.core.exceptions import ResourceNotFoundError |
|||
|
|||
try: |
|||
self.get_collection_info(name) |
|||
return True |
|||
except ResourceNotFoundError: |
|||
return False |
|||
|
|||
def list_collections(self, *args, **kwargs) -> List[CollectionInfo]: |
|||
"""List all Azure Search indices with metadata""" |
|||
from azure.core.credentials import AzureKeyCredential |
|||
from azure.search.documents.indexes import SearchIndexClient |
|||
|
|||
try: |
|||
index_client = SearchIndexClient( |
|||
endpoint=self.endpoint, credential=AzureKeyCredential(self.api_key) |
|||
) |
|||
|
|||
collections = [] |
|||
for index in index_client.list_indexes(): |
|||
collections.append( |
|||
CollectionInfo( |
|||
collection_name=index.name, |
|||
description=f"Azure Search Index with {len(index.fields) if hasattr(index, 'fields') else 0} fields", |
|||
) |
|||
) |
|||
return collections |
|||
|
|||
except Exception as e: |
|||
print(f"Collection listing failed: {str(e)}") |
|||
return [] |
@ -1,247 +0,0 @@ |
|||
import hashlib |
|||
from typing import List |
|||
|
|||
import numpy as np |
|||
|
|||
from deepsearcher.loader.splitter import Chunk |
|||
from deepsearcher.vector_db.base import BaseVectorDB, RetrievalResult |
|||
|
|||
|
|||
def calculate_text_hash(text: str) -> str: |
|||
""" |
|||
计算文本的哈希值,用于作为主键 |
|||
|
|||
Args: |
|||
text (str): 输入文本 |
|||
|
|||
Returns: |
|||
str: 文本的MD5哈希值 |
|||
""" |
|||
return hashlib.md5(text.encode('utf-8')).hexdigest() |
|||
|
|||
|
|||
class DeduplicatedVectorDB: |
|||
""" |
|||
支持去重的向量数据库包装器 |
|||
|
|||
该类在现有向量数据库基础上增加以下功能: |
|||
1. 使用文本哈希值作为主键 |
|||
2. 避免插入重复数据 |
|||
3. 提供清理重复数据的方法 |
|||
""" |
|||
|
|||
def __init__(self, db: BaseVectorDB): |
|||
""" |
|||
初始化去重向量数据库 |
|||
|
|||
Args: |
|||
db (BaseVectorDB): 底层向量数据库实例 |
|||
""" |
|||
self.db = db |
|||
|
|||
def init_collection(self, dim: int, collection: str, description: str, |
|||
force_new_collection=False, *args, **kwargs): |
|||
""" |
|||
初始化集合 |
|||
|
|||
Args: |
|||
dim (int): 向量维度 |
|||
collection (str): 集合名称 |
|||
description (str): 集合描述 |
|||
force_new_collection (bool): 是否强制创建新集合 |
|||
*args: 其他参数 |
|||
**kwargs: 其他关键字参数 |
|||
""" |
|||
return self.db.init_collection(dim, collection, description, |
|||
force_new_collection, *args, **kwargs) |
|||
|
|||
def insert_data(self, collection: str, chunks: List[Chunk], |
|||
batch_size: int = 256, *args, **kwargs): |
|||
""" |
|||
插入数据,避免重复 |
|||
|
|||
该方法会先检查数据库中是否已存在相同文本的哈希值, |
|||
如果不存在才进行插入操作 |
|||
|
|||
Args: |
|||
collection (str): 集合名称 |
|||
chunks (List[Chunk]): 要插入的数据块列表 |
|||
batch_size (int): 批处理大小 |
|||
*args: 其他参数 |
|||
**kwargs: 其他关键字参数 |
|||
""" |
|||
# 为每个chunk计算哈希值并添加到metadata中 |
|||
for chunk in chunks: |
|||
if 'hash' not in chunk.metadata: |
|||
chunk.metadata['hash'] = calculate_text_hash(chunk.text) |
|||
|
|||
# 过滤掉已存在的数据 |
|||
filtered_chunks = self._filter_duplicate_chunks(collection, chunks) |
|||
|
|||
# 插入过滤后的数据 |
|||
if filtered_chunks: |
|||
return self.db.insert_data(collection, filtered_chunks, batch_size, *args, **kwargs) |
|||
|
|||
def _filter_duplicate_chunks(self, collection: str, chunks: List[Chunk]) -> List[Chunk]: |
|||
""" |
|||
过滤掉已存在于数据库中的chunks |
|||
|
|||
Args: |
|||
collection (str): 集合名称 |
|||
chunks (List[Chunk]): 待过滤的数据块列表 |
|||
|
|||
Returns: |
|||
List[Chunk]: 过滤后的数据块列表 |
|||
""" |
|||
# 注意:这个实现依赖于具体的数据库实现 |
|||
# 对于生产环境,应该使用数据库特定的查询方法来检查重复 |
|||
# 这里提供一个通用但效率较低的实现 |
|||
return chunks |
|||
|
|||
def search_data(self, collection: str, vector: np.array, *args, **kwargs) -> List[RetrievalResult]: |
|||
""" |
|||
搜索数据 |
|||
|
|||
Args: |
|||
collection (str): 集合名称 |
|||
vector (np.array): 查询向量 |
|||
*args: 其他参数 |
|||
**kwargs: 其他关键字参数 |
|||
|
|||
Returns: |
|||
List[RetrievalResult]: 搜索结果 |
|||
""" |
|||
return self.db.search_data(collection, vector, *args, **kwargs) |
|||
|
|||
def list_collections(self, *args, **kwargs): |
|||
""" |
|||
列出所有集合 |
|||
|
|||
Args: |
|||
*args: 其他参数 |
|||
**kwargs: 其他关键字参数 |
|||
|
|||
Returns: |
|||
集合信息列表 |
|||
""" |
|||
return self.db.list_collections(*args, **kwargs) |
|||
|
|||
def clear_db(self, collection: str = None, *args, **kwargs): |
|||
""" |
|||
清空数据库集合 |
|||
|
|||
Args: |
|||
collection (str): 集合名称 |
|||
*args: 其他参数 |
|||
**kwargs: 其他关键字参数 |
|||
""" |
|||
return self.db.clear_db(collection, *args, **kwargs) |
|||
|
|||
def remove_duplicate_data(self, collection: str = None) -> int: |
|||
""" |
|||
移除数据库中的重复数据 |
|||
|
|||
注意:这是一个通用实现,具体数据库可能有更高效的实现方式 |
|||
|
|||
Args: |
|||
collection (str): 集合名称 |
|||
|
|||
Returns: |
|||
int: 移除的重复记录数 |
|||
""" |
|||
# 这个方法的具体实现需要根据数据库类型来定制 |
|||
# 这里仅提供一个概念性的实现 |
|||
raise NotImplementedError("需要根据具体的数据库实现去重逻辑") |
|||
|
|||
|
|||
# 为Qdrant实现专门的去重工具 |
|||
class QdrantDeduplicatedVectorDB(DeduplicatedVectorDB): |
|||
""" |
|||
Qdrant专用的去重向量数据库实现 |
|||
""" |
|||
|
|||
def _filter_duplicate_chunks(self, collection: str, chunks: List[Chunk]) -> List[Chunk]: |
|||
""" |
|||
过滤掉已存在于Qdrant数据库中的chunks |
|||
|
|||
Args: |
|||
collection (str): 集合名称 |
|||
chunks (List[Chunk]): 待过滤的数据块列表 |
|||
|
|||
Returns: |
|||
List[Chunk]: 过滤后的数据块列表 |
|||
""" |
|||
try: |
|||
# 获取Qdrant客户端 |
|||
qdrant_client = self.db.client |
|||
|
|||
# 收集所有要检查的哈希值 |
|||
hashes = [calculate_text_hash(chunk.text) for chunk in chunks] |
|||
|
|||
# 查询已存在的记录 |
|||
# 注意:这需要Qdrant支持按payload过滤,具体实现可能需要调整 |
|||
filtered_chunks = [] |
|||
for chunk, hash_value in zip(chunks, hashes): |
|||
# 这里应该使用Qdrant的查询API检查是否已存在该hash |
|||
# 由于Qdrant的API限制,可能需要使用scroll或search功能实现 |
|||
filtered_chunks.append(chunk) |
|||
|
|||
return filtered_chunks |
|||
except Exception: |
|||
# 出错时返回所有chunks(不进行过滤) |
|||
return chunks |
|||
|
|||
def remove_duplicate_data(self, collection: str = None) -> int: |
|||
""" |
|||
移除Qdrant数据库中的重复数据 |
|||
|
|||
Args: |
|||
collection (str): 集合名称 |
|||
|
|||
Returns: |
|||
int: 移除的重复记录数 |
|||
""" |
|||
# Qdrant去重实现 |
|||
# 这需要先检索所有数据,识别重复项,然后删除重复项 |
|||
# 具体实现略 |
|||
return 0 |
|||
|
|||
|
|||
# 为Milvus实现专门的去重工具 |
|||
class MilvusDeduplicatedVectorDB(DeduplicatedVectorDB): |
|||
""" |
|||
Milvus专用的去重向量数据库实现 |
|||
""" |
|||
|
|||
def insert_data(self, collection: str, chunks: List[Chunk], |
|||
batch_size: int = 256, *args, **kwargs): |
|||
""" |
|||
插入数据到Milvus,使用文本哈希作为主键 |
|||
|
|||
Args: |
|||
collection (str): 集合名称 |
|||
chunks (List[Chunk]): 要插入的数据块列表 |
|||
batch_size (int): 批处理大小 |
|||
*args: 其他参数 |
|||
**kwargs: 其他关键字参数 |
|||
""" |
|||
# 为每个chunk设置ID为文本哈希值 |
|||
for chunk in chunks: |
|||
chunk.metadata['id'] = calculate_text_hash(chunk.text) |
|||
|
|||
# 调用父类方法进行插入 |
|||
return super().insert_data(collection, chunks, batch_size, *args, **kwargs) |
|||
|
|||
def remove_duplicate_data(self, collection: str = None) -> int: |
|||
""" |
|||
移除Milvus数据库中的重复数据 |
|||
|
|||
Args: |
|||
collection (str): 集合名称 |
|||
|
|||
Returns: |
|||
int: 移除的重复记录数 |
|||
""" |
|||
# Milvus去重实现 |
|||
# 可以通过查询所有数据,找出ID重复的记录并删除 |
|||
return 0 |
Loading…
Reference in new issue