You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

154 lines
6.3 KiB

from unittest.mock import MagicMock, patch
from deepsearcher.agent.collection_router import CollectionRouter
from deepsearcher.llm.base import ChatResponse
from deepsearcher.vector_db.base import CollectionInfo
from tests.agent.test_base import BaseAgentTest
class TestCollectionRouter(BaseAgentTest):
"""Test class for CollectionRouter."""
def setUp(self):
"""Set up test fixtures for CollectionRouter tests."""
super().setUp()
# Create mock collections
self.collection_infos = [
CollectionInfo(collection_name="books", description="Collection of book summaries"),
CollectionInfo(collection_name="science", description="Scientific articles and papers"),
CollectionInfo(collection_name="news", description="Recent news articles")
]
# Configure vector_db mock
self.vector_db.list_collections = MagicMock(return_value=self.collection_infos)
self.vector_db.default_collection = "books"
# Create the CollectionRouter
self.collection_router = CollectionRouter(
llm=self.llm,
vector_db=self.vector_db,
dim=8
)
def test_init(self):
"""Test the initialization of CollectionRouter."""
self.assertEqual(self.collection_router.llm, self.llm)
self.assertEqual(self.collection_router.vector_db, self.vector_db)
self.assertEqual(
self.collection_router.all_collections,
["books", "science", "news"]
)
def test_invoke_with_multiple_collections(self):
"""Test the invoke method with multiple collections."""
query = "What are the latest scientific breakthroughs?"
# Mock LLM to return specific collections based on query
self.llm.chat = MagicMock(return_value=ChatResponse(
content='["science", "news"]',
total_tokens=10
))
# Disable log output for testing
with patch('deepsearcher.utils.log.color_print'):
selected_collections, tokens = self.collection_router.invoke(query, dim=8)
# Check results
self.assertTrue("science" in selected_collections)
self.assertTrue("news" in selected_collections)
self.assertTrue("books" in selected_collections) # Default collection is always included
self.assertEqual(tokens, 10)
# Verify that the LLM was called with the right prompt
self.llm.chat.assert_called_once()
self.assertIn(query, self.llm.chat.call_args[1]["messages"][0]["content"])
self.assertIn("collection_name", self.llm.chat.call_args[1]["messages"][0]["content"])
def test_invoke_with_empty_response(self):
"""Test the invoke method when LLM returns an empty list."""
query = "Something completely unrelated"
# Mock LLM to return empty list
self.llm.chat = MagicMock(return_value=ChatResponse(
content='[]',
total_tokens=5
))
# Disable log output for testing
with patch('deepsearcher.utils.log.color_print'):
selected_collections, tokens = self.collection_router.invoke(query, dim=8)
# Only default collection should be included
self.assertEqual(len(selected_collections), 1)
self.assertEqual(selected_collections[0], "books")
self.assertEqual(tokens, 5)
def test_invoke_with_no_collections(self):
"""Test the invoke method when no collections are available."""
query = "Test query"
# Mock vector_db to return empty list
self.vector_db.list_collections = MagicMock(return_value=[])
# Disable log warnings for testing
with patch('deepsearcher.utils.log.warning'):
with patch('deepsearcher.utils.log.color_print'):
selected_collections, tokens = self.collection_router.invoke(query, dim=8)
# Should return empty list and zero tokens
self.assertEqual(selected_collections, [])
self.assertEqual(tokens, 0)
def test_invoke_with_single_collection(self):
"""Test the invoke method when only one collection is available."""
query = "Test query"
# Create a fresh mock for llm.chat to verify it's not called
mock_chat = MagicMock(return_value=ChatResponse(content='[]', total_tokens=0))
self.llm.chat = mock_chat
# Mock vector_db to return single collection
single_collection = [CollectionInfo(collection_name="single", description="The only collection")]
self.vector_db.list_collections = MagicMock(return_value=single_collection)
# Disable log output for testing
with patch('deepsearcher.utils.log.color_print'):
selected_collections, tokens = self.collection_router.invoke(query, dim=8)
# Should return the only collection without calling LLM
self.assertEqual(selected_collections, ["single"])
self.assertEqual(tokens, 0)
mock_chat.assert_not_called()
def test_invoke_with_no_description(self):
"""Test the invoke method when a collection has no description."""
query = "Test query"
# Create collections with one having no description
collections_with_no_desc = [
CollectionInfo(collection_name="with_desc", description="Has description"),
CollectionInfo(collection_name="no_desc", description="")
]
self.vector_db.list_collections = MagicMock(return_value=collections_with_no_desc)
self.vector_db.default_collection = "with_desc"
# Mock LLM to return only the first collection
self.llm.chat = MagicMock(return_value=ChatResponse(
content='["with_desc"]',
total_tokens=5
))
# Disable log output for testing
with patch('deepsearcher.utils.log.color_print'):
selected_collections, tokens = self.collection_router.invoke(query, dim=8)
# Both collections should be included (one from LLM, one with no description)
self.assertEqual(set(selected_collections), {"with_desc", "no_desc"})
self.assertEqual(tokens, 5)
if __name__ == "__main__":
import unittest
unittest.main()