Browse Source

移除默认collection

main
tanxing 1 week ago
parent
commit
3fd43ef722
  1. 14
      deepsearcher/agent/deep_search.py
  2. 3
      deepsearcher/configuration.py
  3. 17
      deepsearcher/offline_loading.py
  4. 4
      deepsearcher/vector_db/__init__.py
  5. 44
      deepsearcher/vector_db/base.py
  6. 125
      deepsearcher/vector_db/milvus.py
  7. 7
      test.py

14
deepsearcher/agent/deep_search.py

@ -5,7 +5,7 @@ from deepsearcher.embedding.base import BaseEmbedding
from deepsearcher.llm.base import BaseLLM
from deepsearcher.utils import log
from deepsearcher.vector_db import RetrievalResult
from deepsearcher.vector_db.base import BaseVectorDB, deduplicate_results
from deepsearcher.vector_db.base import BaseVectorDB, deduplicate
COLLECTION_ROUTE_PROMPT = """
I provide you with collection_name(s) and corresponding collection_description(s).
@ -72,8 +72,8 @@ Respond exclusively in valid List of str format without any other text."""
SUMMARY_PROMPT = """
You are a AI content analysis expert, good at summarizing content.
Please summarize a long, specific and detailed answer or report based on the previous queries and the retrieved document chunks.
You are a AI content analysis expert.
Please generate a long, specific and detailed answer or report based on the previous queries and the retrieved document chunks.
Original Query: {question}
@ -148,13 +148,13 @@ class DeepSearch(BaseAgent):
log.color_print(
"No collections found in the vector database. Please check the database connection."
)
return [], 0
return []
if len(collection_infos) == 1:
the_only_collection = collection_infos[0].collection_name
log.color_print(
f"<think> Perform search [{query}] on the vector DB collection: {the_only_collection} </think>\n"
)
return [the_only_collection], 0
return [the_only_collection]
vector_db_search_prompt = COLLECTION_ROUTE_PROMPT.format(
question=query,
collection_info=[
@ -341,7 +341,7 @@ class DeepSearch(BaseAgent):
search_res = result
search_res_from_vectordb.extend(search_res)
search_res_from_vectordb = deduplicate_results(search_res_from_vectordb)
search_res_from_vectordb = deduplicate(search_res_from_vectordb)
# search_res_from_internet = deduplicate_results(search_res_from_internet)
all_search_res.extend(search_res_from_vectordb + search_res_from_internet)
if iter == max_iter - 1:
@ -361,7 +361,7 @@ class DeepSearch(BaseAgent):
)
all_sub_queries.extend(sub_gap_queries)
all_search_res = deduplicate_results(all_search_res)
all_search_res = deduplicate(all_search_res)
additional_info = {"all_sub_queries": all_sub_queries}
return all_search_res, additional_info

3
deepsearcher/configuration.py

@ -197,8 +197,7 @@ def init_config(config: Configuration):
file_loader, \
vector_db, \
web_crawler, \
default_searcher, \
naive_rag
default_searcher
module_factory = ModuleFactory(config)
llm = module_factory.create_llm()
embedding_model = module_factory.create_embedding()

17
deepsearcher/offline_loading.py

@ -12,11 +12,10 @@ def load_from_local_files(
paths_or_directory: str | list[str],
collection_name: str = None,
collection_description: str = None,
force_new_collection: bool = False,
force_rebuild: bool = False,
chunk_size: int = 1500,
chunk_overlap: int = 100,
batch_size: int = 256,
force_rebuild: bool = False,
batch_size: int = 256
):
"""
Load knowledge from local files or directories into the vector database.
@ -28,7 +27,7 @@ def load_from_local_files(
paths_or_directory: A single path or a list of paths to files or directories to load.
collection_name: Name of the collection to store the data in. If None, uses the default collection.
collection_description: Description of the collection. If None, no description is set.
force_new_collection: If True, drops the existing collection and creates a new one.
force_rebuild: If True, drops the existing collection and creates a new one.
chunk_size: Size of each chunk in characters.
chunk_overlap: Number of characters to overlap between chunks.
batch_size: Number of chunks to process at once during embedding.
@ -47,7 +46,7 @@ def load_from_local_files(
dim=embedding_model.dimension,
collection=collection_name,
description=collection_description,
force_new_collection=force_new_collection,
force_rebuild=force_rebuild,
)
# 如果force_rebuild为True,则强制重建集合
@ -56,7 +55,7 @@ def load_from_local_files(
dim=embedding_model.dimension,
collection=collection_name,
description=collection_description,
force_new_collection=True,
force_rebuild=True,
)
if isinstance(paths_or_directory, str):
@ -94,7 +93,7 @@ def load_from_website(
urls: str | list[str],
collection_name: str = None,
collection_description: str = None,
force_new_collection: bool = False,
force_rebuild: bool = False,
chunk_size: int = 1500,
chunk_overlap: int = 100,
batch_size: int = 256,
@ -110,7 +109,7 @@ def load_from_website(
urls: A single URL or a list of URLs to crawl.
collection_name: Name of the collection to store the data in. If None, uses the default collection.
collection_description: Description of the collection. If None, no description is set.
force_new_collection: If True, drops the existing collection and creates a new one.
force_rebuild: If True, drops the existing collection and creates a new one.
chunk_size: Size of each chunk in characters.
chunk_overlap: Number of characters to overlap between chunks.
batch_size: Number of chunks to process at once during embedding.
@ -126,7 +125,7 @@ def load_from_website(
dim=embedding_model.dimension,
collection=collection_name,
description=collection_description,
force_new_collection=force_new_collection,
force_rebuild=force_rebuild,
)
all_docs = web_crawler.crawl_urls(urls, **crawl_kwargs)

4
deepsearcher/vector_db/__init__.py

@ -1,5 +1,3 @@
from .milvus import Milvus, RetrievalResult
from .oracle import OracleDB
from .qdrant import Qdrant
__all__ = ["Milvus", "RetrievalResult", "OracleDB", "Qdrant"]
__all__ = ["Milvus", "RetrievalResult"]

44
deepsearcher/vector_db/base.py

@ -1,5 +1,4 @@
from abc import ABC, abstractmethod
from typing import List, Union
import numpy as np
@ -23,7 +22,7 @@ class RetrievalResult:
def __init__(
self,
embedding: np.array,
embedding: np.ndarray,
text: str,
reference: str,
metadata: dict,
@ -55,7 +54,7 @@ class RetrievalResult:
return f"RetrievalResult(score={self.score}, embedding={self.embedding}, text={self.text}, reference={self.reference}), metadata={self.metadata}"
def deduplicate_results(results: List[RetrievalResult]) -> List[RetrievalResult]:
def deduplicate(results: list[RetrievalResult]) -> list[RetrievalResult]:
"""
Remove duplicate results based on text content.
@ -68,13 +67,13 @@ def deduplicate_results(results: List[RetrievalResult]) -> List[RetrievalResult]
Returns:
A list of deduplicated RetrievalResult objects.
"""
all_text_set = set()
deduplicated_results = []
text = set()
deduplicated = []
for result in results:
if result.text not in all_text_set:
all_text_set.add(result.text)
deduplicated_results.append(result)
return deduplicated_results
if result.text not in text:
text.add(result.text)
deduplicated.append(result)
return deduplicated
class CollectionInfo:
@ -114,7 +113,6 @@ class BaseVectorDB(ABC):
def __init__(
self,
default_collection: str = "deepsearcher",
*args,
**kwargs,
):
@ -126,7 +124,6 @@ class BaseVectorDB(ABC):
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
self.default_collection = default_collection
@abstractmethod
def init_collection(
@ -134,7 +131,7 @@ class BaseVectorDB(ABC):
dim: int,
collection: str,
description: str,
force_new_collection=False,
force_rebuild=False,
*args,
**kwargs,
):
@ -145,14 +142,14 @@ class BaseVectorDB(ABC):
dim: The dimensionality of the vectors in the collection.
collection: The name of the collection.
description: The description of the collection.
force_new_collection: If True, drop the existing collection and create a new one.
force_rebuild: If True, drop the existing collection and create a new one.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
pass
@abstractmethod
def insert_data(self, collection: str, chunks: List[Chunk], *args, **kwargs):
def insert_data(self, collection: str, chunks: list[Chunk], *args, **kwargs):
"""
Insert data into a collection in the vector database.
@ -166,8 +163,8 @@ class BaseVectorDB(ABC):
@abstractmethod
def search_data(
self, collection: str, vector: Union[np.array, List[float]], *args, **kwargs
) -> List[RetrievalResult]:
self, collection: str, vector: np.ndarray | list[float], *args, **kwargs
) -> list[RetrievalResult]:
"""
Search for similar vectors in a collection.
@ -182,7 +179,7 @@ class BaseVectorDB(ABC):
"""
pass
def list_collections(self, *args, **kwargs) -> List[CollectionInfo]:
def list_collections(self, *args, **kwargs) -> list[CollectionInfo]:
"""
List all collections in the vector database.
@ -196,7 +193,7 @@ class BaseVectorDB(ABC):
pass
@abstractmethod
def clear_db(self, *args, **kwargs):
def clear_collection(self, *args, **kwargs):
"""
Clear the vector database.
@ -205,3 +202,14 @@ class BaseVectorDB(ABC):
**kwargs: Arbitrary keyword arguments.
"""
pass
@abstractmethod
def clear_collections(self, *args, **kwargs):
"""
Clear all collections in the vector database.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
pass

125
deepsearcher/vector_db/milvus.py

@ -1,7 +1,5 @@
from typing import List, Optional, Union
import numpy as np
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker
from pymilvus import DataType, MilvusClient
from deepsearcher.loader.splitter import Chunk
from deepsearcher.utils import log
@ -15,43 +13,36 @@ class Milvus(BaseVectorDB):
def __init__(
self,
default_collection: str = "deepsearcher",
uri: str = "http://localhost:19530",
token: str = "root:Milvus",
user: str = "",
password: str = "",
db: str = "default",
hybrid: bool = False,
**kwargs,
):
"""
Initialize the Milvus client.
Args:
default_collection (str, optional): Default collection name. Defaults to "deepsearcher".
uri (str, optional): URI for connecting to Milvus server. Defaults to "http://localhost:19530".
token (str, optional): Authentication token for Milvus. Defaults to "root:Milvus".
user (str, optional): Username for authentication. Defaults to "".
password (str, optional): Password for authentication. Defaults to "".
db (str, optional): Database name. Defaults to "default".
hybrid (bool, optional): Whether to enable hybrid search. Defaults to False.
**kwargs: Additional keyword arguments to pass to the MilvusClient.
"""
super().__init__(default_collection)
self.default_collection = default_collection
super().__init__()
self.client = MilvusClient(
uri=uri, user=user, password=password, token=token, db_name=db, timeout=30, **kwargs
)
self.hybrid = hybrid
def init_collection(
self,
dim: int,
collection: Optional[str] = "deepsearcher",
description: Optional[str] = "",
force_new_collection: bool = False,
text_max_length: int = 65_535,
collection: str,
description: str,
force_rebuild: bool = False,
text_max_length: int = 65535,
reference_max_length: int = 2048,
metric_type: str = "L2",
*args,
@ -62,9 +53,9 @@ class Milvus(BaseVectorDB):
Args:
dim (int): Dimension of the vector embeddings.
collection (Optional[str], optional): Collection name. Defaults to "deepsearcher".
collection (Optional[str], optional): Collection name.
description (Optional[str], optional): Collection description. Defaults to "".
force_new_collection (bool, optional): Whether to force create a new collection if it already exists. Defaults to False.
force_rebuild (bool, optional): Whether to force create a new collection if it already exists. Defaults to False.
text_max_length (int, optional): Maximum length for text field. Defaults to 65_535.
reference_max_length (int, optional): Maximum length for reference field. Defaults to 2048.
metric_type (str, optional): Metric type for vector similarity search. Defaults to "L2".
@ -80,7 +71,7 @@ class Milvus(BaseVectorDB):
try:
has_collection = self.client.has_collection(collection, timeout=5)
if force_new_collection and has_collection:
if force_rebuild and has_collection:
self.client.drop_collection(collection)
elif has_collection:
return
@ -90,42 +81,15 @@ class Milvus(BaseVectorDB):
schema.add_field("id", DataType.VARCHAR, is_primary=True, max_length=64)
schema.add_field("embedding", DataType.FLOAT_VECTOR, dim=dim)
if self.hybrid:
analyzer_params = {"tokenizer": "standard", "filter": ["lowercase"]}
schema.add_field(
"text",
DataType.VARCHAR,
max_length=text_max_length,
analyzer_params=analyzer_params,
enable_match=True,
enable_analyzer=True,
)
else:
schema.add_field("text", DataType.VARCHAR, max_length=text_max_length)
schema.add_field("text", DataType.VARCHAR, max_length=text_max_length)
schema.add_field("reference", DataType.VARCHAR, max_length=reference_max_length)
schema.add_field("metadata", DataType.JSON)
if self.hybrid:
schema.add_field("sparse_vector", DataType.SPARSE_FLOAT_VECTOR)
bm25_function = Function(
name="bm25",
function_type=FunctionType.BM25,
input_field_names=["text"],
output_field_names="sparse_vector",
)
schema.add_function(bm25_function)
index_params = self.client.prepare_index_params()
index_params.add_index(field_name="embedding", metric_type=metric_type)
if self.hybrid:
index_params.add_index(
field_name="sparse_vector",
index_type="SPARSE_INVERTED_INDEX",
metric_type="BM25",
)
self.client.create_collection(
collection,
schema=schema,
@ -138,8 +102,8 @@ class Milvus(BaseVectorDB):
def insert_data(
self,
collection: Optional[str],
chunks: List[Chunk],
collection: str,
chunks: list[Chunk],
batch_size: int = 256,
*args,
**kwargs,
@ -183,13 +147,13 @@ class Milvus(BaseVectorDB):
def search_data(
self,
collection: Optional[str],
vector: Union[np.array, List[float]],
top_k: int = 5,
query_text: Optional[str] = None,
collection: str,
vector: np.ndarray | list[float],
top_k: int = 4,
query_text: str = None,
*args,
**kwargs,
) -> List[RetrievalResult]:
) -> list[RetrievalResult]:
"""
Search for similar vectors in a Milvus collection.
@ -207,35 +171,13 @@ class Milvus(BaseVectorDB):
if not collection:
collection = self.default_collection
try:
use_hybrid = self.hybrid and query_text
if use_hybrid:
sparse_search_params = {"metric_type": "BM25"}
sparse_request = AnnSearchRequest(
[query_text], "sparse_vector", sparse_search_params, limit=top_k
)
dense_search_params = {"metric_type": self.metric_type}
dense_request = AnnSearchRequest(
[vector], "embedding", dense_search_params, limit=top_k
)
search_results = self.client.hybrid_search(
collection_name=collection,
reqs=[sparse_request, dense_request],
ranker=RRFRanker(),
limit=top_k,
output_fields=["embedding", "text", "reference", "metadata"],
timeout=10,
)
else:
search_results = self.client.search(
collection_name=collection,
data=[vector],
limit=top_k,
output_fields=["embedding", "text", "reference", "metadata"],
timeout=10,
)
search_results = self.client.search(
collection_name=collection,
data=[vector],
limit=top_k,
output_fields=["embedding", "text", "reference", "metadata"],
timeout=10,
)
return [
RetrievalResult(
@ -252,7 +194,7 @@ class Milvus(BaseVectorDB):
log.critical(f"fail to search data, error info: {e}")
return []
def list_collections(self, *args, **kwargs) -> List[CollectionInfo]:
def list_collections(self, *args, **kwargs) -> list[CollectionInfo]:
"""
List all collections in the Milvus database.
@ -290,7 +232,7 @@ class Milvus(BaseVectorDB):
log.critical(f"fail to list collections, error info: {e}")
return collection_infos
def clear_db(self, collection: str = "deepsearcher", *args, **kwargs):
def clear_collection(self, collection: str, *args, **kwargs):
"""
Clear (drop) a collection from the Milvus database.
@ -305,3 +247,18 @@ class Milvus(BaseVectorDB):
self.client.drop_collection(collection)
except Exception as e:
log.warning(f"fail to clear db, error info: {e}")
def clear_collections(self, *args, **kwargs):
"""
Clear (drop) all collections from the Milvus database.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
try:
collections = self.client.list_collections()
for collection in collections:
self.client.drop_collection(collection)
except Exception as e:
log.warning(f"fail to clear all collections, error info: {e}")

7
test.py

@ -11,7 +11,12 @@ config.load_config_from_yaml("deepsearcher/config.yaml")
init_config(config = config)
# Load your local data
load_from_local_files(paths_or_directory="examples/data", force_rebuild=True, batch_size=8)
load_from_local_files(
paths_or_directory="examples/data",
collection_name="default",
collection_description="a general collection for all documents",
force_rebuild=True, batch_size=8
)
# (Optional) Load from web crawling (`FIRECRAWL_API_KEY` env variable required)
# from deepsearcher.offline_loading import load_from_website

Loading…
Cancel
Save