You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

149 lines
4.4 KiB

import json
import time
from typing import Dict, List, Optional, Callable, Any
from enum import Enum
from dataclasses import dataclass, asdict
from datetime import datetime
class MessageType(Enum):
"""消息类型枚举"""
SEARCH = "search"
THINK = "think"
ANSWER = "answer"
COMPLETE = "complete"
@dataclass
class Message:
"""消息数据结构"""
type: MessageType
content: str
timestamp: float
metadata: Optional[Dict[str, Any]] = None
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"type": self.type.value,
"content": self.content,
"timestamp": self.timestamp,
"metadata": self.metadata or {}
}
def to_json(self) -> str:
"""转换为JSON字符串"""
return json.dumps(self.to_dict(), ensure_ascii=False)
class MessageStream:
"""消息流管理器"""
def __init__(self):
self._messages: List[Message] = []
self._callbacks: List[Callable[[Message], None]] = []
self._enabled = True
def add_callback(self, callback: Callable[[Message], None]):
"""添加消息回调函数"""
self._callbacks.append(callback)
def remove_callback(self, callback: Callable[[Message], None]):
"""移除消息回调函数"""
if callback in self._callbacks:
self._callbacks.remove(callback)
def enable(self):
"""启用消息流"""
self._enabled = True
def disable(self):
"""禁用消息流"""
self._enabled = False
def send_message(self, msg_type: MessageType, content: str, metadata: Optional[Dict[str, Any]] = None):
"""发送消息"""
if not self._enabled:
return
message = Message(
type=msg_type,
content=content,
timestamp=time.time(),
metadata=metadata
)
self._messages.append(message)
# 调用所有回调函数
for callback in self._callbacks:
try:
callback(message)
except Exception as e:
print(f"Error in message callback: {e}")
def send_search(self, content: str, metadata: Optional[Dict[str, Any]] = None):
"""发送搜索消息"""
self.send_message(MessageType.SEARCH, content, metadata)
def send_think(self, content: str, metadata: Optional[Dict[str, Any]] = None):
"""发送思考消息"""
self.send_message(MessageType.THINK, content, metadata)
def send_answer(self, content: str, metadata: Optional[Dict[str, Any]] = None):
"""发送答案消息"""
self.send_message(MessageType.ANSWER, content, metadata)
def send_complete(self, metadata: Optional[Dict[str, Any]] = None):
"""发送完成消息"""
self.send_message(MessageType.COMPLETE, "", metadata)
def get_messages(self) -> List[Message]:
"""获取所有消息"""
return self._messages.copy()
def get_messages_by_type(self, msg_type: MessageType) -> List[Message]:
"""根据类型获取消息"""
return [msg for msg in self._messages if msg.type == msg_type]
def clear_messages(self):
"""清空所有消息"""
self._messages.clear()
def get_messages_as_dicts(self) -> List[Dict[str, Any]]:
"""获取所有消息的字典格式"""
return [msg.to_dict() for msg in self._messages]
def get_messages_as_json(self) -> str:
"""获取所有消息的JSON格式"""
return json.dumps(self.get_messages_as_dicts(), ensure_ascii=False)
# 全局消息流实例
message_stream = MessageStream()
def send_search(content: str, metadata: Optional[Dict[str, Any]] = None):
"""全局搜索消息发送函数"""
message_stream.send_search(content, metadata)
def send_think(content: str, metadata: Optional[Dict[str, Any]] = None):
"""全局思考消息发送函数"""
message_stream.send_think(content, metadata)
def send_answer(content: str, metadata: Optional[Dict[str, Any]] = None):
"""全局答案消息发送函数"""
message_stream.send_answer(content, metadata)
def send_complete(metadata: Optional[Dict[str, Any]] = None):
"""全局完成消息发送函数"""
message_stream.send_complete(metadata)
def get_message_stream() -> MessageStream:
"""获取全局消息流实例"""
return message_stream