You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

277 lines
8.2 KiB

2 weeks ago
import argparse
import uvicorn
from fastapi import Body, FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
2 weeks ago
from pydantic import BaseModel
import os
2 weeks ago
from deepsearcher.configuration import Configuration, init_config
from deepsearcher.offline_loading import load_from_local_files, load_from_website
from deepsearcher.online_query import query
from deepsearcher.utils.message_stream import get_message_stream
2 weeks ago
app = FastAPI()
config = Configuration()
init_config(config)
class ProviderConfigRequest(BaseModel):
"""
Request model for setting provider configuration.
Attributes:
feature (str): The feature to configure (e.g., 'embedding', 'llm').
provider (str): The provider name (e.g., 'openai', 'azure').
config (Dict): Configuration parameters for the provider.
"""
feature: str
provider: str
config: dict
2 weeks ago
@app.get("/", response_class=HTMLResponse)
async def read_root():
"""
Serve the main HTML page.
"""
current_dir = os.path.dirname(os.path.abspath(__file__))
template_path = os.path.join(current_dir, "deepsearcher", "backend", "templates", "index.html")
try:
with open(template_path, encoding="utf-8") as file:
html_content = file.read()
return HTMLResponse(content=html_content, status_code=200)
except FileNotFoundError:
7 days ago
raise HTTPException(status_code=404, detail="Template file not found")
2 weeks ago
@app.post("/set-provider-config/")
def set_provider_config(request: ProviderConfigRequest):
"""
Set configuration for a specific provider.
Args:
request (ProviderConfigRequest): The request containing provider configuration.
Returns:
dict: A dictionary containing a success message and the updated configuration.
Raises:
HTTPException: If setting the provider config fails.
"""
try:
config.set_provider_config(request.feature, request.provider, request.config)
init_config(config)
return {
"message": "Provider config set successfully",
"provider": request.provider,
"config": request.config,
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to set provider config: {str(e)}")
@app.post("/load-files/")
def load_files(
paths: str | list[str] = Body(
2 weeks ago
...,
description="A list of file paths to be loaded.",
examples=["/path/to/file1", "/path/to/file2", "/path/to/dir1"],
),
collection_name: str = Body(
None,
description="Optional name for the collection.",
examples=["my_collection"],
),
collection_description: str = Body(
None,
description="Optional description for the collection.",
examples=["This is a test collection."],
),
batch_size: int = Body(
None,
description="Optional batch size for the collection.",
examples=[256],
),
):
"""
Load files into the vector database.
Args:
paths (Union[str, List[str]]): File paths or directories to load.
collection_name (str, optional): Name for the collection. Defaults to None.
collection_description (str, optional): Description for the collection. Defaults to None.
batch_size (int, optional): Batch size for processing. Defaults to None.
Returns:
dict: A dictionary containing a success message.
Raises:
HTTPException: If loading files fails.
"""
try:
load_from_local_files(
paths_or_directory=paths,
collection_name=collection_name,
collection_description=collection_description,
7 days ago
batch_size=batch_size if batch_size is not None else 8,
2 weeks ago
)
return {"message": "Files loaded successfully."}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/load-website/")
def load_website(
urls: str | list[str] = Body(
2 weeks ago
...,
description="A list of URLs of websites to be loaded.",
examples=["https://milvus.io/docs/overview.md"],
),
collection_name: str = Body(
None,
description="Optional name for the collection.",
examples=["my_collection"],
),
collection_description: str = Body(
None,
description="Optional description for the collection.",
examples=["This is a test collection."],
),
batch_size: int = Body(
None,
description="Optional batch size for the collection.",
examples=[256],
),
):
"""
Load website content into the vector database.
Args:
urls (Union[str, List[str]]): URLs of websites to load.
collection_name (str, optional): Name for the collection. Defaults to None.
collection_description (str, optional): Description for the collection. Defaults to None.
batch_size (int, optional): Batch size for processing. Defaults to None.
Returns:
dict: A dictionary containing a success message.
Raises:
HTTPException: If loading website content fails.
"""
try:
load_from_website(
urls=urls,
collection_name=collection_name,
collection_description=collection_description,
7 days ago
batch_size=batch_size if batch_size is not None else 8,
2 weeks ago
)
return {"message": "Website loaded successfully."}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/query/")
def perform_query(
original_query: str = Query(
...,
description="Your question here.",
examples=["Write a report about Milvus."],
),
max_iter: int = Query(
3,
description="The maximum number of iterations for reflection.",
ge=1,
examples=[3],
),
):
"""
Perform a query against the loaded data.
Args:
original_query (str): The user's question or query.
max_iter (int, optional): Maximum number of iterations for reflection. Defaults to 3.
Returns:
dict: A dictionary containing the query result and token consumption.
Raises:
HTTPException: If the query fails.
"""
try:
# 清空之前的消息
message_stream = get_message_stream()
message_stream.clear_messages()
7 days ago
result_text, _ = query(original_query, max_iter)
return {
"result": result_text,
"messages": message_stream.get_messages_as_dicts()
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/messages/")
def get_messages():
"""
Get all messages from the message stream.
Returns:
dict: A dictionary containing all messages.
"""
try:
message_stream = get_message_stream()
return {
"messages": message_stream.get_messages_as_dicts()
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/clear-messages/")
def clear_messages():
"""
Clear all messages from the message stream.
Returns:
dict: A dictionary containing a success message.
"""
try:
message_stream = get_message_stream()
message_stream.clear_messages()
return {
"message": "Messages cleared successfully"
}
2 weeks ago
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="FastAPI Server")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to")
parser.add_argument("--port", type=int, default=8000, help="Port to bind the server to")
parser.add_argument("--enable-cors", action="store_true", default=False, help="Enable CORS support")
2 weeks ago
args = parser.parse_args()
2 weeks ago
if args.enable_cors:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
print("CORS is enabled.")
else:
print("CORS is disabled.")
print(f"Starting server on {args.host}:{args.port}")
uvicorn.run(app, host=args.host, port=args.port)