|
|
@ -1,5 +1,4 @@ |
|
|
|
import argparse |
|
|
|
from typing import Dict, List, Union |
|
|
|
|
|
|
|
import uvicorn |
|
|
|
from fastapi import Body, FastAPI, HTTPException, Query |
|
|
@ -31,7 +30,7 @@ class ProviderConfigRequest(BaseModel): |
|
|
|
|
|
|
|
feature: str |
|
|
|
provider: str |
|
|
|
config: Dict |
|
|
|
config: dict |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
|
@ -43,15 +42,15 @@ async def read_root(): |
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
# 构建模板文件路径 - 修复路径问题 |
|
|
|
template_path = os.path.join(current_dir, "deepsearcher", "backend", "templates", "index.html") |
|
|
|
|
|
|
|
|
|
|
|
# 读取HTML文件内容 |
|
|
|
try: |
|
|
|
with open(template_path, "r", encoding="utf-8") as file: |
|
|
|
with open(template_path, encoding="utf-8") as file: |
|
|
|
html_content = file.read() |
|
|
|
return HTMLResponse(content=html_content, status_code=200) |
|
|
|
except FileNotFoundError: |
|
|
|
# 如果找不到文件,提供一个简单的默认页面 |
|
|
|
default_html = """ |
|
|
|
default_html = f""" |
|
|
|
<!DOCTYPE html> |
|
|
|
<html> |
|
|
|
<head> |
|
|
@ -71,7 +70,7 @@ async def read_root(): |
|
|
|
<div class="info"> |
|
|
|
<p>欢迎使用 DeepSearcher 智能搜索系统!</p> |
|
|
|
<p>系统正在运行,但未找到前端模板文件。</p> |
|
|
|
<p>请确认文件是否存在: {}</p> |
|
|
|
<p>请确认文件是否存在: {template_path}</p> |
|
|
|
</div> |
|
|
|
<div class="info"> |
|
|
|
<h2>API 接口</h2> |
|
|
@ -86,7 +85,7 @@ async def read_root(): |
|
|
|
</div> |
|
|
|
</body> |
|
|
|
</html> |
|
|
|
""".format(template_path) |
|
|
|
""" |
|
|
|
return HTMLResponse(content=default_html, status_code=200) |
|
|
|
|
|
|
|
|
|
|
@ -118,7 +117,7 @@ def set_provider_config(request: ProviderConfigRequest): |
|
|
|
|
|
|
|
@app.post("/load-files/") |
|
|
|
def load_files( |
|
|
|
paths: Union[str, List[str]] = Body( |
|
|
|
paths: str | list[str] = Body( |
|
|
|
..., |
|
|
|
description="A list of file paths to be loaded.", |
|
|
|
examples=["/path/to/file1", "/path/to/file2", "/path/to/dir1"], |
|
|
@ -160,7 +159,7 @@ def load_files( |
|
|
|
paths_or_directory=paths, |
|
|
|
collection_name=collection_name, |
|
|
|
collection_description=collection_description, |
|
|
|
batch_size=batch_size if batch_size is not None else 256, # 提供默认值 |
|
|
|
batch_size=batch_size if batch_size is not None else 8, # 提供默认值 |
|
|
|
) |
|
|
|
return {"message": "Files loaded successfully."} |
|
|
|
except Exception as e: |
|
|
@ -169,7 +168,7 @@ def load_files( |
|
|
|
|
|
|
|
@app.post("/load-website/") |
|
|
|
def load_website( |
|
|
|
urls: Union[str, List[str]] = Body( |
|
|
|
urls: str | list[str] = Body( |
|
|
|
..., |
|
|
|
description="A list of URLs of websites to be loaded.", |
|
|
|
examples=["https://milvus.io/docs/overview.md"], |
|
|
@ -249,15 +248,15 @@ def perform_query( |
|
|
|
# 清除之前的进度消息 |
|
|
|
from deepsearcher.utils.log import clear_progress_messages |
|
|
|
clear_progress_messages() |
|
|
|
|
|
|
|
|
|
|
|
result_text, _, consume_token = query(original_query, max_iter) |
|
|
|
|
|
|
|
|
|
|
|
# 获取进度消息 |
|
|
|
from deepsearcher.utils.log import get_progress_messages |
|
|
|
progress_messages = get_progress_messages() |
|
|
|
|
|
|
|
|
|
|
|
return { |
|
|
|
"result": result_text, |
|
|
|
"result": result_text, |
|
|
|
"consume_token": consume_token, |
|
|
|
"progress_messages": progress_messages |
|
|
|
} |
|
|
@ -271,7 +270,7 @@ if __name__ == "__main__": |
|
|
|
parser.add_argument("--port", type=int, default=8000, help="Port to bind the server to") |
|
|
|
parser.add_argument("--enable-cors", action="store_true", default=False, help="Enable CORS support") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
if args.enable_cors: |
|
|
|
app.add_middleware( |
|
|
|
CORSMiddleware, |
|
|
@ -283,6 +282,6 @@ if __name__ == "__main__": |
|
|
|
print("CORS is enabled.") |
|
|
|
else: |
|
|
|
print("CORS is disabled.") |
|
|
|
|
|
|
|
|
|
|
|
print(f"Starting server on {args.host}:{args.port}") |
|
|
|
uvicorn.run(app, host=args.host, port=args.port) |
|
|
|
uvicorn.run(app, host=args.host, port=args.port) |
|
|
|