From b53d77799bb777f1f2a1fb2b1f197038e15d63c5 Mon Sep 17 00:00:00 2001 From: tanxing Date: Wed, 13 Aug 2025 12:35:01 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8Dmax=5Fiter=E8=BF=AD?= =?UTF-8?q?=E4=BB=A3=E5=A4=B1=E6=95=88=E7=9A=84=E9=97=AE=E9=A2=98=20fix:?= =?UTF-8?q?=20=E4=BF=AE=E5=A4=8Ddefault=5Fcollection=E5=8F=82=E6=95=B0?= =?UTF-8?q?=E7=BC=BA=E5=A4=B1=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- deepsearcher/agent/deep_search.py | 7 +++++-- deepsearcher/config.yaml | 2 +- deepsearcher/offline_loading.py | 15 +++++++-------- deepsearcher/vector_db/milvus.py | 3 +++ main.py | 14 ++++++++++++++ 5 files changed, 30 insertions(+), 11 deletions(-) diff --git a/deepsearcher/agent/deep_search.py b/deepsearcher/agent/deep_search.py index bdc318d..02565da 100644 --- a/deepsearcher/agent/deep_search.py +++ b/deepsearcher/agent/deep_search.py @@ -323,6 +323,9 @@ class DeepSearch(BaseAgent): - A list of retrieved document results - Additional information about the retrieval process """ + # Get max_iter from kwargs or use default + max_iter = kwargs.get('max_iter', self.max_iter) + ### SUB QUERIES ### send_think(f" {original_query} ") all_search_results = [] @@ -336,7 +339,7 @@ class DeepSearch(BaseAgent): send_think(f"Break down the original query into new sub queries: {sub_queries}") all_sub_queries.extend(sub_queries) - for it in range(self.max_iter): + for it in range(max_iter): send_think(f">> Iteration: {it + 1}") # Execute all search tasks sequentially @@ -353,7 +356,7 @@ class DeepSearch(BaseAgent): ### REFLECTION & GET MORE SUB QUERIES ### # Only generate more queries if we haven't reached the maximum iterations - if it + 1 < self.max_iter: + if it + 1 < max_iter: send_think("Reflecting on the search results...") sub_queries = self._generate_more_sub_queries( original_query, all_sub_queries, all_search_results diff --git a/deepsearcher/config.yaml b/deepsearcher/config.yaml index 29e1b8e..fef4343 100644 --- a/deepsearcher/config.yaml +++ b/deepsearcher/config.yaml @@ -80,7 +80,7 @@ provide_settings: # port: 6333 query_settings: - max_iter: 1 + max_iter: 3 load_settings: chunk_size: 1024 diff --git a/deepsearcher/offline_loading.py b/deepsearcher/offline_loading.py index 57c1506..df753f5 100644 --- a/deepsearcher/offline_loading.py +++ b/deepsearcher/offline_loading.py @@ -43,14 +43,13 @@ def load_from_local_files( embedding_model = configuration.embedding_model file_loader = configuration.file_loader - # 如果force_rebuild为True,则强制重建集合 - if force_rebuild: - vector_db.init_collection( - dim=embedding_model.dimension, - collection=collection_name, - description=collection_description, - force_rebuild=True, - ) + # 初始化集合(如果不存在则创建,如果force_rebuild为True则重建) + vector_db.init_collection( + dim=embedding_model.dimension, + collection=collection_name, + description=collection_description, + force_rebuild=force_rebuild, + ) if isinstance(paths_or_directory, str): paths_or_directory = [paths_or_directory] diff --git a/deepsearcher/vector_db/milvus.py b/deepsearcher/vector_db/milvus.py index 98acb05..720bef9 100644 --- a/deepsearcher/vector_db/milvus.py +++ b/deepsearcher/vector_db/milvus.py @@ -18,6 +18,7 @@ class Milvus(BaseVectorDB): user: str = "", password: str = "", db: str = "default", + default_collection: str = "deepsearcher", **kwargs, ): """ @@ -29,9 +30,11 @@ class Milvus(BaseVectorDB): user (str, optional): Username for authentication. Defaults to "". password (str, optional): Password for authentication. Defaults to "". db (str, optional): Database name. Defaults to "default". + default_collection (str, optional): Default collection name. Defaults to "deepsearcher". **kwargs: Additional keyword arguments to pass to the MilvusClient. """ super().__init__() + self.default_collection = default_collection self.client = MilvusClient( uri=uri, user=user, password=password, token=token, db_name=db, timeout=30, **kwargs ) diff --git a/main.py b/main.py index d74b823..999df08 100644 --- a/main.py +++ b/main.py @@ -98,6 +98,11 @@ def load_files( description="Optional batch size for the collection.", examples=[256], ), + force_rebuild: bool = Body( + False, + description="Whether to force rebuild the collection if it already exists.", + examples=[False], + ), ): """ Load files into the vector database. @@ -107,6 +112,7 @@ def load_files( collection_name (str, optional): Name for the collection. Defaults to None. collection_description (str, optional): Description for the collection. Defaults to None. batch_size (int, optional): Batch size for processing. Defaults to None. + force_rebuild (bool, optional): Whether to force rebuild the collection. Defaults to False. Returns: dict: A dictionary containing a success message. @@ -120,6 +126,7 @@ def load_files( collection_name=collection_name, collection_description=collection_description, batch_size=batch_size if batch_size is not None else 8, + force_rebuild=force_rebuild, ) return {"message": "Files loaded successfully."} except Exception as e: @@ -148,6 +155,11 @@ def load_website( description="Optional batch size for the collection.", examples=[256], ), + force_rebuild: bool = Body( + False, + description="Whether to force rebuild the collection if it already exists.", + examples=[False], + ), ): """ Load website content into the vector database. @@ -157,6 +169,7 @@ def load_website( collection_name (str, optional): Name for the collection. Defaults to None. collection_description (str, optional): Description for the collection. Defaults to None. batch_size (int, optional): Batch size for processing. Defaults to None. + force_rebuild (bool, optional): Whether to force rebuild the collection. Defaults to False. Returns: dict: A dictionary containing a success message. @@ -170,6 +183,7 @@ def load_website( collection_name=collection_name, collection_description=collection_description, batch_size=batch_size if batch_size is not None else 8, + force_rebuild=force_rebuild, ) return {"message": "Website loaded successfully."} except Exception as e: