Browse Source

feat: 更改网页搜索逻辑

main
tanxing 5 days ago
parent
commit
d58749f825
  1. 164
      deepsearcher/agent/deep_search.py
  2. 10
      deepsearcher/config.yaml
  3. 1
      deepsearcher/configuration.py
  4. 113
      deepsearcher/web_search.py
  5. 40
      test_web_only.py
  6. 75
      test_web_search.py

164
deepsearcher/agent/deep_search.py

@ -5,6 +5,7 @@ from deepsearcher.utils import log
from deepsearcher.utils.message_stream import send_info, send_answer from deepsearcher.utils.message_stream import send_info, send_answer
from deepsearcher.vector_db import RetrievalResult from deepsearcher.vector_db import RetrievalResult
from deepsearcher.vector_db.base import BaseVectorDB, deduplicate from deepsearcher.vector_db.base import BaseVectorDB, deduplicate
from deepsearcher.web_search import WebSearch
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
@ -135,6 +136,7 @@ class DeepSearch(BaseAgent):
max_iter: int, max_iter: int,
route_collection: bool = False, route_collection: bool = False,
text_window_splitter: bool = True, text_window_splitter: bool = True,
web_search: bool = True,
**kwargs, **kwargs,
): ):
""" """
@ -147,6 +149,7 @@ class DeepSearch(BaseAgent):
max_iter: The maximum number of iterations for the search process. max_iter: The maximum number of iterations for the search process.
route_collection: Whether to use a collection router for search. route_collection: Whether to use a collection router for search.
text_window_splitter: Whether to use text_window splitter. text_window_splitter: Whether to use text_window splitter.
enable_web_search: Whether to enable web search functionality.
**kwargs: Additional keyword arguments for customization. **kwargs: Additional keyword arguments for customization.
""" """
self.llm = llm self.llm = llm
@ -155,6 +158,7 @@ class DeepSearch(BaseAgent):
self.max_iter = max_iter self.max_iter = max_iter
self.route_collection = route_collection self.route_collection = route_collection
self.text_window_splitter = text_window_splitter self.text_window_splitter = text_window_splitter
self.web_search = WebSearch() if web_search else None
def invoke(self, query: str, dim: int, **kwargs) -> list[str]: def invoke(self, query: str, dim: int, **kwargs) -> list[str]:
""" """
@ -219,7 +223,10 @@ class DeepSearch(BaseAgent):
content = self.llm.remove_think(content) content = self.llm.remove_think(content)
return self.llm.literal_eval(content) return self.llm.literal_eval(content)
def _search_chunks_from_vectordb(self, query: str): def _search_chunks(self, query: str) -> list[RetrievalResult]:
results = []
# 本地向量搜索
if self.route_collection: if self.route_collection:
selected_collections = self.invoke( selected_collections = self.invoke(
query=query, dim=self.embedding_model.dimension query=query, dim=self.embedding_model.dimension
@ -229,70 +236,78 @@ class DeepSearch(BaseAgent):
collection_info.collection_name collection_info.collection_name
for collection_info in self.vector_db.list_collections(dim=self.embedding_model.dimension) for collection_info in self.vector_db.list_collections(dim=self.embedding_model.dimension)
] ]
all_retrieved_results = []
query_vector = self.embedding_model.embed_query(query) query_vector = self.embedding_model.embed_query(query)
for collection in selected_collections: for collection in selected_collections:
send_info(f"正在 [{collection}] 中搜索 [{query}] ...") send_info(f"正在 [{collection}] 中搜索 [{query}] ...")
retrieved_results = self.vector_db.search_data( vector_results = self.vector_db.search_data(
collection=collection, vector=query_vector, query_text=query collection=collection, vector=query_vector, query_text=query
) )
if not retrieved_results or len(retrieved_results) == 0: if not vector_results or len(vector_results) == 0:
send_info(f"'{collection}' 中没有找到相关文档!") send_info(f"'{collection}' 中没有找到相关文档!")
continue continue
send_info(f"本地向量搜索找到 {len(vector_results)} 个结果")
# Format all chunks for batch processing # 网页搜索
chunks, _ = self._format_chunks(retrieved_results) if self.web_search:
web_results = self.web_search.search_with_retry(query, size=2)
# Batch process all chunks with a single LLM call if web_results:
content = self.llm.chat( send_info(f"网页搜索找到 {len(web_results)} 个结果")
messages=[
{
"role": "user",
"content": RERANK_PROMPT.format(
query=query,
chunks=chunks,
),
}
]
)
content = self.llm.remove_think(content).strip()
# Parse the response to determine which chunks are relevant
try:
relevance_list = self.llm.literal_eval(content)
if not isinstance(relevance_list, list):
raise ValueError("Response is not a list")
except Exception as _:
# Fallback: if parsing fails, treat all chunks as relevant
log.color_print(f"Warning: Failed to parse relevance response. Treating all chunks as relevant. Response was: {content}")
relevance_list = ["True"] * len(retrieved_results)
# Ensure we have enough relevance judgments for all chunks
while len(relevance_list) < len(retrieved_results):
relevance_list.append("True") # Default to relevant if no judgment provided
# Filter relevant chunks based on LLM response
accepted_chunk_num = 0
references = set()
for i, retrieved_result in enumerate(retrieved_results):
# Check if we have a relevance judgment for this chunk
is_relevant = (
i < len(relevance_list) and
"True" in relevance_list[i] and
"False" not in relevance_list[i]
) if i < len(relevance_list) else True
if is_relevant:
all_retrieved_results.append(retrieved_result)
accepted_chunk_num += 1
references.add(retrieved_result.reference)
if accepted_chunk_num > 0:
send_info(f"采纳 {accepted_chunk_num} 个文档片段,来源:{list(references)}")
else: else:
send_info(f"没有采纳任何 '{collection}' 中找到的文档片段!") send_info("网页搜索未找到相关结果")
return all_retrieved_results
retrieved_results = vector_results + web_results
# Format all chunks for batch processing
chunks, _ = self._format_chunks(retrieved_results)
# Batch process all chunks with a single LLM call
content = self.llm.chat(
messages=[
{
"role": "user",
"content": RERANK_PROMPT.format(
query=query,
chunks=chunks,
),
}
]
)
content = self.llm.remove_think(content).strip()
# Parse the response to determine which chunks are relevant
try:
relevance_list = self.llm.literal_eval(content)
if not isinstance(relevance_list, list):
raise ValueError("Response is not a list")
except Exception as _:
# Fallback: if parsing fails, treat all chunks as relevant
log.color_print(f"Warning: Failed to parse relevance response. Treating all chunks as relevant. Response was: {content}")
relevance_list = ["True"] * len(retrieved_results)
# Ensure we have enough relevance judgments for all chunks
while len(relevance_list) < len(retrieved_results):
relevance_list.append("True") # Default to relevant if no judgment provided
# Filter relevant chunks based on LLM response
accepted_chunk_num = 0
references = set()
for i, retrieved_result in enumerate(retrieved_results):
# Check if we have a relevance judgment for this chunk
is_relevant = (
i < len(relevance_list) and
"True" in relevance_list[i] and
"False" not in relevance_list[i]
) if i < len(relevance_list) else True
if is_relevant:
results.append(retrieved_result)
accepted_chunk_num += 1
references.add(retrieved_result.reference)
if accepted_chunk_num > 0:
send_info(f"采纳 {accepted_chunk_num} 个文档片段,来源:{list(references)}")
else:
send_info("没有采纳任何找到的文档片段!")
return results
def _generate_more_sub_queries( def _generate_more_sub_queries(
self, original_query: str, all_sub_queries: list[str], all_retrieved_results: list[RetrievalResult] self, original_query: str, all_sub_queries: list[str], all_retrieved_results: list[RetrievalResult]
@ -345,15 +360,16 @@ class DeepSearch(BaseAgent):
# Execute all search tasks sequentially # Execute all search tasks sequentially
for query in sub_queries: for query in sub_queries:
result = self._search_chunks_from_vectordb(query) results = self._search_chunks(query)
all_search_results.extend(result) all_search_results.extend(results)
# 去重处理
undeduped_len = len(all_search_results) undeduped_len = len(all_search_results)
all_search_results = deduplicate(all_search_results) all_search_results = deduplicate(all_search_results)
deduped_len = len(all_search_results) deduped_len = len(all_search_results)
if undeduped_len - deduped_len != 0: if undeduped_len - deduped_len != 0:
send_info(f"移除 {undeduped_len - deduped_len} 个重复文档片段") send_info(f"移除 {undeduped_len - deduped_len} 个重复文档片段")
# search_res_from_internet = deduplicate_results(search_res_from_internet)
# all_search_res.extend(search_res_from_vectordb + search_res_from_internet)
### REFLECTION & GET MORE SUB QUERIES ### ### REFLECTION & GET MORE SUB QUERIES ###
# Only generate more queries if we haven't reached the maximum iterations # Only generate more queries if we haven't reached the maximum iterations
@ -416,16 +432,23 @@ class DeepSearch(BaseAgent):
formated_refs = ["\n\n"] formated_refs = ["\n\n"]
chunk_count = 0 chunk_count = 0
for i, reference in enumerate(ref_dict): for i, reference in enumerate(ref_dict):
# 检查是否为网页搜索结果
is_web_result = any(
result.metadata and result.metadata.get("source") == "webpage"
for result in retrieved_results
if result.reference == reference
)
formated_chunk = "".join( formated_chunk = "".join(
[ [
( (
f"<reference id='{i + 1}' href='{reference}'>" + f"<reference id='{i + 1}' href='{reference}'>\n" +
f"<chunk id='{j + 1 + chunk_count}'>\n{chunk}\n</chunk id='{j + 1 + chunk_count}'>" + f"<chunk id='{j + 1 + chunk_count}'>\n{chunk}\n</chunk id='{j + 1 + chunk_count}'>\n" +
f"</reference id='{i + 1}'>\n" f"</reference id='{i + 1}'>\n"
) )
if with_chunk_id else ( if with_chunk_id else (
f"<reference id='{i + 1}' href='{reference}'>" + f"<reference id='{i + 1}' href='{reference}'>\n" +
f"<chunk>\n{chunk}\n</chunk>" + f"<chunk>\n{chunk}\n</chunk>\n" +
f"</reference id='{i + 1}'>\n" f"</reference id='{i + 1}'>\n"
) )
for j, chunk in enumerate(ref_dict[reference]) for j, chunk in enumerate(ref_dict[reference])
@ -434,7 +457,18 @@ class DeepSearch(BaseAgent):
print(formated_chunk) print(formated_chunk)
formated_chunks.append(formated_chunk) formated_chunks.append(formated_chunk)
chunk_count += len(ref_dict[reference]) chunk_count += len(ref_dict[reference])
formated_refs.append(f"[^{i + 1}]: " + Path(str(Path(reference).resolve())).as_uri() + "\n")
# 根据来源类型生成不同的引用格式
if is_web_result:
# 网页搜索结果直接使用URL
formated_refs.append(f"[^{i + 1}]: {reference}\n")
else:
# 本地文件使用文件URI
try:
formated_refs.append(f"[^{i + 1}]: " + Path(str(Path(reference).resolve())).as_uri() + "\n")
except Exception as _:
formated_refs.append(f"[^{i + 1}]: {reference}\n")
formated_chunks = "".join(formated_chunks) formated_chunks = "".join(formated_chunks)
formated_refs = "".join(formated_refs) formated_refs = "".join(formated_refs)
return formated_chunks, formated_refs return formated_chunks, formated_refs

10
deepsearcher/config.yaml

@ -2,9 +2,12 @@ provide_settings:
llm: llm:
provider: "OpenAILLM" provider: "OpenAILLM"
config: config:
model: "Qwen/Qwen3-32B" # model: "Qwen/Qwen3-32B"
api_key: "sk-fpzwvagjkhwysjsozfybvtjzongatcwqdihdxzuijnfdrjzt" # api_key: "sk-fpzwvagjkhwysjsozfybvtjzongatcwqdihdxzuijnfdrjzt"
base_url: "https://api.siliconflow.cn/v1" #base_url: "https://api.siliconflow.cn/v1"
model: qwen3-32b
api_key: sk-14f39f0c530d4aa0b5588454bff859d6
base_url: https://dashscope.aliyuncs.com/compatible-mode/v1
embedding: embedding:
provider: "OpenAIEmbedding" provider: "OpenAIEmbedding"
@ -81,6 +84,7 @@ provide_settings:
query_settings: query_settings:
max_iter: 3 max_iter: 3
enable_web_search: true
load_settings: load_settings:
chunk_size: 2048 chunk_size: 2048

1
deepsearcher/configuration.py

@ -212,4 +212,5 @@ def init_config(config: Configuration):
max_iter=config.query_settings["max_iter"], max_iter=config.query_settings["max_iter"],
route_collection=False, route_collection=False,
text_window_splitter=True, text_window_splitter=True,
enable_web_search=config.query_settings.get("enable_web_search", True),
) )

113
deepsearcher/web_search.py

@ -0,0 +1,113 @@
import http.client
import json
import time
from deepsearcher.vector_db import RetrievalResult
from deepsearcher.utils import log
class WebSearch:
"""网页搜索类,用于调用metaso.cn API进行网页搜索"""
def __init__(self, api_key: str = "mk-CCEA085159C048597435780530A55403"):
"""
初始化网页搜索
Args:
api_key (str): metaso.cn API密钥
"""
self.api_key = api_key
self.base_url = "metaso.cn"
self.endpoint = "/api/v1/search"
def search(self, query: str, size: int = 4) -> list[RetrievalResult]:
"""
执行网页搜索
Args:
query (str): 搜索查询
size (int): 返回结果数量默认为4
Returns:
List[RetrievalResult]: 搜索结果列表
"""
try:
# 构建请求数据
payload = json.dumps({
"q": query,
"scope": "webpage",
"includeSummary": False,
"size": str(size),
"includeRawContent": True,
"conciseSnippet": True
})
headers = {
'Authorization': f'Bearer {self.api_key}',
'Accept': 'application/json',
'Content-Type': 'application/json'
}
# 发送请求
conn = http.client.HTTPSConnection(self.base_url)
conn.request("POST", self.endpoint, payload, headers)
res = conn.getresponse()
data = res.read()
if res.status != 200:
log.error(f"网页搜索请求失败: {res.status} - {data.decode('utf-8')}")
return []
response_data = json.loads(data.decode("utf-8"))
# 解析搜索结果
results = []
if "webpages" in response_data:
for i, webpage in enumerate(response_data["webpages"]):
# 使用content字段作为主要文本内容
content = webpage.get("content", "")
if not content:
content = webpage.get("snippet", "")
# 创建RetrievalResult对象
result = RetrievalResult(
embedding=None, # 网页搜索结果没有向量
text=content,
reference=webpage.get("link", ""),
score=1.0 - (i * (1 / size)), # 根据位置计算分数
metadata={
"title": webpage.get("title", ""),
"date": webpage.get("date", ""),
"authors": webpage.get("authors", []),
"position": webpage.get("position", i + 1),
"source": "webpage"
}
)
results.append(result)
log.info(f"网页搜索成功,找到 {len(results)} 个结果")
return results
except Exception as e:
log.error(f"网页搜索出错: {str(e)}")
return []
finally:
if 'conn' in locals():
conn.close()
def search_with_retry(self, query: str, size: int = 4, max_retries: int = 3) -> list[RetrievalResult]:
"""
带重试机制的网页搜索
Args:
query (str): 搜索查询
size (int): 返回结果数量
max_retries (int): 最大重试次数
Returns:
List[RetrievalResult]: 搜索结果列表
"""
for attempt in range(max_retries):
try:
results = self.search(query, size)
if results:
return results
except Exception as e:
log.warning(f"网页搜索第 {attempt + 1} 次尝试失败: {str(e)}")
if attempt < max_retries - 1:
time.sleep(1) # 等待1秒后重试
log.error(f"网页搜索在 {max_retries} 次尝试后仍然失败")
return []

40
test_web_only.py

@ -0,0 +1,40 @@
#!/usr/bin/env python3
"""
只测试网页搜索功能
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from deepsearcher.web_search import WebSearch
def test_web_search():
"""测试网页搜索功能"""
print("=== 测试网页搜索功能 ===")
# 初始化网页搜索
web_search = WebSearch()
# 测试查询
test_query = "Milvus是什么"
print(f"测试查询: {test_query}")
# 执行搜索
results = web_search.search_with_retry(test_query, size=4)
if results:
print(f"✅ 成功找到 {len(results)} 个搜索结果:")
for i, result in enumerate(results, 1):
print(f"\n--- 结果 {i} ---")
print(f"标题: {result.metadata.get('title', 'N/A')}")
print(f"链接: {result.reference}")
print(f"分数: {result.score}")
print(f"内容长度: {len(result.text)} 字符")
print(f"内容预览: {result.text[:200]}...")
print(f"来源: {result.metadata.get('source', 'N/A')}")
else:
print("❌ 未找到搜索结果")
if __name__ == "__main__":
test_web_search()

75
test_web_search.py

@ -0,0 +1,75 @@
#!/usr/bin/env python3
"""
测试网页搜索功能
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from deepsearcher.web_search import WebSearch
from deepsearcher import configuration
def test_web_search():
"""测试网页搜索功能"""
print("=== 测试网页搜索功能 ===")
# 初始化网页搜索
web_search = WebSearch()
# 测试查询
test_query = "Milvus是什么"
print(f"测试查询: {test_query}")
# 执行搜索
results = web_search.search_with_retry(test_query, size=4)
if results:
print(f"找到 {len(results)} 个搜索结果:")
for i, result in enumerate(results, 1):
print(f"\n--- 结果 {i} ---")
print(f"标题: {result.metadata.get('title', 'N/A')}")
print(f"链接: {result.reference}")
print(f"分数: {result.score}")
print(f"内容长度: {len(result.text)} 字符")
print(f"内容预览: {result.text[:200]}...")
else:
print("未找到搜索结果")
def test_integration():
"""测试与DeepSearch的集成"""
print("\n=== 测试与DeepSearch的集成 ===")
# 初始化配置
configuration.init_config(configuration.config)
# 创建DeepSearch实例(启用网页搜索)
from deepsearcher.agent.deep_search import DeepSearch
searcher = DeepSearch(
llm=configuration.llm,
embedding_model=configuration.embedding_model,
vector_db=configuration.vector_db,
max_iter=2,
enable_web_search=True
)
# 测试查询
test_query = "Milvus是什么"
print(f"测试查询: {test_query}")
# 执行搜索
results, sub_queries = searcher.retrieve(test_query, max_iter=2)
print(f"生成的子问题: {sub_queries}")
print(f"找到 {len(results)} 个搜索结果")
# 显示结果统计
web_results = [r for r in results if r.metadata and r.metadata.get("source") == "webpage"]
vector_results = [r for r in results if not r.metadata or r.metadata.get("source") != "webpage"]
print(f"网页搜索结果: {len(web_results)}")
print(f"向量数据库结果: {len(vector_results)}")
if __name__ == "__main__":
test_web_search()
test_integration()
Loading…
Cancel
Save