import argparse import uvicorn from fastapi import Body, FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse from pydantic import BaseModel import os from deepsearcher.configuration import Configuration, init_config from deepsearcher.offline_loading import load_from_local_files, load_from_website from deepsearcher.online_query import query app = FastAPI() config = Configuration() init_config(config) class ProviderConfigRequest(BaseModel): """ Request model for setting provider configuration. Attributes: feature (str): The feature to configure (e.g., 'embedding', 'llm'). provider (str): The provider name (e.g., 'openai', 'azure'). config (Dict): Configuration parameters for the provider. """ feature: str provider: str config: dict @app.get("/", response_class=HTMLResponse) async def read_root(): """ Serve the main HTML page. """ # 获取当前文件所在目录 current_dir = os.path.dirname(os.path.abspath(__file__)) # 构建模板文件路径 - 修复路径问题 template_path = os.path.join(current_dir, "deepsearcher", "backend", "templates", "index.html") # 读取HTML文件内容 try: with open(template_path, encoding="utf-8") as file: html_content = file.read() return HTMLResponse(content=html_content, status_code=200) except FileNotFoundError: # 如果找不到文件,提供一个简单的默认页面 default_html = f""" DeepSearcher

DeepSearcher

欢迎使用 DeepSearcher 智能搜索系统!

系统正在运行,但未找到前端模板文件。

请确认文件是否存在: {template_path}

API 接口

您仍然可以通过以下 API 接口使用系统:

有关 API 使用详情,请查看 API 文档

""" return HTMLResponse(content=default_html, status_code=200) @app.post("/set-provider-config/") def set_provider_config(request: ProviderConfigRequest): """ Set configuration for a specific provider. Args: request (ProviderConfigRequest): The request containing provider configuration. Returns: dict: A dictionary containing a success message and the updated configuration. Raises: HTTPException: If setting the provider config fails. """ try: config.set_provider_config(request.feature, request.provider, request.config) init_config(config) return { "message": "Provider config set successfully", "provider": request.provider, "config": request.config, } except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to set provider config: {str(e)}") @app.post("/load-files/") def load_files( paths: str | list[str] = Body( ..., description="A list of file paths to be loaded.", examples=["/path/to/file1", "/path/to/file2", "/path/to/dir1"], ), collection_name: str = Body( None, description="Optional name for the collection.", examples=["my_collection"], ), collection_description: str = Body( None, description="Optional description for the collection.", examples=["This is a test collection."], ), batch_size: int = Body( None, description="Optional batch size for the collection.", examples=[256], ), ): """ Load files into the vector database. Args: paths (Union[str, List[str]]): File paths or directories to load. collection_name (str, optional): Name for the collection. Defaults to None. collection_description (str, optional): Description for the collection. Defaults to None. batch_size (int, optional): Batch size for processing. Defaults to None. Returns: dict: A dictionary containing a success message. Raises: HTTPException: If loading files fails. """ try: # 修复batch_size为None时的问题 load_from_local_files( paths_or_directory=paths, collection_name=collection_name, collection_description=collection_description, batch_size=batch_size if batch_size is not None else 8, # 提供默认值 ) return {"message": "Files loaded successfully."} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/load-website/") def load_website( urls: str | list[str] = Body( ..., description="A list of URLs of websites to be loaded.", examples=["https://milvus.io/docs/overview.md"], ), collection_name: str = Body( None, description="Optional name for the collection.", examples=["my_collection"], ), collection_description: str = Body( None, description="Optional description for the collection.", examples=["This is a test collection."], ), batch_size: int = Body( None, description="Optional batch size for the collection.", examples=[256], ), ): """ Load website content into the vector database. Args: urls (Union[str, List[str]]): URLs of websites to load. collection_name (str, optional): Name for the collection. Defaults to None. collection_description (str, optional): Description for the collection. Defaults to None. batch_size (int, optional): Batch size for processing. Defaults to None. Returns: dict: A dictionary containing a success message. Raises: HTTPException: If loading website content fails. """ try: # 修复batch_size为None时的问题 load_from_website( urls=urls, collection_name=collection_name, collection_description=collection_description, batch_size=batch_size if batch_size is not None else 256, # 提供默认值 ) return {"message": "Website loaded successfully."} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/query/") def perform_query( original_query: str = Query( ..., description="Your question here.", examples=["Write a report about Milvus."], ), max_iter: int = Query( 3, description="The maximum number of iterations for reflection.", ge=1, examples=[3], ), ): """ Perform a query against the loaded data. Args: original_query (str): The user's question or query. max_iter (int, optional): Maximum number of iterations for reflection. Defaults to 3. Returns: dict: A dictionary containing the query result and token consumption. Raises: HTTPException: If the query fails. """ try: # 清除之前的进度消息 from deepsearcher.utils.log import clear_progress_messages clear_progress_messages() result_text, _, consume_token = query(original_query, max_iter) # 获取进度消息 from deepsearcher.utils.log import get_progress_messages progress_messages = get_progress_messages() return { "result": result_text, "consume_token": consume_token, "progress_messages": progress_messages } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": parser = argparse.ArgumentParser(description="FastAPI Server") parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to") parser.add_argument("--port", type=int, default=8000, help="Port to bind the server to") parser.add_argument("--enable-cors", action="store_true", default=False, help="Enable CORS support") args = parser.parse_args() if args.enable_cors: app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) print("CORS is enabled.") else: print("CORS is disabled.") print(f"Starting server on {args.host}:{args.port}") uvicorn.run(app, host=args.host, port=args.port)