You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

305 lines
12 KiB

from typing import List, Optional, Union
import numpy as np
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker
from deepsearcher.loader.splitter import Chunk
from deepsearcher.utils import log
from deepsearcher.vector_db.base import BaseVectorDB, CollectionInfo, RetrievalResult
class Milvus(BaseVectorDB):
"""Milvus class is a subclass of DB class."""
client: MilvusClient = None
def __init__(
self,
default_collection: str = "deepsearcher",
uri: str = "http://localhost:19530",
token: str = "root:Milvus",
user: str = "",
password: str = "",
db: str = "default",
hybrid: bool = False,
**kwargs,
):
"""
Initialize the Milvus client.
Args:
default_collection (str, optional): Default collection name. Defaults to "deepsearcher".
uri (str, optional): URI for connecting to Milvus server. Defaults to "http://localhost:19530".
token (str, optional): Authentication token for Milvus. Defaults to "root:Milvus".
user (str, optional): Username for authentication. Defaults to "".
password (str, optional): Password for authentication. Defaults to "".
db (str, optional): Database name. Defaults to "default".
hybrid (bool, optional): Whether to enable hybrid search. Defaults to False.
**kwargs: Additional keyword arguments to pass to the MilvusClient.
"""
super().__init__(default_collection)
self.default_collection = default_collection
self.client = MilvusClient(
uri=uri, user=user, password=password, token=token, db_name=db, timeout=30, **kwargs
)
self.hybrid = hybrid
def init_collection(
self,
dim: int,
collection: Optional[str] = "deepsearcher",
description: Optional[str] = "",
force_new_collection: bool = False,
text_max_length: int = 65_535,
reference_max_length: int = 2048,
metric_type: str = "L2",
*args,
**kwargs,
):
"""
Initialize a collection in Milvus.
Args:
dim (int): Dimension of the vector embeddings.
collection (Optional[str], optional): Collection name. Defaults to "deepsearcher".
description (Optional[str], optional): Collection description. Defaults to "".
force_new_collection (bool, optional): Whether to force create a new collection if it already exists. Defaults to False.
text_max_length (int, optional): Maximum length for text field. Defaults to 65_535.
reference_max_length (int, optional): Maximum length for reference field. Defaults to 2048.
metric_type (str, optional): Metric type for vector similarity search. Defaults to "L2".
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
if not collection:
collection = self.default_collection
if description is None:
description = ""
self.metric_type = metric_type
try:
has_collection = self.client.has_collection(collection, timeout=5)
if force_new_collection and has_collection:
self.client.drop_collection(collection)
elif has_collection:
return
schema = self.client.create_schema(
enable_dynamic_field=False, auto_id=True, description=description
)
schema.add_field("id", DataType.INT64, is_primary=True)
schema.add_field("embedding", DataType.FLOAT_VECTOR, dim=dim)
if self.hybrid:
analyzer_params = {"tokenizer": "standard", "filter": ["lowercase"]}
schema.add_field(
"text",
DataType.VARCHAR,
max_length=text_max_length,
analyzer_params=analyzer_params,
enable_match=True,
enable_analyzer=True,
)
else:
schema.add_field("text", DataType.VARCHAR, max_length=text_max_length)
schema.add_field("reference", DataType.VARCHAR, max_length=reference_max_length)
schema.add_field("metadata", DataType.JSON)
if self.hybrid:
schema.add_field("sparse_vector", DataType.SPARSE_FLOAT_VECTOR)
bm25_function = Function(
name="bm25",
function_type=FunctionType.BM25,
input_field_names=["text"],
output_field_names="sparse_vector",
)
schema.add_function(bm25_function)
index_params = self.client.prepare_index_params()
index_params.add_index(field_name="embedding", metric_type=metric_type)
if self.hybrid:
index_params.add_index(
field_name="sparse_vector",
index_type="SPARSE_INVERTED_INDEX",
metric_type="BM25",
)
self.client.create_collection(
collection,
schema=schema,
index_params=index_params,
consistency_level="Strong",
)
log.color_print(f"create collection [{collection}] successfully")
except Exception as e:
log.critical(f"fail to init db for milvus, error info: {e}")
def insert_data(
self,
collection: Optional[str],
chunks: List[Chunk],
batch_size: int = 256,
*args,
**kwargs,
):
"""
Insert data into a Milvus collection.
Args:
collection (Optional[str]): Collection name. If None, uses default_collection.
chunks (List[Chunk]): List of Chunk objects to insert.
batch_size (int, optional): Number of chunks to insert in each batch. Defaults to 256.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
if not collection:
collection = self.default_collection
texts = [chunk.text for chunk in chunks]
references = [chunk.reference for chunk in chunks]
metadatas = [chunk.metadata for chunk in chunks]
embeddings = [chunk.embedding for chunk in chunks]
datas = [
{
"embedding": embedding,
"text": text,
"reference": reference,
"metadata": metadata,
}
for embedding, text, reference, metadata in zip(
embeddings, texts, references, metadatas
)
]
batch_datas = [datas[i : i + batch_size] for i in range(0, len(datas), batch_size)]
try:
for batch_data in batch_datas:
self.client.insert(collection_name=collection, data=batch_data)
except Exception as e:
log.critical(f"fail to insert data, error info: {e}")
def search_data(
self,
collection: Optional[str],
vector: Union[np.array, List[float]],
top_k: int = 5,
query_text: Optional[str] = None,
*args,
**kwargs,
) -> List[RetrievalResult]:
"""
Search for similar vectors in a Milvus collection.
Args:
collection (Optional[str]): Collection name. If None, uses default_collection.
vector (Union[np.array, List[float]]): Query vector for similarity search.
top_k (int, optional): Number of results to return. Defaults to 5.
query_text (Optional[str], optional): Original query text for hybrid search. Defaults to None.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
List[RetrievalResult]: List of retrieval results containing similar vectors.
"""
if not collection:
collection = self.default_collection
try:
use_hybrid = self.hybrid and query_text
if use_hybrid:
sparse_search_params = {"metric_type": "BM25"}
sparse_request = AnnSearchRequest(
[query_text], "sparse_vector", sparse_search_params, limit=top_k
)
dense_search_params = {"metric_type": self.metric_type}
dense_request = AnnSearchRequest(
[vector], "embedding", dense_search_params, limit=top_k
)
search_results = self.client.hybrid_search(
collection_name=collection,
reqs=[sparse_request, dense_request],
ranker=RRFRanker(),
limit=top_k,
output_fields=["embedding", "text", "reference", "metadata"],
timeout=10,
)
else:
search_results = self.client.search(
collection_name=collection,
data=[vector],
limit=top_k,
output_fields=["embedding", "text", "reference", "metadata"],
timeout=10,
)
return [
RetrievalResult(
embedding=b["entity"]["embedding"],
text=b["entity"]["text"],
reference=b["entity"]["reference"],
score=b["distance"],
metadata=b["entity"]["metadata"],
)
for a in search_results
for b in a
]
except Exception as e:
log.critical(f"fail to search data, error info: {e}")
return []
def list_collections(self, *args, **kwargs) -> List[CollectionInfo]:
"""
List all collections in the Milvus database.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
List[CollectionInfo]: List of collection information objects.
"""
collection_infos = []
dim = kwargs.pop("dim", 0)
try:
collections = self.client.list_collections()
for collection in collections:
description = self.client.describe_collection(collection)
if dim != 0:
skip = False
for field_dict in description["fields"]:
if (
field_dict["name"] == "embedding"
and field_dict["type"] == DataType.FLOAT_VECTOR
):
if field_dict["params"]["dim"] != dim:
skip = True
if skip:
continue
collection_infos.append(
CollectionInfo(
collection_name=collection,
description=description["description"],
)
)
except Exception as e:
log.critical(f"fail to list collections, error info: {e}")
return collection_infos
def clear_db(self, collection: str = "deepsearcher", *args, **kwargs):
"""
Clear (drop) a collection from the Milvus database.
Args:
collection (str, optional): Collection name to drop. Defaults to "deepsearcher".
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
if not collection:
collection = self.default_collection
try:
self.client.drop_collection(collection)
except Exception as e:
log.warning(f"fail to clear db, error info: {e}")