You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

237 lines
10 KiB

from unittest.mock import MagicMock, patch
from deepsearcher.agent import ChainOfRAG
from deepsearcher.vector_db.base import RetrievalResult
from deepsearcher.llm.base import ChatResponse
from tests.agent.test_base import BaseAgentTest
class TestChainOfRAG(BaseAgentTest):
"""Test class for ChainOfRAG agent."""
def setUp(self):
"""Set up test fixtures for ChainOfRAG tests."""
super().setUp()
# Set up predefined responses for the LLM for exact prompt substrings
self.llm.predefined_responses = {
"previous queries and answers, generate a new simple follow-up question": "What is the significance of deep learning?",
"Given the following documents, generate an appropriate answer": "Deep learning is a subset of machine learning that uses neural networks with multiple layers.",
"given the following intermediate queries and answers, judge whether you have enough information": "Yes",
"Given a list of agent indexes and corresponding descriptions": "1",
"Given the following documents, select the ones that are support the Q-A pair": "[0, 1]",
"Given the following intermediate queries and answers, generate a final answer": "Deep learning is an advanced subset of machine learning that uses neural networks with multiple layers."
}
self.chain_of_rag = ChainOfRAG(
llm=self.llm,
embedding_model=self.embedding_model,
vector_db=self.vector_db,
max_iter=3,
early_stopping=True,
route_collection=True,
text_window_splitter=True
)
def test_init(self):
"""Test the initialization of ChainOfRAG."""
self.assertEqual(self.chain_of_rag.llm, self.llm)
self.assertEqual(self.chain_of_rag.embedding_model, self.embedding_model)
self.assertEqual(self.chain_of_rag.vector_db, self.vector_db)
self.assertEqual(self.chain_of_rag.max_iter, 3)
self.assertEqual(self.chain_of_rag.early_stopping, True)
self.assertEqual(self.chain_of_rag.route_collection, True)
self.assertEqual(self.chain_of_rag.text_window_splitter, True)
def test_reflect_get_subquery(self):
"""Test the _reflect_get_subquery method."""
query = "What is deep learning?"
intermediate_context = ["Previous query: What is AI?", "Previous answer: AI is artificial intelligence."]
# Direct mock for this specific method
self.llm.chat = MagicMock(return_value=ChatResponse(
content="What is the significance of deep learning?",
total_tokens=10
))
subquery, tokens = self.chain_of_rag._reflect_get_subquery(query, intermediate_context)
self.assertEqual(subquery, "What is the significance of deep learning?")
self.assertEqual(tokens, 10)
self.assertTrue(self.llm.chat.called)
def test_retrieve_and_answer(self):
"""Test the _retrieve_and_answer method."""
query = "What is deep learning?"
# Mock the collection_router.invoke method
self.chain_of_rag.collection_router.invoke = MagicMock(return_value=(["test_collection"], 5))
# Direct mock for this specific method
self.llm.chat = MagicMock(return_value=ChatResponse(
content="Deep learning is a subset of machine learning that uses neural networks with multiple layers.",
total_tokens=10
))
answer, results, tokens = self.chain_of_rag._retrieve_and_answer(query)
# Check if correct methods were called
self.chain_of_rag.collection_router.invoke.assert_called_once()
self.assertTrue(self.vector_db.search_called)
# Check the results
self.assertEqual(answer, "Deep learning is a subset of machine learning that uses neural networks with multiple layers.")
self.assertEqual(tokens, 15) # 5 from collection_router + 10 from LLM
def test_get_supported_docs(self):
"""Test the _get_supported_docs method."""
results = [
RetrievalResult(
embedding=[0.1] * 8,
text=f"Test result {i}",
reference="test_reference",
metadata={"a": i}
)
for i in range(3)
]
query = "What is deep learning?"
answer = "Deep learning is a subset of machine learning that uses neural networks with multiple layers."
# Mock the literal_eval method to return indices as integers
self.llm.literal_eval = MagicMock(return_value=[0, 1])
supported_docs, tokens = self.chain_of_rag._get_supported_docs(results, query, answer)
self.assertEqual(len(supported_docs), 2) # Based on our mock response of [0, 1]
self.assertEqual(tokens, 10)
def test_check_has_enough_info(self):
"""Test the _check_has_enough_info method."""
query = "What is deep learning?"
intermediate_contexts = [
"Intermediate query1: What is deep learning?",
"Intermediate answer1: Deep learning is a subset of machine learning that uses neural networks with multiple layers."
]
# Direct mock for this specific method
self.llm.chat = MagicMock(return_value=ChatResponse(
content="Yes",
total_tokens=10
))
has_enough, tokens = self.chain_of_rag._check_has_enough_info(query, intermediate_contexts)
self.assertTrue(has_enough) # Based on our mock response of "Yes"
self.assertEqual(tokens, 10)
def test_retrieve(self):
"""Test the retrieve method."""
query = "What is deep learning?"
# Mock all the methods that retrieve calls
self.chain_of_rag._reflect_get_subquery = MagicMock(return_value=("What is the significance of deep learning?", 5))
self.chain_of_rag._retrieve_and_answer = MagicMock(
return_value=("Deep learning is important in AI", [RetrievalResult(
embedding=[0.1] * 8,
text="Test result",
reference="test_reference",
metadata={"a": 1}
)], 10)
)
self.chain_of_rag._get_supported_docs = MagicMock(return_value=([RetrievalResult(
embedding=[0.1] * 8,
text="Test result",
reference="test_reference",
metadata={"a": 1}
)], 5))
self.chain_of_rag._check_has_enough_info = MagicMock(return_value=(True, 5))
results, tokens, metadata = self.chain_of_rag.retrieve(query)
# Check if methods were called
self.chain_of_rag._reflect_get_subquery.assert_called_once()
self.chain_of_rag._retrieve_and_answer.assert_called_once()
self.chain_of_rag._get_supported_docs.assert_called_once()
# With early stopping, it should check if we have enough info
self.chain_of_rag._check_has_enough_info.assert_called_once()
# Check results
self.assertEqual(len(results), 1)
self.assertEqual(tokens, 25) # 5 + 10 + 5 + 5
self.assertIn("intermediate_context", metadata)
def test_query(self):
"""Test the query method."""
query = "What is deep learning?"
# Mock the retrieve method
retrieved_results = [
RetrievalResult(
embedding=[0.1] * 8,
text=f"Test result {i}",
reference="test_reference",
metadata={"a": i, "wider_text": f"Wider context for test result {i}"}
)
for i in range(3)
]
self.chain_of_rag.retrieve = MagicMock(
return_value=(retrieved_results, 20, {"intermediate_context": ["Some context"]})
)
# Direct mock for this specific method
self.llm.chat = MagicMock(return_value=ChatResponse(
content="Deep learning is an advanced subset of machine learning that uses neural networks with multiple layers.",
total_tokens=10
))
answer, results, tokens = self.chain_of_rag.query(query)
# Check if methods were called
self.chain_of_rag.retrieve.assert_called_once_with(query)
self.assertTrue(self.llm.chat.called)
# Check results
self.assertEqual(answer, "Deep learning is an advanced subset of machine learning that uses neural networks with multiple layers.")
self.assertEqual(results, retrieved_results)
self.assertEqual(tokens, 30) # 20 from retrieve + 10 from LLM
def test_format_retrieved_results(self):
"""Test the _format_retrieved_results method."""
retrieved_results = [
RetrievalResult(
embedding=[0.1] * 8,
text="Test result 1",
reference="test_reference",
metadata={"a": 1, "wider_text": "Wider context for test result 1"}
),
RetrievalResult(
embedding=[0.1] * 8,
text="Test result 2",
reference="test_reference",
metadata={"a": 2, "wider_text": "Wider context for test result 2"}
)
]
# Test with text_window_splitter enabled
self.chain_of_rag.text_window_splitter = True
formatted = self.chain_of_rag._format_retrieved_results(retrieved_results)
self.assertIn("Wider context for test result 1", formatted)
self.assertIn("Wider context for test result 2", formatted)
# Test with text_window_splitter disabled
self.chain_of_rag.text_window_splitter = False
formatted = self.chain_of_rag._format_retrieved_results(retrieved_results)
self.assertIn("Test result 1", formatted)
self.assertIn("Test result 2", formatted)
self.assertNotIn("Wider context for test result 1", formatted)
if __name__ == "__main__":
import unittest
unittest.main()