You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
162 lines
6.5 KiB
162 lines
6.5 KiB
from unittest.mock import MagicMock, patch
|
|
|
|
from deepsearcher.agent import NaiveRAG, ChainOfRAG, DeepSearch
|
|
from deepsearcher.agent.rag_router import RAGRouter
|
|
from deepsearcher.vector_db.base import RetrievalResult
|
|
from deepsearcher.llm.base import ChatResponse
|
|
|
|
from tests.agent.test_base import BaseAgentTest
|
|
|
|
|
|
class TestRAGRouter(BaseAgentTest):
|
|
"""Test class for RAGRouter agent."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures for RAGRouter tests."""
|
|
super().setUp()
|
|
|
|
# Create mock agent instances
|
|
self.naive_rag = MagicMock(spec=NaiveRAG)
|
|
self.chain_of_rag = MagicMock(spec=ChainOfRAG)
|
|
|
|
# Create the RAGRouter with the mock agents
|
|
self.rag_router = RAGRouter(
|
|
llm=self.llm,
|
|
rag_agents=[self.naive_rag, self.chain_of_rag],
|
|
agent_descriptions=[
|
|
"This agent is suitable for simple factual queries",
|
|
"This agent is suitable for complex multi-hop questions"
|
|
]
|
|
)
|
|
|
|
def test_init(self):
|
|
"""Test the initialization of RAGRouter."""
|
|
self.assertEqual(self.rag_router.llm, self.llm)
|
|
self.assertEqual(len(self.rag_router.rag_agents), 2)
|
|
self.assertEqual(len(self.rag_router.agent_descriptions), 2)
|
|
self.assertEqual(self.rag_router.agent_descriptions[0], "This agent is suitable for simple factual queries")
|
|
self.assertEqual(self.rag_router.agent_descriptions[1], "This agent is suitable for complex multi-hop questions")
|
|
|
|
def test_route(self):
|
|
"""Test the _route method."""
|
|
query = "What is the capital of France?"
|
|
|
|
# Directly mock the chat method to return a numeric response
|
|
self.llm.chat = MagicMock(return_value=ChatResponse(content="1", total_tokens=10))
|
|
|
|
agent, tokens = self.rag_router._route(query)
|
|
|
|
# Should select the first agent based on our mock response
|
|
self.assertEqual(agent, self.naive_rag)
|
|
self.assertEqual(tokens, 10)
|
|
self.assertTrue(self.llm.chat.called)
|
|
|
|
def test_route_with_non_numeric_response(self):
|
|
"""Test the _route method with a non-numeric response from LLM."""
|
|
query = "What is the history of deep learning?"
|
|
|
|
# Mock the LLM to return a response with a trailing digit
|
|
self.llm.chat = MagicMock(return_value=ChatResponse(content="I recommend agent 2", total_tokens=10))
|
|
self.rag_router.find_last_digit = MagicMock(return_value="2")
|
|
|
|
agent, tokens = self.rag_router._route(query)
|
|
|
|
# Should select the second agent based on our mock response
|
|
self.assertEqual(agent, self.chain_of_rag)
|
|
self.assertTrue(self.rag_router.find_last_digit.called)
|
|
|
|
def test_retrieve(self):
|
|
"""Test the retrieve method."""
|
|
query = "What is the capital of France?"
|
|
|
|
# Mock the _route method to return the first agent
|
|
mock_retrieved_results = [
|
|
RetrievalResult(
|
|
embedding=[0.1] * 8,
|
|
text="Paris is the capital of France",
|
|
reference="test_reference",
|
|
metadata={"a": 1}
|
|
)
|
|
]
|
|
self.rag_router._route = MagicMock(return_value=(self.naive_rag, 5))
|
|
self.naive_rag.retrieve = MagicMock(return_value=(mock_retrieved_results, 10, {"some": "metadata"}))
|
|
|
|
results, tokens, metadata = self.rag_router.retrieve(query)
|
|
|
|
# Check if methods were called
|
|
self.rag_router._route.assert_called_once_with(query)
|
|
self.naive_rag.retrieve.assert_called_once_with(query)
|
|
|
|
# Check results
|
|
self.assertEqual(results, mock_retrieved_results)
|
|
self.assertEqual(tokens, 15) # 5 from route + 10 from retrieve
|
|
self.assertEqual(metadata, {"some": "metadata"})
|
|
|
|
def test_query(self):
|
|
"""Test the query method."""
|
|
query = "What is the capital of France?"
|
|
|
|
# Mock the _route method to return the first agent
|
|
mock_retrieved_results = [
|
|
RetrievalResult(
|
|
embedding=[0.1] * 8,
|
|
text="Paris is the capital of France",
|
|
reference="test_reference",
|
|
metadata={"a": 1}
|
|
)
|
|
]
|
|
self.rag_router._route = MagicMock(return_value=(self.naive_rag, 5))
|
|
self.naive_rag.query = MagicMock(return_value=("Paris is the capital of France", mock_retrieved_results, 10))
|
|
|
|
answer, results, tokens = self.rag_router.query(query)
|
|
|
|
# Check if methods were called
|
|
self.rag_router._route.assert_called_once_with(query)
|
|
self.naive_rag.query.assert_called_once_with(query)
|
|
|
|
# Check results
|
|
self.assertEqual(answer, "Paris is the capital of France")
|
|
self.assertEqual(results, mock_retrieved_results)
|
|
self.assertEqual(tokens, 15) # 5 from route + 10 from query
|
|
|
|
def test_find_last_digit(self):
|
|
"""Test the find_last_digit method."""
|
|
self.assertEqual(self.rag_router.find_last_digit("Agent 2 is better"), "2")
|
|
self.assertEqual(self.rag_router.find_last_digit("I recommend agent number 1"), "1")
|
|
self.assertEqual(self.rag_router.find_last_digit("Choose 3"), "3")
|
|
|
|
# Test with no digit
|
|
with self.assertRaises(ValueError):
|
|
self.rag_router.find_last_digit("No digits here")
|
|
|
|
def test_auto_description_fallback(self):
|
|
"""Test that RAGRouter falls back to __description__ when no descriptions provided."""
|
|
# Create classes with __description__ attribute
|
|
class MockAgent1:
|
|
__description__ = "Auto description 1"
|
|
|
|
class MockAgent2:
|
|
__description__ = "Auto description 2"
|
|
|
|
# Create instances of these classes
|
|
mock_agent1 = MagicMock(spec=MockAgent1)
|
|
mock_agent1.__class__ = MockAgent1
|
|
|
|
mock_agent2 = MagicMock(spec=MockAgent2)
|
|
mock_agent2.__class__ = MockAgent2
|
|
|
|
# Create the RAGRouter without explicit descriptions
|
|
router = RAGRouter(
|
|
llm=self.llm,
|
|
rag_agents=[mock_agent1, mock_agent2]
|
|
)
|
|
|
|
# Check that descriptions were pulled from the class attributes
|
|
self.assertEqual(len(router.agent_descriptions), 2)
|
|
self.assertEqual(router.agent_descriptions[0], "Auto description 1")
|
|
self.assertEqual(router.agent_descriptions[1], "Auto description 2")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import unittest
|
|
unittest.main()
|