|
|
@ -3,9 +3,13 @@ import argparse |
|
|
|
import uvicorn |
|
|
|
from fastapi import Body, FastAPI, HTTPException, Query |
|
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
from fastapi.responses import HTMLResponse |
|
|
|
from fastapi.responses import HTMLResponse, StreamingResponse |
|
|
|
from pydantic import BaseModel |
|
|
|
import os |
|
|
|
import asyncio |
|
|
|
import json |
|
|
|
import queue |
|
|
|
from typing import AsyncGenerator |
|
|
|
|
|
|
|
from deepsearcher.configuration import Configuration, init_config |
|
|
|
from deepsearcher.offline_loading import load_from_local_files, load_from_website |
|
|
@ -18,6 +22,9 @@ config = Configuration() |
|
|
|
|
|
|
|
init_config(config) |
|
|
|
|
|
|
|
# 全局变量用于存储消息流回调 |
|
|
|
message_callbacks = [] |
|
|
|
|
|
|
|
|
|
|
|
class ProviderConfigRequest(BaseModel): |
|
|
|
""" |
|
|
@ -98,6 +105,11 @@ def load_files( |
|
|
|
description="Optional batch size for the collection.", |
|
|
|
examples=[256], |
|
|
|
), |
|
|
|
force_rebuild: bool = Body( |
|
|
|
False, |
|
|
|
description="Whether to force rebuild the collection if it already exists.", |
|
|
|
examples=[False], |
|
|
|
), |
|
|
|
): |
|
|
|
""" |
|
|
|
Load files into the vector database. |
|
|
@ -107,6 +119,7 @@ def load_files( |
|
|
|
collection_name (str, optional): Name for the collection. Defaults to None. |
|
|
|
collection_description (str, optional): Description for the collection. Defaults to None. |
|
|
|
batch_size (int, optional): Batch size for processing. Defaults to None. |
|
|
|
force_rebuild (bool, optional): Whether to force rebuild the collection. Defaults to False. |
|
|
|
|
|
|
|
Returns: |
|
|
|
dict: A dictionary containing a success message. |
|
|
@ -120,6 +133,7 @@ def load_files( |
|
|
|
collection_name=collection_name, |
|
|
|
collection_description=collection_description, |
|
|
|
batch_size=batch_size if batch_size is not None else 8, |
|
|
|
force_rebuild=force_rebuild, |
|
|
|
) |
|
|
|
return {"message": "Files loaded successfully."} |
|
|
|
except Exception as e: |
|
|
@ -148,6 +162,11 @@ def load_website( |
|
|
|
description="Optional batch size for the collection.", |
|
|
|
examples=[256], |
|
|
|
), |
|
|
|
force_rebuild: bool = Body( |
|
|
|
False, |
|
|
|
description="Whether to force rebuild the collection if it already exists.", |
|
|
|
examples=[False], |
|
|
|
), |
|
|
|
): |
|
|
|
""" |
|
|
|
Load website content into the vector database. |
|
|
@ -157,6 +176,7 @@ def load_website( |
|
|
|
collection_name (str, optional): Name for the collection. Defaults to None. |
|
|
|
collection_description (str, optional): Description for the collection. Defaults to None. |
|
|
|
batch_size (int, optional): Batch size for processing. Defaults to None. |
|
|
|
force_rebuild (bool, optional): Whether to force rebuild the collection. Defaults to False. |
|
|
|
|
|
|
|
Returns: |
|
|
|
dict: A dictionary containing a success message. |
|
|
@ -170,6 +190,7 @@ def load_website( |
|
|
|
collection_name=collection_name, |
|
|
|
collection_description=collection_description, |
|
|
|
batch_size=batch_size if batch_size is not None else 8, |
|
|
|
force_rebuild=force_rebuild, |
|
|
|
) |
|
|
|
return {"message": "Website loaded successfully."} |
|
|
|
except Exception as e: |
|
|
@ -218,6 +239,125 @@ def perform_query( |
|
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/query-stream/") |
|
|
|
async def perform_query_stream( |
|
|
|
original_query: str = Query( |
|
|
|
..., |
|
|
|
description="Your question here.", |
|
|
|
examples=["Write a report about Milvus."], |
|
|
|
), |
|
|
|
max_iter: int = Query( |
|
|
|
3, |
|
|
|
description="The maximum number of iterations for reflection.", |
|
|
|
ge=1, |
|
|
|
examples=[3], |
|
|
|
), |
|
|
|
) -> StreamingResponse: |
|
|
|
""" |
|
|
|
Perform a query with real-time streaming of progress and results. |
|
|
|
|
|
|
|
Args: |
|
|
|
original_query (str): The user's question or query. |
|
|
|
max_iter (int, optional): Maximum number of iterations for reflection. Defaults to 3. |
|
|
|
|
|
|
|
Returns: |
|
|
|
StreamingResponse: A streaming response with real-time query progress and results. |
|
|
|
""" |
|
|
|
async def query_stream_generator() -> AsyncGenerator[str, None]: |
|
|
|
"""生成查询流""" |
|
|
|
message_callback = None |
|
|
|
try: |
|
|
|
# 清空之前的消息 |
|
|
|
message_stream = get_message_stream() |
|
|
|
message_stream.clear_messages() |
|
|
|
|
|
|
|
# 发送查询开始消息 |
|
|
|
yield f"data: {json.dumps({'type': 'query_start', 'query': original_query, 'max_iter': max_iter}, ensure_ascii=False)}\n\n" |
|
|
|
|
|
|
|
# 创建一个线程安全的队列来接收消息 |
|
|
|
message_queue = queue.Queue() |
|
|
|
|
|
|
|
def message_callback(message): |
|
|
|
"""消息回调函数""" |
|
|
|
try: |
|
|
|
# 直接放入线程安全的队列 |
|
|
|
message_queue.put(message) |
|
|
|
except Exception as e: |
|
|
|
print(f"Error in message callback: {e}") |
|
|
|
|
|
|
|
# 注册回调函数 |
|
|
|
message_stream.add_callback(message_callback) |
|
|
|
|
|
|
|
# 在后台线程中执行查询 |
|
|
|
def run_query(): |
|
|
|
try: |
|
|
|
print(f"Starting query: {original_query} with max_iter: {max_iter}") |
|
|
|
result_text, _ = query(original_query, max_iter) |
|
|
|
print(f"Query completed with result length: {len(result_text) if result_text else 0}") |
|
|
|
return result_text, None |
|
|
|
except Exception as e: |
|
|
|
print(f"Query failed with error: {e}") |
|
|
|
return None, str(e) |
|
|
|
|
|
|
|
# 使用线程池执行查询 |
|
|
|
import concurrent.futures |
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor: |
|
|
|
query_future = executor.submit(run_query) |
|
|
|
|
|
|
|
# 监听消息和查询结果 |
|
|
|
while True: |
|
|
|
try: |
|
|
|
# 检查查询是否已完成 |
|
|
|
if query_future.done(): |
|
|
|
result_text, error = query_future.result() |
|
|
|
if error: |
|
|
|
print(f"Query error: {error}") |
|
|
|
yield f"data: {json.dumps({'type': 'query_error', 'error': error}, ensure_ascii=False)}\n\n" |
|
|
|
return |
|
|
|
# 成功完成时,发送complete消息并结束 |
|
|
|
print("Query completed successfully, sending complete message") |
|
|
|
yield f"data: {json.dumps({'type': 'complete'}, ensure_ascii=False)}\n\n" |
|
|
|
print("Complete message sent, ending stream") |
|
|
|
return |
|
|
|
|
|
|
|
# 尝试从队列获取消息,设置超时 |
|
|
|
try: |
|
|
|
message = message_queue.get(timeout=0.5) |
|
|
|
message_data = { |
|
|
|
'type': message.type.value, |
|
|
|
'content': message.content, |
|
|
|
'timestamp': message.timestamp, |
|
|
|
'metadata': message.metadata or {} |
|
|
|
} |
|
|
|
yield f"data: {json.dumps(message_data, ensure_ascii=False)}\n\n" |
|
|
|
except queue.Empty: |
|
|
|
# 超时,继续循环 |
|
|
|
continue |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
print(f"Error in stream loop: {e}") |
|
|
|
yield f"data: {json.dumps({'type': 'stream_error', 'error': str(e)}, ensure_ascii=False)}\n\n" |
|
|
|
return |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
yield f"data: {json.dumps({'type': 'stream_error', 'error': str(e)}, ensure_ascii=False)}\n\n" |
|
|
|
finally: |
|
|
|
# 清理回调函数 |
|
|
|
if message_callback: |
|
|
|
message_stream.remove_callback(message_callback) |
|
|
|
|
|
|
|
return StreamingResponse( |
|
|
|
query_stream_generator(), |
|
|
|
media_type="text/event-stream", |
|
|
|
headers={ |
|
|
|
"Cache-Control": "no-cache", |
|
|
|
"Connection": "keep-alive", |
|
|
|
"Access-Control-Allow-Origin": "*", |
|
|
|
"Access-Control-Allow-Headers": "Cache-Control" |
|
|
|
} |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/messages/") |
|
|
|
def get_messages(): |
|
|
|
""" |
|
|
@ -235,6 +375,74 @@ def get_messages(): |
|
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/stream-messages/") |
|
|
|
async def stream_messages() -> StreamingResponse: |
|
|
|
""" |
|
|
|
Stream messages in real-time using Server-Sent Events. |
|
|
|
|
|
|
|
Returns: |
|
|
|
StreamingResponse: A streaming response with real-time messages. |
|
|
|
""" |
|
|
|
async def event_generator() -> AsyncGenerator[str, None]: |
|
|
|
"""生成SSE事件流""" |
|
|
|
# 创建一个队列来接收消息 |
|
|
|
message_queue = asyncio.Queue() |
|
|
|
|
|
|
|
def message_callback(message): |
|
|
|
"""消息回调函数""" |
|
|
|
try: |
|
|
|
# 将消息放入队列 |
|
|
|
asyncio.create_task(message_queue.put(message)) |
|
|
|
except Exception as e: |
|
|
|
print(f"Error in message callback: {e}") |
|
|
|
|
|
|
|
# 注册回调函数 |
|
|
|
message_callbacks.append(message_callback) |
|
|
|
message_stream = get_message_stream() |
|
|
|
message_stream.add_callback(message_callback) |
|
|
|
|
|
|
|
try: |
|
|
|
# 发送连接建立消息 |
|
|
|
yield f"data: {json.dumps({'type': 'connection', 'message': 'Connected to message stream'}, ensure_ascii=False)}\n\n" |
|
|
|
|
|
|
|
while True: |
|
|
|
try: |
|
|
|
# 等待新消息,设置超时以便能够检查连接状态 |
|
|
|
message = await asyncio.wait_for(message_queue.get(), timeout=1.0) |
|
|
|
|
|
|
|
# 发送消息数据 |
|
|
|
message_data = { |
|
|
|
'type': message.type.value, |
|
|
|
'content': message.content, |
|
|
|
'timestamp': message.timestamp, |
|
|
|
'metadata': message.metadata or {} |
|
|
|
} |
|
|
|
yield f"data: {json.dumps(message_data, ensure_ascii=False)}\n\n" |
|
|
|
|
|
|
|
except asyncio.TimeoutError: |
|
|
|
# 发送心跳保持连接 |
|
|
|
yield f"data: {json.dumps({'type': 'heartbeat', 'timestamp': asyncio.get_event_loop().time()}, ensure_ascii=False)}\n\n" |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
print(f"Error in event generator: {e}") |
|
|
|
finally: |
|
|
|
# 清理回调函数 |
|
|
|
if message_callback in message_callbacks: |
|
|
|
message_callbacks.remove(message_callback) |
|
|
|
message_stream.remove_callback(message_callback) |
|
|
|
|
|
|
|
return StreamingResponse( |
|
|
|
event_generator(), |
|
|
|
media_type="text/event-stream", |
|
|
|
headers={ |
|
|
|
"Cache-Control": "no-cache", |
|
|
|
"Connection": "keep-alive", |
|
|
|
"Access-Control-Allow-Origin": "*", |
|
|
|
"Access-Control-Allow-Headers": "Cache-Control" |
|
|
|
} |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@app.post("/clear-messages/") |
|
|
|
def clear_messages(): |
|
|
|
""" |
|
|
|