You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

149 lines
5.4 KiB

2 weeks ago
import unittest
from unittest.mock import MagicMock
import numpy as np
from deepsearcher.llm.base import BaseLLM, ChatResponse
from deepsearcher.embedding.base import BaseEmbedding
from deepsearcher.vector_db.base import BaseVectorDB, RetrievalResult, CollectionInfo
class MockLLM(BaseLLM):
"""Mock LLM implementation for testing agents."""
def __init__(self, predefined_responses=None):
"""
Initialize the MockLLM.
Args:
predefined_responses: Dictionary mapping prompt substrings to responses
"""
self.chat_called = False
self.last_messages = None
self.predefined_responses = predefined_responses or {}
def chat(self, messages, **kwargs):
"""Mock implementation of chat that returns predefined responses or a default response."""
self.chat_called = True
self.last_messages = messages
if self.predefined_responses:
message_content = messages[0]["content"] if messages else ""
for key, response in self.predefined_responses.items():
if key in message_content:
return ChatResponse(content=response, total_tokens=10)
# Default response for RERANK_PROMPT - treat all chunks as relevant
if "Based on the query questions and the retrieved chunks" in message_content:
# Count the number of chunks in the message
chunk_count = message_content.count("<chunk_")
# Return a list with "YES" for each chunk
return ChatResponse(content=str(["YES"] * chunk_count), total_tokens=10)
2 weeks ago
return ChatResponse(content="This is a test answer", total_tokens=10)
def literal_eval(self, text):
"""Mock implementation of literal_eval."""
# Default implementation returns a list with test_collection
# Override this in specific tests if needed
if text.strip().startswith("[") and text.strip().endswith("]"):
# Return the list as is if it's already in list format
try:
import ast
return ast.literal_eval(text)
except:
pass
return ["test_collection"]
class MockEmbedding(BaseEmbedding):
"""Mock embedding model implementation for testing agents."""
def __init__(self, dimension=8):
"""Initialize the MockEmbedding with a specific dimension."""
self._dimension = dimension
@property
def dimension(self):
"""Return the dimension of the embedding model."""
return self._dimension
def embed_query(self, text):
"""Mock implementation that returns a random vector of the specified dimension."""
return np.random.random(self._dimension).tolist()
def embed_documents(self, documents):
"""Mock implementation that returns random vectors for each document."""
return [np.random.random(self._dimension).tolist() for _ in documents]
class MockVectorDB(BaseVectorDB):
"""Mock vector database implementation for testing agents."""
def __init__(self, collections=None):
"""
Initialize the MockVectorDB.
Args:
collections: List of collection names to initialize with
"""
self.default_collection = "test_collection"
self.search_called = False
self.insert_called = False
self._collections = []
if collections:
for collection in collections:
self._collections.append(
CollectionInfo(collection_name=collection, description=f"Test collection {collection}")
)
else:
self._collections = [
CollectionInfo(collection_name="test_collection", description="Test collection for testing")
]
def search_data(self, collection, vector, top_k=10, **kwargs):
"""Mock implementation that returns test results."""
self.search_called = True
self.last_search_collection = collection
self.last_search_vector = vector
self.last_search_top_k = top_k
return [
RetrievalResult(
embedding=vector,
text=f"Test result {i} for collection {collection}",
reference=f"test_reference_{collection}_{i}",
metadata={"a": i, "wider_text": f"Wider context for test result {i} in collection {collection}"}
)
for i in range(min(3, top_k))
]
def insert_data(self, collection, chunks):
"""Mock implementation of insert_data."""
self.insert_called = True
self.last_insert_collection = collection
self.last_insert_chunks = chunks
return True
def init_collection(self, dim, collection, **kwargs):
"""Mock implementation of init_collection."""
return True
def list_collections(self, dim=None):
"""Mock implementation that returns the list of collections."""
return self._collections
def clear_db(self, collection):
"""Mock implementation of clear_db."""
return True
class BaseAgentTest(unittest.TestCase):
"""Base test class for agent tests with common setup."""
def setUp(self):
"""Set up test fixtures for agent tests."""
self.llm = MockLLM()
self.embedding_model = MockEmbedding(dimension=8)
self.vector_db = MockVectorDB()