You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
290 lines
10 KiB
290 lines
10 KiB
import uuid
|
|
from typing import List, Optional, Union
|
|
|
|
import numpy as np
|
|
|
|
from deepsearcher.loader.splitter import Chunk
|
|
from deepsearcher.utils import log
|
|
from deepsearcher.vector_db.base import BaseVectorDB, CollectionInfo, RetrievalResult
|
|
|
|
DEFAULT_COLLECTION_NAME = "deepsearcher"
|
|
|
|
TEXT_PAYLOAD_KEY = "text"
|
|
REFERENCE_PAYLOAD_KEY = "reference"
|
|
METADATA_PAYLOAD_KEY = "metadata"
|
|
|
|
|
|
class Qdrant(BaseVectorDB):
|
|
"""Vector DB implementation powered by [Qdrant](https://qdrant.tech/)"""
|
|
|
|
def __init__(
|
|
self,
|
|
location: Optional[str] = None,
|
|
url: Optional[str] = None,
|
|
port: Optional[int] = 6333,
|
|
grpc_port: int = 6334,
|
|
prefer_grpc: bool = False,
|
|
https: Optional[bool] = None,
|
|
api_key: Optional[str] = None,
|
|
prefix: Optional[str] = None,
|
|
timeout: Optional[int] = None,
|
|
host: Optional[str] = None,
|
|
path: Optional[str] = None,
|
|
default_collection: str = DEFAULT_COLLECTION_NAME,
|
|
):
|
|
"""
|
|
Initialize the Qdrant client with flexible connection options.
|
|
|
|
Args:
|
|
location (Optional[str], optional):
|
|
- If ":memory:" - use in-memory Qdrant instance.
|
|
- If str - use it as a URL parameter.
|
|
- If None - use default values for host and port.
|
|
Defaults to None.
|
|
|
|
url (Optional[str], optional):
|
|
URL for Qdrant service, can include scheme, host, port, and prefix.
|
|
Allows flexible connection string specification.
|
|
Defaults to None.
|
|
|
|
port (Optional[int], optional):
|
|
Port of the REST API interface.
|
|
Defaults to 6333.
|
|
|
|
grpc_port (int, optional):
|
|
Port of the gRPC interface.
|
|
Defaults to 6334.
|
|
|
|
prefer_grpc (bool, optional):
|
|
If True, use gRPC interface whenever possible in custom methods.
|
|
Defaults to False.
|
|
|
|
https (Optional[bool], optional):
|
|
If True, use HTTPS (SSL) protocol.
|
|
Defaults to None.
|
|
|
|
api_key (Optional[str], optional):
|
|
API key for authentication in Qdrant Cloud.
|
|
Defaults to None.
|
|
|
|
prefix (Optional[str], optional):
|
|
If not None, add prefix to the REST URL path.
|
|
Example: 'service/v1' results in 'http://localhost:6333/service/v1/{qdrant-endpoint}'
|
|
Defaults to None.
|
|
|
|
timeout (Optional[int], optional):
|
|
Timeout for REST and gRPC API requests.
|
|
Default is 5 seconds for REST and unlimited for gRPC.
|
|
Defaults to None.
|
|
|
|
host (Optional[str], optional):
|
|
Host name of Qdrant service.
|
|
If url and host are None, defaults to 'localhost'.
|
|
Defaults to None.
|
|
|
|
path (Optional[str], optional):
|
|
Persistence path for QdrantLocal.
|
|
Defaults to None.
|
|
|
|
default_collection (str, optional):
|
|
Default collection name to be used.
|
|
"""
|
|
try:
|
|
from qdrant_client import QdrantClient
|
|
except ImportError as original_error:
|
|
raise ImportError(
|
|
"Qdrant client is not installed. Install it using: pip install qdrant-client\n"
|
|
) from original_error
|
|
|
|
super().__init__(default_collection)
|
|
self.client = QdrantClient(
|
|
location=location,
|
|
url=url,
|
|
port=port,
|
|
grpc_port=grpc_port,
|
|
prefer_grpc=prefer_grpc,
|
|
https=https,
|
|
api_key=api_key,
|
|
prefix=prefix,
|
|
timeout=timeout,
|
|
host=host,
|
|
path=path,
|
|
)
|
|
|
|
def init_collection(
|
|
self,
|
|
dim: int,
|
|
collection: Optional[str] = None,
|
|
description: Optional[str] = "",
|
|
force_new_collection: bool = False,
|
|
text_max_length: int = 65_535,
|
|
reference_max_length: int = 2048,
|
|
distance_metric: str = "Cosine",
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Initialize a collection in Qdrant.
|
|
|
|
Args:
|
|
dim (int): Dimension of the vector embeddings.
|
|
collection (Optional[str], optional): Collection name.
|
|
description (Optional[str], optional): Collection description. Defaults to "".
|
|
force_new_collection (bool, optional): Whether to force create a new collection if it already exists. Defaults to False.
|
|
text_max_length (int, optional): Maximum length for text field. Defaults to 65_535.
|
|
reference_max_length (int, optional): Maximum length for reference field. Defaults to 2048.
|
|
distance_metric (str, optional): Metric type for vector similarity search. Defaults to "Cosine".
|
|
*args: Variable length argument list.
|
|
**kwargs: Arbitrary keyword arguments.
|
|
"""
|
|
from qdrant_client import models
|
|
|
|
collection = collection or self.default_collection
|
|
|
|
try:
|
|
collection_exists = self.client.collection_exists(collection_name=collection)
|
|
|
|
if force_new_collection and collection_exists:
|
|
self.client.delete_collection(collection_name=collection)
|
|
collection_exists = False
|
|
|
|
if not collection_exists:
|
|
self.client.create_collection(
|
|
collection_name=collection,
|
|
vectors_config=models.VectorParams(size=dim, distance=distance_metric),
|
|
*args,
|
|
**kwargs,
|
|
)
|
|
|
|
log.color_print(f"Created collection [{collection}] successfully")
|
|
except Exception as e:
|
|
log.critical(f"Failed to init Qdrant collection, error info: {e}")
|
|
|
|
def insert_data(
|
|
self,
|
|
collection: Optional[str],
|
|
chunks: List[Chunk],
|
|
batch_size: int = 256,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Insert data into a Qdrant collection.
|
|
|
|
Args:
|
|
collection (Optional[str]): Collection name.
|
|
chunks (List[Chunk]): List of Chunk objects to insert.
|
|
batch_size (int, optional): Number of chunks to insert in each batch. Defaults to 256.
|
|
*args: Variable length argument list.
|
|
**kwargs: Arbitrary keyword arguments.
|
|
"""
|
|
from qdrant_client import models
|
|
|
|
try:
|
|
for i in range(0, len(chunks), batch_size):
|
|
batch_chunks = chunks[i : i + batch_size]
|
|
|
|
points = [
|
|
models.PointStruct(
|
|
id=uuid.uuid4().hex,
|
|
vector=chunk.embedding,
|
|
payload={
|
|
TEXT_PAYLOAD_KEY: chunk.text,
|
|
REFERENCE_PAYLOAD_KEY: chunk.reference,
|
|
METADATA_PAYLOAD_KEY: chunk.metadata,
|
|
},
|
|
)
|
|
for chunk in batch_chunks
|
|
]
|
|
|
|
self.client.upsert(
|
|
collection_name=collection or self.default_collection, points=points
|
|
)
|
|
except Exception as e:
|
|
log.critical(f"Failed to insert data, error info: {e}")
|
|
|
|
def search_data(
|
|
self,
|
|
collection: Optional[str],
|
|
vector: Union[np.array, List[float]],
|
|
top_k: int = 5,
|
|
*args,
|
|
**kwargs,
|
|
) -> List[RetrievalResult]:
|
|
"""
|
|
Search for similar vectors in a Qdrant collection.
|
|
|
|
Args:
|
|
collection (Optional[str]): Collection name..
|
|
vector (Union[np.array, List[float]]): Query vector for similarity search.
|
|
top_k (int, optional): Number of results to return. Defaults to 5.
|
|
*args: Variable length argument list.
|
|
**kwargs: Arbitrary keyword arguments.
|
|
|
|
Returns:
|
|
List[RetrievalResult]: List of retrieval results containing similar vectors.
|
|
"""
|
|
try:
|
|
results = self.client.query_points(
|
|
collection_name=collection or self.default_collection,
|
|
query=vector,
|
|
limit=top_k,
|
|
with_payload=True,
|
|
with_vectors=True,
|
|
).points
|
|
|
|
return [
|
|
RetrievalResult(
|
|
embedding=result.vector,
|
|
text=result.payload.get(TEXT_PAYLOAD_KEY, ""),
|
|
reference=result.payload.get(REFERENCE_PAYLOAD_KEY, ""),
|
|
score=result.score,
|
|
metadata=result.payload.get(METADATA_PAYLOAD_KEY, {}),
|
|
)
|
|
for result in results
|
|
]
|
|
except Exception as e:
|
|
log.critical(f"Failed to search data, error info: {e}")
|
|
return []
|
|
|
|
def list_collections(self, *args, **kwargs) -> List[CollectionInfo]:
|
|
"""
|
|
List all collections in the Qdrant database.
|
|
|
|
Args:
|
|
*args: Variable length argument list.
|
|
**kwargs: Arbitrary keyword arguments.
|
|
|
|
Returns:
|
|
List[CollectionInfo]: List of collection information objects.
|
|
"""
|
|
collection_infos = []
|
|
|
|
try:
|
|
collections = self.client.get_collections().collections
|
|
for collection in collections:
|
|
collection_infos.append(
|
|
CollectionInfo(
|
|
collection_name=collection.name,
|
|
# Qdrant doesn't have a native description field
|
|
description=collection.name,
|
|
)
|
|
)
|
|
except Exception as e:
|
|
log.critical(f"Failed to list collections, error info: {e}")
|
|
|
|
return collection_infos
|
|
|
|
def clear_db(self, collection: Optional[str] = None, *args, **kwargs):
|
|
"""
|
|
Clear (drop) a collection from the Qdrant database.
|
|
|
|
Args:
|
|
collection (str, optional): Collection name to drop.
|
|
*args: Variable length argument list.
|
|
**kwargs: Arbitrary keyword arguments.
|
|
"""
|
|
try:
|
|
self.client.delete_collection(collection_name=collection or self.default_collection)
|
|
except Exception as e:
|
|
log.warning(f"Failed to drop collection, error info: {e}")
|
|
|