You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

291 lines
10 KiB

2 weeks ago
import uuid
from typing import List, Optional, Union
import numpy as np
from deepsearcher.loader.splitter import Chunk
from deepsearcher.utils import log
from deepsearcher.vector_db.base import BaseVectorDB, CollectionInfo, RetrievalResult
DEFAULT_COLLECTION_NAME = "deepsearcher"
TEXT_PAYLOAD_KEY = "text"
REFERENCE_PAYLOAD_KEY = "reference"
METADATA_PAYLOAD_KEY = "metadata"
class Qdrant(BaseVectorDB):
"""Vector DB implementation powered by [Qdrant](https://qdrant.tech/)"""
def __init__(
self,
location: Optional[str] = None,
url: Optional[str] = None,
port: Optional[int] = 6333,
grpc_port: int = 6334,
prefer_grpc: bool = False,
https: Optional[bool] = None,
api_key: Optional[str] = None,
prefix: Optional[str] = None,
timeout: Optional[int] = None,
host: Optional[str] = None,
path: Optional[str] = None,
default_collection: str = DEFAULT_COLLECTION_NAME,
):
"""
Initialize the Qdrant client with flexible connection options.
Args:
location (Optional[str], optional):
- If ":memory:" - use in-memory Qdrant instance.
- If str - use it as a URL parameter.
- If None - use default values for host and port.
Defaults to None.
url (Optional[str], optional):
URL for Qdrant service, can include scheme, host, port, and prefix.
Allows flexible connection string specification.
Defaults to None.
port (Optional[int], optional):
Port of the REST API interface.
Defaults to 6333.
grpc_port (int, optional):
Port of the gRPC interface.
Defaults to 6334.
prefer_grpc (bool, optional):
If True, use gRPC interface whenever possible in custom methods.
Defaults to False.
https (Optional[bool], optional):
If True, use HTTPS (SSL) protocol.
Defaults to None.
api_key (Optional[str], optional):
API key for authentication in Qdrant Cloud.
Defaults to None.
prefix (Optional[str], optional):
If not None, add prefix to the REST URL path.
Example: 'service/v1' results in 'http://localhost:6333/service/v1/{qdrant-endpoint}'
Defaults to None.
timeout (Optional[int], optional):
Timeout for REST and gRPC API requests.
Default is 5 seconds for REST and unlimited for gRPC.
Defaults to None.
host (Optional[str], optional):
Host name of Qdrant service.
If url and host are None, defaults to 'localhost'.
Defaults to None.
path (Optional[str], optional):
Persistence path for QdrantLocal.
Defaults to None.
default_collection (str, optional):
Default collection name to be used.
"""
try:
from qdrant_client import QdrantClient
except ImportError as original_error:
raise ImportError(
"Qdrant client is not installed. Install it using: pip install qdrant-client\n"
) from original_error
super().__init__(default_collection)
self.client = QdrantClient(
location=location,
url=url,
port=port,
grpc_port=grpc_port,
prefer_grpc=prefer_grpc,
https=https,
api_key=api_key,
prefix=prefix,
timeout=timeout,
host=host,
path=path,
)
def init_collection(
self,
dim: int,
collection: Optional[str] = None,
description: Optional[str] = "",
force_new_collection: bool = False,
text_max_length: int = 65_535,
reference_max_length: int = 2048,
distance_metric: str = "Cosine",
*args,
**kwargs,
):
"""
Initialize a collection in Qdrant.
Args:
dim (int): Dimension of the vector embeddings.
collection (Optional[str], optional): Collection name.
description (Optional[str], optional): Collection description. Defaults to "".
force_new_collection (bool, optional): Whether to force create a new collection if it already exists. Defaults to False.
text_max_length (int, optional): Maximum length for text field. Defaults to 65_535.
reference_max_length (int, optional): Maximum length for reference field. Defaults to 2048.
distance_metric (str, optional): Metric type for vector similarity search. Defaults to "Cosine".
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
from qdrant_client import models
collection = collection or self.default_collection
try:
collection_exists = self.client.collection_exists(collection_name=collection)
if force_new_collection and collection_exists:
self.client.delete_collection(collection_name=collection)
collection_exists = False
if not collection_exists:
self.client.create_collection(
collection_name=collection,
vectors_config=models.VectorParams(size=dim, distance=distance_metric),
*args,
**kwargs,
)
log.color_print(f"Created collection [{collection}] successfully")
except Exception as e:
log.critical(f"Failed to init Qdrant collection, error info: {e}")
def insert_data(
self,
collection: Optional[str],
chunks: List[Chunk],
batch_size: int = 256,
*args,
**kwargs,
):
"""
Insert data into a Qdrant collection.
Args:
collection (Optional[str]): Collection name.
chunks (List[Chunk]): List of Chunk objects to insert.
batch_size (int, optional): Number of chunks to insert in each batch. Defaults to 256.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
from qdrant_client import models
try:
for i in range(0, len(chunks), batch_size):
batch_chunks = chunks[i : i + batch_size]
points = [
models.PointStruct(
id=uuid.uuid4().hex,
vector=chunk.embedding,
payload={
TEXT_PAYLOAD_KEY: chunk.text,
REFERENCE_PAYLOAD_KEY: chunk.reference,
METADATA_PAYLOAD_KEY: chunk.metadata,
},
)
for chunk in batch_chunks
]
self.client.upsert(
collection_name=collection or self.default_collection, points=points
)
except Exception as e:
log.critical(f"Failed to insert data, error info: {e}")
def search_data(
self,
collection: Optional[str],
vector: Union[np.array, List[float]],
top_k: int = 5,
*args,
**kwargs,
) -> List[RetrievalResult]:
"""
Search for similar vectors in a Qdrant collection.
Args:
collection (Optional[str]): Collection name..
vector (Union[np.array, List[float]]): Query vector for similarity search.
top_k (int, optional): Number of results to return. Defaults to 5.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
List[RetrievalResult]: List of retrieval results containing similar vectors.
"""
try:
results = self.client.query_points(
collection_name=collection or self.default_collection,
query=vector,
limit=top_k,
with_payload=True,
with_vectors=True,
).points
return [
RetrievalResult(
embedding=result.vector,
text=result.payload.get(TEXT_PAYLOAD_KEY, ""),
reference=result.payload.get(REFERENCE_PAYLOAD_KEY, ""),
score=result.score,
metadata=result.payload.get(METADATA_PAYLOAD_KEY, {}),
)
for result in results
]
except Exception as e:
log.critical(f"Failed to search data, error info: {e}")
return []
def list_collections(self, *args, **kwargs) -> List[CollectionInfo]:
"""
List all collections in the Qdrant database.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
List[CollectionInfo]: List of collection information objects.
"""
collection_infos = []
try:
collections = self.client.get_collections().collections
for collection in collections:
collection_infos.append(
CollectionInfo(
collection_name=collection.name,
# Qdrant doesn't have a native description field
description=collection.name,
)
)
except Exception as e:
log.critical(f"Failed to list collections, error info: {e}")
return collection_infos
def clear_db(self, collection: Optional[str] = None, *args, **kwargs):
"""
Clear (drop) a collection from the Qdrant database.
Args:
collection (str, optional): Collection name to drop.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
try:
self.client.delete_collection(collection_name=collection or self.default_collection)
except Exception as e:
log.warning(f"Failed to drop collection, error info: {e}")