You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
148 lines
4.4 KiB
148 lines
4.4 KiB
import json
|
|
import time
|
|
from typing import Dict, List, Optional, Callable, Any
|
|
from enum import Enum
|
|
from dataclasses import dataclass, asdict
|
|
from datetime import datetime
|
|
|
|
|
|
class MessageType(Enum):
|
|
"""消息类型枚举"""
|
|
SEARCH = "search"
|
|
THINK = "think"
|
|
ANSWER = "answer"
|
|
COMPLETE = "complete"
|
|
|
|
|
|
@dataclass
|
|
class Message:
|
|
"""消息数据结构"""
|
|
type: MessageType
|
|
content: str
|
|
timestamp: float
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""转换为字典格式"""
|
|
return {
|
|
"type": self.type.value,
|
|
"content": self.content,
|
|
"timestamp": self.timestamp,
|
|
"metadata": self.metadata or {}
|
|
}
|
|
|
|
def to_json(self) -> str:
|
|
"""转换为JSON字符串"""
|
|
return json.dumps(self.to_dict(), ensure_ascii=False)
|
|
|
|
|
|
class MessageStream:
|
|
"""消息流管理器"""
|
|
|
|
def __init__(self):
|
|
self._messages: List[Message] = []
|
|
self._callbacks: List[Callable[[Message], None]] = []
|
|
self._enabled = True
|
|
|
|
def add_callback(self, callback: Callable[[Message], None]):
|
|
"""添加消息回调函数"""
|
|
self._callbacks.append(callback)
|
|
|
|
def remove_callback(self, callback: Callable[[Message], None]):
|
|
"""移除消息回调函数"""
|
|
if callback in self._callbacks:
|
|
self._callbacks.remove(callback)
|
|
|
|
def enable(self):
|
|
"""启用消息流"""
|
|
self._enabled = True
|
|
|
|
def disable(self):
|
|
"""禁用消息流"""
|
|
self._enabled = False
|
|
|
|
def send_message(self, msg_type: MessageType, content: str, metadata: Optional[Dict[str, Any]] = None):
|
|
"""发送消息"""
|
|
if not self._enabled:
|
|
return
|
|
|
|
message = Message(
|
|
type=msg_type,
|
|
content=content,
|
|
timestamp=time.time(),
|
|
metadata=metadata
|
|
)
|
|
|
|
self._messages.append(message)
|
|
|
|
# 调用所有回调函数
|
|
for callback in self._callbacks:
|
|
try:
|
|
callback(message)
|
|
except Exception as e:
|
|
print(f"Error in message callback: {e}")
|
|
|
|
def send_search(self, content: str, metadata: Optional[Dict[str, Any]] = None):
|
|
"""发送搜索消息"""
|
|
self.send_message(MessageType.SEARCH, content, metadata)
|
|
|
|
def send_think(self, content: str, metadata: Optional[Dict[str, Any]] = None):
|
|
"""发送思考消息"""
|
|
self.send_message(MessageType.THINK, content, metadata)
|
|
|
|
def send_answer(self, content: str, metadata: Optional[Dict[str, Any]] = None):
|
|
"""发送答案消息"""
|
|
self.send_message(MessageType.ANSWER, content, metadata)
|
|
|
|
def send_complete(self, metadata: Optional[Dict[str, Any]] = None):
|
|
"""发送完成消息"""
|
|
self.send_message(MessageType.COMPLETE, "", metadata)
|
|
|
|
def get_messages(self) -> List[Message]:
|
|
"""获取所有消息"""
|
|
return self._messages.copy()
|
|
|
|
def get_messages_by_type(self, msg_type: MessageType) -> List[Message]:
|
|
"""根据类型获取消息"""
|
|
return [msg for msg in self._messages if msg.type == msg_type]
|
|
|
|
def clear_messages(self):
|
|
"""清空所有消息"""
|
|
self._messages.clear()
|
|
|
|
def get_messages_as_dicts(self) -> List[Dict[str, Any]]:
|
|
"""获取所有消息的字典格式"""
|
|
return [msg.to_dict() for msg in self._messages]
|
|
|
|
def get_messages_as_json(self) -> str:
|
|
"""获取所有消息的JSON格式"""
|
|
return json.dumps(self.get_messages_as_dicts(), ensure_ascii=False)
|
|
|
|
|
|
# 全局消息流实例
|
|
message_stream = MessageStream()
|
|
|
|
|
|
def send_search(content: str, metadata: Optional[Dict[str, Any]] = None):
|
|
"""全局搜索消息发送函数"""
|
|
message_stream.send_search(content, metadata)
|
|
|
|
|
|
def send_think(content: str, metadata: Optional[Dict[str, Any]] = None):
|
|
"""全局思考消息发送函数"""
|
|
message_stream.send_think(content, metadata)
|
|
|
|
|
|
def send_answer(content: str, metadata: Optional[Dict[str, Any]] = None):
|
|
"""全局答案消息发送函数"""
|
|
message_stream.send_answer(content, metadata)
|
|
|
|
|
|
def send_complete(metadata: Optional[Dict[str, Any]] = None):
|
|
"""全局完成消息发送函数"""
|
|
message_stream.send_complete(metadata)
|
|
|
|
|
|
def get_message_stream() -> MessageStream:
|
|
"""获取全局消息流实例"""
|
|
return message_stream
|
|
|