You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
70 lines
2.5 KiB
70 lines
2.5 KiB
from deepsearcher.llm.base import BaseLLM
|
|
|
|
|
|
class OpenAILLM(BaseLLM):
|
|
"""
|
|
OpenAI language model implementation.
|
|
|
|
This class provides an interface to interact with OpenAI's language models
|
|
through their API.
|
|
|
|
Attributes:
|
|
model (str): The OpenAI model identifier to use.
|
|
client: The OpenAI client instance.
|
|
"""
|
|
|
|
def __init__(self, model: str = "o1-mini", **kwargs):
|
|
"""
|
|
Initialize an OpenAI language model client.
|
|
|
|
Args:
|
|
model (str, optional): The model identifier to use. Defaults to "o1-mini".
|
|
**kwargs: Additional keyword arguments to pass to the OpenAI client.
|
|
- api_key: OpenAI API key. If not provided, uses OPENAI_API_KEY environment variable.
|
|
- base_url: OpenAI API base URL. If not provided, uses OPENAI_BASE_URL environment variable.
|
|
"""
|
|
from openai import OpenAI
|
|
|
|
self.model = model
|
|
if "api_key" in kwargs:
|
|
api_key = kwargs.pop("api_key")
|
|
if "base_url" in kwargs:
|
|
base_url = kwargs.pop("base_url")
|
|
self.client = OpenAI(api_key=api_key, base_url=base_url, **kwargs)
|
|
|
|
def chat(self, messages: list[dict]) -> str:
|
|
"""
|
|
Send a chat message to the OpenAI model and get a response.
|
|
|
|
Args:
|
|
messages (List[Dict]):
|
|
A list of message dictionaries, typically in the format
|
|
[{"role": "system", "content": "..."}, {"role": "user", "content": "..."}]
|
|
|
|
Returns:
|
|
response (str)
|
|
"""
|
|
with self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=messages,
|
|
stream=True,
|
|
temperature=0.8,
|
|
top_p=0.9,
|
|
presence_penalty=1.4
|
|
) as stream:
|
|
# stream到控制台测试
|
|
content = ""
|
|
reasoning_content = ""
|
|
for chunk in stream:
|
|
if not chunk.choices:
|
|
continue
|
|
else:
|
|
delta = chunk.choices[0].delta
|
|
if hasattr(delta, 'reasoning_content') and delta.reasoning_content is not None:
|
|
print(delta.reasoning_content, end='', flush=True)
|
|
reasoning_content += delta.reasoning_content
|
|
if hasattr(delta, 'content') and delta.content is not None:
|
|
print(delta.content, end="", flush=True)
|
|
content += delta.content
|
|
print("\n")
|
|
return content
|
|
|