You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
104 lines
3.1 KiB
104 lines
3.1 KiB
2 weeks ago
|
import os
|
||
|
from typing import List
|
||
|
|
||
|
from openai import OpenAI
|
||
|
from openai._types import NOT_GIVEN
|
||
|
|
||
|
from deepsearcher.embedding.base import BaseEmbedding
|
||
|
|
||
|
|
||
|
class OpenAIEmbedding(BaseEmbedding):
|
||
|
"""
|
||
|
OpenAI embedding model implementation.
|
||
|
|
||
|
This class provides an interface to the OpenAI embedding API, which offers
|
||
|
various embedding models for text processing.
|
||
|
|
||
|
For more information, see:
|
||
|
https://platform.openai.com/docs/guides/embeddings/use-cases
|
||
|
"""
|
||
|
|
||
|
def __init__(self, model: str, **kwargs):
|
||
|
"""
|
||
|
Initialize the OpenAI embedding model.
|
||
|
|
||
|
Args:
|
||
|
model (str): The model identifier to use for embeddings.
|
||
|
**kwargs: Additional keyword arguments.
|
||
|
- api_key (str): The API key.
|
||
|
- base_url (str): The base URL.
|
||
|
- model_name (str): Alternative way to specify the model.
|
||
|
- dimension (int): The dimension of the embedding vectors.
|
||
|
- dim_change (bool): Whether it's able to change the dimension of the generated embeddings.
|
||
|
|
||
|
"""
|
||
|
# Extract standard parameters (keep original behavior)
|
||
|
if "api_key" in kwargs:
|
||
|
api_key = kwargs.pop("api_key")
|
||
|
|
||
|
if "base_url" in kwargs:
|
||
|
base_url = kwargs.pop("base_url")
|
||
|
else:
|
||
|
base_url = os.getenv("OPENAI_BASE_URL")
|
||
|
|
||
|
if "model_name" in kwargs:
|
||
|
model = kwargs.pop("model_name")
|
||
|
|
||
|
if "dimension" in kwargs:
|
||
|
dimension = kwargs.pop("dimension")
|
||
|
else:
|
||
|
dimension = NOT_GIVEN
|
||
|
|
||
|
if "dim_change" in kwargs:
|
||
|
dim_change = kwargs.pop("dim_change")
|
||
|
|
||
|
self.dim = dimension
|
||
|
self.dim_change = dim_change
|
||
|
self.model = model
|
||
|
|
||
|
self.client = OpenAI(api_key=api_key, base_url=base_url, **kwargs)
|
||
|
|
||
|
def embed_query(self, text: str) -> List[float]:
|
||
|
"""
|
||
|
Embed a single query text.
|
||
|
|
||
|
Args:
|
||
|
text (str): The query text to embed.
|
||
|
|
||
|
Returns:
|
||
|
List[float]: A list of floats representing the embedding vector.
|
||
|
"""
|
||
|
|
||
|
response = self.client.embeddings.create(
|
||
|
input=[text], model=self.model, dimensions=self.dimension if self.dim_change is True else NOT_GIVEN
|
||
|
)
|
||
|
|
||
|
return response.data[0].embedding
|
||
|
|
||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||
|
"""
|
||
|
Embed a list of document texts.
|
||
|
|
||
|
Args:
|
||
|
texts (List[str]): A list of document texts to embed.
|
||
|
|
||
|
Returns:
|
||
|
List[List[float]]: A list of embedding vectors, one for each input text.
|
||
|
"""
|
||
|
|
||
|
response = self.client.embeddings.create(
|
||
|
input=texts, model=self.model, dimensions=self.dimension if self.dim_change is True else NOT_GIVEN
|
||
|
)
|
||
|
|
||
|
return [r.embedding for r in response.data]
|
||
|
|
||
|
@property
|
||
|
def dimension(self) -> int:
|
||
|
"""
|
||
|
Get the dimensionality of the embeddings for the current model.
|
||
|
|
||
|
Returns:
|
||
|
int: The number of dimensions in the embedding vectors.
|
||
|
"""
|
||
|
return self.dim
|