|
|
|
import argparse
|
|
|
|
|
|
|
|
import uvicorn
|
|
|
|
from fastapi import Body, FastAPI, HTTPException, Query
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
from fastapi.responses import HTMLResponse, StreamingResponse
|
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
from pydantic import BaseModel
|
|
|
|
import os
|
|
|
|
import asyncio
|
|
|
|
import json
|
|
|
|
import queue
|
|
|
|
from typing import AsyncGenerator
|
|
|
|
|
|
|
|
from deepsearcher.configuration import Configuration, init_config
|
|
|
|
from deepsearcher.offline_loading import load_from_local_files, load_from_website
|
|
|
|
from deepsearcher.online_query import query
|
|
|
|
from deepsearcher.utils.message_stream import get_message_stream
|
|
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
|
|
# 配置静态文件服务
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
static_dir = os.path.join(current_dir, "deepsearcher", "templates", "static")
|
|
|
|
app.mount("/static", StaticFiles(directory=static_dir), name="static")
|
|
|
|
|
|
|
|
config = Configuration()
|
|
|
|
|
|
|
|
init_config(config)
|
|
|
|
|
|
|
|
# 全局变量用于存储消息流回调
|
|
|
|
message_callbacks = []
|
|
|
|
|
|
|
|
|
|
|
|
class ProviderConfigRequest(BaseModel):
|
|
|
|
"""
|
|
|
|
Request model for setting provider configuration.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
feature (str): The feature to configure (e.g., 'embedding', 'llm').
|
|
|
|
provider (str): The provider name (e.g., 'openai', 'azure').
|
|
|
|
config (Dict): Configuration parameters for the provider.
|
|
|
|
"""
|
|
|
|
|
|
|
|
feature: str
|
|
|
|
provider: str
|
|
|
|
config: dict
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
|
|
async def read_root():
|
|
|
|
"""
|
|
|
|
Serve the main HTML page.
|
|
|
|
"""
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
template_path = os.path.join(current_dir, "deepsearcher", "templates", "html", "index.html")
|
|
|
|
|
|
|
|
try:
|
|
|
|
with open(template_path, encoding="utf-8") as file:
|
|
|
|
html_content = file.read()
|
|
|
|
return HTMLResponse(content=html_content, status_code=200)
|
|
|
|
except FileNotFoundError:
|
|
|
|
raise HTTPException(status_code=404, detail="Template file not found")
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/set-provider-config/")
|
|
|
|
def set_provider_config(request: ProviderConfigRequest):
|
|
|
|
"""
|
|
|
|
Set configuration for a specific provider.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
request (ProviderConfigRequest): The request containing provider configuration.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
dict: A dictionary containing a success message and the updated configuration.
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
HTTPException: If setting the provider config fails.
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
config.set_provider_config(request.feature, request.provider, request.config)
|
|
|
|
init_config(config)
|
|
|
|
return {
|
|
|
|
"message": "Provider config set successfully",
|
|
|
|
"provider": request.provider,
|
|
|
|
"config": request.config,
|
|
|
|
}
|
|
|
|
except Exception as e:
|
|
|
|
raise HTTPException(status_code=500, detail=f"Failed to set provider config: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/load-files/")
|
|
|
|
def load_files(
|
|
|
|
paths: str | list[str] = Body(
|
|
|
|
...,
|
|
|
|
description="A list of file paths to be loaded.",
|
|
|
|
examples=["/path/to/file1", "/path/to/file2", "/path/to/dir1"],
|
|
|
|
),
|
|
|
|
collection_name: str = Body(
|
|
|
|
None,
|
|
|
|
description="Optional name for the collection.",
|
|
|
|
examples=["my_collection"],
|
|
|
|
),
|
|
|
|
collection_description: str = Body(
|
|
|
|
None,
|
|
|
|
description="Optional description for the collection.",
|
|
|
|
examples=["This is a test collection."],
|
|
|
|
),
|
|
|
|
batch_size: int = Body(
|
|
|
|
None,
|
|
|
|
description="Optional batch size for the collection.",
|
|
|
|
examples=[256],
|
|
|
|
),
|
|
|
|
force_rebuild: bool = Body(
|
|
|
|
True,
|
|
|
|
description="Whether to force rebuild the collection if it already exists.",
|
|
|
|
examples=[False],
|
|
|
|
),
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Load files into the vector database.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
paths (Union[str, List[str]]): File paths or directories to load.
|
|
|
|
collection_name (str, optional): Name for the collection. Defaults to None.
|
|
|
|
collection_description (str, optional): Description for the collection. Defaults to None.
|
|
|
|
batch_size (int, optional): Batch size for processing. Defaults to None.
|
|
|
|
force_rebuild (bool, optional): Whether to force rebuild the collection. Defaults to False.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
dict: A dictionary containing a success message.
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
HTTPException: If loading files fails.
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
load_from_local_files(
|
|
|
|
paths_or_directory=paths,
|
|
|
|
collection_name=collection_name,
|
|
|
|
collection_description=collection_description,
|
|
|
|
batch_size=batch_size if batch_size is not None else 8,
|
|
|
|
force_rebuild=force_rebuild,
|
|
|
|
)
|
|
|
|
return {"message": "成功加载"}
|
|
|
|
except Exception as e:
|
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/load-website/")
|
|
|
|
def load_website(
|
|
|
|
urls: str | list[str] = Body(
|
|
|
|
...,
|
|
|
|
description="A list of URLs of websites to be loaded.",
|
|
|
|
examples=["https://milvus.io/docs/overview.md"],
|
|
|
|
),
|
|
|
|
collection_name: str = Body(
|
|
|
|
None,
|
|
|
|
description="Optional name for the collection.",
|
|
|
|
examples=["my_collection"],
|
|
|
|
),
|
|
|
|
collection_description: str = Body(
|
|
|
|
None,
|
|
|
|
description="Optional description for the collection.",
|
|
|
|
examples=["This is a test collection."],
|
|
|
|
),
|
|
|
|
batch_size: int = Body(
|
|
|
|
None,
|
|
|
|
description="Optional batch size for the collection.",
|
|
|
|
examples=[256],
|
|
|
|
),
|
|
|
|
force_rebuild: bool = Body(
|
|
|
|
False,
|
|
|
|
description="Whether to force rebuild the collection if it already exists.",
|
|
|
|
examples=[False],
|
|
|
|
),
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Load website content into the vector database.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
urls (Union[str, List[str]]): URLs of websites to load.
|
|
|
|
collection_name (str, optional): Name for the collection. Defaults to None.
|
|
|
|
collection_description (str, optional): Description for the collection. Defaults to None.
|
|
|
|
batch_size (int, optional): Batch size for processing. Defaults to None.
|
|
|
|
force_rebuild (bool, optional): Whether to force rebuild the collection. Defaults to False.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
dict: A dictionary containing a success message.
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
HTTPException: If loading website content fails.
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
load_from_website(
|
|
|
|
urls=urls,
|
|
|
|
collection_name=collection_name,
|
|
|
|
collection_description=collection_description,
|
|
|
|
batch_size=batch_size if batch_size is not None else 8,
|
|
|
|
force_rebuild=force_rebuild,
|
|
|
|
)
|
|
|
|
return {"message": "成功加载"}
|
|
|
|
except Exception as e:
|
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/query-stream/")
|
|
|
|
async def perform_query_stream(
|
|
|
|
original_query: str = Query(
|
|
|
|
...,
|
|
|
|
description="Your question here.",
|
|
|
|
examples=["Write a report about Milvus."],
|
|
|
|
),
|
|
|
|
max_iter: int = Query(
|
|
|
|
3,
|
|
|
|
description="The maximum number of iterations for reflection.",
|
|
|
|
ge=1,
|
|
|
|
examples=[3],
|
|
|
|
),
|
|
|
|
) -> StreamingResponse:
|
|
|
|
"""
|
|
|
|
Perform a query with real-time streaming of progress and results.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
original_query (str): The user's question or query.
|
|
|
|
max_iter (int, optional): Maximum number of iterations for reflection. Defaults to 3.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
StreamingResponse: A streaming response with real-time query progress and results.
|
|
|
|
"""
|
|
|
|
async def query_stream_generator() -> AsyncGenerator[str, None]:
|
|
|
|
"""生成查询流"""
|
|
|
|
message_callback = None
|
|
|
|
try:
|
|
|
|
# 清空之前的消息
|
|
|
|
message_stream = get_message_stream()
|
|
|
|
message_stream.clear_messages()
|
|
|
|
|
|
|
|
# 发送查询开始消息
|
|
|
|
yield f"data: {json.dumps({'type': 'start', 'content': '开始查询'}, ensure_ascii=False)}\n\n"
|
|
|
|
|
|
|
|
# 创建一个线程安全的队列来接收消息
|
|
|
|
message_queue = queue.Queue()
|
|
|
|
|
|
|
|
def message_callback(message):
|
|
|
|
"""消息回调函数"""
|
|
|
|
try:
|
|
|
|
# 直接放入线程安全的队列
|
|
|
|
message_queue.put(message)
|
|
|
|
except Exception as e:
|
|
|
|
print(f"Error in message callback: {e}")
|
|
|
|
|
|
|
|
# 注册回调函数
|
|
|
|
message_stream.add_callback(message_callback)
|
|
|
|
|
|
|
|
# 在后台线程中执行查询
|
|
|
|
def run_query():
|
|
|
|
try:
|
|
|
|
print(f"Starting query: {original_query} with max_iter: {max_iter}")
|
|
|
|
result_text, retrieval_results = query(original_query, max_iter=max_iter)
|
|
|
|
print(f"Query completed with result length: {len(result_text) if result_text else 0}")
|
|
|
|
print(f"Retrieved {len(retrieval_results) if retrieval_results else 0} documents")
|
|
|
|
return result_text, None
|
|
|
|
except Exception as e:
|
|
|
|
import traceback
|
|
|
|
print(f"Query failed with error: {e}")
|
|
|
|
print(f"Traceback: {traceback.format_exc()}")
|
|
|
|
return None, str(e)
|
|
|
|
|
|
|
|
# 使用线程池执行查询
|
|
|
|
import concurrent.futures
|
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
|
|
query_future = executor.submit(run_query)
|
|
|
|
|
|
|
|
# 监听消息和查询结果
|
|
|
|
while True:
|
|
|
|
try:
|
|
|
|
# 检查查询是否已完成
|
|
|
|
if query_future.done():
|
|
|
|
result_text, error = query_future.result()
|
|
|
|
if error:
|
|
|
|
print(f"Query error: {error}")
|
|
|
|
yield f"data: {json.dumps({'type': 'error', 'content': f'查询失败: {error}'}, ensure_ascii=False)}\n\n"
|
|
|
|
return
|
|
|
|
# 成功完成时,先发送answer消息包含查询结果
|
|
|
|
if result_text:
|
|
|
|
print(f"Sending answer message with result length: {len(result_text)}")
|
|
|
|
yield f"data: {json.dumps({'type': 'answer', 'content': result_text}, ensure_ascii=False)}\n\n"
|
|
|
|
# 然后发送complete消息并结束
|
|
|
|
print("Query completed successfully, sending complete message")
|
|
|
|
yield f"data: {json.dumps({'type': 'complete', 'content': '查询完成'}, ensure_ascii=False)}\n\n"
|
|
|
|
print("Complete message sent, ending stream")
|
|
|
|
return
|
|
|
|
|
|
|
|
# 尝试从队列获取消息,设置超时
|
|
|
|
try:
|
|
|
|
message = message_queue.get(timeout=2.0)
|
|
|
|
message_data = {
|
|
|
|
'type': message.type.value,
|
|
|
|
'content': message.content,
|
|
|
|
'timestamp': message.timestamp,
|
|
|
|
'metadata': message.metadata or {}
|
|
|
|
}
|
|
|
|
yield f"data: {json.dumps(message_data, ensure_ascii=False)}\n\n"
|
|
|
|
except queue.Empty:
|
|
|
|
# 超时,继续循环
|
|
|
|
continue
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
print(f"Error in stream loop: {e}")
|
|
|
|
yield f"data: {json.dumps({'type': 'error', 'content': f'流处理错误: {str(e)}'}, ensure_ascii=False)}\n\n"
|
|
|
|
return
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
yield f"data: {json.dumps({'type': 'error', 'content': f'查询错误: {str(e)}'}, ensure_ascii=False)}\n\n"
|
|
|
|
finally:
|
|
|
|
# 清理回调函数
|
|
|
|
if message_callback:
|
|
|
|
message_stream.remove_callback(message_callback)
|
|
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
query_stream_generator(),
|
|
|
|
media_type="text/event-stream",
|
|
|
|
headers={
|
|
|
|
"Cache-Control": "no-cache",
|
|
|
|
"Connection": "keep-alive",
|
|
|
|
"Access-Control-Allow-Origin": "*",
|
|
|
|
"Access-Control-Allow-Headers": "Cache-Control"
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/messages/")
|
|
|
|
def get_messages():
|
|
|
|
"""
|
|
|
|
Get all messages from the message stream.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
dict: A dictionary containing all messages.
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
message_stream = get_message_stream()
|
|
|
|
return {
|
|
|
|
"messages": message_stream.get_messages_as_dicts()
|
|
|
|
}
|
|
|
|
except Exception as e:
|
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/stream-messages/")
|
|
|
|
async def stream_messages() -> StreamingResponse:
|
|
|
|
"""
|
|
|
|
Stream messages in real-time using Server-Sent Events.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
StreamingResponse: A streaming response with real-time messages.
|
|
|
|
"""
|
|
|
|
async def event_generator() -> AsyncGenerator[str, None]:
|
|
|
|
"""生成SSE事件流"""
|
|
|
|
# 创建一个队列来接收消息
|
|
|
|
message_queue = asyncio.Queue()
|
|
|
|
|
|
|
|
def message_callback(message):
|
|
|
|
"""消息回调函数"""
|
|
|
|
try:
|
|
|
|
# 将消息放入队列
|
|
|
|
asyncio.create_task(message_queue.put(message))
|
|
|
|
except Exception as e:
|
|
|
|
print(f"Error in message callback: {e}")
|
|
|
|
|
|
|
|
# 注册回调函数
|
|
|
|
message_callbacks.append(message_callback)
|
|
|
|
message_stream = get_message_stream()
|
|
|
|
message_stream.add_callback(message_callback)
|
|
|
|
|
|
|
|
try:
|
|
|
|
# 发送连接建立消息
|
|
|
|
yield f"data: {json.dumps({'type': 'connection', 'message': 'Connected to message stream'}, ensure_ascii=False)}\n\n"
|
|
|
|
|
|
|
|
while True:
|
|
|
|
try:
|
|
|
|
# 等待新消息,设置超时以便能够检查连接状态
|
|
|
|
message = await asyncio.wait_for(message_queue.get(), timeout=2.0)
|
|
|
|
|
|
|
|
# 发送消息数据
|
|
|
|
message_data = {
|
|
|
|
'type': message.type.value,
|
|
|
|
'content': message.content,
|
|
|
|
'timestamp': message.timestamp,
|
|
|
|
'metadata': message.metadata or {}
|
|
|
|
}
|
|
|
|
yield f"data: {json.dumps(message_data, ensure_ascii=False)}\n\n"
|
|
|
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
# 发送心跳保持连接
|
|
|
|
yield f"data: {json.dumps({'type': 'heartbeat', 'timestamp': asyncio.get_event_loop().time()}, ensure_ascii=False)}\n\n"
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
print(f"Error in event generator: {e}")
|
|
|
|
finally:
|
|
|
|
# 清理回调函数
|
|
|
|
if message_callback in message_callbacks:
|
|
|
|
message_callbacks.remove(message_callback)
|
|
|
|
message_stream.remove_callback(message_callback)
|
|
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
event_generator(),
|
|
|
|
media_type="text/event-stream",
|
|
|
|
headers={
|
|
|
|
"Cache-Control": "no-cache",
|
|
|
|
"Connection": "keep-alive",
|
|
|
|
"Access-Control-Allow-Origin": "*",
|
|
|
|
"Access-Control-Allow-Headers": "Cache-Control"
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/clear-messages/")
|
|
|
|
def clear_messages():
|
|
|
|
"""
|
|
|
|
Clear all messages from the message stream.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
dict: A dictionary containing a success message.
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
message_stream = get_message_stream()
|
|
|
|
message_stream.clear_messages()
|
|
|
|
return {
|
|
|
|
"message": "Messages cleared successfully"
|
|
|
|
}
|
|
|
|
except Exception as e:
|
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser(description="FastAPI Server")
|
|
|
|
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to")
|
|
|
|
parser.add_argument("--port", type=int, default=8000, help="Port to bind the server to")
|
|
|
|
parser.add_argument("--enable-cors", action="store_true", default=False, help="Enable CORS support")
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
if args.enable_cors:
|
|
|
|
app.add_middleware(
|
|
|
|
CORSMiddleware,
|
|
|
|
allow_origins=["*"],
|
|
|
|
allow_credentials=True,
|
|
|
|
allow_methods=["*"],
|
|
|
|
allow_headers=["*"],
|
|
|
|
)
|
|
|
|
print("CORS is enabled.")
|
|
|
|
else:
|
|
|
|
print("CORS is disabled.")
|
|
|
|
|
|
|
|
print(f"Starting server on {args.host}:{args.port}")
|
|
|
|
uvicorn.run(app, host=args.host, port=args.port)
|