Browse Source

feat:优化消息发送

main
tanxing 5 days ago
parent
commit
9eafef3c61
  1. 64
      deepsearcher/agent/deep_search.py
  2. 103
      deepsearcher/backend/templates/index.html
  3. 138
      deepsearcher/utils/message_stream.py
  4. 197
      docs/message_stream_system.md
  5. 43
      main.py
  6. 31
      test_iteration.py

64
deepsearcher/agent/deep_search.py

@ -2,6 +2,7 @@ from deepsearcher.agent.base import BaseAgent, describe_class
from deepsearcher.embedding.base import BaseEmbedding
from deepsearcher.llm.base import BaseLLM
from deepsearcher.utils import log
from deepsearcher.utils.message_stream import send_search, send_think, send_answer
from deepsearcher.vector_db import RetrievalResult
from deepsearcher.vector_db.base import BaseVectorDB, deduplicate
from collections import defaultdict
@ -230,14 +231,12 @@ class DeepSearch(BaseAgent):
all_retrieved_results = []
query_vector = self.embedding_model.embed_query(query)
for collection in selected_collections:
log.color_print(f"<search> Search [{query}] in [{collection}]... </search>\n")
send_search(f"Search [{query}] in [{collection}]...")
retrieved_results = self.vector_db.search_data(
collection=collection, vector=query_vector, query_text=query
)
if not retrieved_results or len(retrieved_results) == 0:
log.color_print(
f"<search> No relevant document chunks found in '{collection}'! </search>\n"
)
send_search(f"No relevant document chunks found in '{collection}'!")
continue
# Format all chunks for batch processing
@ -288,13 +287,9 @@ class DeepSearch(BaseAgent):
references.add(retrieved_result.reference)
if accepted_chunk_num > 0:
log.color_print(
f"<search> Accept {accepted_chunk_num} document chunk(s) from references: {list(references)} </search>\n"
)
send_search(f"Accept {accepted_chunk_num} document chunk(s) from references: {list(references)}")
else:
log.color_print(
f"<search> No document chunk accepted from '{collection}'! </search>\n"
)
send_search(f"No document chunk accepted from '{collection}'!")
return all_retrieved_results
def _generate_more_sub_queries(
@ -329,7 +324,7 @@ class DeepSearch(BaseAgent):
- Additional information about the retrieval process
"""
### SUB QUERIES ###
log.color_print(f"<query> {original_query} </query>\n")
send_think(f"<query> {original_query} </query>")
all_search_results = []
all_sub_queries = []
@ -338,11 +333,11 @@ class DeepSearch(BaseAgent):
log.color_print("No sub queries were generated by the LLM. Exiting.")
return [], {}
else:
log.color_print(f"</think> Break down the original query into new sub queries: {sub_queries} ")
send_think(f"Break down the original query into new sub queries: {sub_queries}")
all_sub_queries.extend(sub_queries)
for it in range(self.max_iter):
log.color_print(f">> Iteration: {it + 1}\n")
send_think(f">> Iteration: {it + 1}")
# Execute all search tasks sequentially
for query in sub_queries:
@ -352,26 +347,26 @@ class DeepSearch(BaseAgent):
all_search_results = deduplicate(all_search_results)
deduped_len = len(all_search_results)
if undeduped_len - deduped_len != 0:
log.color_print(
f"<search> Remove {undeduped_len - deduped_len} duplicates </search> "
)
send_search(f"Remove {undeduped_len - deduped_len} duplicates")
# search_res_from_internet = deduplicate_results(search_res_from_internet)
# all_search_res.extend(search_res_from_vectordb + search_res_from_internet)
if it + 1 >= self.max_iter:
log.color_print("</think> Exceeded maximum iterations. Exiting. ")
break
### REFLECTION & GET MORE SUB QUERIES ###
log.color_print("</think> Reflecting on the search results... ")
sub_queries = self._generate_more_sub_queries(
original_query, all_sub_queries, all_search_results
)
if not sub_queries or len(sub_queries) == 0:
log.color_print("</think> No new search queries were generated. Exiting. ")
break
# Only generate more queries if we haven't reached the maximum iterations
if it + 1 < self.max_iter:
send_think("Reflecting on the search results...")
sub_queries = self._generate_more_sub_queries(
original_query, all_sub_queries, all_search_results
)
if not sub_queries or len(sub_queries) == 0:
send_think("No new search queries were generated. Exiting.")
break
else:
send_think(f"New search queries for next iteration: {sub_queries}")
all_sub_queries.extend(sub_queries)
else:
log.color_print(
f"</think> New search queries for next iteration: {sub_queries} ")
all_sub_queries.extend(sub_queries)
send_think("Reached maximum iterations. Exiting.")
break
all_search_results = deduplicate(all_search_results)
return all_search_results, all_sub_queries
@ -394,20 +389,19 @@ class DeepSearch(BaseAgent):
"""
all_retrieved_results, all_sub_queries = self.retrieve(original_query, **kwargs)
if not all_retrieved_results or len(all_retrieved_results) == 0:
log.color_print(f"No relevant information found for query '{original_query}'.")
send_think(f"No relevant information found for query '{original_query}'.")
return "", []
chunks = self._format_chunks(all_retrieved_results)
log.color_print(
f"<think> Summarize answer from all {len(all_retrieved_results)} retrieved chunks... </think>\n"
)
send_think(f"Summarize answer from all {len(all_retrieved_results)} retrieved chunks...")
summary_prompt = SUMMARY_PROMPT.format(
original_query=original_query,
all_sub_queries=all_sub_queries,
chunks=chunks
)
response = self.llm.chat([{"role": "user", "content": summary_prompt}])
log.color_print("\n==== FINAL ANSWER====\n")
log.color_print(self.llm.remove_think(response))
final_answer = self.llm.remove_think(response)
send_answer("==== FINAL ANSWER====")
send_answer(final_answer)
return self.llm.remove_think(response), all_retrieved_results
def _format_chunks(self, retrieved_results: list[RetrievalResult]):

103
deepsearcher/backend/templates/index.html

@ -192,6 +192,49 @@
line-height: 1.6;
}
.message-stream {
margin-top: 16px;
}
.message-container {
max-height: 400px;
overflow-y: auto;
border: 1px solid var(--border-color);
border-radius: 6px;
background-color: white;
padding: 12px;
}
.message {
margin-bottom: 12px;
padding: 8px 12px;
border-radius: 4px;
border-left: 4px solid;
font-size: 14px;
line-height: 1.4;
}
.message-search {
background-color: #f0f9ff;
border-left-color: var(--info-color);
}
.message-think {
background-color: #fef3c7;
border-left-color: var(--warning-color);
}
.message-answer {
background-color: #d1fae5;
border-left-color: var(--success-color);
}
.message-timestamp {
font-size: 12px;
color: var(--text-secondary);
margin-top: 4px;
}
footer {
text-align: center;
margin-top: 30px;
@ -266,11 +309,17 @@
<input type="number" id="maxIter" min="1" max="10" value="3">
</div>
<button id="queryBtn">执行查询</button>
<button id="clearMessagesBtn" style="margin-left: 10px; background-color: var(--text-secondary);">清空消息</button>
<div id="queryStatus" class="status"></div>
<div id="queryResult" class="result-container">
<h3>查询结果:</h3>
<div class="query-result" id="resultText"></div>
<h3>处理过程:</h3>
<div id="messageStream" class="message-stream">
<div class="message-container" id="messageContainer"></div>
</div>
</div>
</div>
</main>
@ -300,6 +349,32 @@
}
}
// 工具函数:显示消息流
function displayMessages(messages) {
const container = document.getElementById('messageContainer');
container.innerHTML = '';
messages.forEach(message => {
const messageElement = document.createElement('div');
messageElement.className = `message message-${message.type}`;
const contentElement = document.createElement('div');
contentElement.textContent = message.content;
const timestampElement = document.createElement('div');
timestampElement.className = 'message-timestamp';
const date = new Date(message.timestamp * 1000);
timestampElement.textContent = date.toLocaleTimeString();
messageElement.appendChild(contentElement);
messageElement.appendChild(timestampElement);
container.appendChild(messageElement);
});
// 滚动到底部
container.scrollTop = container.scrollHeight;
}
// 工具函数:隐藏状态信息
function hideStatus(elementId) {
const statusElement = document.getElementById(elementId);
@ -388,6 +463,28 @@
}
});
// 清空消息功能
document.getElementById('clearMessagesBtn').addEventListener('click', async function() {
try {
const response = await fetch('/clear-messages/', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
}
});
if (response.ok) {
const container = document.getElementById('messageContainer');
container.innerHTML = '';
showStatus('queryStatus', '消息已清空', 'success');
} else {
showStatus('queryStatus', '清空消息失败', 'error');
}
} catch (error) {
showStatus('queryStatus', `请求失败: ${error.message}`, 'error');
}
});
// 加载网站内容功能
document.getElementById('loadWebsiteBtn').addEventListener('click', async function() {
const button = this;
@ -466,6 +563,12 @@
if (response.ok) {
showStatus('queryStatus', '查询完成', 'success');
document.getElementById('resultText').textContent = data.result;
// 显示消息流
if (data.messages && data.messages.length > 0) {
displayMessages(data.messages);
}
showResult();
} else {
showStatus('queryStatus', `查询失败: ${data.detail}`, 'error');

138
deepsearcher/utils/message_stream.py

@ -0,0 +1,138 @@
import json
import time
from typing import Dict, List, Optional, Callable, Any
from enum import Enum
from dataclasses import dataclass, asdict
from datetime import datetime
class MessageType(Enum):
"""消息类型枚举"""
SEARCH = "search"
THINK = "think"
ANSWER = "answer"
@dataclass
class Message:
"""消息数据结构"""
type: MessageType
content: str
timestamp: float
metadata: Optional[Dict[str, Any]] = None
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"type": self.type.value,
"content": self.content,
"timestamp": self.timestamp,
"metadata": self.metadata or {}
}
def to_json(self) -> str:
"""转换为JSON字符串"""
return json.dumps(self.to_dict(), ensure_ascii=False)
class MessageStream:
"""消息流管理器"""
def __init__(self):
self._messages: List[Message] = []
self._callbacks: List[Callable[[Message], None]] = []
self._enabled = True
def add_callback(self, callback: Callable[[Message], None]):
"""添加消息回调函数"""
self._callbacks.append(callback)
def remove_callback(self, callback: Callable[[Message], None]):
"""移除消息回调函数"""
if callback in self._callbacks:
self._callbacks.remove(callback)
def enable(self):
"""启用消息流"""
self._enabled = True
def disable(self):
"""禁用消息流"""
self._enabled = False
def send_message(self, msg_type: MessageType, content: str, metadata: Optional[Dict[str, Any]] = None):
"""发送消息"""
if not self._enabled:
return
message = Message(
type=msg_type,
content=content,
timestamp=time.time(),
metadata=metadata
)
self._messages.append(message)
# 调用所有回调函数
for callback in self._callbacks:
try:
callback(message)
except Exception as e:
print(f"Error in message callback: {e}")
def send_search(self, content: str, metadata: Optional[Dict[str, Any]] = None):
"""发送搜索消息"""
self.send_message(MessageType.SEARCH, content, metadata)
def send_think(self, content: str, metadata: Optional[Dict[str, Any]] = None):
"""发送思考消息"""
self.send_message(MessageType.THINK, content, metadata)
def send_answer(self, content: str, metadata: Optional[Dict[str, Any]] = None):
"""发送答案消息"""
self.send_message(MessageType.ANSWER, content, metadata)
def get_messages(self) -> List[Message]:
"""获取所有消息"""
return self._messages.copy()
def get_messages_by_type(self, msg_type: MessageType) -> List[Message]:
"""根据类型获取消息"""
return [msg for msg in self._messages if msg.type == msg_type]
def clear_messages(self):
"""清空所有消息"""
self._messages.clear()
def get_messages_as_dicts(self) -> List[Dict[str, Any]]:
"""获取所有消息的字典格式"""
return [msg.to_dict() for msg in self._messages]
def get_messages_as_json(self) -> str:
"""获取所有消息的JSON格式"""
return json.dumps(self.get_messages_as_dicts(), ensure_ascii=False)
# 全局消息流实例
message_stream = MessageStream()
def send_search(content: str, metadata: Optional[Dict[str, Any]] = None):
"""全局搜索消息发送函数"""
message_stream.send_search(content, metadata)
def send_think(content: str, metadata: Optional[Dict[str, Any]] = None):
"""全局思考消息发送函数"""
message_stream.send_think(content, metadata)
def send_answer(content: str, metadata: Optional[Dict[str, Any]] = None):
"""全局答案消息发送函数"""
message_stream.send_answer(content, metadata)
def get_message_stream() -> MessageStream:
"""获取全局消息流实例"""
return message_stream

197
docs/message_stream_system.md

@ -0,0 +1,197 @@
# 消息流系统 (Message Stream System)
## 概述
DeepSearcher 的消息流系统是一个新的消息传输机制,用于替代原来的日志传输方式。该系统支持三种类型的消息:`search`、`think` 和 `answer`,并提供了灵活的消息管理和传输功能。
## 消息类型
### 1. Search 消息
- **类型**: `search`
- **用途**: 表示搜索相关的操作和状态
- **示例**:
- "在向量数据库中搜索相关信息..."
- "找到5个相关文档片段"
- "搜索人工智能的定义..."
### 2. Think 消息
- **类型**: `think`
- **用途**: 表示思考和推理过程
- **示例**:
- "开始处理查询: 什么是人工智能?"
- "分析搜索结果..."
- "生成子查询: 人工智能的定义、历史、应用"
### 3. Answer 消息
- **类型**: `answer`
- **用途**: 表示最终答案和结果
- **示例**:
- "==== FINAL ANSWER===="
- "人工智能是计算机科学的一个分支..."
## 核心组件
### MessageStream 类
主要的消息流管理器,提供以下功能:
```python
from deepsearcher.utils.message_stream import MessageStream
# 创建消息流实例
message_stream = MessageStream()
# 发送消息
message_stream.send_search("搜索内容...")
message_stream.send_think("思考内容...")
message_stream.send_answer("答案内容...")
# 获取消息
messages = message_stream.get_messages()
messages_dict = message_stream.get_messages_as_dicts()
messages_json = message_stream.get_messages_as_json()
```
### 全局函数
为了方便使用,提供了全局函数:
```python
from deepsearcher.utils.message_stream import send_search, send_think, send_answer
# 直接发送消息
send_search("搜索内容...")
send_think("思考内容...")
send_answer("答案内容...")
```
## API 接口
### 1. 获取消息
```
GET /messages/
```
返回所有消息的列表:
```json
{
"messages": [
{
"type": "search",
"content": "搜索内容...",
"timestamp": 1755043653.9606102,
"metadata": {}
}
]
}
```
### 2. 清空消息
```
POST /clear-messages/
```
清空所有消息并返回成功状态:
```json
{
"message": "Messages cleared successfully"
}
```
### 3. 查询接口(已更新)
```
GET /query/?original_query=<query>&max_iter=<iterations>
```
现在返回结果包含消息流:
```json
{
"result": "最终答案...",
"messages": [
{
"type": "search",
"content": "搜索内容...",
"timestamp": 1755043653.9606102,
"metadata": {}
}
]
}
```
## 前端集成
前端界面现在包含一个消息流显示区域,实时显示处理过程中的各种消息:
### CSS 样式
- `.message-search`: 搜索消息样式(蓝色边框)
- `.message-think`: 思考消息样式(黄色边框)
- `.message-answer`: 答案消息样式(绿色边框)
### JavaScript 功能
- `displayMessages(messages)`: 显示消息流
- 自动滚动到最新消息
- 时间戳显示
## 使用示例
### 后端使用
```python
from deepsearcher.utils.message_stream import send_search, send_think, send_answer
# 在搜索过程中发送消息
send_think("开始处理查询...")
send_search("在数据库中搜索...")
send_search("找到相关文档...")
send_think("分析结果...")
send_answer("最终答案...")
```
### 前端使用
```javascript
// 获取消息
const response = await fetch('/query/?original_query=test&max_iter=3');
const data = await response.json();
// 显示消息流
if (data.messages && data.messages.length > 0) {
displayMessages(data.messages);
}
```
## 优势
1. **结构化数据**: 消息包含类型、内容、时间戳和元数据
2. **类型安全**: 使用枚举确保消息类型的一致性
3. **灵活传输**: 支持多种输出格式(字典、JSON)
4. **实时显示**: 前端可以实时显示处理过程
5. **易于扩展**: 可以轻松添加新的消息类型和功能
## 迁移说明
从原来的日志系统迁移到新的消息流系统:
### 原来的代码
```python
log.color_print(f"<search> Search [{query}] in [{collection}]... </search>\n")
log.color_print(f"<think> Summarize answer... </think>\n")
log.color_print(f"<answer> Final answer... </answer>\n")
```
### 新的代码
```python
send_search(f"Search [{query}] in [{collection}]...")
send_think("Summarize answer...")
send_answer("Final answer...")
```
## 测试
运行测试脚本验证系统功能:
```bash
python test_message_stream.py
```
这将测试消息流的基本功能,包括消息发送、获取和格式化。

43
main.py

@ -10,6 +10,7 @@ import os
from deepsearcher.configuration import Configuration, init_config
from deepsearcher.offline_loading import load_from_local_files, load_from_website
from deepsearcher.online_query import query
from deepsearcher.utils.message_stream import get_message_stream
app = FastAPI()
@ -203,10 +204,50 @@ def perform_query(
HTTPException: If the query fails.
"""
try:
# 清空之前的消息
message_stream = get_message_stream()
message_stream.clear_messages()
result_text, _ = query(original_query, max_iter)
return {
"result": result_text
"result": result_text,
"messages": message_stream.get_messages_as_dicts()
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/messages/")
def get_messages():
"""
Get all messages from the message stream.
Returns:
dict: A dictionary containing all messages.
"""
try:
message_stream = get_message_stream()
return {
"messages": message_stream.get_messages_as_dicts()
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/clear-messages/")
def clear_messages():
"""
Clear all messages from the message stream.
Returns:
dict: A dictionary containing a success message.
"""
try:
message_stream = get_message_stream()
message_stream.clear_messages()
return {
"message": "Messages cleared successfully"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

31
test_iteration.py

@ -0,0 +1,31 @@
#!/usr/bin/env python3
"""
测试迭代逻辑的脚本
"""
def test_iteration_logic():
"""测试迭代逻辑"""
max_iter = 3
print(f"测试最大迭代次数: {max_iter}")
print("-" * 40)
for it in range(max_iter):
print(f">> Iteration: {it + 1}")
# 模拟搜索任务
print(" 执行搜索任务...")
# 检查是否达到最大迭代次数
if it + 1 < max_iter:
print(" 反思并生成新的子查询...")
print(" 准备下一次迭代")
else:
print(" 达到最大迭代次数,退出")
break
print("-" * 40)
print("测试完成!")
if __name__ == "__main__":
test_iteration_logic()
Loading…
Cancel
Save