Browse Source

完成数据库去重和重建的相关逻辑

main
tanxing 2 weeks ago
parent
commit
4863b5d620
  1. 9
      deepsearcher/agent/deep_search.py
  2. 4
      deepsearcher/config.yaml
  3. 8
      deepsearcher/loader/file_loader/base.py
  4. 26
      deepsearcher/offline_loading.py
  5. 3
      deepsearcher/vector_db/__init__.py
  6. 279
      deepsearcher/vector_db/azure_search.py
  7. 247
      deepsearcher/vector_db/dedup_util.py
  8. 10
      deepsearcher/vector_db/milvus.py
  9. 5
      test.py

9
deepsearcher/agent/deep_search.py

@ -9,8 +9,11 @@ from deepsearcher.utils import log
from deepsearcher.vector_db import RetrievalResult from deepsearcher.vector_db import RetrievalResult
from deepsearcher.vector_db.base import BaseVectorDB, deduplicate_results from deepsearcher.vector_db.base import BaseVectorDB, deduplicate_results
SUB_QUERY_PROMPT = """To answer this question more comprehensively, please break down the original question into up to four sub-questions. Return as list of str. SUB_QUERY_PROMPT = """
If this is a very simple question and no decomposition is necessary, then keep the only one original question in the python code list. To answer this question more comprehensively, please break down the original question into few numbers of sub-questions (more if necessary).
If this is a very simple question and no decomposition is necessary, then keep the only one original question.
Make sure each sub-question is clear and concise and atomic.
Return as list of str in python style and json convertable.
Original Question: {original_query} Original Question: {original_query}
@ -54,7 +57,7 @@ Related Chunks:
Respond exclusively in valid List of str format without any other text.""" Respond exclusively in valid List of str format without any other text."""
SUMMARY_PROMPT = """You are a AI content analysis expert, good at summarizing content. Please summarize a specific and detailed answer or report based on the previous queries and the retrieved document chunks. SUMMARY_PROMPT = """You are a AI content analysis expert, good at summarizing content. Please summarize a long, specific and detailed answer or report based on the previous queries and the retrieved document chunks.
Original Query: {question} Original Query: {question}

4
deepsearcher/config.yaml

@ -23,8 +23,8 @@ provide_settings:
# config: # config:
# text_key: "" # text_key: ""
# provider: "TextLoader" provider: "TextLoader"
# config: {} config: {}
# provider: "UnstructuredLoader" # provider: "UnstructuredLoader"
# config: {} # config: {}

8
deepsearcher/loader/file_loader/base.py

@ -37,7 +37,7 @@ class BaseLoader(ABC):
In the metadata, it's recommended to include the reference to the file. In the metadata, it's recommended to include the reference to the file.
e.g. return [Document(page_content=..., metadata={"reference": file_path})] e.g. return [Document(page_content=..., metadata={"reference": file_path})]
""" """
pass return []
def load_directory(self, directory: str) -> List[Document]: def load_directory(self, directory: str) -> List[Document]:
""" """
@ -55,7 +55,9 @@ class BaseLoader(ABC):
for suffix in self.supported_file_types: for suffix in self.supported_file_types:
if file.endswith(suffix): if file.endswith(suffix):
full_path = os.path.join(root, file) full_path = os.path.join(root, file)
documents.extend(self.load_file(full_path)) loaded_docs = self.load_file(full_path)
if loaded_docs is not None:
documents.extend(loaded_docs)
break break
return documents return documents
@ -67,4 +69,4 @@ class BaseLoader(ABC):
Returns: Returns:
A list of supported file extensions (without the dot). A list of supported file extensions (without the dot).
""" """
pass return []

26
deepsearcher/offline_loading.py

