9 changed files with 46 additions and 545 deletions
@ -1,6 +1,5 @@ |
|||||
from .azure_search import AzureSearch |
|
||||
from .milvus import Milvus, RetrievalResult |
from .milvus import Milvus, RetrievalResult |
||||
from .oracle import OracleDB |
from .oracle import OracleDB |
||||
from .qdrant import Qdrant |
from .qdrant import Qdrant |
||||
|
|
||||
__all__ = ["Milvus", "RetrievalResult", "OracleDB", "Qdrant", "AzureSearch"] |
__all__ = ["Milvus", "RetrievalResult", "OracleDB", "Qdrant"] |
||||
|
@ -1,279 +0,0 @@ |
|||||
import uuid |
|
||||
from typing import Any, Dict, List, Optional |
|
||||
|
|
||||
from deepsearcher.vector_db.base import BaseVectorDB, CollectionInfo, RetrievalResult |
|
||||
|
|
||||
|
|
||||
class AzureSearch(BaseVectorDB): |
|
||||
def __init__(self, endpoint, index_name, api_key, vector_field): |
|
||||
super().__init__(default_collection=index_name) |
|
||||
from azure.core.credentials import AzureKeyCredential |
|
||||
from azure.search.documents import SearchClient |
|
||||
|
|
||||
self.client = SearchClient( |
|
||||
endpoint=endpoint, |
|
||||
index_name=index_name, |
|
||||
credential=AzureKeyCredential(api_key), |
|
||||
) |
|
||||
self.vector_field = vector_field |
|
||||
self.endpoint = endpoint |
|
||||
self.index_name = index_name |
|
||||
self.api_key = api_key |
|
||||
|
|
||||
def init_collection(self): |
|
||||
"""Initialize Azure Search index with proper schema""" |
|
||||
from azure.core.credentials import AzureKeyCredential |
|
||||
from azure.core.exceptions import ResourceNotFoundError |
|
||||
from azure.search.documents.indexes import SearchIndexClient |
|
||||
from azure.search.documents.indexes.models import ( |
|
||||
SearchableField, |
|
||||
SearchField, |
|
||||
SearchIndex, |
|
||||
SimpleField, |
|
||||
) |
|
||||
|
|
||||
index_client = SearchIndexClient( |
|
||||
endpoint=self.endpoint, credential=AzureKeyCredential(self.api_key) |
|
||||
) |
|
||||
|
|
||||
# Create the index (simplified for compatibility with older SDK versions) |
|
||||
fields = [ |
|
||||
SimpleField(name="id", type="Edm.String", key=True), |
|
||||
SearchableField(name="content", type="Edm.String"), |
|
||||
SearchField( |
|
||||
name="content_vector", |
|
||||
type="Collection(Edm.Single)", |
|
||||
searchable=True, |
|
||||
vector_search_dimensions=1536, |
|
||||
), |
|
||||
] |
|
||||
|
|
||||
# Create index with fields |
|
||||
index = SearchIndex(name=self.index_name, fields=fields) |
|
||||
|
|
||||
try: |
|
||||
# Try to delete existing index |
|
||||
try: |
|
||||
index_client.delete_index(self.index_name) |
|
||||
except ResourceNotFoundError: |
|
||||
pass |
|
||||
|
|
||||
# Create the index |
|
||||
index_client.create_index(index) |
|
||||
except Exception as e: |
|
||||
print(f"Error creating index: {str(e)}") |
|
||||
|
|
||||
def insert_data(self, documents: List[dict]): |
|
||||
"""Batch insert documents with vector embeddings""" |
|
||||
from azure.core.credentials import AzureKeyCredential |
|
||||
from azure.search.documents import SearchClient |
|
||||
|
|
||||
search_client = SearchClient( |
|
||||
endpoint=self.endpoint, |
|
||||
index_name=self.index_name, |
|
||||
credential=AzureKeyCredential(self.api_key), |
|
||||
) |
|
||||
|
|
||||
actions = [ |
|
||||
{ |
|
||||
"@search.action": "upload" if doc.get("id") else "merge", |
|
||||
"id": doc.get("id", str(uuid.uuid4())), |
|
||||
"content": doc["text"], |
|
||||
"content_vector": doc["vector"], |
|
||||
} |
|
||||
for doc in documents |
|
||||
] |
|
||||
|
|
||||
result = search_client.upload_documents(actions) |
|
||||
return [x.succeeded for x in result] |
|
||||
|
|
||||
def search_data( |
|
||||
self, collection: Optional[str], vector: List[float], top_k: int = 50 |
|
||||
) -> List[RetrievalResult]: |
|
||||
"""Azure Cognitive Search implementation with compatibility for older SDK versions""" |
|
||||
from azure.core.credentials import AzureKeyCredential |
|
||||
from azure.search.documents import SearchClient |
|
||||
|
|
||||
search_client = SearchClient( |
|
||||
endpoint=self.endpoint, |
|
||||
index_name=collection or self.index_name, |
|
||||
credential=AzureKeyCredential(self.api_key), |
|
||||
) |
|
||||
|
|
||||
# Validate that vector is not empty |
|
||||
if not vector or len(vector) == 0: |
|
||||
print("Error: Empty vector provided for search. Vector must have 1536 dimensions.") |
|
||||
return [] |
|
||||
|
|
||||
# Debug vector and field info |
|
||||
print(f"Vector length for search: {len(vector)}") |
|
||||
print(f"Vector field name: {self.vector_field}") |
|
||||
|
|
||||
# Ensure vector has the right dimensions |
|
||||
if len(vector) != 1536: |
|
||||
print(f"Warning: Vector length {len(vector)} does not match expected 1536 dimensions") |
|
||||
return [] |
|
||||
|
|
||||
# Execute search with direct parameters - simpler approach |
|
||||
try: |
|
||||
print(f"Executing search with top_k={top_k}") |
|
||||
|
|
||||
# Directly use the search_by_vector method for compatibility |
|
||||
body = { |
|
||||
"search": "*", |
|
||||
"select": "id,content", |
|
||||
"top": top_k, |
|
||||
"vectorQueries": [ |
|
||||
{ |
|
||||
"vector": vector, |
|
||||
"fields": self.vector_field, |
|
||||
"k": top_k, |
|
||||
"kind": "vector", |
|
||||
} |
|
||||
], |
|
||||
} |
|
||||
|
|
||||
# Print the search request body for debugging |
|
||||
print(f"Search request body: {body}") |
|
||||
|
|
||||
# Use the REST API directly |
|
||||
result = search_client._client.documents.search_post( |
|
||||
search_request=body, headers={"api-key": self.api_key} |
|
||||
) |
|
||||
|
|
||||
# Format results |
|
||||
search_results = [] |
|
||||
if hasattr(result, "results"): |
|
||||
for doc in result.results: |
|
||||
try: |
|
||||
doc_dict = doc.as_dict() if hasattr(doc, "as_dict") else doc |
|
||||
content = doc_dict.get("content", "") |
|
||||
doc_id = doc_dict.get("id", "") |
|
||||
score = doc_dict.get("@search.score", 0.0) |
|
||||
|
|
||||
result = RetrievalResult( |
|
||||
embedding=[], # We don't get the vectors back |
|
||||
text=content, |
|
||||
reference=doc_id, |
|
||||
metadata={"source": doc_id}, |
|
||||
score=score, |
|
||||
) |
|
||||
search_results.append(result) |
|
||||
except Exception as e: |
|
||||
print(f"Error processing result: {str(e)}") |
|
||||
|
|
||||
return search_results |
|
||||
except Exception as e: |
|
||||
print(f"Search error: {str(e)}") |
|
||||
|
|
||||
# Try another approach if the first one fails |
|
||||
try: |
|
||||
print("Trying alternative search method...") |
|
||||
results = search_client.search(search_text="*", select=["id", "content"], top=top_k) |
|
||||
|
|
||||
# Process results |
|
||||
alt_results = [] |
|
||||
for doc in results: |
|
||||
try: |
|
||||
# Handle different result formats |
|
||||
if isinstance(doc, dict): |
|
||||
content = doc.get("content", "") |
|
||||
doc_id = doc.get("id", "") |
|
||||
score = doc.get("@search.score", 0.0) |
|
||||
else: |
|
||||
content = getattr(doc, "content", "") |
|
||||
doc_id = getattr(doc, "id", "") |
|
||||
score = getattr(doc, "@search.score", 0.0) |
|
||||
|
|
||||
result = RetrievalResult( |
|
||||
embedding=[], |
|
||||
text=content, |
|
||||
reference=doc_id, |
|
||||
metadata={"source": doc_id}, |
|
||||
score=score, |
|
||||
) |
|
||||
alt_results.append(result) |
|
||||
except Exception as e: |
|
||||
print(f"Error processing result: {str(e)}") |
|
||||
|
|
||||
return alt_results |
|
||||
except Exception as e: |
|
||||
print(f"Alternative search failed: {str(e)}") |
|
||||
return [] |
|
||||
|
|
||||
def clear_db(self): |
|
||||
"""Delete all documents in the index""" |
|
||||
from azure.core.credentials import AzureKeyCredential |
|
||||
from azure.search.documents import SearchClient |
|
||||
|
|
||||
search_client = SearchClient( |
|
||||
endpoint=self.endpoint, |
|
||||
index_name=self.index_name, |
|
||||
credential=AzureKeyCredential(self.api_key), |
|
||||
) |
|
||||
|
|
||||
docs = search_client.search(search_text="*", include_total_count=True, select=["id"]) |
|
||||
ids = [doc["id"] for doc in docs] |
|
||||
|
|
||||
if ids: |
|
||||
search_client.delete_documents([{"id": id} for id in ids]) |
|
||||
|
|
||||
return len(ids) |
|
||||
|
|
||||
def get_all_collections(self) -> List[str]: |
|
||||
"""List all search indices in Azure Cognitive Search""" |
|
||||
from azure.core.credentials import AzureKeyCredential |
|
||||
from azure.search.documents.indexes import SearchIndexClient |
|
||||
|
|
||||
try: |
|
||||
index_client = SearchIndexClient( |
|
||||
endpoint=self.endpoint, credential=AzureKeyCredential(self.api_key) |
|
||||
) |
|
||||
return [index.name for index in index_client.list_indexes()] |
|
||||
except Exception as e: |
|
||||
print(f"Failed to list indices: {str(e)}") |
|
||||
return [] |
|
||||
|
|
||||
def get_collection_info(self, name: str) -> Dict[str, Any]: |
|
||||
"""Retrieve index metadata""" |
|
||||
from azure.core.credentials import AzureKeyCredential |
|
||||
from azure.search.documents.indexes import SearchIndexClient |
|
||||
|
|
||||
index_client = SearchIndexClient( |
|
||||
endpoint=self.endpoint, credential=AzureKeyCredential(self.api_key) |
|
||||
) |
|
||||
return index_client.get_index(name).__dict__ |
|
||||
|
|
||||
def collection_exists(self, name: str) -> bool: |
|
||||
"""Check index existence""" |
|
||||
from azure.core.exceptions import ResourceNotFoundError |
|
||||
|
|
||||
try: |
|
||||
self.get_collection_info(name) |
|
||||
return True |
|
||||
except ResourceNotFoundError: |
|
||||
return False |
|
||||
|
|
||||
def list_collections(self, *args, **kwargs) -> List[CollectionInfo]: |
|
||||
"""List all Azure Search indices with metadata""" |
|
||||
from azure.core.credentials import AzureKeyCredential |
|
||||
from azure.search.documents.indexes import SearchIndexClient |
|
||||
|
|
||||
try: |
|
||||
index_client = SearchIndexClient( |
|
||||
endpoint=self.endpoint, credential=AzureKeyCredential(self.api_key) |
|
||||
) |
|
||||
|
|
||||
collections = [] |
|
||||
for index in index_client.list_indexes(): |
|
||||
collections.append( |
|
||||
CollectionInfo( |
|
||||
collection_name=index.name, |
|
||||
description=f"Azure Search Index with {len(index.fields) if hasattr(index, 'fields') else 0} fields", |
|
||||
) |
|
||||
) |
|
||||
return collections |
|
||||
|
|
||||
except Exception as e: |
|
||||
print(f"Collection listing failed: {str(e)}") |
|
||||
return [] |
|
@ -1,247 +0,0 @@ |
|||||
import hashlib |
|
||||
from typing import List |
|
||||
|
|
||||
import numpy as np |
|
||||
|
|
||||
from deepsearcher.loader.splitter import Chunk |
|
||||
from deepsearcher.vector_db.base import BaseVectorDB, RetrievalResult |
|
||||
|
|
||||
|
|
||||
def calculate_text_hash(text: str) -> str: |
|
||||
""" |
|
||||
计算文本的哈希值,用于作为主键 |
|
||||
|
|
||||
Args: |
|
||||
text (str): 输入文本 |
|
||||
|
|
||||
Returns: |
|
||||
str: 文本的MD5哈希值 |
|
||||
""" |
|
||||
return hashlib.md5(text.encode('utf-8')).hexdigest() |
|
||||
|
|
||||
|
|
||||
class DeduplicatedVectorDB: |
|
||||
""" |
|
||||
支持去重的向量数据库包装器 |
|
||||
|
|
||||
该类在现有向量数据库基础上增加以下功能: |
|
||||
1. 使用文本哈希值作为主键 |
|
||||
2. 避免插入重复数据 |
|
||||
3. 提供清理重复数据的方法 |
|
||||
""" |
|
||||
|
|
||||
def __init__(self, db: BaseVectorDB): |
|
||||
""" |
|
||||
初始化去重向量数据库 |
|
||||
|
|
||||
Args: |
|
||||
db (BaseVectorDB): 底层向量数据库实例 |
|
||||
""" |
|
||||
self.db = db |
|
||||
|
|
||||
def init_collection(self, dim: int, collection: str, description: str, |
|
||||
force_new_collection=False, *args, **kwargs): |
|
||||
""" |
|
||||
初始化集合 |
|
||||
|
|
||||
Args: |
|
||||
dim (int): 向量维度 |
|
||||
collection (str): 集合名称 |
|
||||
description (str): 集合描述 |
|
||||
force_new_collection (bool): 是否强制创建新集合 |
|
||||
*args: 其他参数 |
|
||||
**kwargs: 其他关键字参数 |
|
||||
""" |
|
||||
return self.db.init_collection(dim, collection, description, |
|
||||
force_new_collection, *args, **kwargs) |
|
||||
|
|
||||
def insert_data(self, collection: str, chunks: List[Chunk], |
|
||||
batch_size: int = 256, *args, **kwargs): |
|
||||
""" |
|
||||
插入数据,避免重复 |
|
||||
|
|
||||
该方法会先检查数据库中是否已存在相同文本的哈希值, |
|
||||
如果不存在才进行插入操作 |
|
||||
|
|
||||
Args: |
|
||||
collection (str): 集合名称 |
|
||||
chunks (List[Chunk]): 要插入的数据块列表 |
|
||||
batch_size (int): 批处理大小 |
|
||||
*args: 其他参数 |
|
||||
**kwargs: 其他关键字参数 |
|
||||
""" |
|
||||
# 为每个chunk计算哈希值并添加到metadata中 |
|
||||
for chunk in chunks: |
|
||||
if 'hash' not in chunk.metadata: |
|
||||
chunk.metadata['hash'] = calculate_text_hash(chunk.text) |
|
||||
|
|
||||
# 过滤掉已存在的数据 |
|
||||
filtered_chunks = self._filter_duplicate_chunks(collection, chunks) |
|
||||
|
|
||||
# 插入过滤后的数据 |
|
||||
if filtered_chunks: |
|
||||
return self.db.insert_data(collection, filtered_chunks, batch_size, *args, **kwargs) |
|
||||
|
|
||||
def _filter_duplicate_chunks(self, collection: str, chunks: List[Chunk]) -> List[Chunk]: |
|
||||
""" |
|
||||
过滤掉已存在于数据库中的chunks |
|
||||
|
|
||||
Args: |
|
||||
collection (str): 集合名称 |
|
||||
chunks (List[Chunk]): 待过滤的数据块列表 |
|
||||
|
|
||||
Returns: |
|
||||
List[Chunk]: 过滤后的数据块列表 |
|
||||
""" |
|
||||
# 注意:这个实现依赖于具体的数据库实现 |
|
||||
# 对于生产环境,应该使用数据库特定的查询方法来检查重复 |
|
||||
# 这里提供一个通用但效率较低的实现 |
|
||||
return chunks |
|
||||
|
|
||||
def search_data(self, collection: str, vector: np.array, *args, **kwargs) -> List[RetrievalResult]: |
|
||||
""" |
|
||||
搜索数据 |
|
||||
|
|
||||
Args: |
|
||||
collection (str): 集合名称 |
|
||||
vector (np.array): 查询向量 |
|
||||
*args: 其他参数 |
|
||||
**kwargs: 其他关键字参数 |
|
||||
|
|
||||
Returns: |
|
||||
List[RetrievalResult]: 搜索结果 |
|
||||
""" |
|
||||
return self.db.search_data(collection, vector, *args, **kwargs) |
|
||||
|
|
||||
def list_collections(self, *args, **kwargs): |
|
||||
""" |
|
||||
列出所有集合 |
|
||||
|
|
||||
Args: |
|
||||
*args: 其他参数 |
|
||||
**kwargs: 其他关键字参数 |
|
||||
|
|
||||
Returns: |
|
||||
集合信息列表 |
|
||||
""" |
|
||||
return self.db.list_collections(*args, **kwargs) |
|
||||
|
|
||||
def clear_db(self, collection: str = None, *args, **kwargs): |
|
||||
""" |
|
||||
清空数据库集合 |
|
||||
|
|
||||
Args: |
|
||||
collection (str): 集合名称 |
|
||||
*args: 其他参数 |
|
||||
**kwargs: 其他关键字参数 |
|
||||
""" |
|
||||
return self.db.clear_db(collection, *args, **kwargs) |
|
||||
|
|
||||
def remove_duplicate_data(self, collection: str = None) -> int: |
|
||||
""" |
|
||||
移除数据库中的重复数据 |
|
||||
|
|
||||
注意:这是一个通用实现,具体数据库可能有更高效的实现方式 |
|
||||
|
|
||||
Args: |
|
||||
collection (str): 集合名称 |
|
||||
|
|
||||
Returns: |
|
||||
int: 移除的重复记录数 |
|
||||
""" |
|
||||
# 这个方法的具体实现需要根据数据库类型来定制 |
|
||||
# 这里仅提供一个概念性的实现 |
|
||||
raise NotImplementedError("需要根据具体的数据库实现去重逻辑") |
|
||||
|
|
||||
|
|
||||
# 为Qdrant实现专门的去重工具 |
|
||||
class QdrantDeduplicatedVectorDB(DeduplicatedVectorDB): |
|
||||
""" |
|
||||
Qdrant专用的去重向量数据库实现 |
|
||||
""" |
|
||||
|
|
||||
def _filter_duplicate_chunks(self, collection: str, chunks: List[Chunk]) -> List[Chunk]: |
|
||||
""" |
|
||||
过滤掉已存在于Qdrant数据库中的chunks |
|
||||
|
|
||||
Args: |
|
||||
collection (str): 集合名称 |
|
||||
chunks (List[Chunk]): 待过滤的数据块列表 |
|
||||
|
|
||||
Returns: |
|
||||
List[Chunk]: 过滤后的数据块列表 |
|
||||
""" |
|
||||
try: |
|
||||
# 获取Qdrant客户端 |
|
||||
qdrant_client = self.db.client |
|
||||
|
|
||||
# 收集所有要检查的哈希值 |
|
||||
hashes = [calculate_text_hash(chunk.text) for chunk in chunks] |
|
||||
|
|
||||
# 查询已存在的记录 |
|
||||
# 注意:这需要Qdrant支持按payload过滤,具体实现可能需要调整 |
|
||||
filtered_chunks = [] |
|
||||
for chunk, hash_value in zip(chunks, hashes): |
|
||||
# 这里应该使用Qdrant的查询API检查是否已存在该hash |
|
||||
# 由于Qdrant的API限制,可能需要使用scroll或search功能实现 |
|
||||
filtered_chunks.append(chunk) |
|
||||
|
|
||||
return filtered_chunks |
|
||||
except Exception: |
|
||||
# 出错时返回所有chunks(不进行过滤) |
|
||||
return chunks |
|
||||
|
|
||||
def remove_duplicate_data(self, collection: str = None) -> int: |
|
||||
""" |
|
||||
移除Qdrant数据库中的重复数据 |
|
||||
|
|
||||
Args: |
|
||||
collection (str): 集合名称 |
|
||||
|
|
||||
Returns: |
|
||||
int: 移除的重复记录数 |
|
||||
""" |
|
||||
# Qdrant去重实现 |
|
||||
# 这需要先检索所有数据,识别重复项,然后删除重复项 |
|
||||
# 具体实现略 |
|
||||
return 0 |
|
||||
|
|
||||
|
|
||||
# 为Milvus实现专门的去重工具 |
|
||||
class MilvusDeduplicatedVectorDB(DeduplicatedVectorDB): |
|
||||
""" |
|
||||
Milvus专用的去重向量数据库实现 |
|
||||
""" |
|
||||
|
|
||||
def insert_data(self, collection: str, chunks: List[Chunk], |
|
||||
batch_size: int = 256, *args, **kwargs): |
|
||||
""" |
|
||||
插入数据到Milvus,使用文本哈希作为主键 |
|
||||
|
|
||||
Args: |
|
||||
collection (str): 集合名称 |
|
||||
chunks (List[Chunk]): 要插入的数据块列表 |
|
||||
batch_size (int): 批处理大小 |
|
||||
*args: 其他参数 |
|
||||
**kwargs: 其他关键字参数 |
|
||||
""" |
|
||||
# 为每个chunk设置ID为文本哈希值 |
|
||||
for chunk in chunks: |
|
||||
chunk.metadata['id'] = calculate_text_hash(chunk.text) |
|
||||
|
|
||||
# 调用父类方法进行插入 |
|
||||
return super().insert_data(collection, chunks, batch_size, *args, **kwargs) |
|
||||
|
|
||||
def remove_duplicate_data(self, collection: str = None) -> int: |
|
||||
""" |
|
||||
移除Milvus数据库中的重复数据 |
|
||||
|
|
||||
Args: |
|
||||
collection (str): 集合名称 |
|
||||
|
|
||||
Returns: |
|
||||
int: 移除的重复记录数 |
|
||||
""" |
|
||||
# Milvus去重实现 |
|
||||
# 可以通过查询所有数据,找出ID重复的记录并删除 |
|
||||
return 0 |
|
Loading…
Reference in new issue