From d58749f825c50c41a36e6e99460d031e6ab1f37a Mon Sep 17 00:00:00 2001 From: tanxing Date: Mon, 18 Aug 2025 11:28:04 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=9B=B4=E6=94=B9=E7=BD=91=E9=A1=B5?= =?UTF-8?q?=E6=90=9C=E7=B4=A2=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- deepsearcher/agent/deep_search.py | 164 ++++++++++++++++++------------ deepsearcher/config.yaml | 10 +- deepsearcher/configuration.py | 1 + deepsearcher/web_search.py | 113 ++++++++++++++++++++ test_web_only.py | 40 ++++++++ test_web_search.py | 75 ++++++++++++++ 6 files changed, 335 insertions(+), 68 deletions(-) create mode 100644 deepsearcher/web_search.py create mode 100644 test_web_only.py create mode 100644 test_web_search.py diff --git a/deepsearcher/agent/deep_search.py b/deepsearcher/agent/deep_search.py index f912084..31bccdf 100644 --- a/deepsearcher/agent/deep_search.py +++ b/deepsearcher/agent/deep_search.py @@ -5,6 +5,7 @@ from deepsearcher.utils import log from deepsearcher.utils.message_stream import send_info, send_answer from deepsearcher.vector_db import RetrievalResult from deepsearcher.vector_db.base import BaseVectorDB, deduplicate +from deepsearcher.web_search import WebSearch from collections import defaultdict from pathlib import Path @@ -135,6 +136,7 @@ class DeepSearch(BaseAgent): max_iter: int, route_collection: bool = False, text_window_splitter: bool = True, + web_search: bool = True, **kwargs, ): """ @@ -147,6 +149,7 @@ class DeepSearch(BaseAgent): max_iter: The maximum number of iterations for the search process. route_collection: Whether to use a collection router for search. text_window_splitter: Whether to use text_window splitter. + enable_web_search: Whether to enable web search functionality. **kwargs: Additional keyword arguments for customization. """ self.llm = llm @@ -155,6 +158,7 @@ class DeepSearch(BaseAgent): self.max_iter = max_iter self.route_collection = route_collection self.text_window_splitter = text_window_splitter + self.web_search = WebSearch() if web_search else None def invoke(self, query: str, dim: int, **kwargs) -> list[str]: """ @@ -219,7 +223,10 @@ class DeepSearch(BaseAgent): content = self.llm.remove_think(content) return self.llm.literal_eval(content) - def _search_chunks_from_vectordb(self, query: str): + def _search_chunks(self, query: str) -> list[RetrievalResult]: + results = [] + + # 本地向量搜索 if self.route_collection: selected_collections = self.invoke( query=query, dim=self.embedding_model.dimension @@ -229,70 +236,78 @@ class DeepSearch(BaseAgent): collection_info.collection_name for collection_info in self.vector_db.list_collections(dim=self.embedding_model.dimension) ] - - all_retrieved_results = [] query_vector = self.embedding_model.embed_query(query) for collection in selected_collections: send_info(f"正在 [{collection}] 中搜索 [{query}] ...") - retrieved_results = self.vector_db.search_data( + vector_results = self.vector_db.search_data( collection=collection, vector=query_vector, query_text=query ) - if not retrieved_results or len(retrieved_results) == 0: + if not vector_results or len(vector_results) == 0: send_info(f"'{collection}' 中没有找到相关文档!") continue + send_info(f"本地向量搜索找到 {len(vector_results)} 个结果") - # Format all chunks for batch processing - chunks, _ = self._format_chunks(retrieved_results) - - # Batch process all chunks with a single LLM call - content = self.llm.chat( - messages=[ - { - "role": "user", - "content": RERANK_PROMPT.format( - query=query, - chunks=chunks, - ), - } - ] - ) - content = self.llm.remove_think(content).strip() - - # Parse the response to determine which chunks are relevant - try: - relevance_list = self.llm.literal_eval(content) - if not isinstance(relevance_list, list): - raise ValueError("Response is not a list") - except Exception as _: - # Fallback: if parsing fails, treat all chunks as relevant - log.color_print(f"Warning: Failed to parse relevance response. Treating all chunks as relevant. Response was: {content}") - relevance_list = ["True"] * len(retrieved_results) - - # Ensure we have enough relevance judgments for all chunks - while len(relevance_list) < len(retrieved_results): - relevance_list.append("True") # Default to relevant if no judgment provided - - # Filter relevant chunks based on LLM response - accepted_chunk_num = 0 - references = set() - for i, retrieved_result in enumerate(retrieved_results): - # Check if we have a relevance judgment for this chunk - is_relevant = ( - i < len(relevance_list) and - "True" in relevance_list[i] and - "False" not in relevance_list[i] - ) if i < len(relevance_list) else True - - if is_relevant: - all_retrieved_results.append(retrieved_result) - accepted_chunk_num += 1 - references.add(retrieved_result.reference) - - if accepted_chunk_num > 0: - send_info(f"采纳 {accepted_chunk_num} 个文档片段,来源:{list(references)}") + # 网页搜索 + if self.web_search: + web_results = self.web_search.search_with_retry(query, size=2) + if web_results: + send_info(f"网页搜索找到 {len(web_results)} 个结果") else: - send_info(f"没有采纳任何 '{collection}' 中找到的文档片段!") - return all_retrieved_results + send_info("网页搜索未找到相关结果") + + retrieved_results = vector_results + web_results + # Format all chunks for batch processing + chunks, _ = self._format_chunks(retrieved_results) + + # Batch process all chunks with a single LLM call + content = self.llm.chat( + messages=[ + { + "role": "user", + "content": RERANK_PROMPT.format( + query=query, + chunks=chunks, + ), + } + ] + ) + content = self.llm.remove_think(content).strip() + + # Parse the response to determine which chunks are relevant + try: + relevance_list = self.llm.literal_eval(content) + if not isinstance(relevance_list, list): + raise ValueError("Response is not a list") + except Exception as _: + # Fallback: if parsing fails, treat all chunks as relevant + log.color_print(f"Warning: Failed to parse relevance response. Treating all chunks as relevant. Response was: {content}") + relevance_list = ["True"] * len(retrieved_results) + + # Ensure we have enough relevance judgments for all chunks + while len(relevance_list) < len(retrieved_results): + relevance_list.append("True") # Default to relevant if no judgment provided + + # Filter relevant chunks based on LLM response + accepted_chunk_num = 0 + references = set() + for i, retrieved_result in enumerate(retrieved_results): + # Check if we have a relevance judgment for this chunk + is_relevant = ( + i < len(relevance_list) and + "True" in relevance_list[i] and + "False" not in relevance_list[i] + ) if i < len(relevance_list) else True + + if is_relevant: + results.append(retrieved_result) + accepted_chunk_num += 1 + references.add(retrieved_result.reference) + + if accepted_chunk_num > 0: + send_info(f"采纳 {accepted_chunk_num} 个文档片段,来源:{list(references)}") + else: + send_info("没有采纳任何找到的文档片段!") + return results def _generate_more_sub_queries( self, original_query: str, all_sub_queries: list[str], all_retrieved_results: list[RetrievalResult] @@ -345,15 +360,16 @@ class DeepSearch(BaseAgent): # Execute all search tasks sequentially for query in sub_queries: - result = self._search_chunks_from_vectordb(query) - all_search_results.extend(result) + results = self._search_chunks(query) + all_search_results.extend(results) + + + # 去重处理 undeduped_len = len(all_search_results) all_search_results = deduplicate(all_search_results) deduped_len = len(all_search_results) if undeduped_len - deduped_len != 0: send_info(f"移除 {undeduped_len - deduped_len} 个重复文档片段") - # search_res_from_internet = deduplicate_results(search_res_from_internet) - # all_search_res.extend(search_res_from_vectordb + search_res_from_internet) ### REFLECTION & GET MORE SUB QUERIES ### # Only generate more queries if we haven't reached the maximum iterations @@ -416,16 +432,23 @@ class DeepSearch(BaseAgent): formated_refs = ["\n\n"] chunk_count = 0 for i, reference in enumerate(ref_dict): + # 检查是否为网页搜索结果 + is_web_result = any( + result.metadata and result.metadata.get("source") == "webpage" + for result in retrieved_results + if result.reference == reference + ) + formated_chunk = "".join( [ ( - f"" + - f"\n{chunk}\n" + + f"\n" + + f"\n{chunk}\n\n" + f"\n" ) if with_chunk_id else ( - f"" + - f"\n{chunk}\n" + + f"\n" + + f"\n{chunk}\n\n" + f"\n" ) for j, chunk in enumerate(ref_dict[reference]) @@ -434,7 +457,18 @@ class DeepSearch(BaseAgent): print(formated_chunk) formated_chunks.append(formated_chunk) chunk_count += len(ref_dict[reference]) - formated_refs.append(f"[^{i + 1}]: " + Path(str(Path(reference).resolve())).as_uri() + "\n") + + # 根据来源类型生成不同的引用格式 + if is_web_result: + # 网页搜索结果直接使用URL + formated_refs.append(f"[^{i + 1}]: {reference}\n") + else: + # 本地文件使用文件URI + try: + formated_refs.append(f"[^{i + 1}]: " + Path(str(Path(reference).resolve())).as_uri() + "\n") + except Exception as _: + formated_refs.append(f"[^{i + 1}]: {reference}\n") + formated_chunks = "".join(formated_chunks) formated_refs = "".join(formated_refs) return formated_chunks, formated_refs diff --git a/deepsearcher/config.yaml b/deepsearcher/config.yaml index 5b59538..88af5bf 100644 --- a/deepsearcher/config.yaml +++ b/deepsearcher/config.yaml @@ -2,9 +2,12 @@ provide_settings: llm: provider: "OpenAILLM" config: - model: "Qwen/Qwen3-32B" - api_key: "sk-fpzwvagjkhwysjsozfybvtjzongatcwqdihdxzuijnfdrjzt" - base_url: "https://api.siliconflow.cn/v1" + # model: "Qwen/Qwen3-32B" + # api_key: "sk-fpzwvagjkhwysjsozfybvtjzongatcwqdihdxzuijnfdrjzt" + #base_url: "https://api.siliconflow.cn/v1" + model: qwen3-32b + api_key: sk-14f39f0c530d4aa0b5588454bff859d6 + base_url: https://dashscope.aliyuncs.com/compatible-mode/v1 embedding: provider: "OpenAIEmbedding" @@ -81,6 +84,7 @@ provide_settings: query_settings: max_iter: 3 + enable_web_search: true load_settings: chunk_size: 2048 diff --git a/deepsearcher/configuration.py b/deepsearcher/configuration.py index 6a6f44c..56763a7 100644 --- a/deepsearcher/configuration.py +++ b/deepsearcher/configuration.py @@ -212,4 +212,5 @@ def init_config(config: Configuration): max_iter=config.query_settings["max_iter"], route_collection=False, text_window_splitter=True, + enable_web_search=config.query_settings.get("enable_web_search", True), ) diff --git a/deepsearcher/web_search.py b/deepsearcher/web_search.py new file mode 100644 index 0000000..11ce773 --- /dev/null +++ b/deepsearcher/web_search.py @@ -0,0 +1,113 @@ +import http.client +import json +import time +from deepsearcher.vector_db import RetrievalResult +from deepsearcher.utils import log + + +class WebSearch: + """网页搜索类,用于调用metaso.cn API进行网页搜索""" + def __init__(self, api_key: str = "mk-CCEA085159C048597435780530A55403"): + """ + 初始化网页搜索 + Args: + api_key (str): metaso.cn API密钥 + """ + self.api_key = api_key + self.base_url = "metaso.cn" + self.endpoint = "/api/v1/search" + + def search(self, query: str, size: int = 4) -> list[RetrievalResult]: + """ + 执行网页搜索 + Args: + query (str): 搜索查询 + size (int): 返回结果数量,默认为4 + Returns: + List[RetrievalResult]: 搜索结果列表 + """ + try: + # 构建请求数据 + payload = json.dumps({ + "q": query, + "scope": "webpage", + "includeSummary": False, + "size": str(size), + "includeRawContent": True, + "conciseSnippet": True + }) + + headers = { + 'Authorization': f'Bearer {self.api_key}', + 'Accept': 'application/json', + 'Content-Type': 'application/json' + } + + # 发送请求 + conn = http.client.HTTPSConnection(self.base_url) + conn.request("POST", self.endpoint, payload, headers) + res = conn.getresponse() + data = res.read() + + if res.status != 200: + log.error(f"网页搜索请求失败: {res.status} - {data.decode('utf-8')}") + return [] + + response_data = json.loads(data.decode("utf-8")) + + # 解析搜索结果 + results = [] + if "webpages" in response_data: + for i, webpage in enumerate(response_data["webpages"]): + # 使用content字段作为主要文本内容 + content = webpage.get("content", "") + if not content: + content = webpage.get("snippet", "") + + # 创建RetrievalResult对象 + result = RetrievalResult( + embedding=None, # 网页搜索结果没有向量 + text=content, + reference=webpage.get("link", ""), + score=1.0 - (i * (1 / size)), # 根据位置计算分数 + metadata={ + "title": webpage.get("title", ""), + "date": webpage.get("date", ""), + "authors": webpage.get("authors", []), + "position": webpage.get("position", i + 1), + "source": "webpage" + } + ) + results.append(result) + + log.info(f"网页搜索成功,找到 {len(results)} 个结果") + return results + + except Exception as e: + log.error(f"网页搜索出错: {str(e)}") + return [] + finally: + if 'conn' in locals(): + conn.close() + + def search_with_retry(self, query: str, size: int = 4, max_retries: int = 3) -> list[RetrievalResult]: + """ + 带重试机制的网页搜索 + Args: + query (str): 搜索查询 + size (int): 返回结果数量 + max_retries (int): 最大重试次数 + Returns: + List[RetrievalResult]: 搜索结果列表 + """ + for attempt in range(max_retries): + try: + results = self.search(query, size) + if results: + return results + except Exception as e: + log.warning(f"网页搜索第 {attempt + 1} 次尝试失败: {str(e)}") + if attempt < max_retries - 1: + time.sleep(1) # 等待1秒后重试 + log.error(f"网页搜索在 {max_retries} 次尝试后仍然失败") + return [] diff --git a/test_web_only.py b/test_web_only.py new file mode 100644 index 0000000..5c6fc86 --- /dev/null +++ b/test_web_only.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +""" +只测试网页搜索功能 +""" + +import sys +import os +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from deepsearcher.web_search import WebSearch + +def test_web_search(): + """测试网页搜索功能""" + print("=== 测试网页搜索功能 ===") + + # 初始化网页搜索 + web_search = WebSearch() + + # 测试查询 + test_query = "Milvus是什么" + print(f"测试查询: {test_query}") + + # 执行搜索 + results = web_search.search_with_retry(test_query, size=4) + + if results: + print(f"✅ 成功找到 {len(results)} 个搜索结果:") + for i, result in enumerate(results, 1): + print(f"\n--- 结果 {i} ---") + print(f"标题: {result.metadata.get('title', 'N/A')}") + print(f"链接: {result.reference}") + print(f"分数: {result.score}") + print(f"内容长度: {len(result.text)} 字符") + print(f"内容预览: {result.text[:200]}...") + print(f"来源: {result.metadata.get('source', 'N/A')}") + else: + print("❌ 未找到搜索结果") + +if __name__ == "__main__": + test_web_search() diff --git a/test_web_search.py b/test_web_search.py new file mode 100644 index 0000000..4387926 --- /dev/null +++ b/test_web_search.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +""" +测试网页搜索功能 +""" + +import sys +import os +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from deepsearcher.web_search import WebSearch +from deepsearcher import configuration + +def test_web_search(): + """测试网页搜索功能""" + print("=== 测试网页搜索功能 ===") + + # 初始化网页搜索 + web_search = WebSearch() + + # 测试查询 + test_query = "Milvus是什么" + print(f"测试查询: {test_query}") + + # 执行搜索 + results = web_search.search_with_retry(test_query, size=4) + + if results: + print(f"找到 {len(results)} 个搜索结果:") + for i, result in enumerate(results, 1): + print(f"\n--- 结果 {i} ---") + print(f"标题: {result.metadata.get('title', 'N/A')}") + print(f"链接: {result.reference}") + print(f"分数: {result.score}") + print(f"内容长度: {len(result.text)} 字符") + print(f"内容预览: {result.text[:200]}...") + else: + print("未找到搜索结果") + +def test_integration(): + """测试与DeepSearch的集成""" + print("\n=== 测试与DeepSearch的集成 ===") + + # 初始化配置 + configuration.init_config(configuration.config) + + # 创建DeepSearch实例(启用网页搜索) + from deepsearcher.agent.deep_search import DeepSearch + + searcher = DeepSearch( + llm=configuration.llm, + embedding_model=configuration.embedding_model, + vector_db=configuration.vector_db, + max_iter=2, + enable_web_search=True + ) + + # 测试查询 + test_query = "Milvus是什么" + print(f"测试查询: {test_query}") + + # 执行搜索 + results, sub_queries = searcher.retrieve(test_query, max_iter=2) + + print(f"生成的子问题: {sub_queries}") + print(f"找到 {len(results)} 个搜索结果") + # 显示结果统计 + web_results = [r for r in results if r.metadata and r.metadata.get("source") == "webpage"] + vector_results = [r for r in results if not r.metadata or r.metadata.get("source") != "webpage"] + + print(f"网页搜索结果: {len(web_results)} 个") + print(f"向量数据库结果: {len(vector_results)} 个") + +if __name__ == "__main__": + test_web_search() + test_integration()