You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
421 lines
17 KiB
421 lines
17 KiB
import unittest
|
|
from unittest.mock import MagicMock, patch
|
|
import os
|
|
|
|
class TestWatsonX(unittest.TestCase):
|
|
"""Test cases for WatsonX class."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
pass
|
|
|
|
@patch('deepsearcher.llm.watsonx.ModelInference')
|
|
@patch('deepsearcher.llm.watsonx.Credentials')
|
|
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
|
|
@patch.dict('os.environ', {
|
|
'WATSONX_APIKEY': 'test-api-key',
|
|
'WATSONX_URL': 'https://test.watsonx.com',
|
|
'WATSONX_PROJECT_ID': 'test-project-id'
|
|
})
|
|
def test_init_with_env_vars(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
|
|
"""Test initialization with environment variables."""
|
|
from deepsearcher.llm.watsonx import WatsonX
|
|
|
|
mock_credentials_instance = MagicMock()
|
|
mock_model_inference_instance = MagicMock()
|
|
|
|
mock_credentials_class.return_value = mock_credentials_instance
|
|
mock_model_inference_class.return_value = mock_model_inference_instance
|
|
|
|
# Mock the GenTextParamsMetaNames attributes
|
|
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
|
|
mock_gen_text_params_class.TEMPERATURE = 'temperature'
|
|
mock_gen_text_params_class.TOP_P = 'top_p'
|
|
mock_gen_text_params_class.TOP_K = 'top_k'
|
|
|
|
llm = WatsonX()
|
|
|
|
# Check that Credentials was called with correct parameters
|
|
mock_credentials_class.assert_called_once_with(
|
|
url='https://test.watsonx.com',
|
|
api_key='test-api-key'
|
|
)
|
|
|
|
# Check that ModelInference was called with correct parameters
|
|
mock_model_inference_class.assert_called_once_with(
|
|
model_id='ibm/granite-3-3-8b-instruct',
|
|
credentials=mock_credentials_instance,
|
|
project_id='test-project-id'
|
|
)
|
|
|
|
# Check default model
|
|
self.assertEqual(llm.model, 'ibm/granite-3-3-8b-instruct')
|
|
|
|
@patch('deepsearcher.llm.watsonx.ModelInference')
|
|
@patch('deepsearcher.llm.watsonx.Credentials')
|
|
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
|
|
@patch.dict('os.environ', {
|
|
'WATSONX_APIKEY': 'test-api-key',
|
|
'WATSONX_URL': 'https://test.watsonx.com'
|
|
})
|
|
def test_init_with_space_id(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
|
|
"""Test initialization with space_id instead of project_id."""
|
|
from deepsearcher.llm.watsonx import WatsonX
|
|
|
|
mock_credentials_instance = MagicMock()
|
|
mock_model_inference_instance = MagicMock()
|
|
|
|
mock_credentials_class.return_value = mock_credentials_instance
|
|
mock_model_inference_class.return_value = mock_model_inference_instance
|
|
|
|
# Mock the GenTextParamsMetaNames attributes
|
|
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
|
|
mock_gen_text_params_class.TEMPERATURE = 'temperature'
|
|
mock_gen_text_params_class.TOP_P = 'top_p'
|
|
mock_gen_text_params_class.TOP_K = 'top_k'
|
|
|
|
llm = WatsonX(space_id='test-space-id')
|
|
|
|
# Check that ModelInference was called with space_id
|
|
mock_model_inference_class.assert_called_once_with(
|
|
model_id='ibm/granite-3-3-8b-instruct',
|
|
credentials=mock_credentials_instance,
|
|
space_id='test-space-id'
|
|
)
|
|
|
|
@patch('deepsearcher.llm.watsonx.ModelInference')
|
|
@patch('deepsearcher.llm.watsonx.Credentials')
|
|
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
|
|
@patch.dict('os.environ', {
|
|
'WATSONX_APIKEY': 'test-api-key',
|
|
'WATSONX_URL': 'https://test.watsonx.com',
|
|
'WATSONX_PROJECT_ID': 'test-project-id'
|
|
})
|
|
def test_init_with_custom_model(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
|
|
"""Test initialization with custom model."""
|
|
from deepsearcher.llm.watsonx import WatsonX
|
|
|
|
mock_credentials_instance = MagicMock()
|
|
mock_model_inference_instance = MagicMock()
|
|
|
|
mock_credentials_class.return_value = mock_credentials_instance
|
|
mock_model_inference_class.return_value = mock_model_inference_instance
|
|
|
|
# Mock the GenTextParamsMetaNames attributes
|
|
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
|
|
mock_gen_text_params_class.TEMPERATURE = 'temperature'
|
|
mock_gen_text_params_class.TOP_P = 'top_p'
|
|
mock_gen_text_params_class.TOP_K = 'top_k'
|
|
|
|
llm = WatsonX(model='ibm/granite-13b-chat-v2')
|
|
|
|
# Check that ModelInference was called with custom model
|
|
mock_model_inference_class.assert_called_once_with(
|
|
model_id='ibm/granite-13b-chat-v2',
|
|
credentials=mock_credentials_instance,
|
|
project_id='test-project-id'
|
|
)
|
|
|
|
self.assertEqual(llm.model, 'ibm/granite-13b-chat-v2')
|
|
|
|
@patch('deepsearcher.llm.watsonx.ModelInference')
|
|
@patch('deepsearcher.llm.watsonx.Credentials')
|
|
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
|
|
@patch.dict('os.environ', {
|
|
'WATSONX_APIKEY': 'test-api-key',
|
|
'WATSONX_URL': 'https://test.watsonx.com',
|
|
'WATSONX_PROJECT_ID': 'test-project-id'
|
|
})
|
|
def test_init_with_custom_params(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
|
|
"""Test initialization with custom generation parameters."""
|
|
from deepsearcher.llm.watsonx import WatsonX
|
|
|
|
mock_credentials_instance = MagicMock()
|
|
mock_model_inference_instance = MagicMock()
|
|
|
|
mock_credentials_class.return_value = mock_credentials_instance
|
|
mock_model_inference_class.return_value = mock_model_inference_instance
|
|
|
|
# Mock the GenTextParamsMetaNames attributes
|
|
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
|
|
mock_gen_text_params_class.TEMPERATURE = 'temperature'
|
|
mock_gen_text_params_class.TOP_P = 'top_p'
|
|
mock_gen_text_params_class.TOP_K = 'top_k'
|
|
|
|
llm = WatsonX(
|
|
max_new_tokens=500,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
top_k=40
|
|
)
|
|
|
|
# Check that generation parameters were set correctly
|
|
expected_params = {
|
|
'max_new_tokens': 500,
|
|
'temperature': 0.7,
|
|
'top_p': 0.9,
|
|
'top_k': 40
|
|
}
|
|
self.assertEqual(llm.generation_params, expected_params)
|
|
|
|
@patch('deepsearcher.llm.watsonx.ModelInference')
|
|
@patch('deepsearcher.llm.watsonx.Credentials')
|
|
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
|
|
def test_init_missing_api_key(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
|
|
"""Test initialization with missing API key."""
|
|
from deepsearcher.llm.watsonx import WatsonX
|
|
|
|
with patch.dict(os.environ, {}, clear=True):
|
|
with self.assertRaises(ValueError) as context:
|
|
WatsonX()
|
|
|
|
self.assertIn("WATSONX_APIKEY", str(context.exception))
|
|
|
|
@patch('deepsearcher.llm.watsonx.ModelInference')
|
|
@patch('deepsearcher.llm.watsonx.Credentials')
|
|
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
|
|
@patch.dict('os.environ', {
|
|
'WATSONX_APIKEY': 'test-api-key'
|
|
})
|
|
def test_init_missing_url(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
|
|
"""Test initialization with missing URL."""
|
|
from deepsearcher.llm.watsonx import WatsonX
|
|
|
|
with self.assertRaises(ValueError) as context:
|
|
WatsonX()
|
|
|
|
self.assertIn("WATSONX_URL", str(context.exception))
|
|
|
|
@patch('deepsearcher.llm.watsonx.ModelInference')
|
|
@patch('deepsearcher.llm.watsonx.Credentials')
|
|
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
|
|
@patch.dict('os.environ', {
|
|
'WATSONX_APIKEY': 'test-api-key',
|
|
'WATSONX_URL': 'https://test.watsonx.com'
|
|
})
|
|
def test_init_missing_project_and_space_id(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
|
|
"""Test initialization with missing both project_id and space_id."""
|
|
from deepsearcher.llm.watsonx import WatsonX
|
|
|
|
with self.assertRaises(ValueError) as context:
|
|
WatsonX()
|
|
|
|
self.assertIn("WATSONX_PROJECT_ID", str(context.exception))
|
|
|
|
@patch('deepsearcher.llm.watsonx.ModelInference')
|
|
@patch('deepsearcher.llm.watsonx.Credentials')
|
|
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
|
|
@patch.dict('os.environ', {
|
|
'WATSONX_APIKEY': 'test-api-key',
|
|
'WATSONX_URL': 'https://test.watsonx.com',
|
|
'WATSONX_PROJECT_ID': 'test-project-id'
|
|
})
|
|
def test_chat_simple_message(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
|
|
"""Test chat with a simple message."""
|
|
from deepsearcher.llm.watsonx import WatsonX
|
|
|
|
mock_credentials_instance = MagicMock()
|
|
mock_model_inference_instance = MagicMock()
|
|
mock_model_inference_instance.generate_text.return_value = "This is a test response from WatsonX."
|
|
|
|
mock_credentials_class.return_value = mock_credentials_instance
|
|
mock_model_inference_class.return_value = mock_model_inference_instance
|
|
|
|
# Mock the GenTextParamsMetaNames attributes
|
|
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
|
|
mock_gen_text_params_class.TEMPERATURE = 'temperature'
|
|
mock_gen_text_params_class.TOP_P = 'top_p'
|
|
mock_gen_text_params_class.TOP_K = 'top_k'
|
|
|
|
llm = WatsonX()
|
|
|
|
messages = [
|
|
{"role": "user", "content": "Hello, how are you?"}
|
|
]
|
|
|
|
response = llm.chat(messages)
|
|
|
|
# Check that generate_text was called
|
|
mock_model_inference_instance.generate_text.assert_called_once()
|
|
call_args = mock_model_inference_instance.generate_text.call_args
|
|
|
|
# Check the prompt format
|
|
expected_prompt = "Human: Hello, how are you?\n\nAssistant:"
|
|
self.assertEqual(call_args[1]['prompt'], expected_prompt)
|
|
|
|
# Check response
|
|
self.assertEqual(response.content, "This is a test response from WatsonX.")
|
|
self.assertIsInstance(response.total_tokens, int)
|
|
self.assertGreater(response.total_tokens, 0)
|
|
|
|
@patch('deepsearcher.llm.watsonx.ModelInference')
|
|
@patch('deepsearcher.llm.watsonx.Credentials')
|
|
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
|
|
@patch.dict('os.environ', {
|
|
'WATSONX_APIKEY': 'test-api-key',
|
|
'WATSONX_URL': 'https://test.watsonx.com',
|
|
'WATSONX_PROJECT_ID': 'test-project-id'
|
|
})
|
|
def test_chat_with_system_message(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
|
|
"""Test chat with system and user messages."""
|
|
from deepsearcher.llm.watsonx import WatsonX
|
|
|
|
mock_credentials_instance = MagicMock()
|
|
mock_model_inference_instance = MagicMock()
|
|
mock_model_inference_instance.generate_text.return_value = "4"
|
|
|
|
mock_credentials_class.return_value = mock_credentials_instance
|
|
mock_model_inference_class.return_value = mock_model_inference_instance
|
|
|
|
# Mock the GenTextParamsMetaNames attributes
|
|
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
|
|
mock_gen_text_params_class.TEMPERATURE = 'temperature'
|
|
mock_gen_text_params_class.TOP_P = 'top_p'
|
|
mock_gen_text_params_class.TOP_K = 'top_k'
|
|
|
|
llm = WatsonX()
|
|
|
|
messages = [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": "What is 2+2?"}
|
|
]
|
|
|
|
response = llm.chat(messages)
|
|
|
|
# Check that generate_text was called
|
|
mock_model_inference_instance.generate_text.assert_called_once()
|
|
call_args = mock_model_inference_instance.generate_text.call_args
|
|
|
|
# Check the prompt format
|
|
expected_prompt = "System: You are a helpful assistant.\n\nHuman: What is 2+2?\n\nAssistant:"
|
|
self.assertEqual(call_args[1]['prompt'], expected_prompt)
|
|
|
|
@patch('deepsearcher.llm.watsonx.ModelInference')
|
|
@patch('deepsearcher.llm.watsonx.Credentials')
|
|
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
|
|
@patch.dict('os.environ', {
|
|
'WATSONX_APIKEY': 'test-api-key',
|
|
'WATSONX_URL': 'https://test.watsonx.com',
|
|
'WATSONX_PROJECT_ID': 'test-project-id'
|
|
})
|
|
def test_chat_conversation_history(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
|
|
"""Test chat with conversation history."""
|
|
from deepsearcher.llm.watsonx import WatsonX
|
|
|
|
mock_credentials_instance = MagicMock()
|
|
mock_model_inference_instance = MagicMock()
|
|
mock_model_inference_instance.generate_text.return_value = "6"
|
|
|
|
mock_credentials_class.return_value = mock_credentials_instance
|
|
mock_model_inference_class.return_value = mock_model_inference_instance
|
|
|
|
# Mock the GenTextParamsMetaNames attributes
|
|
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
|
|
mock_gen_text_params_class.TEMPERATURE = 'temperature'
|
|
mock_gen_text_params_class.TOP_P = 'top_p'
|
|
mock_gen_text_params_class.TOP_K = 'top_k'
|
|
|
|
llm = WatsonX()
|
|
|
|
messages = [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": "What is 2+2?"},
|
|
{"role": "assistant", "content": "2+2 equals 4."},
|
|
{"role": "user", "content": "What about 3+3?"}
|
|
]
|
|
|
|
response = llm.chat(messages)
|
|
|
|
# Check that generate_text was called
|
|
mock_model_inference_instance.generate_text.assert_called_once()
|
|
call_args = mock_model_inference_instance.generate_text.call_args
|
|
|
|
# Check the prompt format includes conversation history
|
|
expected_prompt = ("System: You are a helpful assistant.\n\n"
|
|
"Human: What is 2+2?\n\n"
|
|
"Assistant: 2+2 equals 4.\n\n"
|
|
"Human: What about 3+3?\n\n"
|
|
"Assistant:")
|
|
self.assertEqual(call_args[1]['prompt'], expected_prompt)
|
|
|
|
@patch('deepsearcher.llm.watsonx.ModelInference')
|
|
@patch('deepsearcher.llm.watsonx.Credentials')
|
|
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
|
|
@patch.dict('os.environ', {
|
|
'WATSONX_APIKEY': 'test-api-key',
|
|
'WATSONX_URL': 'https://test.watsonx.com',
|
|
'WATSONX_PROJECT_ID': 'test-project-id'
|
|
})
|
|
def test_chat_error_handling(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
|
|
"""Test error handling in chat method."""
|
|
from deepsearcher.llm.watsonx import WatsonX
|
|
|
|
mock_credentials_instance = MagicMock()
|
|
mock_model_inference_instance = MagicMock()
|
|
mock_model_inference_instance.generate_text.side_effect = Exception("API Error")
|
|
|
|
mock_credentials_class.return_value = mock_credentials_instance
|
|
mock_model_inference_class.return_value = mock_model_inference_instance
|
|
|
|
# Mock the GenTextParamsMetaNames attributes
|
|
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
|
|
mock_gen_text_params_class.TEMPERATURE = 'temperature'
|
|
mock_gen_text_params_class.TOP_P = 'top_p'
|
|
mock_gen_text_params_class.TOP_K = 'top_k'
|
|
|
|
llm = WatsonX()
|
|
|
|
messages = [{"role": "user", "content": "Hello"}]
|
|
|
|
# Test that the exception is properly wrapped
|
|
with self.assertRaises(RuntimeError) as context:
|
|
llm.chat(messages)
|
|
|
|
self.assertIn("Error generating response with WatsonX", str(context.exception))
|
|
|
|
@patch('deepsearcher.llm.watsonx.ModelInference')
|
|
@patch('deepsearcher.llm.watsonx.Credentials')
|
|
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
|
|
@patch.dict('os.environ', {
|
|
'WATSONX_APIKEY': 'test-api-key',
|
|
'WATSONX_URL': 'https://test.watsonx.com',
|
|
'WATSONX_PROJECT_ID': 'test-project-id'
|
|
})
|
|
def test_messages_to_prompt(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
|
|
"""Test the _messages_to_prompt method."""
|
|
from deepsearcher.llm.watsonx import WatsonX
|
|
|
|
mock_credentials_instance = MagicMock()
|
|
mock_model_inference_instance = MagicMock()
|
|
|
|
mock_credentials_class.return_value = mock_credentials_instance
|
|
mock_model_inference_class.return_value = mock_model_inference_instance
|
|
|
|
# Mock the GenTextParamsMetaNames attributes
|
|
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
|
|
mock_gen_text_params_class.TEMPERATURE = 'temperature'
|
|
mock_gen_text_params_class.TOP_P = 'top_p'
|
|
mock_gen_text_params_class.TOP_K = 'top_k'
|
|
|
|
llm = WatsonX()
|
|
|
|
messages = [
|
|
{"role": "system", "content": "System message"},
|
|
{"role": "user", "content": "User message"},
|
|
{"role": "assistant", "content": "Assistant message"},
|
|
{"role": "user", "content": "Another user message"}
|
|
]
|
|
|
|
prompt = llm._messages_to_prompt(messages)
|
|
|
|
expected_prompt = ("System: System message\n\n"
|
|
"Human: User message\n\n"
|
|
"Assistant: Assistant message\n\n"
|
|
"Human: Another user message\n\n"
|
|
"Assistant:")
|
|
|
|
self.assertEqual(prompt, expected_prompt)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|
|
|