You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

288 lines
9.4 KiB

2 weeks ago
import argparse
import uvicorn
from fastapi import Body, FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
2 weeks ago
from pydantic import BaseModel
import os
2 weeks ago
from deepsearcher.configuration import Configuration, init_config
from deepsearcher.offline_loading import load_from_local_files, load_from_website
from deepsearcher.online_query import query
app = FastAPI()
config = Configuration()
init_config(config)
class ProviderConfigRequest(BaseModel):
"""
Request model for setting provider configuration.
Attributes:
feature (str): The feature to configure (e.g., 'embedding', 'llm').
provider (str): The provider name (e.g., 'openai', 'azure').
config (Dict): Configuration parameters for the provider.
"""
feature: str
provider: str
config: dict
2 weeks ago
@app.get("/", response_class=HTMLResponse)
async def read_root():
"""
Serve the main HTML page.
"""
# 获取当前文件所在目录
current_dir = os.path.dirname(os.path.abspath(__file__))
# 构建模板文件路径 - 修复路径问题
template_path = os.path.join(current_dir, "deepsearcher", "backend", "templates", "index.html")
# 读取HTML文件内容
try:
with open(template_path, encoding="utf-8") as file:
html_content = file.read()
return HTMLResponse(content=html_content, status_code=200)
except FileNotFoundError:
# 如果找不到文件,提供一个简单的默认页面
default_html = f"""
<!DOCTYPE html>
<html>
<head>
<title>DeepSearcher</title>
<meta charset="utf-8">
<style>
body {{ font-family: Arial, sans-serif; margin: 40px; }}
.container {{ max-width: 800px; margin: 0 auto; }}
h1 {{ color: #333; }}
.info {{ background: #f0f8ff; padding: 15px; border-radius: 5px; }}
.error {{ background: #ffe4e1; padding: 15px; border-radius: 5px; color: #d00; }}
</style>
</head>
<body>
<div class="container">
<h1>DeepSearcher</h1>
<div class="info">
<p>欢迎使用 DeepSearcher 智能搜索系统!</p>
<p>系统正在运行但未找到前端模板文件</p>
<p>请确认文件是否存在: {template_path}</p>
</div>
<div class="info">
<h2>API 接口</h2>
<p>您仍然可以通过以下 API 接口使用系统:</p>
<ul>
<li><code>POST /load-files/</code> - 加载本地文件</li>
<li><code>POST /load-website/</code> - 加载网站内容</li>
<li><code>GET /query/</code> - 执行查询</li>
</ul>
<p>有关 API 使用详情请查看 <a href="/docs">API 文档</a></p>
</div>
</div>
</body>
</html>
"""
return HTMLResponse(content=default_html, status_code=200)
2 weeks ago
@app.post("/set-provider-config/")
def set_provider_config(request: ProviderConfigRequest):
"""
Set configuration for a specific provider.
Args:
request (ProviderConfigRequest): The request containing provider configuration.
Returns:
dict: A dictionary containing a success message and the updated configuration.
Raises:
HTTPException: If setting the provider config fails.
"""
try:
config.set_provider_config(request.feature, request.provider, request.config)
init_config(config)
return {
"message": "Provider config set successfully",
"provider": request.provider,
"config": request.config,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to set provider config: {str(e)}")
@app.post("/load-files/")
def load_files(
paths: str | list[str] = Body(
2 weeks ago
...,
description="A list of file paths to be loaded.",
examples=["/path/to/file1", "/path/to/file2", "/path/to/dir1"],
),
collection_name: str = Body(
None,
description="Optional name for the collection.",
examples=["my_collection"],
),
collection_description: str = Body(
None,
description="Optional description for the collection.",
examples=["This is a test collection."],
),
batch_size: int = Body(
None,
description="Optional batch size for the collection.",
examples=[256],
),
):
"""
Load files into the vector database.
Args:
paths (Union[str, List[str]]): File paths or directories to load.
collection_name (str, optional): Name for the collection. Defaults to None.
collection_description (str, optional): Description for the collection. Defaults to None.
batch_size (int, optional): Batch size for processing. Defaults to None.
Returns:
dict: A dictionary containing a success message.
Raises:
HTTPException: If loading files fails.
"""
try:
# 修复batch_size为None时的问题
2 weeks ago
load_from_local_files(
paths_or_directory=paths,
collection_name=collection_name,
collection_description=collection_description,
batch_size=batch_size if batch_size is not None else 8, # 提供默认值
2 weeks ago
)
return {"message": "Files loaded successfully."}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/load-website/")
def load_website(
urls: str | list[str] = Body(
2 weeks ago
...,
description="A list of URLs of websites to be loaded.",
examples=["https://milvus.io/docs/overview.md"],
),
collection_name: str = Body(
None,
description="Optional name for the collection.",
examples=["my_collection"],
),
collection_description: str = Body(
None,
description="Optional description for the collection.",
examples=["This is a test collection."],
),
batch_size: int = Body(
None,
description="Optional batch size for the collection.",
examples=[256],
),
):
"""
Load website content into the vector database.
Args:
urls (Union[str, List[str]]): URLs of websites to load.
collection_name (str, optional): Name for the collection. Defaults to None.
collection_description (str, optional): Description for the collection. Defaults to None.
batch_size (int, optional): Batch size for processing. Defaults to None.
Returns:
dict: A dictionary containing a success message.
Raises:
HTTPException: If loading website content fails.
"""
try:
# 修复batch_size为None时的问题
2 weeks ago
load_from_website(
urls=urls,
collection_name=collection_name,
collection_description=collection_description,
batch_size=batch_size if batch_size is not None else 256, # 提供默认值
2 weeks ago
)
return {"message": "Website loaded successfully."}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/query/")
def perform_query(
original_query: str = Query(
...,
description="Your question here.",
examples=["Write a report about Milvus."],
),
max_iter: int = Query(
3,
description="The maximum number of iterations for reflection.",
ge=1,
examples=[3],
),
):
"""
Perform a query against the loaded data.
Args:
original_query (str): The user's question or query.
max_iter (int, optional): Maximum number of iterations for reflection. Defaults to 3.
Returns:
dict: A dictionary containing the query result and token consumption.
Raises:
HTTPException: If the query fails.
"""
try:
# 清除之前的进度消息
from deepsearcher.utils.log import clear_progress_messages
clear_progress_messages()
2 weeks ago
result_text, _, consume_token = query(original_query, max_iter)
# 获取进度消息
from deepsearcher.utils.log import get_progress_messages
progress_messages = get_progress_messages()
return {
"result": result_text,
"consume_token": consume_token,
"progress_messages": progress_messages
}
2 weeks ago
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="FastAPI Server")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to")
parser.add_argument("--port", type=int, default=8000, help="Port to bind the server to")
parser.add_argument("--enable-cors", action="store_true", default=False, help="Enable CORS support")
2 weeks ago
args = parser.parse_args()
2 weeks ago
if args.enable_cors:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
print("CORS is enabled.")
else:
print("CORS is disabled.")
print(f"Starting server on {args.host}:{args.port}")
uvicorn.run(app, host=args.host, port=args.port)