Browse Source

更换provider,修复语法错误

main
tanxing 1 week ago
parent
commit
2b20787e50
  1. 12
      deepsearcher/config.yaml
  2. 8
      deepsearcher/embedding/base.py
  3. 11
      deepsearcher/embedding/openai_embedding.py
  4. 2
      deepsearcher/llm/openai_llm.py
  5. 9
      deepsearcher/offline_loading.py
  6. 50
      deepsearcher/online_query.py
  7. 33
      main.py
  8. 2
      test.py

12
deepsearcher/config.yaml

@ -2,16 +2,16 @@ provide_settings:
llm: llm:
provider: "OpenAILLM" provider: "OpenAILLM"
config: config:
model: "Qwen/Qwen3-8B-FP8" model: "Qwen/Qwen3-30B-A3B"
api_key: "empty" api_key: "sk-fpzwvagjkhwysjsozfybvtjzongatcwqdihdxzuijnfdrjzt"
base_url: "http://localhost:8000/v1" base_url: "https://api.siliconflow.cn/v1"
embedding: embedding:
provider: "OpenAIEmbedding" provider: "OpenAIEmbedding"
config: config:
model: "Qwen/Qwen3-Embedding-0.6B" model: "Qwen/Qwen3-Embedding-0.6B"
api_key: "empty" api_key: "sk-fpzwvagjkhwysjsozfybvtjzongatcwqdihdxzuijnfdrjzt"
base_url: "http://localhost:8001/v1" base_url: "https://api.siliconflow.cn/v1"
dimension: 1024 dimension: 1024
dim_change: false dim_change: false

8
deepsearcher/embedding/base.py

@ -1,5 +1,3 @@
from typing import List
from tqdm import tqdm from tqdm import tqdm
from deepsearcher.loader.splitter import Chunk from deepsearcher.loader.splitter import Chunk
@ -14,7 +12,7 @@ class BaseEmbedding:
for the dimensionality of the embeddings. for the dimensionality of the embeddings.
""" """
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> list[float]:
""" """
Embed a single query text. Embed a single query text.
@ -26,7 +24,7 @@ class BaseEmbedding:
""" """
pass pass
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
""" """
Embed a list of document texts. Embed a list of document texts.
@ -41,7 +39,7 @@ class BaseEmbedding:
""" """
return [self.embed_query(text) for text in texts] return [self.embed_query(text) for text in texts]
def embed_chunks(self, chunks: List[Chunk], batch_size: int = 256) -> List[Chunk]: def embed_chunks(self, chunks: list[Chunk], batch_size: int = 256) -> list[Chunk]:
""" """
Embed a list of Chunk objects. Embed a list of Chunk objects.

11
deepsearcher/embedding/openai_embedding.py

@ -1,5 +1,4 @@
import os import os
from typing import List
from openai import OpenAI from openai import OpenAI
from openai._types import NOT_GIVEN from openai._types import NOT_GIVEN
@ -45,20 +44,20 @@ class OpenAIEmbedding(BaseEmbedding):
model = kwargs.pop("model_name") model = kwargs.pop("model_name")
if "dimension" in kwargs: if "dimension" in kwargs:
dimension = kwargs.pop("dimension") dimension = kwargs.pop("dimension")
else: else:
dimension = NOT_GIVEN dimension = NOT_GIVEN
if "dim_change" in kwargs: if "dim_change" in kwargs:
dim_change = kwargs.pop("dim_change") dim_change = kwargs.pop("dim_change")
self.dim = dimension self.dim = dimension
self.dim_change = dim_change self.dim_change = dim_change
self.model = model self.model = model
self.client = OpenAI(api_key=api_key, base_url=base_url, **kwargs) self.client = OpenAI(api_key=api_key, base_url=base_url, **kwargs)
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> list[float]:
""" """
Embed a single query text. Embed a single query text.
@ -75,7 +74,7 @@ class OpenAIEmbedding(BaseEmbedding):
return response.data[0].embedding return response.data[0].embedding
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
""" """
Embed a list of document texts. Embed a list of document texts.

2
deepsearcher/llm/openai_llm.py

@ -59,7 +59,9 @@ class OpenAILLM(BaseLLM):
for chunk in completion: for chunk in completion:
stream_response = chunk.choices[0].delta.content stream_response = chunk.choices[0].delta.content
if stream_response: if stream_response:
print(stream_response, end="", flush=True)
response += stream_response response += stream_response
if stream_callback: if stream_callback:
stream_callback(stream_response) stream_callback(stream_response)
print("\n")
return response return response

9
deepsearcher/offline_loading.py

