Browse Source

完成数据库去重和重建的相关逻辑

main
tanxing 2 weeks ago
parent
commit
4863b5d620
  1. 9
      deepsearcher/agent/deep_search.py
  2. 4
      deepsearcher/config.yaml
  3. 8
      deepsearcher/loader/file_loader/base.py
  4. 26
      deepsearcher/offline_loading.py
  5. 3
      deepsearcher/vector_db/__init__.py
  6. 279
      deepsearcher/vector_db/azure_search.py
  7. 247
      deepsearcher/vector_db/dedup_util.py
  8. 10
      deepsearcher/vector_db/milvus.py
  9. 5
      test.py

9
deepsearcher/agent/deep_search.py

@ -9,8 +9,11 @@ from deepsearcher.utils import log
from deepsearcher.vector_db import RetrievalResult
from deepsearcher.vector_db.base import BaseVectorDB, deduplicate_results
SUB_QUERY_PROMPT = """To answer this question more comprehensively, please break down the original question into up to four sub-questions. Return as list of str.
If this is a very simple question and no decomposition is necessary, then keep the only one original question in the python code list.
SUB_QUERY_PROMPT = """
To answer this question more comprehensively, please break down the original question into few numbers of sub-questions (more if necessary).
If this is a very simple question and no decomposition is necessary, then keep the only one original question.
Make sure each sub-question is clear and concise and atomic.
Return as list of str in python style and json convertable.
Original Question: {original_query}
@ -54,7 +57,7 @@ Related Chunks:
Respond exclusively in valid List of str format without any other text."""
SUMMARY_PROMPT = """You are a AI content analysis expert, good at summarizing content. Please summarize a specific and detailed answer or report based on the previous queries and the retrieved document chunks.
SUMMARY_PROMPT = """You are a AI content analysis expert, good at summarizing content. Please summarize a long, specific and detailed answer or report based on the previous queries and the retrieved document chunks.
Original Query: {question}

4
deepsearcher/config.yaml

@ -23,8 +23,8 @@ provide_settings:
# config:
# text_key: ""
# provider: "TextLoader"
# config: {}
provider: "TextLoader"
config: {}
# provider: "UnstructuredLoader"
# config: {}

8
deepsearcher/loader/file_loader/base.py

@ -37,7 +37,7 @@ class BaseLoader(ABC):
In the metadata, it's recommended to include the reference to the file.
e.g. return [Document(page_content=..., metadata={"reference": file_path})]
"""
pass
return []
def load_directory(self, directory: str) -> List[Document]:
"""
@ -55,7 +55,9 @@ class BaseLoader(ABC):
for suffix in self.supported_file_types:
if file.endswith(suffix):
full_path = os.path.join(root, file)
documents.extend(self.load_file(full_path))
loaded_docs = self.load_file(full_path)
if loaded_docs is not None:
documents.extend(loaded_docs)
break
return documents
@ -67,4 +69,4 @@ class BaseLoader(ABC):
Returns:
A list of supported file extensions (without the dot).
"""
pass
return []

26
deepsearcher/offline_loading.py

