Browse Source

移除其他嵌套agent,以减少复杂度

main
tanxing 1 week ago
parent
commit
64296cae96
  1. 10
      deepsearcher/agent/__init__.py
  2. 18
      deepsearcher/agent/base.py
  3. 326
      deepsearcher/agent/chain_of_rag.py
  4. 95
      deepsearcher/agent/collection_router.py
  5. 82
      deepsearcher/agent/deep_search.py
  6. 128
      deepsearcher/agent/naive_rag.py
  7. 45
      deepsearcher/agent/rag_router.py
  8. 23
      deepsearcher/configuration.py

10
deepsearcher/agent/__init__.py

@ -1,12 +1,6 @@
from .base import BaseAgent, RAGAgent
from .chain_of_rag import ChainOfRAG
from .deep_search import DeepSearch
from .naive_rag import NaiveRAG
from .deep_search import BaseAgent
__all__ = [
"ChainOfRAG",
"DeepSearch",
"NaiveRAG",
"BaseAgent",
"RAGAgent",
"DeepSearch", "BaseAgent"
]

18
deepsearcher/agent/base.py

@ -53,24 +53,6 @@ class BaseAgent(ABC):
The result of invoking the agent.
"""
class RAGAgent(BaseAgent):
"""
Abstract base class for Retrieval-Augmented Generation (RAG) agents.
This class extends BaseAgent with methods specific to RAG, including
retrieval and query methods.
"""
def __init__(self, **kwargs):
"""
Initialize a RAGAgent object.
Args:
**kwargs: Arbitrary keyword arguments.
"""
pass
def retrieve(self, query: str, **kwargs) -> tuple[list[RetrievalResult], dict]:
"""
Retrieve document results from the knowledge base.

326
deepsearcher/agent/chain_of_rag.py

