|
|
@ -1,5 +1,3 @@ |
|
|
|
import os |
|
|
|
|
|
|
|
from deepsearcher.llm.base import BaseLLM |
|
|
|
|
|
|
|
|
|
|
@ -30,12 +28,8 @@ class OpenAILLM(BaseLLM): |
|
|
|
self.model = model |
|
|
|
if "api_key" in kwargs: |
|
|
|
api_key = kwargs.pop("api_key") |
|
|
|
else: |
|
|
|
api_key = os.getenv("OPENAI_API_KEY") |
|
|
|
if "base_url" in kwargs: |
|
|
|
base_url = kwargs.pop("base_url") |
|
|
|
else: |
|
|
|
base_url = os.getenv("OPENAI_BASE_URL") |
|
|
|
self.client = OpenAI(api_key=api_key, base_url=base_url, **kwargs) |
|
|
|
|
|
|
|
def chat(self, messages: list[dict], stream_callback = None) -> str: |
|
|
@ -50,18 +44,27 @@ class OpenAILLM(BaseLLM): |
|
|
|
Returns: |
|
|
|
response (str) |
|
|
|
""" |
|
|
|
completion = self.client.chat.completions.create( |
|
|
|
with self.client.chat.completions.create( |
|
|
|
model=self.model, |
|
|
|
messages=messages, |
|
|
|
stream=True |
|
|
|
) |
|
|
|
response = "" |
|
|
|
for chunk in completion: |
|
|
|
stream_response = chunk.choices[0].delta.content |
|
|
|
if stream_response: |
|
|
|
print(stream_response, end="", flush=True) |
|
|
|
response += stream_response |
|
|
|
if stream_callback: |
|
|
|
stream_callback(stream_response) |
|
|
|
stream=True, |
|
|
|
) as stream: |
|
|
|
content = "" |
|
|
|
reasoning_content = "" |
|
|
|
for chunk in stream: |
|
|
|
if not chunk.choices: |
|
|
|
continue |
|
|
|
else: |
|
|
|
delta = chunk.choices[0].delta |
|
|
|
if hasattr(delta, 'reasoning_content') and delta.reasoning_content is not None: |
|
|
|
print(delta.reasoning_content, end='', flush=True) |
|
|
|
reasoning_content += delta.reasoning_content |
|
|
|
if stream_callback: |
|
|
|
stream_callback(delta.reasoning_content) |
|
|
|
if hasattr(delta, 'content') and delta.content is not None: |
|
|
|
print(delta.content, end="", flush=True) |
|
|
|
content += delta.content |
|
|
|
if stream_callback: |
|
|
|
stream_callback(delta.content) |
|
|
|
print("\n") |
|
|
|
return response |
|
|
|
return content |
|
|
|