Browse Source

更换provider,修复语法错误

main
tanxing 1 week ago
parent
commit
2b20787e50
  1. 12
      deepsearcher/config.yaml
  2. 8
      deepsearcher/embedding/base.py
  3. 11
      deepsearcher/embedding/openai_embedding.py
  4. 2
      deepsearcher/llm/openai_llm.py
  5. 9
      deepsearcher/offline_loading.py
  6. 50
      deepsearcher/online_query.py
  7. 33
      main.py
  8. 2
      test.py

12
deepsearcher/config.yaml

@ -2,16 +2,16 @@ provide_settings:
llm:
provider: "OpenAILLM"
config:
model: "Qwen/Qwen3-8B-FP8"
api_key: "empty"
base_url: "http://localhost:8000/v1"
model: "Qwen/Qwen3-30B-A3B"
api_key: "sk-fpzwvagjkhwysjsozfybvtjzongatcwqdihdxzuijnfdrjzt"
base_url: "https://api.siliconflow.cn/v1"
embedding:
provider: "OpenAIEmbedding"
config:
config:
model: "Qwen/Qwen3-Embedding-0.6B"
api_key: "empty"
base_url: "http://localhost:8001/v1"
api_key: "sk-fpzwvagjkhwysjsozfybvtjzongatcwqdihdxzuijnfdrjzt"
base_url: "https://api.siliconflow.cn/v1"
dimension: 1024
dim_change: false

8
deepsearcher/embedding/base.py

@ -1,5 +1,3 @@
from typing import List
from tqdm import tqdm
from deepsearcher.loader.splitter import Chunk
@ -14,7 +12,7 @@ class BaseEmbedding:
for the dimensionality of the embeddings.
"""
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str) -> list[float]:
"""
Embed a single query text.
@ -26,7 +24,7 @@ class BaseEmbedding:
"""
pass
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""
Embed a list of document texts.
@ -41,7 +39,7 @@ class BaseEmbedding:
"""
return [self.embed_query(text) for text in texts]
def embed_chunks(self, chunks: List[Chunk], batch_size: int = 256) -> List[Chunk]:
def embed_chunks(self, chunks: list[Chunk], batch_size: int = 256) -> list[Chunk]:
"""
Embed a list of Chunk objects.

11
deepsearcher/embedding/openai_embedding.py

@ -1,5 +1,4 @@
import os
from typing import List
from openai import OpenAI
from openai._types import NOT_GIVEN
@ -45,20 +44,20 @@ class OpenAIEmbedding(BaseEmbedding):
model = kwargs.pop("model_name")
if "dimension" in kwargs:
dimension = kwargs.pop("dimension")
dimension = kwargs.pop("dimension")
else:
dimension = NOT_GIVEN
if "dim_change" in kwargs:
dim_change = kwargs.pop("dim_change")
self.dim = dimension
self.dim_change = dim_change
self.model = model
self.client = OpenAI(api_key=api_key, base_url=base_url, **kwargs)
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str) -> list[float]:
"""
Embed a single query text.
@ -75,7 +74,7 @@ class OpenAIEmbedding(BaseEmbedding):
return response.data[0].embedding
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""
Embed a list of document texts.

2
deepsearcher/llm/openai_llm.py

@ -59,7 +59,9 @@ class OpenAILLM(BaseLLM):
for chunk in completion:
stream_response = chunk.choices[0].delta.content
if stream_response:
print(stream_response, end="", flush=True)
response += stream_response
if stream_callback:
stream_callback(stream_response)
print("\n")
return response

9
deepsearcher/offline_loading.py

