Browse Source

移除默认collection

main
tanxing 1 week ago
parent
commit
3fd43ef722
  1. 14
      deepsearcher/agent/deep_search.py
  2. 3
      deepsearcher/configuration.py
  3. 17
      deepsearcher/offline_loading.py
  4. 4
      deepsearcher/vector_db/__init__.py
  5. 44
      deepsearcher/vector_db/base.py
  6. 109
      deepsearcher/vector_db/milvus.py
  7. 7
      test.py

14
deepsearcher/agent/deep_search.py

@ -5,7 +5,7 @@ from deepsearcher.embedding.base import BaseEmbedding
from deepsearcher.llm.base import BaseLLM from deepsearcher.llm.base import BaseLLM
from deepsearcher.utils import log from deepsearcher.utils import log
from deepsearcher.vector_db import RetrievalResult from deepsearcher.vector_db import RetrievalResult
from deepsearcher.vector_db.base import BaseVectorDB, deduplicate_results from deepsearcher.vector_db.base import BaseVectorDB, deduplicate
COLLECTION_ROUTE_PROMPT = """ COLLECTION_ROUTE_PROMPT = """
I provide you with collection_name(s) and corresponding collection_description(s). I provide you with collection_name(s) and corresponding collection_description(s).
@ -72,8 +72,8 @@ Respond exclusively in valid List of str format without any other text."""
SUMMARY_PROMPT = """ SUMMARY_PROMPT = """
You are a AI content analysis expert, good at summarizing content. You are a AI content analysis expert.
Please summarize a long, specific and detailed answer or report based on the previous queries and the retrieved document chunks. Please generate a long, specific and detailed answer or report based on the previous queries and the retrieved document chunks.
Original Query: {question} Original Query: {question}
@ -148,13 +148,13 @@ class DeepSearch(BaseAgent):
log.color_print( log.color_print(
"No collections found in the vector database. Please check the database connection." "No collections found in the vector database. Please check the database connection."
) )
return [], 0 return []
if len(collection_infos) == 1: if len(collection_infos) == 1:
the_only_collection = collection_infos[0].collection_name the_only_collection = collection_infos[0].collection_name
log.color_print( log.color_print(
f"<think> Perform search [{query}] on the vector DB collection: {the_only_collection} </think>\n" f"<think> Perform search [{query}] on the vector DB collection: {the_only_collection} </think>\n"
) )
return [the_only_collection], 0 return [the_only_collection]
vector_db_search_prompt = COLLECTION_ROUTE_PROMPT.format( vector_db_search_prompt = COLLECTION_ROUTE_PROMPT.format(
question=query, question=query,
collection_info=[ collection_info=[
@ -341,7 +341,7 @@ class DeepSearch(BaseAgent):
search_res = result search_res = result
search_res_from_vectordb.extend(search_res) search_res_from_vectordb.extend(search_res)
search_res_from_vectordb = deduplicate_results(search_res_from_vectordb) search_res_from_vectordb = deduplicate(search_res_from_vectordb)
# search_res_from_internet = deduplicate_results(search_res_from_internet) # search_res_from_internet = deduplicate_results(search_res_from_internet)
all_search_res.extend(search_res_from_vectordb + search_res_from_internet) all_search_res.extend(search_res_from_vectordb + search_res_from_internet)
if iter == max_iter - 1: if iter == max_iter - 1:
@ -361,7 +361,7 @@ class DeepSearch(BaseAgent):
) )
all_sub_queries.extend(sub_gap_queries) all_sub_queries.extend(sub_gap_queries)
all_search_res = deduplicate_results(all_search_res) all_search_res = deduplicate(all_search_res)
additional_info = {"all_sub_queries": all_sub_queries} additional_info = {"all_sub_queries": all_sub_queries}
return all_search_res, additional_info return all_search_res, additional_info

3
deepsearcher/configuration.py

@ -197,8 +197,7 @@ def init_config(config: Configuration):
file_loader, \ file_loader, \
vector_db, \ vector_db, \
web_crawler, \ web_crawler, \
default_searcher, \ default_searcher
naive_rag
module_factory = ModuleFactory(config) module_factory = ModuleFactory(config)
llm = module_factory.create_llm() llm = module_factory.create_llm()
embedding_model = module_factory.create_embedding() embedding_model = module_factory.create_embedding()

17
deepsearcher/offline_loading.py

@ -12,11 +12,10 @@ def load_from_local_files(
paths_or_directory: str | list[str], paths_or_directory: str | list[str],
collection_name: str = None, collection_name: str = None,
collection_description: str = None, collection_description: str = None,
force_new_collection: bool = False, force_rebuild: bool = False,
chunk_size: int = 1500, chunk_size: int = 1500,
chunk_overlap: int = 100, chunk_overlap: int = 100,
batch_size: int = 256, batch_size: int = 256
force_rebuild: bool = False,
): ):
""" """
Load knowledge from local files or directories into the vector database. Load knowledge from local files or directories into the vector database.
@ -28,7 +27,7 @@ def load_from_local_files(
paths_or_directory: A single path or a list of paths to files or directories to load. paths_or_directory: A single path or a list of paths to files or directories to load.
collection_name: Name of the collection to store the data in. If None, uses the default collection. collection_name: Name of the collection to store the data in. If None, uses the default collection.
collection_description: Description of the collection. If None, no description is set. collection_description: Description of the collection. If None, no description is set.
force_new_collection: If True, drops the existing collection and creates a new one. force_rebuild: If True, drops the existing collection and creates a new one.
chunk_size: Size of each chunk in characters. chunk_size: Size of each chunk in characters.
chunk_overlap: Number of characters to overlap between chunks. chunk_overlap: Number of characters to overlap between chunks.
batch_size: Number of chunks to process at once during embedding. batch_size: Number of chunks to process at once during embedding.
@ -47,7 +46,7 @@ def load_from_local_files(
dim=embedding_model.dimension, dim=embedding_model.dimension,
collection=collection_name, collection=collection_name,
description=collection_description, description=collection_description,
force_new_collection=force_new_collection, force_rebuild=force_rebuild,
) )
# 如果force_rebuild为True,则强制重建集合 # 如果force_rebuild为True,则强制重建集合
@ -56,7 +55,7 @@ def load_from_local_files(
dim=embedding_model.dimension, dim=embedding_model.dimension,
collection=collection_name, collection=collection_name,
description=collection_description, description=collection_description,
force_new_collection=True, force_rebuild=True,
) )
if isinstance(paths_or_directory, str): if isinstance(paths_or_directory, str):
@ -94,7 +93,7 @@ def load_from_website(
urls: str | list[str], urls: str | list[str],
collection_name: str = None, collection_name: str = None,
collection_description: str = None, collection_description: str = None,
force_new_collection: bool = False, force_rebuild: bool = False,
chunk_size: int = 1500, chunk_size: int = 1500,
chunk_overlap: int = 100, chunk_overlap: int = 100,
batch_size: int = 256, batch_size: int = 256,
@ -110,7 +109,7 @@ def load_from_website(
urls: A single URL or a list of URLs to crawl. urls: A single URL or a list of URLs to crawl.
collection_name: Name of the collection to store the data in. If None, uses the default collection. collection_name: Name of the collection to store the data in. If None, uses the default collection.
collection_description: Description of the collection. If None, no description is set. collection_description: Description of the collection. If None, no description is set.
force_new_collection: If True, drops the existing collection and creates a new one. force_rebuild: If True, drops the existing collection and creates a new one.
chunk_size: Size of each chunk in characters. chunk_size: Size of each chunk in characters.
chunk_overlap: Number of characters to overlap between chunks. chunk_overlap: Number of characters to overlap between chunks.
batch_size: Number of chunks to process at once during embedding. batch_size: Number of chunks to process at once during embedding.
@ -126,7 +125,7 @@ def load_from_website(
dim=embedding_model.dimension, dim=embedding_model.dimension,
collection=collection_name, collection=collection_name,
description=collection_description, description=collection_description,
force_new_collection=force_new_collection, force_rebuild=force_rebuild,
) )
all_docs = web_crawler.crawl_urls(urls, **crawl_kwargs) all_docs = web_crawler.crawl_urls(urls, **crawl_kwargs)

4
deepsearcher/vector_db/__init__.py

@ -1,5 +1,3 @@
from .milvus import Milvus, RetrievalResult from .milvus import Milvus, RetrievalResult
from .oracle import OracleDB
from .qdrant import Qdrant
__all__ = ["Milvus", "RetrievalResult", "OracleDB", "Qdrant"] __all__ = ["Milvus", "RetrievalResult"]

44
deepsearcher/vector_db/base.py

@ -1,5 +1,4 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Union
import numpy as np import numpy as np
@ -23,7 +22,7 @@ class RetrievalResult:
def __init__( def __init__(
self, self,
embedding: np.array, embedding: np.ndarray,
text: str, text: str,
reference: str, reference: str,
metadata: dict, metadata: dict,
@ -55,7 +54,7 @@ class RetrievalResult:
return f"RetrievalResult(score={self.score}, embedding={self.embedding}, text={self.text}, reference={self.reference}), metadata={self.metadata}" return f"RetrievalResult(score={self.score}, embedding={self.embedding}, text={self.text}, reference={self.reference}), metadata={self.metadata}"
def deduplicate_results(results: List[RetrievalResult]) -> List[RetrievalResult]: def deduplicate(results: list[RetrievalResult]) -> list[RetrievalResult]:
""" """
Remove duplicate results based on text content. Remove duplicate results based on text content.
@ -68,13 +67,13 @@ def deduplicate_results(results: List[RetrievalResult]) -> List[RetrievalResult]
Returns: Returns:
A list of deduplicated RetrievalResult objects. A list of deduplicated RetrievalResult objects.
""" """
all_text_set = set() text = set()
deduplicated_results = [] deduplicated = []
for result in results: for result in results:
if result.text not in all_text_set: if result.text not in text:
all_text_set.add(result.text) text.add(result.text)
deduplicated_results.append(result) deduplicated.append(result)
return deduplicated_results return deduplicated
class CollectionInfo: class CollectionInfo:
@ -114,7 +113,6 @@ class BaseVectorDB(ABC):
def __init__( def __init__(
self, self,
default_collection: str = "deepsearcher",
*args, *args,
**kwargs, **kwargs,
): ):
@ -126,7 +124,6 @@ class BaseVectorDB(ABC):
*args: Variable length argument list. *args: Variable length argument list.
**kwargs: Arbitrary keyword arguments. **kwargs: Arbitrary keyword arguments.
""" """
self.default_collection = default_collection
@abstractmethod @abstractmethod
def init_collection( def init_collection(
@ -134,7 +131,7 @@ class BaseVectorDB(ABC):
dim: int, dim: int,
collection: str, collection: str,
description: str, description: str,
force_new_collection=False, force_rebuild=False,
*args, *args,
**kwargs, **kwargs,
): ):
@ -145,14 +142,14 @@ class BaseVectorDB(ABC):
dim: The dimensionality of the vectors in the collection. dim: The dimensionality of the vectors in the collection.
collection: The name of the collection. collection: The name of the collection.
description: The description of the collection. description: The description of the collection.
force_new_collection: If True, drop the existing collection and create a new one. force_rebuild: If True, drop the existing collection and create a new one.
*args: Variable length argument list. *args: Variable length argument list.
**kwargs: Arbitrary keyword arguments. **kwargs: Arbitrary keyword arguments.
""" """
pass pass
@abstractmethod @abstractmethod
def insert_data(self, collection: str, chunks: List[Chunk], *args, **kwargs): def insert_data(self, collection: str, chunks: list[Chunk], *args, **kwargs):
""" """
Insert data into a collection in the vector database. Insert data into a collection in the vector database.
@ -166,8 +163,8 @@ class BaseVectorDB(ABC):
@abstractmethod @abstractmethod
def search_data( def search_data(
self, collection: str, vector: Union[np.array, List[float]], *args, **kwargs self, collection: str, vector: np.ndarray | list[float], *args, **kwargs
) -> List[RetrievalResult]: ) -> list[RetrievalResult]:
""" """
Search for similar vectors in a collection. Search for similar vectors in a collection.
@ -182,7 +179,7 @@ class BaseVectorDB(ABC):
""" """
pass pass
def list_collections(self, *args, **kwargs) -> List[CollectionInfo]: def list_collections(self, *args, **kwargs) -> list[CollectionInfo]:
""" """
List all collections in the vector database. List all collections in the vector database.
@ -196,7 +193,7 @@ class BaseVectorDB(ABC):
pass pass
@abstractmethod @abstractmethod
def clear_db(self, *args, **kwargs): def clear_collection(self, *args, **kwargs):
""" """
Clear the vector database. Clear the vector database.
@ -205,3 +202,14 @@ class BaseVectorDB(ABC):
**kwargs: Arbitrary keyword arguments. **kwargs: Arbitrary keyword arguments.
""" """
pass pass
@abstractmethod
def clear_collections(self, *args, **kwargs):
"""
Clear all collections in the vector database.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
pass

