You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

265 lines
9.8 KiB

2 weeks ago
import numpy as np
from pymilvus import DataType, MilvusClient
2 weeks ago
from deepsearcher.loader.splitter import Chunk
from deepsearcher.utils import log
from deepsearcher.vector_db.base import BaseVectorDB, CollectionInfo, RetrievalResult
class Milvus(BaseVectorDB):
"""Milvus class is a subclass of DB class."""
client: MilvusClient = None
def __init__(
self,
uri: str = "http://localhost:19530",
token: str = "root:Milvus",
user: str = "",
password: str = "",
db: str = "default",
**kwargs,
):
"""
Initialize the Milvus client.
Args:
uri (str, optional): URI for connecting to Milvus server. Defaults to "http://localhost:19530".
token (str, optional): Authentication token for Milvus. Defaults to "root:Milvus".
user (str, optional): Username for authentication. Defaults to "".
password (str, optional): Password for authentication. Defaults to "".
db (str, optional): Database name. Defaults to "default".
**kwargs: Additional keyword arguments to pass to the MilvusClient.
"""
super().__init__()
2 weeks ago
self.client = MilvusClient(
uri=uri, user=user, password=password, token=token, db_name=db, timeout=30, **kwargs
)
def init_collection(
self,
dim: int,
collection: str,
description: str,
force_rebuild: bool = False,
text_max_length: int = 65535,
2 weeks ago
reference_max_length: int = 2048,
metric_type: str = "IP",
2 weeks ago
*args,
**kwargs,
):
"""
Initialize a collection in Milvus.
Args:
dim (int): Dimension of the vector embeddings.
collection (Optional[str], optional): Collection name.
2 weeks ago
description (Optional[str], optional): Collection description. Defaults to "".
force_rebuild (bool, optional): Whether to force create a new collection if it already exists. Defaults to False.
2 weeks ago
text_max_length (int, optional): Maximum length for text field. Defaults to 65_535.
reference_max_length (int, optional): Maximum length for reference field. Defaults to 2048.
metric_type (str, optional): Metric type for vector similarity search. Defaults to "IP".
2 weeks ago
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
if not collection:
collection = self.default_collection
if description is None:
description = ""
self.metric_type = metric_type
try:
has_collection = self.client.has_collection(collection, timeout=5)
if force_rebuild and has_collection:
2 weeks ago
self.client.drop_collection(collection)
elif has_collection:
return
schema = self.client.create_schema(
enable_dynamic_field=False, auto_id=False, description=description
2 weeks ago
)
schema.add_field("id", DataType.VARCHAR, is_primary=True, max_length=64)
2 weeks ago
schema.add_field("embedding", DataType.FLOAT_VECTOR, dim=dim)
schema.add_field("text", DataType.VARCHAR, max_length=text_max_length)
2 weeks ago
schema.add_field("reference", DataType.VARCHAR, max_length=reference_max_length)
schema.add_field("metadata", DataType.JSON)
index_params = self.client.prepare_index_params()
index_params.add_index(field_name="embedding", metric_type=metric_type)
self.client.create_collection(
collection,
schema=schema,
index_params=index_params,
consistency_level="Strong",
)
log.color_print(f"create collection [{collection}] successfully")
except Exception as e:
log.critical(f"fail to init db for milvus, error info: {e}")
def insert_data(
self,
collection: str,
chunks: list[Chunk],
2 weeks ago
batch_size: int = 256,
*args,
**kwargs,
):
"""
Insert data into a Milvus collection.
Args:
collection (Optional[str]): Collection name. If None, uses default_collection.
chunks (List[Chunk]): List of Chunk objects to insert.
batch_size (int, optional): Number of chunks to insert in each batch. Defaults to 256.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
if not collection:
collection = self.default_collection
ids = [chunk.metadata.get('id', '') for chunk in chunks]
2 weeks ago
texts = [chunk.text for chunk in chunks]
references = [chunk.reference for chunk in chunks]
metadatas = [chunk.metadata for chunk in chunks]
embeddings = [chunk.embedding for chunk in chunks]
datas = [
{
"id": id,
2 weeks ago
"embedding": embedding,
"text": text,
"reference": reference,
"metadata": metadata,
}
for id, embedding, text, reference, metadata in zip(
ids, embeddings, texts, references, metadatas
2 weeks ago
)
]
batch_datas = [datas[i : i + batch_size] for i in range(0, len(datas), batch_size)]
try:
for batch_data in batch_datas:
self.client.insert(collection_name=collection, data=batch_data)
except Exception as e:
log.critical(f"fail to insert data, error info: {e}")
def search_data(
self,
collection: str,
vector: np.ndarray | list[float],
top_k: int = 3,
query_text: str = None,
2 weeks ago
*args,
**kwargs,
) -> list[RetrievalResult]:
2 weeks ago
"""
Search for similar vectors in a Milvus collection.
Args:
collection (Optional[str]): Collection name. If None, uses default_collection.
vector (Union[np.array, List[float]]): Query vector for similarity search.
top_k (int, optional): Number of results to return. Defaults to 5.
query_text (Optional[str], optional): Original query text for hybrid search. Defaults to None.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
List[RetrievalResult]: List of retrieval results containing similar vectors.
"""
if not collection:
collection = self.default_collection
try:
search_results = self.client.search(
collection_name=collection,
data=[vector],
limit=top_k,
output_fields=["embedding", "text", "reference", "metadata"],
timeout=10,
)
2 weeks ago
return [
RetrievalResult(
embedding=b["entity"]["embedding"],
text=b["entity"]["text"],
reference=b["entity"]["reference"],
score=b["distance"],
metadata=b["entity"]["metadata"],
)
for a in search_results
for b in a
]
except Exception as e:
log.critical(f"fail to search data, error info: {e}")
return []
def list_collections(self, *args, **kwargs) -> list[CollectionInfo]:
2 weeks ago
"""
List all collections in the Milvus database.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
List[CollectionInfo]: List of collection information objects.
"""
collection_infos = []
dim = kwargs.pop("dim", 0)
try:
collections = self.client.list_collections()
for collection in collections:
description = self.client.describe_collection(collection)
if dim != 0:
skip = False
for field_dict in description["fields"]:
if (
field_dict["name"] == "embedding"
and field_dict["type"] == DataType.FLOAT_VECTOR
):
if field_dict["params"]["dim"] != dim:
skip = True
if skip:
continue
collection_infos.append(
CollectionInfo(
collection_name=collection,
description=description["description"],
)
)
except Exception as e:
log.critical(f"fail to list collections, error info: {e}")
return collection_infos
def clear_collection(self, collection: str, *args, **kwargs):
2 weeks ago
"""
Clear (drop) a collection from the Milvus database.
Args:
collection (str, optional): Collection name to drop. Defaults to "deepsearcher".
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
if not collection:
collection = self.default_collection
try:
self.client.drop_collection(collection)
except Exception as e:
log.warning(f"fail to clear db, error info: {e}")
def clear_collections(self, *args, **kwargs):
"""
Clear (drop) all collections from the Milvus database.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
try:
collections = self.client.list_collections()
for collection in collections:
self.client.drop_collection(collection)
except Exception as e:
log.warning(f"fail to clear all collections, error info: {e}")