Browse Source

feat: 添加本地文件的网页访问接口

main
tanxing 5 days ago
parent
commit
ffc97396dc
  1. 9
      deepsearcher/agent/deep_search.py
  2. 13
      deepsearcher/configuration.py
  3. 50
      deepsearcher/offline_loading.py
  4. 59
      main.py
  5. 8
      test.py

9
deepsearcher/agent/deep_search.py

@ -463,9 +463,14 @@ class DeepSearch(BaseAgent):
# 网页搜索结果直接使用URL
formated_refs.append(f"[^{i + 1}]: {reference}\n")
else:
# 本地文件使用文件URI
# 本地文件使用服务器文件访问接口
try:
formated_refs.append(f"[^{i + 1}]: " + Path(str(Path(reference).resolve())).as_uri() + "\n")
# 获取绝对路径并URL编码
import urllib.parse
absolute_path = str(Path(reference).resolve())
encoded_path = urllib.parse.quote(absolute_path, safe='')
# 使用相对路径,这样可以在不同的服务器配置下工作
formated_refs.append(f"[^{i + 1}]: /file/{encoded_path}\n")
except Exception as _:
formated_refs.append(f"[^{i + 1}]: {reference}\n")

13
deepsearcher/configuration.py

@ -12,7 +12,7 @@ from deepsearcher.vector_db.base import BaseVectorDB
current_dir = os.path.dirname(os.path.abspath(__file__))
DEFAULT_CONFIG_YAML_PATH = os.path.join(current_dir, "config.yaml")
FeatureType = Literal["llm", "embedding", "file_loader", "web_crawler", "vector_db"]
FeatureType = Literal["llm", "embedding", "file_loader", "vector_db"]
class Configuration:
@ -55,7 +55,7 @@ class Configuration:
Set the provider and its configurations for a given feature.
Args:
feature: The feature to configure (e.g., 'llm', 'file_loader', 'web_crawler').
feature: The feature to configure (e.g., 'llm', 'file_loader').
provider: The provider name (e.g., 'openai', 'deepseek').
provider_configs: A dictionary with configurations specific to the provider.
@ -73,7 +73,7 @@ class Configuration:
Get the current provider and configuration for a given feature.
Args:
feature: The feature to retrieve (e.g., 'llm', 'file_loader', 'web_crawler').
feature: The feature to retrieve (e.g., 'llm', 'file_loader').
Returns:
A dictionary with provider and its configurations.
@ -91,7 +91,7 @@ class ModuleFactory:
"""
Factory class for creating instances of various modules in the DeepSearcher system.
This class creates instances of LLMs, embedding models, file loaders, web crawlers,
This class creates instances of LLMs, embedding models, file loaders
and vector databases based on the configuration settings.
"""
@ -175,7 +175,7 @@ def init_config(config: Configuration):
Initialize the global configuration and create instances of all required modules.
This function initializes the global variables for the LLM, embedding model,
file loader, web crawler, vector database, and RAG agents.
file loader, vector database, and RAG agents.
Args:
config: The Configuration object to use for initialization.
@ -199,6 +199,5 @@ def init_config(config: Configuration):
vector_db=vector_db,
max_iter=config.query_settings["max_iter"],
route_collection=False,
text_window_splitter=True,
enable_web_search=config.query_settings.get("enable_web_search", True),
text_window_splitter=True
)

50
deepsearcher/offline_loading.py

@ -80,53 +80,3 @@ def load_from_local_files(
unique_chunks = embedding_model.embed_chunks(unique_chunks, batch_size=batch_size)
vector_db.insert_data(collection=collection_name, chunks=unique_chunks)
def load_from_website(
urls: str | list[str],
collection_name: str = None,
collection_description: str = None,
force_rebuild: bool = False,
chunk_size: int = 1500,
chunk_overlap: int = 100,
batch_size: int = 256,
**crawl_kwargs,
):
"""
Load knowledge from websites into the vector database.
This function crawls the specified URLs, processes the content,
splits it into chunks, embeds the chunks, and stores them in the vector database.
Args:
urls: A single URL or a list of URLs to crawl.
collection_name: Name of the collection to store the data in. If None, uses the default collection.
collection_description: Description of the collection. If None, no description is set.
force_rebuild: If True, drops the existing collection and creates a new one.
chunk_size: Size of each chunk in characters.
chunk_overlap: Number of characters to overlap between chunks.
batch_size: Number of chunks to process at once during embedding.
**crawl_kwargs: Additional keyword arguments to pass to the web crawler.
"""
if isinstance(urls, str):
urls = [urls]
vector_db = configuration.vector_db
embedding_model = configuration.embedding_model
web_crawler = configuration.web_crawler
vector_db.init_collection(
dim=embedding_model.dimension,
collection=collection_name,
description=collection_description,
force_rebuild=force_rebuild,
)
all_docs = web_crawler.crawl_urls(urls, **crawl_kwargs)
chunks = split_docs_to_chunks(
all_docs,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
chunks = embedding_model.embed_chunks(chunks, batch_size=batch_size)
vector_db.insert_data(collection=collection_name, chunks=chunks)

59
main.py

@ -3,7 +3,7 @@ import argparse
import uvicorn
from fastapi import Body, FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, StreamingResponse
from fastapi.responses import HTMLResponse, StreamingResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import os
@ -375,6 +375,63 @@ def clear_messages():
raise HTTPException(status_code=500, detail=str(e))
@app.get("/file/{file_path:path}")
def serve_file(file_path: str):
"""
Serve local files for file:// URIs in generated reports.
This endpoint allows accessing local files that are referenced in the
generated reports. The file_path parameter should be the URL-encoded
path to the file.
Args:
file_path (str): The URL-encoded file path
Returns:
FileResponse: The file content or an error response
Raises:
HTTPException: If the file is not found or access is denied
"""
import urllib.parse
from pathlib import Path
try:
# URL解码文件路径
decoded_path = urllib.parse.unquote(file_path)
# 转换为Path对象
file_path_obj = Path(decoded_path)
# 安全检查:确保文件路径是绝对路径
if not file_path_obj.is_absolute():
raise HTTPException(status_code=400, detail="Only absolute file paths are allowed")
# 安全检查:确保文件存在
if not file_path_obj.exists():
raise HTTPException(status_code=404, detail=f"File not found: {decoded_path}")
# 安全检查:确保是文件而不是目录
if not file_path_obj.is_file():
raise HTTPException(status_code=400, detail=f"Path is not a file: {decoded_path}")
# 尝试读取文件并返回
try:
return FileResponse(
path=str(file_path_obj),
filename=file_path_obj.name,
media_type='application/octet-stream'
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error reading file: {str(e)}")
except HTTPException:
# 重新抛出HTTP异常
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing file request: {str(e)}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="FastAPI Server")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind the server to")

8
test.py

@ -5,8 +5,7 @@ from deepsearcher.online_query import query
config = Configuration()
# Customize your config here,
# more configuration see the Configuration Details section below.
config.load_config_from_yaml("deepsearcher/config.yaml")
init_config(config = config)
@ -19,10 +18,5 @@ load_from_local_files(
force_rebuild=True, batch_size=16
)
# (Optional) Load from web crawling (`FIRECRAWL_API_KEY` env variable required)
# from deepsearcher.offline_loading import load_from_website
# load_from_website(urls=website_url)
# Query
result = query("Write a comprehensive report about Milvus.", max_iter=1) # Your question here

Loading…
Cancel
Save