109
deepsearcher/vector_db/milvus.py

@ -1,7 +1,5 @@
from typing import List, Optional, Union
import numpy as np import numpy as np
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker from pymilvus import DataType, MilvusClient
from deepsearcher.loader.splitter import Chunk from deepsearcher.loader.splitter import Chunk
from deepsearcher.utils import log from deepsearcher.utils import log
@ -15,43 +13,36 @@ class Milvus(BaseVectorDB):
def __init__( def __init__(
self, self,
default_collection: str = "deepsearcher",
uri: str = "http://localhost:19530", uri: str = "http://localhost:19530",
token: str = "root:Milvus", token: str = "root:Milvus",
user: str = "", user: str = "",
password: str = "", password: str = "",
db: str = "default", db: str = "default",
hybrid: bool = False,
**kwargs, **kwargs,
): ):
""" """
Initialize the Milvus client. Initialize the Milvus client.
Args: Args:
default_collection (str, optional): Default collection name. Defaults to "deepsearcher".
uri (str, optional): URI for connecting to Milvus server. Defaults to "http://localhost:19530". uri (str, optional): URI for connecting to Milvus server. Defaults to "http://localhost:19530".
token (str, optional): Authentication token for Milvus. Defaults to "root:Milvus". token (str, optional): Authentication token for Milvus. Defaults to "root:Milvus".
user (str, optional): Username for authentication. Defaults to "". user (str, optional): Username for authentication. Defaults to "".
password (str, optional): Password for authentication. Defaults to "". password (str, optional): Password for authentication. Defaults to "".
db (str, optional): Database name. Defaults to "default". db (str, optional): Database name. Defaults to "default".
hybrid (bool, optional): Whether to enable hybrid search. Defaults to False.
**kwargs: Additional keyword arguments to pass to the MilvusClient. **kwargs: Additional keyword arguments to pass to the MilvusClient.
""" """
super().__init__(default_collection) super().__init__()
self.default_collection = default_collection
self.client = MilvusClient( self.client = MilvusClient(
uri=uri, user=user, password=password, token=token, db_name=db, timeout=30, **kwargs uri=uri, user=user, password=password, token=token, db_name=db, timeout=30, **kwargs
) )
self.hybrid = hybrid
def init_collection( def init_collection(
self, self,
dim: int, dim: int,
collection: Optional[str] = "deepsearcher", collection: str,
description: Optional[str] = "", description: str,
force_new_collection: bool = False, force_rebuild: bool = False,
text_max_length: int = 65_535, text_max_length: int = 65535,
reference_max_length: int = 2048, reference_max_length: int = 2048,
metric_type: str = "L2", metric_type: str = "L2",
*args, *args,
@ -62,9 +53,9 @@ class Milvus(BaseVectorDB):
Args: Args:
dim (int): Dimension of the vector embeddings. dim (int): Dimension of the vector embeddings.
collection (Optional[str], optional): Collection name. Defaults to "deepsearcher". collection (Optional[str], optional): Collection name.
description (Optional[str], optional): Collection description. Defaults to "". description (Optional[str], optional): Collection description. Defaults to "".
force_new_collection (bool, optional): Whether to force create a new collection if it already exists. Defaults to False. force_rebuild (bool, optional): Whether to force create a new collection if it already exists. Defaults to False.
text_max_length (int, optional): Maximum length for text field. Defaults to 65_535. text_max_length (int, optional): Maximum length for text field. Defaults to 65_535.
reference_max_length (int, optional): Maximum length for reference field. Defaults to 2048. reference_max_length (int, optional): Maximum length for reference field. Defaults to 2048.
metric_type (str, optional): Metric type for vector similarity search. Defaults to "L2". metric_type (str, optional): Metric type for vector similarity search. Defaults to "L2".
@ -80,7 +71,7 @@ class Milvus(BaseVectorDB):
try: try:
has_collection = self.client.has_collection(collection, timeout=5) has_collection = self.client.has_collection(collection, timeout=5)
if force_new_collection and has_collection: if force_rebuild and has_collection:
self.client.drop_collection(collection) self.client.drop_collection(collection)
elif has_collection: elif has_collection:
return return
@ -90,42 +81,15 @@ class Milvus(BaseVectorDB):
schema.add_field("id", DataType.VARCHAR, is_primary=True, max_length=64) schema.add_field("id", DataType.VARCHAR, is_primary=True, max_length=64)
schema.add_field("embedding", DataType.FLOAT_VECTOR, dim=dim) schema.add_field("embedding", DataType.FLOAT_VECTOR, dim=dim)
if self.hybrid:
analyzer_params = {"tokenizer": "standard", "filter": ["lowercase"]}
schema.add_field(
"text",
DataType.VARCHAR,
max_length=text_max_length,
analyzer_params=analyzer_params,
enable_match=True,
enable_analyzer=True,
)
else:
schema.add_field("text", DataType.VARCHAR, max_length=text_max_length) schema.add_field("text", DataType.VARCHAR, max_length=text_max_length)
schema.add_field("reference", DataType.VARCHAR, max_length=reference_max_length) schema.add_field("reference", DataType.VARCHAR, max_length=reference_max_length)
schema.add_field("metadata", DataType.JSON) schema.add_field("metadata", DataType.JSON)
if self.hybrid:
schema.add_field("sparse_vector", DataType.SPARSE_FLOAT_VECTOR)
bm25_function = Function(
name="bm25",
function_type=FunctionType.BM25,
input_field_names=["text"],
output_field_names="sparse_vector",
)
schema.add_function(bm25_function)
index_params = self.client.prepare_index_params() index_params = self.client.prepare_index_params()
index_params.add_index(field_name="embedding", metric_type=metric_type) index_params.add_index(field_name="embedding", metric_type=metric_type)
if self.hybrid:
index_params.add_index(
field_name="sparse_vector",
index_type="SPARSE_INVERTED_INDEX",
metric_type="BM25",
)
self.client.create_collection( self.client.create_collection(
collection, collection,
schema=schema, schema=schema,
@ -138,8 +102,8 @@ class Milvus(BaseVectorDB):
def insert_data( def insert_data(
self, self,
collection: Optional[str], collection: str,
chunks: List[Chunk], chunks: list[Chunk],
batch_size: int = 256, batch_size: int = 256,
*args, *args,
**kwargs, **kwargs,
@ -183,13 +147,13 @@ class Milvus(BaseVectorDB):
def search_data( def search_data(
self, self,
collection: Optional[str], collection: str,
vector: Union[np.array, List[float]], vector: np.ndarray | list[float],
top_k: int = 5, top_k: int = 4,
query_text: Optional[str] = None, query_text: str = None,
*args, *args,
**kwargs, **kwargs,
) -> List[RetrievalResult]: ) -> list[RetrievalResult]:
""" """
Search for similar vectors in a Milvus collection. Search for similar vectors in a Milvus collection.
@ -207,28 +171,6 @@ class Milvus(BaseVectorDB):
if not collection: if not collection:
collection = self.default_collection collection = self.default_collection
try: try:
use_hybrid = self.hybrid and query_text
if use_hybrid:
sparse_search_params = {"metric_type": "BM25"}
sparse_request = AnnSearchRequest(
[query_text], "sparse_vector", sparse_search_params, limit=top_k
)
dense_search_params = {"metric_type": self.metric_type}
dense_request = AnnSearchRequest(
[vector], "embedding", dense_search_params, limit=top_k
)
search_results = self.client.hybrid_search(
collection_name=collection,
reqs=[sparse_request, dense_request],
ranker=RRFRanker(),
limit=top_k,
output_fields=["embedding", "text", "reference", "metadata"],
timeout=10,
)
else:
search_results = self.client.search( search_results = self.client.search(
collection_name=collection, collection_name=collection,
data=[vector], data=[vector],
@ -252,7 +194,7 @@ class Milvus(BaseVectorDB):
log.critical(f"fail to search data, error info: {e}") log.critical(f"fail to search data, error info: {e}")
return [] return []
def list_collections(self, *args, **kwargs) -> List[CollectionInfo]: def list_collections(self, *args, **kwargs) -> list[CollectionInfo]:
""" """
List all collections in the Milvus database. List all collections in the Milvus database.
@ -290,7 +232,7 @@ class Milvus(BaseVectorDB):
log.critical(f"fail to list collections, error info: {e}") log.critical(f"fail to list collections, error info: {e}")
return collection_infos return collection_infos
def clear_db(self, collection: str = "deepsearcher", *args, **kwargs): def clear_collection(self, collection: str, *args, **kwargs):
""" """
Clear (drop) a collection from the Milvus database. Clear (drop) a collection from the Milvus database.
@ -305,3 +247,18 @@ class Milvus(BaseVectorDB):
self.client.drop_collection(collection) self.client.drop_collection(collection)
except Exception as e: except Exception as e:
log.warning(f"fail to clear db, error info: {e}") log.warning(f"fail to clear db, error info: {e}")
def clear_collections(self, *args, **kwargs):
"""
Clear (drop) all collections from the Milvus database.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
try:
collections = self.client.list_collections()
for collection in collections:
self.client.drop_collection(collection)
except Exception as e:
log.warning(f"fail to clear all collections, error info: {e}")

7
test.py

@ -11,7 +11,12 @@ config.load_config_from_yaml("deepsearcher/config.yaml")
init_config(config = config) init_config(config = config)
# Load your local data # Load your local data
load_from_local_files(paths_or_directory="examples/data", force_rebuild=True, batch_size=8) load_from_local_files(
paths_or_directory="examples/data",
collection_name="default",
collection_description="a general collection for all documents",
force_rebuild=True, batch_size=8
)
# (Optional) Load from web crawling (`FIRECRAWL_API_KEY` env variable required) # (Optional) Load from web crawling (`FIRECRAWL_API_KEY` env variable required)
# from deepsearcher.offline_loading import load_from_website # from deepsearcher.offline_loading import load_from_website

Loading…
Cancel
Save