import argparse import uvicorn from fastapi import Body, FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse from pydantic import BaseModel import os from deepsearcher.configuration import Configuration, init_config from deepsearcher.offline_loading import load_from_local_files, load_from_website from deepsearcher.online_query import query app = FastAPI() config = Configuration() init_config(config) class ProviderConfigRequest(BaseModel): """ Request model for setting provider configuration. Attributes: feature (str): The feature to configure (e.g., 'embedding', 'llm'). provider (str): The provider name (e.g., 'openai', 'azure'). config (Dict): Configuration parameters for the provider. """ feature: str provider: str config: dict @app.get("/", response_class=HTMLResponse) async def read_root(): """ Serve the main HTML page. """ current_dir = os.path.dirname(os.path.abspath(__file__)) template_path = os.path.join(current_dir, "deepsearcher", "backend", "templates", "index.html") try: with open(template_path, encoding="utf-8") as file: html_content = file.read() return HTMLResponse(content=html_content, status_code=200) except FileNotFoundError: raise HTTPException(status_code=404, detail="Template file not found") @app.post("/set-provider-config/") def set_provider_config(request: ProviderConfigRequest): """ Set configuration for a specific provider. Args: request (ProviderConfigRequest): The request containing provider configuration. Returns: dict: A dictionary containing a success message and the updated configuration. Raises: HTTPException: If setting the provider config fails. """ try: config.set_provider_config(request.feature, request.provider, request.config) init_config(config) return { "message": "Provider config set successfully", "provider": request.provider, "config": request.config, } except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to set provider config: {str(e)}") @app.post("/load-files/") def load_files( paths: str | list[str] = Body( ..., description="A list of file paths to be loaded.", examples=["/path/to/file1", "/path/to/file2", "/path/to/dir1"], ), collection_name: str = Body( None, description="Optional name for the collection.", examples=["my_collection"], ), collection_description: str = Body( None, description="Optional description for the collection.", examples=["This is a test collection."], ), batch_size: int = Body( None, description="Optional batch size for the collection.", examples=[256], ), ): """ Load files into the vector database. Args: paths (Union[str, List[str]]): File paths or directories to load. collection_name (str, optional): Name for the collection. Defaults to None. collection_description (str, optional): Description for the collection. Defaults to None. batch_size (int, optional): Batch size for processing. Defaults to None. Returns: dict: A dictionary containing a success message. Raises: HTTPException: If loading files fails. """ try: load_from_local_files( paths_or_directory=paths, collection_name=collection_name, collection_description=collection_description, batch_size=batch_size if batch_size is not None else 8, ) return {"message": "Files loaded successfully."} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/load-website/") def load_website( urls: str | list[str] = Body( ..., description="A list of URLs of websites to be loaded.", examples=["https://milvus.io/docs/overview.md"], ), collection_name: str = Body( None, description="Optional name for the collection.", examples=["my_collection"], ), collection_description: str = Body( None, description="Optional description for the collection.", examples=["This is a test collection."], ), batch_size: int = Body( None, description="Optional batch size for the collection.", examples=[256], ), ): """ Load website content into the vector database. Args: urls (Union[str, List[str]]): URLs of websites to load. collection_name (str, optional): Name for the collection. Defaults to None. collection_description (str, optional): Description for the collection. Defaults to None. batch_size (int, optional): Batch size for processing. Defaults to None. Returns: dict: A dictionary containing a success message. Raises: HTTPException: If loading website content fails. """ try: load_from_website( urls=urls, collection_name=collection_name, collection_description=collection_description, batch_size=batch_size if batch_size is not None else 8, ) return {"message": "Website loaded successfully."} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/query/") def perform_query( original_query: str = Query( ..., description="Your question here.", examples=["Write a report about Milvus."], ), max_iter: int = Query( 3, description="The maximum number of iterations for reflection.", ge=1, examples=[3], ), ): """ Perform a query against the loaded data. Args: original_query (str): The user's question or query. max_iter (int, optional): Maximum number of iterations for reflection. Defaults to 3. Returns: dict: A dictionary containing the query result and token consumption. Raises: HTTPException: If the query fails. """ try: result_text, _ = query(original_query, max_iter) return { "result": result_text } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": parser = argparse.ArgumentParser(description="FastAPI Server") parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to") parser.add_argument("--port", type=int, default=8000, help="Port to bind the server to") parser.add_argument("--enable-cors", action="store_true", default=False, help="Enable CORS support") args = parser.parse_args() if args.enable_cors: app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) print("CORS is enabled.") else: print("CORS is disabled.") print(f"Starting server on {args.host}:{args.port}") uvicorn.run(app, host=args.host, port=args.port)