diff --git a/deepsearcher/agent/deep_search.py b/deepsearcher/agent/deep_search.py index 02565da..9d90a5a 100644 --- a/deepsearcher/agent/deep_search.py +++ b/deepsearcher/agent/deep_search.py @@ -325,7 +325,7 @@ class DeepSearch(BaseAgent): """ # Get max_iter from kwargs or use default max_iter = kwargs.get('max_iter', self.max_iter) - + ### SUB QUERIES ### send_think(f" {original_query} ") all_search_results = [] @@ -353,7 +353,7 @@ class DeepSearch(BaseAgent): send_search(f"Remove {undeduped_len - deduped_len} duplicates") # search_res_from_internet = deduplicate_results(search_res_from_internet) # all_search_res.extend(search_res_from_vectordb + search_res_from_internet) - + ### REFLECTION & GET MORE SUB QUERIES ### # Only generate more queries if we haven't reached the maximum iterations if it + 1 < max_iter: diff --git a/deepsearcher/backend/templates/index.html b/deepsearcher/backend/templates/index.html index 56bb73e..6689d51 100644 --- a/deepsearcher/backend/templates/index.html +++ b/deepsearcher/backend/templates/index.html @@ -330,6 +330,10 @@ \ No newline at end of file diff --git a/main.py b/main.py index 999df08..bfb5f11 100644 --- a/main.py +++ b/main.py @@ -3,9 +3,13 @@ import argparse import uvicorn from fastapi import Body, FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import HTMLResponse +from fastapi.responses import HTMLResponse, StreamingResponse from pydantic import BaseModel import os +import asyncio +import json +import queue +from typing import AsyncGenerator from deepsearcher.configuration import Configuration, init_config from deepsearcher.offline_loading import load_from_local_files, load_from_website @@ -18,6 +22,9 @@ config = Configuration() init_config(config) +# 全局变量用于存储消息流回调 +message_callbacks = [] + class ProviderConfigRequest(BaseModel): """ @@ -232,6 +239,118 @@ def perform_query( raise HTTPException(status_code=500, detail=str(e)) +@app.get("/query-stream/") +async def perform_query_stream( + original_query: str = Query( + ..., + description="Your question here.", + examples=["Write a report about Milvus."], + ), + max_iter: int = Query( + 3, + description="The maximum number of iterations for reflection.", + ge=1, + examples=[3], + ), +) -> StreamingResponse: + """ + Perform a query with real-time streaming of progress and results. + + Args: + original_query (str): The user's question or query. + max_iter (int, optional): Maximum number of iterations for reflection. Defaults to 3. + + Returns: + StreamingResponse: A streaming response with real-time query progress and results. + """ + async def query_stream_generator() -> AsyncGenerator[str, None]: + """生成查询流""" + message_callback = None + try: + # 清空之前的消息 + message_stream = get_message_stream() + message_stream.clear_messages() + + # 发送查询开始消息 + yield f"data: {json.dumps({'type': 'query_start', 'query': original_query, 'max_iter': max_iter}, ensure_ascii=False)}\n\n" + + # 创建一个线程安全的队列来接收消息 + message_queue = queue.Queue() + + def message_callback(message): + """消息回调函数""" + try: + # 直接放入线程安全的队列 + message_queue.put(message) + except Exception as e: + print(f"Error in message callback: {e}") + + # 注册回调函数 + message_stream.add_callback(message_callback) + + # 在后台线程中执行查询 + def run_query(): + try: + result_text, _ = query(original_query, max_iter) + return result_text, None + except Exception as e: + return None, str(e) + + # 使用线程池执行查询 + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + query_future = executor.submit(run_query) + + # 监听消息和查询结果 + while True: + try: + # 检查查询是否已完成 + if query_future.done(): + result_text, error = query_future.result() + if error: + yield f"data: {json.dumps({'type': 'query_error', 'error': error}, ensure_ascii=False)}\n\n" + else: + yield f"data: {json.dumps({'type': 'query_complete', 'result': result_text}, ensure_ascii=False)}\n\n" + return + + # 尝试从队列获取消息,设置超时 + try: + message = message_queue.get(timeout=0.5) + message_data = { + 'type': message.type.value, + 'content': message.content, + 'timestamp': message.timestamp, + 'metadata': message.metadata or {} + } + yield f"data: {json.dumps(message_data, ensure_ascii=False)}\n\n" + except queue.Empty: + # 超时,继续循环 + continue + + except Exception as e: + print(f"Error in stream loop: {e}") + yield f"data: {json.dumps({'type': 'stream_error', 'error': str(e)}, ensure_ascii=False)}\n\n" + return + + except Exception as e: + yield f"data: {json.dumps({'type': 'stream_error', 'error': str(e)}, ensure_ascii=False)}\n\n" + finally: + # 清理回调函数 + if message_callback: + message_stream.remove_callback(message_callback) + + return StreamingResponse( + query_stream_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Headers": "Cache-Control" + } + ) + + @app.get("/messages/") def get_messages(): """ @@ -249,6 +368,74 @@ def get_messages(): raise HTTPException(status_code=500, detail=str(e)) +@app.get("/stream-messages/") +async def stream_messages() -> StreamingResponse: + """ + Stream messages in real-time using Server-Sent Events. + + Returns: + StreamingResponse: A streaming response with real-time messages. + """ + async def event_generator() -> AsyncGenerator[str, None]: + """生成SSE事件流""" + # 创建一个队列来接收消息 + message_queue = asyncio.Queue() + + def message_callback(message): + """消息回调函数""" + try: + # 将消息放入队列 + asyncio.create_task(message_queue.put(message)) + except Exception as e: + print(f"Error in message callback: {e}") + + # 注册回调函数 + message_callbacks.append(message_callback) + message_stream = get_message_stream() + message_stream.add_callback(message_callback) + + try: + # 发送连接建立消息 + yield f"data: {json.dumps({'type': 'connection', 'message': 'Connected to message stream'}, ensure_ascii=False)}\n\n" + + while True: + try: + # 等待新消息,设置超时以便能够检查连接状态 + message = await asyncio.wait_for(message_queue.get(), timeout=1.0) + + # 发送消息数据 + message_data = { + 'type': message.type.value, + 'content': message.content, + 'timestamp': message.timestamp, + 'metadata': message.metadata or {} + } + yield f"data: {json.dumps(message_data, ensure_ascii=False)}\n\n" + + except asyncio.TimeoutError: + # 发送心跳保持连接 + yield f"data: {json.dumps({'type': 'heartbeat', 'timestamp': asyncio.get_event_loop().time()}, ensure_ascii=False)}\n\n" + + except Exception as e: + print(f"Error in event generator: {e}") + finally: + # 清理回调函数 + if message_callback in message_callbacks: + message_callbacks.remove(message_callback) + message_stream.remove_callback(message_callback) + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Headers": "Cache-Control" + } + ) + + @app.post("/clear-messages/") def clear_messages(): """ diff --git a/test_iteration.py b/test_iteration.py deleted file mode 100644 index d27eeb7..0000000 --- a/test_iteration.py +++ /dev/null @@ -1,31 +0,0 @@ -#!/usr/bin/env python3 -""" -测试迭代逻辑的脚本 -""" - -def test_iteration_logic(): - """测试迭代逻辑""" - max_iter = 3 - - print(f"测试最大迭代次数: {max_iter}") - print("-" * 40) - - for it in range(max_iter): - print(f">> Iteration: {it + 1}") - - # 模拟搜索任务 - print(" 执行搜索任务...") - - # 检查是否达到最大迭代次数 - if it + 1 < max_iter: - print(" 反思并生成新的子查询...") - print(" 准备下一次迭代") - else: - print(" 达到最大迭代次数,退出") - break - - print("-" * 40) - print("测试完成!") - -if __name__ == "__main__": - test_iteration_logic()