You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
130 lines
4.8 KiB
130 lines
4.8 KiB
from unittest.mock import MagicMock
|
|
|
|
from deepsearcher.agent import NaiveRAG
|
|
from deepsearcher.vector_db.base import RetrievalResult
|
|
|
|
from tests.agent.test_base import BaseAgentTest
|
|
|
|
|
|
class TestNaiveRAG(BaseAgentTest):
|
|
"""Test class for NaiveRAG agent."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures for NaiveRAG tests."""
|
|
super().setUp()
|
|
self.naive_rag = NaiveRAG(
|
|
llm=self.llm,
|
|
embedding_model=self.embedding_model,
|
|
vector_db=self.vector_db,
|
|
top_k=5,
|
|
route_collection=True,
|
|
text_window_splitter=True
|
|
)
|
|
|
|
def test_init(self):
|
|
"""Test the initialization of NaiveRAG."""
|
|
self.assertEqual(self.naive_rag.llm, self.llm)
|
|
self.assertEqual(self.naive_rag.embedding_model, self.embedding_model)
|
|
self.assertEqual(self.naive_rag.vector_db, self.vector_db)
|
|
self.assertEqual(self.naive_rag.top_k, 5)
|
|
self.assertEqual(self.naive_rag.route_collection, True)
|
|
self.assertEqual(self.naive_rag.text_window_splitter, True)
|
|
|
|
def test_retrieve(self):
|
|
"""Test the retrieve method."""
|
|
query = "Test query"
|
|
|
|
# Mock the collection_router.invoke method
|
|
self.naive_rag.collection_router.invoke = MagicMock(return_value=(["test_collection"], 5))
|
|
|
|
results, tokens, metadata = self.naive_rag.retrieve(query)
|
|
|
|
# Check if correct methods were called
|
|
self.naive_rag.collection_router.invoke.assert_called_once()
|
|
self.assertTrue(self.vector_db.search_called)
|
|
|
|
# Check the results
|
|
self.assertIsInstance(results, list)
|
|
self.assertEqual(len(results), 3) # Should match our mock return of 3 results
|
|
for result in results:
|
|
self.assertIsInstance(result, RetrievalResult)
|
|
|
|
# Check token count
|
|
self.assertEqual(tokens, 5) # From our mocked collection_router.invoke
|
|
|
|
def test_retrieve_without_routing(self):
|
|
"""Test retrieve method with routing disabled."""
|
|
self.naive_rag.route_collection = False
|
|
query = "Test query without routing"
|
|
|
|
results, tokens, metadata = self.naive_rag.retrieve(query)
|
|
|
|
# Check that routing was not called
|
|
self.assertTrue(self.vector_db.search_called)
|
|
|
|
# Check the results
|
|
self.assertIsInstance(results, list)
|
|
for result in results:
|
|
self.assertIsInstance(result, RetrievalResult)
|
|
|
|
# Check token count
|
|
self.assertEqual(tokens, 0) # No tokens used for routing
|
|
|
|
def test_query(self):
|
|
"""Test the query method."""
|
|
query = "Test query for full RAG"
|
|
|
|
# Mock the retrieve method
|
|
mock_results = [
|
|
RetrievalResult(
|
|
embedding=[0.1] * 8,
|
|
text=f"Test result {i}",
|
|
reference="test_reference",
|
|
metadata={"a": i, "wider_text": f"Wider context for test result {i}"}
|
|
)
|
|
for i in range(3)
|
|
]
|
|
self.naive_rag.retrieve = MagicMock(return_value=(mock_results, 5, {}))
|
|
|
|
answer, retrieved_results, tokens = self.naive_rag.query(query)
|
|
|
|
# Check if correct methods were called
|
|
self.naive_rag.retrieve.assert_called_once_with(query)
|
|
self.assertTrue(self.llm.chat_called)
|
|
|
|
# Check the messages sent to LLM
|
|
self.assertIn("content", self.llm.last_messages[0])
|
|
self.assertIn(query, self.llm.last_messages[0]["content"])
|
|
|
|
# Check the results
|
|
self.assertEqual(answer, "This is a test answer")
|
|
self.assertEqual(retrieved_results, mock_results)
|
|
self.assertEqual(tokens, 15) # 5 from retrieve + 10 from LLM
|
|
|
|
def test_with_window_splitter_disabled(self):
|
|
"""Test with text window splitter disabled."""
|
|
self.naive_rag.text_window_splitter = False
|
|
query = "Test query with window splitter off"
|
|
|
|
# Mock the retrieve method
|
|
mock_results = [
|
|
RetrievalResult(
|
|
embedding=[0.1] * 8,
|
|
text=f"Test result {i}",
|
|
reference="test_reference",
|
|
metadata={"a": i, "wider_text": f"Wider context for test result {i}"}
|
|
)
|
|
for i in range(3)
|
|
]
|
|
self.naive_rag.retrieve = MagicMock(return_value=(mock_results, 5, {}))
|
|
|
|
answer, retrieved_results, tokens = self.naive_rag.query(query)
|
|
|
|
# Check that regular text is used instead of wider_text
|
|
self.assertIn("Test result 0", self.llm.last_messages[0]["content"])
|
|
self.assertNotIn("Wider context", self.llm.last_messages[0]["content"])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import unittest
|
|
unittest.main()
|