You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

104 lines
3.1 KiB

2 weeks ago
import os
from typing import List
from openai import OpenAI
from openai._types import NOT_GIVEN
from deepsearcher.embedding.base import BaseEmbedding
class OpenAIEmbedding(BaseEmbedding):
"""
OpenAI embedding model implementation.
This class provides an interface to the OpenAI embedding API, which offers
various embedding models for text processing.
For more information, see:
https://platform.openai.com/docs/guides/embeddings/use-cases
"""
def __init__(self, model: str, **kwargs):
"""
Initialize the OpenAI embedding model.
Args:
model (str): The model identifier to use for embeddings.
**kwargs: Additional keyword arguments.
- api_key (str): The API key.
- base_url (str): The base URL.
- model_name (str): Alternative way to specify the model.
- dimension (int): The dimension of the embedding vectors.
- dim_change (bool): Whether it's able to change the dimension of the generated embeddings.
"""
# Extract standard parameters (keep original behavior)
if "api_key" in kwargs:
api_key = kwargs.pop("api_key")
if "base_url" in kwargs:
base_url = kwargs.pop("base_url")
else:
base_url = os.getenv("OPENAI_BASE_URL")
if "model_name" in kwargs:
model = kwargs.pop("model_name")
if "dimension" in kwargs:
dimension = kwargs.pop("dimension")
else:
dimension = NOT_GIVEN
if "dim_change" in kwargs:
dim_change = kwargs.pop("dim_change")
self.dim = dimension
self.dim_change = dim_change
self.model = model
self.client = OpenAI(api_key=api_key, base_url=base_url, **kwargs)
def embed_query(self, text: str) -> List[float]:
"""
Embed a single query text.
Args:
text (str): The query text to embed.
Returns:
List[float]: A list of floats representing the embedding vector.
"""
response = self.client.embeddings.create(
input=[text], model=self.model, dimensions=self.dimension if self.dim_change is True else NOT_GIVEN
)
return response.data[0].embedding
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
Embed a list of document texts.
Args:
texts (List[str]): A list of document texts to embed.
Returns:
List[List[float]]: A list of embedding vectors, one for each input text.
"""
response = self.client.embeddings.create(
input=texts, model=self.model, dimensions=self.dimension if self.dim_change is True else NOT_GIVEN
)
return [r.embedding for r in response.data]
@property
def dimension(self) -> int:
"""
Get the dimensionality of the embeddings for the current model.
Returns:
int: The number of dimensions in the embedding vectors.
"""
return self.dim