You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

486 lines
17 KiB

import argparse
import uvicorn
from fastapi import Body, FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, StreamingResponse
from pydantic import BaseModel
import os
import asyncio
import json
import queue
from typing import AsyncGenerator
from deepsearcher.configuration import Configuration, init_config
from deepsearcher.offline_loading import load_from_local_files, load_from_website
from deepsearcher.online_query import query
from deepsearcher.utils.message_stream import get_message_stream
app = FastAPI()
config = Configuration()
init_config(config)
# 全局变量用于存储消息流回调
message_callbacks = []
class ProviderConfigRequest(BaseModel):
"""
Request model for setting provider configuration.
Attributes:
feature (str): The feature to configure (e.g., 'embedding', 'llm').
provider (str): The provider name (e.g., 'openai', 'azure').
config (Dict): Configuration parameters for the provider.
"""
feature: str
provider: str
config: dict
@app.get("/", response_class=HTMLResponse)
async def read_root():
"""
Serve the main HTML page.
"""
current_dir = os.path.dirname(os.path.abspath(__file__))
template_path = os.path.join(current_dir, "deepsearcher", "backend", "templates", "index.html")
try:
with open(template_path, encoding="utf-8") as file:
html_content = file.read()
return HTMLResponse(content=html_content, status_code=200)
except FileNotFoundError:
raise HTTPException(status_code=404, detail="Template file not found")
@app.post("/set-provider-config/")
def set_provider_config(request: ProviderConfigRequest):
"""
Set configuration for a specific provider.
Args:
request (ProviderConfigRequest): The request containing provider configuration.
Returns:
dict: A dictionary containing a success message and the updated configuration.
Raises:
HTTPException: If setting the provider config fails.
"""
try:
config.set_provider_config(request.feature, request.provider, request.config)
init_config(config)
return {
"message": "Provider config set successfully",
"provider": request.provider,
"config": request.config,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to set provider config: {str(e)}")
@app.post("/load-files/")
def load_files(
paths: str | list[str] = Body(
...,
description="A list of file paths to be loaded.",
examples=["/path/to/file1", "/path/to/file2", "/path/to/dir1"],
),
collection_name: str = Body(
None,
description="Optional name for the collection.",
examples=["my_collection"],
),
collection_description: str = Body(
None,
description="Optional description for the collection.",
examples=["This is a test collection."],
),
batch_size: int = Body(
None,
description="Optional batch size for the collection.",
examples=[256],
),
force_rebuild: bool = Body(
False,
description="Whether to force rebuild the collection if it already exists.",
examples=[False],
),
):
"""
Load files into the vector database.
Args:
paths (Union[str, List[str]]): File paths or directories to load.
collection_name (str, optional): Name for the collection. Defaults to None.
collection_description (str, optional): Description for the collection. Defaults to None.
batch_size (int, optional): Batch size for processing. Defaults to None.
force_rebuild (bool, optional): Whether to force rebuild the collection. Defaults to False.
Returns:
dict: A dictionary containing a success message.
Raises:
HTTPException: If loading files fails.
"""
try:
load_from_local_files(
paths_or_directory=paths,
collection_name=collection_name,
collection_description=collection_description,
batch_size=batch_size if batch_size is not None else 8,
force_rebuild=force_rebuild,
)
return {"message": "Files loaded successfully."}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/load-website/")
def load_website(
urls: str | list[str] = Body(
...,
description="A list of URLs of websites to be loaded.",
examples=["https://milvus.io/docs/overview.md"],
),
collection_name: str = Body(
None,
description="Optional name for the collection.",
examples=["my_collection"],
),
collection_description: str = Body(
None,
description="Optional description for the collection.",
examples=["This is a test collection."],
),
batch_size: int = Body(
None,
description="Optional batch size for the collection.",
examples=[256],
),
force_rebuild: bool = Body(
False,
description="Whether to force rebuild the collection if it already exists.",
examples=[False],
),
):
"""
Load website content into the vector database.
Args:
urls (Union[str, List[str]]): URLs of websites to load.
collection_name (str, optional): Name for the collection. Defaults to None.
collection_description (str, optional): Description for the collection. Defaults to None.
batch_size (int, optional): Batch size for processing. Defaults to None.
force_rebuild (bool, optional): Whether to force rebuild the collection. Defaults to False.
Returns:
dict: A dictionary containing a success message.
Raises:
HTTPException: If loading website content fails.
"""
try:
load_from_website(
urls=urls,
collection_name=collection_name,
collection_description=collection_description,
batch_size=batch_size if batch_size is not None else 8,
force_rebuild=force_rebuild,
)
return {"message": "Website loaded successfully."}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/query/")
def perform_query(
original_query: str = Query(
...,
description="Your question here.",
examples=["Write a report about Milvus."],
),
max_iter: int = Query(
3,
description="The maximum number of iterations for reflection.",
ge=1,
examples=[3],
),
):
"""
Perform a query against the loaded data.
Args:
original_query (str): The user's question or query.
max_iter (int, optional): Maximum number of iterations for reflection. Defaults to 3.
Returns:
dict: A dictionary containing the query result and token consumption.
Raises:
HTTPException: If the query fails.
"""
try:
# 清空之前的消息
message_stream = get_message_stream()
message_stream.clear_messages()
result_text, _ = query(original_query, max_iter)
return {
"result": result_text,
"messages": message_stream.get_messages_as_dicts()
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/query-stream/")
async def perform_query_stream(
original_query: str = Query(
...,
description="Your question here.",
examples=["Write a report about Milvus."],
),
max_iter: int = Query(
3,
description="The maximum number of iterations for reflection.",
ge=1,
examples=[3],
),
) -> StreamingResponse:
"""
Perform a query with real-time streaming of progress and results.
Args:
original_query (str): The user's question or query.
max_iter (int, optional): Maximum number of iterations for reflection. Defaults to 3.
Returns:
StreamingResponse: A streaming response with real-time query progress and results.
"""
async def query_stream_generator() -> AsyncGenerator[str, None]:
"""生成查询流"""
message_callback = None
try:
# 清空之前的消息
message_stream = get_message_stream()
message_stream.clear_messages()
# 发送查询开始消息
yield f"data: {json.dumps({'type': 'start', 'content': '开始查询'}, ensure_ascii=False)}\n\n"
# 创建一个线程安全的队列来接收消息
message_queue = queue.Queue()
def message_callback(message):
"""消息回调函数"""
try:
# 直接放入线程安全的队列
message_queue.put(message)
except Exception as e:
print(f"Error in message callback: {e}")
# 注册回调函数
message_stream.add_callback(message_callback)
# 在后台线程中执行查询
def run_query():
try:
print(f"Starting query: {original_query} with max_iter: {max_iter}")
result_text, retrieval_results = query(original_query, max_iter)
print(f"Query completed with result length: {len(result_text) if result_text else 0}")
print(f"Retrieved {len(retrieval_results) if retrieval_results else 0} documents")
return result_text, None
except Exception as e:
import traceback
print(f"Query failed with error: {e}")
print(f"Traceback: {traceback.format_exc()}")
return None, str(e)
# 使用线程池执行查询
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
query_future = executor.submit(run_query)
# 监听消息和查询结果
while True:
try:
# 检查查询是否已完成
if query_future.done():
result_text, error = query_future.result()
if error:
print(f"Query error: {error}")
yield f"data: {json.dumps({'type': 'error', 'content': f'查询失败: {error}'}, ensure_ascii=False)}\n\n"
return
# 成功完成时,发送complete消息并结束
print("Query completed successfully, sending complete message")
yield f"data: {json.dumps({'type': 'complete', 'content': '查询完成'}, ensure_ascii=False)}\n\n"
print("Complete message sent, ending stream")
return
# 尝试从队列获取消息,设置超时
try:
message = message_queue.get(timeout=2.0)
message_data = {
'type': message.type.value,
'content': message.content,
'timestamp': message.timestamp,
'metadata': message.metadata or {}
}
yield f"data: {json.dumps(message_data, ensure_ascii=False)}\n\n"
except queue.Empty:
# 超时,继续循环
continue
except Exception as e:
print(f"Error in stream loop: {e}")
yield f"data: {json.dumps({'type': 'error', 'content': f'流处理错误: {str(e)}'}, ensure_ascii=False)}\n\n"
return
except Exception as e:
yield f"data: {json.dumps({'type': 'error', 'content': f'查询错误: {str(e)}'}, ensure_ascii=False)}\n\n"
finally:
# 清理回调函数
if message_callback:
message_stream.remove_callback(message_callback)
return StreamingResponse(
query_stream_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "Cache-Control"
}
)
@app.get("/messages/")
def get_messages():
"""
Get all messages from the message stream.
Returns:
dict: A dictionary containing all messages.
"""
try:
message_stream = get_message_stream()
return {
"messages": message_stream.get_messages_as_dicts()
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/stream-messages/")
async def stream_messages() -> StreamingResponse:
"""
Stream messages in real-time using Server-Sent Events.
Returns:
StreamingResponse: A streaming response with real-time messages.
"""
async def event_generator() -> AsyncGenerator[str, None]:
"""生成SSE事件流"""
# 创建一个队列来接收消息
message_queue = asyncio.Queue()
def message_callback(message):
"""消息回调函数"""
try:
# 将消息放入队列
asyncio.create_task(message_queue.put(message))
except Exception as e:
print(f"Error in message callback: {e}")
# 注册回调函数
message_callbacks.append(message_callback)
message_stream = get_message_stream()
message_stream.add_callback(message_callback)
try:
# 发送连接建立消息
yield f"data: {json.dumps({'type': 'connection', 'message': 'Connected to message stream'}, ensure_ascii=False)}\n\n"
while True:
try:
# 等待新消息,设置超时以便能够检查连接状态
message = await asyncio.wait_for(message_queue.get(), timeout=2.0)
# 发送消息数据
message_data = {
'type': message.type.value,
'content': message.content,
'timestamp': message.timestamp,
'metadata': message.metadata or {}
}
yield f"data: {json.dumps(message_data, ensure_ascii=False)}\n\n"
except asyncio.TimeoutError:
# 发送心跳保持连接
yield f"data: {json.dumps({'type': 'heartbeat', 'timestamp': asyncio.get_event_loop().time()}, ensure_ascii=False)}\n\n"
except Exception as e:
print(f"Error in event generator: {e}")
finally:
# 清理回调函数
if message_callback in message_callbacks:
message_callbacks.remove(message_callback)
message_stream.remove_callback(message_callback)
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "Cache-Control"
}
)
@app.post("/clear-messages/")
def clear_messages():
"""
Clear all messages from the message stream.
Returns:
dict: A dictionary containing a success message.
"""
try:
message_stream = get_message_stream()
message_stream.clear_messages()
return {
"message": "Messages cleared successfully"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="FastAPI Server")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to")
parser.add_argument("--port", type=int, default=8000, help="Port to bind the server to")
parser.add_argument("--enable-cors", action="store_true", default=False, help="Enable CORS support")
args = parser.parse_args()
if args.enable_cors:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
print("CORS is enabled.")
else:
print("CORS is disabled.")
print(f"Starting server on {args.host}:{args.port}")
uvicorn.run(app, host=args.host, port=args.port)