|
|
|
import argparse
|
|
|
|
|
|
|
|
import uvicorn
|
|
|
|
from fastapi import Body, FastAPI, HTTPException, Query
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
from fastapi.responses import HTMLResponse, StreamingResponse, FileResponse, PlainTextResponse
|
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
from pydantic import BaseModel
|
|
|
|
import os
|
|
|
|
import asyncio
|
|
|
|
import json
|
|
|
|
import queue
|
|
|
|
from collections.abc import AsyncGenerator
|
|
|
|
|
|
|
|
from deepsearcher.configuration import Configuration, init_config
|
|
|
|
from deepsearcher.offline_loading import load_from_local_files
|
|
|
|
from deepsearcher.online_query import query
|
|
|
|
from deepsearcher.utils.message_stream import get_message_stream
|
|
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
|
|
# 配置静态文件服务
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
static_dir = os.path.join(current_dir, "deepsearcher", "templates", "static")
|
|
|
|
app.mount("/static", StaticFiles(directory=static_dir), name="static")
|
|
|
|
|
|
|
|
config = Configuration()
|
|
|
|
|
|
|
|
init_config(config)
|
|
|
|
|
|
|
|
# 全局变量用于存储消息流回调
|
|
|
|
message_callbacks = []
|
|
|
|
|
|
|
|
|
|
|
|
class ProviderConfigRequest(BaseModel):
|
|
|
|
"""
|
|
|
|
Request model for setting provider configuration.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
feature (str): The feature to configure (e.g., 'embedding', 'llm').
|
|
|
|
provider (str): The provider name (e.g., 'openai', 'azure').
|
|
|
|
config (Dict): Configuration parameters for the provider.
|
|
|
|
"""
|
|
|
|
|
|
|
|
feature: str
|
|
|
|
provider: str
|
|
|
|
config: dict
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
|
|
async def read_root():
|
|
|
|
"""
|
|
|
|
Serve the main HTML page.
|
|
|
|
"""
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
template_path = os.path.join(current_dir, "deepsearcher", "templates", "html", "index.html")
|
|
|
|
|
|
|
|
try:
|
|
|
|
with open(template_path, encoding="utf-8") as file:
|
|
|
|
html_content = file.read()
|
|
|
|
return HTMLResponse(content=html_content, status_code=200)
|
|
|
|
except FileNotFoundError:
|
|
|
|
raise HTTPException(status_code=404, detail="Template file not found")
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/set-provider-config/")
|
|
|
|
def set_provider_config(request: ProviderConfigRequest):
|
|
|
|
"""
|
|
|
|
Set configuration for a specific provider.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
request (ProviderConfigRequest): The request containing provider configuration.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
dict: A dictionary containing a success message and the updated configuration.
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
HTTPException: If setting the provider config fails.
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
config.set_provider_config(request.feature, request.provider, request.config)
|
|
|
|
init_config(config)
|
|
|
|
return {
|
|
|
|
"message": "Provider config set successfully",
|
|
|
|
"provider": request.provider,
|
|
|
|
"config": request.config,
|
|
|
|
}
|
|
|
|
except Exception as e:
|
|
|
|
raise HTTPException(status_code=500, detail=f"Failed to set provider config: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/load-files/")
|
|
|
|
def load_files(
|
|
|
|
paths: str | list[str] = Body(
|
|
|
|
...,
|
|
|
|
description="A list of file paths to be loaded.",
|
|
|
|
examples=["/path/to/file1", "/path/to/file2", "/path/to/dir1"],
|
|
|
|
),
|
|
|
|
collection_name: str = Body(
|
|
|
|
None,
|
|
|
|
description="Optional name for the collection.",
|
|
|
|
examples=["my_collection"],
|
|
|
|
),
|
|
|
|
collection_description: str = Body(
|
|
|
|
None,
|
|
|
|
description="Optional description for the collection.",
|
|
|
|
examples=["This is a test collection."],
|
|
|
|
),
|
|
|
|
batch_size: int = Body(
|
|
|
|
None,
|
|
|
|
description="Optional batch size for the collection.",
|
|
|
|
examples=[256],
|
|
|
|
),
|
|
|
|
force_rebuild: bool = Body(
|
|
|
|
True,
|
|
|
|
description="Whether to force rebuild the collection if it already exists.",
|
|
|
|
examples=[False],
|
|
|
|
),
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Load files into the vector database.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
paths (Union[str, List[str]]): File paths or directories to load.
|
|
|
|
collection_name (str, optional): Name for the collection. Defaults to None.
|
|
|
|
collection_description (str, optional): Description for the collection. Defaults to None.
|
|
|
|
batch_size (int, optional): Batch size for processing. Defaults to None.
|
|
|
|
force_rebuild (bool, optional): Whether to force rebuild the collection. Defaults to False.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
dict: A dictionary containing a success message.
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
HTTPException: If loading files fails.
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
load_from_local_files(
|
|
|
|
paths_or_directory=paths,
|
|
|
|
collection_name=collection_name,
|
|
|
|
collection_description=collection_description,
|
|
|
|
batch_size=batch_size if batch_size is not None else 8,
|
|
|
|
force_rebuild=force_rebuild,
|
|
|
|
)
|
|
|
|
return {"message": "成功加载"}
|
|
|
|
except Exception as e:
|
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/query-stream/")
|
|
|
|
async def perform_query_stream(
|
|
|
|
original_query: str = Query(
|
|
|
|
...,
|
|
|
|
description="Your question here.",
|
|
|
|
examples=["Write a report about Milvus."],
|
|
|
|
),
|
|
|
|
max_iter: int = Query(
|
|
|
|
3,
|
|
|
|
description="The maximum number of iterations for reflection.",
|
|
|
|
ge=1,
|
|
|
|
examples=[3],
|
|
|
|
),
|
|
|
|
) -> StreamingResponse:
|
|
|
|
"""
|
|
|
|
Perform a query with real-time streaming of progress and results.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
original_query (str): The user's question or query.
|
|
|
|
max_iter (int, optional): Maximum number of iterations for reflection. Defaults to 3.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
StreamingResponse: A streaming response with real-time query progress and results.
|
|
|
|
"""
|
|
|
|
async def query_stream_generator() -> AsyncGenerator[str, None]:
|
|
|
|
"""生成查询流"""
|
|
|
|
message_callback = None
|
|
|
|
try:
|
|
|
|
# 清空之前的消息
|
|
|
|
message_stream = get_message_stream()
|
|
|
|
message_stream.clear_messages()
|
|
|
|
|
|
|
|
# 发送查询开始消息
|
|
|
|
yield f"data: {json.dumps({'type': 'start', 'content': '开始查询'}, ensure_ascii=False)}\n\n"
|
|
|
|
|
|
|
|
# 创建一个线程安全的队列来接收消息
|
|
|
|
message_queue = queue.Queue()
|
|
|
|
|
|
|
|
def message_callback(message):
|
|
|
|
"""消息回调函数"""
|
|
|
|
try:
|
|
|
|
# 直接放入线程安全的队列
|
|
|
|
message_queue.put(message)
|
|
|
|
except Exception as e:
|
|
|
|
print(f"Error in message callback: {e}")
|
|
|
|
|
|
|
|
# 注册回调函数
|
|
|
|
message_stream.add_callback(message_callback)
|
|
|
|
|
|
|
|
# 在后台线程中执行查询
|
|
|
|
def run_query():
|
|
|
|
try:
|
|
|
|
print(f"Starting query: {original_query} with max_iter: {max_iter}")
|
|
|
|
result_text, retrieval_results = query(original_query, max_iter=max_iter)
|
|
|
|
print(f"Query completed with result length: {len(result_text) if result_text else 0}")
|
|
|
|
print(f"Retrieved {len(retrieval_results) if retrieval_results else 0} documents")
|
|
|
|
return result_text, None
|
|
|
|
except Exception as e:
|
|
|
|
import traceback
|
|
|
|
print(f"Query failed with error: {e}")
|
|
|
|
print(f"Traceback: {traceback.format_exc()}")
|
|
|
|
return None, str(e)
|
|
|
|
|
|
|
|
# 使用线程池执行查询
|
|
|
|
import concurrent.futures
|
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
|
|
query_future = executor.submit(run_query)
|
|
|
|
|
|
|
|
# 监听消息和查询结果
|
|
|
|
while True:
|
|
|
|
try:
|
|
|
|
# 检查查询是否已完成
|
|
|
|
if query_future.done():
|
|
|
|
result_text, error = query_future.result()
|
|
|
|
if error:
|
|
|
|
print(f"Query error: {error}")
|
|
|
|
yield f"data: {json.dumps({'type': 'error', 'content': f'查询失败: {error}'}, ensure_ascii=False)}\n\n"
|
|
|
|
return
|
|
|
|
# 成功完成时,先发送answer消息包含查询结果
|
|
|
|
if result_text:
|
|
|
|
print(f"Sending answer message with result length: {len(result_text)}")
|
|
|
|
yield f"data: {json.dumps({'type': 'answer', 'content': result_text}, ensure_ascii=False)}\n\n"
|
|
|
|
# 然后发送complete消息并结束
|
|
|
|
print("Query completed successfully, sending complete message")
|
|
|
|
yield f"data: {json.dumps({'type': 'complete', 'content': '查询完成'}, ensure_ascii=False)}\n\n"
|
|
|
|
print("Complete message sent, ending stream")
|
|
|
|
return
|
|
|
|
|
|
|
|
# 尝试从队列获取消息,设置超时
|
|
|
|
try:
|
|
|
|
message = message_queue.get(timeout=2.0)
|
|
|
|
message_data = {
|
|
|
|
'type': message.type.value,
|
|
|
|
'content': message.content,
|
|
|
|
'timestamp': message.timestamp,
|
|
|
|
'metadata': message.metadata or {}
|
|
|
|
}
|
|
|
|
yield f"data: {json.dumps(message_data, ensure_ascii=False)}\n\n"
|
|
|
|
except queue.Empty:
|
|
|
|
# 超时,继续循环
|
|
|
|
continue
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
print(f"Error in stream loop: {e}")
|
|
|
|
yield f"data: {json.dumps({'type': 'error', 'content': f'流处理错误: {str(e)}'}, ensure_ascii=False)}\n\n"
|
|
|
|
return
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
yield f"data: {json.dumps({'type': 'error', 'content': f'查询错误: {str(e)}'}, ensure_ascii=False)}\n\n"
|
|
|
|
finally:
|
|
|
|
# 清理回调函数
|
|
|
|
if message_callback:
|
|
|
|
message_stream.remove_callback(message_callback)
|
|
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
query_stream_generator(),
|
|
|
|
media_type="text/event-stream",
|
|
|
|
headers={
|
|
|
|
"Cache-Control": "no-cache",
|
|
|
|
"Connection": "keep-alive",
|
|
|
|
"Access-Control-Allow-Origin": "*",
|
|
|
|
"Access-Control-Allow-Headers": "Cache-Control"
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/messages/")
|
|
|
|
def get_messages():
|
|
|
|
"""
|
|
|
|
Get all messages from the message stream.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
dict: A dictionary containing all messages.
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
message_stream = get_message_stream()
|
|
|
|
return {
|
|
|
|
"messages": message_stream.get_messages_as_dicts()
|
|
|
|
}
|
|
|
|
except Exception as e:
|
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/stream-messages/")
|
|
|
|
async def stream_messages() -> StreamingResponse:
|
|
|
|
"""
|
|
|
|
Stream messages in real-time using Server-Sent Events.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
StreamingResponse: A streaming response with real-time messages.
|
|
|
|
"""
|
|
|
|
async def event_generator() -> AsyncGenerator[str, None]:
|
|
|
|
"""生成SSE事件流"""
|
|
|
|
# 创建一个队列来接收消息
|
|
|
|
message_queue = asyncio.Queue()
|
|
|
|
|
|
|
|
def message_callback(message):
|
|
|
|
"""消息回调函数"""
|
|
|
|
try:
|
|
|
|
# 将消息放入队列
|
|
|
|
asyncio.create_task(message_queue.put(message))
|
|
|
|
except Exception as e:
|
|
|
|
print(f"Error in message callback: {e}")
|
|
|
|
|
|
|
|
# 注册回调函数
|
|
|
|
message_callbacks.append(message_callback)
|
|
|
|
message_stream = get_message_stream()
|
|
|
|
message_stream.add_callback(message_callback)
|
|
|
|
|
|
|
|
try:
|
|
|
|
# 发送连接建立消息
|
|
|
|
yield f"data: {json.dumps({'type': 'connection', 'message': 'Connected to message stream'}, ensure_ascii=False)}\n\n"
|
|
|
|
|
|
|
|
while True:
|
|
|
|
try:
|
|
|
|
# 等待新消息,设置超时以便能够检查连接状态
|
|
|
|
message = await asyncio.wait_for(message_queue.get(), timeout=2.0)
|
|
|
|
|
|
|
|
# 发送消息数据
|
|
|
|
message_data = {
|
|
|
|
'type': message.type.value,
|
|
|
|
'content': message.content,
|
|
|
|
'timestamp': message.timestamp,
|
|
|
|
'metadata': message.metadata or {}
|
|
|
|
}
|
|
|
|
yield f"data: {json.dumps(message_data, ensure_ascii=False)}\n\n"
|
|
|
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
# 发送心跳保持连接
|
|
|
|
yield f"data: {json.dumps({'type': 'heartbeat', 'timestamp': asyncio.get_event_loop().time()}, ensure_ascii=False)}\n\n"
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
print(f"Error in event generator: {e}")
|
|
|
|
finally:
|
|
|
|
# 清理回调函数
|
|
|
|
if message_callback in message_callbacks:
|
|
|
|
message_callbacks.remove(message_callback)
|
|
|
|
message_stream.remove_callback(message_callback)
|
|
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
event_generator(),
|
|
|
|
media_type="text/event-stream",
|
|
|
|
headers={
|
|
|
|
"Cache-Control": "no-cache",
|
|
|
|
"Connection": "keep-alive",
|
|
|
|
"Access-Control-Allow-Origin": "*",
|
|
|
|
"Access-Control-Allow-Headers": "Cache-Control"
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/clear-messages/")
|
|
|
|
def clear_messages():
|
|
|
|
"""
|
|
|
|
Clear all messages from the message stream.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
dict: A dictionary containing a success message.
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
message_stream = get_message_stream()
|
|
|
|
message_stream.clear_messages()
|
|
|
|
return {
|
|
|
|
"message": "Messages cleared successfully"
|
|
|
|
}
|
|
|
|
except Exception as e:
|
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/file/{file_path:path}")
|
|
|
|
def serve_file(file_path: str, download: bool = Query(False, description="Whether to download the file")):
|
|
|
|
"""
|
|
|
|
Serve local files for file:// URIs in generated reports.
|
|
|
|
|
|
|
|
This endpoint allows accessing local files that are referenced in the
|
|
|
|
generated reports. The file_path parameter should be the URL-encoded
|
|
|
|
path to the file.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
file_path (str): The URL-encoded file path
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
HTMLResponse or PlainTextResponse: The file content displayed in browser
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
HTTPException: If the file is not found or access is denied
|
|
|
|
"""
|
|
|
|
import urllib.parse
|
|
|
|
import mimetypes
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
try:
|
|
|
|
# URL解码文件路径
|
|
|
|
decoded_path = urllib.parse.unquote(file_path)
|
|
|
|
|
|
|
|
# 转换为Path对象
|
|
|
|
file_path_obj = Path(decoded_path)
|
|
|
|
|
|
|
|
# 安全检查:确保文件路径是绝对路径
|
|
|
|
if not file_path_obj.is_absolute():
|
|
|
|
raise HTTPException(status_code=400, detail="Only absolute file paths are allowed")
|
|
|
|
|
|
|
|
# 安全检查:确保文件存在
|
|
|
|
if not file_path_obj.exists():
|
|
|
|
raise HTTPException(status_code=404, detail=f"File not found: {decoded_path}")
|
|
|
|
|
|
|
|
# 安全检查:确保是文件而不是目录
|
|
|
|
if not file_path_obj.is_file():
|
|
|
|
raise HTTPException(status_code=400, detail=f"Path is not a file: {decoded_path}")
|
|
|
|
|
|
|
|
# 如果请求下载,直接返回文件
|
|
|
|
if download:
|
|
|
|
return FileResponse(
|
|
|
|
path=str(file_path_obj),
|
|
|
|
filename=file_path_obj.name,
|
|
|
|
media_type='application/octet-stream'
|
|
|
|
)
|
|
|
|
|
|
|
|
# 尝试读取文件内容
|
|
|
|
try:
|
|
|
|
with open(file_path_obj, 'r', encoding='utf-8') as f:
|
|
|
|
content = f.read()
|
|
|
|
except UnicodeDecodeError:
|
|
|
|
# 如果UTF-8解码失败,尝试其他编码
|
|
|
|
try:
|
|
|
|
with open(file_path_obj, 'r', encoding='latin-1') as f:
|
|
|
|
content = f.read()
|
|
|
|
except Exception as e:
|
|
|
|
raise HTTPException(status_code=500, detail=f"Error reading file: {str(e)}")
|
|
|
|
except Exception as e:
|
|
|
|
raise HTTPException(status_code=500, detail=f"Error reading file: {str(e)}")
|
|
|
|
|
|
|
|
# 获取文件类型
|
|
|
|
mime_type, _ = mimetypes.guess_type(str(file_path_obj))
|
|
|
|
|
|
|
|
# 根据文件类型决定如何显示
|
|
|
|
if mime_type and mime_type.startswith('text/'):
|
|
|
|
# 文本文件直接在浏览器中显示
|
|
|
|
return PlainTextResponse(content=content, media_type=mime_type)
|
|
|
|
else:
|
|
|
|
# 其他文件类型创建HTML页面显示
|
|
|
|
html_content = f"""
|
|
|
|
<!DOCTYPE html>
|
|
|
|
<html lang="zh-CN">
|
|
|
|
<head>
|
|
|
|
<meta charset="UTF-8">
|
|
|
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
|
|
<title>文件查看器 - {file_path_obj.name}</title>
|
|
|
|
<style>
|
|
|
|
body {{
|
|
|
|
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
|
|
|
margin: 0;
|
|
|
|
padding: 20px;
|
|
|
|
background-color: #f5f5f5;
|
|
|
|
}}
|
|
|
|
.container {{
|
|
|
|
max-width: 1200px;
|
|
|
|
margin: 0 auto;
|
|
|
|
background-color: white;
|
|
|
|
border-radius: 8px;
|
|
|
|
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
|
|
|
overflow: hidden;
|
|
|
|
}}
|
|
|
|
.header {{
|
|
|
|
background-color: #2c3e50;
|
|
|
|
color: white;
|
|
|
|
padding: 15px 20px;
|
|
|
|
display: flex;
|
|
|
|
justify-content: space-between;
|
|
|
|
align-items: center;
|
|
|
|
}}
|
|
|
|
.header h1 {{
|
|
|
|
margin: 0;
|
|
|
|
font-size: 1.5em;
|
|
|
|
}}
|
|
|
|
.file-info {{
|
|
|
|
font-size: 0.9em;
|
|
|
|
opacity: 0.8;
|
|
|
|
}}
|
|
|
|
.content {{
|
|
|
|
padding: 20px;
|
|
|
|
}}
|
|
|
|
.download-btn {{
|
|
|
|
background-color: #3498db;
|
|
|
|
color: white;
|
|
|
|
border: none;
|
|
|
|
padding: 8px 16px;
|
|
|
|
border-radius: 4px;
|
|
|
|
cursor: pointer;
|
|
|
|
text-decoration: none;
|
|
|
|
display: inline-block;
|
|
|
|
margin-left: 10px;
|
|
|
|
}}
|
|
|
|
.download-btn:hover {{
|
|
|
|
background-color: #2980b9;
|
|
|
|
}}
|
|
|
|
pre {{
|
|
|
|
background-color: #f8f9fa;
|
|
|
|
border: 1px solid #e9ecef;
|
|
|
|
border-radius: 4px;
|
|
|
|
padding: 15px;
|
|
|
|
overflow-x: auto;
|
|
|
|
white-space: pre-wrap;
|
|
|
|
word-wrap: break-word;
|
|
|
|
font-family: 'Courier New', monospace;
|
|
|
|
font-size: 14px;
|
|
|
|
line-height: 1.5;
|
|
|
|
}}
|
|
|
|
.binary-notice {{
|
|
|
|
background-color: #fff3cd;
|
|
|
|
border: 1px solid #ffeaa7;
|
|
|
|
border-radius: 4px;
|
|
|
|
padding: 15px;
|
|
|
|
margin-bottom: 20px;
|
|
|
|
color: #856404;
|
|
|
|
}}
|
|
|
|
</style>
|
|
|
|
</head>
|
|
|
|
<body>
|
|
|
|
<div class="container">
|
|
|
|
<div class="header">
|
|
|
|
<div>
|
|
|
|
<h1>{file_path_obj.name}</h1>
|
|
|
|
<div class="file-info">
|
|
|
|
路径: {decoded_path}<br>
|
|
|
|
大小: {file_path_obj.stat().st_size:,} 字节
|
|
|
|
</div>
|
|
|
|
</div>
|
|
|
|
<a href="/file/{file_path}?download=true" class="download-btn">下载文件</a>
|
|
|
|
</div>
|
|
|
|
<div class="content">
|
|
|
|
"""
|
|
|
|
|
|
|
|
# 检查是否为二进制文件
|
|
|
|
try:
|
|
|
|
# 尝试读取前1024字节来检测是否为二进制文件
|
|
|
|
with open(file_path_obj, 'rb') as f:
|
|
|
|
sample = f.read(1024)
|
|
|
|
# 检查是否包含null字节,这是二进制文件的特征
|
|
|
|
if b'\x00' in sample:
|
|
|
|
html_content += f"""
|
|
|
|
<div class="binary-notice">
|
|
|
|
<strong>注意:</strong>这是一个二进制文件,无法在浏览器中直接显示内容。
|
|
|
|
</div>
|
|
|
|
"""
|
|
|
|
else:
|
|
|
|
# 尝试以文本形式显示
|
|
|
|
try:
|
|
|
|
text_content = sample.decode('utf-8')
|
|
|
|
html_content += f"""
|
|
|
|
<pre>{text_content}</pre>
|
|
|
|
"""
|
|
|
|
except UnicodeDecodeError:
|
|
|
|
html_content += f"""
|
|
|
|
<div class="binary-notice">
|
|
|
|
<strong>注意:</strong>此文件包含非文本内容,无法在浏览器中直接显示。
|
|
|
|
</div>
|
|
|
|
"""
|
|
|
|
except Exception:
|
|
|
|
html_content += f"""
|
|
|
|
<div class="binary-notice">
|
|
|
|
<strong>注意:</strong>无法读取文件内容。
|
|
|
|
</div>
|
|
|
|
"""
|
|
|
|
|
|
|
|
html_content += """
|
|
|
|
</div>
|
|
|
|
</div>
|
|
|
|
</body>
|
|
|
|
</html>
|
|
|
|
"""
|
|
|
|
return HTMLResponse(content=html_content)
|
|
|
|
|
|
|
|
except HTTPException:
|
|
|
|
# 重新抛出HTTP异常
|
|
|
|
raise
|
|
|
|
except Exception as e:
|
|
|
|
raise HTTPException(status_code=500, detail=f"Error processing file request: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser(description="FastAPI Server")
|
|
|
|
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to")
|
|
|
|
parser.add_argument("--port", type=int, default=8000, help="Port to bind the server to")
|
|
|
|
parser.add_argument("--enable-cors", action="store_true", default=False, help="Enable CORS support")
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
if args.enable_cors:
|
|
|
|
app.add_middleware(
|
|
|
|
CORSMiddleware,
|
|
|
|
allow_origins=["*"],
|
|
|
|
allow_credentials=True,
|
|
|
|
allow_methods=["*"],
|
|
|
|
allow_headers=["*"],
|
|
|
|
)
|
|
|
|
print("CORS is enabled.")
|
|
|
|
else:
|
|
|
|
print("CORS is disabled.")
|
|
|
|
|
|
|
|
print(f"Starting server on {args.host}:{args.port}")
|
|
|
|
uvicorn.run(app, host=args.host, port=args.port)
|