You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
88 lines
2.8 KiB
88 lines
2.8 KiB
import ast
|
|
import re
|
|
from abc import ABC
|
|
|
|
|
|
class BaseLLM(ABC):
|
|
"""
|
|
Abstract base class for language model implementations.
|
|
|
|
This class defines the interface for language model implementations,
|
|
including methods for chat-based interactions and parsing responses.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""
|
|
Initialize a BaseLLM object.
|
|
"""
|
|
pass
|
|
|
|
def chat(self, messages: list[dict]) -> str:
|
|
"""
|
|
Send a chat message to the language model and get a response.
|
|
|
|
Args:
|
|
messages:
|
|
A list of message dictionaries, typically in the format
|
|
[{"role": "system", "content": "..."}, {"role": "user", "content": "..."}]
|
|
|
|
Returns:
|
|
response (str)
|
|
The content of the llm response.
|
|
"""
|
|
pass
|
|
|
|
@staticmethod
|
|
def literal_eval(response: str) -> str:
|
|
"""
|
|
Parse a string response into a Python object using ast.literal_eval.
|
|
|
|
This method attempts to extract and parse JSON or Python literals from the response content,
|
|
handling various formats like code blocks and special tags.
|
|
|
|
Args:
|
|
response: The string content to parse.
|
|
|
|
Returns:
|
|
The processed and parsed Python object.
|
|
|
|
Raises:
|
|
ValueError: If the response content cannot be parsed.
|
|
"""
|
|
response = response.strip()
|
|
|
|
response = BaseLLM.remove_think(response)
|
|
|
|
try:
|
|
if response.startswith("```") and response.endswith("```"):
|
|
if response.startswith("```python"):
|
|
response = response[9:-3]
|
|
elif response.startswith("```json"):
|
|
response = response[7:-3]
|
|
elif response.startswith("```str"):
|
|
response = response[6:-3]
|
|
elif response.startswith("```\n"):
|
|
response = response[4:-3]
|
|
else:
|
|
raise ValueError("Invalid code block format")
|
|
result = ast.literal_eval(response.strip())
|
|
except Exception:
|
|
matches = re.findall(r"(\[.*?\]|\{.*?\})", response, re.DOTALL)
|
|
|
|
if len(matches) != 1:
|
|
raise ValueError(
|
|
f"Invalid JSON/List format for response content:\n{response}"
|
|
)
|
|
|
|
json_part = matches[0]
|
|
return ast.literal_eval(json_part)
|
|
|
|
return result
|
|
|
|
@staticmethod
|
|
def remove_think(response: str) -> str:
|
|
# remove content between <think> and </think>, especial for reasoning model
|
|
if "<think>" in response and "</think>" in response:
|
|
end_of_think = response.find("</think>") + len("</think>")
|
|
response = response[end_of_think:]
|
|
return response.strip()
|
|
|