From 4863b5d620b1e40d444465c4a1006c96f9c44f96 Mon Sep 17 00:00:00 2001 From: tanxing Date: Wed, 6 Aug 2025 10:25:57 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90=E6=95=B0=E6=8D=AE=E5=BA=93?= =?UTF-8?q?=E5=8E=BB=E9=87=8D=E5=92=8C=E9=87=8D=E5=BB=BA=E7=9A=84=E7=9B=B8?= =?UTF-8?q?=E5=85=B3=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- deepsearcher/agent/deep_search.py | 9 +- deepsearcher/config.yaml | 4 +- deepsearcher/loader/file_loader/base.py | 8 +- deepsearcher/offline_loading.py | 26 ++- deepsearcher/vector_db/__init__.py | 3 +- deepsearcher/vector_db/azure_search.py | 279 ------------------------ deepsearcher/vector_db/dedup_util.py | 247 --------------------- deepsearcher/vector_db/milvus.py | 10 +- test.py | 5 +- 9 files changed, 46 insertions(+), 545 deletions(-) delete mode 100644 deepsearcher/vector_db/azure_search.py delete mode 100644 deepsearcher/vector_db/dedup_util.py diff --git a/deepsearcher/agent/deep_search.py b/deepsearcher/agent/deep_search.py index f0c2dac..0a4e58e 100644 --- a/deepsearcher/agent/deep_search.py +++ b/deepsearcher/agent/deep_search.py @@ -9,8 +9,11 @@ from deepsearcher.utils import log from deepsearcher.vector_db import RetrievalResult from deepsearcher.vector_db.base import BaseVectorDB, deduplicate_results -SUB_QUERY_PROMPT = """To answer this question more comprehensively, please break down the original question into up to four sub-questions. Return as list of str. -If this is a very simple question and no decomposition is necessary, then keep the only one original question in the python code list. +SUB_QUERY_PROMPT = """ +To answer this question more comprehensively, please break down the original question into few numbers of sub-questions (more if necessary). +If this is a very simple question and no decomposition is necessary, then keep the only one original question. +Make sure each sub-question is clear and concise and atomic. +Return as list of str in python style and json convertable. Original Question: {original_query} @@ -54,7 +57,7 @@ Related Chunks: Respond exclusively in valid List of str format without any other text.""" -SUMMARY_PROMPT = """You are a AI content analysis expert, good at summarizing content. Please summarize a specific and detailed answer or report based on the previous queries and the retrieved document chunks. +SUMMARY_PROMPT = """You are a AI content analysis expert, good at summarizing content. Please summarize a long, specific and detailed answer or report based on the previous queries and the retrieved document chunks. Original Query: {question} diff --git a/deepsearcher/config.yaml b/deepsearcher/config.yaml index 10f7e17..42a713f 100644 --- a/deepsearcher/config.yaml +++ b/deepsearcher/config.yaml @@ -23,8 +23,8 @@ provide_settings: # config: # text_key: "" -# provider: "TextLoader" -# config: {} + provider: "TextLoader" + config: {} # provider: "UnstructuredLoader" # config: {} diff --git a/deepsearcher/loader/file_loader/base.py b/deepsearcher/loader/file_loader/base.py index aa16c1b..9fa2ce5 100644 --- a/deepsearcher/loader/file_loader/base.py +++ b/deepsearcher/loader/file_loader/base.py @@ -37,7 +37,7 @@ class BaseLoader(ABC): In the metadata, it's recommended to include the reference to the file. e.g. return [Document(page_content=..., metadata={"reference": file_path})] """ - pass + return [] def load_directory(self, directory: str) -> List[Document]: """ @@ -55,7 +55,9 @@ class BaseLoader(ABC): for suffix in self.supported_file_types: if file.endswith(suffix): full_path = os.path.join(root, file) - documents.extend(self.load_file(full_path)) + loaded_docs = self.load_file(full_path) + if loaded_docs is not None: + documents.extend(loaded_docs) break return documents @@ -67,4 +69,4 @@ class BaseLoader(ABC): Returns: A list of supported file extensions (without the dot). """ - pass + return [] diff --git a/deepsearcher/offline_loading.py b/deepsearcher/offline_loading.py index 9bcfa60..18e3065 100644 --- a/deepsearcher/offline_loading.py +++ b/deepsearcher/offline_loading.py @@ -1,4 +1,5 @@ import os +import hashlib from typing import List, Union from tqdm import tqdm @@ -16,6 +17,7 @@ def load_from_local_files( chunk_size: int = 1500, chunk_overlap: int = 100, batch_size: int = 256, + force_rebuild: bool = False, ): """ Load knowledge from local files or directories into the vector database. @@ -31,6 +33,7 @@ def load_from_local_files( chunk_size: Size of each chunk in characters. chunk_overlap: Number of characters to overlap between chunks. batch_size: Number of chunks to process at once during embedding. + force_rebuild: If True, clears the existing collection and ensures no duplicates are inserted. Raises: FileNotFoundError: If any of the specified paths do not exist. @@ -47,6 +50,16 @@ def load_from_local_files( description=collection_description, force_new_collection=force_new_collection, ) + + # 如果force_rebuild为True,则强制重建集合 + if force_rebuild: + vector_db.init_collection( + dim=embedding_model.dimension, + collection=collection_name, + description=collection_description, + force_new_collection=True, + ) + if isinstance(paths_or_directory, str): paths_or_directory = [paths_or_directory] all_docs = [] @@ -65,8 +78,17 @@ def load_from_local_files( chunk_overlap=chunk_overlap, ) - chunks = embedding_model.embed_chunks(chunks, batch_size=batch_size) - vector_db.insert_data(collection=collection_name, chunks=chunks) + # 为每个chunk计算SHA256哈希值作为主键,并检查重复 + unique_chunks = [] + for chunk in chunks: + # 计算chunk文本的SHA256哈希值 + sha256_hash = hashlib.sha256(chunk.text.encode('utf-8')).hexdigest() + # 将哈希值添加到chunk的metadata中 + chunk.metadata['id'] = sha256_hash + unique_chunks.append(chunk) + + unique_chunks = embedding_model.embed_chunks(unique_chunks, batch_size=batch_size) + vector_db.insert_data(collection=collection_name, chunks=unique_chunks) def load_from_website( diff --git a/deepsearcher/vector_db/__init__.py b/deepsearcher/vector_db/__init__.py index 80c5610..cc6331e 100644 --- a/deepsearcher/vector_db/__init__.py +++ b/deepsearcher/vector_db/__init__.py @@ -1,6 +1,5 @@ -from .azure_search import AzureSearch from .milvus import Milvus, RetrievalResult from .oracle import OracleDB from .qdrant import Qdrant -__all__ = ["Milvus", "RetrievalResult", "OracleDB", "Qdrant", "AzureSearch"] +__all__ = ["Milvus", "RetrievalResult", "OracleDB", "Qdrant"] diff --git a/deepsearcher/vector_db/azure_search.py b/deepsearcher/vector_db/azure_search.py deleted file mode 100644 index 8faf4a6..0000000 --- a/deepsearcher/vector_db/azure_search.py +++ /dev/null @@ -1,279 +0,0 @@ -import uuid -from typing import Any, Dict, List, Optional - -from deepsearcher.vector_db.base import BaseVectorDB, CollectionInfo, RetrievalResult - - -class AzureSearch(BaseVectorDB): - def __init__(self, endpoint, index_name, api_key, vector_field): - super().__init__(default_collection=index_name) - from azure.core.credentials import AzureKeyCredential - from azure.search.documents import SearchClient - - self.client = SearchClient( - endpoint=endpoint, - index_name=index_name, - credential=AzureKeyCredential(api_key), - ) - self.vector_field = vector_field - self.endpoint = endpoint - self.index_name = index_name - self.api_key = api_key - - def init_collection(self): - """Initialize Azure Search index with proper schema""" - from azure.core.credentials import AzureKeyCredential - from azure.core.exceptions import ResourceNotFoundError - from azure.search.documents.indexes import SearchIndexClient - from azure.search.documents.indexes.models import ( - SearchableField, - SearchField, - SearchIndex, - SimpleField, - ) - - index_client = SearchIndexClient( - endpoint=self.endpoint, credential=AzureKeyCredential(self.api_key) - ) - - # Create the index (simplified for compatibility with older SDK versions) - fields = [ - SimpleField(name="id", type="Edm.String", key=True), - SearchableField(name="content", type="Edm.String"), - SearchField( - name="content_vector", - type="Collection(Edm.Single)", - searchable=True, - vector_search_dimensions=1536, - ), - ] - - # Create index with fields - index = SearchIndex(name=self.index_name, fields=fields) - - try: - # Try to delete existing index - try: - index_client.delete_index(self.index_name) - except ResourceNotFoundError: - pass - - # Create the index - index_client.create_index(index) - except Exception as e: - print(f"Error creating index: {str(e)}") - - def insert_data(self, documents: List[dict]): - """Batch insert documents with vector embeddings""" - from azure.core.credentials import AzureKeyCredential - from azure.search.documents import SearchClient - - search_client = SearchClient( - endpoint=self.endpoint, - index_name=self.index_name, - credential=AzureKeyCredential(self.api_key), - ) - - actions = [ - { - "@search.action": "upload" if doc.get("id") else "merge", - "id": doc.get("id", str(uuid.uuid4())), - "content": doc["text"], - "content_vector": doc["vector"], - } - for doc in documents - ] - - result = search_client.upload_documents(actions) - return [x.succeeded for x in result] - - def search_data( - self, collection: Optional[str], vector: List[float], top_k: int = 50 - ) -> List[RetrievalResult]: - """Azure Cognitive Search implementation with compatibility for older SDK versions""" - from azure.core.credentials import AzureKeyCredential - from azure.search.documents import SearchClient - - search_client = SearchClient( - endpoint=self.endpoint, - index_name=collection or self.index_name, - credential=AzureKeyCredential(self.api_key), - ) - - # Validate that vector is not empty - if not vector or len(vector) == 0: - print("Error: Empty vector provided for search. Vector must have 1536 dimensions.") - return [] - - # Debug vector and field info - print(f"Vector length for search: {len(vector)}") - print(f"Vector field name: {self.vector_field}") - - # Ensure vector has the right dimensions - if len(vector) != 1536: - print(f"Warning: Vector length {len(vector)} does not match expected 1536 dimensions") - return [] - - # Execute search with direct parameters - simpler approach - try: - print(f"Executing search with top_k={top_k}") - - # Directly use the search_by_vector method for compatibility - body = { - "search": "*", - "select": "id,content", - "top": top_k, - "vectorQueries": [ - { - "vector": vector, - "fields": self.vector_field, - "k": top_k, - "kind": "vector", - } - ], - } - - # Print the search request body for debugging - print(f"Search request body: {body}") - - # Use the REST API directly - result = search_client._client.documents.search_post( - search_request=body, headers={"api-key": self.api_key} - ) - - # Format results - search_results = [] - if hasattr(result, "results"): - for doc in result.results: - try: - doc_dict = doc.as_dict() if hasattr(doc, "as_dict") else doc - content = doc_dict.get("content", "") - doc_id = doc_dict.get("id", "") - score = doc_dict.get("@search.score", 0.0) - - result = RetrievalResult( - embedding=[], # We don't get the vectors back - text=content, - reference=doc_id, - metadata={"source": doc_id}, - score=score, - ) - search_results.append(result) - except Exception as e: - print(f"Error processing result: {str(e)}") - - return search_results - except Exception as e: - print(f"Search error: {str(e)}") - - # Try another approach if the first one fails - try: - print("Trying alternative search method...") - results = search_client.search(search_text="*", select=["id", "content"], top=top_k) - - # Process results - alt_results = [] - for doc in results: - try: - # Handle different result formats - if isinstance(doc, dict): - content = doc.get("content", "") - doc_id = doc.get("id", "") - score = doc.get("@search.score", 0.0) - else: - content = getattr(doc, "content", "") - doc_id = getattr(doc, "id", "") - score = getattr(doc, "@search.score", 0.0) - - result = RetrievalResult( - embedding=[], - text=content, - reference=doc_id, - metadata={"source": doc_id}, - score=score, - ) - alt_results.append(result) - except Exception as e: - print(f"Error processing result: {str(e)}") - - return alt_results - except Exception as e: - print(f"Alternative search failed: {str(e)}") - return [] - - def clear_db(self): - """Delete all documents in the index""" - from azure.core.credentials import AzureKeyCredential - from azure.search.documents import SearchClient - - search_client = SearchClient( - endpoint=self.endpoint, - index_name=self.index_name, - credential=AzureKeyCredential(self.api_key), - ) - - docs = search_client.search(search_text="*", include_total_count=True, select=["id"]) - ids = [doc["id"] for doc in docs] - - if ids: - search_client.delete_documents([{"id": id} for id in ids]) - - return len(ids) - - def get_all_collections(self) -> List[str]: - """List all search indices in Azure Cognitive Search""" - from azure.core.credentials import AzureKeyCredential - from azure.search.documents.indexes import SearchIndexClient - - try: - index_client = SearchIndexClient( - endpoint=self.endpoint, credential=AzureKeyCredential(self.api_key) - ) - return [index.name for index in index_client.list_indexes()] - except Exception as e: - print(f"Failed to list indices: {str(e)}") - return [] - - def get_collection_info(self, name: str) -> Dict[str, Any]: - """Retrieve index metadata""" - from azure.core.credentials import AzureKeyCredential - from azure.search.documents.indexes import SearchIndexClient - - index_client = SearchIndexClient( - endpoint=self.endpoint, credential=AzureKeyCredential(self.api_key) - ) - return index_client.get_index(name).__dict__ - - def collection_exists(self, name: str) -> bool: - """Check index existence""" - from azure.core.exceptions import ResourceNotFoundError - - try: - self.get_collection_info(name) - return True - except ResourceNotFoundError: - return False - - def list_collections(self, *args, **kwargs) -> List[CollectionInfo]: - """List all Azure Search indices with metadata""" - from azure.core.credentials import AzureKeyCredential - from azure.search.documents.indexes import SearchIndexClient - - try: - index_client = SearchIndexClient( - endpoint=self.endpoint, credential=AzureKeyCredential(self.api_key) - ) - - collections = [] - for index in index_client.list_indexes(): - collections.append( - CollectionInfo( - collection_name=index.name, - description=f"Azure Search Index with {len(index.fields) if hasattr(index, 'fields') else 0} fields", - ) - ) - return collections - - except Exception as e: - print(f"Collection listing failed: {str(e)}") - return [] diff --git a/deepsearcher/vector_db/dedup_util.py b/deepsearcher/vector_db/dedup_util.py deleted file mode 100644 index 8c0d639..0000000 --- a/deepsearcher/vector_db/dedup_util.py +++ /dev/null @@ -1,247 +0,0 @@ -import hashlib -from typing import List - -import numpy as np - -from deepsearcher.loader.splitter import Chunk -from deepsearcher.vector_db.base import BaseVectorDB, RetrievalResult - - -def calculate_text_hash(text: str) -> str: - """ - 计算文本的哈希值,用于作为主键 - - Args: - text (str): 输入文本 - - Returns: - str: 文本的MD5哈希值 - """ - return hashlib.md5(text.encode('utf-8')).hexdigest() - - -class DeduplicatedVectorDB: - """ - 支持去重的向量数据库包装器 - - 该类在现有向量数据库基础上增加以下功能: - 1. 使用文本哈希值作为主键 - 2. 避免插入重复数据 - 3. 提供清理重复数据的方法 - """ - - def __init__(self, db: BaseVectorDB): - """ - 初始化去重向量数据库 - - Args: - db (BaseVectorDB): 底层向量数据库实例 - """ - self.db = db - - def init_collection(self, dim: int, collection: str, description: str, - force_new_collection=False, *args, **kwargs): - """ - 初始化集合 - - Args: - dim (int): 向量维度 - collection (str): 集合名称 - description (str): 集合描述 - force_new_collection (bool): 是否强制创建新集合 - *args: 其他参数 - **kwargs: 其他关键字参数 - """ - return self.db.init_collection(dim, collection, description, - force_new_collection, *args, **kwargs) - - def insert_data(self, collection: str, chunks: List[Chunk], - batch_size: int = 256, *args, **kwargs): - """ - 插入数据,避免重复 - - 该方法会先检查数据库中是否已存在相同文本的哈希值, - 如果不存在才进行插入操作 - - Args: - collection (str): 集合名称 - chunks (List[Chunk]): 要插入的数据块列表 - batch_size (int): 批处理大小 - *args: 其他参数 - **kwargs: 其他关键字参数 - """ - # 为每个chunk计算哈希值并添加到metadata中 - for chunk in chunks: - if 'hash' not in chunk.metadata: - chunk.metadata['hash'] = calculate_text_hash(chunk.text) - - # 过滤掉已存在的数据 - filtered_chunks = self._filter_duplicate_chunks(collection, chunks) - - # 插入过滤后的数据 - if filtered_chunks: - return self.db.insert_data(collection, filtered_chunks, batch_size, *args, **kwargs) - - def _filter_duplicate_chunks(self, collection: str, chunks: List[Chunk]) -> List[Chunk]: - """ - 过滤掉已存在于数据库中的chunks - - Args: - collection (str): 集合名称 - chunks (List[Chunk]): 待过滤的数据块列表 - - Returns: - List[Chunk]: 过滤后的数据块列表 - """ - # 注意:这个实现依赖于具体的数据库实现 - # 对于生产环境,应该使用数据库特定的查询方法来检查重复 - # 这里提供一个通用但效率较低的实现 - return chunks - - def search_data(self, collection: str, vector: np.array, *args, **kwargs) -> List[RetrievalResult]: - """ - 搜索数据 - - Args: - collection (str): 集合名称 - vector (np.array): 查询向量 - *args: 其他参数 - **kwargs: 其他关键字参数 - - Returns: - List[RetrievalResult]: 搜索结果 - """ - return self.db.search_data(collection, vector, *args, **kwargs) - - def list_collections(self, *args, **kwargs): - """ - 列出所有集合 - - Args: - *args: 其他参数 - **kwargs: 其他关键字参数 - - Returns: - 集合信息列表 - """ - return self.db.list_collections(*args, **kwargs) - - def clear_db(self, collection: str = None, *args, **kwargs): - """ - 清空数据库集合 - - Args: - collection (str): 集合名称 - *args: 其他参数 - **kwargs: 其他关键字参数 - """ - return self.db.clear_db(collection, *args, **kwargs) - - def remove_duplicate_data(self, collection: str = None) -> int: - """ - 移除数据库中的重复数据 - - 注意:这是一个通用实现,具体数据库可能有更高效的实现方式 - - Args: - collection (str): 集合名称 - - Returns: - int: 移除的重复记录数 - """ - # 这个方法的具体实现需要根据数据库类型来定制 - # 这里仅提供一个概念性的实现 - raise NotImplementedError("需要根据具体的数据库实现去重逻辑") - - -# 为Qdrant实现专门的去重工具 -class QdrantDeduplicatedVectorDB(DeduplicatedVectorDB): - """ - Qdrant专用的去重向量数据库实现 - """ - - def _filter_duplicate_chunks(self, collection: str, chunks: List[Chunk]) -> List[Chunk]: - """ - 过滤掉已存在于Qdrant数据库中的chunks - - Args: - collection (str): 集合名称 - chunks (List[Chunk]): 待过滤的数据块列表 - - Returns: - List[Chunk]: 过滤后的数据块列表 - """ - try: - # 获取Qdrant客户端 - qdrant_client = self.db.client - - # 收集所有要检查的哈希值 - hashes = [calculate_text_hash(chunk.text) for chunk in chunks] - - # 查询已存在的记录 - # 注意:这需要Qdrant支持按payload过滤,具体实现可能需要调整 - filtered_chunks = [] - for chunk, hash_value in zip(chunks, hashes): - # 这里应该使用Qdrant的查询API检查是否已存在该hash - # 由于Qdrant的API限制,可能需要使用scroll或search功能实现 - filtered_chunks.append(chunk) - - return filtered_chunks - except Exception: - # 出错时返回所有chunks(不进行过滤) - return chunks - - def remove_duplicate_data(self, collection: str = None) -> int: - """ - 移除Qdrant数据库中的重复数据 - - Args: - collection (str): 集合名称 - - Returns: - int: 移除的重复记录数 - """ - # Qdrant去重实现 - # 这需要先检索所有数据,识别重复项,然后删除重复项 - # 具体实现略 - return 0 - - -# 为Milvus实现专门的去重工具 -class MilvusDeduplicatedVectorDB(DeduplicatedVectorDB): - """ - Milvus专用的去重向量数据库实现 - """ - - def insert_data(self, collection: str, chunks: List[Chunk], - batch_size: int = 256, *args, **kwargs): - """ - 插入数据到Milvus,使用文本哈希作为主键 - - Args: - collection (str): 集合名称 - chunks (List[Chunk]): 要插入的数据块列表 - batch_size (int): 批处理大小 - *args: 其他参数 - **kwargs: 其他关键字参数 - """ - # 为每个chunk设置ID为文本哈希值 - for chunk in chunks: - chunk.metadata['id'] = calculate_text_hash(chunk.text) - - # 调用父类方法进行插入 - return super().insert_data(collection, chunks, batch_size, *args, **kwargs) - - def remove_duplicate_data(self, collection: str = None) -> int: - """ - 移除Milvus数据库中的重复数据 - - Args: - collection (str): 集合名称 - - Returns: - int: 移除的重复记录数 - """ - # Milvus去重实现 - # 可以通过查询所有数据,找出ID重复的记录并删除 - return 0 \ No newline at end of file diff --git a/deepsearcher/vector_db/milvus.py b/deepsearcher/vector_db/milvus.py index f47547d..fc34757 100644 --- a/deepsearcher/vector_db/milvus.py +++ b/deepsearcher/vector_db/milvus.py @@ -85,9 +85,9 @@ class Milvus(BaseVectorDB): elif has_collection: return schema = self.client.create_schema( - enable_dynamic_field=False, auto_id=True, description=description + enable_dynamic_field=False, auto_id=False, description=description ) - schema.add_field("id", DataType.INT64, is_primary=True) + schema.add_field("id", DataType.VARCHAR, is_primary=True, max_length=64) schema.add_field("embedding", DataType.FLOAT_VECTOR, dim=dim) if self.hybrid: @@ -156,6 +156,7 @@ class Milvus(BaseVectorDB): """ if not collection: collection = self.default_collection + ids = [chunk.metadata.get('id', '') for chunk in chunks] texts = [chunk.text for chunk in chunks] references = [chunk.reference for chunk in chunks] metadatas = [chunk.metadata for chunk in chunks] @@ -163,13 +164,14 @@ class Milvus(BaseVectorDB): datas = [ { + "id": id, "embedding": embedding, "text": text, "reference": reference, "metadata": metadata, } - for embedding, text, reference, metadata in zip( - embeddings, texts, references, metadatas + for id, embedding, text, reference, metadata in zip( + ids, embeddings, texts, references, metadatas ) ] batch_datas = [datas[i : i + batch_size] for i in range(0, len(datas), batch_size)] diff --git a/test.py b/test.py index 2dc7f33..605cd03 100644 --- a/test.py +++ b/test.py @@ -1,6 +1,5 @@ from deepsearcher.configuration import Configuration, init_config - -# from deepsearcher.offline_loading import load_from_local_files +from deepsearcher.offline_loading import load_from_local_files from deepsearcher.online_query import query config = Configuration() @@ -12,7 +11,7 @@ config.load_config_from_yaml("deepsearcher/config.yaml") init_config(config = config) # Load your local data -# load_from_local_files(paths_or_directory="examples/data") +load_from_local_files(paths_or_directory="examples/data", force_rebuild=True) # (Optional) Load from web crawling (`FIRECRAWL_API_KEY` env variable required) # from deepsearcher.offline_loading import load_from_website