@ -1,326 +0,0 @@
from typing import List, Tuple
from deepsearcher.agent.base import RAGAgent, describe_class
from deepsearcher.agent.collection_router import CollectionRouter
from deepsearcher.embedding.base import BaseEmbedding
from deepsearcher.llm.base import BaseLLM
from deepsearcher.utils import log
from deepsearcher.vector_db import RetrievalResult
from deepsearcher.vector_db.base import BaseVectorDB, deduplicate_results
FOLLOWUP_QUERY_PROMPT = """You are using a search tool to answer the main query by iteratively searching the database. Given the following intermediate queries and answers, generate a new simple follow-up question that can help answer the main query. You may rephrase or decompose the main query when previous answers are not helpful. Ask simple follow-up questions only as the search tool may not understand complex questions.
## Previous intermediate queries and answers
{intermediate_context}
## Main query to answer
{query}
Respond with a simple follow-up question that will help answer the main query, do not explain yourself or output anything else.
"""
INTERMEDIATE_ANSWER_PROMPT = """Given the following documents, generate an appropriate answer for the query. DO NOT hallucinate any information, only use the provided documents to generate the answer. Respond "No relevant information found" if the documents do not contain useful information.
## Documents
{retrieved_documents}
## Query
{sub_query}
Respond with a concise answer only, do not explain yourself or output anything else.
"""
FINAL_ANSWER_PROMPT = """Given the following intermediate queries and answers, generate a final answer for the main query by combining relevant information. Note that intermediate answers are generated by an LLM and may not always be accurate.
## Documents
{retrieved_documents}
## Intermediate queries and answers
{intermediate_context}
## Main query
{query}
Respond with an appropriate answer only, do not explain yourself or output anything else.
"""
REFLECTION_PROMPT = """Given the following intermediate queries and answers, judge whether you have enough information to answer the main query. If you believe you have enough information, respond with "Yes", otherwise respond with "No".
## Intermediate queries and answers
{intermediate_context}
## Main query
{query}
Respond with "Yes" or "No" only, do not explain yourself or output anything else.
"""
GET_SUPPORTED_DOCS_PROMPT = """Given the following documents, select the ones that are support the Q-A pair.
## Documents
{retrieved_documents}
## Q-A Pair
### Question
{query}
### Answer
{answer}
Respond with a python list of indices of the selected documents.
"""
@describe_class(
"This agent can decompose complex queries and gradually find the fact information of sub-queries. "
"It is very suitable for handling concrete factual queries and multi-hop questions."
)
class ChainOfRAG(RAGAgent):
"""
Chain of Retrieval-Augmented Generation (RAG) agent implementation.
This agent implements a multi-step RAG process where each step can refine
the query and retrieval process based on previous results, creating a chain
of increasingly focused and relevant information retrieval and generation.
Inspired by: https://arxiv.org/pdf/2501.14342
"""
def __init__(
self,
llm: BaseLLM,
embedding_model: BaseEmbedding,
vector_db: BaseVectorDB,
max_iter: int = 4,
early_stopping: bool = False,
route_collection: bool = True,
text_window_splitter: bool = True,
**kwargs,
):
"""
Initialize the ChainOfRAG agent with configuration parameters.
Args:
llm (BaseLLM): The language model to use for generating answers.
embedding_model (BaseEmbedding): The embedding model to use for embedding queries.
vector_db (BaseVectorDB): The vector database to search for relevant documents.
max_iter (int, optional): The maximum number of iterations for the RAG process. Defaults to 4.
early_stopping (bool, optional): Whether to use early stopping. Defaults to False.
route_collection (bool, optional): Whether to route the query to specific collections. Defaults to True.
text_window_splitter (bool, optional): Whether use text_window splitter. Defaults to True.
"""
self.llm = llm
self.embedding_model = embedding_model
self.vector_db = vector_db
self.max_iter = max_iter
self.early_stopping = early_stopping
self.route_collection = route_collection
self.collection_router = CollectionRouter(
llm=self.llm, vector_db=self.vector_db, dim=embedding_model.dimension
)
self.text_window_splitter = text_window_splitter
def _reflect_get_subquery(self, query: str, intermediate_context: List[str]) -> Tuple[str, int]:
chat_response = self.llm.chat(
[
{
"role": "user",
"content": FOLLOWUP_QUERY_PROMPT.format(
query=query,
intermediate_context="\n".join(intermediate_context),
),
}
]
)
return self.llm.remove_think(chat_response.content), chat_response.total_tokens
def _retrieve_and_answer(self, query: str) -> Tuple[str, List[RetrievalResult], int]:
consume_tokens = 0
if self.route_collection:
selected_collections, n_token_route = self.collection_router.invoke(
query=query, dim=self.embedding_model.dimension
)
else:
selected_collections = self.collection_router.all_collections
n_token_route = 0
consume_tokens += n_token_route
all_retrieved_results = []
for collection in selected_collections:
log.color_print(f"<search> Search [{query}] in [{collection}]... </search>\n")
query_vector = self.embedding_model.embed_query(query)
retrieved_results = self.vector_db.search_data(
collection=collection, vector=query_vector, query_text=query
)
all_retrieved_results.extend(retrieved_results)
all_retrieved_results = deduplicate_results(all_retrieved_results)
chat_response = self.llm.chat(
[
{
"role": "user",
"content": INTERMEDIATE_ANSWER_PROMPT.format(
retrieved_documents=self._format_retrieved_results(all_retrieved_results),
sub_query=query,
),
}
]
)
return (
self.llm.remove_think(chat_response.content),
all_retrieved_results,
consume_tokens + chat_response.total_tokens,
)
def _get_supported_docs(
self,
retrieved_results: List[RetrievalResult],
query: str,
intermediate_answer: str,
) -> Tuple[List[RetrievalResult], int]:
supported_retrieved_results = []
token_usage = 0
if "No relevant information found" not in intermediate_answer:
chat_response = self.llm.chat(
[
{
"role": "user",
"content": GET_SUPPORTED_DOCS_PROMPT.format(
retrieved_documents=self._format_retrieved_results(retrieved_results),
query=query,
answer=intermediate_answer,
),
}
]
)
supported_doc_indices = self.llm.literal_eval(chat_response.content)
supported_retrieved_results = [
retrieved_results[int(i)]
for i in supported_doc_indices
if int(i) < len(retrieved_results)
]
token_usage = chat_response.total_tokens
return supported_retrieved_results, token_usage
def _check_has_enough_info(
self, query: str, intermediate_contexts: List[str]
) -> Tuple[bool, int]:
if not intermediate_contexts:
return False, 0
chat_response = self.llm.chat(
[
{
"role": "user",
"content": REFLECTION_PROMPT.format(
query=query,
intermediate_context="\n".join(intermediate_contexts),
),
}
]
)
has_enough_info = self.llm.remove_think(chat_response.content).strip().lower() == "yes"
return has_enough_info, chat_response.total_tokens
def retrieve(self, query: str, **kwargs) -> Tuple[List[RetrievalResult], int, dict]:
"""
Retrieves relevant documents based on the input query and iteratively refines the search.
This method iteratively refines the search query based on intermediate results, retrieves documents,
and filters out supported documents. It keeps track of the intermediate contexts and token usage.
Args:
query (str): The initial search query.
**kwargs: Additional keyword arguments.
- max_iter (int, optional): The maximum number of iterations for refinement. Defaults to self.max_iter.
Returns:
Tuple[List[RetrievalResult], int, dict]: A tuple containing:
- List[RetrievalResult]: The list of all retrieved and deduplicated results.
- int: The total token usage across all iterations.
- dict: A dictionary containing additional information, including the intermediate contexts.
"""
max_iter = kwargs.pop("max_iter", self.max_iter)
intermediate_contexts = []
all_retrieved_results = []
token_usage = 0
for iter in range(max_iter):
log.color_print(f">> Iteration: {iter + 1}\n")
followup_query, n_token0 = self._reflect_get_subquery(query, intermediate_contexts)
intermediate_answer, retrieved_results, n_token1 = self._retrieve_and_answer(
followup_query
)
supported_retrieved_results, n_token2 = self._get_supported_docs(
retrieved_results, followup_query, intermediate_answer
)
all_retrieved_results.extend(supported_retrieved_results)
intermediate_idx = len(intermediate_contexts) + 1
intermediate_contexts.append(
f"Intermediate query{intermediate_idx}: {followup_query}\nIntermediate answer{intermediate_idx}: {intermediate_answer}"
)
token_usage += n_token0 + n_token1 + n_token2
if self.early_stopping:
has_enough_info, n_token_check = self._check_has_enough_info(
query, intermediate_contexts
)
token_usage += n_token_check
if has_enough_info:
log.color_print(
f"<think> Early stopping after iteration {iter + 1}: Have enough information to answer the main query. </think>\n"
)
break
all_retrieved_results = deduplicate_results(all_retrieved_results)
additional_info = {"intermediate_context": intermediate_contexts}
return all_retrieved_results, token_usage, additional_info
def query(self, query: str, **kwargs) -> Tuple[str, List[RetrievalResult], int]:
"""
Executes a query and returns the final answer along with all retrieved results and total token usage.
This method initiates a query, retrieves relevant documents, and then summarizes the answer based on the retrieved documents and intermediate contexts. It logs the final answer and returns the answer content, all retrieved results, and the total token usage including the tokens used for the final answer.
Args:
query (str): The initial query to execute.
**kwargs: Additional keyword arguments to pass to the `retrieve` method.
Returns:
Tuple[str, List[RetrievalResult], int]: A tuple containing:
- str: The final answer content.
- List[RetrievalResult]: The list of all retrieved and deduplicated results.
- int: The total token usage across all iterations, including the final answer.
"""
all_retrieved_results, n_token_retrieval, additional_info = self.retrieve(query, **kwargs)
intermediate_context = additional_info["intermediate_context"]
log.color_print(
f"<think> Summarize answer from all {len(all_retrieved_results)} retrieved chunks... </think>\n"
)
chat_response = self.llm.chat(
[
{
"role": "user",
"content": FINAL_ANSWER_PROMPT.format(
retrieved_documents=self._format_retrieved_results(all_retrieved_results),
intermediate_context="\n".join(intermediate_context),
query=query,
),
}
]
)
log.color_print("\n==== FINAL ANSWER====\n")
log.color_print(self.llm.remove_think(chat_response.content))
return (
self.llm.remove_think(chat_response.content),
all_retrieved_results,
n_token_retrieval + chat_response.total_tokens,
)
def _format_retrieved_results(self, retrieved_results: List[RetrievalResult]) -> str:
formatted_documents = []
for i, result in enumerate(retrieved_results):
if self.text_window_splitter and "wider_text" in result.metadata:
text = result.metadata["wider_text"]
else:
text = result.text
formatted_documents.append(f"<Document {i}>\n{text}\n<\Document {i}>")
return "\n".join(formatted_documents)

