diff --git a/deepsearcher/agent/deep_search.py b/deepsearcher/agent/deep_search.py index ce55f29..ab620a2 100644 --- a/deepsearcher/agent/deep_search.py +++ b/deepsearcher/agent/deep_search.py @@ -5,7 +5,7 @@ from deepsearcher.embedding.base import BaseEmbedding from deepsearcher.llm.base import BaseLLM from deepsearcher.utils import log from deepsearcher.vector_db import RetrievalResult -from deepsearcher.vector_db.base import BaseVectorDB, deduplicate_results +from deepsearcher.vector_db.base import BaseVectorDB, deduplicate COLLECTION_ROUTE_PROMPT = """ I provide you with collection_name(s) and corresponding collection_description(s). @@ -72,8 +72,8 @@ Respond exclusively in valid List of str format without any other text.""" SUMMARY_PROMPT = """ -You are a AI content analysis expert, good at summarizing content. -Please summarize a long, specific and detailed answer or report based on the previous queries and the retrieved document chunks. +You are a AI content analysis expert. +Please generate a long, specific and detailed answer or report based on the previous queries and the retrieved document chunks. Original Query: {question} @@ -148,13 +148,13 @@ class DeepSearch(BaseAgent): log.color_print( "No collections found in the vector database. Please check the database connection." ) - return [], 0 + return [] if len(collection_infos) == 1: the_only_collection = collection_infos[0].collection_name log.color_print( f" Perform search [{query}] on the vector DB collection: {the_only_collection} \n" ) - return [the_only_collection], 0 + return [the_only_collection] vector_db_search_prompt = COLLECTION_ROUTE_PROMPT.format( question=query, collection_info=[ @@ -341,7 +341,7 @@ class DeepSearch(BaseAgent): search_res = result search_res_from_vectordb.extend(search_res) - search_res_from_vectordb = deduplicate_results(search_res_from_vectordb) + search_res_from_vectordb = deduplicate(search_res_from_vectordb) # search_res_from_internet = deduplicate_results(search_res_from_internet) all_search_res.extend(search_res_from_vectordb + search_res_from_internet) if iter == max_iter - 1: @@ -361,7 +361,7 @@ class DeepSearch(BaseAgent): ) all_sub_queries.extend(sub_gap_queries) - all_search_res = deduplicate_results(all_search_res) + all_search_res = deduplicate(all_search_res) additional_info = {"all_sub_queries": all_sub_queries} return all_search_res, additional_info diff --git a/deepsearcher/configuration.py b/deepsearcher/configuration.py index 2698a80..09b741c 100644 --- a/deepsearcher/configuration.py +++ b/deepsearcher/configuration.py @@ -197,8 +197,7 @@ def init_config(config: Configuration): file_loader, \ vector_db, \ web_crawler, \ - default_searcher, \ - naive_rag + default_searcher module_factory = ModuleFactory(config) llm = module_factory.create_llm() embedding_model = module_factory.create_embedding() diff --git a/deepsearcher/offline_loading.py b/deepsearcher/offline_loading.py index 01598b8..f460420 100644 --- a/deepsearcher/offline_loading.py +++ b/deepsearcher/offline_loading.py @@ -12,11 +12,10 @@ def load_from_local_files( paths_or_directory: str | list[str], collection_name: str = None, collection_description: str = None, - force_new_collection: bool = False, + force_rebuild: bool = False, chunk_size: int = 1500, chunk_overlap: int = 100, - batch_size: int = 256, - force_rebuild: bool = False, + batch_size: int = 256 ): """ Load knowledge from local files or directories into the vector database. @@ -28,7 +27,7 @@ def load_from_local_files( paths_or_directory: A single path or a list of paths to files or directories to load. collection_name: Name of the collection to store the data in. If None, uses the default collection. collection_description: Description of the collection. If None, no description is set. - force_new_collection: If True, drops the existing collection and creates a new one. + force_rebuild: If True, drops the existing collection and creates a new one. chunk_size: Size of each chunk in characters. chunk_overlap: Number of characters to overlap between chunks. batch_size: Number of chunks to process at once during embedding. @@ -47,7 +46,7 @@ def load_from_local_files( dim=embedding_model.dimension, collection=collection_name, description=collection_description, - force_new_collection=force_new_collection, + force_rebuild=force_rebuild, ) # 如果force_rebuild为True,则强制重建集合 @@ -56,7 +55,7 @@ def load_from_local_files( dim=embedding_model.dimension, collection=collection_name, description=collection_description, - force_new_collection=True, + force_rebuild=True, ) if isinstance(paths_or_directory, str): @@ -94,7 +93,7 @@ def load_from_website( urls: str | list[str], collection_name: str = None, collection_description: str = None, - force_new_collection: bool = False, + force_rebuild: bool = False, chunk_size: int = 1500, chunk_overlap: int = 100, batch_size: int = 256, @@ -110,7 +109,7 @@ def load_from_website( urls: A single URL or a list of URLs to crawl. collection_name: Name of the collection to store the data in. If None, uses the default collection. collection_description: Description of the collection. If None, no description is set. - force_new_collection: If True, drops the existing collection and creates a new one. + force_rebuild: If True, drops the existing collection and creates a new one. chunk_size: Size of each chunk in characters. chunk_overlap: Number of characters to overlap between chunks. batch_size: Number of chunks to process at once during embedding. @@ -126,7 +125,7 @@ def load_from_website( dim=embedding_model.dimension, collection=collection_name, description=collection_description, - force_new_collection=force_new_collection, + force_rebuild=force_rebuild, ) all_docs = web_crawler.crawl_urls(urls, **crawl_kwargs) diff --git a/deepsearcher/vector_db/__init__.py b/deepsearcher/vector_db/__init__.py index cc6331e..100b5f4 100644 --- a/deepsearcher/vector_db/__init__.py +++ b/deepsearcher/vector_db/__init__.py @@ -1,5 +1,3 @@ from .milvus import Milvus, RetrievalResult -from .oracle import OracleDB -from .qdrant import Qdrant -__all__ = ["Milvus", "RetrievalResult", "OracleDB", "Qdrant"] +__all__ = ["Milvus", "RetrievalResult"] diff --git a/deepsearcher/vector_db/base.py b/deepsearcher/vector_db/base.py index 6962e58..8f250e0 100644 --- a/deepsearcher/vector_db/base.py +++ b/deepsearcher/vector_db/base.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import List, Union import numpy as np @@ -23,7 +22,7 @@ class RetrievalResult: def __init__( self, - embedding: np.array, + embedding: np.ndarray, text: str, reference: str, metadata: dict, @@ -55,7 +54,7 @@ class RetrievalResult: return f"RetrievalResult(score={self.score}, embedding={self.embedding}, text={self.text}, reference={self.reference}), metadata={self.metadata}" -def deduplicate_results(results: List[RetrievalResult]) -> List[RetrievalResult]: +def deduplicate(results: list[RetrievalResult]) -> list[RetrievalResult]: """ Remove duplicate results based on text content. @@ -68,13 +67,13 @@ def deduplicate_results(results: List[RetrievalResult]) -> List[RetrievalResult] Returns: A list of deduplicated RetrievalResult objects. """ - all_text_set = set() - deduplicated_results = [] + text = set() + deduplicated = [] for result in results: - if result.text not in all_text_set: - all_text_set.add(result.text) - deduplicated_results.append(result) - return deduplicated_results + if result.text not in text: + text.add(result.text) + deduplicated.append(result) + return deduplicated class CollectionInfo: @@ -114,7 +113,6 @@ class BaseVectorDB(ABC): def __init__( self, - default_collection: str = "deepsearcher", *args, **kwargs, ): @@ -126,7 +124,6 @@ class BaseVectorDB(ABC): *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. """ - self.default_collection = default_collection @abstractmethod def init_collection( @@ -134,7 +131,7 @@ class BaseVectorDB(ABC): dim: int, collection: str, description: str, - force_new_collection=False, + force_rebuild=False, *args, **kwargs, ): @@ -145,14 +142,14 @@ class BaseVectorDB(ABC): dim: The dimensionality of the vectors in the collection. collection: The name of the collection. description: The description of the collection. - force_new_collection: If True, drop the existing collection and create a new one. + force_rebuild: If True, drop the existing collection and create a new one. *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. """ pass @abstractmethod - def insert_data(self, collection: str, chunks: List[Chunk], *args, **kwargs): + def insert_data(self, collection: str, chunks: list[Chunk], *args, **kwargs): """ Insert data into a collection in the vector database. @@ -166,8 +163,8 @@ class BaseVectorDB(ABC): @abstractmethod def search_data( - self, collection: str, vector: Union[np.array, List[float]], *args, **kwargs - ) -> List[RetrievalResult]: + self, collection: str, vector: np.ndarray | list[float], *args, **kwargs + ) -> list[RetrievalResult]: """ Search for similar vectors in a collection. @@ -182,7 +179,7 @@ class BaseVectorDB(ABC): """ pass - def list_collections(self, *args, **kwargs) -> List[CollectionInfo]: + def list_collections(self, *args, **kwargs) -> list[CollectionInfo]: """ List all collections in the vector database. @@ -196,7 +193,7 @@ class BaseVectorDB(ABC): pass @abstractmethod - def clear_db(self, *args, **kwargs): + def clear_collection(self, *args, **kwargs): """ Clear the vector database. @@ -205,3 +202,14 @@ class BaseVectorDB(ABC): **kwargs: Arbitrary keyword arguments. """ pass + + @abstractmethod + def clear_collections(self, *args, **kwargs): + """ + Clear all collections in the vector database. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ + pass diff --git a/deepsearcher/vector_db/milvus.py b/deepsearcher/vector_db/milvus.py index fc34757..0d19b41 100644 --- a/deepsearcher/vector_db/milvus.py +++ b/deepsearcher/vector_db/milvus.py @@ -1,7 +1,5 @@ -from typing import List, Optional, Union - import numpy as np -from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker +from pymilvus import DataType, MilvusClient from deepsearcher.loader.splitter import Chunk from deepsearcher.utils import log @@ -15,43 +13,36 @@ class Milvus(BaseVectorDB): def __init__( self, - default_collection: str = "deepsearcher", uri: str = "http://localhost:19530", token: str = "root:Milvus", user: str = "", password: str = "", db: str = "default", - hybrid: bool = False, **kwargs, ): """ Initialize the Milvus client. Args: - default_collection (str, optional): Default collection name. Defaults to "deepsearcher". uri (str, optional): URI for connecting to Milvus server. Defaults to "http://localhost:19530". token (str, optional): Authentication token for Milvus. Defaults to "root:Milvus". user (str, optional): Username for authentication. Defaults to "". password (str, optional): Password for authentication. Defaults to "". db (str, optional): Database name. Defaults to "default". - hybrid (bool, optional): Whether to enable hybrid search. Defaults to False. **kwargs: Additional keyword arguments to pass to the MilvusClient. """ - super().__init__(default_collection) - self.default_collection = default_collection + super().__init__() self.client = MilvusClient( uri=uri, user=user, password=password, token=token, db_name=db, timeout=30, **kwargs ) - self.hybrid = hybrid - def init_collection( self, dim: int, - collection: Optional[str] = "deepsearcher", - description: Optional[str] = "", - force_new_collection: bool = False, - text_max_length: int = 65_535, + collection: str, + description: str, + force_rebuild: bool = False, + text_max_length: int = 65535, reference_max_length: int = 2048, metric_type: str = "L2", *args, @@ -62,9 +53,9 @@ class Milvus(BaseVectorDB): Args: dim (int): Dimension of the vector embeddings. - collection (Optional[str], optional): Collection name. Defaults to "deepsearcher". + collection (Optional[str], optional): Collection name. description (Optional[str], optional): Collection description. Defaults to "". - force_new_collection (bool, optional): Whether to force create a new collection if it already exists. Defaults to False. + force_rebuild (bool, optional): Whether to force create a new collection if it already exists. Defaults to False. text_max_length (int, optional): Maximum length for text field. Defaults to 65_535. reference_max_length (int, optional): Maximum length for reference field. Defaults to 2048. metric_type (str, optional): Metric type for vector similarity search. Defaults to "L2". @@ -80,7 +71,7 @@ class Milvus(BaseVectorDB): try: has_collection = self.client.has_collection(collection, timeout=5) - if force_new_collection and has_collection: + if force_rebuild and has_collection: self.client.drop_collection(collection) elif has_collection: return @@ -90,42 +81,15 @@ class Milvus(BaseVectorDB): schema.add_field("id", DataType.VARCHAR, is_primary=True, max_length=64) schema.add_field("embedding", DataType.FLOAT_VECTOR, dim=dim) - if self.hybrid: - analyzer_params = {"tokenizer": "standard", "filter": ["lowercase"]} - schema.add_field( - "text", - DataType.VARCHAR, - max_length=text_max_length, - analyzer_params=analyzer_params, - enable_match=True, - enable_analyzer=True, - ) - else: - schema.add_field("text", DataType.VARCHAR, max_length=text_max_length) + schema.add_field("text", DataType.VARCHAR, max_length=text_max_length) schema.add_field("reference", DataType.VARCHAR, max_length=reference_max_length) schema.add_field("metadata", DataType.JSON) - if self.hybrid: - schema.add_field("sparse_vector", DataType.SPARSE_FLOAT_VECTOR) - bm25_function = Function( - name="bm25", - function_type=FunctionType.BM25, - input_field_names=["text"], - output_field_names="sparse_vector", - ) - schema.add_function(bm25_function) index_params = self.client.prepare_index_params() index_params.add_index(field_name="embedding", metric_type=metric_type) - if self.hybrid: - index_params.add_index( - field_name="sparse_vector", - index_type="SPARSE_INVERTED_INDEX", - metric_type="BM25", - ) - self.client.create_collection( collection, schema=schema, @@ -138,8 +102,8 @@ class Milvus(BaseVectorDB): def insert_data( self, - collection: Optional[str], - chunks: List[Chunk], + collection: str, + chunks: list[Chunk], batch_size: int = 256, *args, **kwargs, @@ -183,13 +147,13 @@ class Milvus(BaseVectorDB): def search_data( self, - collection: Optional[str], - vector: Union[np.array, List[float]], - top_k: int = 5, - query_text: Optional[str] = None, + collection: str, + vector: np.ndarray | list[float], + top_k: int = 4, + query_text: str = None, *args, **kwargs, - ) -> List[RetrievalResult]: + ) -> list[RetrievalResult]: """ Search for similar vectors in a Milvus collection. @@ -207,35 +171,13 @@ class Milvus(BaseVectorDB): if not collection: collection = self.default_collection try: - use_hybrid = self.hybrid and query_text - - if use_hybrid: - sparse_search_params = {"metric_type": "BM25"} - sparse_request = AnnSearchRequest( - [query_text], "sparse_vector", sparse_search_params, limit=top_k - ) - - dense_search_params = {"metric_type": self.metric_type} - dense_request = AnnSearchRequest( - [vector], "embedding", dense_search_params, limit=top_k - ) - - search_results = self.client.hybrid_search( - collection_name=collection, - reqs=[sparse_request, dense_request], - ranker=RRFRanker(), - limit=top_k, - output_fields=["embedding", "text", "reference", "metadata"], - timeout=10, - ) - else: - search_results = self.client.search( - collection_name=collection, - data=[vector], - limit=top_k, - output_fields=["embedding", "text", "reference", "metadata"], - timeout=10, - ) + search_results = self.client.search( + collection_name=collection, + data=[vector], + limit=top_k, + output_fields=["embedding", "text", "reference", "metadata"], + timeout=10, + ) return [ RetrievalResult( @@ -252,7 +194,7 @@ class Milvus(BaseVectorDB): log.critical(f"fail to search data, error info: {e}") return [] - def list_collections(self, *args, **kwargs) -> List[CollectionInfo]: + def list_collections(self, *args, **kwargs) -> list[CollectionInfo]: """ List all collections in the Milvus database. @@ -290,7 +232,7 @@ class Milvus(BaseVectorDB): log.critical(f"fail to list collections, error info: {e}") return collection_infos - def clear_db(self, collection: str = "deepsearcher", *args, **kwargs): + def clear_collection(self, collection: str, *args, **kwargs): """ Clear (drop) a collection from the Milvus database. @@ -305,3 +247,18 @@ class Milvus(BaseVectorDB): self.client.drop_collection(collection) except Exception as e: log.warning(f"fail to clear db, error info: {e}") + + def clear_collections(self, *args, **kwargs): + """ + Clear (drop) all collections from the Milvus database. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ + try: + collections = self.client.list_collections() + for collection in collections: + self.client.drop_collection(collection) + except Exception as e: + log.warning(f"fail to clear all collections, error info: {e}") diff --git a/test.py b/test.py index d32932d..0e78fc8 100644 --- a/test.py +++ b/test.py @@ -11,7 +11,12 @@ config.load_config_from_yaml("deepsearcher/config.yaml") init_config(config = config) # Load your local data -load_from_local_files(paths_or_directory="examples/data", force_rebuild=True, batch_size=8) +load_from_local_files( + paths_or_directory="examples/data", + collection_name="default", + collection_description="a general collection for all documents", + force_rebuild=True, batch_size=8 +) # (Optional) Load from web crawling (`FIRECRAWL_API_KEY` env variable required) # from deepsearcher.offline_loading import load_from_website