Browse Source

移除token使用统计,简化消息发送

main
tanxing 2 weeks ago
parent
commit
d4484950df
  1. 9
      deepsearcher/agent/base.py
  2. 23
      deepsearcher/agent/collection_router.py
  3. 111
      deepsearcher/agent/deep_search.py
  4. 16
      deepsearcher/agent/rag_router.py
  5. 8
      deepsearcher/config.yaml
  6. 88
      deepsearcher/llm/base.py
  7. 26
      deepsearcher/llm/openai_llm.py
  8. 12
      deepsearcher/online_query.py
  9. 10
      pyproject.toml
  10. 2
      test.py

9
deepsearcher/agent/base.py

@ -1,5 +1,4 @@
from abc import ABC from abc import ABC
from typing import Any, List, Tuple
from deepsearcher.vector_db import RetrievalResult from deepsearcher.vector_db import RetrievalResult
@ -42,7 +41,7 @@ class BaseAgent(ABC):
""" """
pass pass
def invoke(self, query: str, **kwargs) -> Any: def invoke(self, query: str, **kwargs) -> any:
""" """
Invoke the agent and return the result. Invoke the agent and return the result.
@ -72,7 +71,7 @@ class RAGAgent(BaseAgent):
""" """
pass pass
def retrieve(self, query: str, **kwargs) -> Tuple[List[RetrievalResult], int, dict]: def retrieve(self, query: str, **kwargs) -> tuple[list[RetrievalResult], dict]:
""" """
Retrieve document results from the knowledge base. Retrieve document results from the knowledge base.
@ -83,11 +82,10 @@ class RAGAgent(BaseAgent):
Returns: Returns:
A tuple containing: A tuple containing:
- the retrieved results - the retrieved results
- the total number of token usages of the LLM
- any additional metadata, which can be an empty dictionary - any additional metadata, which can be an empty dictionary
""" """
def query(self, query: str, **kwargs) -> Tuple[str, List[RetrievalResult], int]: def query(self, query: str, **kwargs) -> tuple[str, list[RetrievalResult]]:
""" """
Query the agent and return the answer. Query the agent and return the answer.
@ -99,5 +97,4 @@ class RAGAgent(BaseAgent):
A tuple containing: A tuple containing:
- the result generated from LLM - the result generated from LLM
- the retrieved document results - the retrieved document results
- the total number of token usages of the LLM
""" """

23
deepsearcher/agent/collection_router.py

@ -1,17 +1,18 @@
from typing import List, Tuple
from deepsearcher.agent.base import BaseAgent from deepsearcher.agent.base import BaseAgent
from deepsearcher.llm.base import BaseLLM from deepsearcher.llm.base import BaseLLM
from deepsearcher.utils import log from deepsearcher.utils import log
from deepsearcher.vector_db.base import BaseVectorDB from deepsearcher.vector_db.base import BaseVectorDB
COLLECTION_ROUTE_PROMPT = """ COLLECTION_ROUTE_PROMPT = """
I provide you with collection_name(s) and corresponding collection_description(s). Please select the collection names that may be related to the question and return a python list of str. If there is no collection related to the question, you can return an empty list. I provide you with collection_name(s) and corresponding collection_description(s).
Please select the collection names that may be related to the question and return a python list of str.
If there is no collection related to the question, you can return an empty list.
"QUESTION": {question} "QUESTION": {question}
"COLLECTION_INFO": {collection_info} "COLLECTION_INFO": {collection_info}
When you return, you can ONLY return a json convertable python list of str, WITHOUT any other additional content. Your selected collection name list is: When you return, you can ONLY return a json convertable python list of str, WITHOUT any other additional content.
Your selected collection name list is:
""" """
@ -39,7 +40,7 @@ class CollectionRouter(BaseAgent):
for collection_info in self.vector_db.list_collections(dim=dim) for collection_info in self.vector_db.list_collections(dim=dim)
] ]
def invoke(self, query: str, dim: int, **kwargs) -> Tuple[List[str], int]: def invoke(self, query: str, dim: int, **kwargs) -> list[str]:
""" """
Determine which collections are relevant for the given query. Determine which collections are relevant for the given query.
@ -51,11 +52,8 @@ class CollectionRouter(BaseAgent):
dim (int): The dimension of the vector space to search in. dim (int): The dimension of the vector space to search in.
Returns: Returns:
Tuple[List[str], int]: A tuple containing: List[str]: A list of selected collection names
- A list of selected collection names
- The token usage for the routing operation
""" """
consume_tokens = 0
collection_infos = self.vector_db.list_collections(dim=dim) collection_infos = self.vector_db.list_collections(dim=dim)
if len(collection_infos) == 0: if len(collection_infos) == 0:
log.color_print( log.color_print(
@ -78,11 +76,10 @@ class CollectionRouter(BaseAgent):
for collection_info in collection_infos for collection_info in collection_infos
], ],
) )
chat_response = self.llm.chat( response = self.llm.chat(
messages=[{"role": "user", "content": vector_db_search_prompt}] messages=[{"role": "user", "content": vector_db_search_prompt}]
) )
selected_collections = self.llm.literal_eval(chat_response.content) selected_collections = self.llm.literal_eval(response)
consume_tokens += chat_response.total_tokens
for collection_info in collection_infos: for collection_info in collection_infos:
# If a collection description is not provided, use the query as the search query # If a collection description is not provided, use the query as the search query
@ -95,4 +92,4 @@ class CollectionRouter(BaseAgent):
log.color_print( log.color_print(
f"<think> Perform search [{query}] on the vector DB collections: {selected_collections} </think>\n" f"<think> Perform search [{query}] on the vector DB collections: {selected_collections} </think>\n"
) )
return selected_collections, consume_tokens return selected_collections

