pd 1 day ago
parent
commit
0164b14d7e
  1. 7
      deepsearcher/agent/deep_search.py
  2. 3
      deepsearcher/online_query.py
  3. 12
      deepsearcher/templates/html/index.html
  4. 18
      deepsearcher/templates/static/css/styles.css
  5. 3
      deepsearcher/templates/static/js/app.js
  6. 8
      main.py

7
deepsearcher/agent/deep_search.py

@ -158,7 +158,7 @@ class DeepSearch(BaseAgent):
self.max_iter = max_iter self.max_iter = max_iter
self.route_collection = route_collection self.route_collection = route_collection
self.text_window_splitter = text_window_splitter self.text_window_splitter = text_window_splitter
self.web_search = WebSearch() if web_search else None self.web_search = False
def invoke(self, query: str, dim: int, **kwargs) -> list[str]: def invoke(self, query: str, dim: int, **kwargs) -> list[str]:
""" """
@ -223,7 +223,7 @@ class DeepSearch(BaseAgent):
content = self.llm.remove_think(content) content = self.llm.remove_think(content)
return self.llm.literal_eval(content) return self.llm.literal_eval(content)
def _search_chunks(self, query: str) -> list[RetrievalResult]: def _search_chunks(self, query: str, **kwargs) -> list[RetrievalResult]:
results = [] results = []
# 本地向量搜索 # 本地向量搜索
@ -248,6 +248,7 @@ class DeepSearch(BaseAgent):
send_info(f"本地向量搜索找到 {len(vector_results)} 个结果") send_info(f"本地向量搜索找到 {len(vector_results)} 个结果")
# 网页搜索 # 网页搜索
self.web_search = WebSearch() if kwargs.get('web_search', False) else None
if self.web_search: if self.web_search:
web_results = self.web_search.search_with_retry(query, size=2) web_results = self.web_search.search_with_retry(query, size=2)
if web_results: if web_results:
@ -362,7 +363,7 @@ class DeepSearch(BaseAgent):
# Execute all search tasks sequentially # Execute all search tasks sequentially
for query in sub_queries: for query in sub_queries:
results = self._search_chunks(query) results = self._search_chunks(query, **kwargs)
all_search_results.extend(results) all_search_results.extend(results)

3
deepsearcher/online_query.py

@ -21,7 +21,8 @@ def query(original_query: str, **kwargs) -> tuple[str, list[RetrievalResult]]:
""" """
default_searcher = configuration.default_searcher default_searcher = configuration.default_searcher
max_iter = kwargs.get("max_iter", 3) max_iter = kwargs.get("max_iter", 3)
return default_searcher.query(original_query, max_iter=max_iter) web_search = kwargs.get("web_search", False)
return default_searcher.query(original_query, max_iter=max_iter, web_search=web_search)
def retrieve(original_query: str, max_iter: int | None = None) -> tuple[list[RetrievalResult], list[str]]: def retrieve(original_query: str, max_iter: int | None = None) -> tuple[list[RetrievalResult], list[str]]:

12
deepsearcher/templates/html/index.html

@ -86,6 +86,18 @@
value="3" value="3"
/> />
</div> </div>
<div class="form-group">
<label class="checkbox-label">
<input
type="checkbox"
id="webSearch"
/>
启用网络搜索
</label>
<small style="color: var(--text-secondary); display: block; margin-top: 4px;">
启用后将同时搜索本地文档和网络内容,提供更全面的答案
</small>
</div>
<button id="queryBtn">执行查询</button> <button id="queryBtn">执行查询</button>
<button <button
id="clearMessagesBtn" id="clearMessagesBtn"

18
deepsearcher/templates/static/css/styles.css

@ -84,6 +84,24 @@ label {
font-weight: 500; font-weight: 500;
} }
.checkbox-label {
display: flex !important;
align-items: center;
cursor: pointer;
font-weight: 500;
margin-bottom: 8px;
flex-direction: row;
gap: 8px;
}
.checkbox-label input[type="checkbox"] {
width: 16px;
height: 16px;
margin: 0;
cursor: pointer;
flex-shrink: 0;
}
input, input,
textarea, textarea,
select { select {

3
deepsearcher/templates/static/js/app.js

@ -366,6 +366,7 @@ document
const button = this; const button = this;
const queryText = document.getElementById('queryText').value; const queryText = document.getElementById('queryText').value;
const maxIter = parseInt(document.getElementById('maxIter').value); const maxIter = parseInt(document.getElementById('maxIter').value);
const webSearch = document.getElementById('webSearch').checked;
if (!queryText) { if (!queryText) {
showStatus('queryStatus', '请输入查询问题', 'error'); showStatus('queryStatus', '请输入查询问题', 'error');
@ -407,7 +408,7 @@ document
const eventSource = new EventSource( const eventSource = new EventSource(
`/query-stream/?original_query=${encodeURIComponent( `/query-stream/?original_query=${encodeURIComponent(
queryText queryText
)}&max_iter=${maxIter}` )}&max_iter=${maxIter}&web_search=${webSearch}`
); );
// 保存EventSource引用以便后续关闭 // 保存EventSource引用以便后续关闭

8
main.py

@ -159,6 +159,11 @@ async def perform_query_stream(
ge=1, ge=1,
examples=[3], examples=[3],
), ),
web_search: bool = Query(
False,
description="Whether to enable web search functionality.",
examples=[True, False],
),
) -> StreamingResponse: ) -> StreamingResponse:
""" """
Perform a query with real-time streaming of progress and results. Perform a query with real-time streaming of progress and results.
@ -166,6 +171,7 @@ async def perform_query_stream(
Args: Args:
original_query (str): The user's question or query. original_query (str): The user's question or query.
max_iter (int, optional): Maximum number of iterations for reflection. Defaults to 3. max_iter (int, optional): Maximum number of iterations for reflection. Defaults to 3.
web_search (bool, optional): Whether to enable web search functionality. Defaults to False.
Returns: Returns:
StreamingResponse: A streaming response with real-time query progress and results. StreamingResponse: A streaming response with real-time query progress and results.
@ -199,7 +205,7 @@ async def perform_query_stream(
def run_query(): def run_query():
try: try:
print(f"Starting query: {original_query} with max_iter: {max_iter}") print(f"Starting query: {original_query} with max_iter: {max_iter}")
result_text, retrieval_results = query(original_query, max_iter=max_iter) result_text, retrieval_results = query(original_query, max_iter=max_iter, web_search=web_search)
print(f"Query completed with result length: {len(result_text) if result_text else 0}") print(f"Query completed with result length: {len(result_text) if result_text else 0}")
print(f"Retrieved {len(retrieval_results) if retrieval_results else 0} documents") print(f"Retrieved {len(retrieval_results) if retrieval_results else 0} documents")
return result_text, None return result_text, None

Loading…
Cancel
Save