You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
149 lines
5.4 KiB
149 lines
5.4 KiB
import unittest
|
|
from unittest.mock import MagicMock
|
|
import numpy as np
|
|
|
|
from deepsearcher.llm.base import BaseLLM, ChatResponse
|
|
from deepsearcher.embedding.base import BaseEmbedding
|
|
from deepsearcher.vector_db.base import BaseVectorDB, RetrievalResult, CollectionInfo
|
|
|
|
|
|
class MockLLM(BaseLLM):
|
|
"""Mock LLM implementation for testing agents."""
|
|
|
|
def __init__(self, predefined_responses=None):
|
|
"""
|
|
Initialize the MockLLM.
|
|
|
|
Args:
|
|
predefined_responses: Dictionary mapping prompt substrings to responses
|
|
"""
|
|
self.chat_called = False
|
|
self.last_messages = None
|
|
self.predefined_responses = predefined_responses or {}
|
|
|
|
def chat(self, messages, **kwargs):
|
|
"""Mock implementation of chat that returns predefined responses or a default response."""
|
|
self.chat_called = True
|
|
self.last_messages = messages
|
|
|
|
if self.predefined_responses:
|
|
message_content = messages[0]["content"] if messages else ""
|
|
for key, response in self.predefined_responses.items():
|
|
if key in message_content:
|
|
return ChatResponse(content=response, total_tokens=10)
|
|
|
|
# Default response for RERANK_PROMPT - treat all chunks as relevant
|
|
if "Based on the query questions and the retrieved chunks" in message_content:
|
|
# Count the number of chunks in the message
|
|
chunk_count = message_content.count("<chunk_")
|
|
# Return a list with "YES" for each chunk
|
|
return ChatResponse(content=str(["YES"] * chunk_count), total_tokens=10)
|
|
|
|
return ChatResponse(content="This is a test answer", total_tokens=10)
|
|
|
|
def literal_eval(self, text):
|
|
"""Mock implementation of literal_eval."""
|
|
# Default implementation returns a list with test_collection
|
|
# Override this in specific tests if needed
|
|
if text.strip().startswith("[") and text.strip().endswith("]"):
|
|
# Return the list as is if it's already in list format
|
|
try:
|
|
import ast
|
|
return ast.literal_eval(text)
|
|
except:
|
|
pass
|
|
|
|
return ["test_collection"]
|
|
|
|
|
|
class MockEmbedding(BaseEmbedding):
|
|
"""Mock embedding model implementation for testing agents."""
|
|
|
|
def __init__(self, dimension=8):
|
|
"""Initialize the MockEmbedding with a specific dimension."""
|
|
self._dimension = dimension
|
|
|
|
@property
|
|
def dimension(self):
|
|
"""Return the dimension of the embedding model."""
|
|
return self._dimension
|
|
|
|
def embed_query(self, text):
|
|
"""Mock implementation that returns a random vector of the specified dimension."""
|
|
return np.random.random(self._dimension).tolist()
|
|
|
|
def embed_documents(self, documents):
|
|
"""Mock implementation that returns random vectors for each document."""
|
|
return [np.random.random(self._dimension).tolist() for _ in documents]
|
|
|
|
|
|
class MockVectorDB(BaseVectorDB):
|
|
"""Mock vector database implementation for testing agents."""
|
|
|
|
def __init__(self, collections=None):
|
|
"""
|
|
Initialize the MockVectorDB.
|
|
|
|
Args:
|
|
collections: List of collection names to initialize with
|
|
"""
|
|
self.default_collection = "test_collection"
|
|
self.search_called = False
|
|
self.insert_called = False
|
|
self._collections = []
|
|
|
|
if collections:
|
|
for collection in collections:
|
|
self._collections.append(
|
|
CollectionInfo(collection_name=collection, description=f"Test collection {collection}")
|
|
)
|
|
else:
|
|
self._collections = [
|
|
CollectionInfo(collection_name="test_collection", description="Test collection for testing")
|
|
]
|
|
|
|
def search_data(self, collection, vector, top_k=10, **kwargs):
|
|
"""Mock implementation that returns test results."""
|
|
self.search_called = True
|
|
self.last_search_collection = collection
|
|
self.last_search_vector = vector
|
|
self.last_search_top_k = top_k
|
|
|
|
return [
|
|
RetrievalResult(
|
|
embedding=vector,
|
|
text=f"Test result {i} for collection {collection}",
|
|
reference=f"test_reference_{collection}_{i}",
|
|
metadata={"a": i, "wider_text": f"Wider context for test result {i} in collection {collection}"}
|
|
)
|
|
for i in range(min(3, top_k))
|
|
]
|
|
|
|
def insert_data(self, collection, chunks):
|
|
"""Mock implementation of insert_data."""
|
|
self.insert_called = True
|
|
self.last_insert_collection = collection
|
|
self.last_insert_chunks = chunks
|
|
return True
|
|
|
|
def init_collection(self, dim, collection, **kwargs):
|
|
"""Mock implementation of init_collection."""
|
|
return True
|
|
|
|
def list_collections(self, dim=None):
|
|
"""Mock implementation that returns the list of collections."""
|
|
return self._collections
|
|
|
|
def clear_db(self, collection):
|
|
"""Mock implementation of clear_db."""
|
|
return True
|
|
|
|
|
|
class BaseAgentTest(unittest.TestCase):
|
|
"""Base test class for agent tests with common setup."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures for agent tests."""
|
|
self.llm = MockLLM()
|
|
self.embedding_model = MockEmbedding(dimension=8)
|
|
self.vector_db = MockVectorDB()
|