You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
76 lines
2.3 KiB
76 lines
2.3 KiB
from typing import List
|
|
|
|
from tqdm import tqdm
|
|
|
|
from deepsearcher.loader.splitter import Chunk
|
|
|
|
|
|
class BaseEmbedding:
|
|
"""
|
|
Abstract base class for embedding model implementations.
|
|
|
|
This class defines the interface for embedding model implementations,
|
|
including methods for embedding queries and documents, and a property
|
|
for the dimensionality of the embeddings.
|
|
"""
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
"""
|
|
Embed a single query text.
|
|
|
|
Args:
|
|
text: The query text to embed.
|
|
|
|
Returns:
|
|
A list of floats representing the embedding vector.
|
|
"""
|
|
pass
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
"""
|
|
Embed a list of document texts.
|
|
|
|
This default implementation calls embed_query for each text,
|
|
but implementations may override this with a more efficient batch method.
|
|
|
|
Args:
|
|
texts: A list of document texts to embed.
|
|
|
|
Returns:
|
|
A list of embedding vectors, one for each input text.
|
|
"""
|
|
return [self.embed_query(text) for text in texts]
|
|
|
|
def embed_chunks(self, chunks: List[Chunk], batch_size: int = 256) -> List[Chunk]:
|
|
"""
|
|
Embed a list of Chunk objects.
|
|
|
|
This method extracts the text from each chunk, embeds it in batches,
|
|
and updates the chunks with their embeddings.
|
|
|
|
Args:
|
|
chunks: A list of Chunk objects to embed.
|
|
batch_size: The number of chunks to process in each batch.
|
|
|
|
Returns:
|
|
The input list of Chunk objects, updated with embeddings.
|
|
"""
|
|
texts = [chunk.text for chunk in chunks]
|
|
batch_texts = [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)]
|
|
embeddings = []
|
|
for batch_text in tqdm(batch_texts, desc="Embedding chunks"):
|
|
batch_embeddings = self.embed_documents(batch_text)
|
|
embeddings.extend(batch_embeddings)
|
|
for chunk, embedding in zip(chunks, embeddings):
|
|
chunk.embedding = embedding
|
|
return chunks
|
|
|
|
@property
|
|
def dimension(self) -> int:
|
|
"""
|
|
Get the dimensionality of the embeddings.
|
|
|
|
Returns:
|
|
The number of dimensions in the embedding vectors.
|
|
"""
|
|
pass
|
|
|