diff --git a/deepsearcher/agent/deep_search.py b/deepsearcher/agent/deep_search.py index d34f75e..bdc318d 100644 --- a/deepsearcher/agent/deep_search.py +++ b/deepsearcher/agent/deep_search.py @@ -2,6 +2,7 @@ from deepsearcher.agent.base import BaseAgent, describe_class from deepsearcher.embedding.base import BaseEmbedding from deepsearcher.llm.base import BaseLLM from deepsearcher.utils import log +from deepsearcher.utils.message_stream import send_search, send_think, send_answer from deepsearcher.vector_db import RetrievalResult from deepsearcher.vector_db.base import BaseVectorDB, deduplicate from collections import defaultdict @@ -230,14 +231,12 @@ class DeepSearch(BaseAgent): all_retrieved_results = [] query_vector = self.embedding_model.embed_query(query) for collection in selected_collections: - log.color_print(f" Search [{query}] in [{collection}]... \n") + send_search(f"Search [{query}] in [{collection}]...") retrieved_results = self.vector_db.search_data( collection=collection, vector=query_vector, query_text=query ) if not retrieved_results or len(retrieved_results) == 0: - log.color_print( - f" No relevant document chunks found in '{collection}'! \n" - ) + send_search(f"No relevant document chunks found in '{collection}'!") continue # Format all chunks for batch processing @@ -288,13 +287,9 @@ class DeepSearch(BaseAgent): references.add(retrieved_result.reference) if accepted_chunk_num > 0: - log.color_print( - f" Accept {accepted_chunk_num} document chunk(s) from references: {list(references)} \n" - ) + send_search(f"Accept {accepted_chunk_num} document chunk(s) from references: {list(references)}") else: - log.color_print( - f" No document chunk accepted from '{collection}'! \n" - ) + send_search(f"No document chunk accepted from '{collection}'!") return all_retrieved_results def _generate_more_sub_queries( @@ -329,7 +324,7 @@ class DeepSearch(BaseAgent): - Additional information about the retrieval process """ ### SUB QUERIES ### - log.color_print(f" {original_query} \n") + send_think(f" {original_query} ") all_search_results = [] all_sub_queries = [] @@ -338,11 +333,11 @@ class DeepSearch(BaseAgent): log.color_print("No sub queries were generated by the LLM. Exiting.") return [], {} else: - log.color_print(f" Break down the original query into new sub queries: {sub_queries} ") + send_think(f"Break down the original query into new sub queries: {sub_queries}") all_sub_queries.extend(sub_queries) for it in range(self.max_iter): - log.color_print(f">> Iteration: {it + 1}\n") + send_think(f">> Iteration: {it + 1}") # Execute all search tasks sequentially for query in sub_queries: @@ -352,26 +347,26 @@ class DeepSearch(BaseAgent): all_search_results = deduplicate(all_search_results) deduped_len = len(all_search_results) if undeduped_len - deduped_len != 0: - log.color_print( - f" Remove {undeduped_len - deduped_len} duplicates " - ) + send_search(f"Remove {undeduped_len - deduped_len} duplicates") # search_res_from_internet = deduplicate_results(search_res_from_internet) # all_search_res.extend(search_res_from_vectordb + search_res_from_internet) - if it + 1 >= self.max_iter: - log.color_print(" Exceeded maximum iterations. Exiting. ") - break + ### REFLECTION & GET MORE SUB QUERIES ### - log.color_print(" Reflecting on the search results... ") - sub_queries = self._generate_more_sub_queries( - original_query, all_sub_queries, all_search_results - ) - if not sub_queries or len(sub_queries) == 0: - log.color_print(" No new search queries were generated. Exiting. ") - break + # Only generate more queries if we haven't reached the maximum iterations + if it + 1 < self.max_iter: + send_think("Reflecting on the search results...") + sub_queries = self._generate_more_sub_queries( + original_query, all_sub_queries, all_search_results + ) + if not sub_queries or len(sub_queries) == 0: + send_think("No new search queries were generated. Exiting.") + break + else: + send_think(f"New search queries for next iteration: {sub_queries}") + all_sub_queries.extend(sub_queries) else: - log.color_print( - f" New search queries for next iteration: {sub_queries} ") - all_sub_queries.extend(sub_queries) + send_think("Reached maximum iterations. Exiting.") + break all_search_results = deduplicate(all_search_results) return all_search_results, all_sub_queries @@ -394,20 +389,19 @@ class DeepSearch(BaseAgent): """ all_retrieved_results, all_sub_queries = self.retrieve(original_query, **kwargs) if not all_retrieved_results or len(all_retrieved_results) == 0: - log.color_print(f"No relevant information found for query '{original_query}'.") + send_think(f"No relevant information found for query '{original_query}'.") return "", [] chunks = self._format_chunks(all_retrieved_results) - log.color_print( - f" Summarize answer from all {len(all_retrieved_results)} retrieved chunks... \n" - ) + send_think(f"Summarize answer from all {len(all_retrieved_results)} retrieved chunks...") summary_prompt = SUMMARY_PROMPT.format( original_query=original_query, all_sub_queries=all_sub_queries, chunks=chunks ) response = self.llm.chat([{"role": "user", "content": summary_prompt}]) - log.color_print("\n==== FINAL ANSWER====\n") - log.color_print(self.llm.remove_think(response)) + final_answer = self.llm.remove_think(response) + send_answer("==== FINAL ANSWER====") + send_answer(final_answer) return self.llm.remove_think(response), all_retrieved_results def _format_chunks(self, retrieved_results: list[RetrievalResult]): diff --git a/deepsearcher/backend/templates/index.html b/deepsearcher/backend/templates/index.html index 6c9d3e2..56bb73e 100644 --- a/deepsearcher/backend/templates/index.html +++ b/deepsearcher/backend/templates/index.html @@ -192,6 +192,49 @@ line-height: 1.6; } + .message-stream { + margin-top: 16px; + } + + .message-container { + max-height: 400px; + overflow-y: auto; + border: 1px solid var(--border-color); + border-radius: 6px; + background-color: white; + padding: 12px; + } + + .message { + margin-bottom: 12px; + padding: 8px 12px; + border-radius: 4px; + border-left: 4px solid; + font-size: 14px; + line-height: 1.4; + } + + .message-search { + background-color: #f0f9ff; + border-left-color: var(--info-color); + } + + .message-think { + background-color: #fef3c7; + border-left-color: var(--warning-color); + } + + .message-answer { + background-color: #d1fae5; + border-left-color: var(--success-color); + } + + .message-timestamp { + font-size: 12px; + color: var(--text-secondary); + margin-top: 4px; + } + footer { text-align: center; margin-top: 30px; @@ -266,11 +309,17 @@ +