@ -1,6 +1,5 @@
import hashlib import hashlib
import os import os
from typing import List, Union
from tqdm import tqdm from tqdm import tqdm
@ -10,7 +9,7 @@ from deepsearcher.loader.splitter import split_docs_to_chunks
def load_from_local_files( def load_from_local_files(
paths_or_directory: Union[str, List[str]], paths_or_directory: str | list[str],
collection_name: str = None, collection_name: str = None,
collection_description: str = None, collection_description: str = None,
force_new_collection: bool = False, force_new_collection: bool = False,
@ -50,7 +49,7 @@ def load_from_local_files(
description=collection_description, description=collection_description,
force_new_collection=force_new_collection, force_new_collection=force_new_collection,
) )
# 如果force_rebuild为True,则强制重建集合 # 如果force_rebuild为True,则强制重建集合
if force_rebuild: if force_rebuild:
vector_db.init_collection( vector_db.init_collection(
@ -59,7 +58,7 @@ def load_from_local_files(
description=collection_description, description=collection_description,
force_new_collection=True, force_new_collection=True,
) )
if isinstance(paths_or_directory, str): if isinstance(paths_or_directory, str):
paths_or_directory = [paths_or_directory] paths_or_directory = [paths_or_directory]
all_docs = [] all_docs = []
@ -92,7 +91,7 @@ def load_from_local_files(
def load_from_website( def load_from_website(
urls: Union[str, List[str]], urls: str | list[str],
collection_name: str = None, collection_name: str = None,
collection_description: str = None, collection_description: str = None,
force_new_collection: bool = False, force_new_collection: bool = False,

50
deepsearcher/online_query.py

@ -1,11 +1,9 @@
from typing import List, Tuple
# from deepsearcher.configuration import vector_db, embedding_model, llm # from deepsearcher.configuration import vector_db, embedding_model, llm
from deepsearcher import configuration from deepsearcher import configuration
from deepsearcher.vector_db.base import RetrievalResult from deepsearcher.vector_db.base import RetrievalResult
def query(original_query: str, max_iter: int = 3) -> Tuple[str, List[RetrievalResult]]: def query(original_query: str, max_iter: int = 3) -> tuple[str, list[RetrievalResult]]:
""" """
Query the knowledge base with a question and get an answer. Query the knowledge base with a question and get an answer.
@ -27,7 +25,7 @@ def query(original_query: str, max_iter: int = 3) -> Tuple[str, List[RetrievalRe
def retrieve( def retrieve(
original_query: str, max_iter: int = 3 original_query: str, max_iter: int = 3
) -> Tuple[List[RetrievalResult], List[str]]: ) -> tuple[list[RetrievalResult], list[str]]:
""" """
Retrieve relevant information from the knowledge base without generating an answer. Retrieve relevant information from the knowledge base without generating an answer.
@ -48,47 +46,3 @@ def retrieve(
original_query, max_iter=max_iter original_query, max_iter=max_iter
) )
return retrieved_results, [] return retrieved_results, []
def naive_retrieve(query: str, collection: str = None, top_k=10) -> List[RetrievalResult]:
"""
Perform a simple retrieval from the knowledge base using the naive RAG approach.
This function uses the naive RAG agent to retrieve information from the knowledge base
without any advanced techniques like iterative refinement.
Args:
query: The question or query to search for.
collection: The name of the collection to search in. If None, searches in all collections.
top_k: The maximum number of results to return.
Returns:
A list of retrieval results.
"""
naive_rag = configuration.naive_rag
all_retrieved_results, _ = naive_rag.retrieve(query)
return all_retrieved_results
def naive_rag_query(
query: str, collection: str = None, top_k=10
) -> Tuple[str, List[RetrievalResult]]:
"""
Query the knowledge base using the naive RAG approach and get an answer.
This function uses the naive RAG agent to query the knowledge base and generate
an answer based on the retrieved information, without any advanced techniques.
Args:
query: The question or query to search for.
collection: The name of the collection to search in. If None, searches in all collections.
top_k: The maximum number of results to consider.
Returns:
A tuple containing:
- The generated answer as a string
- A list of retrieval results that were used to generate the answer
"""
naive_rag = configuration.naive_rag
answer, retrieved_results = naive_rag.query(query)
return answer, retrieved_results

33
main.py

@ -1,5 +1,4 @@
import argparse import argparse
from typing import Dict, List, Union
import uvicorn import uvicorn
from fastapi import Body, FastAPI, HTTPException, Query from fastapi import Body, FastAPI, HTTPException, Query
@ -31,7 +30,7 @@ class ProviderConfigRequest(BaseModel):
feature: str feature: str
provider: str provider: str
config: Dict config: dict
@app.get("/", response_class=HTMLResponse) @app.get("/", response_class=HTMLResponse)
@ -43,15 +42,15 @@ async def read_root():
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
# 构建模板文件路径 - 修复路径问题 # 构建模板文件路径 - 修复路径问题
template_path = os.path.join(current_dir, "deepsearcher", "backend", "templates", "index.html") template_path = os.path.join(current_dir, "deepsearcher", "backend", "templates", "index.html")
# 读取HTML文件内容 # 读取HTML文件内容
try: try:
with open(template_path, "r", encoding="utf-8") as file: with open(template_path, encoding="utf-8") as file:
html_content = file.read() html_content = file.read()
return HTMLResponse(content=html_content, status_code=200) return HTMLResponse(content=html_content, status_code=200)
except FileNotFoundError: except FileNotFoundError:
# 如果找不到文件,提供一个简单的默认页面 # 如果找不到文件,提供一个简单的默认页面
default_html = """ default_html = f"""
<!DOCTYPE html> <!DOCTYPE html>
<html> <html>
<head> <head>
@ -71,7 +70,7 @@ async def read_root():
<div class="info"> <div class="info">
<p>欢迎使用 DeepSearcher 智能搜索系统!</p> <p>欢迎使用 DeepSearcher 智能搜索系统!</p>
<p>系统正在运行但未找到前端模板文件</p> <p>系统正在运行但未找到前端模板文件</p>
<p>请确认文件是否存在: {}</p> <p>请确认文件是否存在: {template_path}</p>
</div> </div>
<div class="info"> <div class="info">
<h2>API 接口</h2> <h2>API 接口</h2>
@ -86,7 +85,7 @@ async def read_root():
</div> </div>
</body> </body>
</html> </html>
""".format(template_path) """
return HTMLResponse(content=default_html, status_code=200) return HTMLResponse(content=default_html, status_code=200)
@ -118,7 +117,7 @@ def set_provider_config(request: ProviderConfigRequest):
@app.post("/load-files/") @app.post("/load-files/")
def load_files( def load_files(
paths: Union[str, List[str]] = Body( paths: str | list[str] = Body(
..., ...,
description="A list of file paths to be loaded.", description="A list of file paths to be loaded.",
examples=["/path/to/file1", "/path/to/file2", "/path/to/dir1"], examples=["/path/to/file1", "/path/to/file2", "/path/to/dir1"],
@ -160,7 +159,7 @@ def load_files(
paths_or_directory=paths, paths_or_directory=paths,
collection_name=collection_name, collection_name=collection_name,
collection_description=collection_description, collection_description=collection_description,
batch_size=batch_size if batch_size is not None else 256, # 提供默认值 batch_size=batch_size if batch_size is not None else 8, # 提供默认值
) )
return {"message": "Files loaded successfully."} return {"message": "Files loaded successfully."}
except Exception as e: except Exception as e:
@ -169,7 +168,7 @@ def load_files(
@app.post("/load-website/") @app.post("/load-website/")
def load_website( def load_website(
urls: Union[str, List[str]] = Body( urls: str | list[str] = Body(
..., ...,
description="A list of URLs of websites to be loaded.", description="A list of URLs of websites to be loaded.",
examples=["https://milvus.io/docs/overview.md"], examples=["https://milvus.io/docs/overview.md"],
@ -249,15 +248,15 @@ def perform_query(
# 清除之前的进度消息 # 清除之前的进度消息
from deepsearcher.utils.log import clear_progress_messages from deepsearcher.utils.log import clear_progress_messages
clear_progress_messages() clear_progress_messages()
result_text, _, consume_token = query(original_query, max_iter) result_text, _, consume_token = query(original_query, max_iter)
# 获取进度消息 # 获取进度消息
from deepsearcher.utils.log import get_progress_messages from deepsearcher.utils.log import get_progress_messages
progress_messages = get_progress_messages() progress_messages = get_progress_messages()
return { return {
"result": result_text, "result": result_text,
"consume_token": consume_token, "consume_token": consume_token,
"progress_messages": progress_messages "progress_messages": progress_messages
} }
@ -271,7 +270,7 @@ if __name__ == "__main__":
parser.add_argument("--port", type=int, default=8000, help="Port to bind the server to") parser.add_argument("--port", type=int, default=8000, help="Port to bind the server to")
parser.add_argument("--enable-cors", action="store_true", default=False, help="Enable CORS support") parser.add_argument("--enable-cors", action="store_true", default=False, help="Enable CORS support")
args = parser.parse_args() args = parser.parse_args()
if args.enable_cors: if args.enable_cors:
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
@ -283,6 +282,6 @@ if __name__ == "__main__":
print("CORS is enabled.") print("CORS is enabled.")
else: else:
print("CORS is disabled.") print("CORS is disabled.")
print(f"Starting server on {args.host}:{args.port}") print(f"Starting server on {args.host}:{args.port}")
uvicorn.run(app, host=args.host, port=args.port) uvicorn.run(app, host=args.host, port=args.port)

2
test.py

@ -11,7 +11,7 @@ config.load_config_from_yaml("deepsearcher/config.yaml")
init_config(config = config) init_config(config = config)
# Load your local data # Load your local data
load_from_local_files(paths_or_directory="examples/data", force_rebuild=True) load_from_local_files(paths_or_directory="examples/data", force_rebuild=True, batch_size=8)
# (Optional) Load from web crawling (`FIRECRAWL_API_KEY` env variable required) # (Optional) Load from web crawling (`FIRECRAWL_API_KEY` env variable required)
# from deepsearcher.offline_loading import load_from_website # from deepsearcher.offline_loading import load_from_website

Loading…
Cancel
Save