You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
128 lines
5.1 KiB
128 lines
5.1 KiB
from typing import List, Tuple
|
|
|
|
from deepsearcher.agent.base import RAGAgent
|
|
from deepsearcher.agent.collection_router import CollectionRouter
|
|
from deepsearcher.embedding.base import BaseEmbedding
|
|
from deepsearcher.llm.base import BaseLLM
|
|
from deepsearcher.utils import log
|
|
from deepsearcher.vector_db.base import BaseVectorDB, RetrievalResult, deduplicate_results
|
|
|
|
SUMMARY_PROMPT = """You are a AI content analysis expert, good at summarizing content. Please summarize a specific and detailed answer or report based on the previous queries and the retrieved document chunks.
|
|
|
|
Original Query: {query}
|
|
|
|
Related Chunks:
|
|
{mini_chunk_str}
|
|
"""
|
|
|
|
|
|
class NaiveRAG(RAGAgent):
|
|
"""
|
|
Naive Retrieval-Augmented Generation agent implementation.
|
|
|
|
This agent implements a straightforward RAG approach, retrieving relevant
|
|
documents and generating answers without complex processing or refinement steps.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
llm: BaseLLM,
|
|
embedding_model: BaseEmbedding,
|
|
vector_db: BaseVectorDB,
|
|
top_k: int = 10,
|
|
route_collection: bool = True,
|
|
text_window_splitter: bool = True,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Initialize the NaiveRAG agent.
|
|
|
|
Args:
|
|
llm: The language model to use for generating answers.
|
|
embedding_model: The embedding model to use for query embedding.
|
|
vector_db: The vector database to search for relevant documents.
|
|
**kwargs: Additional keyword arguments for customization.
|
|
"""
|
|
self.llm = llm
|
|
self.embedding_model = embedding_model
|
|
self.vector_db = vector_db
|
|
self.top_k = top_k
|
|
self.route_collection = route_collection
|
|
if self.route_collection:
|
|
self.collection_router = CollectionRouter(
|
|
llm=self.llm, vector_db=self.vector_db, dim=embedding_model.dimension
|
|
)
|
|
self.text_window_splitter = text_window_splitter
|
|
|
|
def retrieve(self, query: str, **kwargs) -> Tuple[List[RetrievalResult], int, dict]:
|
|
"""
|
|
Retrieve relevant documents from the knowledge base for the given query.
|
|
|
|
This method performs a basic search through the vector database to find
|
|
documents relevant to the query.
|
|
|
|
Args:
|
|
query (str): The query to search for.
|
|
**kwargs: Additional keyword arguments for customizing the retrieval.
|
|
|
|
Returns:
|
|
Tuple[List[RetrievalResult], int, dict]: A tuple containing:
|
|
- A list of retrieved document results
|
|
- The token usage for the retrieval operation
|
|
- Additional information about the retrieval process
|
|
"""
|
|
consume_tokens = 0
|
|
if self.route_collection:
|
|
selected_collections, n_token_route = self.collection_router.invoke(
|
|
query=query, dim=self.embedding_model.dimension
|
|
)
|
|
else:
|
|
selected_collections = self.collection_router.all_collections
|
|
n_token_route = 0
|
|
consume_tokens += n_token_route
|
|
all_retrieved_results = []
|
|
for collection in selected_collections:
|
|
retrieval_res = self.vector_db.search_data(
|
|
collection=collection,
|
|
vector=self.embedding_model.embed_query(query),
|
|
top_k=max(self.top_k // len(selected_collections), 1),
|
|
query_text=query,
|
|
)
|
|
all_retrieved_results.extend(retrieval_res)
|
|
all_retrieved_results = deduplicate_results(all_retrieved_results)
|
|
return all_retrieved_results, consume_tokens, {}
|
|
|
|
def query(self, query: str, **kwargs) -> Tuple[str, List[RetrievalResult], int]:
|
|
"""
|
|
Query the agent and generate an answer based on retrieved documents.
|
|
|
|
This method retrieves relevant documents and uses the language model
|
|
to generate a simple answer to the query.
|
|
|
|
Args:
|
|
query (str): The query to answer.
|
|
**kwargs: Additional keyword arguments for customizing the query process.
|
|
|
|
Returns:
|
|
Tuple[str, List[RetrievalResult], int]: A tuple containing:
|
|
- The generated answer
|
|
- A list of retrieved document results
|
|
- The total token usage
|
|
"""
|
|
all_retrieved_results, n_token_retrieval, _ = self.retrieve(query)
|
|
chunk_texts = []
|
|
for chunk in all_retrieved_results:
|
|
if self.text_window_splitter and "wider_text" in chunk.metadata:
|
|
chunk_texts.append(chunk.metadata["wider_text"])
|
|
else:
|
|
chunk_texts.append(chunk.text)
|
|
mini_chunk_str = ""
|
|
for i, chunk in enumerate(chunk_texts):
|
|
mini_chunk_str += f"""<chunk_{i}>\n{chunk}\n</chunk_{i}>\n"""
|
|
|
|
summary_prompt = SUMMARY_PROMPT.format(query=query, mini_chunk_str=mini_chunk_str)
|
|
char_response = self.llm.chat([{"role": "user", "content": summary_prompt}])
|
|
final_answer = char_response.content
|
|
log.color_print("\n==== FINAL ANSWER====\n")
|
|
log.color_print(final_answer)
|
|
return final_answer, all_retrieved_results, n_token_retrieval + char_response.total_tokens
|
|
|