import os from typing import List from openai import OpenAI from openai._types import NOT_GIVEN from deepsearcher.embedding.base import BaseEmbedding class OpenAIEmbedding(BaseEmbedding): """ OpenAI embedding model implementation. This class provides an interface to the OpenAI embedding API, which offers various embedding models for text processing. For more information, see: https://platform.openai.com/docs/guides/embeddings/use-cases """ def __init__(self, model: str, **kwargs): """ Initialize the OpenAI embedding model. Args: model (str): The model identifier to use for embeddings. **kwargs: Additional keyword arguments. - api_key (str): The API key. - base_url (str): The base URL. - model_name (str): Alternative way to specify the model. - dimension (int): The dimension of the embedding vectors. - dim_change (bool): Whether it's able to change the dimension of the generated embeddings. """ # Extract standard parameters (keep original behavior) if "api_key" in kwargs: api_key = kwargs.pop("api_key") if "base_url" in kwargs: base_url = kwargs.pop("base_url") else: base_url = os.getenv("OPENAI_BASE_URL") if "model_name" in kwargs: model = kwargs.pop("model_name") if "dimension" in kwargs: dimension = kwargs.pop("dimension") else: dimension = NOT_GIVEN if "dim_change" in kwargs: dim_change = kwargs.pop("dim_change") self.dim = dimension self.dim_change = dim_change self.model = model self.client = OpenAI(api_key=api_key, base_url=base_url, **kwargs) def embed_query(self, text: str) -> List[float]: """ Embed a single query text. Args: text (str): The query text to embed. Returns: List[float]: A list of floats representing the embedding vector. """ response = self.client.embeddings.create( input=[text], model=self.model, dimensions=self.dimension if self.dim_change is True else NOT_GIVEN ) return response.data[0].embedding def embed_documents(self, texts: List[str]) -> List[List[float]]: """ Embed a list of document texts. Args: texts (List[str]): A list of document texts to embed. Returns: List[List[float]]: A list of embedding vectors, one for each input text. """ response = self.client.embeddings.create( input=texts, model=self.model, dimensions=self.dimension if self.dim_change is True else NOT_GIVEN ) return [r.embedding for r in response.data] @property def dimension(self) -> int: """ Get the dimensionality of the embeddings for the current model. Returns: int: The number of dimensions in the embedding vectors. """ return self.dim