diff --git a/deepsearcher/agent/rag_router.py b/deepsearcher/agent/rag_router.py index 799bb72..dc0aabd 100644 --- a/deepsearcher/agent/rag_router.py +++ b/deepsearcher/agent/rag_router.py @@ -1,8 +1,6 @@ -from typing import List, Optional, Tuple +from typing import List, Tuple from deepsearcher.agent import RAGAgent -from deepsearcher.llm.base import BaseLLM -from deepsearcher.utils import log from deepsearcher.vector_db import RetrievalResult RAG_ROUTER_PROMPT = """Given a list of agent indexes and corresponding descriptions, each agent has a specific function. @@ -28,9 +26,7 @@ class RAGRouter(RAGAgent): def __init__( self, - llm: BaseLLM, - rag_agents: List[RAGAgent], - agent_descriptions: Optional[List[str]] = None, + agent: RAGAgent ): """ Initialize the RAGRouter. @@ -40,54 +36,12 @@ class RAGRouter(RAGAgent): rag_agents: A list of RAGAgent instances. agent_descriptions (list, optional): A list of descriptions for each agent. """ - self.llm = llm - self.rag_agents = rag_agents - self.agent_descriptions = agent_descriptions - if not self.agent_descriptions: - try: - self.agent_descriptions = [ - agent.__class__.__description__ for agent in self.rag_agents - ] - except Exception: - raise AttributeError( - "Please provide agent descriptions or set __description__ attribute for each agent class." - ) - - def _route(self, query: str) -> Tuple[RAGAgent, int]: - description_str = "\n".join( - [f"[{i + 1}]: {description}" for i, description in enumerate(self.agent_descriptions)] - ) - prompt = RAG_ROUTER_PROMPT.format(query=query, description_str=description_str) - chat_response = self.llm.chat(messages=[{"role": "user", "content": prompt}]) - try: - selected_agent_index = int(self.llm.remove_think(chat_response.content)) - 1 - except ValueError: - # In some reasoning LLM, the output is not a number, but a explaination string with a number in the end. - log.warning( - "Parse int failed in RAGRouter, but will try to find the last digit as fallback." - ) - selected_agent_index = ( - int(self.find_last_digit(self.llm.remove_think(chat_response.content))) - 1 - ) - - selected_agent = self.rag_agents[selected_agent_index] - log.color_print( - f" Select agent [{selected_agent.__class__.__name__}] to answer the query [{query}] \n" - ) - return self.rag_agents[selected_agent_index], chat_response.total_tokens + self.agent = agent def retrieve(self, query: str, **kwargs) -> Tuple[List[RetrievalResult], int, dict]: - agent, n_token_router = self._route(query) - retrieved_results, n_token_retrieval, metadata = agent.retrieve(query, **kwargs) - return retrieved_results, n_token_router + n_token_retrieval, metadata + retrieved_results, n_token_retrieval, metadata = self.agent.retrieve(query, **kwargs) + return retrieved_results, n_token_retrieval, metadata def query(self, query: str, **kwargs) -> Tuple[str, List[RetrievalResult], int]: - agent, n_token_router = self._route(query) - answer, retrieved_results, n_token_retrieval = agent.query(query, **kwargs) - return answer, retrieved_results, n_token_router + n_token_retrieval - - def find_last_digit(self, string): - for char in reversed(string): - if char.isdigit(): - return char - raise ValueError("No digit found in the string") + answer, retrieved_results, n_token_retrieval = self.agent.query(query, **kwargs) + return answer, retrieved_results, n_token_retrieval diff --git a/deepsearcher/configuration.py b/deepsearcher/configuration.py index 36d0e87..896aa9e 100644 --- a/deepsearcher/configuration.py +++ b/deepsearcher/configuration.py @@ -3,7 +3,7 @@ from typing import Literal import yaml -from deepsearcher.agent import ChainOfRAG, DeepSearch, NaiveRAG +from deepsearcher.agent import DeepSearch, NaiveRAG from deepsearcher.agent.rag_router import RAGRouter from deepsearcher.embedding.base import BaseEmbedding from deepsearcher.llm.base import BaseLLM @@ -210,25 +210,14 @@ def init_config(config: Configuration): vector_db = module_factory.create_vector_db() default_searcher = RAGRouter( - llm=llm, - rag_agents=[ - DeepSearch( - llm=llm, - embedding_model=embedding_model, - vector_db=vector_db, - max_iter=config.query_settings["max_iter"], - route_collection=True, - text_window_splitter=True, - ), - ChainOfRAG( - llm=llm, - embedding_model=embedding_model, - vector_db=vector_db, - max_iter=config.query_settings["max_iter"], - route_collection=True, - text_window_splitter=True, - ), - ], + agent=DeepSearch( + llm=llm, + embedding_model=embedding_model, + vector_db=vector_db, + max_iter=config.query_settings["max_iter"], + route_collection=True, + text_window_splitter=True, + ) ) naive_rag = NaiveRAG( llm=llm, diff --git a/deepsearcher/loader/file_loader/base.py b/deepsearcher/loader/file_loader/base.py index 9fa2ce5..aa16c1b 100644 --- a/deepsearcher/loader/file_loader/base.py +++ b/deepsearcher/loader/file_loader/base.py @@ -37,7 +37,7 @@ class BaseLoader(ABC): In the metadata, it's recommended to include the reference to the file. e.g. return [Document(page_content=..., metadata={"reference": file_path})] """ - return [] + pass def load_directory(self, directory: str) -> List[Document]: """ @@ -55,9 +55,7 @@ class BaseLoader(ABC): for suffix in self.supported_file_types: if file.endswith(suffix): full_path = os.path.join(root, file) - loaded_docs = self.load_file(full_path) - if loaded_docs is not None: - documents.extend(loaded_docs) + documents.extend(self.load_file(full_path)) break return documents @@ -69,4 +67,4 @@ class BaseLoader(ABC): Returns: A list of supported file extensions (without the dot). """ - return [] + pass diff --git a/deepsearcher/offline_loading.py b/deepsearcher/offline_loading.py index 18e3065..9e2a2a2 100644 --- a/deepsearcher/offline_loading.py +++ b/deepsearcher/offline_loading.py @@ -1,5 +1,5 @@ -import os import hashlib +import os from typing import List, Union from tqdm import tqdm