111
deepsearcher/agent/deep_search.py

@ -1,5 +1,4 @@
import asyncio import asyncio
from typing import List, Tuple
from deepsearcher.agent.base import RAGAgent, describe_class from deepsearcher.agent.base import RAGAgent, describe_class
from deepsearcher.agent.collection_router import CollectionRouter from deepsearcher.agent.collection_router import CollectionRouter
@ -33,7 +32,9 @@ Example output:
Provide your response in a python code list of str format: Provide your response in a python code list of str format:
""" """
RERANK_PROMPT = """Based on the query questions and the retrieved chunks, determine whether each chunk is helpful in answering any of the query questions. For each chunk, you must return "YES" or "NO" without any other information. RERANK_PROMPT = """
Based on the query questions and the retrieved chunks, determine whether each chunk is helpful in answering any of the query questions.
For each chunk, you must return "YES" or "NO" without any other information.
Query Questions: {query} Query Questions: {query}
@ -43,7 +44,9 @@ Retrieved Chunks:
Respond with a list of "YES" or "NO" values, one for each chunk, in the same order as the chunks are listed. For example a list of chunks of three: ["YES", "NO", "YES"]""" Respond with a list of "YES" or "NO" values, one for each chunk, in the same order as the chunks are listed. For example a list of chunks of three: ["YES", "NO", "YES"]"""
REFLECT_PROMPT = """Determine whether additional search queries are needed based on the original query, previous sub queries, and all retrieved document chunks. If further research is required, provide a Python list of up to 3 search queries. If no further research is required, return an empty list. REFLECT_PROMPT = """
Determine whether additional search queries are needed based on the original query, previous sub queries, and all retrieved document chunks.
If further research is required, provide a Python list of up to 3 search queries. If no further research is required, return an empty list.
If the original query is to write a report, then you prefer to generate some further queries, instead return an empty list. If the original query is to write a report, then you prefer to generate some further queries, instead return an empty list.
@ -51,19 +54,21 @@ Original Query: {question}
Previous Sub Queries: {mini_questions} Previous Sub Queries: {mini_questions}
Related Chunks: Related Chunks:
{mini_chunk_str} {mini_chunk_str}
Respond exclusively in valid List of str format without any other text.""" Respond exclusively in valid List of str format without any other text."""
SUMMARY_PROMPT = """You are a AI content analysis expert, good at summarizing content. Please summarize a long, specific and detailed answer or report based on the previous queries and the retrieved document chunks. SUMMARY_PROMPT = """
You are a AI content analysis expert, good at summarizing content.
Please summarize a long, specific and detailed answer or report based on the previous queries and the retrieved document chunks.
Original Query: {question} Original Query: {question}
Previous Sub Queries: {mini_questions} Previous Sub Queries: {mini_questions}
Related Chunks: Related Chunks:
{mini_chunk_str} {mini_chunk_str}
""" """
@ -112,25 +117,22 @@ class DeepSearch(RAGAgent):
) )
self.text_window_splitter = text_window_splitter self.text_window_splitter = text_window_splitter
def _generate_sub_queries(self, original_query: str) -> Tuple[List[str], int]: def _generate_sub_queries(self, original_query: str) -> tuple[list[str], int]:
chat_response = self.llm.chat( content = self.llm.chat(
messages=[ messages=[
{"role": "user", "content": SUB_QUERY_PROMPT.format(original_query=original_query)} {"role": "user", "content": SUB_QUERY_PROMPT.format(original_query=original_query)}
] ]
) )
response_content = self.llm.remove_think(chat_response.content) content = self.llm.remove_think(content)
return self.llm.literal_eval(response_content), chat_response.total_tokens return self.llm.literal_eval(content)
async def _search_chunks_from_vectordb(self, query: str): async def _search_chunks_from_vectordb(self, query: str):
consume_tokens = 0
if self.route_collection: if self.route_collection:
selected_collections, n_token_route = self.collection_router.invoke( selected_collections = self.collection_router.invoke(
query=query, dim=self.embedding_model.dimension query=query, dim=self.embedding_model.dimension
) )
else: else:
selected_collections = self.collection_router.all_collections selected_collections = self.collection_router.all_collections
n_token_route = 0
consume_tokens += n_token_route
all_retrieved_results = [] all_retrieved_results = []
query_vector = self.embedding_model.embed_query(query) query_vector = self.embedding_model.embed_query(query)
@ -144,14 +146,14 @@ class DeepSearch(RAGAgent):
f"<search> No relevant document chunks found in '{collection}'! </search>\n" f"<search> No relevant document chunks found in '{collection}'! </search>\n"
) )
continue continue
# Format all chunks for batch processing # Format all chunks for batch processing
formatted_chunks = "" formatted_chunks = ""
for i, retrieved_result in enumerate(retrieved_results): for i, retrieved_result in enumerate(retrieved_results):
formatted_chunks += f"<chunk_{i}>\n{retrieved_result.text}\n</chunk_{i}>\n" formatted_chunks += f"<chunk_{i}>\n{retrieved_result.text}\n</chunk_{i}>\n"
# Batch process all chunks with a single LLM call # Batch process all chunks with a single LLM call
chat_response = self.llm.chat( content = self.llm.chat(
messages=[ messages=[
{ {
"role": "user", "role": "user",
@ -162,37 +164,38 @@ class DeepSearch(RAGAgent):
} }
] ]
) )
consume_tokens += chat_response.total_tokens content = self.llm.remove_think(content).strip()
response_content = self.llm.remove_think(chat_response.content).strip()
# Parse the response to determine which chunks are relevant # Parse the response to determine which chunks are relevant
try: try:
relevance_list = self.llm.literal_eval(response_content) relevance_list = self.llm.literal_eval(content)
if not isinstance(relevance_list, list): if not isinstance(relevance_list, list):
raise ValueError("Response is not a list") raise ValueError("Response is not a list")
except (ValueError, SyntaxError): except (ValueError, SyntaxError):
# Fallback: if parsing fails, treat all chunks as relevant # Fallback: if parsing fails, treat all chunks as relevant
log.color_print(f"Warning: Failed to parse relevance response. Treating all chunks as relevant. Response was: {response_content}") log.color_print(f"Warning: Failed to parse relevance response. Treating all chunks as relevant. Response was: {content}")
relevance_list = ["YES"] * len(retrieved_results) relevance_list = ["YES"] * len(retrieved_results)
# Ensure we have enough relevance judgments for all chunks # Ensure we have enough relevance judgments for all chunks
while len(relevance_list) < len(retrieved_results): while len(relevance_list) < len(retrieved_results):
relevance_list.append("YES") # Default to relevant if no judgment provided relevance_list.append("YES") # Default to relevant if no judgment provided
# Filter relevant chunks based on LLM response # Filter relevant chunks based on LLM response
accepted_chunk_num = 0 accepted_chunk_num = 0
references = set() references = set()
for i, retrieved_result in enumerate(retrieved_results): for i, retrieved_result in enumerate(retrieved_results):
# Check if we have a relevance judgment for this chunk # Check if we have a relevance judgment for this chunk
is_relevant = (i < len(relevance_list) and is_relevant = (
"YES" in relevance_list[i].upper() and i < len(relevance_list) and
"NO" not in relevance_list[i].upper()) if i < len(relevance_list) else True "YES" in relevance_list[i].upper() and
"NO" not in relevance_list[i].upper()) if i < len(relevance_list
) else True
if is_relevant: if is_relevant:
all_retrieved_results.append(retrieved_result) all_retrieved_results.append(retrieved_result)
accepted_chunk_num += 1 accepted_chunk_num += 1
references.add(retrieved_result.reference) references.add(retrieved_result.reference)
if accepted_chunk_num > 0: if accepted_chunk_num > 0:
log.color_print( log.color_print(
f"<search> Accept {accepted_chunk_num} document chunk(s) from references: {list(references)} </search>\n" f"<search> Accept {accepted_chunk_num} document chunk(s) from references: {list(references)} </search>\n"
@ -201,11 +204,11 @@ class DeepSearch(RAGAgent):
log.color_print( log.color_print(
f"<search> No document chunk accepted from '{collection}'! </search>\n" f"<search> No document chunk accepted from '{collection}'! </search>\n"
) )
return all_retrieved_results, consume_tokens return all_retrieved_results
def _generate_gap_queries( def _generate_gap_queries(
self, original_query: str, all_sub_queries: List[str], all_chunks: List[RetrievalResult] self, original_query: str, all_sub_queries: list[str], all_chunks: list[RetrievalResult]
) -> Tuple[List[str], int]: ) -> list[str]:
reflect_prompt = REFLECT_PROMPT.format( reflect_prompt = REFLECT_PROMPT.format(
question=original_query, question=original_query,
mini_questions=all_sub_queries, mini_questions=all_sub_queries,
@ -213,11 +216,11 @@ class DeepSearch(RAGAgent):
if len(all_chunks) > 0 if len(all_chunks) > 0
else "NO RELATED CHUNKS FOUND.", else "NO RELATED CHUNKS FOUND.",
) )
chat_response = self.llm.chat([{"role": "user", "content": reflect_prompt}]) response = self.llm.chat([{"role": "user", "content": reflect_prompt}])
response_content = self.llm.remove_think(chat_response.content) response = self.llm.remove_think(response)
return self.llm.literal_eval(response_content), chat_response.total_tokens return self.llm.literal_eval(response)
def retrieve(self, original_query: str, **kwargs) -> Tuple[List[RetrievalResult], int, dict]: def retrieve(self, original_query: str, **kwargs) -> tuple[list[RetrievalResult], dict]:
""" """
Retrieve relevant documents from the knowledge base for the given query. Retrieve relevant documents from the knowledge base for the given query.
@ -231,26 +234,23 @@ class DeepSearch(RAGAgent):
Returns: Returns:
Tuple[List[RetrievalResult], int, dict]: A tuple containing: Tuple[List[RetrievalResult], int, dict]: A tuple containing:
- A list of retrieved document results - A list of retrieved document results
- The token usage for the retrieval operation
- Additional information about the retrieval process - Additional information about the retrieval process
""" """
return asyncio.run(self.async_retrieve(original_query, **kwargs)) return asyncio.run(self.async_retrieve(original_query, **kwargs))
async def async_retrieve( async def async_retrieve(
self, original_query: str, **kwargs self, original_query: str, **kwargs
) -> Tuple[List[RetrievalResult], int, dict]: ) -> tuple[list[RetrievalResult], dict]:
max_iter = kwargs.pop("max_iter", self.max_iter) max_iter = kwargs.pop("max_iter", self.max_iter)
### SUB QUERIES ### ### SUB QUERIES ###
log.color_print(f"<query> {original_query} </query>\n") log.color_print(f"<query> {original_query} </query>\n")
all_search_res = [] all_search_res = []
all_sub_queries = [] all_sub_queries = []
total_tokens = 0
sub_queries, used_token = self._generate_sub_queries(original_query) sub_queries = self._generate_sub_queries(original_query)
total_tokens += used_token
if not sub_queries: if not sub_queries:
log.color_print("No sub queries were generated by the LLM. Exiting.") log.color_print("No sub queries were generated by the LLM. Exiting.")
return [], total_tokens, {} return [], {}
else: else:
log.color_print( log.color_print(
f"<think> Break down the original query into new sub queries: {sub_queries}</think>\n" f"<think> Break down the original query into new sub queries: {sub_queries}</think>\n"
@ -272,8 +272,7 @@ class DeepSearch(RAGAgent):
search_results = await asyncio.gather(*search_tasks) search_results = await asyncio.gather(*search_tasks)
# Merge all results # Merge all results
for result in search_results: for result in search_results:
search_res, consumed_token = result search_res = result
total_tokens += consumed_token
search_res_from_vectordb.extend(search_res) search_res_from_vectordb.extend(search_res)
search_res_from_vectordb = deduplicate_results(search_res_from_vectordb) search_res_from_vectordb = deduplicate_results(search_res_from_vectordb)
@ -284,10 +283,9 @@ class DeepSearch(RAGAgent):
break break
### REFLECTION & GET GAP QUERIES ### ### REFLECTION & GET GAP QUERIES ###
log.color_print("<think> Reflecting on the search results... </think>\n") log.color_print("<think> Reflecting on the search results... </think>\n")
sub_gap_queries, consumed_token = self._generate_gap_queries( sub_gap_queries = self._generate_gap_queries(
original_query, all_sub_queries, all_search_res original_query, all_sub_queries, all_search_res
) )
total_tokens += consumed_token
if not sub_gap_queries or len(sub_gap_queries) == 0: if not sub_gap_queries or len(sub_gap_queries) == 0:
log.color_print("<think> No new search queries were generated. Exiting. </think>\n") log.color_print("<think> No new search queries were generated. Exiting. </think>\n")
break break
@ -299,9 +297,9 @@ class DeepSearch(RAGAgent):
all_search_res = deduplicate_results(all_search_res) all_search_res = deduplicate_results(all_search_res)
additional_info = {"all_sub_queries": all_sub_queries} additional_info = {"all_sub_queries": all_sub_queries}
return all_search_res, total_tokens, additional_info return all_search_res, additional_info
def query(self, query: str, **kwargs) -> Tuple[str, List[RetrievalResult], int]: def query(self, query: str, **kwargs) -> tuple[str, list[RetrievalResult]]:
""" """
Query the agent and generate an answer based on retrieved documents. Query the agent and generate an answer based on retrieved documents.
@ -316,11 +314,10 @@ class DeepSearch(RAGAgent):
Tuple[str, List[RetrievalResult], int]: A tuple containing: Tuple[str, List[RetrievalResult], int]: A tuple containing:
- The generated answer - The generated answer
- A list of retrieved document results - A list of retrieved document results
- The total token usage
""" """
all_retrieved_results, n_token_retrieval, additional_info = self.retrieve(query, **kwargs) all_retrieved_results, additional_info = self.retrieve(query, **kwargs)
if not all_retrieved_results or len(all_retrieved_results) == 0: if not all_retrieved_results or len(all_retrieved_results) == 0:
return f"No relevant information found for query '{query}'.", [], n_token_retrieval return f"No relevant information found for query '{query}'.", []
all_sub_queries = additional_info["all_sub_queries"] all_sub_queries = additional_info["all_sub_queries"]
chunk_texts = [] chunk_texts = []
for chunk in all_retrieved_results: for chunk in all_retrieved_results:
@ -336,16 +333,12 @@ class DeepSearch(RAGAgent):
mini_questions=all_sub_queries, mini_questions=all_sub_queries,
mini_chunk_str=self._format_chunk_texts(chunk_texts), mini_chunk_str=self._format_chunk_texts(chunk_texts),
) )
chat_response = self.llm.chat([{"role": "user", "content": summary_prompt}]) response = self.llm.chat([{"role": "user", "content": summary_prompt}])
log.color_print("\n==== FINAL ANSWER====\n") log.color_print("\n==== FINAL ANSWER====\n")
log.color_print(self.llm.remove_think(chat_response.content)) log.color_print(self.llm.remove_think(response))
return ( return self.llm.remove_think(response), all_retrieved_results
self.llm.remove_think(chat_response.content),
all_retrieved_results,
n_token_retrieval + chat_response.total_tokens,
)
def _format_chunk_texts(self, chunk_texts: List[str]) -> str: def _format_chunk_texts(self, chunk_texts: list[str]) -> str:
chunk_str = "" chunk_str = ""
for i, chunk in enumerate(chunk_texts): for i, chunk in enumerate(chunk_texts):
chunk_str += f"""<chunk_{i}>\n{chunk}\n</chunk_{i}>\n""" chunk_str += f"""<chunk_{i}>\n{chunk}\n</chunk_{i}>\n"""

16
deepsearcher/agent/rag_router.py

@ -1,9 +1,7 @@
from typing import List, Tuple
from deepsearcher.agent import RAGAgent from deepsearcher.agent import RAGAgent
from deepsearcher.vector_db import RetrievalResult from deepsearcher.vector_db import RetrievalResult
RAG_ROUTER_PROMPT = """Given a list of agent indexes and corresponding descriptions, each agent has a specific function. RAG_ROUTER_PROMPT = """Given a list of agent indexes and corresponding descriptions, each agent has a specific function.
Given a query, select only one agent that best matches the agent handling the query, and return the index without any other information. Given a query, select only one agent that best matches the agent handling the query, and return the index without any other information.
## Question ## Question
@ -38,10 +36,10 @@ class RAGRouter(RAGAgent):
""" """
self.agent = agent self.agent = agent
def retrieve(self, query: str, **kwargs) -> Tuple[List[RetrievalResult], int, dict]: def retrieve(self, query: str, **kwargs) -> tuple[list[RetrievalResult], dict]:
retrieved_results, n_token_retrieval, metadata = self.agent.retrieve(query, **kwargs) retrieved_results, metadata = self.agent.retrieve(query, **kwargs)
return retrieved_results, n_token_retrieval, metadata return retrieved_results, metadata
def query(self, query: str, **kwargs) -> Tuple[str, List[RetrievalResult], int]: def query(self, query: str, **kwargs) -> tuple[str, list[RetrievalResult]]:
answer, retrieved_results, n_token_retrieval = self.agent.query(query, **kwargs) answer, retrieved_results = self.agent.query(query, **kwargs)
return answer, retrieved_results, n_token_retrieval return answer, retrieved_results

8
deepsearcher/config.yaml

@ -19,13 +19,13 @@ provide_settings:
provider: "PDFLoader" provider: "PDFLoader"
config: {} config: {}
provider: "TextLoader"
config: {}
# provider: "JsonFileLoader" # provider: "JsonFileLoader"
# config: # config:
# text_key: "" # text_key: ""
provider: "TextLoader"
config: {}
# provider: "UnstructuredLoader" # provider: "UnstructuredLoader"
# config: {} # config: {}
@ -83,5 +83,5 @@ query_settings:
max_iter: 2 max_iter: 2
load_settings: load_settings:
chunk_size: 1024 chunk_size: 2048
chunk_overlap: 128 chunk_overlap: 128

88
deepsearcher/llm/base.py

@ -1,40 +1,6 @@
import ast import ast
import re import re
from abc import ABC from abc import ABC
from typing import Dict, List
class ChatResponse(ABC):
"""
Represents a response from a chat model.
This class encapsulates the content of a response from a chat model
along with information about token usage.
Attributes:
content: The text content of the response.
total_tokens: The total number of tokens used in the request and response.
"""
def __init__(self, content: str, total_tokens: int) -> None:
"""
Initialize a ChatResponse object.
Args:
content: The text content of the response.
total_tokens: The total number of tokens used in the request and response.
"""
self.content = content
self.total_tokens = total_tokens
def __repr__(self) -> str:
"""
Return a string representation of the ChatResponse.
Returns:
A string representation of the ChatResponse object.
"""
return f"ChatResponse(content={self.content}, total_tokens={self.total_tokens})"
class BaseLLM(ABC): class BaseLLM(ABC):
@ -51,21 +17,23 @@ class BaseLLM(ABC):
""" """
pass pass
def chat(self, messages: List[Dict]) -> ChatResponse: def chat(self, messages: list[dict]) -> str:
""" """
Send a chat message to the language model and get a response. Send a chat message to the language model and get a response.
Args: Args:
messages: A list of message dictionaries, typically in the format messages:
[{"role": "system", "content": "..."}, {"role": "user", "content": "..."}] A list of message dictionaries, typically in the format
[{"role": "system", "content": "..."}, {"role": "user", "content": "..."}]
Returns: Returns:
A ChatResponse object containing the model's response. response (str)
The content of the llm response.
""" """
pass pass
@staticmethod @staticmethod
def literal_eval(response_content: str): def literal_eval(response: str) -> str:
""" """
Parse a string response into a Python object using ast.literal_eval. Parse a string response into a Python object using ast.literal_eval.
@ -73,37 +41,37 @@ class BaseLLM(ABC):
handling various formats like code blocks and special tags. handling various formats like code blocks and special tags.
Args: Args:
response_content: The string content to parse. response: The string content to parse.
Returns: Returns:
The parsed Python object. The processed and parsed Python object.
Raises: Raises:
ValueError: If the response content cannot be parsed. ValueError: If the response content cannot be parsed.
""" """
response_content = response_content.strip() response = response.strip()
response_content = BaseLLM.remove_think(response_content) response = BaseLLM.remove_think(response)
try: try:
if response_content.startswith("```") and response_content.endswith("```"): if response.startswith("```") and response.endswith("```"):
if response_content.startswith("```python"): if response.startswith("```python"):
response_content = response_content[9:-3] response = response[9:-3]
elif response_content.startswith("```json"): elif response.startswith("```json"):
response_content = response_content[7:-3] response = response[7:-3]
elif response_content.startswith("```str"): elif response.startswith("```str"):
response_content = response_content[6:-3] response = response[6:-3]
elif response_content.startswith("```\n"): elif response.startswith("```\n"):
response_content = response_content[4:-3] response = response[4:-3]
else: else:
raise ValueError("Invalid code block format") raise ValueError("Invalid code block format")
result = ast.literal_eval(response_content.strip()) result = ast.literal_eval(response.strip())
except Exception: except Exception:
matches = re.findall(r"(\[.*?\]|\{.*?\})", response_content, re.DOTALL) matches = re.findall(r"(\[.*?\]|\{.*?\})", response, re.DOTALL)
if len(matches) != 1: if len(matches) != 1:
raise ValueError( raise ValueError(
f"Invalid JSON/List format for response content:\n{response_content}" f"Invalid JSON/List format for response content:\n{response}"
) )
json_part = matches[0] json_part = matches[0]
@ -112,9 +80,9 @@ class BaseLLM(ABC):
return result return result
@staticmethod @staticmethod
def remove_think(response_content: str) -> str: def remove_think(response: str) -> str:
# remove content between <think> and </think>, especial for reasoning model # remove content between <think> and </think>, especial for reasoning model
if "<think>" in response_content and "</think>" in response_content: if "<think>" in response and "</think>" in response:
end_of_think = response_content.find("</think>") + len("</think>") end_of_think = response.find("</think>") + len("</think>")
response_content = response_content[end_of_think:] response = response[end_of_think:]
return response_content.strip() return response.strip()

26
deepsearcher/llm/openai_llm.py

@ -1,7 +1,6 @@
import os import os
from typing import Dict, List
from deepsearcher.llm.base import BaseLLM, ChatResponse from deepsearcher.llm.base import BaseLLM
class OpenAILLM(BaseLLM): class OpenAILLM(BaseLLM):
@ -39,23 +38,28 @@ class OpenAILLM(BaseLLM):
base_url = os.getenv("OPENAI_BASE_URL") base_url = os.getenv("OPENAI_BASE_URL")
self.client = OpenAI(api_key=api_key, base_url=base_url, **kwargs) self.client = OpenAI(api_key=api_key, base_url=base_url, **kwargs)
def chat(self, messages: List[Dict]) -> ChatResponse: def chat(self, messages: list[dict], stream_callback = None) -> str:
""" """
Send a chat message to the OpenAI model and get a response. Send a chat message to the OpenAI model and get a response.
Args: Args:
messages (List[Dict]): A list of message dictionaries, typically in the format messages (List[Dict]):
[{"role": "system", "content": "..."}, A list of message dictionaries, typically in the format
{"role": "user", "content": "..."}] [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}]
Returns: Returns:
ChatResponse: An object containing the model's response and token usage information. response (str)
""" """
completion = self.client.chat.completions.create( completion = self.client.chat.completions.create(
model=self.model, model=self.model,
messages=messages, messages=messages,
stream=True
) )
return ChatResponse( response = ""
content=completion.choices[0].message.content, for chunk in completion:
total_tokens=completion.usage.total_tokens, stream_response = chunk.choices[0].delta.content
) if stream_response:
response += stream_response
if stream_callback:
stream_callback(stream_response)
return response

12
deepsearcher/online_query.py

@ -5,7 +5,7 @@ from deepsearcher import configuration
from deepsearcher.vector_db.base import RetrievalResult from deepsearcher.vector_db.base import RetrievalResult
def query(original_query: str, max_iter: int = 3) -> Tuple[str, List[RetrievalResult], int]: def query(original_query: str, max_iter: int = 3) -> Tuple[str, List[RetrievalResult]]:
""" """
Query the knowledge base with a question and get an answer. Query the knowledge base with a question and get an answer.
@ -20,7 +20,6 @@ def query(original_query: str, max_iter: int = 3) -> Tuple[str, List[RetrievalRe
A tuple containing: A tuple containing:
- The generated answer as a string - The generated answer as a string
- A list of retrieval results that were used to generate the answer - A list of retrieval results that were used to generate the answer
- The number of tokens consumed during the process
""" """
default_searcher = configuration.default_searcher default_searcher = configuration.default_searcher
return default_searcher.query(original_query, max_iter=max_iter) return default_searcher.query(original_query, max_iter=max_iter)
@ -28,7 +27,7 @@ def query(original_query: str, max_iter: int = 3) -> Tuple[str, List[RetrievalRe
def retrieve( def retrieve(
original_query: str, max_iter: int = 3 original_query: str, max_iter: int = 3
) -> Tuple[List[RetrievalResult], List[str], int]: ) -> Tuple[List[RetrievalResult], List[str]]:
""" """
Retrieve relevant information from the knowledge base without generating an answer. Retrieve relevant information from the knowledge base without generating an answer.
@ -43,13 +42,12 @@ def retrieve(
A tuple containing: A tuple containing:
- A list of retrieval results - A list of retrieval results
- An empty list (placeholder for future use) - An empty list (placeholder for future use)
- The number of tokens consumed during the process
""" """
default_searcher = configuration.default_searcher default_searcher = configuration.default_searcher
retrieved_results, consume_tokens, metadata = default_searcher.retrieve( retrieved_results, consume_tokens, metadata = default_searcher.retrieve(
original_query, max_iter=max_iter original_query, max_iter=max_iter
) )
return retrieved_results, [], consume_tokens return retrieved_results, []
def naive_retrieve(query: str, collection: str = None, top_k=10) -> List[RetrievalResult]: def naive_retrieve(query: str, collection: str = None, top_k=10) -> List[RetrievalResult]:
@ -68,7 +66,7 @@ def naive_retrieve(query: str, collection: str = None, top_k=10) -> List[Retriev
A list of retrieval results. A list of retrieval results.
""" """
naive_rag = configuration.naive_rag naive_rag = configuration.naive_rag
all_retrieved_results, consume_tokens, _ = naive_rag.retrieve(query) all_retrieved_results, _ = naive_rag.retrieve(query)
return all_retrieved_results return all_retrieved_results
@ -92,5 +90,5 @@ def naive_rag_query(
- A list of retrieval results that were used to generate the answer - A list of retrieval results that were used to generate the answer
""" """
naive_rag = configuration.naive_rag naive_rag = configuration.naive_rag
answer, retrieved_results, consume_tokens = naive_rag.query(query) answer, retrieved_results = naive_rag.query(query)
return answer, retrieved_results return answer, retrieved_results

10
pyproject.toml

@ -19,11 +19,7 @@ dependencies = [
"uvicorn>=0.34.2", "uvicorn>=0.34.2",
] ]
description = "None" description = "None"
license = { file = "LICENSE"}
authors = [
{ name = "Cheney Zhang", email = "277584121@qq.com" },
{ name = "SimFG", email = "bang.fu@zilliz.com" }
]
[project.urls] [project.urls]
Homepage = "https://github.com/zilliztech/deep-searcher" Homepage = "https://github.com/zilliztech/deep-searcher"
@ -156,7 +152,7 @@ exclude = [
] ]
# Same as Black. # Same as Black.
line-length = 100 line-length = 256
indent-width = 4 indent-width = 4
# Assume Python 3.10 # Assume Python 3.10
@ -167,7 +163,7 @@ show-fixes = true
[tool.ruff.lint] [tool.ruff.lint]
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
# Enable isort (`I`) # Enable isort (`I`)
select = ["E4", "E7", "E9", "F", "I"] select = ["E", "W", "UP", "F"]
ignore = [] ignore = []
# Allow fix for all enabled rules (when `--fix`) is provided. # Allow fix for all enabled rules (when `--fix`) is provided.

2
test.py

@ -18,4 +18,4 @@ load_from_local_files(paths_or_directory="examples/data", force_rebuild=True)
# load_from_website(urls=website_url) # load_from_website(urls=website_url)
# Query # Query
result = query("Write a report about Milvus.") # Your question here result = query("Write a report about Milvus.") # Your question here

Loading…
Cancel
Save