@ -1,4 +1,5 @@
import os import os
import hashlib
from typing import List, Union from typing import List, Union
from tqdm import tqdm from tqdm import tqdm
@ -16,6 +17,7 @@ def load_from_local_files(
chunk_size: int = 1500, chunk_size: int = 1500,
chunk_overlap: int = 100, chunk_overlap: int = 100,
batch_size: int = 256, batch_size: int = 256,
force_rebuild: bool = False,
): ):
""" """
Load knowledge from local files or directories into the vector database. Load knowledge from local files or directories into the vector database.
@ -31,6 +33,7 @@ def load_from_local_files(
chunk_size: Size of each chunk in characters. chunk_size: Size of each chunk in characters.
chunk_overlap: Number of characters to overlap between chunks. chunk_overlap: Number of characters to overlap between chunks.
batch_size: Number of chunks to process at once during embedding. batch_size: Number of chunks to process at once during embedding.
force_rebuild: If True, clears the existing collection and ensures no duplicates are inserted.
Raises: Raises:
FileNotFoundError: If any of the specified paths do not exist. FileNotFoundError: If any of the specified paths do not exist.
@ -47,6 +50,16 @@ def load_from_local_files(
description=collection_description, description=collection_description,
force_new_collection=force_new_collection, force_new_collection=force_new_collection,
) )
# 如果force_rebuild为True,则强制重建集合
if force_rebuild:
vector_db.init_collection(
dim=embedding_model.dimension,
collection=collection_name,
description=collection_description,
force_new_collection=True,
)
if isinstance(paths_or_directory, str): if isinstance(paths_or_directory, str):
paths_or_directory = [paths_or_directory] paths_or_directory = [paths_or_directory]
all_docs = [] all_docs = []
@ -65,8 +78,17 @@ def load_from_local_files(
chunk_overlap=chunk_overlap, chunk_overlap=chunk_overlap,
) )
chunks = embedding_model.embed_chunks(chunks, batch_size=batch_size) # 为每个chunk计算SHA256哈希值作为主键,并检查重复
vector_db.insert_data(collection=collection_name, chunks=chunks) unique_chunks = []
for chunk in chunks:
# 计算chunk文本的SHA256哈希值
sha256_hash = hashlib.sha256(chunk.text.encode('utf-8')).hexdigest()
# 将哈希值添加到chunk的metadata中
chunk.metadata['id'] = sha256_hash
unique_chunks.append(chunk)
unique_chunks = embedding_model.embed_chunks(unique_chunks, batch_size=batch_size)
vector_db.insert_data(collection=collection_name, chunks=unique_chunks)
def load_from_website( def load_from_website(

3
deepsearcher/vector_db/__init__.py

@ -1,6 +1,5 @@
from .azure_search import AzureSearch
from .milvus import Milvus, RetrievalResult from .milvus import Milvus, RetrievalResult
from .oracle import OracleDB from .oracle import OracleDB
from .qdrant import Qdrant from .qdrant import Qdrant
__all__ = ["Milvus", "RetrievalResult", "OracleDB", "Qdrant", "AzureSearch"] __all__ = ["Milvus", "RetrievalResult", "OracleDB", "Qdrant"]

279
deepsearcher/vector_db/azure_search.py

@ -1,279 +0,0 @@
import uuid
from typing import Any, Dict, List, Optional
from deepsearcher.vector_db.base import BaseVectorDB, CollectionInfo, RetrievalResult
class AzureSearch(BaseVectorDB):
def __init__(self, endpoint, index_name, api_key, vector_field):
super().__init__(default_collection=index_name)
from azure.core.credentials import AzureKeyCredential
from azure.search.documents import SearchClient
self.client = SearchClient(
endpoint=endpoint,
index_name=index_name,
credential=AzureKeyCredential(api_key),
)
self.vector_field = vector_field
self.endpoint = endpoint
self.index_name = index_name
self.api_key = api_key
def init_collection(self):
"""Initialize Azure Search index with proper schema"""
from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import ResourceNotFoundError
from azure.search.documents.indexes import SearchIndexClient
from azure.search.documents.indexes.models import (
SearchableField,
SearchField,
SearchIndex,
SimpleField,
)
index_client = SearchIndexClient(
endpoint=self.endpoint, credential=AzureKeyCredential(self.api_key)
)
# Create the index (simplified for compatibility with older SDK versions)
fields = [
SimpleField(name="id", type="Edm.String", key=True),
SearchableField(name="content", type="Edm.String"),
SearchField(
name="content_vector",
type="Collection(Edm.Single)",
searchable=True,
vector_search_dimensions=1536,
),
]
# Create index with fields
index = SearchIndex(name=self.index_name, fields=fields)
try:
# Try to delete existing index
try:
index_client.delete_index(self.index_name)
except ResourceNotFoundError:
pass
# Create the index
index_client.create_index(index)
except Exception as e:
print(f"Error creating index: {str(e)}")
def insert_data(self, documents: List[dict]):
"""Batch insert documents with vector embeddings"""
from azure.core.credentials import AzureKeyCredential
from azure.search.documents import SearchClient
search_client = SearchClient(
endpoint=self.endpoint,
index_name=self.index_name,
credential=AzureKeyCredential(self.api_key),
)
actions = [
{
"@search.action": "upload" if doc.get("id") else "merge",
"id": doc.get("id", str(uuid.uuid4())),
"content": doc["text"],
"content_vector": doc["vector"],
}
for doc in documents
]
result = search_client.upload_documents(actions)
return [x.succeeded for x in result]
def search_data(
self, collection: Optional[str], vector: List[float], top_k: int = 50
) -> List[RetrievalResult]:
"""Azure Cognitive Search implementation with compatibility for older SDK versions"""
from azure.core.credentials import AzureKeyCredential
from azure.search.documents import SearchClient
search_client = SearchClient(
endpoint=self.endpoint,
index_name=collection or self.index_name,
credential=AzureKeyCredential(self.api_key),
)
# Validate that vector is not empty
if not vector or len(vector) == 0:
print("Error: Empty vector provided for search. Vector must have 1536 dimensions.")
return []
# Debug vector and field info
print(f"Vector length for search: {len(vector)}")
print(f"Vector field name: {self.vector_field}")
# Ensure vector has the right dimensions
if len(vector) != 1536:
print(f"Warning: Vector length {len(vector)} does not match expected 1536 dimensions")
return []
# Execute search with direct parameters - simpler approach
try:
print(f"Executing search with top_k={top_k}")
# Directly use the search_by_vector method for compatibility
body = {
"search": "*",
"select": "id,content",
"top": top_k,
"vectorQueries": [
{
"vector": vector,
"fields": self.vector_field,
"k": top_k,
"kind": "vector",
}
],
}
# Print the search request body for debugging
print(f"Search request body: {body}")
# Use the REST API directly
result = search_client._client.documents.search_post(
search_request=body, headers={"api-key": self.api_key}
)
# Format results
search_results = []
if hasattr(result, "results"):
for doc in result.results:
try:
doc_dict = doc.as_dict() if hasattr(doc, "as_dict") else doc
content = doc_dict.get("content", "")
doc_id = doc_dict.get("id", "")
score = doc_dict.get("@search.score", 0.0)
result = RetrievalResult(
embedding=[], # We don't get the vectors back
text=content,
reference=doc_id,
metadata={"source": doc_id},
score=score,
)
search_results.append(result)
except Exception as e:
print(f"Error processing result: {str(e)}")
return search_results
except Exception as e:
print(f"Search error: {str(e)}")
# Try another approach if the first one fails
try:
print("Trying alternative search method...")
results = search_client.search(search_text="*", select=["id", "content"], top=top_k)
# Process results
alt_results = []
for doc in results:
try:
# Handle different result formats
if isinstance(doc, dict):
content = doc.get("content", "")
doc_id = doc.get("id", "")
score = doc.get("@search.score", 0.0)
else:
content = getattr(doc, "content", "")
doc_id = getattr(doc, "id", "")
score = getattr(doc, "@search.score", 0.0)
result = RetrievalResult(
embedding=[],
text=content,
reference=doc_id,
metadata={"source": doc_id},
score=score,
)
alt_results.append(result)
except Exception as e:
print(f"Error processing result: {str(e)}")
return alt_results
except Exception as e:
print(f"Alternative search failed: {str(e)}")
return []
def clear_db(self):
"""Delete all documents in the index"""
from azure.core.credentials import AzureKeyCredential
from azure.search.documents import SearchClient
search_client = SearchClient(
endpoint=self.endpoint,
index_name=self.index_name,
credential=AzureKeyCredential(self.api_key),
)
docs = search_client.search(search_text="*", include_total_count=True, select=["id"])
ids = [doc["id"] for doc in docs]
if ids:
search_client.delete_documents([{"id": id} for id in ids])
return len(ids)
def get_all_collections(self) -> List[str]:
"""List all search indices in Azure Cognitive Search"""
from azure.core.credentials import AzureKeyCredential
from azure.search.documents.indexes import SearchIndexClient
try:
index_client = SearchIndexClient(
endpoint=self.endpoint, credential=AzureKeyCredential(self.api_key)
)
return [index.name for index in index_client.list_indexes()]
except Exception as e:
print(f"Failed to list indices: {str(e)}")
return []
def get_collection_info(self, name: str) -> Dict[str, Any]:
"""Retrieve index metadata"""
from azure.core.credentials import AzureKeyCredential
from azure.search.documents.indexes import SearchIndexClient
index_client = SearchIndexClient(
endpoint=self.endpoint, credential=AzureKeyCredential(self.api_key)
)
return index_client.get_index(name).__dict__
def collection_exists(self, name: str) -> bool:
"""Check index existence"""
from azure.core.exceptions import ResourceNotFoundError
try:
self.get_collection_info(name)
return True
except ResourceNotFoundError:
return False
def list_collections(self, *args, **kwargs) -> List[CollectionInfo]:
"""List all Azure Search indices with metadata"""
from azure.core.credentials import AzureKeyCredential
from azure.search.documents.indexes import SearchIndexClient
try:
index_client = SearchIndexClient(
endpoint=self.endpoint, credential=AzureKeyCredential(self.api_key)
)
collections = []
for index in index_client.list_indexes():
collections.append(
CollectionInfo(
collection_name=index.name,
description=f"Azure Search Index with {len(index.fields) if hasattr(index, 'fields') else 0} fields",
)
)
return collections
except Exception as e:
print(f"Collection listing failed: {str(e)}")
return []

247
deepsearcher/vector_db/dedup_util.py

@ -1,247 +0,0 @@
import hashlib
from typing import List
import numpy as np
from deepsearcher.loader.splitter import Chunk
from deepsearcher.vector_db.base import BaseVectorDB, RetrievalResult
def calculate_text_hash(text: str) -> str:
"""
计算文本的哈希值用于作为主键
Args:
text (str): 输入文本
Returns:
str: 文本的MD5哈希值
"""
return hashlib.md5(text.encode('utf-8')).hexdigest()
class DeduplicatedVectorDB:
"""
支持去重的向量数据库包装器
该类在现有向量数据库基础上增加以下功能
1. 使用文本哈希值作为主键
2. 避免插入重复数据
3. 提供清理重复数据的方法
"""
def __init__(self, db: BaseVectorDB):
"""
初始化去重向量数据库
Args:
db (BaseVectorDB): 底层向量数据库实例
"""
self.db = db
def init_collection(self, dim: int, collection: str, description: str,
force_new_collection=False, *args, **kwargs):
"""
初始化集合
Args:
dim (int): 向量维度
collection (str): 集合名称
description (str): 集合描述
force_new_collection (bool): 是否强制创建新集合
*args: 其他参数
**kwargs: 其他关键字参数
"""
return self.db.init_collection(dim, collection, description,
force_new_collection, *args, **kwargs)
def insert_data(self, collection: str, chunks: List[Chunk],
batch_size: int = 256, *args, **kwargs):
"""
插入数据避免重复
该方法会先检查数据库中是否已存在相同文本的哈希值
如果不存在才进行插入操作
Args:
collection (str): 集合名称
chunks (List[Chunk]): 要插入的数据块列表
batch_size (int): 批处理大小
*args: 其他参数
**kwargs: 其他关键字参数
"""
# 为每个chunk计算哈希值并添加到metadata中
for chunk in chunks:
if 'hash' not in chunk.metadata:
chunk.metadata['hash'] = calculate_text_hash(chunk.text)
# 过滤掉已存在的数据
filtered_chunks = self._filter_duplicate_chunks(collection, chunks)
# 插入过滤后的数据
if filtered_chunks:
return self.db.insert_data(collection, filtered_chunks, batch_size, *args, **kwargs)
def _filter_duplicate_chunks(self, collection: str, chunks: List[Chunk]) -> List[Chunk]:
"""
过滤掉已存在于数据库中的chunks
Args:
collection (str): 集合名称
chunks (List[Chunk]): 待过滤的数据块列表
Returns:
List[Chunk]: 过滤后的数据块列表
"""
# 注意:这个实现依赖于具体的数据库实现
# 对于生产环境,应该使用数据库特定的查询方法来检查重复
# 这里提供一个通用但效率较低的实现
return chunks
def search_data(self, collection: str, vector: np.array, *args, **kwargs) -> List[RetrievalResult]:
"""
搜索数据
Args:
collection (str): 集合名称
vector (np.array): 查询向量
*args: 其他参数
**kwargs: 其他关键字参数
Returns:
List[RetrievalResult]: 搜索结果
"""
return self.db.search_data(collection, vector, *args, **kwargs)
def list_collections(self, *args, **kwargs):
"""
列出所有集合
Args:
*args: 其他参数
**kwargs: 其他关键字参数
Returns:
集合信息列表
"""
return self.db.list_collections(*args, **kwargs)
def clear_db(self, collection: str = None, *args, **kwargs):
"""
清空数据库集合
Args:
collection (str): 集合名称
*args: 其他参数
**kwargs: 其他关键字参数
"""
return self.db.clear_db(collection, *args, **kwargs)
def remove_duplicate_data(self, collection: str = None) -> int:
"""
移除数据库中的重复数据
注意这是一个通用实现具体数据库可能有更高效的实现方式
Args:
collection (str): 集合名称
Returns:
int: 移除的重复记录数
"""
# 这个方法的具体实现需要根据数据库类型来定制
# 这里仅提供一个概念性的实现
raise NotImplementedError("需要根据具体的数据库实现去重逻辑")
# 为Qdrant实现专门的去重工具
class QdrantDeduplicatedVectorDB(DeduplicatedVectorDB):
"""
Qdrant专用的去重向量数据库实现
"""
def _filter_duplicate_chunks(self, collection: str, chunks: List[Chunk]) -> List[Chunk]:
"""
过滤掉已存在于Qdrant数据库中的chunks
Args:
collection (str): 集合名称
chunks (List[Chunk]): 待过滤的数据块列表
Returns:
List[Chunk]: 过滤后的数据块列表
"""
try:
# 获取Qdrant客户端
qdrant_client = self.db.client
# 收集所有要检查的哈希值
hashes = [calculate_text_hash(chunk.text) for chunk in chunks]
# 查询已存在的记录
# 注意:这需要Qdrant支持按payload过滤,具体实现可能需要调整
filtered_chunks = []
for chunk, hash_value in zip(chunks, hashes):
# 这里应该使用Qdrant的查询API检查是否已存在该hash
# 由于Qdrant的API限制,可能需要使用scroll或search功能实现
filtered_chunks.append(chunk)
return filtered_chunks
except Exception:
# 出错时返回所有chunks(不进行过滤)
return chunks
def remove_duplicate_data(self, collection: str = None) -> int:
"""
移除Qdrant数据库中的重复数据
Args:
collection (str): 集合名称
Returns:
int: 移除的重复记录数
"""
# Qdrant去重实现
# 这需要先检索所有数据,识别重复项,然后删除重复项
# 具体实现略
return 0
# 为Milvus实现专门的去重工具
class MilvusDeduplicatedVectorDB(DeduplicatedVectorDB):
"""
Milvus专用的去重向量数据库实现
"""
def insert_data(self, collection: str, chunks: List[Chunk],
batch_size: int = 256, *args, **kwargs):
"""
插入数据到Milvus使用文本哈希作为主键
Args:
collection (str): 集合名称
chunks (List[Chunk]): 要插入的数据块列表
batch_size (int): 批处理大小
*args: 其他参数
**kwargs: 其他关键字参数
"""
# 为每个chunk设置ID为文本哈希值
for chunk in chunks:
chunk.metadata['id'] = calculate_text_hash(chunk.text)
# 调用父类方法进行插入
return super().insert_data(collection, chunks, batch_size, *args, **kwargs)
def remove_duplicate_data(self, collection: str = None) -> int:
"""
移除Milvus数据库中的重复数据
Args:
collection (str): 集合名称
Returns:
int: 移除的重复记录数
"""
# Milvus去重实现
# 可以通过查询所有数据,找出ID重复的记录并删除
return 0

10
deepsearcher/vector_db/milvus.py

@ -85,9 +85,9 @@ class Milvus(BaseVectorDB):
elif has_collection: elif has_collection:
return return
schema = self.client.create_schema( schema = self.client.create_schema(
enable_dynamic_field=False, auto_id=True, description=description enable_dynamic_field=False, auto_id=False, description=description
) )
schema.add_field("id", DataType.INT64, is_primary=True) schema.add_field("id", DataType.VARCHAR, is_primary=True, max_length=64)
schema.add_field("embedding", DataType.FLOAT_VECTOR, dim=dim) schema.add_field("embedding", DataType.FLOAT_VECTOR, dim=dim)
if self.hybrid: if self.hybrid:
@ -156,6 +156,7 @@ class Milvus(BaseVectorDB):
""" """
if not collection: if not collection:
collection = self.default_collection collection = self.default_collection
ids = [chunk.metadata.get('id', '') for chunk in chunks]
texts = [chunk.text for chunk in chunks] texts = [chunk.text for chunk in chunks]
references = [chunk.reference for chunk in chunks] references = [chunk.reference for chunk in chunks]
metadatas = [chunk.metadata for chunk in chunks] metadatas = [chunk.metadata for chunk in chunks]
@ -163,13 +164,14 @@ class Milvus(BaseVectorDB):
datas = [ datas = [
{ {
"id": id,
"embedding": embedding, "embedding": embedding,
"text": text, "text": text,
"reference": reference, "reference": reference,
"metadata": metadata, "metadata": metadata,
} }
for embedding, text, reference, metadata in zip( for id, embedding, text, reference, metadata in zip(
embeddings, texts, references, metadatas ids, embeddings, texts, references, metadatas
) )
] ]
batch_datas = [datas[i : i + batch_size] for i in range(0, len(datas), batch_size)] batch_datas = [datas[i : i + batch_size] for i in range(0, len(datas), batch_size)]

5
test.py

@ -1,6 +1,5 @@
from deepsearcher.configuration import Configuration, init_config from deepsearcher.configuration import Configuration, init_config
from deepsearcher.offline_loading import load_from_local_files
# from deepsearcher.offline_loading import load_from_local_files
from deepsearcher.online_query import query from deepsearcher.online_query import query
config = Configuration() config = Configuration()
@ -12,7 +11,7 @@ config.load_config_from_yaml("deepsearcher/config.yaml")
init_config(config = config) init_config(config = config)
# Load your local data # Load your local data
# load_from_local_files(paths_or_directory="examples/data") load_from_local_files(paths_or_directory="examples/data", force_rebuild=True)
# (Optional) Load from web crawling (`FIRECRAWL_API_KEY` env variable required) # (Optional) Load from web crawling (`FIRECRAWL_API_KEY` env variable required)
# from deepsearcher.offline_loading import load_from_website # from deepsearcher.offline_loading import load_from_website

Loading…
Cancel
Save