@ -1,6 +1,5 @@
import hashlib
import os
from typing import List, Union
from tqdm import tqdm
@ -10,7 +9,7 @@ from deepsearcher.loader.splitter import split_docs_to_chunks
def load_from_local_files(
paths_or_directory: Union[str, List[str]],
paths_or_directory: str | list[str],
collection_name: str = None,
collection_description: str = None,
force_new_collection: bool = False,
@ -50,7 +49,7 @@ def load_from_local_files(
description=collection_description,
force_new_collection=force_new_collection,
)
# 如果force_rebuild为True,则强制重建集合
if force_rebuild:
vector_db.init_collection(
@ -59,7 +58,7 @@ def load_from_local_files(
description=collection_description,
force_new_collection=True,
)
if isinstance(paths_or_directory, str):
paths_or_directory = [paths_or_directory]
all_docs = []
@ -92,7 +91,7 @@ def load_from_local_files(
def load_from_website(
urls: Union[str, List[str]],
urls: str | list[str],
collection_name: str = None,
collection_description: str = None,
force_new_collection: bool = False,

50
deepsearcher/online_query.py

@ -1,11 +1,9 @@
from typing import List, Tuple
# from deepsearcher.configuration import vector_db, embedding_model, llm
from deepsearcher import configuration
from deepsearcher.vector_db.base import RetrievalResult
def query(original_query: str, max_iter: int = 3) -> Tuple[str, List[RetrievalResult]]:
def query(original_query: str, max_iter: int = 3) -> tuple[str, list[RetrievalResult]]:
"""
Query the knowledge base with a question and get an answer.
@ -27,7 +25,7 @@ def query(original_query: str, max_iter: int = 3) -> Tuple[str, List[RetrievalRe
def retrieve(
original_query: str, max_iter: int = 3
) -> Tuple[List[RetrievalResult], List[str]]:
) -> tuple[list[RetrievalResult], list[str]]:
"""
Retrieve relevant information from the knowledge base without generating an answer.
@ -48,47 +46,3 @@ def retrieve(
original_query, max_iter=max_iter
)
return retrieved_results, []
def naive_retrieve(query: str, collection: str = None, top_k=10) -> List[RetrievalResult]:
"""
Perform a simple retrieval from the knowledge base using the naive RAG approach.
This function uses the naive RAG agent to retrieve information from the knowledge base
without any advanced techniques like iterative refinement.
Args:
query: The question or query to search for.
collection: The name of the collection to search in. If None, searches in all collections.
top_k: The maximum number of results to return.
Returns:
A list of retrieval results.
"""
naive_rag = configuration.naive_rag
all_retrieved_results, _ = naive_rag.retrieve(query)
return all_retrieved_results
def naive_rag_query(
query: str, collection: str = None, top_k=10
) -> Tuple[str, List[RetrievalResult]]:
"""
Query the knowledge base using the naive RAG approach and get an answer.
This function uses the naive RAG agent to query the knowledge base and generate
an answer based on the retrieved information, without any advanced techniques.
Args:
query: The question or query to search for.
collection: The name of the collection to search in. If None, searches in all collections.
top_k: The maximum number of results to consider.
Returns:
A tuple containing:
- The generated answer as a string
- A list of retrieval results that were used to generate the answer
"""
naive_rag = configuration.naive_rag
answer, retrieved_results = naive_rag.query(query)
return answer, retrieved_results

33
main.py

@ -1,5 +1,4 @@
import argparse
from typing import Dict, List, Union
import uvicorn
from fastapi import Body, FastAPI, HTTPException, Query
@ -31,7 +30,7 @@ class ProviderConfigRequest(BaseModel):
feature: str
provider: str
config: Dict
config: dict
@app.get("/", response_class=HTMLResponse)
@ -43,15 +42,15 @@ async def read_root():
current_dir = os.path.dirname(os.path.abspath(__file__))
# 构建模板文件路径 - 修复路径问题
template_path = os.path.join(current_dir, "deepsearcher", "backend", "templates", "index.html")
# 读取HTML文件内容
try:
with open(template_path, "r", encoding="utf-8") as file:
with open(template_path, encoding="utf-8") as file:
html_content = file.read()
return HTMLResponse(content=html_content, status_code=200)
except FileNotFoundError:
# 如果找不到文件,提供一个简单的默认页面
default_html = """
default_html = f"""
<!DOCTYPE html>
<html>
<head>
@ -71,7 +70,7 @@ async def read_root():
<div class="info">
<p>欢迎使用 DeepSearcher 智能搜索系统!</p>
<p>系统正在运行但未找到前端模板文件</p>
<p>请确认文件是否存在: {}</p>
<p>请确认文件是否存在: {template_path}</p>
</div>
<div class="info">
<h2>API 接口</h2>
@ -86,7 +85,7 @@ async def read_root():
</div>
</body>
</html>
""".format(template_path)
"""
return HTMLResponse(content=default_html, status_code=200)
@ -118,7 +117,7 @@ def set_provider_config(request: ProviderConfigRequest):
@app.post("/load-files/")
def load_files(
paths: Union[str, List[str]] = Body(
paths: str | list[str] = Body(
...,
description="A list of file paths to be loaded.",
examples=["/path/to/file1", "/path/to/file2", "/path/to/dir1"],
@ -160,7 +159,7 @@ def load_files(
paths_or_directory=paths,
collection_name=collection_name,
collection_description=collection_description,
batch_size=batch_size if batch_size is not None else 256, # 提供默认值
batch_size=batch_size if batch_size is not None else 8, # 提供默认值
)
return {"message": "Files loaded successfully."}
except Exception as e:
@ -169,7 +168,7 @@ def load_files(
@app.post("/load-website/")
def load_website(
urls: Union[str, List[str]] = Body(
urls: str | list[str] = Body(
...,
description="A list of URLs of websites to be loaded.",
examples=["https://milvus.io/docs/overview.md"],
@ -249,15 +248,15 @@ def perform_query(
# 清除之前的进度消息
from deepsearcher.utils.log import clear_progress_messages
clear_progress_messages()
result_text, _, consume_token = query(original_query, max_iter)
# 获取进度消息
from deepsearcher.utils.log import get_progress_messages
progress_messages = get_progress_messages()
return {
"result": result_text,
"result": result_text,
"consume_token": consume_token,
"progress_messages": progress_messages
}
@ -271,7 +270,7 @@ if __name__ == "__main__":
parser.add_argument("--port", type=int, default=8000, help="Port to bind the server to")
parser.add_argument("--enable-cors", action="store_true", default=False, help="Enable CORS support")
args = parser.parse_args()
if args.enable_cors:
app.add_middleware(
CORSMiddleware,
@ -283,6 +282,6 @@ if __name__ == "__main__":
print("CORS is enabled.")
else:
print("CORS is disabled.")
print(f"Starting server on {args.host}:{args.port}")
uvicorn.run(app, host=args.host, port=args.port)
uvicorn.run(app, host=args.host, port=args.port)

2
test.py

@ -11,7 +11,7 @@ config.load_config_from_yaml("deepsearcher/config.yaml")
init_config(config = config)
# Load your local data
load_from_local_files(paths_or_directory="examples/data", force_rebuild=True)
load_from_local_files(paths_or_directory="examples/data", force_rebuild=True, batch_size=8)
# (Optional) Load from web crawling (`FIRECRAWL_API_KEY` env variable required)
# from deepsearcher.offline_loading import load_from_website

Loading…
Cancel
Save