You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
45 lines
1.5 KiB
45 lines
1.5 KiB
from deepsearcher.agent import RAGAgent
|
|
from deepsearcher.vector_db import RetrievalResult
|
|
|
|
RAG_ROUTER_PROMPT = """Given a list of agent indexes and corresponding descriptions, each agent has a specific function.
|
|
Given a query, select only one agent that best matches the agent handling the query, and return the index without any other information.
|
|
|
|
## Question
|
|
{query}
|
|
|
|
## Agent Indexes and Descriptions
|
|
{description_str}
|
|
|
|
Only return one agent index number that best matches the agent handling the query:
|
|
"""
|
|
|
|
|
|
class RAGRouter(RAGAgent):
|
|
"""
|
|
Routes queries to the most appropriate RAG agent implementation.
|
|
|
|
This class analyzes the content and requirements of a query and determines
|
|
which RAG agent implementation is best suited to handle it.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
agent: RAGAgent
|
|
):
|
|
"""
|
|
Initialize the RAGRouter.
|
|
|
|
Args:
|
|
llm: The language model to use for analyzing queries.
|
|
rag_agents: A list of RAGAgent instances.
|
|
agent_descriptions (list, optional): A list of descriptions for each agent.
|
|
"""
|
|
self.agent = agent
|
|
|
|
def retrieve(self, query: str, **kwargs) -> tuple[list[RetrievalResult], dict]:
|
|
retrieved_results, metadata = self.agent.retrieve(query, **kwargs)
|
|
return retrieved_results, metadata
|
|
|
|
def query(self, query: str, **kwargs) -> tuple[str, list[RetrievalResult]]:
|
|
answer, retrieved_results = self.agent.query(query, **kwargs)
|
|
return answer, retrieved_results
|
|
|