You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
237 lines
10 KiB
237 lines
10 KiB
from unittest.mock import MagicMock, patch
|
|
|
|
from deepsearcher.agent import ChainOfRAG
|
|
from deepsearcher.vector_db.base import RetrievalResult
|
|
from deepsearcher.llm.base import ChatResponse
|
|
|
|
from tests.agent.test_base import BaseAgentTest
|
|
|
|
|
|
class TestChainOfRAG(BaseAgentTest):
|
|
"""Test class for ChainOfRAG agent."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures for ChainOfRAG tests."""
|
|
super().setUp()
|
|
|
|
# Set up predefined responses for the LLM for exact prompt substrings
|
|
self.llm.predefined_responses = {
|
|
"previous queries and answers, generate a new simple follow-up question": "What is the significance of deep learning?",
|
|
"Given the following documents, generate an appropriate answer": "Deep learning is a subset of machine learning that uses neural networks with multiple layers.",
|
|
"given the following intermediate queries and answers, judge whether you have enough information": "Yes",
|
|
"Given a list of agent indexes and corresponding descriptions": "1",
|
|
"Given the following documents, select the ones that are support the Q-A pair": "[0, 1]",
|
|
"Given the following intermediate queries and answers, generate a final answer": "Deep learning is an advanced subset of machine learning that uses neural networks with multiple layers."
|
|
}
|
|
|
|
self.chain_of_rag = ChainOfRAG(
|
|
llm=self.llm,
|
|
embedding_model=self.embedding_model,
|
|
vector_db=self.vector_db,
|
|
max_iter=3,
|
|
early_stopping=True,
|
|
route_collection=True,
|
|
text_window_splitter=True
|
|
)
|
|
|
|
def test_init(self):
|
|
"""Test the initialization of ChainOfRAG."""
|
|
self.assertEqual(self.chain_of_rag.llm, self.llm)
|
|
self.assertEqual(self.chain_of_rag.embedding_model, self.embedding_model)
|
|
self.assertEqual(self.chain_of_rag.vector_db, self.vector_db)
|
|
self.assertEqual(self.chain_of_rag.max_iter, 3)
|
|
self.assertEqual(self.chain_of_rag.early_stopping, True)
|
|
self.assertEqual(self.chain_of_rag.route_collection, True)
|
|
self.assertEqual(self.chain_of_rag.text_window_splitter, True)
|
|
|
|
def test_reflect_get_subquery(self):
|
|
"""Test the _reflect_get_subquery method."""
|
|
query = "What is deep learning?"
|
|
intermediate_context = ["Previous query: What is AI?", "Previous answer: AI is artificial intelligence."]
|
|
|
|
# Direct mock for this specific method
|
|
self.llm.chat = MagicMock(return_value=ChatResponse(
|
|
content="What is the significance of deep learning?",
|
|
total_tokens=10
|
|
))
|
|
|
|
subquery, tokens = self.chain_of_rag._reflect_get_subquery(query, intermediate_context)
|
|
|
|
self.assertEqual(subquery, "What is the significance of deep learning?")
|
|
self.assertEqual(tokens, 10)
|
|
self.assertTrue(self.llm.chat.called)
|
|
|
|
def test_retrieve_and_answer(self):
|
|
"""Test the _retrieve_and_answer method."""
|
|
query = "What is deep learning?"
|
|
|
|
# Mock the collection_router.invoke method
|
|
self.chain_of_rag.collection_router.invoke = MagicMock(return_value=(["test_collection"], 5))
|
|
|
|
# Direct mock for this specific method
|
|
self.llm.chat = MagicMock(return_value=ChatResponse(
|
|
content="Deep learning is a subset of machine learning that uses neural networks with multiple layers.",
|
|
total_tokens=10
|
|
))
|
|
|
|
answer, results, tokens = self.chain_of_rag._retrieve_and_answer(query)
|
|
|
|
# Check if correct methods were called
|
|
self.chain_of_rag.collection_router.invoke.assert_called_once()
|
|
self.assertTrue(self.vector_db.search_called)
|
|
|
|
# Check the results
|
|
self.assertEqual(answer, "Deep learning is a subset of machine learning that uses neural networks with multiple layers.")
|
|
self.assertEqual(tokens, 15) # 5 from collection_router + 10 from LLM
|
|
|
|
def test_get_supported_docs(self):
|
|
"""Test the _get_supported_docs method."""
|
|
results = [
|
|
RetrievalResult(
|
|
embedding=[0.1] * 8,
|
|
text=f"Test result {i}",
|
|
reference="test_reference",
|
|
metadata={"a": i}
|
|
)
|
|
for i in range(3)
|
|
]
|
|
|
|
query = "What is deep learning?"
|
|
answer = "Deep learning is a subset of machine learning that uses neural networks with multiple layers."
|
|
|
|
# Mock the literal_eval method to return indices as integers
|
|
self.llm.literal_eval = MagicMock(return_value=[0, 1])
|
|
|
|
supported_docs, tokens = self.chain_of_rag._get_supported_docs(results, query, answer)
|
|
|
|
self.assertEqual(len(supported_docs), 2) # Based on our mock response of [0, 1]
|
|
self.assertEqual(tokens, 10)
|
|
|
|
def test_check_has_enough_info(self):
|
|
"""Test the _check_has_enough_info method."""
|
|
query = "What is deep learning?"
|
|
intermediate_contexts = [
|
|
"Intermediate query1: What is deep learning?",
|
|
"Intermediate answer1: Deep learning is a subset of machine learning that uses neural networks with multiple layers."
|
|
]
|
|
|
|
# Direct mock for this specific method
|
|
self.llm.chat = MagicMock(return_value=ChatResponse(
|
|
content="Yes",
|
|
total_tokens=10
|
|
))
|
|
|
|
has_enough, tokens = self.chain_of_rag._check_has_enough_info(query, intermediate_contexts)
|
|
|
|
self.assertTrue(has_enough) # Based on our mock response of "Yes"
|
|
self.assertEqual(tokens, 10)
|
|
|
|
def test_retrieve(self):
|
|
"""Test the retrieve method."""
|
|
query = "What is deep learning?"
|
|
|
|
# Mock all the methods that retrieve calls
|
|
self.chain_of_rag._reflect_get_subquery = MagicMock(return_value=("What is the significance of deep learning?", 5))
|
|
self.chain_of_rag._retrieve_and_answer = MagicMock(
|
|
return_value=("Deep learning is important in AI", [RetrievalResult(
|
|
embedding=[0.1] * 8,
|
|
text="Test result",
|
|
reference="test_reference",
|
|
metadata={"a": 1}
|
|
)], 10)
|
|
)
|
|
self.chain_of_rag._get_supported_docs = MagicMock(return_value=([RetrievalResult(
|
|
embedding=[0.1] * 8,
|
|
text="Test result",
|
|
reference="test_reference",
|
|
metadata={"a": 1}
|
|
)], 5))
|
|
self.chain_of_rag._check_has_enough_info = MagicMock(return_value=(True, 5))
|
|
|
|
results, tokens, metadata = self.chain_of_rag.retrieve(query)
|
|
|
|
# Check if methods were called
|
|
self.chain_of_rag._reflect_get_subquery.assert_called_once()
|
|
self.chain_of_rag._retrieve_and_answer.assert_called_once()
|
|
self.chain_of_rag._get_supported_docs.assert_called_once()
|
|
|
|
# With early stopping, it should check if we have enough info
|
|
self.chain_of_rag._check_has_enough_info.assert_called_once()
|
|
|
|
# Check results
|
|
self.assertEqual(len(results), 1)
|
|
self.assertEqual(tokens, 25) # 5 + 10 + 5 + 5
|
|
self.assertIn("intermediate_context", metadata)
|
|
|
|
def test_query(self):
|
|
"""Test the query method."""
|
|
query = "What is deep learning?"
|
|
|
|
# Mock the retrieve method
|
|
retrieved_results = [
|
|
RetrievalResult(
|
|
embedding=[0.1] * 8,
|
|
text=f"Test result {i}",
|
|
reference="test_reference",
|
|
metadata={"a": i, "wider_text": f"Wider context for test result {i}"}
|
|
)
|
|
for i in range(3)
|
|
]
|
|
|
|
self.chain_of_rag.retrieve = MagicMock(
|
|
return_value=(retrieved_results, 20, {"intermediate_context": ["Some context"]})
|
|
)
|
|
|
|
# Direct mock for this specific method
|
|
self.llm.chat = MagicMock(return_value=ChatResponse(
|
|
content="Deep learning is an advanced subset of machine learning that uses neural networks with multiple layers.",
|
|
total_tokens=10
|
|
))
|
|
|
|
answer, results, tokens = self.chain_of_rag.query(query)
|
|
|
|
# Check if methods were called
|
|
self.chain_of_rag.retrieve.assert_called_once_with(query)
|
|
self.assertTrue(self.llm.chat.called)
|
|
|
|
# Check results
|
|
self.assertEqual(answer, "Deep learning is an advanced subset of machine learning that uses neural networks with multiple layers.")
|
|
self.assertEqual(results, retrieved_results)
|
|
self.assertEqual(tokens, 30) # 20 from retrieve + 10 from LLM
|
|
|
|
def test_format_retrieved_results(self):
|
|
"""Test the _format_retrieved_results method."""
|
|
retrieved_results = [
|
|
RetrievalResult(
|
|
embedding=[0.1] * 8,
|
|
text="Test result 1",
|
|
reference="test_reference",
|
|
metadata={"a": 1, "wider_text": "Wider context for test result 1"}
|
|
),
|
|
RetrievalResult(
|
|
embedding=[0.1] * 8,
|
|
text="Test result 2",
|
|
reference="test_reference",
|
|
metadata={"a": 2, "wider_text": "Wider context for test result 2"}
|
|
)
|
|
]
|
|
|
|
# Test with text_window_splitter enabled
|
|
self.chain_of_rag.text_window_splitter = True
|
|
formatted = self.chain_of_rag._format_retrieved_results(retrieved_results)
|
|
|
|
self.assertIn("Wider context for test result 1", formatted)
|
|
self.assertIn("Wider context for test result 2", formatted)
|
|
|
|
# Test with text_window_splitter disabled
|
|
self.chain_of_rag.text_window_splitter = False
|
|
formatted = self.chain_of_rag._format_retrieved_results(retrieved_results)
|
|
|
|
self.assertIn("Test result 1", formatted)
|
|
self.assertIn("Test result 2", formatted)
|
|
self.assertNotIn("Wider context for test result 1", formatted)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import unittest
|
|
unittest.main()
|