You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
446 lines
19 KiB
446 lines
19 KiB
from deepsearcher.agent.base import BaseAgent, describe_class
|
|
from deepsearcher.embedding.base import BaseEmbedding
|
|
from deepsearcher.llm.base import BaseLLM
|
|
from deepsearcher.utils import log
|
|
from deepsearcher.utils.message_stream import send_info, send_answer
|
|
from deepsearcher.vector_db import RetrievalResult
|
|
from deepsearcher.vector_db.base import BaseVectorDB, deduplicate
|
|
from collections import defaultdict
|
|
|
|
|
|
COLLECTION_ROUTE_PROMPT = """
|
|
我现在给你提供collection_name(s)和对应的collection_description(s).
|
|
请你选择可能与问题相关的合集名称,并返回一个字符串列表。
|
|
如果没有相关合集,请返回一个空列表。
|
|
|
|
"问题": {query}
|
|
"合集信息": {collection_info}
|
|
|
|
使用的语言与问题相同
|
|
你需要返回的格式是 a python list of str without any addtional content:
|
|
"""
|
|
|
|
|
|
SUB_QUERY_PROMPT = """
|
|
为了能够全面的回答这个问题,请你尝试把原本的问题拆分或扩展为几个子问题
|
|
请你使用自顶向下和自底向上两种方向来思考如何拆分问题
|
|
子问题的数量不可以太多,但是也不可以太少,应当保证问题的回答全面性,请根据问题复杂程度来决定子问题的数量
|
|
如果原问题本身非常简单,没有必要进行拆分,则保留输出原问题本身
|
|
需要保证每个子问题都具体、清晰、不可分(原子性,即不可以再包含更细分的子问题),子问题中不要包含"请你回答"、"请你总结"、"请你分析"等祈使类型词语
|
|
你需要最终返回一个字符串列表
|
|
|
|
原问题: {original_query}
|
|
|
|
<EXAMPLE>
|
|
示例输入:
|
|
"请解释机器学习"
|
|
|
|
示例输出(例子中的数量不是要求):
|
|
[
|
|
"什么是机器学习?",
|
|
"机器学习的使用目的",
|
|
"机器学习的常用算法",
|
|
"机器学习的历史演进过程",
|
|
"机器学习和深度学习的区别是什么?"
|
|
]
|
|
|
|
</EXAMPLE>
|
|
|
|
使用的语言与原问题相同
|
|
你需要返回的是 a python list of str without any addtional content:
|
|
"""
|
|
|
|
|
|
RERANK_PROMPT = """
|
|
根据当前的问题和获取到的文档片段(文档片段包裹都在<reference></reference>和<chunk></chunk>标签中并有对应的连续的id)
|
|
请你对当前获取到的文档片段是否能帮助回答这个问题(直接或间接、全面或部分,都可以,但需要有实际有效内容)给出一个快速判断
|
|
对于每一个文档片段,你只应该返回"True"或者"False"(需要注意顺序和数量)
|
|
|
|
问题: {query}
|
|
|
|
检索到的文档片段:
|
|
{chunks}
|
|
|
|
例如,假如给出4个chunks(实际检索到的文档片段不一定是这么多),返回4个"True"或者"False"(注意这只是一个示例,不代表实际判断): ["True", "False", "True", "True"]
|
|
使用的语言与问题相同
|
|
你需要返回的是 a python list of str without any addtional content:
|
|
"""
|
|
|
|
|
|
REFLECT_PROMPT = """
|
|
根据原问题和子问题,以及获取到的文档片段,请你决定是否要生成更多的问题,这些问题将被用于后续的思考和搜索。
|
|
你应该根据已经获得到的文档片段进行批评思考,生成其他新的问题来保证原问题的回答的准确和全面,请你使用自顶向下和自底向上两种方向来思考如何生成新问题。
|
|
如果已经获得的文档片段没能覆盖所有的子问题,这意味着有关这些问题的文档无法被检索到,你应该根据你自己的知识补充思考。
|
|
需要保证每个新的问题都具体、清晰、不可分(原子性)并且不可以和之前的问题重复,新的问题中不要包含"请你回答"、"请你总结"、"请你分析"等祈使类型词语
|
|
如果没有真的必要继续研究(取决于你的判断),返回一个空列表。
|
|
|
|
原问题: {original_query}
|
|
|
|
子问题: {all_sub_queries}
|
|
|
|
检索到的文档片段:
|
|
{chunks}
|
|
|
|
使用的语言与原问题相同
|
|
你需要返回的是 a python list of str without any addtional content:
|
|
"""
|
|
|
|
|
|
SUMMARY_PROMPT = """
|
|
你是一个内容分析专家,请你根据提供的问题和检索到的信息,生成详细、层次分明、尽可能长的回答。
|
|
如果检索到的信息不足以回答问题,你应该使用你的知识来进行扩展补充。
|
|
注意,不要一个子问题一个子问题的回答,而是应该仔细分析子问题之间的关系、子问题和原问题之间的关系。
|
|
同时,你应该根据提供的信息生成文内引用和文末参考资料列表,使用markdown脚注。
|
|
如果你自己提供的信息需要使用"your knowledge here[^0]"引用。
|
|
注意,这里的"[^0]"的序号0是固定的,表示你的知识,文末引用使用"[^0]: AI 生成",
|
|
来自<chunk><reference>的引用序号从[^1]开始,来源需要与前文<reference>中的"href"一致,不需要对每个<chunk>分配一个引用,而是相同<reference>的<chunk>共用一个引用
|
|
另外,如果回答的内容文内引用需要引用多个<reference>,请添加多个[^index]到句尾。
|
|
如果多个片段是相同的来源或者一个片段可以回答多个问题,文内引用可以引用多次,但文末只引用一次来源,即文末的引用列表中不能有重复。
|
|
|
|
例子:
|
|
<EXAMPLE>
|
|
|
|
文内引用示例:
|
|
"XGBoost是非常强大的集成学习模型[^2]"
|
|
|
|
|
|
文末引用示例:
|
|
正确例子:
|
|
[^0]: AI 生成
|
|
[^1]: files/docs/machine_learning.md
|
|
[^2]: files/docs/chap_001_003_models.md
|
|
|
|
错误例子:
|
|
[^0]: AI 生成
|
|
[^1]: files/docs/chap_001_003_models.md
|
|
[^2]: files/docs/machine_learning.md
|
|
[^3]: files/docs/chap_001_003_models.md(错误,这是重复引用)
|
|
[^5]: files/docs/machine_learning.md(错误,也这是重复引用)
|
|
|
|
</EXAMPLE>
|
|
|
|
原问题: {original_query}
|
|
|
|
子问题: {all_sub_queries}
|
|
|
|
检索到的文档片段:
|
|
{chunks}
|
|
|
|
注意,你需要使用与原始问题的相同的语言来回答
|
|
"""
|
|
|
|
|
|
@describe_class(
|
|
"This agent is suitable for handling general and simple queries, such as given a topic and then writing a report, survey, or article."
|
|
)
|
|
class DeepSearch(BaseAgent):
|
|
"""
|
|
Deep Search agent implementation for comprehensive information retrieval.
|
|
|
|
This agent performs a thorough search through the knowledge base, analyzing
|
|
multiple aspects of the query to provide comprehensive and detailed answers.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
llm: BaseLLM,
|
|
embedding_model: BaseEmbedding,
|
|
vector_db: BaseVectorDB,
|
|
max_iter: int,
|
|
route_collection: bool = False,
|
|
text_window_splitter: bool = True,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Initialize the DeepSearch agent.
|
|
|
|
Args:
|
|
llm: The language model to use for generating answers.
|
|
embedding_model: The embedding model to use for query embedding.
|
|
vector_db: The vector database to search for relevant documents.
|
|
max_iter: The maximum number of iterations for the search process.
|
|
route_collection: Whether to use a collection router for search.
|
|
text_window_splitter: Whether to use text_window splitter.
|
|
**kwargs: Additional keyword arguments for customization.
|
|
"""
|
|
self.llm = llm
|
|
self.embedding_model = embedding_model
|
|
self.vector_db = vector_db
|
|
self.max_iter = max_iter
|
|
self.route_collection = route_collection
|
|
self.all_collections = [
|
|
collection_info.collection_name
|
|
for collection_info in self.vector_db.list_collections(dim=embedding_model.dimension)
|
|
]
|
|
self.text_window_splitter = text_window_splitter
|
|
|
|
def invoke(self, query: str, dim: int, **kwargs) -> list[str]:
|
|
"""
|
|
Determine which collections are relevant for the given query.
|
|
|
|
This method analyzes the query content and selects collections that are
|
|
most likely to contain information relevant to answering the query.
|
|
|
|
Args:
|
|
query (str): The query to analyze.
|
|
dim (int): The dimension of the vector space to search in.
|
|
|
|
Returns:
|
|
List[str]: A list of selected collection names
|
|
"""
|
|
collection_infos = self.vector_db.list_collections(dim=dim)
|
|
if len(collection_infos) == 0:
|
|
log.color_print(
|
|
"No collection found in the vector database!"
|
|
)
|
|
return []
|
|
if len(collection_infos) == 1:
|
|
the_only_collection = collection_infos[0].collection_name
|
|
log.color_print(
|
|
f"Perform search [{query}] on the vector DB collection: {the_only_collection}\n"
|
|
)
|
|
return [the_only_collection]
|
|
vector_db_search_prompt = COLLECTION_ROUTE_PROMPT.format(
|
|
query=query,
|
|
collection_info=[
|
|
{
|
|
"collection_name": collection_info.collection_name,
|
|
"collection_description": collection_info.description,
|
|
}
|
|
for collection_info in collection_infos
|
|
],
|
|
)
|
|
response = self.llm.chat(
|
|
messages=[{"role": "user", "content": vector_db_search_prompt}]
|
|
)
|
|
selected_collections = self.llm.literal_eval(response)
|
|
|
|
for collection_info in collection_infos:
|
|
# If a collection description is not provided, use the query as the search query
|
|
if not collection_info.description:
|
|
selected_collections.append(collection_info.collection_name)
|
|
# If the default collection exists, use the query as the search query
|
|
if self.vector_db.default_collection == collection_info.collection_name:
|
|
selected_collections.append(collection_info.collection_name)
|
|
selected_collections = list(set(selected_collections))
|
|
log.color_print(
|
|
f"Perform search [{query}] on the vector DB collections: {selected_collections}\n"
|
|
)
|
|
return selected_collections
|
|
|
|
def _generate_sub_queries(self, original_query: str) -> tuple[list[str], int]:
|
|
content = self.llm.chat(
|
|
messages=[
|
|
{"role": "user", "content": SUB_QUERY_PROMPT.format(original_query=original_query)}
|
|
]
|
|
)
|
|
content = self.llm.remove_think(content)
|
|
return self.llm.literal_eval(content)
|
|
|
|
def _search_chunks_from_vectordb(self, query: str):
|
|
if self.route_collection:
|
|
selected_collections = self.invoke(
|
|
query=query, dim=self.embedding_model.dimension
|
|
)
|
|
else:
|
|
selected_collections = self.all_collections
|
|
|
|
all_retrieved_results = []
|
|
query_vector = self.embedding_model.embed_query(query)
|
|
for collection in selected_collections:
|
|
send_info(f"正在 [{collection}] 中搜索 [{query}] ...")
|
|
retrieved_results = self.vector_db.search_data(
|
|
collection=collection, vector=query_vector, query_text=query
|
|
)
|
|
if not retrieved_results or len(retrieved_results) == 0:
|
|
send_info(f"'{collection}' 中没有找到相关文档!")
|
|
continue
|
|
|
|
# Format all chunks for batch processing
|
|
chunks = self._format_chunks(retrieved_results)
|
|
|
|
# Batch process all chunks with a single LLM call
|
|
content = self.llm.chat(
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": RERANK_PROMPT.format(
|
|
query=query,
|
|
chunks=chunks,
|
|
),
|
|
}
|
|
]
|
|
)
|
|
content = self.llm.remove_think(content).strip()
|
|
|
|
# Parse the response to determine which chunks are relevant
|
|
try:
|
|
relevance_list = self.llm.literal_eval(content)
|
|
if not isinstance(relevance_list, list):
|
|
raise ValueError("Response is not a list")
|
|
except Exception as _:
|
|
# Fallback: if parsing fails, treat all chunks as relevant
|
|
log.color_print(f"Warning: Failed to parse relevance response. Treating all chunks as relevant. Response was: {content}")
|
|
relevance_list = ["True"] * len(retrieved_results)
|
|
|
|
# Ensure we have enough relevance judgments for all chunks
|
|
while len(relevance_list) < len(retrieved_results):
|
|
relevance_list.append("True") # Default to relevant if no judgment provided
|
|
|
|
# Filter relevant chunks based on LLM response
|
|
accepted_chunk_num = 0
|
|
references = set()
|
|
for i, retrieved_result in enumerate(retrieved_results):
|
|
# Check if we have a relevance judgment for this chunk
|
|
is_relevant = (
|
|
i < len(relevance_list) and
|
|
"True" in relevance_list[i] and
|
|
"False" not in relevance_list[i]
|
|
) if i < len(relevance_list) else True
|
|
|
|
if is_relevant:
|
|
all_retrieved_results.append(retrieved_result)
|
|
accepted_chunk_num += 1
|
|
references.add(retrieved_result.reference)
|
|
|
|
if accepted_chunk_num > 0:
|
|
send_info(f"采纳 {accepted_chunk_num} 个文档片段,来源:{list(references)}")
|
|
else:
|
|
send_info(f"没有采纳任何 '{collection}' 中找到的文档片段!")
|
|
return all_retrieved_results
|
|
|
|
def _generate_more_sub_queries(
|
|
self, original_query: str, all_sub_queries: list[str], all_retrieved_results: list[RetrievalResult]
|
|
) -> list[str]:
|
|
chunks = self._format_chunks(all_retrieved_results)
|
|
reflect_prompt = REFLECT_PROMPT.format(
|
|
original_query=original_query,
|
|
all_sub_queries=all_sub_queries,
|
|
chunks=chunks
|
|
if len(all_retrieved_results) > 0
|
|
else "NO RELATED CHUNKS FOUND.",
|
|
)
|
|
response = self.llm.chat([{"role": "user", "content": reflect_prompt}])
|
|
response = self.llm.remove_think(response)
|
|
return self.llm.literal_eval(response)
|
|
|
|
def retrieve(self, original_query: str, **kwargs) -> tuple[list[RetrievalResult], list[str]]:
|
|
"""
|
|
Retrieve relevant documents from the knowledge base for the given query.
|
|
|
|
This method performs a deep search through the vector database to find
|
|
the most relevant documents for answering the query.
|
|
|
|
Args:
|
|
original_query (str): The query to search for.
|
|
**kwargs: Additional keyword arguments for customizing the retrieval.
|
|
|
|
Returns:
|
|
Tuple[List[RetrievalResult], int, dict]: A tuple containing:
|
|
- A list of retrieved document results
|
|
- Additional information about the retrieval process
|
|
"""
|
|
# Get max_iter from kwargs or use default
|
|
max_iter = kwargs.get('max_iter', self.max_iter)
|
|
|
|
### SUB QUERIES ###
|
|
all_search_results = []
|
|
all_sub_queries = []
|
|
|
|
sub_queries = self._generate_sub_queries(original_query)
|
|
if not sub_queries:
|
|
log.color_print("No sub queries were generated by the LLM. Exiting.")
|
|
return [], {}
|
|
else:
|
|
send_info(f"原问题被拆分为这些子问题: {sub_queries}")
|
|
all_sub_queries.extend(sub_queries)
|
|
|
|
for it in range(max_iter):
|
|
send_info(f"第 {it + 1} 轮搜索:")
|
|
|
|
# Execute all search tasks sequentially
|
|
for query in sub_queries:
|
|
result = self._search_chunks_from_vectordb(query)
|
|
all_search_results.extend(result)
|
|
undeduped_len = len(all_search_results)
|
|
all_search_results = deduplicate(all_search_results)
|
|
deduped_len = len(all_search_results)
|
|
if undeduped_len - deduped_len != 0:
|
|
send_info(f"移除 {undeduped_len - deduped_len} 个重复文档片段")
|
|
# search_res_from_internet = deduplicate_results(search_res_from_internet)
|
|
# all_search_res.extend(search_res_from_vectordb + search_res_from_internet)
|
|
|
|
### REFLECTION & GET MORE SUB QUERIES ###
|
|
# Only generate more queries if we haven't reached the maximum iterations
|
|
if it + 1 < max_iter:
|
|
send_info("正在根据文档片段思考 ...")
|
|
sub_queries = self._generate_more_sub_queries(
|
|
original_query, all_sub_queries, all_search_results
|
|
)
|
|
if not sub_queries or len(sub_queries) == 0:
|
|
send_info("没能生成更多的子问题,正在退出 ....")
|
|
break
|
|
else:
|
|
send_info(f"下一轮搜索的子问题: {sub_queries}")
|
|
all_sub_queries.extend(sub_queries)
|
|
else:
|
|
send_info("已达到最大搜索轮数,正在退出 ...")
|
|
break
|
|
|
|
all_search_results = deduplicate(all_search_results)
|
|
return all_search_results, all_sub_queries
|
|
|
|
def query(self, original_query: str, **kwargs) -> tuple[str, list[RetrievalResult]]:
|
|
"""
|
|
Query the agent and generate an answer based on retrieved documents.
|
|
|
|
This method retrieves relevant documents and uses the language model
|
|
to generate a comprehensive answer to the query.
|
|
|
|
Args:
|
|
query (str): The query to answer.
|
|
**kwargs: Additional keyword arguments for customizing the query process.
|
|
|
|
Returns:
|
|
Tuple[str, List[RetrievalResult], int]: A tuple containing:
|
|
- The generated answer
|
|
- A list of retrieved document results
|
|
"""
|
|
all_retrieved_results, all_sub_queries = self.retrieve(original_query, **kwargs)
|
|
if not all_retrieved_results or len(all_retrieved_results) == 0:
|
|
send_info(f"'{original_query}'没能找到更多信息!")
|
|
return "", []
|
|
chunks = self._format_chunks(all_retrieved_results)
|
|
send_info(f"正在总结 {len(all_retrieved_results)} 个查找到的文档片段")
|
|
summary_prompt = SUMMARY_PROMPT.format(
|
|
original_query=original_query,
|
|
all_sub_queries=all_sub_queries,
|
|
chunks=chunks
|
|
)
|
|
response = self.llm.chat([{"role": "user", "content": summary_prompt}])
|
|
final_answer = self.llm.remove_think(response)
|
|
send_answer(final_answer)
|
|
return self.llm.remove_think(response), all_retrieved_results
|
|
|
|
def _format_chunks(self, retrieved_results: list[RetrievalResult]):
|
|
# 以referecen为key,把chunk放到字典中
|
|
references = defaultdict(list)
|
|
for result in retrieved_results:
|
|
references[result.reference].append(result.text)
|
|
chunks = []
|
|
chunk_count = 0
|
|
for i, reference in enumerate(references):
|
|
formated = "".join(
|
|
[
|
|
(
|
|
f"<reference id='{i + 1}' href='{reference}'>" +
|
|
f"<chunk id='{j + 1 + chunk_count}'>\n{chunk}\n</chunk id='{j + 1 + chunk_count}'>" +
|
|
f"</reference id='{i + 1}'>\n"
|
|
)
|
|
for j, chunk in enumerate(references[reference])
|
|
]
|
|
)
|
|
print(formated)
|
|
chunks.append(formated)
|
|
chunk_count += len(references[reference])
|
|
return "".join(chunks)
|
|
|