Compare commits

...

4 Commits

  1. 11
      deepsearcher/agent/deep_search.py
  2. 246
      deepsearcher/backend/templates/index.html
  3. 2
      deepsearcher/config.yaml
  4. 15
      deepsearcher/offline_loading.py
  5. 10
      deepsearcher/utils/message_stream.py
  6. 3
      deepsearcher/vector_db/milvus.py
  7. 210
      main.py
  8. 31
      test_iteration.py

11
deepsearcher/agent/deep_search.py

@ -323,6 +323,9 @@ class DeepSearch(BaseAgent):
- A list of retrieved document results
- Additional information about the retrieval process
"""
# Get max_iter from kwargs or use default
max_iter = kwargs.get('max_iter', self.max_iter)
### SUB QUERIES ###
send_think(f"<query> {original_query} </query>")
all_search_results = []
@ -336,7 +339,7 @@ class DeepSearch(BaseAgent):
send_think(f"Break down the original query into new sub queries: {sub_queries}")
all_sub_queries.extend(sub_queries)
for it in range(self.max_iter):
for it in range(max_iter):
send_think(f">> Iteration: {it + 1}")
# Execute all search tasks sequentially
@ -350,10 +353,10 @@ class DeepSearch(BaseAgent):
send_search(f"Remove {undeduped_len - deduped_len} duplicates")
# search_res_from_internet = deduplicate_results(search_res_from_internet)
# all_search_res.extend(search_res_from_vectordb + search_res_from_internet)
### REFLECTION & GET MORE SUB QUERIES ###
# Only generate more queries if we haven't reached the maximum iterations
if it + 1 < self.max_iter:
if it + 1 < max_iter:
send_think("Reflecting on the search results...")
sub_queries = self._generate_more_sub_queries(
original_query, all_sub_queries, all_search_results
@ -400,7 +403,7 @@ class DeepSearch(BaseAgent):
)
response = self.llm.chat([{"role": "user", "content": summary_prompt}])
final_answer = self.llm.remove_think(response)
send_answer("==== FINAL ANSWER====")
# 直接发送最终答案,不发送占位符
send_answer(final_answer)
return self.llm.remove_think(response), all_retrieved_results

246
deepsearcher/backend/templates/index.html

@ -330,6 +330,10 @@
</div>
<script>
// 全局变量
let eventSource = null;
let isStreaming = false;
// 工具函数:显示状态信息
function showStatus(elementId, message, type) {
const statusElement = document.getElementById(elementId);
@ -355,26 +359,50 @@
container.innerHTML = '';
messages.forEach(message => {
const messageElement = document.createElement('div');
messageElement.className = `message message-${message.type}`;
const contentElement = document.createElement('div');
contentElement.textContent = message.content;
const timestampElement = document.createElement('div');
timestampElement.className = 'message-timestamp';
const date = new Date(message.timestamp * 1000);
timestampElement.textContent = date.toLocaleTimeString();
messageElement.appendChild(contentElement);
messageElement.appendChild(timestampElement);
container.appendChild(messageElement);
addMessageToContainer(message);
});
// 滚动到底部
container.scrollTop = container.scrollHeight;
}
// 工具函数:添加单个消息到容器
function addMessageToContainer(message) {
console.log('Adding message to container:', message);
const container = document.getElementById('messageContainer');
if (!container) {
console.error('Message container not found!');
return;
}
const messageElement = document.createElement('div');
messageElement.className = `message message-${message.type}`;
const contentElement = document.createElement('div');
contentElement.textContent = message.content;
const timestampElement = document.createElement('div');
timestampElement.className = 'message-timestamp';
const date = new Date(message.timestamp * 1000);
timestampElement.textContent = date.toLocaleTimeString();
messageElement.appendChild(contentElement);
messageElement.appendChild(timestampElement);
container.appendChild(messageElement);
// 确保结果容器是可见的
const resultContainer = document.getElementById('queryResult');
if (resultContainer && !resultContainer.classList.contains('visible')) {
resultContainer.classList.add('visible');
}
// 滚动到底部
container.scrollTop = container.scrollHeight;
console.log('Message added successfully, container now has', container.children.length, 'messages');
}
// 工具函数:隐藏状态信息
function hideStatus(elementId) {
const statusElement = document.getElementById(elementId);
@ -418,6 +446,119 @@
}
}
// 工具函数:关闭EventSource连接
function closeEventSource() {
if (eventSource) {
console.log('Closing eventSource in closeEventSource function');
eventSource.close();
eventSource = null;
}
if (window.currentEventSource) {
console.log('Closing currentEventSource in closeEventSource function');
window.currentEventSource.close();
window.currentEventSource = null;
}
isStreaming = false;
}
// 工具函数:处理实时消息流
function handleStreamMessage(data) {
try {
const message = JSON.parse(data);
switch (message.type) {
case 'connection':
console.log('Connected to message stream:', message.message);
break;
case 'heartbeat':
// 心跳消息,不需要处理
break;
case 'query_start':
console.log('Query started:', message.query);
showStatus('queryStatus', '查询已开始,正在处理...', 'loading');
break;
case 'complete':
console.log('Query completed - closing connection');
showStatus('queryStatus', '查询完成', 'success');
// 关闭EventSource连接
if (window.currentEventSource) {
console.log('Closing currentEventSource');
window.currentEventSource.close();
window.currentEventSource = null;
}
isStreaming = false;
setButtonLoading(document.getElementById('queryBtn'), false);
console.log('Query completed - connection closed, isStreaming set to false');
break;
case 'query_error':
console.error('Query error:', message.error);
showStatus('queryStatus', `查询失败: ${message.error}`, 'error');
// 关闭EventSource连接
if (window.currentEventSource) {
window.currentEventSource.close();
window.currentEventSource = null;
}
isStreaming = false;
setButtonLoading(document.getElementById('queryBtn'), false);
break;
case 'stream_error':
console.error('Stream error:', message.error);
showStatus('queryStatus', `流错误: ${message.error}`, 'error');
// 关闭EventSource连接
if (window.currentEventSource) {
window.currentEventSource.close();
window.currentEventSource = null;
}
isStreaming = false;
setButtonLoading(document.getElementById('queryBtn'), false);
break;
case 'search':
case 'think':
// 处理常规消息
console.log('Processing message type:', message.type, 'with content:', message.content.substring(0, 100) + '...');
addMessageToContainer(message);
break;
case 'answer':
// 处理answer类型,显示查询结果
console.log('Processing message type:', message.type, 'with content:', message.content.substring(0, 100) + '...');
// 将结果内容显示在结果区域
if (message.content && message.content !== "==== FINAL ANSWER====") {
document.getElementById('resultText').textContent = message.content;
showResult();
}
// 同时添加到消息容器中
addMessageToContainer(message);
break;
default:
console.log('Unknown message type:', message.type);
}
} catch (error) {
console.error('Error parsing message:', error);
}
}
// 工具函数:开始实时消息流
function startMessageStream() {
closeEventSource(); // 关闭之前的连接
eventSource = new EventSource('/stream-messages/');
eventSource.onopen = function(event) {
console.log('EventSource connection opened');
};
eventSource.onmessage = function(event) {
handleStreamMessage(event.data);
};
eventSource.onerror = function(event) {
console.error('EventSource error:', event);
if (eventSource.readyState === EventSource.CLOSED) {
console.log('EventSource connection closed');
}
};
}
// 加载文件功能
document.getElementById('loadFilesBtn').addEventListener('click', async function() {
const button = this;
@ -530,7 +671,7 @@
}
});
// 查询功能
// 查询功能 - 使用实时流
document.getElementById('queryBtn').addEventListener('click', async function() {
const button = this;
const queryText = document.getElementById('queryText').value;
@ -546,39 +687,72 @@
return;
}
if (isStreaming) {
console.log('Query already in progress, isStreaming:', isStreaming);
showStatus('queryStatus', '查询正在进行中,请等待完成', 'error');
return;
}
setButtonLoading(button, true);
showStatus('queryStatus', '正在处理查询...', 'loading');
showStatus('queryStatus', '正在启动查询...', 'loading');
hideResult();
// 清空消息容器
const container = document.getElementById('messageContainer');
container.innerHTML = '';
try {
const response = await fetch(`/query/?original_query=${encodeURIComponent(queryText)}&max_iter=${maxIter}`, {
method: 'GET',
headers: {
'Content-Type': 'application/json'
}
});
console.log('Starting new query, setting isStreaming to true');
isStreaming = true;
const data = await response.json();
// 确保没有其他连接存在
if (window.currentEventSource) {
console.log('Closing existing EventSource connection');
window.currentEventSource.close();
window.currentEventSource = null;
}
if (response.ok) {
showStatus('queryStatus', '查询完成', 'success');
document.getElementById('resultText').textContent = data.result;
// 显示消息流
if (data.messages && data.messages.length > 0) {
displayMessages(data.messages);
// 使用EventSource直接连接到查询流
const eventSource = new EventSource(`/query-stream/?original_query=${encodeURIComponent(queryText)}&max_iter=${maxIter}`);
// 保存EventSource引用以便后续关闭
window.currentEventSource = eventSource;
eventSource.onopen = function(event) {
console.log('EventSource connection opened for query');
showStatus('queryStatus', '查询已开始,正在处理...', 'loading');
};
eventSource.onmessage = function(event) {
console.log('Received message:', event.data);
handleStreamMessage(event.data);
};
eventSource.onerror = function(event) {
console.error('EventSource error:', event);
if (eventSource.readyState === EventSource.CLOSED) {
console.log('EventSource connection closed due to error');
isStreaming = false;
setButtonLoading(button, false);
window.currentEventSource = null;
}
showResult();
} else {
showStatus('queryStatus', `查询失败: ${data.detail}`, 'error');
}
};
} catch (error) {
console.error('Query error:', error);
showStatus('queryStatus', `请求失败: ${error.message}`, 'error');
} finally {
isStreaming = false;
setButtonLoading(button, false);
}
});
// 页面卸载时清理连接
window.addEventListener('beforeunload', function() {
if (window.currentEventSource) {
window.currentEventSource.close();
window.currentEventSource = null;
}
});
</script>
</body>
</html>

2
deepsearcher/config.yaml

@ -80,7 +80,7 @@ provide_settings:
# port: 6333
query_settings:
max_iter: 1
max_iter: 3
load_settings:
chunk_size: 1024

15
deepsearcher/offline_loading.py

@ -43,14 +43,13 @@ def load_from_local_files(
embedding_model = configuration.embedding_model
file_loader = configuration.file_loader
# 如果force_rebuild为True,则强制重建集合
if force_rebuild:
vector_db.init_collection(
dim=embedding_model.dimension,
collection=collection_name,
description=collection_description,
force_rebuild=True,
)
# 初始化集合(如果不存在则创建,如果force_rebuild为True则重建)
vector_db.init_collection(
dim=embedding_model.dimension,
collection=collection_name,
description=collection_description,
force_rebuild=force_rebuild,
)
if isinstance(paths_or_directory, str):
paths_or_directory = [paths_or_directory]

10
deepsearcher/utils/message_stream.py

@ -11,6 +11,7 @@ class MessageType(Enum):
SEARCH = "search"
THINK = "think"
ANSWER = "answer"
COMPLETE = "complete"
@dataclass
@ -93,6 +94,10 @@ class MessageStream:
"""发送答案消息"""
self.send_message(MessageType.ANSWER, content, metadata)
def send_complete(self, metadata: Optional[Dict[str, Any]] = None):
"""发送完成消息"""
self.send_message(MessageType.COMPLETE, "", metadata)
def get_messages(self) -> List[Message]:
"""获取所有消息"""
return self._messages.copy()
@ -133,6 +138,11 @@ def send_answer(content: str, metadata: Optional[Dict[str, Any]] = None):
message_stream.send_answer(content, metadata)
def send_complete(metadata: Optional[Dict[str, Any]] = None):
"""全局完成消息发送函数"""
message_stream.send_complete(metadata)
def get_message_stream() -> MessageStream:
"""获取全局消息流实例"""
return message_stream

3
deepsearcher/vector_db/milvus.py

@ -18,6 +18,7 @@ class Milvus(BaseVectorDB):
user: str = "",
password: str = "",
db: str = "default",
default_collection: str = "deepsearcher",
**kwargs,
):
"""
@ -29,9 +30,11 @@ class Milvus(BaseVectorDB):
user (str, optional): Username for authentication. Defaults to "".
password (str, optional): Password for authentication. Defaults to "".
db (str, optional): Database name. Defaults to "default".
default_collection (str, optional): Default collection name. Defaults to "deepsearcher".
**kwargs: Additional keyword arguments to pass to the MilvusClient.
"""
super().__init__()
self.default_collection = default_collection
self.client = MilvusClient(
uri=uri, user=user, password=password, token=token, db_name=db, timeout=30, **kwargs
)

210
main.py

@ -3,9 +3,13 @@ import argparse
import uvicorn
from fastapi import Body, FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from fastapi.responses import HTMLResponse, StreamingResponse
from pydantic import BaseModel
import os
import asyncio
import json
import queue
from typing import AsyncGenerator
from deepsearcher.configuration import Configuration, init_config
from deepsearcher.offline_loading import load_from_local_files, load_from_website
@ -18,6 +22,9 @@ config = Configuration()
init_config(config)
# 全局变量用于存储消息流回调
message_callbacks = []
class ProviderConfigRequest(BaseModel):
"""
@ -98,6 +105,11 @@ def load_files(
description="Optional batch size for the collection.",
examples=[256],
),
force_rebuild: bool = Body(
False,
description="Whether to force rebuild the collection if it already exists.",
examples=[False],
),
):
"""
Load files into the vector database.
@ -107,6 +119,7 @@ def load_files(
collection_name (str, optional): Name for the collection. Defaults to None.
collection_description (str, optional): Description for the collection. Defaults to None.
batch_size (int, optional): Batch size for processing. Defaults to None.
force_rebuild (bool, optional): Whether to force rebuild the collection. Defaults to False.
Returns:
dict: A dictionary containing a success message.
@ -120,6 +133,7 @@ def load_files(
collection_name=collection_name,
collection_description=collection_description,
batch_size=batch_size if batch_size is not None else 8,
force_rebuild=force_rebuild,
)
return {"message": "Files loaded successfully."}
except Exception as e:
@ -148,6 +162,11 @@ def load_website(
description="Optional batch size for the collection.",
examples=[256],
),
force_rebuild: bool = Body(
False,
description="Whether to force rebuild the collection if it already exists.",
examples=[False],
),
):
"""
Load website content into the vector database.
@ -157,6 +176,7 @@ def load_website(
collection_name (str, optional): Name for the collection. Defaults to None.
collection_description (str, optional): Description for the collection. Defaults to None.
batch_size (int, optional): Batch size for processing. Defaults to None.
force_rebuild (bool, optional): Whether to force rebuild the collection. Defaults to False.
Returns:
dict: A dictionary containing a success message.
@ -170,6 +190,7 @@ def load_website(
collection_name=collection_name,
collection_description=collection_description,
batch_size=batch_size if batch_size is not None else 8,
force_rebuild=force_rebuild,
)
return {"message": "Website loaded successfully."}
except Exception as e:
@ -218,6 +239,125 @@ def perform_query(
raise HTTPException(status_code=500, detail=str(e))
@app.get("/query-stream/")
async def perform_query_stream(
original_query: str = Query(
...,
description="Your question here.",
examples=["Write a report about Milvus."],
),
max_iter: int = Query(
3,
description="The maximum number of iterations for reflection.",
ge=1,
examples=[3],
),
) -> StreamingResponse:
"""
Perform a query with real-time streaming of progress and results.
Args:
original_query (str): The user's question or query.
max_iter (int, optional): Maximum number of iterations for reflection. Defaults to 3.
Returns:
StreamingResponse: A streaming response with real-time query progress and results.
"""
async def query_stream_generator() -> AsyncGenerator[str, None]:
"""生成查询流"""
message_callback = None
try:
# 清空之前的消息
message_stream = get_message_stream()
message_stream.clear_messages()
# 发送查询开始消息
yield f"data: {json.dumps({'type': 'query_start', 'query': original_query, 'max_iter': max_iter}, ensure_ascii=False)}\n\n"
# 创建一个线程安全的队列来接收消息
message_queue = queue.Queue()
def message_callback(message):
"""消息回调函数"""
try:
# 直接放入线程安全的队列
message_queue.put(message)
except Exception as e:
print(f"Error in message callback: {e}")
# 注册回调函数
message_stream.add_callback(message_callback)
# 在后台线程中执行查询
def run_query():
try:
print(f"Starting query: {original_query} with max_iter: {max_iter}")
result_text, _ = query(original_query, max_iter)
print(f"Query completed with result length: {len(result_text) if result_text else 0}")
return result_text, None
except Exception as e:
print(f"Query failed with error: {e}")
return None, str(e)
# 使用线程池执行查询
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
query_future = executor.submit(run_query)
# 监听消息和查询结果
while True:
try:
# 检查查询是否已完成
if query_future.done():
result_text, error = query_future.result()
if error:
print(f"Query error: {error}")
yield f"data: {json.dumps({'type': 'query_error', 'error': error}, ensure_ascii=False)}\n\n"
return
# 成功完成时,发送complete消息并结束
print("Query completed successfully, sending complete message")
yield f"data: {json.dumps({'type': 'complete'}, ensure_ascii=False)}\n\n"
print("Complete message sent, ending stream")
return
# 尝试从队列获取消息,设置超时
try:
message = message_queue.get(timeout=0.5)
message_data = {
'type': message.type.value,
'content': message.content,
'timestamp': message.timestamp,
'metadata': message.metadata or {}
}
yield f"data: {json.dumps(message_data, ensure_ascii=False)}\n\n"
except queue.Empty:
# 超时,继续循环
continue
except Exception as e:
print(f"Error in stream loop: {e}")
yield f"data: {json.dumps({'type': 'stream_error', 'error': str(e)}, ensure_ascii=False)}\n\n"
return
except Exception as e:
yield f"data: {json.dumps({'type': 'stream_error', 'error': str(e)}, ensure_ascii=False)}\n\n"
finally:
# 清理回调函数
if message_callback:
message_stream.remove_callback(message_callback)
return StreamingResponse(
query_stream_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "Cache-Control"
}
)
@app.get("/messages/")
def get_messages():
"""
@ -235,6 +375,74 @@ def get_messages():
raise HTTPException(status_code=500, detail=str(e))
@app.get("/stream-messages/")
async def stream_messages() -> StreamingResponse:
"""
Stream messages in real-time using Server-Sent Events.
Returns:
StreamingResponse: A streaming response with real-time messages.
"""
async def event_generator() -> AsyncGenerator[str, None]:
"""生成SSE事件流"""
# 创建一个队列来接收消息
message_queue = asyncio.Queue()
def message_callback(message):
"""消息回调函数"""
try:
# 将消息放入队列
asyncio.create_task(message_queue.put(message))
except Exception as e:
print(f"Error in message callback: {e}")
# 注册回调函数
message_callbacks.append(message_callback)
message_stream = get_message_stream()
message_stream.add_callback(message_callback)
try:
# 发送连接建立消息
yield f"data: {json.dumps({'type': 'connection', 'message': 'Connected to message stream'}, ensure_ascii=False)}\n\n"
while True:
try:
# 等待新消息,设置超时以便能够检查连接状态
message = await asyncio.wait_for(message_queue.get(), timeout=1.0)
# 发送消息数据
message_data = {
'type': message.type.value,
'content': message.content,
'timestamp': message.timestamp,
'metadata': message.metadata or {}
}
yield f"data: {json.dumps(message_data, ensure_ascii=False)}\n\n"
except asyncio.TimeoutError:
# 发送心跳保持连接
yield f"data: {json.dumps({'type': 'heartbeat', 'timestamp': asyncio.get_event_loop().time()}, ensure_ascii=False)}\n\n"
except Exception as e:
print(f"Error in event generator: {e}")
finally:
# 清理回调函数
if message_callback in message_callbacks:
message_callbacks.remove(message_callback)
message_stream.remove_callback(message_callback)
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "Cache-Control"
}
)
@app.post("/clear-messages/")
def clear_messages():
"""

31
test_iteration.py

@ -1,31 +0,0 @@
#!/usr/bin/env python3
"""
测试迭代逻辑的脚本
"""
def test_iteration_logic():
"""测试迭代逻辑"""
max_iter = 3
print(f"测试最大迭代次数: {max_iter}")
print("-" * 40)
for it in range(max_iter):
print(f">> Iteration: {it + 1}")
# 模拟搜索任务
print(" 执行搜索任务...")
# 检查是否达到最大迭代次数
if it + 1 < max_iter:
print(" 反思并生成新的子查询...")
print(" 准备下一次迭代")
else:
print(" 达到最大迭代次数,退出")
break
print("-" * 40)
print("测试完成!")
if __name__ == "__main__":
test_iteration_logic()
Loading…
Cancel
Save