diff --git a/deepsearcher/agent/deep_search.py b/deepsearcher/agent/deep_search.py
index d34f75e..bdc318d 100644
--- a/deepsearcher/agent/deep_search.py
+++ b/deepsearcher/agent/deep_search.py
@@ -2,6 +2,7 @@ from deepsearcher.agent.base import BaseAgent, describe_class
from deepsearcher.embedding.base import BaseEmbedding
from deepsearcher.llm.base import BaseLLM
from deepsearcher.utils import log
+from deepsearcher.utils.message_stream import send_search, send_think, send_answer
from deepsearcher.vector_db import RetrievalResult
from deepsearcher.vector_db.base import BaseVectorDB, deduplicate
from collections import defaultdict
@@ -230,14 +231,12 @@ class DeepSearch(BaseAgent):
all_retrieved_results = []
query_vector = self.embedding_model.embed_query(query)
for collection in selected_collections:
- log.color_print(f" Search [{query}] in [{collection}]... \n")
+ send_search(f"Search [{query}] in [{collection}]...")
retrieved_results = self.vector_db.search_data(
collection=collection, vector=query_vector, query_text=query
)
if not retrieved_results or len(retrieved_results) == 0:
- log.color_print(
- f" No relevant document chunks found in '{collection}'! \n"
- )
+ send_search(f"No relevant document chunks found in '{collection}'!")
continue
# Format all chunks for batch processing
@@ -288,13 +287,9 @@ class DeepSearch(BaseAgent):
references.add(retrieved_result.reference)
if accepted_chunk_num > 0:
- log.color_print(
- f" Accept {accepted_chunk_num} document chunk(s) from references: {list(references)} \n"
- )
+ send_search(f"Accept {accepted_chunk_num} document chunk(s) from references: {list(references)}")
else:
- log.color_print(
- f" No document chunk accepted from '{collection}'! \n"
- )
+ send_search(f"No document chunk accepted from '{collection}'!")
return all_retrieved_results
def _generate_more_sub_queries(
@@ -329,7 +324,7 @@ class DeepSearch(BaseAgent):
- Additional information about the retrieval process
"""
### SUB QUERIES ###
- log.color_print(f" {original_query} \n")
+ send_think(f" {original_query} ")
all_search_results = []
all_sub_queries = []
@@ -338,11 +333,11 @@ class DeepSearch(BaseAgent):
log.color_print("No sub queries were generated by the LLM. Exiting.")
return [], {}
else:
- log.color_print(f" Break down the original query into new sub queries: {sub_queries} ")
+ send_think(f"Break down the original query into new sub queries: {sub_queries}")
all_sub_queries.extend(sub_queries)
for it in range(self.max_iter):
- log.color_print(f">> Iteration: {it + 1}\n")
+ send_think(f">> Iteration: {it + 1}")
# Execute all search tasks sequentially
for query in sub_queries:
@@ -352,26 +347,26 @@ class DeepSearch(BaseAgent):
all_search_results = deduplicate(all_search_results)
deduped_len = len(all_search_results)
if undeduped_len - deduped_len != 0:
- log.color_print(
- f" Remove {undeduped_len - deduped_len} duplicates "
- )
+ send_search(f"Remove {undeduped_len - deduped_len} duplicates")
# search_res_from_internet = deduplicate_results(search_res_from_internet)
# all_search_res.extend(search_res_from_vectordb + search_res_from_internet)
- if it + 1 >= self.max_iter:
- log.color_print(" Exceeded maximum iterations. Exiting. ")
- break
+
### REFLECTION & GET MORE SUB QUERIES ###
- log.color_print(" Reflecting on the search results... ")
- sub_queries = self._generate_more_sub_queries(
- original_query, all_sub_queries, all_search_results
- )
- if not sub_queries or len(sub_queries) == 0:
- log.color_print(" No new search queries were generated. Exiting. ")
- break
+ # Only generate more queries if we haven't reached the maximum iterations
+ if it + 1 < self.max_iter:
+ send_think("Reflecting on the search results...")
+ sub_queries = self._generate_more_sub_queries(
+ original_query, all_sub_queries, all_search_results
+ )
+ if not sub_queries or len(sub_queries) == 0:
+ send_think("No new search queries were generated. Exiting.")
+ break
+ else:
+ send_think(f"New search queries for next iteration: {sub_queries}")
+ all_sub_queries.extend(sub_queries)
else:
- log.color_print(
- f" New search queries for next iteration: {sub_queries} ")
- all_sub_queries.extend(sub_queries)
+ send_think("Reached maximum iterations. Exiting.")
+ break
all_search_results = deduplicate(all_search_results)
return all_search_results, all_sub_queries
@@ -394,20 +389,19 @@ class DeepSearch(BaseAgent):
"""
all_retrieved_results, all_sub_queries = self.retrieve(original_query, **kwargs)
if not all_retrieved_results or len(all_retrieved_results) == 0:
- log.color_print(f"No relevant information found for query '{original_query}'.")
+ send_think(f"No relevant information found for query '{original_query}'.")
return "", []
chunks = self._format_chunks(all_retrieved_results)
- log.color_print(
- f" Summarize answer from all {len(all_retrieved_results)} retrieved chunks... \n"
- )
+ send_think(f"Summarize answer from all {len(all_retrieved_results)} retrieved chunks...")
summary_prompt = SUMMARY_PROMPT.format(
original_query=original_query,
all_sub_queries=all_sub_queries,
chunks=chunks
)
response = self.llm.chat([{"role": "user", "content": summary_prompt}])
- log.color_print("\n==== FINAL ANSWER====\n")
- log.color_print(self.llm.remove_think(response))
+ final_answer = self.llm.remove_think(response)
+ send_answer("==== FINAL ANSWER====")
+ send_answer(final_answer)
return self.llm.remove_think(response), all_retrieved_results
def _format_chunks(self, retrieved_results: list[RetrievalResult]):
diff --git a/deepsearcher/backend/templates/index.html b/deepsearcher/backend/templates/index.html
index 6c9d3e2..56bb73e 100644
--- a/deepsearcher/backend/templates/index.html
+++ b/deepsearcher/backend/templates/index.html
@@ -192,6 +192,49 @@
line-height: 1.6;
}
+ .message-stream {
+ margin-top: 16px;
+ }
+
+ .message-container {
+ max-height: 400px;
+ overflow-y: auto;
+ border: 1px solid var(--border-color);
+ border-radius: 6px;
+ background-color: white;
+ padding: 12px;
+ }
+
+ .message {
+ margin-bottom: 12px;
+ padding: 8px 12px;
+ border-radius: 4px;
+ border-left: 4px solid;
+ font-size: 14px;
+ line-height: 1.4;
+ }
+
+ .message-search {
+ background-color: #f0f9ff;
+ border-left-color: var(--info-color);
+ }
+
+ .message-think {
+ background-color: #fef3c7;
+ border-left-color: var(--warning-color);
+ }
+
+ .message-answer {
+ background-color: #d1fae5;
+ border-left-color: var(--success-color);
+ }
+
+ .message-timestamp {
+ font-size: 12px;
+ color: var(--text-secondary);
+ margin-top: 4px;
+ }
+
footer {
text-align: center;
margin-top: 30px;
@@ -266,11 +309,17 @@
+
@@ -300,6 +349,32 @@
}
}
+ // 工具函数:显示消息流
+ function displayMessages(messages) {
+ const container = document.getElementById('messageContainer');
+ container.innerHTML = '';
+
+ messages.forEach(message => {
+ const messageElement = document.createElement('div');
+ messageElement.className = `message message-${message.type}`;
+
+ const contentElement = document.createElement('div');
+ contentElement.textContent = message.content;
+
+ const timestampElement = document.createElement('div');
+ timestampElement.className = 'message-timestamp';
+ const date = new Date(message.timestamp * 1000);
+ timestampElement.textContent = date.toLocaleTimeString();
+
+ messageElement.appendChild(contentElement);
+ messageElement.appendChild(timestampElement);
+ container.appendChild(messageElement);
+ });
+
+ // 滚动到底部
+ container.scrollTop = container.scrollHeight;
+ }
+
// 工具函数:隐藏状态信息
function hideStatus(elementId) {
const statusElement = document.getElementById(elementId);
@@ -388,6 +463,28 @@
}
});
+ // 清空消息功能
+ document.getElementById('clearMessagesBtn').addEventListener('click', async function() {
+ try {
+ const response = await fetch('/clear-messages/', {
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json'
+ }
+ });
+
+ if (response.ok) {
+ const container = document.getElementById('messageContainer');
+ container.innerHTML = '';
+ showStatus('queryStatus', '消息已清空', 'success');
+ } else {
+ showStatus('queryStatus', '清空消息失败', 'error');
+ }
+ } catch (error) {
+ showStatus('queryStatus', `请求失败: ${error.message}`, 'error');
+ }
+ });
+
// 加载网站内容功能
document.getElementById('loadWebsiteBtn').addEventListener('click', async function() {
const button = this;
@@ -466,6 +563,12 @@
if (response.ok) {
showStatus('queryStatus', '查询完成', 'success');
document.getElementById('resultText').textContent = data.result;
+
+ // 显示消息流
+ if (data.messages && data.messages.length > 0) {
+ displayMessages(data.messages);
+ }
+
showResult();
} else {
showStatus('queryStatus', `查询失败: ${data.detail}`, 'error');
diff --git a/deepsearcher/utils/message_stream.py b/deepsearcher/utils/message_stream.py
new file mode 100644
index 0000000..a0ed0e7
--- /dev/null
+++ b/deepsearcher/utils/message_stream.py
@@ -0,0 +1,138 @@
+import json
+import time
+from typing import Dict, List, Optional, Callable, Any
+from enum import Enum
+from dataclasses import dataclass, asdict
+from datetime import datetime
+
+
+class MessageType(Enum):
+ """消息类型枚举"""
+ SEARCH = "search"
+ THINK = "think"
+ ANSWER = "answer"
+
+
+@dataclass
+class Message:
+ """消息数据结构"""
+ type: MessageType
+ content: str
+ timestamp: float
+ metadata: Optional[Dict[str, Any]] = None
+
+ def to_dict(self) -> Dict[str, Any]:
+ """转换为字典格式"""
+ return {
+ "type": self.type.value,
+ "content": self.content,
+ "timestamp": self.timestamp,
+ "metadata": self.metadata or {}
+ }
+
+ def to_json(self) -> str:
+ """转换为JSON字符串"""
+ return json.dumps(self.to_dict(), ensure_ascii=False)
+
+
+class MessageStream:
+ """消息流管理器"""
+
+ def __init__(self):
+ self._messages: List[Message] = []
+ self._callbacks: List[Callable[[Message], None]] = []
+ self._enabled = True
+
+ def add_callback(self, callback: Callable[[Message], None]):
+ """添加消息回调函数"""
+ self._callbacks.append(callback)
+
+ def remove_callback(self, callback: Callable[[Message], None]):
+ """移除消息回调函数"""
+ if callback in self._callbacks:
+ self._callbacks.remove(callback)
+
+ def enable(self):
+ """启用消息流"""
+ self._enabled = True
+
+ def disable(self):
+ """禁用消息流"""
+ self._enabled = False
+
+ def send_message(self, msg_type: MessageType, content: str, metadata: Optional[Dict[str, Any]] = None):
+ """发送消息"""
+ if not self._enabled:
+ return
+
+ message = Message(
+ type=msg_type,
+ content=content,
+ timestamp=time.time(),
+ metadata=metadata
+ )
+
+ self._messages.append(message)
+
+ # 调用所有回调函数
+ for callback in self._callbacks:
+ try:
+ callback(message)
+ except Exception as e:
+ print(f"Error in message callback: {e}")
+
+ def send_search(self, content: str, metadata: Optional[Dict[str, Any]] = None):
+ """发送搜索消息"""
+ self.send_message(MessageType.SEARCH, content, metadata)
+
+ def send_think(self, content: str, metadata: Optional[Dict[str, Any]] = None):
+ """发送思考消息"""
+ self.send_message(MessageType.THINK, content, metadata)
+
+ def send_answer(self, content: str, metadata: Optional[Dict[str, Any]] = None):
+ """发送答案消息"""
+ self.send_message(MessageType.ANSWER, content, metadata)
+
+ def get_messages(self) -> List[Message]:
+ """获取所有消息"""
+ return self._messages.copy()
+
+ def get_messages_by_type(self, msg_type: MessageType) -> List[Message]:
+ """根据类型获取消息"""
+ return [msg for msg in self._messages if msg.type == msg_type]
+
+ def clear_messages(self):
+ """清空所有消息"""
+ self._messages.clear()
+
+ def get_messages_as_dicts(self) -> List[Dict[str, Any]]:
+ """获取所有消息的字典格式"""
+ return [msg.to_dict() for msg in self._messages]
+
+ def get_messages_as_json(self) -> str:
+ """获取所有消息的JSON格式"""
+ return json.dumps(self.get_messages_as_dicts(), ensure_ascii=False)
+
+
+# 全局消息流实例
+message_stream = MessageStream()
+
+
+def send_search(content: str, metadata: Optional[Dict[str, Any]] = None):
+ """全局搜索消息发送函数"""
+ message_stream.send_search(content, metadata)
+
+
+def send_think(content: str, metadata: Optional[Dict[str, Any]] = None):
+ """全局思考消息发送函数"""
+ message_stream.send_think(content, metadata)
+
+
+def send_answer(content: str, metadata: Optional[Dict[str, Any]] = None):
+ """全局答案消息发送函数"""
+ message_stream.send_answer(content, metadata)
+
+
+def get_message_stream() -> MessageStream:
+ """获取全局消息流实例"""
+ return message_stream
diff --git a/docs/message_stream_system.md b/docs/message_stream_system.md
new file mode 100644
index 0000000..b572003
--- /dev/null
+++ b/docs/message_stream_system.md
@@ -0,0 +1,197 @@
+# 消息流系统 (Message Stream System)
+
+## 概述
+
+DeepSearcher 的消息流系统是一个新的消息传输机制,用于替代原来的日志传输方式。该系统支持三种类型的消息:`search`、`think` 和 `answer`,并提供了灵活的消息管理和传输功能。
+
+## 消息类型
+
+### 1. Search 消息
+- **类型**: `search`
+- **用途**: 表示搜索相关的操作和状态
+- **示例**:
+ - "在向量数据库中搜索相关信息..."
+ - "找到5个相关文档片段"
+ - "搜索人工智能的定义..."
+
+### 2. Think 消息
+- **类型**: `think`
+- **用途**: 表示思考和推理过程
+- **示例**:
+ - "开始处理查询: 什么是人工智能?"
+ - "分析搜索结果..."
+ - "生成子查询: 人工智能的定义、历史、应用"
+
+### 3. Answer 消息
+- **类型**: `answer`
+- **用途**: 表示最终答案和结果
+- **示例**:
+ - "==== FINAL ANSWER===="
+ - "人工智能是计算机科学的一个分支..."
+
+## 核心组件
+
+### MessageStream 类
+
+主要的消息流管理器,提供以下功能:
+
+```python
+from deepsearcher.utils.message_stream import MessageStream
+
+# 创建消息流实例
+message_stream = MessageStream()
+
+# 发送消息
+message_stream.send_search("搜索内容...")
+message_stream.send_think("思考内容...")
+message_stream.send_answer("答案内容...")
+
+# 获取消息
+messages = message_stream.get_messages()
+messages_dict = message_stream.get_messages_as_dicts()
+messages_json = message_stream.get_messages_as_json()
+```
+
+### 全局函数
+
+为了方便使用,提供了全局函数:
+
+```python
+from deepsearcher.utils.message_stream import send_search, send_think, send_answer
+
+# 直接发送消息
+send_search("搜索内容...")
+send_think("思考内容...")
+send_answer("答案内容...")
+```
+
+## API 接口
+
+### 1. 获取消息
+```
+GET /messages/
+```
+
+返回所有消息的列表:
+```json
+{
+ "messages": [
+ {
+ "type": "search",
+ "content": "搜索内容...",
+ "timestamp": 1755043653.9606102,
+ "metadata": {}
+ }
+ ]
+}
+```
+
+### 2. 清空消息
+```
+POST /clear-messages/
+```
+
+清空所有消息并返回成功状态:
+```json
+{
+ "message": "Messages cleared successfully"
+}
+```
+
+### 3. 查询接口(已更新)
+```
+GET /query/?original_query=&max_iter=
+```
+
+现在返回结果包含消息流:
+```json
+{
+ "result": "最终答案...",
+ "messages": [
+ {
+ "type": "search",
+ "content": "搜索内容...",
+ "timestamp": 1755043653.9606102,
+ "metadata": {}
+ }
+ ]
+}
+```
+
+## 前端集成
+
+前端界面现在包含一个消息流显示区域,实时显示处理过程中的各种消息:
+
+### CSS 样式
+- `.message-search`: 搜索消息样式(蓝色边框)
+- `.message-think`: 思考消息样式(黄色边框)
+- `.message-answer`: 答案消息样式(绿色边框)
+
+### JavaScript 功能
+- `displayMessages(messages)`: 显示消息流
+- 自动滚动到最新消息
+- 时间戳显示
+
+## 使用示例
+
+### 后端使用
+
+```python
+from deepsearcher.utils.message_stream import send_search, send_think, send_answer
+
+# 在搜索过程中发送消息
+send_think("开始处理查询...")
+send_search("在数据库中搜索...")
+send_search("找到相关文档...")
+send_think("分析结果...")
+send_answer("最终答案...")
+```
+
+### 前端使用
+
+```javascript
+// 获取消息
+const response = await fetch('/query/?original_query=test&max_iter=3');
+const data = await response.json();
+
+// 显示消息流
+if (data.messages && data.messages.length > 0) {
+ displayMessages(data.messages);
+}
+```
+
+## 优势
+
+1. **结构化数据**: 消息包含类型、内容、时间戳和元数据
+2. **类型安全**: 使用枚举确保消息类型的一致性
+3. **灵活传输**: 支持多种输出格式(字典、JSON)
+4. **实时显示**: 前端可以实时显示处理过程
+5. **易于扩展**: 可以轻松添加新的消息类型和功能
+
+## 迁移说明
+
+从原来的日志系统迁移到新的消息流系统:
+
+### 原来的代码
+```python
+log.color_print(f" Search [{query}] in [{collection}]... \n")
+log.color_print(f" Summarize answer... \n")
+log.color_print(f" Final answer... \n")
+```
+
+### 新的代码
+```python
+send_search(f"Search [{query}] in [{collection}]...")
+send_think("Summarize answer...")
+send_answer("Final answer...")
+```
+
+## 测试
+
+运行测试脚本验证系统功能:
+
+```bash
+python test_message_stream.py
+```
+
+这将测试消息流的基本功能,包括消息发送、获取和格式化。
diff --git a/main.py b/main.py
index 77eafdc..d74b823 100644
--- a/main.py
+++ b/main.py
@@ -10,6 +10,7 @@ import os
from deepsearcher.configuration import Configuration, init_config
from deepsearcher.offline_loading import load_from_local_files, load_from_website
from deepsearcher.online_query import query
+from deepsearcher.utils.message_stream import get_message_stream
app = FastAPI()
@@ -203,10 +204,50 @@ def perform_query(
HTTPException: If the query fails.
"""
try:
+ # 清空之前的消息
+ message_stream = get_message_stream()
+ message_stream.clear_messages()
+
result_text, _ = query(original_query, max_iter)
return {
- "result": result_text
+ "result": result_text,
+ "messages": message_stream.get_messages_as_dicts()
+ }
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+@app.get("/messages/")
+def get_messages():
+ """
+ Get all messages from the message stream.
+
+ Returns:
+ dict: A dictionary containing all messages.
+ """
+ try:
+ message_stream = get_message_stream()
+ return {
+ "messages": message_stream.get_messages_as_dicts()
+ }
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+@app.post("/clear-messages/")
+def clear_messages():
+ """
+ Clear all messages from the message stream.
+
+ Returns:
+ dict: A dictionary containing a success message.
+ """
+ try:
+ message_stream = get_message_stream()
+ message_stream.clear_messages()
+ return {
+ "message": "Messages cleared successfully"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
diff --git a/test_iteration.py b/test_iteration.py
new file mode 100644
index 0000000..d27eeb7
--- /dev/null
+++ b/test_iteration.py
@@ -0,0 +1,31 @@
+#!/usr/bin/env python3
+"""
+测试迭代逻辑的脚本
+"""
+
+def test_iteration_logic():
+ """测试迭代逻辑"""
+ max_iter = 3
+
+ print(f"测试最大迭代次数: {max_iter}")
+ print("-" * 40)
+
+ for it in range(max_iter):
+ print(f">> Iteration: {it + 1}")
+
+ # 模拟搜索任务
+ print(" 执行搜索任务...")
+
+ # 检查是否达到最大迭代次数
+ if it + 1 < max_iter:
+ print(" 反思并生成新的子查询...")
+ print(" 准备下一次迭代")
+ else:
+ print(" 达到最大迭代次数,退出")
+ break
+
+ print("-" * 40)
+ print("测试完成!")
+
+if __name__ == "__main__":
+ test_iteration_logic()