You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

98 lines
4.2 KiB

from typing import List, Tuple
from deepsearcher.agent.base import BaseAgent
from deepsearcher.llm.base import BaseLLM
from deepsearcher.utils import log
from deepsearcher.vector_db.base import BaseVectorDB
COLLECTION_ROUTE_PROMPT = """
I provide you with collection_name(s) and corresponding collection_description(s). Please select the collection names that may be related to the question and return a python list of str. If there is no collection related to the question, you can return an empty list.
"QUESTION": {question}
"COLLECTION_INFO": {collection_info}
When you return, you can ONLY return a json convertable python list of str, WITHOUT any other additional content. Your selected collection name list is:
"""
class CollectionRouter(BaseAgent):
"""
Routes queries to appropriate collections in the vector database.
This class analyzes the content of a query and determines which collections
in the vector database are most likely to contain relevant information.
"""
def __init__(self, llm: BaseLLM, vector_db: BaseVectorDB, dim: int, **kwargs):
"""
Initialize the CollectionRouter.
Args:
llm: The language model to use for analyzing queries.
vector_db: The vector database containing the collections.
dim: The dimension of the vector space to search in.
"""
self.llm = llm
self.vector_db = vector_db
self.all_collections = [
collection_info.collection_name
for collection_info in self.vector_db.list_collections(dim=dim)
]
def invoke(self, query: str, dim: int, **kwargs) -> Tuple[List[str], int]:
"""
Determine which collections are relevant for the given query.
This method analyzes the query content and selects collections that are
most likely to contain information relevant to answering the query.
Args:
query (str): The query to analyze.
dim (int): The dimension of the vector space to search in.
Returns:
Tuple[List[str], int]: A tuple containing:
- A list of selected collection names
- The token usage for the routing operation
"""
consume_tokens = 0
collection_infos = self.vector_db.list_collections(dim=dim)
if len(collection_infos) == 0:
log.color_print(
"No collections found in the vector database. Please check the database connection."
)
return [], 0
if len(collection_infos) == 1:
the_only_collection = collection_infos[0].collection_name
log.color_print(
f"<think> Perform search [{query}] on the vector DB collection: {the_only_collection} </think>\n"
)
return [the_only_collection], 0
vector_db_search_prompt = COLLECTION_ROUTE_PROMPT.format(
question=query,
collection_info=[
{
"collection_name": collection_info.collection_name,
"collection_description": collection_info.description,
}
for collection_info in collection_infos
],
)
chat_response = self.llm.chat(
messages=[{"role": "user", "content": vector_db_search_prompt}]
)
selected_collections = self.llm.literal_eval(chat_response.content)
consume_tokens += chat_response.total_tokens
for collection_info in collection_infos:
# If a collection description is not provided, use the query as the search query
if not collection_info.description:
selected_collections.append(collection_info.collection_name)
# If the default collection exists, use the query as the search query
if self.vector_db.default_collection == collection_info.collection_name:
selected_collections.append(collection_info.collection_name)
selected_collections = list(set(selected_collections))
log.color_print(
f"<think> Perform search [{query}] on the vector DB collections: {selected_collections} </think>\n"
)
return selected_collections, consume_tokens