You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

426 lines
18 KiB

from deepsearcher.agent.base import BaseAgent, describe_class
from deepsearcher.embedding.base import BaseEmbedding
from deepsearcher.llm.base import BaseLLM
from deepsearcher.utils import log
from deepsearcher.utils.message_stream import send_info, send_answer
from deepsearcher.vector_db import RetrievalResult
from deepsearcher.vector_db.base import BaseVectorDB, deduplicate
from collections import defaultdict
COLLECTION_ROUTE_PROMPT = """
我现在给你提供collection_name(s)和对应的collection_description(s).
请你选择可能与问题相关的合集名称,并返回一个字符串列表。
如果没有相关合集,请返回一个空列表。
"问题": {query}
"合集信息": {collection_info}
使用的语言与问题相同
你需要返回的格式是 a python list of str without any addtional content:
"""
SUB_QUERY_PROMPT = """
为了能够全面的回答这个问题,请你尝试把原本的问题拆分或扩展为几个子问题
不可以太多,但是也不可以太少,请根据问题复杂程度来决定子问题的数量
如果原问题本身非常简单,没有必要进行拆分,则保留输出原问题本身
需要保证每个子问题都具体、清晰、不可分(原子性),最终返回一个字符串列表
原问题: {original_query}
<EXAMPLE>
示例输入:
"请解释机器学习"
示例输出(例子中的数量不是要求):
[
"什么是机器学习?",
"机器学习的使用目的是什么?",
"机器学习和深度学习的区别是什么?",
"机器学习的历史演进过程?"
]
</EXAMPLE>
使用的语言与原问题相同
你需要返回的是 a python list of str without any addtional content:
"""
RERANK_PROMPT = """
根据当前的问题和获取到的文档片段
请你对当前获取到的文档是否能帮助回答这个问题(直接或间接,全面或部分,都可以)给出一个快速判断
对于每一个文档片段,你只应该返回"YES"或者"NO"(需要注意顺序和数量)
问题: {query}
检索到的文档片段:
{chunks}
例如,给定4个chunks(实际检索到的文档片段不一定是4个),返回: ["YES", "NO", "YES", "YES"]
使用的语言与问题相同
你需要返回的是 a python list of str without any addtional content:
"""
REFLECT_PROMPT = """
根据原问题和子问题,以及获取到的文档片段,请你决定是否要生成更多的问题。
如果已经获得的文档片段没能覆盖所有的子问题,这意味着这些文档无法被检索到。
你可以尝试生成相似但些许不同的问题来尝试重新检索,但是也可以根据获得到的文档片段进行批评思考,生成新的问题来保证原问题的回答的准确和全面
如果没有真的必要继续研究(取决于你的判断),返回一个空列表
原问题: {original_query}
子问题: {all_sub_queries}
检索到的文档片段:
{chunks}
使用的语言与原问题相同
你需要返回的是 a python list of str without any addtional content:
"""
SUMMARY_PROMPT = """
你是一个内容分析专家,请你根据提供的问题和检索到的信息生成详尽的长文回答。
如果检索到的信息不足以回答问题或者必须添加额外信息才能能回答,你应该使用你的知识来进行补充,
这种情况下,你自己提供的信息需要使用例如"your knowledge here[^0]"引用,注意,这里的"[^0]"的序号0是固定的,表示你的知识,下文当中有文末引用的例子
同时,你应该根据提供的信息生成文内引用和文末参考资料列表,来自文档切片的reference引用从[^1]开始
如果多个片段是相同的来源或者一个片段可以回答多个问题,文内引用可以引用多次,但文末只引用一次来源,即文末的引用列表中不能有重复的来源。
例子:
<EXAMPLE>
文内引用示例(使用markdown脚注):
"XGBoost是非常强大的集成学习模型[^2]"
(必须使用 "[^index]",这里的index是对应的<reference>的id)
文末引用示例 (需要与前文reference的href一致,不需要对每个chunk分配一个引用,而是每一个referecen共用一个引用):
[^0]: AI Generated
[^2]: files/docs/chap_001_003_models.md
</EXAMPLE>
原问题: {original_query}
子问题: {all_sub_queries}
检索到的文档片段:
{chunks}
注意,你需要使用与原始问题的相同的语言来回答
"""
@describe_class(
"This agent is suitable for handling general and simple queries, such as given a topic and then writing a report, survey, or article."
)
class DeepSearch(BaseAgent):
"""
Deep Search agent implementation for comprehensive information retrieval.
This agent performs a thorough search through the knowledge base, analyzing
multiple aspects of the query to provide comprehensive and detailed answers.
"""
def __init__(
self,
llm: BaseLLM,
embedding_model: BaseEmbedding,
vector_db: BaseVectorDB,
max_iter: int,
route_collection: bool = False,
text_window_splitter: bool = True,
**kwargs,
):
"""
Initialize the DeepSearch agent.
Args:
llm: The language model to use for generating answers.
embedding_model: The embedding model to use for query embedding.
vector_db: The vector database to search for relevant documents.
max_iter: The maximum number of iterations for the search process.
route_collection: Whether to use a collection router for search.
text_window_splitter: Whether to use text_window splitter.
**kwargs: Additional keyword arguments for customization.
"""
self.llm = llm
self.embedding_model = embedding_model
self.vector_db = vector_db
self.max_iter = max_iter
self.route_collection = route_collection
self.all_collections = [
collection_info.collection_name
for collection_info in self.vector_db.list_collections(dim=embedding_model.dimension)
]
self.text_window_splitter = text_window_splitter
def invoke(self, query: str, dim: int, **kwargs) -> list[str]:
"""
Determine which collections are relevant for the given query.
This method analyzes the query content and selects collections that are
most likely to contain information relevant to answering the query.
Args:
query (str): The query to analyze.
dim (int): The dimension of the vector space to search in.
Returns:
List[str]: A list of selected collection names
"""
collection_infos = self.vector_db.list_collections(dim=dim)
if len(collection_infos) == 0:
log.color_print(
"No collection found in the vector database!"
)
return []
if len(collection_infos) == 1:
the_only_collection = collection_infos[0].collection_name
log.color_print(
f"Perform search [{query}] on the vector DB collection: {the_only_collection}\n"
)
return [the_only_collection]
vector_db_search_prompt = COLLECTION_ROUTE_PROMPT.format(
query=query,
collection_info=[
{
"collection_name": collection_info.collection_name,
"collection_description": collection_info.description,
}
for collection_info in collection_infos
],
)
response = self.llm.chat(
messages=[{"role": "user", "content": vector_db_search_prompt}]
)
selected_collections = self.llm.literal_eval(response)
for collection_info in collection_infos:
# If a collection description is not provided, use the query as the search query
if not collection_info.description:
selected_collections.append(collection_info.collection_name)
# If the default collection exists, use the query as the search query
if self.vector_db.default_collection == collection_info.collection_name:
selected_collections.append(collection_info.collection_name)
selected_collections = list(set(selected_collections))
log.color_print(
f"Perform search [{query}] on the vector DB collections: {selected_collections}\n"
)
return selected_collections
def _generate_sub_queries(self, original_query: str) -> tuple[list[str], int]:
content = self.llm.chat(
messages=[
{"role": "user", "content": SUB_QUERY_PROMPT.format(original_query=original_query)}
]
)
content = self.llm.remove_think(content)
return self.llm.literal_eval(content)
def _search_chunks_from_vectordb(self, query: str):
if self.route_collection:
selected_collections = self.invoke(
query=query, dim=self.embedding_model.dimension
)
else:
selected_collections = self.all_collections
all_retrieved_results = []
query_vector = self.embedding_model.embed_query(query)
for collection in selected_collections:
send_info(f"正在 [{collection}] 中搜索 [{query}] ...")
retrieved_results = self.vector_db.search_data(
collection=collection, vector=query_vector, query_text=query
)
if not retrieved_results or len(retrieved_results) == 0:
send_info(f"'{collection}' 中没有找到相关文档!")
continue
# Format all chunks for batch processing
chunks = self._format_chunks(retrieved_results)
# Batch process all chunks with a single LLM call
content = self.llm.chat(
messages=[
{
"role": "user",
"content": RERANK_PROMPT.format(
query=query,
chunks=chunks,
),
}
]
)
content = self.llm.remove_think(content).strip()
# Parse the response to determine which chunks are relevant
try:
relevance_list = self.llm.literal_eval(content)
if not isinstance(relevance_list, list):
raise ValueError("Response is not a list")
except (ValueError, SyntaxError):
# Fallback: if parsing fails, treat all chunks as relevant
log.color_print(f"Warning: Failed to parse relevance response. Treating all chunks as relevant. Response was: {content}")
relevance_list = ["YES"] * len(retrieved_results)
# Ensure we have enough relevance judgments for all chunks
while len(relevance_list) < len(retrieved_results):
relevance_list.append("YES") # Default to relevant if no judgment provided
# Filter relevant chunks based on LLM response
accepted_chunk_num = 0
references = set()
for i, retrieved_result in enumerate(retrieved_results):
# Check if we have a relevance judgment for this chunk
is_relevant = (
i < len(relevance_list) and
"YES" in relevance_list[i].upper() and
"NO" not in relevance_list[i].upper()) if i < len(relevance_list
) else True
if is_relevant:
all_retrieved_results.append(retrieved_result)
accepted_chunk_num += 1
references.add(retrieved_result.reference)
if accepted_chunk_num > 0:
send_info(f"采纳 {accepted_chunk_num} 个文档片段,来源:{list(references)}")
else:
send_info(f"没有采纳任何 '{collection}' 中找到的文档片段!")
return all_retrieved_results
def _generate_more_sub_queries(
self, original_query: str, all_sub_queries: list[str], all_retrieved_results: list[RetrievalResult]
) -> list[str]:
chunks = self._format_chunks(all_retrieved_results)
reflect_prompt = REFLECT_PROMPT.format(
original_query=original_query,
all_sub_queries=all_sub_queries,
chunks=chunks
if len(all_retrieved_results) > 0
else "NO RELATED CHUNKS FOUND.",
)
response = self.llm.chat([{"role": "user", "content": reflect_prompt}])
response = self.llm.remove_think(response)
return self.llm.literal_eval(response)
def retrieve(self, original_query: str, **kwargs) -> tuple[list[RetrievalResult], list[str]]:
"""
Retrieve relevant documents from the knowledge base for the given query.
This method performs a deep search through the vector database to find
the most relevant documents for answering the query.
Args:
original_query (str): The query to search for.
**kwargs: Additional keyword arguments for customizing the retrieval.
Returns:
Tuple[List[RetrievalResult], int, dict]: A tuple containing:
- A list of retrieved document results
- Additional information about the retrieval process
"""
# Get max_iter from kwargs or use default
max_iter = kwargs.get('max_iter', self.max_iter)
### SUB QUERIES ###
all_search_results = []
all_sub_queries = []
sub_queries = self._generate_sub_queries(original_query)
if not sub_queries:
log.color_print("No sub queries were generated by the LLM. Exiting.")
return [], {}
else:
send_info(f"原问题被拆分为这些子问题: {sub_queries}")
all_sub_queries.extend(sub_queries)
for it in range(max_iter):
send_info(f"{it + 1} 轮搜索:")
# Execute all search tasks sequentially
for query in sub_queries:
result = self._search_chunks_from_vectordb(query)
all_search_results.extend(result)
undeduped_len = len(all_search_results)
all_search_results = deduplicate(all_search_results)
deduped_len = len(all_search_results)
if undeduped_len - deduped_len != 0:
send_info(f"移除 {undeduped_len - deduped_len} 个重复文档片段")
# search_res_from_internet = deduplicate_results(search_res_from_internet)
# all_search_res.extend(search_res_from_vectordb + search_res_from_internet)
### REFLECTION & GET MORE SUB QUERIES ###
# Only generate more queries if we haven't reached the maximum iterations
if it + 1 < max_iter:
send_info("正在根据文档片段思考 ...")
sub_queries = self._generate_more_sub_queries(
original_query, all_sub_queries, all_search_results
)
if not sub_queries or len(sub_queries) == 0:
send_info("没能生成更多的子问题,正在退出 ....")
break
else:
send_info(f"下一轮搜索的子问题: {sub_queries}")
all_sub_queries.extend(sub_queries)
else:
send_info("已达到最大搜索轮数,正在退出 ...")
break
all_search_results = deduplicate(all_search_results)
return all_search_results, all_sub_queries
def query(self, original_query: str, **kwargs) -> tuple[str, list[RetrievalResult]]:
"""
Query the agent and generate an answer based on retrieved documents.
This method retrieves relevant documents and uses the language model
to generate a comprehensive answer to the query.
Args:
query (str): The query to answer.
**kwargs: Additional keyword arguments for customizing the query process.
Returns:
Tuple[str, List[RetrievalResult], int]: A tuple containing:
- The generated answer
- A list of retrieved document results
"""
all_retrieved_results, all_sub_queries = self.retrieve(original_query, **kwargs)
if not all_retrieved_results or len(all_retrieved_results) == 0:
send_info(f"'{original_query}'没能找到更多信息!")
return "", []
chunks = self._format_chunks(all_retrieved_results)
send_info(f"正在总结 {len(all_retrieved_results)} 个查找到的文档片段")
summary_prompt = SUMMARY_PROMPT.format(
original_query=original_query,
all_sub_queries=all_sub_queries,
chunks=chunks
)
response = self.llm.chat([{"role": "user", "content": summary_prompt}])
final_answer = self.llm.remove_think(response)
send_answer(final_answer)
return self.llm.remove_think(response), all_retrieved_results
def _format_chunks(self, retrieved_results: list[RetrievalResult]):
# 以referecen为key,把chunk放到字典中
references = defaultdict(list)
for result in retrieved_results:
references[result.reference].append(result.text)
chunks = []
chunk_count = 0
for i, reference in enumerate(references):
formated = f"<reference id='{i + 1}' href='{reference}'>\n" + "".join(
[
f"<chunk id='{j + 1 + chunk_count}'>\n{chunk}\n</chunk id='{j + 1 + chunk_count}'>\n"
for j, chunk in enumerate(references[reference])
]
) + f"</reference id='{i + 1}'>\n"
print(formated)
chunks.append(formated)
chunk_count += len(references[reference])
return "".join(chunks)