|
|
@ -1,7 +1,5 @@ |
|
|
|
from typing import List, Optional, Union |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker |
|
|
|
from pymilvus import DataType, MilvusClient |
|
|
|
|
|
|
|
from deepsearcher.loader.splitter import Chunk |
|
|
|
from deepsearcher.utils import log |
|
|
@ -15,43 +13,36 @@ class Milvus(BaseVectorDB): |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
default_collection: str = "deepsearcher", |
|
|
|
uri: str = "http://localhost:19530", |
|
|
|
token: str = "root:Milvus", |
|
|
|
user: str = "", |
|
|
|
password: str = "", |
|
|
|
db: str = "default", |
|
|
|
hybrid: bool = False, |
|
|
|
**kwargs, |
|
|
|
): |
|
|
|
""" |
|
|
|
Initialize the Milvus client. |
|
|
|
|
|
|
|
Args: |
|
|
|
default_collection (str, optional): Default collection name. Defaults to "deepsearcher". |
|
|
|
uri (str, optional): URI for connecting to Milvus server. Defaults to "http://localhost:19530". |
|
|
|
token (str, optional): Authentication token for Milvus. Defaults to "root:Milvus". |
|
|
|
user (str, optional): Username for authentication. Defaults to "". |
|
|
|
password (str, optional): Password for authentication. Defaults to "". |
|
|
|
db (str, optional): Database name. Defaults to "default". |
|
|
|
hybrid (bool, optional): Whether to enable hybrid search. Defaults to False. |
|
|
|
**kwargs: Additional keyword arguments to pass to the MilvusClient. |
|
|
|
""" |
|
|
|
super().__init__(default_collection) |
|
|
|
self.default_collection = default_collection |
|
|
|
super().__init__() |
|
|
|
self.client = MilvusClient( |
|
|
|
uri=uri, user=user, password=password, token=token, db_name=db, timeout=30, **kwargs |
|
|
|
) |
|
|
|
|
|
|
|
self.hybrid = hybrid |
|
|
|
|
|
|
|
def init_collection( |
|
|
|
self, |
|
|
|
dim: int, |
|
|
|
collection: Optional[str] = "deepsearcher", |
|
|
|
description: Optional[str] = "", |
|
|
|
force_new_collection: bool = False, |
|
|
|
text_max_length: int = 65_535, |
|
|
|
collection: str, |
|
|
|
description: str, |
|
|
|
force_rebuild: bool = False, |
|
|
|
text_max_length: int = 65535, |
|
|
|
reference_max_length: int = 2048, |
|
|
|
metric_type: str = "L2", |
|
|
|
*args, |
|
|
@ -62,9 +53,9 @@ class Milvus(BaseVectorDB): |
|
|
|
|
|
|
|
Args: |
|
|
|
dim (int): Dimension of the vector embeddings. |
|
|
|
collection (Optional[str], optional): Collection name. Defaults to "deepsearcher". |
|
|
|
collection (Optional[str], optional): Collection name. |
|
|
|
description (Optional[str], optional): Collection description. Defaults to "". |
|
|
|
force_new_collection (bool, optional): Whether to force create a new collection if it already exists. Defaults to False. |
|
|
|
force_rebuild (bool, optional): Whether to force create a new collection if it already exists. Defaults to False. |
|
|
|
text_max_length (int, optional): Maximum length for text field. Defaults to 65_535. |
|
|
|
reference_max_length (int, optional): Maximum length for reference field. Defaults to 2048. |
|
|
|
metric_type (str, optional): Metric type for vector similarity search. Defaults to "L2". |
|
|
@ -80,7 +71,7 @@ class Milvus(BaseVectorDB): |
|
|
|
|
|
|
|
try: |
|
|
|
has_collection = self.client.has_collection(collection, timeout=5) |
|
|
|
if force_new_collection and has_collection: |
|
|
|
if force_rebuild and has_collection: |
|
|
|
self.client.drop_collection(collection) |
|
|
|
elif has_collection: |
|
|
|
return |
|
|
@ -90,42 +81,15 @@ class Milvus(BaseVectorDB): |
|
|
|
schema.add_field("id", DataType.VARCHAR, is_primary=True, max_length=64) |
|
|
|
schema.add_field("embedding", DataType.FLOAT_VECTOR, dim=dim) |
|
|
|
|
|
|
|
if self.hybrid: |
|
|
|
analyzer_params = {"tokenizer": "standard", "filter": ["lowercase"]} |
|
|
|
schema.add_field( |
|
|
|
"text", |
|
|
|
DataType.VARCHAR, |
|
|
|
max_length=text_max_length, |
|
|
|
analyzer_params=analyzer_params, |
|
|
|
enable_match=True, |
|
|
|
enable_analyzer=True, |
|
|
|
) |
|
|
|
else: |
|
|
|
schema.add_field("text", DataType.VARCHAR, max_length=text_max_length) |
|
|
|
schema.add_field("text", DataType.VARCHAR, max_length=text_max_length) |
|
|
|
|
|
|
|
schema.add_field("reference", DataType.VARCHAR, max_length=reference_max_length) |
|
|
|
schema.add_field("metadata", DataType.JSON) |
|
|
|
|
|
|
|
if self.hybrid: |
|
|
|
schema.add_field("sparse_vector", DataType.SPARSE_FLOAT_VECTOR) |
|
|
|
bm25_function = Function( |
|
|
|
name="bm25", |
|
|
|
function_type=FunctionType.BM25, |
|
|
|
input_field_names=["text"], |
|
|
|
output_field_names="sparse_vector", |
|
|
|
) |
|
|
|
schema.add_function(bm25_function) |
|
|
|
|
|
|
|
index_params = self.client.prepare_index_params() |
|
|
|
index_params.add_index(field_name="embedding", metric_type=metric_type) |
|
|
|
|
|
|
|
if self.hybrid: |
|
|
|
index_params.add_index( |
|
|
|
field_name="sparse_vector", |
|
|
|
index_type="SPARSE_INVERTED_INDEX", |
|
|
|
metric_type="BM25", |
|
|
|
) |
|
|
|
|
|
|
|
self.client.create_collection( |
|
|
|
collection, |
|
|
|
schema=schema, |
|
|
@ -138,8 +102,8 @@ class Milvus(BaseVectorDB): |
|
|
|
|
|
|
|
def insert_data( |
|
|
|
self, |
|
|
|
collection: Optional[str], |
|
|
|
chunks: List[Chunk], |
|
|
|
collection: str, |
|
|
|
chunks: list[Chunk], |
|
|
|
batch_size: int = 256, |
|
|
|
*args, |
|
|
|
**kwargs, |
|
|
@ -183,13 +147,13 @@ class Milvus(BaseVectorDB): |
|
|
|
|
|
|
|
def search_data( |
|
|
|
self, |
|
|
|
collection: Optional[str], |
|
|
|
vector: Union[np.array, List[float]], |
|
|
|
top_k: int = 5, |
|
|
|
query_text: Optional[str] = None, |
|
|
|
collection: str, |
|
|
|
vector: np.ndarray | list[float], |
|
|
|
top_k: int = 4, |
|
|
|
query_text: str = None, |
|
|
|
*args, |
|
|
|
**kwargs, |
|
|
|
) -> List[RetrievalResult]: |
|
|
|
) -> list[RetrievalResult]: |
|
|
|
""" |
|
|
|
Search for similar vectors in a Milvus collection. |
|
|
|
|
|
|
@ -207,35 +171,13 @@ class Milvus(BaseVectorDB): |
|
|
|
if not collection: |
|
|
|
collection = self.default_collection |
|
|
|
try: |
|
|
|
use_hybrid = self.hybrid and query_text |
|
|
|
|
|
|
|
if use_hybrid: |
|
|
|
sparse_search_params = {"metric_type": "BM25"} |
|
|
|
sparse_request = AnnSearchRequest( |
|
|
|
[query_text], "sparse_vector", sparse_search_params, limit=top_k |
|
|
|
) |
|
|
|
|
|
|
|
dense_search_params = {"metric_type": self.metric_type} |
|
|
|
dense_request = AnnSearchRequest( |
|
|
|
[vector], "embedding", dense_search_params, limit=top_k |
|
|
|
) |
|
|
|
|
|
|
|
search_results = self.client.hybrid_search( |
|
|
|
collection_name=collection, |
|
|
|
reqs=[sparse_request, dense_request], |
|
|
|
ranker=RRFRanker(), |
|
|
|
limit=top_k, |
|
|
|
output_fields=["embedding", "text", "reference", "metadata"], |
|
|
|
timeout=10, |
|
|
|
) |
|
|
|
else: |
|
|
|
search_results = self.client.search( |
|
|
|
collection_name=collection, |
|
|
|
data=[vector], |
|
|
|
limit=top_k, |
|
|
|
output_fields=["embedding", "text", "reference", "metadata"], |
|
|
|
timeout=10, |
|
|
|
) |
|
|
|
search_results = self.client.search( |
|
|
|
collection_name=collection, |
|
|
|
data=[vector], |
|
|
|
limit=top_k, |
|
|
|
output_fields=["embedding", "text", "reference", "metadata"], |
|
|
|
timeout=10, |
|
|
|
) |
|
|
|
|
|
|
|
return [ |
|
|
|
RetrievalResult( |
|
|
@ -252,7 +194,7 @@ class Milvus(BaseVectorDB): |
|
|
|
log.critical(f"fail to search data, error info: {e}") |
|
|
|
return [] |
|
|
|
|
|
|
|
def list_collections(self, *args, **kwargs) -> List[CollectionInfo]: |
|
|
|
def list_collections(self, *args, **kwargs) -> list[CollectionInfo]: |
|
|
|
""" |
|
|
|
List all collections in the Milvus database. |
|
|
|
|
|
|
@ -290,7 +232,7 @@ class Milvus(BaseVectorDB): |
|
|
|
log.critical(f"fail to list collections, error info: {e}") |
|
|
|
return collection_infos |
|
|
|
|
|
|
|
def clear_db(self, collection: str = "deepsearcher", *args, **kwargs): |
|
|
|
def clear_collection(self, collection: str, *args, **kwargs): |
|
|
|
""" |
|
|
|
Clear (drop) a collection from the Milvus database. |
|
|
|
|
|
|
@ -305,3 +247,18 @@ class Milvus(BaseVectorDB): |
|
|
|
self.client.drop_collection(collection) |
|
|
|
except Exception as e: |
|
|
|
log.warning(f"fail to clear db, error info: {e}") |
|
|
|
|
|
|
|
def clear_collections(self, *args, **kwargs): |
|
|
|
""" |
|
|
|
Clear (drop) all collections from the Milvus database. |
|
|
|
|
|
|
|
Args: |
|
|
|
*args: Variable length argument list. |
|
|
|
**kwargs: Arbitrary keyword arguments. |
|
|
|
""" |
|
|
|
try: |
|
|
|
collections = self.client.list_collections() |
|
|
|
for collection in collections: |
|
|
|
self.client.drop_collection(collection) |
|
|
|
except Exception as e: |
|
|
|
log.warning(f"fail to clear all collections, error info: {e}") |
|
|
|