Browse Source

移除路由机制,所有任务都使用一个llm完成

main
tanxing 2 weeks ago
parent
commit
8d83061233
  1. 60
      deepsearcher/agent/rag_router.py
  2. 29
      deepsearcher/configuration.py
  3. 8
      deepsearcher/loader/file_loader/base.py
  4. 2
      deepsearcher/offline_loading.py

60
deepsearcher/agent/rag_router.py

@ -1,8 +1,6 @@
from typing import List, Optional, Tuple from typing import List, Tuple
from deepsearcher.agent import RAGAgent from deepsearcher.agent import RAGAgent
from deepsearcher.llm.base import BaseLLM
from deepsearcher.utils import log
from deepsearcher.vector_db import RetrievalResult from deepsearcher.vector_db import RetrievalResult
RAG_ROUTER_PROMPT = """Given a list of agent indexes and corresponding descriptions, each agent has a specific function. RAG_ROUTER_PROMPT = """Given a list of agent indexes and corresponding descriptions, each agent has a specific function.
@ -28,9 +26,7 @@ class RAGRouter(RAGAgent):
def __init__( def __init__(
self, self,
llm: BaseLLM, agent: RAGAgent
rag_agents: List[RAGAgent],
agent_descriptions: Optional[List[str]] = None,
): ):
""" """
Initialize the RAGRouter. Initialize the RAGRouter.
@ -40,54 +36,12 @@ class RAGRouter(RAGAgent):
rag_agents: A list of RAGAgent instances. rag_agents: A list of RAGAgent instances.
agent_descriptions (list, optional): A list of descriptions for each agent. agent_descriptions (list, optional): A list of descriptions for each agent.
""" """
self.llm = llm self.agent = agent
self.rag_agents = rag_agents
self.agent_descriptions = agent_descriptions
if not self.agent_descriptions:
try:
self.agent_descriptions = [
agent.__class__.__description__ for agent in self.rag_agents
]
except Exception:
raise AttributeError(
"Please provide agent descriptions or set __description__ attribute for each agent class."
)
def _route(self, query: str) -> Tuple[RAGAgent, int]:
description_str = "\n".join(
[f"[{i + 1}]: {description}" for i, description in enumerate(self.agent_descriptions)]
)
prompt = RAG_ROUTER_PROMPT.format(query=query, description_str=description_str)
chat_response = self.llm.chat(messages=[{"role": "user", "content": prompt}])
try:
selected_agent_index = int(self.llm.remove_think(chat_response.content)) - 1
except ValueError:
# In some reasoning LLM, the output is not a number, but a explaination string with a number in the end.
log.warning(
"Parse int failed in RAGRouter, but will try to find the last digit as fallback."
)
selected_agent_index = (
int(self.find_last_digit(self.llm.remove_think(chat_response.content))) - 1
)
selected_agent = self.rag_agents[selected_agent_index]
log.color_print(
f"<think> Select agent [{selected_agent.__class__.__name__}] to answer the query [{query}] </think>\n"
)
return self.rag_agents[selected_agent_index], chat_response.total_tokens
def retrieve(self, query: str, **kwargs) -> Tuple[List[RetrievalResult], int, dict]: def retrieve(self, query: str, **kwargs) -> Tuple[List[RetrievalResult], int, dict]:
agent, n_token_router = self._route(query) retrieved_results, n_token_retrieval, metadata = self.agent.retrieve(query, **kwargs)
retrieved_results, n_token_retrieval, metadata = agent.retrieve(query, **kwargs) return retrieved_results, n_token_retrieval, metadata
return retrieved_results, n_token_router + n_token_retrieval, metadata
def query(self, query: str, **kwargs) -> Tuple[str, List[RetrievalResult], int]: def query(self, query: str, **kwargs) -> Tuple[str, List[RetrievalResult], int]:
agent, n_token_router = self._route(query) answer, retrieved_results, n_token_retrieval = self.agent.query(query, **kwargs)
answer, retrieved_results, n_token_retrieval = agent.query(query, **kwargs) return answer, retrieved_results, n_token_retrieval
return answer, retrieved_results, n_token_router + n_token_retrieval
def find_last_digit(self, string):
for char in reversed(string):
if char.isdigit():
return char
raise ValueError("No digit found in the string")

29
deepsearcher/configuration.py

@ -3,7 +3,7 @@ from typing import Literal
import yaml import yaml
from deepsearcher.agent import ChainOfRAG, DeepSearch, NaiveRAG from deepsearcher.agent import DeepSearch, NaiveRAG
from deepsearcher.agent.rag_router import RAGRouter from deepsearcher.agent.rag_router import RAGRouter
from deepsearcher.embedding.base import BaseEmbedding from deepsearcher.embedding.base import BaseEmbedding
from deepsearcher.llm.base import BaseLLM from deepsearcher.llm.base import BaseLLM
@ -210,25 +210,14 @@ def init_config(config: Configuration):
vector_db = module_factory.create_vector_db() vector_db = module_factory.create_vector_db()
default_searcher = RAGRouter( default_searcher = RAGRouter(
llm=llm, agent=DeepSearch(
rag_agents=[ llm=llm,
DeepSearch( embedding_model=embedding_model,
llm=llm, vector_db=vector_db,
embedding_model=embedding_model, max_iter=config.query_settings["max_iter"],
vector_db=vector_db, route_collection=True,
max_iter=config.query_settings["max_iter"], text_window_splitter=True,
route_collection=True, )
text_window_splitter=True,
),
ChainOfRAG(
llm=llm,
embedding_model=embedding_model,
vector_db=vector_db,
max_iter=config.query_settings["max_iter"],
route_collection=True,
text_window_splitter=True,
),
],
) )
naive_rag = NaiveRAG( naive_rag = NaiveRAG(
llm=llm, llm=llm,

8
deepsearcher/loader/file_loader/base.py

@ -37,7 +37,7 @@ class BaseLoader(ABC):
In the metadata, it's recommended to include the reference to the file. In the metadata, it's recommended to include the reference to the file.
e.g. return [Document(page_content=..., metadata={"reference": file_path})] e.g. return [Document(page_content=..., metadata={"reference": file_path})]
""" """
return [] pass
def load_directory(self, directory: str) -> List[Document]: def load_directory(self, directory: str) -> List[Document]:
""" """
@ -55,9 +55,7 @@ class BaseLoader(ABC):
for suffix in self.supported_file_types: for suffix in self.supported_file_types:
if file.endswith(suffix): if file.endswith(suffix):
full_path = os.path.join(root, file) full_path = os.path.join(root, file)
loaded_docs = self.load_file(full_path) documents.extend(self.load_file(full_path))
if loaded_docs is not None:
documents.extend(loaded_docs)
break break
return documents return documents
@ -69,4 +67,4 @@ class BaseLoader(ABC):
Returns: Returns:
A list of supported file extensions (without the dot). A list of supported file extensions (without the dot).
""" """
return [] pass

2
deepsearcher/offline_loading.py

@ -1,5 +1,5 @@
import os
import hashlib import hashlib
import os
from typing import List, Union from typing import List, Union
from tqdm import tqdm from tqdm import tqdm

Loading…
Cancel
Save