Browse Source

feat: 添加查询中前端消息流输出

main
tanxing 5 days ago
parent
commit
b4303f8edf
  1. 4
      deepsearcher/agent/deep_search.py
  2. 223
      deepsearcher/backend/templates/index.html
  3. 189
      main.py
  4. 31
      test_iteration.py

4
deepsearcher/agent/deep_search.py

@ -325,7 +325,7 @@ class DeepSearch(BaseAgent):
"""
# Get max_iter from kwargs or use default
max_iter = kwargs.get('max_iter', self.max_iter)
### SUB QUERIES ###
send_think(f"<query> {original_query} </query>")
all_search_results = []
@ -353,7 +353,7 @@ class DeepSearch(BaseAgent):
send_search(f"Remove {undeduped_len - deduped_len} duplicates")
# search_res_from_internet = deduplicate_results(search_res_from_internet)
# all_search_res.extend(search_res_from_vectordb + search_res_from_internet)
### REFLECTION & GET MORE SUB QUERIES ###
# Only generate more queries if we haven't reached the maximum iterations
if it + 1 < max_iter:

223
deepsearcher/backend/templates/index.html

@ -330,6 +330,10 @@
</div>
<script>
// 全局变量
let eventSource = null;
let isStreaming = false;
// 工具函数:显示状态信息
function showStatus(elementId, message, type) {
const statusElement = document.getElementById(elementId);
@ -355,26 +359,50 @@
container.innerHTML = '';
messages.forEach(message => {
const messageElement = document.createElement('div');
messageElement.className = `message message-${message.type}`;
const contentElement = document.createElement('div');
contentElement.textContent = message.content;
const timestampElement = document.createElement('div');
timestampElement.className = 'message-timestamp';
const date = new Date(message.timestamp * 1000);
timestampElement.textContent = date.toLocaleTimeString();
messageElement.appendChild(contentElement);
messageElement.appendChild(timestampElement);
container.appendChild(messageElement);
addMessageToContainer(message);
});
// 滚动到底部
container.scrollTop = container.scrollHeight;
}
// 工具函数:添加单个消息到容器
function addMessageToContainer(message) {
console.log('Adding message to container:', message);
const container = document.getElementById('messageContainer');
if (!container) {
console.error('Message container not found!');
return;
}
const messageElement = document.createElement('div');
messageElement.className = `message message-${message.type}`;
const contentElement = document.createElement('div');
contentElement.textContent = message.content;
const timestampElement = document.createElement('div');
timestampElement.className = 'message-timestamp';
const date = new Date(message.timestamp * 1000);
timestampElement.textContent = date.toLocaleTimeString();
messageElement.appendChild(contentElement);
messageElement.appendChild(timestampElement);
container.appendChild(messageElement);
// 确保结果容器是可见的
const resultContainer = document.getElementById('queryResult');
if (resultContainer && !resultContainer.classList.contains('visible')) {
resultContainer.classList.add('visible');
}
// 滚动到底部
container.scrollTop = container.scrollHeight;
console.log('Message added successfully, container now has', container.children.length, 'messages');
}
// 工具函数:隐藏状态信息
function hideStatus(elementId) {
const statusElement = document.getElementById(elementId);
@ -418,6 +446,105 @@
}
}
// 工具函数:关闭EventSource连接
function closeEventSource() {
if (eventSource) {
eventSource.close();
eventSource = null;
}
isStreaming = false;
}
// 工具函数:处理实时消息流
function handleStreamMessage(data) {
try {
const message = JSON.parse(data);
switch (message.type) {
case 'connection':
console.log('Connected to message stream:', message.message);
break;
case 'heartbeat':
// 心跳消息,不需要处理
break;
case 'query_start':
console.log('Query started:', message.query);
showStatus('queryStatus', '查询已开始,正在处理...', 'loading');
break;
case 'query_complete':
console.log('Query completed');
showStatus('queryStatus', '查询完成', 'success');
if (message.result) {
document.getElementById('resultText').textContent = message.result;
showResult();
}
// 关闭EventSource连接
if (window.currentEventSource) {
window.currentEventSource.close();
window.currentEventSource = null;
}
isStreaming = false;
setButtonLoading(document.getElementById('queryBtn'), false);
break;
case 'query_error':
console.error('Query error:', message.error);
showStatus('queryStatus', `查询失败: ${message.error}`, 'error');
// 关闭EventSource连接
if (window.currentEventSource) {
window.currentEventSource.close();
window.currentEventSource = null;
}
isStreaming = false;
setButtonLoading(document.getElementById('queryBtn'), false);
break;
case 'stream_error':
console.error('Stream error:', message.error);
showStatus('queryStatus', `流错误: ${message.error}`, 'error');
// 关闭EventSource连接
if (window.currentEventSource) {
window.currentEventSource.close();
window.currentEventSource = null;
}
isStreaming = false;
setButtonLoading(document.getElementById('queryBtn'), false);
break;
case 'search':
case 'think':
case 'answer':
// 处理常规消息
console.log('Processing message type:', message.type, 'with content:', message.content.substring(0, 100) + '...');
addMessageToContainer(message);
break;
default:
console.log('Unknown message type:', message.type);
}
} catch (error) {
console.error('Error parsing message:', error);
}
}
// 工具函数:开始实时消息流
function startMessageStream() {
closeEventSource(); // 关闭之前的连接
eventSource = new EventSource('/stream-messages/');
eventSource.onopen = function(event) {
console.log('EventSource connection opened');
};
eventSource.onmessage = function(event) {
handleStreamMessage(event.data);
};
eventSource.onerror = function(event) {
console.error('EventSource error:', event);
if (eventSource.readyState === EventSource.CLOSED) {
console.log('EventSource connection closed');
}
};
}
// 加载文件功能
document.getElementById('loadFilesBtn').addEventListener('click', async function() {
const button = this;
@ -530,7 +657,7 @@
}
});
// 查询功能
// 查询功能 - 使用实时流
document.getElementById('queryBtn').addEventListener('click', async function() {
const button = this;
const queryText = document.getElementById('queryText').value;
@ -546,39 +673,63 @@
return;
}
if (isStreaming) {
showStatus('queryStatus', '查询正在进行中,请等待完成', 'error');
return;
}
setButtonLoading(button, true);
showStatus('queryStatus', '正在处理查询...', 'loading');
showStatus('queryStatus', '正在启动查询...', 'loading');
hideResult();
// 清空消息容器
const container = document.getElementById('messageContainer');
container.innerHTML = '';
try {
const response = await fetch(`/query/?original_query=${encodeURIComponent(queryText)}&max_iter=${maxIter}`, {
method: 'GET',
headers: {
'Content-Type': 'application/json'
}
});
isStreaming = true;
const data = await response.json();
// 使用EventSource直接连接到查询流
const eventSource = new EventSource(`/query-stream/?original_query=${encodeURIComponent(queryText)}&max_iter=${maxIter}`);
if (response.ok) {
showStatus('queryStatus', '查询完成', 'success');
document.getElementById('resultText').textContent = data.result;
// 显示消息流
if (data.messages && data.messages.length > 0) {
displayMessages(data.messages);
// 保存EventSource引用以便后续关闭
window.currentEventSource = eventSource;
eventSource.onopen = function(event) {
console.log('EventSource connection opened for query');
showStatus('queryStatus', '查询已开始,正在处理...', 'loading');
};
eventSource.onmessage = function(event) {
console.log('Received message:', event.data);
handleStreamMessage(event.data);
};
eventSource.onerror = function(event) {
console.error('EventSource error:', event);
if (eventSource.readyState === EventSource.CLOSED) {
console.log('EventSource connection closed');
isStreaming = false;
setButtonLoading(button, false);
window.currentEventSource = null;
}
showResult();
} else {
showStatus('queryStatus', `查询失败: ${data.detail}`, 'error');
}
};
} catch (error) {
console.error('Query error:', error);
showStatus('queryStatus', `请求失败: ${error.message}`, 'error');
} finally {
isStreaming = false;
setButtonLoading(button, false);
}
});
// 页面卸载时清理连接
window.addEventListener('beforeunload', function() {
if (window.currentEventSource) {
window.currentEventSource.close();
window.currentEventSource = null;
}
});
</script>
</body>
</html>

189
main.py

@ -3,9 +3,13 @@ import argparse
import uvicorn
from fastapi import Body, FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from fastapi.responses import HTMLResponse, StreamingResponse
from pydantic import BaseModel
import os
import asyncio
import json
import queue
from typing import AsyncGenerator
from deepsearcher.configuration import Configuration, init_config
from deepsearcher.offline_loading import load_from_local_files, load_from_website
@ -18,6 +22,9 @@ config = Configuration()
init_config(config)
# 全局变量用于存储消息流回调
message_callbacks = []
class ProviderConfigRequest(BaseModel):
"""
@ -232,6 +239,118 @@ def perform_query(
raise HTTPException(status_code=500, detail=str(e))
@app.get("/query-stream/")
async def perform_query_stream(
original_query: str = Query(
...,
description="Your question here.",
examples=["Write a report about Milvus."],
),
max_iter: int = Query(
3,
description="The maximum number of iterations for reflection.",
ge=1,
examples=[3],
),
) -> StreamingResponse:
"""
Perform a query with real-time streaming of progress and results.
Args:
original_query (str): The user's question or query.
max_iter (int, optional): Maximum number of iterations for reflection. Defaults to 3.
Returns:
StreamingResponse: A streaming response with real-time query progress and results.
"""
async def query_stream_generator() -> AsyncGenerator[str, None]:
"""生成查询流"""
message_callback = None
try:
# 清空之前的消息
message_stream = get_message_stream()
message_stream.clear_messages()
# 发送查询开始消息
yield f"data: {json.dumps({'type': 'query_start', 'query': original_query, 'max_iter': max_iter}, ensure_ascii=False)}\n\n"
# 创建一个线程安全的队列来接收消息
message_queue = queue.Queue()
def message_callback(message):
"""消息回调函数"""
try:
# 直接放入线程安全的队列
message_queue.put(message)
except Exception as e:
print(f"Error in message callback: {e}")
# 注册回调函数
message_stream.add_callback(message_callback)
# 在后台线程中执行查询
def run_query():
try:
result_text, _ = query(original_query, max_iter)
return result_text, None
except Exception as e:
return None, str(e)
# 使用线程池执行查询
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
query_future = executor.submit(run_query)
# 监听消息和查询结果
while True:
try:
# 检查查询是否已完成
if query_future.done():
result_text, error = query_future.result()
if error:
yield f"data: {json.dumps({'type': 'query_error', 'error': error}, ensure_ascii=False)}\n\n"
else:
yield f"data: {json.dumps({'type': 'query_complete', 'result': result_text}, ensure_ascii=False)}\n\n"
return
# 尝试从队列获取消息,设置超时
try:
message = message_queue.get(timeout=0.5)
message_data = {
'type': message.type.value,
'content': message.content,
'timestamp': message.timestamp,
'metadata': message.metadata or {}
}
yield f"data: {json.dumps(message_data, ensure_ascii=False)}\n\n"
except queue.Empty:
# 超时,继续循环
continue
except Exception as e:
print(f"Error in stream loop: {e}")
yield f"data: {json.dumps({'type': 'stream_error', 'error': str(e)}, ensure_ascii=False)}\n\n"
return
except Exception as e:
yield f"data: {json.dumps({'type': 'stream_error', 'error': str(e)}, ensure_ascii=False)}\n\n"
finally:
# 清理回调函数
if message_callback:
message_stream.remove_callback(message_callback)
return StreamingResponse(
query_stream_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "Cache-Control"
}
)
@app.get("/messages/")
def get_messages():
"""
@ -249,6 +368,74 @@ def get_messages():
raise HTTPException(status_code=500, detail=str(e))
@app.get("/stream-messages/")
async def stream_messages() -> StreamingResponse:
"""
Stream messages in real-time using Server-Sent Events.
Returns:
StreamingResponse: A streaming response with real-time messages.
"""
async def event_generator() -> AsyncGenerator[str, None]:
"""生成SSE事件流"""
# 创建一个队列来接收消息
message_queue = asyncio.Queue()
def message_callback(message):
"""消息回调函数"""
try:
# 将消息放入队列
asyncio.create_task(message_queue.put(message))
except Exception as e:
print(f"Error in message callback: {e}")
# 注册回调函数
message_callbacks.append(message_callback)
message_stream = get_message_stream()
message_stream.add_callback(message_callback)
try:
# 发送连接建立消息
yield f"data: {json.dumps({'type': 'connection', 'message': 'Connected to message stream'}, ensure_ascii=False)}\n\n"
while True:
try:
# 等待新消息,设置超时以便能够检查连接状态
message = await asyncio.wait_for(message_queue.get(), timeout=1.0)
# 发送消息数据
message_data = {
'type': message.type.value,
'content': message.content,
'timestamp': message.timestamp,
'metadata': message.metadata or {}
}
yield f"data: {json.dumps(message_data, ensure_ascii=False)}\n\n"
except asyncio.TimeoutError:
# 发送心跳保持连接
yield f"data: {json.dumps({'type': 'heartbeat', 'timestamp': asyncio.get_event_loop().time()}, ensure_ascii=False)}\n\n"
except Exception as e:
print(f"Error in event generator: {e}")
finally:
# 清理回调函数
if message_callback in message_callbacks:
message_callbacks.remove(message_callback)
message_stream.remove_callback(message_callback)
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "Cache-Control"
}
)
@app.post("/clear-messages/")
def clear_messages():
"""

31
test_iteration.py

@ -1,31 +0,0 @@
#!/usr/bin/env python3
"""
测试迭代逻辑的脚本
"""
def test_iteration_logic():
"""测试迭代逻辑"""
max_iter = 3
print(f"测试最大迭代次数: {max_iter}")
print("-" * 40)
for it in range(max_iter):
print(f">> Iteration: {it + 1}")
# 模拟搜索任务
print(" 执行搜索任务...")
# 检查是否达到最大迭代次数
if it + 1 < max_iter:
print(" 反思并生成新的子查询...")
print(" 准备下一次迭代")
else:
print(" 达到最大迭代次数,退出")
break
print("-" * 40)
print("测试完成!")
if __name__ == "__main__":
test_iteration_logic()
Loading…
Cancel
Save