You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

130 lines
4.8 KiB

from unittest.mock import MagicMock
from deepsearcher.agent import NaiveRAG
from deepsearcher.vector_db.base import RetrievalResult
from tests.agent.test_base import BaseAgentTest
class TestNaiveRAG(BaseAgentTest):
"""Test class for NaiveRAG agent."""
def setUp(self):
"""Set up test fixtures for NaiveRAG tests."""
super().setUp()
self.naive_rag = NaiveRAG(
llm=self.llm,
embedding_model=self.embedding_model,
vector_db=self.vector_db,
top_k=5,
route_collection=True,
text_window_splitter=True
)
def test_init(self):
"""Test the initialization of NaiveRAG."""
self.assertEqual(self.naive_rag.llm, self.llm)
self.assertEqual(self.naive_rag.embedding_model, self.embedding_model)
self.assertEqual(self.naive_rag.vector_db, self.vector_db)
self.assertEqual(self.naive_rag.top_k, 5)
self.assertEqual(self.naive_rag.route_collection, True)
self.assertEqual(self.naive_rag.text_window_splitter, True)
def test_retrieve(self):
"""Test the retrieve method."""
query = "Test query"
# Mock the collection_router.invoke method
self.naive_rag.collection_router.invoke = MagicMock(return_value=(["test_collection"], 5))
results, tokens, metadata = self.naive_rag.retrieve(query)
# Check if correct methods were called
self.naive_rag.collection_router.invoke.assert_called_once()
self.assertTrue(self.vector_db.search_called)
# Check the results
self.assertIsInstance(results, list)
self.assertEqual(len(results), 3) # Should match our mock return of 3 results
for result in results:
self.assertIsInstance(result, RetrievalResult)
# Check token count
self.assertEqual(tokens, 5) # From our mocked collection_router.invoke
def test_retrieve_without_routing(self):
"""Test retrieve method with routing disabled."""
self.naive_rag.route_collection = False
query = "Test query without routing"
results, tokens, metadata = self.naive_rag.retrieve(query)
# Check that routing was not called
self.assertTrue(self.vector_db.search_called)
# Check the results
self.assertIsInstance(results, list)
for result in results:
self.assertIsInstance(result, RetrievalResult)
# Check token count
self.assertEqual(tokens, 0) # No tokens used for routing
def test_query(self):
"""Test the query method."""
query = "Test query for full RAG"
# Mock the retrieve method
mock_results = [
RetrievalResult(
embedding=[0.1] * 8,
text=f"Test result {i}",
reference="test_reference",
metadata={"a": i, "wider_text": f"Wider context for test result {i}"}
)
for i in range(3)
]
self.naive_rag.retrieve = MagicMock(return_value=(mock_results, 5, {}))
answer, retrieved_results, tokens = self.naive_rag.query(query)
# Check if correct methods were called
self.naive_rag.retrieve.assert_called_once_with(query)
self.assertTrue(self.llm.chat_called)
# Check the messages sent to LLM
self.assertIn("content", self.llm.last_messages[0])
self.assertIn(query, self.llm.last_messages[0]["content"])
# Check the results
self.assertEqual(answer, "This is a test answer")
self.assertEqual(retrieved_results, mock_results)
self.assertEqual(tokens, 15) # 5 from retrieve + 10 from LLM
def test_with_window_splitter_disabled(self):
"""Test with text window splitter disabled."""
self.naive_rag.text_window_splitter = False
query = "Test query with window splitter off"
# Mock the retrieve method
mock_results = [
RetrievalResult(
embedding=[0.1] * 8,
text=f"Test result {i}",
reference="test_reference",
metadata={"a": i, "wider_text": f"Wider context for test result {i}"}
)
for i in range(3)
]
self.naive_rag.retrieve = MagicMock(return_value=(mock_results, 5, {}))
answer, retrieved_results, tokens = self.naive_rag.query(query)
# Check that regular text is used instead of wider_text
self.assertIn("Test result 0", self.llm.last_messages[0]["content"])
self.assertNotIn("Wider context", self.llm.last_messages[0]["content"])
if __name__ == "__main__":
import unittest
unittest.main()