|
@ -158,7 +158,7 @@ class DeepSearch(BaseAgent): |
|
|
self.max_iter = max_iter |
|
|
self.max_iter = max_iter |
|
|
self.route_collection = route_collection |
|
|
self.route_collection = route_collection |
|
|
self.text_window_splitter = text_window_splitter |
|
|
self.text_window_splitter = text_window_splitter |
|
|
self.web_search = WebSearch() if web_search else None |
|
|
self.web_search = False |
|
|
|
|
|
|
|
|
def invoke(self, query: str, dim: int, **kwargs) -> list[str]: |
|
|
def invoke(self, query: str, dim: int, **kwargs) -> list[str]: |
|
|
""" |
|
|
""" |
|
@ -223,7 +223,7 @@ class DeepSearch(BaseAgent): |
|
|
content = self.llm.remove_think(content) |
|
|
content = self.llm.remove_think(content) |
|
|
return self.llm.literal_eval(content) |
|
|
return self.llm.literal_eval(content) |
|
|
|
|
|
|
|
|
def _search_chunks(self, query: str) -> list[RetrievalResult]: |
|
|
def _search_chunks(self, query: str, **kwargs) -> list[RetrievalResult]: |
|
|
results = [] |
|
|
results = [] |
|
|
|
|
|
|
|
|
# 本地向量搜索 |
|
|
# 本地向量搜索 |
|
@ -248,6 +248,7 @@ class DeepSearch(BaseAgent): |
|
|
send_info(f"本地向量搜索找到 {len(vector_results)} 个结果") |
|
|
send_info(f"本地向量搜索找到 {len(vector_results)} 个结果") |
|
|
|
|
|
|
|
|
# 网页搜索 |
|
|
# 网页搜索 |
|
|
|
|
|
self.web_search = WebSearch() if kwargs.get('web_search', False) else None |
|
|
if self.web_search: |
|
|
if self.web_search: |
|
|
web_results = self.web_search.search_with_retry(query, size=2) |
|
|
web_results = self.web_search.search_with_retry(query, size=2) |
|
|
if web_results: |
|
|
if web_results: |
|
@ -362,7 +363,7 @@ class DeepSearch(BaseAgent): |
|
|
|
|
|
|
|
|
# Execute all search tasks sequentially |
|
|
# Execute all search tasks sequentially |
|
|
for query in sub_queries: |
|
|
for query in sub_queries: |
|
|
results = self._search_chunks(query) |
|
|
results = self._search_chunks(query, **kwargs) |
|
|
all_search_results.extend(results) |
|
|
all_search_results.extend(results) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|