|
|
|
import argparse
|
|
|
|
|
|
|
|
import uvicorn
|
|
|
|
from fastapi import Body, FastAPI, HTTPException, Query
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
from fastapi.responses import HTMLResponse
|
|
|
|
from pydantic import BaseModel
|
|
|
|
import os
|
|
|
|
|
|
|
|
from deepsearcher.configuration import Configuration, init_config
|
|
|
|
from deepsearcher.offline_loading import load_from_local_files, load_from_website
|
|
|
|
from deepsearcher.online_query import query
|
|
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
|
|
config = Configuration()
|
|
|
|
|
|
|
|
init_config(config)
|
|
|
|
|
|
|
|
|
|
|
|
class ProviderConfigRequest(BaseModel):
|
|
|
|
"""
|
|
|
|
Request model for setting provider configuration.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
feature (str): The feature to configure (e.g., 'embedding', 'llm').
|
|
|
|
provider (str): The provider name (e.g., 'openai', 'azure').
|
|
|
|
config (Dict): Configuration parameters for the provider.
|
|
|
|
"""
|
|
|
|
|
|
|
|
feature: str
|
|
|
|
provider: str
|
|
|
|
config: dict
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
|
|
async def read_root():
|
|
|
|
"""
|
|
|
|
Serve the main HTML page.
|
|
|
|
"""
|
|
|
|
# 获取当前文件所在目录
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
# 构建模板文件路径 - 修复路径问题
|
|
|
|
template_path = os.path.join(current_dir, "deepsearcher", "backend", "templates", "index.html")
|
|
|
|
|
|
|
|
# 读取HTML文件内容
|
|
|
|
try:
|
|
|
|
with open(template_path, encoding="utf-8") as file:
|
|
|
|
html_content = file.read()
|
|
|
|
return HTMLResponse(content=html_content, status_code=200)
|
|
|
|
except FileNotFoundError:
|
|
|
|
# 如果找不到文件,提供一个简单的默认页面
|
|
|
|
default_html = f"""
|
|
|
|
<!DOCTYPE html>
|
|
|
|
<html>
|
|
|
|
<head>
|
|
|
|
<title>DeepSearcher</title>
|
|
|
|
<meta charset="utf-8">
|
|
|
|
<style>
|
|
|
|
body {{ font-family: Arial, sans-serif; margin: 40px; }}
|
|
|
|
.container {{ max-width: 800px; margin: 0 auto; }}
|
|
|
|
h1 {{ color: #333; }}
|
|
|
|
.info {{ background: #f0f8ff; padding: 15px; border-radius: 5px; }}
|
|
|
|
.error {{ background: #ffe4e1; padding: 15px; border-radius: 5px; color: #d00; }}
|
|
|
|
</style>
|
|
|
|
</head>
|
|
|
|
<body>
|
|
|
|
<div class="container">
|
|
|
|
<h1>DeepSearcher</h1>
|
|
|
|
<div class="info">
|
|
|
|
<p>欢迎使用 DeepSearcher 智能搜索系统!</p>
|
|
|
|
<p>系统正在运行,但未找到前端模板文件。</p>
|
|
|
|
<p>请确认文件是否存在: {template_path}</p>
|
|
|
|
</div>
|
|
|
|
<div class="info">
|
|
|
|
<h2>API 接口</h2>
|
|
|
|
<p>您仍然可以通过以下 API 接口使用系统:</p>
|
|
|
|
<ul>
|
|
|
|
<li><code>POST /load-files/</code> - 加载本地文件</li>
|
|
|
|
<li><code>POST /load-website/</code> - 加载网站内容</li>
|
|
|
|
<li><code>GET /query/</code> - 执行查询</li>
|
|
|
|
</ul>
|
|
|
|
<p>有关 API 使用详情,请查看 <a href="/docs">API 文档</a></p>
|
|
|
|
</div>
|
|
|
|
</div>
|
|
|
|
</body>
|
|
|
|
</html>
|
|
|
|
"""
|
|
|
|
return HTMLResponse(content=default_html, status_code=200)
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/set-provider-config/")
|
|
|
|
def set_provider_config(request: ProviderConfigRequest):
|
|
|
|
"""
|
|
|
|
Set configuration for a specific provider.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
request (ProviderConfigRequest): The request containing provider configuration.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
dict: A dictionary containing a success message and the updated configuration.
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
HTTPException: If setting the provider config fails.
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
config.set_provider_config(request.feature, request.provider, request.config)
|
|
|
|
init_config(config)
|
|
|
|
return {
|
|
|
|
"message": "Provider config set successfully",
|
|
|
|
"provider": request.provider,
|
|
|
|
"config": request.config,
|
|
|
|
}
|
|
|
|
except Exception as e:
|
|
|
|
raise HTTPException(status_code=500, detail=f"Failed to set provider config: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/load-files/")
|
|
|
|
def load_files(
|
|
|
|
paths: str | list[str] = Body(
|
|
|
|
...,
|
|
|
|
description="A list of file paths to be loaded.",
|
|
|
|
examples=["/path/to/file1", "/path/to/file2", "/path/to/dir1"],
|
|
|
|
),
|
|
|
|
collection_name: str = Body(
|
|
|
|
None,
|
|
|
|
description="Optional name for the collection.",
|
|
|
|
examples=["my_collection"],
|
|
|
|
),
|
|
|
|
collection_description: str = Body(
|
|
|
|
None,
|
|
|
|
description="Optional description for the collection.",
|
|
|
|
examples=["This is a test collection."],
|
|
|
|
),
|
|
|
|
batch_size: int = Body(
|
|
|
|
None,
|
|
|
|
description="Optional batch size for the collection.",
|
|
|
|
examples=[256],
|
|
|
|
),
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Load files into the vector database.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
paths (Union[str, List[str]]): File paths or directories to load.
|
|
|
|
collection_name (str, optional): Name for the collection. Defaults to None.
|
|
|
|
collection_description (str, optional): Description for the collection. Defaults to None.
|
|
|
|
batch_size (int, optional): Batch size for processing. Defaults to None.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
dict: A dictionary containing a success message.
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
HTTPException: If loading files fails.
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
# 修复batch_size为None时的问题
|
|
|
|
load_from_local_files(
|
|
|
|
paths_or_directory=paths,
|
|
|
|
collection_name=collection_name,
|
|
|
|
collection_description=collection_description,
|
|
|
|
batch_size=batch_size if batch_size is not None else 8, # 提供默认值
|
|
|
|
)
|
|
|
|
return {"message": "Files loaded successfully."}
|
|
|
|
except Exception as e:
|
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/load-website/")
|
|
|
|
def load_website(
|
|
|
|
urls: str | list[str] = Body(
|
|
|
|
...,
|
|
|
|
description="A list of URLs of websites to be loaded.",
|
|
|
|
examples=["https://milvus.io/docs/overview.md"],
|
|
|
|
),
|
|
|
|
collection_name: str = Body(
|
|
|
|
None,
|
|
|
|
description="Optional name for the collection.",
|
|
|
|
examples=["my_collection"],
|
|
|
|
),
|
|
|
|
collection_description: str = Body(
|
|
|
|
None,
|
|
|
|
description="Optional description for the collection.",
|
|
|
|
examples=["This is a test collection."],
|
|
|
|
),
|
|
|
|
batch_size: int = Body(
|
|
|
|
None,
|
|
|
|
description="Optional batch size for the collection.",
|
|
|
|
examples=[256],
|
|
|
|
),
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Load website content into the vector database.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
urls (Union[str, List[str]]): URLs of websites to load.
|
|
|
|
collection_name (str, optional): Name for the collection. Defaults to None.
|
|
|
|
collection_description (str, optional): Description for the collection. Defaults to None.
|
|
|
|
batch_size (int, optional): Batch size for processing. Defaults to None.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
dict: A dictionary containing a success message.
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
HTTPException: If loading website content fails.
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
# 修复batch_size为None时的问题
|
|
|
|
load_from_website(
|
|
|
|
urls=urls,
|
|
|
|
collection_name=collection_name,
|
|
|
|
collection_description=collection_description,
|
|
|
|
batch_size=batch_size if batch_size is not None else 256, # 提供默认值
|
|
|
|
)
|
|
|
|
return {"message": "Website loaded successfully."}
|
|
|
|
except Exception as e:
|
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/query/")
|
|
|
|
def perform_query(
|
|
|
|
original_query: str = Query(
|
|
|
|
...,
|
|
|
|
description="Your question here.",
|
|
|
|
examples=["Write a report about Milvus."],
|
|
|
|
),
|
|
|
|
max_iter: int = Query(
|
|
|
|
3,
|
|
|
|
description="The maximum number of iterations for reflection.",
|
|
|
|
ge=1,
|
|
|
|
examples=[3],
|
|
|
|
),
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Perform a query against the loaded data.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
original_query (str): The user's question or query.
|
|
|
|
max_iter (int, optional): Maximum number of iterations for reflection. Defaults to 3.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
dict: A dictionary containing the query result and token consumption.
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
HTTPException: If the query fails.
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
# 清除之前的进度消息
|
|
|
|
from deepsearcher.utils.log import clear_progress_messages
|
|
|
|
clear_progress_messages()
|
|
|
|
|
|
|
|
result_text, _, consume_token = query(original_query, max_iter)
|
|
|
|
|
|
|
|
# 获取进度消息
|
|
|
|
from deepsearcher.utils.log import get_progress_messages
|
|
|
|
progress_messages = get_progress_messages()
|
|
|
|
|
|
|
|
return {
|
|
|
|
"result": result_text,
|
|
|
|
"consume_token": consume_token,
|
|
|
|
"progress_messages": progress_messages
|
|
|
|
}
|
|
|
|
except Exception as e:
|
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser(description="FastAPI Server")
|
|
|
|
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to")
|
|
|
|
parser.add_argument("--port", type=int, default=8000, help="Port to bind the server to")
|
|
|
|
parser.add_argument("--enable-cors", action="store_true", default=False, help="Enable CORS support")
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
if args.enable_cors:
|
|
|
|
app.add_middleware(
|
|
|
|
CORSMiddleware,
|
|
|
|
allow_origins=["*"],
|
|
|
|
allow_credentials=True,
|
|
|
|
allow_methods=["*"],
|
|
|
|
allow_headers=["*"],
|
|
|
|
)
|
|
|
|
print("CORS is enabled.")
|
|
|
|
else:
|
|
|
|
print("CORS is disabled.")
|
|
|
|
|
|
|
|
print(f"Starting server on {args.host}:{args.port}")
|
|
|
|
uvicorn.run(app, host=args.host, port=args.port)
|