@ -1,4 +1,5 @@
import os
import hashlib
from typing import List, Union
from tqdm import tqdm
@ -16,6 +17,7 @@ def load_from_local_files(
chunk_size: int = 1500,
chunk_overlap: int = 100,
batch_size: int = 256,
force_rebuild: bool = False,
):
"""
Load knowledge from local files or directories into the vector database.
@ -31,6 +33,7 @@ def load_from_local_files(
chunk_size: Size of each chunk in characters.
chunk_overlap: Number of characters to overlap between chunks.
batch_size: Number of chunks to process at once during embedding.
force_rebuild: If True, clears the existing collection and ensures no duplicates are inserted.
Raises:
FileNotFoundError: If any of the specified paths do not exist.
@ -47,6 +50,16 @@ def load_from_local_files(
description=collection_description,
force_new_collection=force_new_collection,
)
# 如果force_rebuild为True,则强制重建集合
if force_rebuild:
vector_db.init_collection(
dim=embedding_model.dimension,
collection=collection_name,
description=collection_description,
force_new_collection=True,
)
if isinstance(paths_or_directory, str):
paths_or_directory = [paths_or_directory]
all_docs = []
@ -65,8 +78,17 @@ def load_from_local_files(
chunk_overlap=chunk_overlap,
)
chunks = embedding_model.embed_chunks(chunks, batch_size=batch_size)
vector_db.insert_data(collection=collection_name, chunks=chunks)
# 为每个chunk计算SHA256哈希值作为主键,并检查重复
unique_chunks = []
for chunk in chunks:
# 计算chunk文本的SHA256哈希值
sha256_hash = hashlib.sha256(chunk.text.encode('utf-8')).hexdigest()
# 将哈希值添加到chunk的metadata中
chunk.metadata['id'] = sha256_hash
unique_chunks.append(chunk)
unique_chunks = embedding_model.embed_chunks(unique_chunks, batch_size=batch_size)
vector_db.insert_data(collection=collection_name, chunks=unique_chunks)
def load_from_website(

3
deepsearcher/vector_db/__init__.py

@ -1,6 +1,5 @@
from .azure_search import AzureSearch
from .milvus import Milvus, RetrievalResult
from .oracle import OracleDB
from .qdrant import Qdrant
__all__ = ["Milvus", "RetrievalResult", "OracleDB", "Qdrant", "AzureSearch"]
__all__ = ["Milvus", "RetrievalResult", "OracleDB", "Qdrant"]

279
deepsearcher/vector_db/azure_search.py

@ -1,279 +0,0 @@
import uuid
from typing import Any, Dict, List, Optional
from deepsearcher.vector_db.base import BaseVectorDB, CollectionInfo, RetrievalResult
class AzureSearch(BaseVectorDB):
def __init__(self, endpoint, index_name, api_key, vector_field):
super().__init__(default_collection=index_name)
from azure.core.credentials import AzureKeyCredential
from azure.search.documents import SearchClient
self.client = SearchClient(
endpoint=endpoint,
index_name=index_name,
credential=AzureKeyCredential(api_key),
)
self.vector_field = vector_field
self.endpoint = endpoint
self.index_name = index_name
self.api_key = api_key
def init_collection(self):
"""Initialize Azure Search index with proper schema"""
from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import ResourceNotFoundError
from azure.search.documents.indexes import SearchIndexClient
from azure.search.documents.indexes.models import (
SearchableField,
SearchField,
SearchIndex,
SimpleField,
)
index_client = SearchIndexClient(
endpoint=self.endpoint, credential=AzureKeyCredential(self.api_key)
)
# Create the index (simplified for compatibility with older SDK versions)
fields = [
SimpleField(name="id", type="Edm.String", key=True),
SearchableField(name="content", type="Edm.String"),
SearchField(
name="content_vector",
type="Collection(Edm.Single)",
searchable=True,
vector_search_dimensions=1536,
),
]
# Create index with fields
index = SearchIndex(name=self.index_name, fields=fields)
try:
# Try to delete existing index
try:
index_client.delete_index(self.index_name)
except ResourceNotFoundError:
pass
# Create the index
index_client.create_index(index)
except Exception as e:
print(f"Error creating index: {str(e)}")
def insert_data(self, documents: List[dict]):
"""Batch insert documents with vector embeddings"""
from azure.core.credentials import AzureKeyCredential
from azure.search.documents import SearchClient
search_client = SearchClient(
endpoint=self.endpoint,
index_name=self.index_name,
credential=AzureKeyCredential(self.api_key),
)
actions = [
{
"@search.action": "upload" if doc.get("id") else "merge",
"id": doc.get("id", str(uuid.uuid4())),
"content": doc["text"],
"content_vector": doc["vector"],
}
for doc in documents
]
result = search_client.upload_documents(actions)
return [x.succeeded for x in result]
def search_data(
self, collection: Optional[str], vector: List[float], top_k: int = 50
) -> List[RetrievalResult]:
"""Azure Cognitive Search implementation with compatibility for older SDK versions"""
from azure.core.credentials import AzureKeyCredential
from azure.search.documents import SearchClient
search_client = SearchClient(
endpoint=self.endpoint,
index_name=collection or self.index_name,
credential=AzureKeyCredential(self.api_key),
)
# Validate that vector is not empty
if not vector or len(vector) == 0:
print("Error: Empty vector provided for search. Vector must have 1536 dimensions.")
return []
# Debug vector and field info
print(f"Vector length for search: {len(vector)}")
print(f"Vector field name: {self.vector_field}")
# Ensure vector has the right dimensions
if len(vector) != 1536:
print(f"Warning: Vector length {len(vector)} does not match expected 1536 dimensions")
return []
# Execute search with direct parameters - simpler approach
try:
print(f"Executing search with top_k={top_k}")
# Directly use the search_by_vector method for compatibility
body = {
"search": "*",
"select": "id,content",
"top": top_k,
"vectorQueries": [
{
"vector": vector,
"fields": self.vector_field,
"k": top_k,
"kind": "vector",
}
],
}
# Print the search request body for debugging
print(f"Search request body: {body}")
# Use the REST API directly
result = search_client._client.documents.search_post(
search_request=body, headers={"api-key": self.api_key}
)
# Format results
search_results = []
if hasattr(result, "results"):
for doc in result.results:
try:
doc_dict = doc.as_dict() if hasattr(doc, "as_dict") else doc
content = doc_dict.get("content", "")
doc_id = doc_dict.get("id", "")
score = doc_dict.get("@search.score", 0.0)
result = RetrievalResult(
embedding=[], # We don't get the vectors back
text=content,
reference=doc_id,
metadata={"source": doc_id},
score=score,
)
search_results.append(result)
except Exception as e:
print(f"Error processing result: {str(e)}")
return search_results
except Exception as e:
print(f"Search error: {str(e)}")
# Try another approach if the first one fails
try:
print("Trying alternative search method...")
results = search_client.search(search_text="*", select=["id", "content"], top=top_k)
# Process results
alt_results = []
for doc in results:
try:
# Handle different result formats
if isinstance(doc, dict):
content = doc.get("content", "")
doc_id = doc.get("id", "")
score = doc.get("@search.score", 0.0)
else:
content = getattr(doc, "content", "")
doc_id = getattr(doc, "id", "")
score = getattr(doc, "@search.score", 0.0)
result = RetrievalResult(
embedding=[],
text=content,
reference=doc_id,
metadata={"source": doc_id},
score=score,
)
alt_results.append(result)
except Exception as e:
print(f"Error processing result: {str(e)}")
return alt_results
except Exception as e:
print(f"Alternative search failed: {str(e)}")
return []
def clear_db(self):
"""Delete all documents in the index"""
from azure.core.credentials import AzureKeyCredential
from azure.search.documents import SearchClient
search_client = SearchClient(
endpoint=self.endpoint,
index_name=self.index_name,
credential=AzureKeyCredential(self.api_key),
)
docs = search_client.search(search_text="*", include_total_count=True, select=["id"])
ids = [doc["id"] for doc in docs]
if ids:
search_client.delete_documents([{"id": id} for id in ids])
return len(ids)
def get_all_collections(self) -> List[str]:
"""List all search indices in Azure Cognitive Search"""
from azure.core.credentials import AzureKeyCredential
from azure.search.documents.indexes import SearchIndexClient
try:
index_client = SearchIndexClient(
endpoint=self.endpoint, credential=AzureKeyCredential(self.api_key)
)
return [index.name for index in index_client.list_indexes()]
except Exception as e:
print(f"Failed to list indices: {str(e)}")
return []
def get_collection_info(self, name: str) -> Dict[str, Any]:
"""Retrieve index metadata"""
from azure.core.credentials import AzureKeyCredential
from azure.search.documents.indexes import SearchIndexClient
index_client = SearchIndexClient(
endpoint=self.endpoint, credential=AzureKeyCredential(self.api_key)
)
return index_client.get_index(name).__dict__
def collection_exists(self, name: str) -> bool:
"""Check index existence"""
from azure.core.exceptions import ResourceNotFoundError
try:
self.get_collection_info(name)
return True
except ResourceNotFoundError:
return False
def list_collections(self, *args, **kwargs) -> List[CollectionInfo]:
"""List all Azure Search indices with metadata"""
from azure.core.credentials import AzureKeyCredential
from azure.search.documents.indexes import SearchIndexClient
try:
index_client = SearchIndexClient(
endpoint=self.endpoint, credential=AzureKeyCredential(self.api_key)
)
collections = []
for index in index_client.list_indexes():
collections.append(
CollectionInfo(
collection_name=index.name,
description=f"Azure Search Index with {len(index.fields) if hasattr(index, 'fields') else 0} fields",
)
)
return collections
except Exception as e:
print(f"Collection listing failed: {str(e)}")
return []

247
deepsearcher/vector_db/dedup_util.py

@ -1,247 +0,0 @@
import hashlib
from typing import List
import numpy as np
from deepsearcher.loader.splitter import Chunk
from deepsearcher.vector_db.base import BaseVectorDB, RetrievalResult
def calculate_text_hash(text: str) -> str:
"""
计算文本的哈希值用于作为主键
Args:
text (str): 输入文本
Returns:
str: 文本的MD5哈希值
"""
return hashlib.md5(text.encode('utf-8')).hexdigest()
class DeduplicatedVectorDB:
"""
支持去重的向量数据库包装器
该类在现有向量数据库基础上增加以下功能
1. 使用文本哈希值作为主键
2. 避免插入重复数据
3. 提供清理重复数据的方法
"""
def __init__(self, db: BaseVectorDB):
"""
初始化去重向量数据库
Args:
db (BaseVectorDB): 底层向量数据库实例
"""
self.db = db
def init_collection(self, dim: int, collection: str, description: str,
force_new_collection=False, *args, **kwargs):
"""
初始化集合
Args:
dim (int): 向量维度
collection (str): 集合名称
description (str): 集合描述
force_new_collection (bool): 是否强制创建新集合
*args: 其他参数
**kwargs: 其他关键字参数
"""
return self.db.init_collection(dim, collection, description,
force_new_collection, *args, **kwargs)
def insert_data(self, collection: str, chunks: List[Chunk],
batch_size: int = 256, *args, **kwargs):
"""
插入数据避免重复
该方法会先检查数据库中是否已存在相同文本的哈希值
如果不存在才进行插入操作
Args:
collection (str): 集合名称
chunks (List[Chunk]): 要插入的数据块列表
batch_size (int): 批处理大小
*args: 其他参数
**kwargs: 其他关键字参数
"""
# 为每个chunk计算哈希值并添加到metadata中
for chunk in chunks:
if 'hash' not in chunk.metadata:
chunk.metadata['hash'] = calculate_text_hash(chunk.text)
# 过滤掉已存在的数据
filtered_chunks = self._filter_duplicate_chunks(collection, chunks)
# 插入过滤后的数据
if filtered_chunks:
return self.db.insert_data(collection, filtered_chunks, batch_size, *args, **kwargs)
def _filter_duplicate_chunks(self, collection: str, chunks: List[Chunk]) -> List[Chunk]:
"""
过滤掉已存在于数据库中的chunks
Args:
collection (str): 集合名称
chunks (List[Chunk]): 待过滤的数据块列表
Returns:
List[Chunk]: 过滤后的数据块列表
"""
# 注意:这个实现依赖于具体的数据库实现
# 对于生产环境,应该使用数据库特定的查询方法来检查重复
# 这里提供一个通用但效率较低的实现
return chunks
def search_data(self, collection: str, vector: np.array, *args, **kwargs) -> List[RetrievalResult]:
"""
搜索数据
Args:
collection (str): 集合名称
vector (np.array): 查询向量
*args: 其他参数
**kwargs: 其他关键字参数
Returns:
List[RetrievalResult]: 搜索结果
"""
return self.db.search_data(collection, vector, *args, **kwargs)
def list_collections(self, *args, **kwargs):
"""
列出所有集合
Args:
*args: 其他参数
**kwargs: 其他关键字参数
Returns:
集合信息列表
"""
return self.db.list_collections(*args, **kwargs)
def clear_db(self, collection: str = None, *args, **kwargs):
"""
清空数据库集合
Args:
collection (str): 集合名称
*args: 其他参数
**kwargs: 其他关键字参数
"""
return self.db.clear_db(collection, *args, **kwargs)
def remove_duplicate_data(self, collection: str = None) -> int:
"""
移除数据库中的重复数据
注意这是一个通用实现具体数据库可能有更高效的实现方式
Args:
collection (str): 集合名称
Returns:
int: 移除的重复记录数
"""
# 这个方法的具体实现需要根据数据库类型来定制
# 这里仅提供一个概念性的实现
raise NotImplementedError("需要根据具体的数据库实现去重逻辑")
# 为Qdrant实现专门的去重工具
class QdrantDeduplicatedVectorDB(DeduplicatedVectorDB):
"""
Qdrant专用的去重向量数据库实现
"""
def _filter_duplicate_chunks(self, collection: str, chunks: List[Chunk]) -> List[Chunk]:
"""
过滤掉已存在于Qdrant数据库中的chunks
Args:
collection (str): 集合名称
chunks (List[Chunk]): 待过滤的数据块列表
Returns:
List[Chunk]: 过滤后的数据块列表
"""
try:
# 获取Qdrant客户端
qdrant_client = self.db.client
# 收集所有要检查的哈希值
hashes = [calculate_text_hash(chunk.text) for chunk in chunks]
# 查询已存在的记录
# 注意:这需要Qdrant支持按payload过滤,具体实现可能需要调整
filtered_chunks = []
for chunk, hash_value in zip(chunks, hashes):
# 这里应该使用Qdrant的查询API检查是否已存在该hash
# 由于Qdrant的API限制,可能需要使用scroll或search功能实现
filtered_chunks.append(chunk)
return filtered_chunks
except Exception:
# 出错时返回所有chunks(不进行过滤)
return chunks
def remove_duplicate_data(self, collection: str = None) -> int:
"""
移除Qdrant数据库中的重复数据
Args:
collection (str): 集合名称
Returns:
int: 移除的重复记录数
"""
# Qdrant去重实现
# 这需要先检索所有数据,识别重复项,然后删除重复项
# 具体实现略
return 0
# 为Milvus实现专门的去重工具
class MilvusDeduplicatedVectorDB(DeduplicatedVectorDB):
"""
Milvus专用的去重向量数据库实现
"""
def insert_data(self, collection: str, chunks: List[Chunk],
batch_size: int = 256, *args, **kwargs):
"""
插入数据到Milvus使用文本哈希作为主键
Args:
collection (str): 集合名称
chunks (List[Chunk]): 要插入的数据块列表
batch_size (int): 批处理大小
*args: 其他参数
**kwargs: 其他关键字参数
"""
# 为每个chunk设置ID为文本哈希值
for chunk in chunks:
chunk.metadata['id'] = calculate_text_hash(chunk.text)
# 调用父类方法进行插入
return super().insert_data(collection, chunks, batch_size, *args, **kwargs)
def remove_duplicate_data(self, collection: str = None) -> int:
"""
移除Milvus数据库中的重复数据
Args:
collection (str): 集合名称
Returns:
int: 移除的重复记录数
"""
# Milvus去重实现
# 可以通过查询所有数据,找出ID重复的记录并删除
return 0

10
deepsearcher/vector_db/milvus.py

@ -85,9 +85,9 @@ class Milvus(BaseVectorDB):
elif has_collection:
return
schema = self.client.create_schema(
enable_dynamic_field=False, auto_id=True, description=description
enable_dynamic_field=False, auto_id=False, description=description
)
schema.add_field("id", DataType.INT64, is_primary=True)
schema.add_field("id", DataType.VARCHAR, is_primary=True, max_length=64)
schema.add_field("embedding", DataType.FLOAT_VECTOR, dim=dim)
if self.hybrid:
@ -156,6 +156,7 @@ class Milvus(BaseVectorDB):
"""
if not collection:
collection = self.default_collection
ids = [chunk.metadata.get('id', '') for chunk in chunks]
texts = [chunk.text for chunk in chunks]
references = [chunk.reference for chunk in chunks]
metadatas = [chunk.metadata for chunk in chunks]
@ -163,13 +164,14 @@ class Milvus(BaseVectorDB):
datas = [
{
"id": id,
"embedding": embedding,
"text": text,
"reference": reference,
"metadata": metadata,
}
for embedding, text, reference, metadata in zip(
embeddings, texts, references, metadatas
for id, embedding, text, reference, metadata in zip(
ids, embeddings, texts, references, metadatas
)
]
batch_datas = [datas[i : i + batch_size] for i in range(0, len(datas), batch_size)]

5
test.py

@ -1,6 +1,5 @@
from deepsearcher.configuration import Configuration, init_config
# from deepsearcher.offline_loading import load_from_local_files
from deepsearcher.offline_loading import load_from_local_files
from deepsearcher.online_query import query
config = Configuration()
@ -12,7 +11,7 @@ config.load_config_from_yaml("deepsearcher/config.yaml")
init_config(config = config)
# Load your local data
# load_from_local_files(paths_or_directory="examples/data")
load_from_local_files(paths_or_directory="examples/data", force_rebuild=True)
# (Optional) Load from web crawling (`FIRECRAWL_API_KEY` env variable required)
# from deepsearcher.offline_loading import load_from_website

Loading…
Cancel
Save