95
deepsearcher/agent/collection_router.py

@ -1,95 +0,0 @@
from deepsearcher.agent.base import BaseAgent
from deepsearcher.llm.base import BaseLLM
from deepsearcher.utils import log
from deepsearcher.vector_db.base import BaseVectorDB
COLLECTION_ROUTE_PROMPT = """
I provide you with collection_name(s) and corresponding collection_description(s).
Please select the collection names that may be related to the question and return a python list of str.
If there is no collection related to the question, you can return an empty list.
"QUESTION": {question}
"COLLECTION_INFO": {collection_info}
When you return, you can ONLY return a json convertable python list of str, WITHOUT any other additional content.
Your selected collection name list is:
"""
class CollectionRouter(BaseAgent):
"""
Routes queries to appropriate collections in the vector database.
This class analyzes the content of a query and determines which collections
in the vector database are most likely to contain relevant information.
"""
def __init__(self, llm: BaseLLM, vector_db: BaseVectorDB, dim: int, **kwargs):
"""
Initialize the CollectionRouter.
Args:
llm: The language model to use for analyzing queries.
vector_db: The vector database containing the collections.
dim: The dimension of the vector space to search in.
"""
self.llm = llm
self.vector_db = vector_db
self.all_collections = [
collection_info.collection_name
for collection_info in self.vector_db.list_collections(dim=dim)
]
def invoke(self, query: str, dim: int, **kwargs) -> list[str]:
"""
Determine which collections are relevant for the given query.
This method analyzes the query content and selects collections that are
most likely to contain information relevant to answering the query.
Args:
query (str): The query to analyze.
dim (int): The dimension of the vector space to search in.
Returns:
List[str]: A list of selected collection names
"""
collection_infos = self.vector_db.list_collections(dim=dim)
if len(collection_infos) == 0:
log.color_print(
"No collections found in the vector database. Please check the database connection."
)
return [], 0
if len(collection_infos) == 1:
the_only_collection = collection_infos[0].collection_name
log.color_print(
f"<think> Perform search [{query}] on the vector DB collection: {the_only_collection} </think>\n"
)
return [the_only_collection], 0
vector_db_search_prompt = COLLECTION_ROUTE_PROMPT.format(
question=query,
collection_info=[
{
"collection_name": collection_info.collection_name,
"collection_description": collection_info.description,
}
for collection_info in collection_infos
],
)
response = self.llm.chat(
messages=[{"role": "user", "content": vector_db_search_prompt}]
)
selected_collections = self.llm.literal_eval(response)
for collection_info in collection_infos:
# If a collection description is not provided, use the query as the search query
if not collection_info.description:
selected_collections.append(collection_info.collection_name)
# If the default collection exists, use the query as the search query
if self.vector_db.default_collection == collection_info.collection_name:
selected_collections.append(collection_info.collection_name)
selected_collections = list(set(selected_collections))
log.color_print(
f"<think> Perform search [{query}] on the vector DB collections: {selected_collections} </think>\n"
)
return selected_collections

