Browse Source

fix: 修复deep_search.py中新建collection无法获取的问题

fix: 修改错误的脚注样式
main
tanxing 3 days ago
parent
commit
8bab5b066e
  1. 29
      deepsearcher/agent/deep_search.py

29
deepsearcher/agent/deep_search.py

@ -92,14 +92,13 @@ SUMMARY_PROMPT = """
如果检索到的信息不足以回答问题你应该使用你的知识来进行扩展补充
注意不要逐个回答问题而是应该综合所有问题和信息生成一个完整的回答
同时你应该根据提供的信息生成文内引用"[^index]"(markdown文内引用)
如果你自己提供的信息需要使用"[^0]"引用即你提供的信息使用固定index=0
来自<chunk><reference>的引用序号从[^index]从index=1开始来源需要与前文<reference>中的"href"一致
不需要对每个<chunk>分配一个引用而是相同<reference><chunk>共用一个引用
来自<chunk><reference>的引用序号从[^index]从index=1开始来源需要与前文<reference>中的"id"一致
不需要对每个<chunk>分配一个引用而是相同<reference><chunk>共用引用并确保每一个<reference>都被引用
另外如果回答的内容文内引用需要引用多个<reference>请添加多个[^index]到句尾
<EXAMPLE>
"XGBoost是非常强大的集成学习模型。[^1]但是XGBoost的缺点是计算复杂度高,需要大量的计算资源。[^0]"
"XGBoost是非常强大的集成学习模型。[^1]但是XGBoost的缺点是计算复杂度高,需要大量的计算资源。[^2]"
</EXAMPLE>
@ -152,10 +151,6 @@ class DeepSearch(BaseAgent):
self.vector_db = vector_db
self.max_iter = max_iter
self.route_collection = route_collection
self.all_collections = [
collection_info.collection_name
for collection_info in self.vector_db.list_collections(dim=embedding_model.dimension)
]
self.text_window_splitter = text_window_splitter
def invoke(self, query: str, dim: int, **kwargs) -> list[str]:
@ -227,7 +222,10 @@ class DeepSearch(BaseAgent):
query=query, dim=self.embedding_model.dimension
)
else:
selected_collections = self.all_collections
selected_collections = [
collection_info.collection_name
for collection_info in self.vector_db.list_collections(dim=self.embedding_model.dimension)
]
all_retrieved_results = []
query_vector = self.embedding_model.embed_query(query)
@ -394,7 +392,7 @@ class DeepSearch(BaseAgent):
if not all_retrieved_results or len(all_retrieved_results) == 0:
send_info(f"'{original_query}'没能找到更多信息!")
return "", []
chunks, refs = self._format_chunks(all_retrieved_results)
chunks, refs = self._format_chunks(all_retrieved_results, with_chunk_id=False)
send_info(f"正在总结 {len(all_retrieved_results)} 个查找到的文档片段")
summary_prompt = SUMMARY_PROMPT.format(
original_query=original_query,
@ -406,13 +404,13 @@ class DeepSearch(BaseAgent):
send_answer(response)
return response, all_retrieved_results
def _format_chunks(self, retrieved_results: list[RetrievalResult]) -> tuple[str, str]:
def _format_chunks(self, retrieved_results: list[RetrievalResult], with_chunk_id: bool = True) -> tuple[str, str]:
# 以referecen为key,把chunk放到字典中
ref_dict = defaultdict(list)
for result in retrieved_results:
ref_dict[result.reference].append(result.text)
formated_chunks = []
formated_refs = ["\n\n[^0]: AI 生成\n"]
formated_refs = ["\n\n"]
chunk_count = 0
for i, reference in enumerate(ref_dict):
formated_chunk = "".join(
@ -421,6 +419,11 @@ class DeepSearch(BaseAgent):
f"<reference id='{i + 1}' href='{reference}'>" +
f"<chunk id='{j + 1 + chunk_count}'>\n{chunk}\n</chunk id='{j + 1 + chunk_count}'>" +
f"</reference id='{i + 1}'>\n"
)
if with_chunk_id else (
f"<reference id='{i + 1}' href='{reference}'>" +
f"<chunk>\n{chunk}\n</chunk>" +
f"</reference id='{i + 1}'>\n"
)
for j, chunk in enumerate(ref_dict[reference])
]
@ -428,7 +431,7 @@ class DeepSearch(BaseAgent):
print(formated_chunk)
formated_chunks.append(formated_chunk)
chunk_count += len(ref_dict[reference])
formated_refs.append(f"[^{i + 1}]: " + str(reference) + "\n")
formated_refs.append(f"[{i + 1}]: " + str(reference) + "\n")
formated_chunks = "".join(formated_chunks)
formated_refs = "".join(formated_refs)
return formated_chunks, formated_refs

Loading…
Cancel
Save