You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
235 lines
9.2 KiB
235 lines
9.2 KiB
from unittest.mock import MagicMock, patch
|
|
import asyncio
|
|
|
|
from deepsearcher.agent import DeepSearch
|
|
from deepsearcher.vector_db.base import RetrievalResult
|
|
|
|
from tests.agent.test_base import BaseAgentTest
|
|
|
|
|
|
class TestDeepSearch(BaseAgentTest):
|
|
"""Test class for DeepSearch agent."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures for DeepSearch tests."""
|
|
super().setUp()
|
|
|
|
# Set up predefined responses for the LLM for exact prompt substrings
|
|
self.llm.predefined_responses = {
|
|
"Original Question:": '["What is deep learning?", "How does deep learning work?", "What are applications of deep learning?"]',
|
|
"Is the chunk helpful": "YES",
|
|
"Respond exclusively in valid List": '["What are limitations of deep learning?"]',
|
|
"You are a AI content analysis expert": "Deep learning is a subset of machine learning that uses neural networks with multiple layers."
|
|
}
|
|
|
|
self.deep_search = DeepSearch(
|
|
llm=self.llm,
|
|
embedding_model=self.embedding_model,
|
|
vector_db=self.vector_db,
|
|
max_iter=2,
|
|
route_collection=True,
|
|
text_window_splitter=True
|
|
)
|
|
|
|
def test_init(self):
|
|
"""Test the initialization of DeepSearch."""
|
|
self.assertEqual(self.deep_search.llm, self.llm)
|
|
self.assertEqual(self.deep_search.embedding_model, self.embedding_model)
|
|
self.assertEqual(self.deep_search.vector_db, self.vector_db)
|
|
self.assertEqual(self.deep_search.max_iter, 2)
|
|
self.assertEqual(self.deep_search.route_collection, True)
|
|
self.assertEqual(self.deep_search.text_window_splitter, True)
|
|
|
|
def test_generate_sub_queries(self):
|
|
"""Test the _generate_sub_queries method."""
|
|
query = "Tell me about deep learning"
|
|
|
|
sub_queries, tokens = self.deep_search._generate_sub_queries(query)
|
|
|
|
self.assertEqual(len(sub_queries), 3)
|
|
self.assertEqual(sub_queries[0], "What is deep learning?")
|
|
self.assertEqual(sub_queries[1], "How does deep learning work?")
|
|
self.assertEqual(sub_queries[2], "What are applications of deep learning?")
|
|
self.assertEqual(tokens, 10)
|
|
self.assertTrue(self.llm.chat_called)
|
|
|
|
def test_search_chunks_from_vectordb(self):
|
|
"""Test the _search_chunks_from_vectordb method."""
|
|
query = "What is deep learning?"
|
|
sub_queries = ["What is deep learning?", "How does deep learning work?"]
|
|
|
|
# Mock the collection_router.invoke method
|
|
self.deep_search.collection_router.invoke = MagicMock(return_value=(["test_collection"], 5))
|
|
|
|
# Run the async method using asyncio.run
|
|
results, tokens = asyncio.run(
|
|
self.deep_search._search_chunks_from_vectordb(query, sub_queries)
|
|
)
|
|
|
|
# Check if correct methods were called
|
|
self.deep_search.collection_router.invoke.assert_called_once()
|
|
self.assertTrue(self.vector_db.search_called)
|
|
self.assertTrue(self.llm.chat_called)
|
|
|
|
# With our mock returning "YES" for RERANK_PROMPT, all chunks should be accepted
|
|
self.assertEqual(len(results), 3) # 3 mock results from MockVectorDB
|
|
self.assertEqual(tokens, 15) # 5 from collection_router + 10*1 from LLM calls for reranking (batch)
|
|
|
|
def test_generate_gap_queries(self):
|
|
"""Test the _generate_gap_queries method."""
|
|
query = "Tell me about deep learning"
|
|
all_sub_queries = ["What is deep learning?", "How does deep learning work?"]
|
|
all_chunks = [
|
|
RetrievalResult(
|
|
embedding=[0.1] * 8,
|
|
text="Deep learning is a subset of machine learning",
|
|
reference="test_reference",
|
|
metadata={"a": 1}
|
|
),
|
|
RetrievalResult(
|
|
embedding=[0.1] * 8,
|
|
text="Deep learning uses neural networks",
|
|
reference="test_reference",
|
|
metadata={"a": 2}
|
|
)
|
|
]
|
|
|
|
gap_queries, tokens = self.deep_search._generate_gap_queries(query, all_sub_queries, all_chunks)
|
|
|
|
self.assertEqual(len(gap_queries), 1)
|
|
self.assertEqual(gap_queries[0], "What are limitations of deep learning?")
|
|
self.assertEqual(tokens, 10)
|
|
|
|
def test_retrieve(self):
|
|
"""Test the retrieve method."""
|
|
query = "Tell me about deep learning"
|
|
|
|
# Mock async method to run synchronously
|
|
async def mock_async_retrieve(*args, **kwargs):
|
|
# Create some test results
|
|
results = [
|
|
RetrievalResult(
|
|
embedding=[0.1] * 8,
|
|
text="Deep learning is a subset of machine learning",
|
|
reference="test_reference",
|
|
metadata={"a": 1}
|
|
),
|
|
RetrievalResult(
|
|
embedding=[0.1] * 8,
|
|
text="Deep learning uses neural networks",
|
|
reference="test_reference",
|
|
metadata={"a": 2}
|
|
)
|
|
]
|
|
# Return the results, token count, and additional info
|
|
return results, 30, {"all_sub_queries": ["What is deep learning?", "How does deep learning work?"]}
|
|
|
|
# Replace the async method with our mock
|
|
self.deep_search.async_retrieve = mock_async_retrieve
|
|
|
|
results, tokens, metadata = self.deep_search.retrieve(query)
|
|
|
|
# Check results
|
|
self.assertEqual(len(results), 2)
|
|
self.assertEqual(tokens, 30)
|
|
self.assertIn("all_sub_queries", metadata)
|
|
self.assertEqual(len(metadata["all_sub_queries"]), 2)
|
|
|
|
def test_async_retrieve(self):
|
|
"""Test the async_retrieve method."""
|
|
query = "Tell me about deep learning"
|
|
|
|
# Create mock results
|
|
mock_results = [
|
|
RetrievalResult(
|
|
embedding=[0.1] * 8,
|
|
text="Deep learning is a subset of machine learning",
|
|
reference="test_reference",
|
|
metadata={"a": 1}
|
|
)
|
|
]
|
|
|
|
# Create a mock async_retrieve result
|
|
mock_retrieve_result = (
|
|
mock_results,
|
|
20,
|
|
{"all_sub_queries": ["What is deep learning?", "How does deep learning work?"]}
|
|
)
|
|
|
|
# Mock the async_retrieve method
|
|
async def mock_async_retrieve(*args, **kwargs):
|
|
return mock_retrieve_result
|
|
|
|
self.deep_search.async_retrieve = mock_async_retrieve
|
|
|
|
# Run the async method using asyncio.run
|
|
results, tokens, metadata = asyncio.run(self.deep_search.async_retrieve(query))
|
|
|
|
# Check results
|
|
self.assertEqual(len(results), 1)
|
|
self.assertEqual(tokens, 20)
|
|
self.assertIn("all_sub_queries", metadata)
|
|
|
|
def test_query(self):
|
|
"""Test the query method."""
|
|
query = "Tell me about deep learning"
|
|
|
|
# Mock the retrieve method
|
|
retrieved_results = [
|
|
RetrievalResult(
|
|
embedding=[0.1] * 8,
|
|
text=f"Test result {i}",
|
|
reference="test_reference",
|
|
metadata={"a": i, "wider_text": f"Wider context for test result {i}"}
|
|
)
|
|
for i in range(3)
|
|
]
|
|
|
|
self.deep_search.retrieve = MagicMock(
|
|
return_value=(retrieved_results, 20, {"all_sub_queries": ["What is deep learning?"]})
|
|
)
|
|
|
|
answer, results, tokens = self.deep_search.query(query)
|
|
|
|
# Check if methods were called
|
|
self.deep_search.retrieve.assert_called_once_with(query)
|
|
self.assertTrue(self.llm.chat_called)
|
|
|
|
# Check results
|
|
self.assertEqual(answer, "Deep learning is a subset of machine learning that uses neural networks with multiple layers.")
|
|
self.assertEqual(results, retrieved_results)
|
|
self.assertEqual(tokens, 30) # 20 from retrieve + 10 from LLM
|
|
|
|
def test_query_no_results(self):
|
|
"""Test the query method when no results are found."""
|
|
query = "Tell me about deep learning"
|
|
|
|
# Mock the retrieve method to return no results
|
|
self.deep_search.retrieve = MagicMock(
|
|
return_value=([], 10, {"all_sub_queries": ["What is deep learning?"]})
|
|
)
|
|
|
|
answer, results, tokens = self.deep_search.query(query)
|
|
|
|
# Should return a message saying no results found
|
|
self.assertIn("No relevant information found", answer)
|
|
self.assertEqual(results, [])
|
|
self.assertEqual(tokens, 10) # Only tokens from retrieve
|
|
|
|
def test_format_chunk_texts(self):
|
|
"""Test the _format_chunk_texts method."""
|
|
chunk_texts = ["Text 1", "Text 2", "Text 3"]
|
|
|
|
formatted = self.deep_search._format_chunk_texts(chunk_texts)
|
|
|
|
self.assertIn("<chunk_0>", formatted)
|
|
self.assertIn("Text 1", formatted)
|
|
self.assertIn("<chunk_1>", formatted)
|
|
self.assertIn("Text 2", formatted)
|
|
self.assertIn("<chunk_2>", formatted)
|
|
self.assertIn("Text 3", formatted)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import unittest
|
|
unittest.main()
|