82
deepsearcher/agent/deep_search.py

@ -1,13 +1,24 @@
import asyncio
from deepsearcher.agent.base import RAGAgent, describe_class
from deepsearcher.agent.collection_router import CollectionRouter
from deepsearcher.agent.base import BaseAgent, describe_class
from deepsearcher.embedding.base import BaseEmbedding
from deepsearcher.llm.base import BaseLLM
from deepsearcher.utils import log
from deepsearcher.vector_db import RetrievalResult
from deepsearcher.vector_db.base import BaseVectorDB, deduplicate_results
COLLECTION_ROUTE_PROMPT = """
I provide you with collection_name(s) and corresponding collection_description(s).
Please select the collection names that may be related to the question and return a python list of str.
If there is no collection related to the question, you can return an empty list.
"QUESTION": {question}
"COLLECTION_INFO": {collection_info}
When you return, you can ONLY return a json convertable python list of str, WITHOUT any other additional content.
Your selected collection name list is:
"""
SUB_QUERY_PROMPT = """
To answer this question more comprehensively, please break down the original question into few numbers of sub-questions (more if necessary).
If this is a very simple question and no decomposition is necessary, then keep the only one original question.
@ -77,7 +88,7 @@ Related Chunks:
@describe_class(
"This agent is suitable for handling general and simple queries, such as given a topic and then writing a report, survey, or article."
)
class DeepSearch(RAGAgent):
class DeepSearch(BaseAgent):
"""
Deep Search agent implementation for comprehensive information retrieval.
@ -112,11 +123,66 @@ class DeepSearch(RAGAgent):
self.vector_db = vector_db
self.max_iter = max_iter
self.route_collection = route_collection
self.collection_router = CollectionRouter(
llm=self.llm, vector_db=self.vector_db, dim=embedding_model.dimension
)
self.all_collections = [
collection_info.collection_name
for collection_info in self.vector_db.list_collections(dim=embedding_model.dimension)
]
self.text_window_splitter = text_window_splitter
def invoke(self, query: str, dim: int, **kwargs) -> list[str]:
"""
Determine which collections are relevant for the given query.
This method analyzes the query content and selects collections that are
most likely to contain information relevant to answering the query.
Args:
query (str): The query to analyze.
dim (int): The dimension of the vector space to search in.
Returns:
List[str]: A list of selected collection names
"""
collection_infos = self.vector_db.list_collections(dim=dim)
if len(collection_infos) == 0:
log.color_print(
"No collections found in the vector database. Please check the database connection."
)
return [], 0
if len(collection_infos) == 1:
the_only_collection = collection_infos[0].collection_name
log.color_print(
f"<think> Perform search [{query}] on the vector DB collection: {the_only_collection} </think>\n"
)
return [the_only_collection], 0
vector_db_search_prompt = COLLECTION_ROUTE_PROMPT.format(
question=query,
collection_info=[
{
"collection_name": collection_info.collection_name,
"collection_description": collection_info.description,
}
for collection_info in collection_infos
],
)
response = self.llm.chat(
messages=[{"role": "user", "content": vector_db_search_prompt}]
)
selected_collections = self.llm.literal_eval(response)
for collection_info in collection_infos:
# If a collection description is not provided, use the query as the search query
if not collection_info.description:
selected_collections.append(collection_info.collection_name)
# If the default collection exists, use the query as the search query
if self.vector_db.default_collection == collection_info.collection_name:
selected_collections.append(collection_info.collection_name)
selected_collections = list(set(selected_collections))
log.color_print(
f"<think> Perform search [{query}] on the vector DB collections: {selected_collections} </think>\n"
)
return selected_collections
def _generate_sub_queries(self, original_query: str) -> tuple[list[str], int]:
content = self.llm.chat(
messages=[
@ -128,11 +194,11 @@ class DeepSearch(RAGAgent):
async def _search_chunks_from_vectordb(self, query: str):
if self.route_collection:
selected_collections = self.collection_router.invoke(
selected_collections = self.invoke(
query=query, dim=self.embedding_model.dimension
)
else:
selected_collections = self.collection_router.all_collections
selected_collections = self.all_collections
all_retrieved_results = []
query_vector = self.embedding_model.embed_query(query)

128
deepsearcher/agent/naive_rag.py

@ -1,128 +0,0 @@
from typing import List, Tuple
from deepsearcher.agent.base import RAGAgent
from deepsearcher.agent.collection_router import CollectionRouter
from deepsearcher.embedding.base import BaseEmbedding
from deepsearcher.llm.base import BaseLLM
from deepsearcher.utils import log
from deepsearcher.vector_db.base import BaseVectorDB, RetrievalResult, deduplicate_results
SUMMARY_PROMPT = """You are a AI content analysis expert, good at summarizing content. Please summarize a specific and detailed answer or report based on the previous queries and the retrieved document chunks.
Original Query: {query}
Related Chunks:
{mini_chunk_str}
"""
class NaiveRAG(RAGAgent):
"""
Naive Retrieval-Augmented Generation agent implementation.
This agent implements a straightforward RAG approach, retrieving relevant
documents and generating answers without complex processing or refinement steps.
"""
def __init__(
self,
llm: BaseLLM,
embedding_model: BaseEmbedding,
vector_db: BaseVectorDB,
top_k: int = 10,
route_collection: bool = True,
text_window_splitter: bool = True,
**kwargs,
):
"""
Initialize the NaiveRAG agent.
Args:
llm: The language model to use for generating answers.
embedding_model: The embedding model to use for query embedding.
vector_db: The vector database to search for relevant documents.
**kwargs: Additional keyword arguments for customization.
"""
self.llm = llm
self.embedding_model = embedding_model
self.vector_db = vector_db
self.top_k = top_k
self.route_collection = route_collection
if self.route_collection:
self.collection_router = CollectionRouter(
llm=self.llm, vector_db=self.vector_db, dim=embedding_model.dimension
)
self.text_window_splitter = text_window_splitter
def retrieve(self, query: str, **kwargs) -> Tuple[List[RetrievalResult], int, dict]:
"""
Retrieve relevant documents from the knowledge base for the given query.
This method performs a basic search through the vector database to find
documents relevant to the query.
Args:
query (str): The query to search for.
**kwargs: Additional keyword arguments for customizing the retrieval.
Returns:
Tuple[List[RetrievalResult], int, dict]: A tuple containing:
- A list of retrieved document results
- The token usage for the retrieval operation
- Additional information about the retrieval process
"""
consume_tokens = 0
if self.route_collection:
selected_collections, n_token_route = self.collection_router.invoke(
query=query, dim=self.embedding_model.dimension
)
else:
selected_collections = self.collection_router.all_collections
n_token_route = 0
consume_tokens += n_token_route
all_retrieved_results = []
for collection in selected_collections:
retrieval_res = self.vector_db.search_data(
collection=collection,
vector=self.embedding_model.embed_query(query),
top_k=max(self.top_k // len(selected_collections), 1),
query_text=query,
)
all_retrieved_results.extend(retrieval_res)
all_retrieved_results = deduplicate_results(all_retrieved_results)
return all_retrieved_results, consume_tokens, {}
def query(self, query: str, **kwargs) -> Tuple[str, List[RetrievalResult], int]:
"""
Query the agent and generate an answer based on retrieved documents.
This method retrieves relevant documents and uses the language model
to generate a simple answer to the query.
Args:
query (str): The query to answer.
**kwargs: Additional keyword arguments for customizing the query process.
Returns:
Tuple[str, List[RetrievalResult], int]: A tuple containing:
- The generated answer
- A list of retrieved document results
- The total token usage
"""
all_retrieved_results, n_token_retrieval, _ = self.retrieve(query)
chunk_texts = []
for chunk in all_retrieved_results:
if self.text_window_splitter and "wider_text" in chunk.metadata:
chunk_texts.append(chunk.metadata["wider_text"])
else:
chunk_texts.append(chunk.text)
mini_chunk_str = ""
for i, chunk in enumerate(chunk_texts):
mini_chunk_str += f"""<chunk_{i}>\n{chunk}\n</chunk_{i}>\n"""
summary_prompt = SUMMARY_PROMPT.format(query=query, mini_chunk_str=mini_chunk_str)
char_response = self.llm.chat([{"role": "user", "content": summary_prompt}])
final_answer = char_response.content
log.color_print("\n==== FINAL ANSWER====\n")
log.color_print(final_answer)
return final_answer, all_retrieved_results, n_token_retrieval + char_response.total_tokens

45
deepsearcher/agent/rag_router.py

@ -1,45 +0,0 @@
from deepsearcher.agent import RAGAgent
from deepsearcher.vector_db import RetrievalResult
RAG_ROUTER_PROMPT = """Given a list of agent indexes and corresponding descriptions, each agent has a specific function.
Given a query, select only one agent that best matches the agent handling the query, and return the index without any other information.
## Question
{query}
## Agent Indexes and Descriptions
{description_str}
Only return one agent index number that best matches the agent handling the query:
"""
class RAGRouter(RAGAgent):
"""
Routes queries to the most appropriate RAG agent implementation.
This class analyzes the content and requirements of a query and determines
which RAG agent implementation is best suited to handle it.
"""
def __init__(
self,
agent: RAGAgent
):
"""
Initialize the RAGRouter.
Args:
llm: The language model to use for analyzing queries.
rag_agents: A list of RAGAgent instances.
agent_descriptions (list, optional): A list of descriptions for each agent.
"""
self.agent = agent
def retrieve(self, query: str, **kwargs) -> tuple[list[RetrievalResult], dict]:
retrieved_results, metadata = self.agent.retrieve(query, **kwargs)
return retrieved_results, metadata
def query(self, query: str, **kwargs) -> tuple[str, list[RetrievalResult]]:
answer, retrieved_results = self.agent.query(query, **kwargs)
return answer, retrieved_results

23
deepsearcher/configuration.py

@ -3,8 +3,7 @@ from typing import Literal
import yaml
from deepsearcher.agent import DeepSearch, NaiveRAG
from deepsearcher.agent.rag_router import RAGRouter
from deepsearcher.agent import BaseAgent, DeepSearch
from deepsearcher.embedding.base import BaseEmbedding
from deepsearcher.llm.base import BaseLLM
from deepsearcher.loader.file_loader.base import BaseLoader
@ -49,7 +48,7 @@ class Configuration:
Returns:
The loaded configuration data as a dictionary.
"""
with open(config_path, "r") as file:
with open(config_path) as file:
return yaml.safe_load(file)
def set_provider_config(self, feature: FeatureType, provider: str, provider_configs: dict):
@ -179,9 +178,7 @@ embedding_model: BaseEmbedding = None
file_loader: BaseLoader = None
vector_db: BaseVectorDB = None
web_crawler: BaseCrawler = None
default_searcher: RAGRouter = None
naive_rag: NaiveRAG = None
default_searcher: BaseAgent = None
def init_config(config: Configuration):
"""
@ -209,21 +206,11 @@ def init_config(config: Configuration):
web_crawler = module_factory.create_web_crawler()
vector_db = module_factory.create_vector_db()
default_searcher = RAGRouter(
agent=DeepSearch(
llm=llm,
embedding_model=embedding_model,
vector_db=vector_db,
max_iter=config.query_settings["max_iter"],
route_collection=True,
text_window_splitter=True,
)
)
naive_rag = NaiveRAG(
default_searcher = DeepSearch(
llm=llm,
embedding_model=embedding_model,
vector_db=vector_db,
top_k=10,
max_iter=config.query_settings["max_iter"],
route_collection=True,
text_window_splitter=True,
)

Loading…
Cancel
Save