You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

94 lines
3.7 KiB

2 weeks ago
from typing import List, Optional, Tuple
from deepsearcher.agent import RAGAgent
from deepsearcher.llm.base import BaseLLM
from deepsearcher.utils import log
from deepsearcher.vector_db import RetrievalResult
RAG_ROUTER_PROMPT = """Given a list of agent indexes and corresponding descriptions, each agent has a specific function.
Given a query, select only one agent that best matches the agent handling the query, and return the index without any other information.
## Question
{query}
## Agent Indexes and Descriptions
{description_str}
Only return one agent index number that best matches the agent handling the query:
"""
class RAGRouter(RAGAgent):
"""
Routes queries to the most appropriate RAG agent implementation.
This class analyzes the content and requirements of a query and determines
which RAG agent implementation is best suited to handle it.
"""
def __init__(
self,
llm: BaseLLM,
rag_agents: List[RAGAgent],
agent_descriptions: Optional[List[str]] = None,
):
"""
Initialize the RAGRouter.
Args:
llm: The language model to use for analyzing queries.
rag_agents: A list of RAGAgent instances.
agent_descriptions (list, optional): A list of descriptions for each agent.
"""
self.llm = llm
self.rag_agents = rag_agents
self.agent_descriptions = agent_descriptions
if not self.agent_descriptions:
try:
self.agent_descriptions = [
agent.__class__.__description__ for agent in self.rag_agents
]
except Exception:
raise AttributeError(
"Please provide agent descriptions or set __description__ attribute for each agent class."
)
def _route(self, query: str) -> Tuple[RAGAgent, int]:
description_str = "\n".join(
[f"[{i + 1}]: {description}" for i, description in enumerate(self.agent_descriptions)]
)
prompt = RAG_ROUTER_PROMPT.format(query=query, description_str=description_str)
chat_response = self.llm.chat(messages=[{"role": "user", "content": prompt}])
try:
selected_agent_index = int(self.llm.remove_think(chat_response.content)) - 1
except ValueError:
# In some reasoning LLM, the output is not a number, but a explaination string with a number in the end.
log.warning(
"Parse int failed in RAGRouter, but will try to find the last digit as fallback."
)
selected_agent_index = (
int(self.find_last_digit(self.llm.remove_think(chat_response.content))) - 1
)
selected_agent = self.rag_agents[selected_agent_index]
log.color_print(
f"<think> Select agent [{selected_agent.__class__.__name__}] to answer the query [{query}] </think>\n"
)
return self.rag_agents[selected_agent_index], chat_response.total_tokens
def retrieve(self, query: str, **kwargs) -> Tuple[List[RetrievalResult], int, dict]:
agent, n_token_router = self._route(query)
retrieved_results, n_token_retrieval, metadata = agent.retrieve(query, **kwargs)
return retrieved_results, n_token_router + n_token_retrieval, metadata
def query(self, query: str, **kwargs) -> Tuple[str, List[RetrievalResult], int]:
agent, n_token_router = self._route(query)
answer, retrieved_results, n_token_retrieval = agent.query(query, **kwargs)
return answer, retrieved_results, n_token_router + n_token_retrieval
def find_last_digit(self, string):
for char in reversed(string):
if char.isdigit():
return char
raise ValueError("No digit found in the string")