You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
207 lines
6.2 KiB
207 lines
6.2 KiB
from abc import ABC, abstractmethod
|
|
from typing import List, Union
|
|
|
|
import numpy as np
|
|
|
|
from deepsearcher.loader.splitter import Chunk
|
|
|
|
|
|
class RetrievalResult:
|
|
"""
|
|
Represents a result retrieved from the vector database.
|
|
|
|
This class encapsulates the information about a retrieved document,
|
|
including its embedding, text content, reference, metadata, and similarity score.
|
|
|
|
Attributes:
|
|
embedding: The vector embedding of the document.
|
|
text: The text content of the document.
|
|
reference: A reference to the source of the document.
|
|
metadata: Additional metadata associated with the document.
|
|
score: The similarity score of the document to the query.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embedding: np.array,
|
|
text: str,
|
|
reference: str,
|
|
metadata: dict,
|
|
score: float = 0.0,
|
|
):
|
|
"""
|
|
Initialize a RetrievalResult object.
|
|
|
|
Args:
|
|
embedding: The vector embedding of the document.
|
|
text: The text content of the document.
|
|
reference: A reference to the source of the document.
|
|
metadata: Additional metadata associated with the document.
|
|
score: The similarity score of the document to the query. Defaults to 0.0.
|
|
"""
|
|
self.embedding = embedding
|
|
self.text = text
|
|
self.reference = reference
|
|
self.metadata = metadata
|
|
self.score: float = score
|
|
|
|
def __repr__(self):
|
|
"""
|
|
Return a string representation of the RetrievalResult.
|
|
|
|
Returns:
|
|
A string representation of the RetrievalResult object.
|
|
"""
|
|
return f"RetrievalResult(score={self.score}, embedding={self.embedding}, text={self.text}, reference={self.reference}), metadata={self.metadata}"
|
|
|
|
|
|
def deduplicate_results(results: List[RetrievalResult]) -> List[RetrievalResult]:
|
|
"""
|
|
Remove duplicate results based on text content.
|
|
|
|
This function removes duplicate results from a list of RetrievalResult objects
|
|
by keeping only the first occurrence of each unique text content.
|
|
|
|
Args:
|
|
results: A list of RetrievalResult objects to deduplicate.
|
|
|
|
Returns:
|
|
A list of deduplicated RetrievalResult objects.
|
|
"""
|
|
all_text_set = set()
|
|
deduplicated_results = []
|
|
for result in results:
|
|
if result.text not in all_text_set:
|
|
all_text_set.add(result.text)
|
|
deduplicated_results.append(result)
|
|
return deduplicated_results
|
|
|
|
|
|
class CollectionInfo:
|
|
"""
|
|
Represents information about a collection in the vector database.
|
|
|
|
This class encapsulates the name and description of a collection.
|
|
|
|
Attributes:
|
|
collection_name: The name of the collection.
|
|
description: The description of the collection.
|
|
"""
|
|
|
|
def __init__(self, collection_name: str, description: str):
|
|
"""
|
|
Initialize a CollectionInfo object.
|
|
|
|
Args:
|
|
collection_name: The name of the collection.
|
|
description: The description of the collection.
|
|
"""
|
|
self.collection_name = collection_name
|
|
self.description = description
|
|
|
|
|
|
class BaseVectorDB(ABC):
|
|
"""
|
|
Abstract base class for vector database implementations.
|
|
|
|
This class defines the interface for vector database implementations,
|
|
including methods for initializing collections, inserting data, searching,
|
|
listing collections, and clearing the database.
|
|
|
|
Attributes:
|
|
default_collection: The name of the default collection.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
default_collection: str = "deepsearcher",
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Initialize a BaseVectorDB object.
|
|
|
|
Args:
|
|
default_collection: The name of the default collection. Defaults to "deepsearcher".
|
|
*args: Variable length argument list.
|
|
**kwargs: Arbitrary keyword arguments.
|
|
"""
|
|
self.default_collection = default_collection
|
|
|
|
@abstractmethod
|
|
def init_collection(
|
|
self,
|
|
dim: int,
|
|
collection: str,
|
|
description: str,
|
|
force_new_collection=False,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Initialize a collection in the vector database.
|
|
|
|
Args:
|
|
dim: The dimensionality of the vectors in the collection.
|
|
collection: The name of the collection.
|
|
description: The description of the collection.
|
|
force_new_collection: If True, drop the existing collection and create a new one.
|
|
*args: Variable length argument list.
|
|
**kwargs: Arbitrary keyword arguments.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def insert_data(self, collection: str, chunks: List[Chunk], *args, **kwargs):
|
|
"""
|
|
Insert data into a collection in the vector database.
|
|
|
|
Args:
|
|
collection: The name of the collection.
|
|
chunks: A list of Chunk objects to insert.
|
|
*args: Variable length argument list.
|
|
**kwargs: Arbitrary keyword arguments.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def search_data(
|
|
self, collection: str, vector: Union[np.array, List[float]], *args, **kwargs
|
|
) -> List[RetrievalResult]:
|
|
"""
|
|
Search for similar vectors in a collection.
|
|
|
|
Args:
|
|
collection: The name of the collection.
|
|
vector: The query vector to search for.
|
|
*args: Variable length argument list.
|
|
**kwargs: Arbitrary keyword arguments.
|
|
|
|
Returns:
|
|
A list of RetrievalResult objects representing the search results.
|
|
"""
|
|
pass
|
|
|
|
def list_collections(self, *args, **kwargs) -> List[CollectionInfo]:
|
|
"""
|
|
List all collections in the vector database.
|
|
|
|
Args:
|
|
*args: Variable length argument list.
|
|
**kwargs: Arbitrary keyword arguments.
|
|
|
|
Returns:
|
|
A list of CollectionInfo objects representing the collections.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def clear_db(self, *args, **kwargs):
|
|
"""
|
|
Clear the vector database.
|
|
|
|
Args:
|
|
*args: Variable length argument list.
|
|
**kwargs: Arbitrary keyword arguments.
|
|
"""
|
|
pass
|
|
|