import numpy as np from pymilvus import DataType, MilvusClient from deepsearcher.loader.splitter import Chunk from deepsearcher.utils import log from deepsearcher.vector_db.base import BaseVectorDB, CollectionInfo, RetrievalResult class Milvus(BaseVectorDB): """Milvus class is a subclass of DB class.""" client: MilvusClient = None def __init__( self, uri: str = "http://localhost:19530", token: str = "root:Milvus", user: str = "", password: str = "", db: str = "default", **kwargs, ): """ Initialize the Milvus client. Args: uri (str, optional): URI for connecting to Milvus server. Defaults to "http://localhost:19530". token (str, optional): Authentication token for Milvus. Defaults to "root:Milvus". user (str, optional): Username for authentication. Defaults to "". password (str, optional): Password for authentication. Defaults to "". db (str, optional): Database name. Defaults to "default". **kwargs: Additional keyword arguments to pass to the MilvusClient. """ super().__init__() self.client = MilvusClient( uri=uri, user=user, password=password, token=token, db_name=db, timeout=30, **kwargs ) def init_collection( self, dim: int, collection: str, description: str, force_rebuild: bool = False, text_max_length: int = 65535, reference_max_length: int = 2048, metric_type: str = "IP", *args, **kwargs, ): """ Initialize a collection in Milvus. Args: dim (int): Dimension of the vector embeddings. collection (Optional[str], optional): Collection name. description (Optional[str], optional): Collection description. Defaults to "". force_rebuild (bool, optional): Whether to force create a new collection if it already exists. Defaults to False. text_max_length (int, optional): Maximum length for text field. Defaults to 65_535. reference_max_length (int, optional): Maximum length for reference field. Defaults to 2048. metric_type (str, optional): Metric type for vector similarity search. Defaults to "IP". *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. """ if not collection: collection = self.default_collection if description is None: description = "" self.metric_type = metric_type try: has_collection = self.client.has_collection(collection, timeout=5) if force_rebuild and has_collection: self.client.drop_collection(collection) elif has_collection: return schema = self.client.create_schema( enable_dynamic_field=False, auto_id=False, description=description ) schema.add_field("id", DataType.VARCHAR, is_primary=True, max_length=64) schema.add_field("embedding", DataType.FLOAT_VECTOR, dim=dim) schema.add_field("text", DataType.VARCHAR, max_length=text_max_length) schema.add_field("reference", DataType.VARCHAR, max_length=reference_max_length) schema.add_field("metadata", DataType.JSON) index_params = self.client.prepare_index_params() index_params.add_index(field_name="embedding", metric_type=metric_type) self.client.create_collection( collection, schema=schema, index_params=index_params, consistency_level="Strong", ) log.color_print(f"create collection [{collection}] successfully") except Exception as e: log.critical(f"fail to init db for milvus, error info: {e}") def insert_data( self, collection: str, chunks: list[Chunk], batch_size: int = 256, *args, **kwargs, ): """ Insert data into a Milvus collection. Args: collection (Optional[str]): Collection name. If None, uses default_collection. chunks (List[Chunk]): List of Chunk objects to insert. batch_size (int, optional): Number of chunks to insert in each batch. Defaults to 256. *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. """ if not collection: collection = self.default_collection ids = [chunk.metadata.get('id', '') for chunk in chunks] texts = [chunk.text for chunk in chunks] references = [chunk.reference for chunk in chunks] metadatas = [chunk.metadata for chunk in chunks] embeddings = [chunk.embedding for chunk in chunks] datas = [ { "id": id, "embedding": embedding, "text": text, "reference": reference, "metadata": metadata, } for id, embedding, text, reference, metadata in zip( ids, embeddings, texts, references, metadatas ) ] batch_datas = [datas[i : i + batch_size] for i in range(0, len(datas), batch_size)] try: for batch_data in batch_datas: self.client.insert(collection_name=collection, data=batch_data) except Exception as e: log.critical(f"fail to insert data, error info: {e}") def search_data( self, collection: str, vector: np.ndarray | list[float], top_k: int = 4, query_text: str = None, *args, **kwargs, ) -> list[RetrievalResult]: """ Search for similar vectors in a Milvus collection. Args: collection (Optional[str]): Collection name. If None, uses default_collection. vector (Union[np.array, List[float]]): Query vector for similarity search. top_k (int, optional): Number of results to return. Defaults to 5. query_text (Optional[str], optional): Original query text for hybrid search. Defaults to None. *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. Returns: List[RetrievalResult]: List of retrieval results containing similar vectors. """ if not collection: collection = self.default_collection try: search_results = self.client.search( collection_name=collection, data=[vector], limit=top_k, output_fields=["embedding", "text", "reference", "metadata"], timeout=10, ) return [ RetrievalResult( embedding=b["entity"]["embedding"], text=b["entity"]["text"], reference=b["entity"]["reference"], score=b["distance"], metadata=b["entity"]["metadata"], ) for a in search_results for b in a ] except Exception as e: log.critical(f"fail to search data, error info: {e}") return [] def list_collections(self, *args, **kwargs) -> list[CollectionInfo]: """ List all collections in the Milvus database. Args: *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. Returns: List[CollectionInfo]: List of collection information objects. """ collection_infos = [] dim = kwargs.pop("dim", 0) try: collections = self.client.list_collections() for collection in collections: description = self.client.describe_collection(collection) if dim != 0: skip = False for field_dict in description["fields"]: if ( field_dict["name"] == "embedding" and field_dict["type"] == DataType.FLOAT_VECTOR ): if field_dict["params"]["dim"] != dim: skip = True if skip: continue collection_infos.append( CollectionInfo( collection_name=collection, description=description["description"], ) ) except Exception as e: log.critical(f"fail to list collections, error info: {e}") return collection_infos def clear_collection(self, collection: str, *args, **kwargs): """ Clear (drop) a collection from the Milvus database. Args: collection (str, optional): Collection name to drop. Defaults to "deepsearcher". *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. """ if not collection: collection = self.default_collection try: self.client.drop_collection(collection) except Exception as e: log.warning(f"fail to clear db, error info: {e}") def clear_collections(self, *args, **kwargs): """ Clear (drop) all collections from the Milvus database. Args: *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. """ try: collections = self.client.list_collections() for collection in collections: self.client.drop_collection(collection) except Exception as e: log.warning(f"fail to clear all collections, error info: {e}")