You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

427 lines
18 KiB

from deepsearcher.agent.base import BaseAgent, describe_class
2 weeks ago
from deepsearcher.embedding.base import BaseEmbedding
from deepsearcher.llm.base import BaseLLM
from deepsearcher.utils import log
from deepsearcher.utils.message_stream import send_info, send_answer
2 weeks ago
from deepsearcher.vector_db import RetrievalResult
from deepsearcher.vector_db.base import BaseVectorDB, deduplicate
from collections import defaultdict
2 weeks ago
1 week ago
COLLECTION_ROUTE_PROMPT = """
我现在给你提供collection_name(s)和对应的collection_description(s).
请你选择可能与问题相关的合集名称并返回一个字符串列表
如果没有相关合集请返回一个空列表
"问题": {query}
"合集信息": {collection_info}
使用的语言与问题相同
你需要返回的格式是 a python list of str without any addtional content:
"""
1 week ago
SUB_QUERY_PROMPT = """
为了能够全面的回答这个问题请你尝试把原本的问题拆分或扩展为几个子问题
不可以太多但是也不可以太少请根据问题复杂程度来决定子问题的数量
如果原问题本身非常简单没有必要进行拆分则保留输出原问题本身
需要保证每个子问题都具体清晰不可分原子性最终返回一个字符串列表
2 weeks ago
原问题: {original_query}
2 weeks ago
<EXAMPLE>
示例输入:
"请解释机器学习"
2 weeks ago
示例输出例子中的数量不是要求:
2 weeks ago
[
"什么是机器学习?",
"机器学习的使用目的是什么?",
"机器学习和深度学习的区别是什么?",
"机器学习的历史演进过程?"
2 weeks ago
]
2 weeks ago
</EXAMPLE>
使用的语言与原问题相同
你需要返回的是 a python list of str without any addtional content:
2 weeks ago
"""
1 week ago
RERANK_PROMPT = """
根据当前的问题和获取到的文档片段
请你对当前获取到的文档是否能帮助回答这个问题直接或间接全面或部分都可以给出一个快速判断
对于每一个文档片段你只应该返回"YES"或者"NO"需要注意顺序和数量
2 weeks ago
问题: {query}
2 weeks ago
检索到的文档片段:
{chunks}
例如给定4个chunks实际检索到的文档片段不一定是4个返回: ["YES", "NO", "YES", "YES"]
使用的语言与问题相同
你需要返回的是 a python list of str without any addtional content:
"""
2 weeks ago
REFLECT_PROMPT = """
根据原问题和子问题以及获取到的文档片段请你决定是否要生成更多的问题
如果已经获得的文档片段没能覆盖所有的子问题这意味着这些文档无法被检索到
你可以尝试生成相似但些许不同的问题来尝试重新检索但是也可以根据获得到的文档片段进行批评思考生成新的问题来保证原问题的回答的准确和全面
如果没有真的必要继续研究取决于你的判断返回一个空列表
2 weeks ago
原问题: {original_query}
2 weeks ago
子问题: {all_sub_queries}
2 weeks ago
检索到的文档片段:
{chunks}
2 weeks ago
使用的语言与原问题相同
你需要返回的是 a python list of str without any addtional content:
"""
2 weeks ago
SUMMARY_PROMPT = """
你是一个内容分析专家请你根据提供的问题和检索到的信息生成详尽的长文回答
如果检索到的信息不足以回答问题或者必须添加额外信息才能能回答你应该使用你的知识来进行补充
这种情况下你自己提供的信息需要使用例如"your knowledge here[^0]"引用注意这里的"[^0]"的序号0是固定的表示你的知识下文当中有文末引用的例子
同时你应该根据提供的信息生成文内引用和文末参考资料列表来自文档切片的reference引用从[^1]开始
如果多个片段是相同的来源或者一个片段可以回答多个问题文内引用可以引用多次但文末只引用一次来源即文末的引用列表中不能有重复的来源
例子:
<EXAMPLE>
文内引用示例使用markdown脚注:
"XGBoost是非常强大的集成学习模型[^2]"
(必须使用 "[^index]"这里的index是对应的<reference>的id)
文末引用示例 (需要与前文reference的href一致不需要对每个chunk分配一个引用而是每一个referecen共用一个引用):
[^0]: AI Generated
[^2]: files/docs/chap_001_003_models.md
</EXAMPLE>
2 weeks ago
原问题: {original_query}
2 weeks ago
子问题: {all_sub_queries}
2 weeks ago
检索到的文档片段:
{chunks}
2 weeks ago
注意你需要使用与原始问题的相同的语言来回答
2 weeks ago
"""
@describe_class(
"This agent is suitable for handling general and simple queries, such as given a topic and then writing a report, survey, or article."
)
class DeepSearch(BaseAgent):
2 weeks ago
"""
Deep Search agent implementation for comprehensive information retrieval.
This agent performs a thorough search through the knowledge base, analyzing
multiple aspects of the query to provide comprehensive and detailed answers.
"""
def __init__(
self,
llm: BaseLLM,
embedding_model: BaseEmbedding,
vector_db: BaseVectorDB,
max_iter: int,
1 week ago
route_collection: bool = False,
2 weeks ago
text_window_splitter: bool = True,
**kwargs,
):
"""
Initialize the DeepSearch agent.
Args:
llm: The language model to use for generating answers.
embedding_model: The embedding model to use for query embedding.
vector_db: The vector database to search for relevant documents.
max_iter: The maximum number of iterations for the search process.
route_collection: Whether to use a collection router for search.
text_window_splitter: Whether to use text_window splitter.
**kwargs: Additional keyword arguments for customization.
"""
self.llm = llm
self.embedding_model = embedding_model
self.vector_db = vector_db
self.max_iter = max_iter
self.route_collection = route_collection
self.all_collections = [
collection_info.collection_name
for collection_info in self.vector_db.list_collections(dim=embedding_model.dimension)
]
2 weeks ago
self.text_window_splitter = text_window_splitter
def invoke(self, query: str, dim: int, **kwargs) -> list[str]:
"""
Determine which collections are relevant for the given query.
This method analyzes the query content and selects collections that are
most likely to contain information relevant to answering the query.
Args:
query (str): The query to analyze.
dim (int): The dimension of the vector space to search in.
Returns:
List[str]: A list of selected collection names
"""
collection_infos = self.vector_db.list_collections(dim=dim)
if len(collection_infos) == 0:
log.color_print(
"No collection found in the vector database!"
)
return []
if len(collection_infos) == 1:
the_only_collection = collection_infos[0].collection_name
log.color_print(
f"Perform search [{query}] on the vector DB collection: {the_only_collection}\n"
)
return [the_only_collection]
vector_db_search_prompt = COLLECTION_ROUTE_PROMPT.format(
query=query,
collection_info=[
{
"collection_name": collection_info.collection_name,
"collection_description": collection_info.description,
}
for collection_info in collection_infos
],
)
response = self.llm.chat(
messages=[{"role": "user", "content": vector_db_search_prompt}]
)
selected_collections = self.llm.literal_eval(response)
for collection_info in collection_infos:
# If a collection description is not provided, use the query as the search query
if not collection_info.description:
selected_collections.append(collection_info.collection_name)
# If the default collection exists, use the query as the search query
if self.vector_db.default_collection == collection_info.collection_name:
selected_collections.append(collection_info.collection_name)
selected_collections = list(set(selected_collections))
log.color_print(
f"Perform search [{query}] on the vector DB collections: {selected_collections}\n"
)
return selected_collections
def _generate_sub_queries(self, original_query: str) -> tuple[list[str], int]:
content = self.llm.chat(
2 weeks ago
messages=[
{"role": "user", "content": SUB_QUERY_PROMPT.format(original_query=original_query)}
]
)
content = self.llm.remove_think(content)
return self.llm.literal_eval(content)
2 weeks ago
def _search_chunks_from_vectordb(self, query: str):
2 weeks ago
if self.route_collection:
selected_collections = self.invoke(
2 weeks ago
query=query, dim=self.embedding_model.dimension
)
else:
selected_collections = self.all_collections
2 weeks ago
all_retrieved_results = []
query_vector = self.embedding_model.embed_query(query)
for collection in selected_collections:
send_info(f"正在 [{collection}] 中搜索 [{query}] ...")
2 weeks ago
retrieved_results = self.vector_db.search_data(
collection=collection, vector=query_vector, query_text=query
)
if not retrieved_results or len(retrieved_results) == 0:
send_info(f"'{collection}' 中没有找到相关文档!")
2 weeks ago
continue
# Format all chunks for batch processing
chunks = self._format_chunks(retrieved_results)
# Batch process all chunks with a single LLM call
content = self.llm.chat(
messages=[
{
"role": "user",
"content": RERANK_PROMPT.format(
query=query,
chunks=chunks,
),
}
]
)
content = self.llm.remove_think(content).strip()
# Parse the response to determine which chunks are relevant
try:
relevance_list = self.llm.literal_eval(content)
if not isinstance(relevance_list, list):
raise ValueError("Response is not a list")
except (ValueError, SyntaxError):
# Fallback: if parsing fails, treat all chunks as relevant
log.color_print(f"Warning: Failed to parse relevance response. Treating all chunks as relevant. Response was: {content}")
relevance_list = ["YES"] * len(retrieved_results)
# Ensure we have enough relevance judgments for all chunks
while len(relevance_list) < len(retrieved_results):
relevance_list.append("YES") # Default to relevant if no judgment provided
# Filter relevant chunks based on LLM response
2 weeks ago
accepted_chunk_num = 0
references = set()
for i, retrieved_result in enumerate(retrieved_results):
# Check if we have a relevance judgment for this chunk
is_relevant = (
i < len(relevance_list) and
"YES" in relevance_list[i].upper() and
"NO" not in relevance_list[i].upper()) if i < len(relevance_list
) else True
if is_relevant:
2 weeks ago
all_retrieved_results.append(retrieved_result)
accepted_chunk_num += 1
references.add(retrieved_result.reference)
2 weeks ago
if accepted_chunk_num > 0:
send_info(f"采纳 {accepted_chunk_num} 个文档片段,来源:{list(references)}")
2 weeks ago
else:
send_info(f"没有采纳任何 '{collection}' 中找到的文档片段!")
return all_retrieved_results
2 weeks ago
def _generate_more_sub_queries(
self, original_query: str, all_sub_queries: list[str], all_retrieved_results: list[RetrievalResult]
) -> list[str]:
chunks = self._format_chunks(all_retrieved_results)
2 weeks ago
reflect_prompt = REFLECT_PROMPT.format(
original_query=original_query,
all_sub_queries=all_sub_queries,
chunks=chunks
if len(all_retrieved_results) > 0
2 weeks ago
else "NO RELATED CHUNKS FOUND.",
)
response = self.llm.chat([{"role": "user", "content": reflect_prompt}])
response = self.llm.remove_think(response)
return self.llm.literal_eval(response)
2 weeks ago
def retrieve(self, original_query: str, **kwargs) -> tuple[list[RetrievalResult], list[str]]:
2 weeks ago
"""
Retrieve relevant documents from the knowledge base for the given query.
This method performs a deep search through the vector database to find
the most relevant documents for answering the query.
Args:
original_query (str): The query to search for.
**kwargs: Additional keyword arguments for customizing the retrieval.
Returns:
Tuple[List[RetrievalResult], int, dict]: A tuple containing:
- A list of retrieved document results
- Additional information about the retrieval process
"""
# Get max_iter from kwargs or use default
max_iter = kwargs.get('max_iter', self.max_iter)
2 weeks ago
### SUB QUERIES ###
all_search_results = []
2 weeks ago
all_sub_queries = []
sub_queries = self._generate_sub_queries(original_query)
2 weeks ago
if not sub_queries:
log.color_print("No sub queries were generated by the LLM. Exiting.")
return [], {}
2 weeks ago
else:
send_info(f"原问题被拆分为这些子问题: {sub_queries}")
2 weeks ago
all_sub_queries.extend(sub_queries)
for it in range(max_iter):
send_info(f"{it + 1} 轮搜索:")
# Execute all search tasks sequentially
for query in sub_queries:
result = self._search_chunks_from_vectordb(query)
all_search_results.extend(result)
undeduped_len = len(all_search_results)
all_search_results = deduplicate(all_search_results)
deduped_len = len(all_search_results)
if undeduped_len - deduped_len != 0:
send_info(f"移除 {undeduped_len - deduped_len} 个重复文档片段")
2 weeks ago
# search_res_from_internet = deduplicate_results(search_res_from_internet)
# all_search_res.extend(search_res_from_vectordb + search_res_from_internet)
### REFLECTION & GET MORE SUB QUERIES ###
# Only generate more queries if we haven't reached the maximum iterations
if it + 1 < max_iter:
send_info("正在根据文档片段思考 ...")
sub_queries = self._generate_more_sub_queries(
original_query, all_sub_queries, all_search_results
)
if not sub_queries or len(sub_queries) == 0:
send_info("没能生成更多的子问题,正在退出 ....")
break
else:
send_info(f"下一轮搜索的子问题: {sub_queries}")
all_sub_queries.extend(sub_queries)
2 weeks ago
else:
send_info("已达到最大搜索轮数,正在退出 ...")
break
2 weeks ago
all_search_results = deduplicate(all_search_results)
return all_search_results, all_sub_queries
2 weeks ago
def query(self, original_query: str, **kwargs) -> tuple[str, list[RetrievalResult]]:
2 weeks ago
"""
Query the agent and generate an answer based on retrieved documents.
This method retrieves relevant documents and uses the language model
to generate a comprehensive answer to the query.
Args:
query (str): The query to answer.
**kwargs: Additional keyword arguments for customizing the query process.
Returns:
Tuple[str, List[RetrievalResult], int]: A tuple containing:
- The generated answer
- A list of retrieved document results
"""
all_retrieved_results, all_sub_queries = self.retrieve(original_query, **kwargs)
2 weeks ago
if not all_retrieved_results or len(all_retrieved_results) == 0:
send_info(f"'{original_query}'没能找到更多信息!")
return "", []
chunks = self._format_chunks(all_retrieved_results)
send_info(f"正在总结 {len(all_retrieved_results)} 个查找到的文档片段")
2 weeks ago
summary_prompt = SUMMARY_PROMPT.format(
original_query=original_query,
all_sub_queries=all_sub_queries,
chunks=chunks
2 weeks ago
)
response = self.llm.chat([{"role": "user", "content": summary_prompt}])
final_answer = self.llm.remove_think(response)
send_answer(final_answer)
return self.llm.remove_think(response), all_retrieved_results
def _format_chunks(self, retrieved_results: list[RetrievalResult]):
# 以referecen为key,把chunk放到字典中
references = defaultdict(list)
for result in retrieved_results:
references[result.reference].append(result.text)
chunks = []
chunk_count = 0
for i, reference in enumerate(references):
formated = f"<reference id='{i + 1}' href='{reference}'>\n" + "".join(
[
f"<chunk id='{j + 1 + chunk_count}'>\n{chunk}\n</chunk id='{j + 1 + chunk_count}'>\n"
for j, chunk in enumerate(references[reference])
]
) + f"</reference id='{i + 1}'>\n"
print(formated)
chunks.append(formated)
chunk_count += len(references[reference])
return "".join(chunks)