查询结果:

+ +

处理过程:

+
+
+
@@ -300,6 +349,32 @@ } } + // 工具函数:显示消息流 + function displayMessages(messages) { + const container = document.getElementById('messageContainer'); + container.innerHTML = ''; + + messages.forEach(message => { + const messageElement = document.createElement('div'); + messageElement.className = `message message-${message.type}`; + + const contentElement = document.createElement('div'); + contentElement.textContent = message.content; + + const timestampElement = document.createElement('div'); + timestampElement.className = 'message-timestamp'; + const date = new Date(message.timestamp * 1000); + timestampElement.textContent = date.toLocaleTimeString(); + + messageElement.appendChild(contentElement); + messageElement.appendChild(timestampElement); + container.appendChild(messageElement); + }); + + // 滚动到底部 + container.scrollTop = container.scrollHeight; + } + // 工具函数:隐藏状态信息 function hideStatus(elementId) { const statusElement = document.getElementById(elementId); @@ -388,6 +463,28 @@ } }); + // 清空消息功能 + document.getElementById('clearMessagesBtn').addEventListener('click', async function() { + try { + const response = await fetch('/clear-messages/', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + } + }); + + if (response.ok) { + const container = document.getElementById('messageContainer'); + container.innerHTML = ''; + showStatus('queryStatus', '消息已清空', 'success'); + } else { + showStatus('queryStatus', '清空消息失败', 'error'); + } + } catch (error) { + showStatus('queryStatus', `请求失败: ${error.message}`, 'error'); + } + }); + // 加载网站内容功能 document.getElementById('loadWebsiteBtn').addEventListener('click', async function() { const button = this; @@ -466,6 +563,12 @@ if (response.ok) { showStatus('queryStatus', '查询完成', 'success'); document.getElementById('resultText').textContent = data.result; + + // 显示消息流 + if (data.messages && data.messages.length > 0) { + displayMessages(data.messages); + } + showResult(); } else { showStatus('queryStatus', `查询失败: ${data.detail}`, 'error'); diff --git a/deepsearcher/utils/message_stream.py b/deepsearcher/utils/message_stream.py new file mode 100644 index 0000000..a0ed0e7 --- /dev/null +++ b/deepsearcher/utils/message_stream.py @@ -0,0 +1,138 @@ +import json +import time +from typing import Dict, List, Optional, Callable, Any +from enum import Enum +from dataclasses import dataclass, asdict +from datetime import datetime + + +class MessageType(Enum): + """消息类型枚举""" + SEARCH = "search" + THINK = "think" + ANSWER = "answer" + + +@dataclass +class Message: + """消息数据结构""" + type: MessageType + content: str + timestamp: float + metadata: Optional[Dict[str, Any]] = None + + def to_dict(self) -> Dict[str, Any]: + """转换为字典格式""" + return { + "type": self.type.value, + "content": self.content, + "timestamp": self.timestamp, + "metadata": self.metadata or {} + } + + def to_json(self) -> str: + """转换为JSON字符串""" + return json.dumps(self.to_dict(), ensure_ascii=False) + + +class MessageStream: + """消息流管理器""" + + def __init__(self): + self._messages: List[Message] = [] + self._callbacks: List[Callable[[Message], None]] = [] + self._enabled = True + + def add_callback(self, callback: Callable[[Message], None]): + """添加消息回调函数""" + self._callbacks.append(callback) + + def remove_callback(self, callback: Callable[[Message], None]): + """移除消息回调函数""" + if callback in self._callbacks: + self._callbacks.remove(callback) + + def enable(self): + """启用消息流""" + self._enabled = True + + def disable(self): + """禁用消息流""" + self._enabled = False + + def send_message(self, msg_type: MessageType, content: str, metadata: Optional[Dict[str, Any]] = None): + """发送消息""" + if not self._enabled: + return + + message = Message( + type=msg_type, + content=content, + timestamp=time.time(), + metadata=metadata + ) + + self._messages.append(message) + + # 调用所有回调函数 + for callback in self._callbacks: + try: + callback(message) + except Exception as e: + print(f"Error in message callback: {e}") + + def send_search(self, content: str, metadata: Optional[Dict[str, Any]] = None): + """发送搜索消息""" + self.send_message(MessageType.SEARCH, content, metadata) + + def send_think(self, content: str, metadata: Optional[Dict[str, Any]] = None): + """发送思考消息""" + self.send_message(MessageType.THINK, content, metadata) + + def send_answer(self, content: str, metadata: Optional[Dict[str, Any]] = None): + """发送答案消息""" + self.send_message(MessageType.ANSWER, content, metadata) + + def get_messages(self) -> List[Message]: + """获取所有消息""" + return self._messages.copy() + + def get_messages_by_type(self, msg_type: MessageType) -> List[Message]: + """根据类型获取消息""" + return [msg for msg in self._messages if msg.type == msg_type] + + def clear_messages(self): + """清空所有消息""" + self._messages.clear() + + def get_messages_as_dicts(self) -> List[Dict[str, Any]]: + """获取所有消息的字典格式""" + return [msg.to_dict() for msg in self._messages] + + def get_messages_as_json(self) -> str: + """获取所有消息的JSON格式""" + return json.dumps(self.get_messages_as_dicts(), ensure_ascii=False) + + +# 全局消息流实例 +message_stream = MessageStream() + + +def send_search(content: str, metadata: Optional[Dict[str, Any]] = None): + """全局搜索消息发送函数""" + message_stream.send_search(content, metadata) + + +def send_think(content: str, metadata: Optional[Dict[str, Any]] = None): + """全局思考消息发送函数""" + message_stream.send_think(content, metadata) + + +def send_answer(content: str, metadata: Optional[Dict[str, Any]] = None): + """全局答案消息发送函数""" + message_stream.send_answer(content, metadata) + + +def get_message_stream() -> MessageStream: + """获取全局消息流实例""" + return message_stream diff --git a/docs/message_stream_system.md b/docs/message_stream_system.md new file mode 100644 index 0000000..b572003 --- /dev/null +++ b/docs/message_stream_system.md @@ -0,0 +1,197 @@ +# 消息流系统 (Message Stream System) + +## 概述 + +DeepSearcher 的消息流系统是一个新的消息传输机制,用于替代原来的日志传输方式。该系统支持三种类型的消息:`search`、`think` 和 `answer`,并提供了灵活的消息管理和传输功能。 + +## 消息类型 + +### 1. Search 消息 +- **类型**: `search` +- **用途**: 表示搜索相关的操作和状态 +- **示例**: + - "在向量数据库中搜索相关信息..." + - "找到5个相关文档片段" + - "搜索人工智能的定义..." + +### 2. Think 消息 +- **类型**: `think` +- **用途**: 表示思考和推理过程 +- **示例**: + - "开始处理查询: 什么是人工智能?" + - "分析搜索结果..." + - "生成子查询: 人工智能的定义、历史、应用" + +### 3. Answer 消息 +- **类型**: `answer` +- **用途**: 表示最终答案和结果 +- **示例**: + - "==== FINAL ANSWER====" + - "人工智能是计算机科学的一个分支..." + +## 核心组件 + +### MessageStream 类 + +主要的消息流管理器,提供以下功能: + +```python +from deepsearcher.utils.message_stream import MessageStream + +# 创建消息流实例 +message_stream = MessageStream() + +# 发送消息 +message_stream.send_search("搜索内容...") +message_stream.send_think("思考内容...") +message_stream.send_answer("答案内容...") + +# 获取消息 +messages = message_stream.get_messages() +messages_dict = message_stream.get_messages_as_dicts() +messages_json = message_stream.get_messages_as_json() +``` + +### 全局函数 + +为了方便使用,提供了全局函数: + +```python +from deepsearcher.utils.message_stream import send_search, send_think, send_answer + +# 直接发送消息 +send_search("搜索内容...") +send_think("思考内容...") +send_answer("答案内容...") +``` + +## API 接口 + +### 1. 获取消息 +``` +GET /messages/ +``` + +返回所有消息的列表: +```json +{ + "messages": [ + { + "type": "search", + "content": "搜索内容...", + "timestamp": 1755043653.9606102, + "metadata": {} + } + ] +} +``` + +### 2. 清空消息 +``` +POST /clear-messages/ +``` + +清空所有消息并返回成功状态: +```json +{ + "message": "Messages cleared successfully" +} +``` + +### 3. 查询接口(已更新) +``` +GET /query/?original_query=&max_iter= +``` + +现在返回结果包含消息流: +```json +{ + "result": "最终答案...", + "messages": [ + { + "type": "search", + "content": "搜索内容...", + "timestamp": 1755043653.9606102, + "metadata": {} + } + ] +} +``` + +## 前端集成 + +前端界面现在包含一个消息流显示区域,实时显示处理过程中的各种消息: + +### CSS 样式 +- `.message-search`: 搜索消息样式(蓝色边框) +- `.message-think`: 思考消息样式(黄色边框) +- `.message-answer`: 答案消息样式(绿色边框) + +### JavaScript 功能 +- `displayMessages(messages)`: 显示消息流 +- 自动滚动到最新消息 +- 时间戳显示 + +## 使用示例 + +### 后端使用 + +```python +from deepsearcher.utils.message_stream import send_search, send_think, send_answer + +# 在搜索过程中发送消息 +send_think("开始处理查询...") +send_search("在数据库中搜索...") +send_search("找到相关文档...") +send_think("分析结果...") +send_answer("最终答案...") +``` + +### 前端使用 + +```javascript +// 获取消息 +const response = await fetch('/query/?original_query=test&max_iter=3'); +const data = await response.json(); + +// 显示消息流 +if (data.messages && data.messages.length > 0) { + displayMessages(data.messages); +} +``` + +## 优势 + +1. **结构化数据**: 消息包含类型、内容、时间戳和元数据 +2. **类型安全**: 使用枚举确保消息类型的一致性 +3. **灵活传输**: 支持多种输出格式(字典、JSON) +4. **实时显示**: 前端可以实时显示处理过程 +5. **易于扩展**: 可以轻松添加新的消息类型和功能 + +## 迁移说明 + +从原来的日志系统迁移到新的消息流系统: + +### 原来的代码 +```python +log.color_print(f" Search [{query}] in [{collection}]... \n") +log.color_print(f" Summarize answer... \n") +log.color_print(f" Final answer... \n") +``` + +### 新的代码 +```python +send_search(f"Search [{query}] in [{collection}]...") +send_think("Summarize answer...") +send_answer("Final answer...") +``` + +## 测试 + +运行测试脚本验证系统功能: + +```bash +python test_message_stream.py +``` + +这将测试消息流的基本功能,包括消息发送、获取和格式化。 diff --git a/main.py b/main.py index 77eafdc..d74b823 100644 --- a/main.py +++ b/main.py @@ -10,6 +10,7 @@ import os from deepsearcher.configuration import Configuration, init_config from deepsearcher.offline_loading import load_from_local_files, load_from_website from deepsearcher.online_query import query +from deepsearcher.utils.message_stream import get_message_stream app = FastAPI() @@ -203,10 +204,50 @@ def perform_query( HTTPException: If the query fails. """ try: + # 清空之前的消息 + message_stream = get_message_stream() + message_stream.clear_messages() + result_text, _ = query(original_query, max_iter) return { - "result": result_text + "result": result_text, + "messages": message_stream.get_messages_as_dicts() + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/messages/") +def get_messages(): + """ + Get all messages from the message stream. + + Returns: + dict: A dictionary containing all messages. + """ + try: + message_stream = get_message_stream() + return { + "messages": message_stream.get_messages_as_dicts() + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/clear-messages/") +def clear_messages(): + """ + Clear all messages from the message stream. + + Returns: + dict: A dictionary containing a success message. + """ + try: + message_stream = get_message_stream() + message_stream.clear_messages() + return { + "message": "Messages cleared successfully" } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/test_iteration.py b/test_iteration.py new file mode 100644 index 0000000..d27eeb7 --- /dev/null +++ b/test_iteration.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +""" +测试迭代逻辑的脚本 +""" + +def test_iteration_logic(): + """测试迭代逻辑""" + max_iter = 3 + + print(f"测试最大迭代次数: {max_iter}") + print("-" * 40) + + for it in range(max_iter): + print(f">> Iteration: {it + 1}") + + # 模拟搜索任务 + print(" 执行搜索任务...") + + # 检查是否达到最大迭代次数 + if it + 1 < max_iter: + print(" 反思并生成新的子查询...") + print(" 准备下一次迭代") + else: + print(" 达到最大迭代次数,退出") + break + + print("-" * 40) + print("测试完成!") + +if __name__ == "__main__": + test_iteration_logic()