You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

421 lines
17 KiB

import unittest
from unittest.mock import MagicMock, patch
import os
class TestWatsonX(unittest.TestCase):
"""Test cases for WatsonX class."""
def setUp(self):
"""Set up test fixtures."""
pass
@patch('deepsearcher.llm.watsonx.ModelInference')
@patch('deepsearcher.llm.watsonx.Credentials')
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_init_with_env_vars(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
"""Test initialization with environment variables."""
from deepsearcher.llm.watsonx import WatsonX
mock_credentials_instance = MagicMock()
mock_model_inference_instance = MagicMock()
mock_credentials_class.return_value = mock_credentials_instance
mock_model_inference_class.return_value = mock_model_inference_instance
# Mock the GenTextParamsMetaNames attributes
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
mock_gen_text_params_class.TEMPERATURE = 'temperature'
mock_gen_text_params_class.TOP_P = 'top_p'
mock_gen_text_params_class.TOP_K = 'top_k'
llm = WatsonX()
# Check that Credentials was called with correct parameters
mock_credentials_class.assert_called_once_with(
url='https://test.watsonx.com',
api_key='test-api-key'
)
# Check that ModelInference was called with correct parameters
mock_model_inference_class.assert_called_once_with(
model_id='ibm/granite-3-3-8b-instruct',
credentials=mock_credentials_instance,
project_id='test-project-id'
)
# Check default model
self.assertEqual(llm.model, 'ibm/granite-3-3-8b-instruct')
@patch('deepsearcher.llm.watsonx.ModelInference')
@patch('deepsearcher.llm.watsonx.Credentials')
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com'
})
def test_init_with_space_id(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
"""Test initialization with space_id instead of project_id."""
from deepsearcher.llm.watsonx import WatsonX
mock_credentials_instance = MagicMock()
mock_model_inference_instance = MagicMock()
mock_credentials_class.return_value = mock_credentials_instance
mock_model_inference_class.return_value = mock_model_inference_instance
# Mock the GenTextParamsMetaNames attributes
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
mock_gen_text_params_class.TEMPERATURE = 'temperature'
mock_gen_text_params_class.TOP_P = 'top_p'
mock_gen_text_params_class.TOP_K = 'top_k'
llm = WatsonX(space_id='test-space-id')
# Check that ModelInference was called with space_id
mock_model_inference_class.assert_called_once_with(
model_id='ibm/granite-3-3-8b-instruct',
credentials=mock_credentials_instance,
space_id='test-space-id'
)
@patch('deepsearcher.llm.watsonx.ModelInference')
@patch('deepsearcher.llm.watsonx.Credentials')
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_init_with_custom_model(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
"""Test initialization with custom model."""
from deepsearcher.llm.watsonx import WatsonX
mock_credentials_instance = MagicMock()
mock_model_inference_instance = MagicMock()
mock_credentials_class.return_value = mock_credentials_instance
mock_model_inference_class.return_value = mock_model_inference_instance
# Mock the GenTextParamsMetaNames attributes
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
mock_gen_text_params_class.TEMPERATURE = 'temperature'
mock_gen_text_params_class.TOP_P = 'top_p'
mock_gen_text_params_class.TOP_K = 'top_k'
llm = WatsonX(model='ibm/granite-13b-chat-v2')
# Check that ModelInference was called with custom model
mock_model_inference_class.assert_called_once_with(
model_id='ibm/granite-13b-chat-v2',
credentials=mock_credentials_instance,
project_id='test-project-id'
)
self.assertEqual(llm.model, 'ibm/granite-13b-chat-v2')
@patch('deepsearcher.llm.watsonx.ModelInference')
@patch('deepsearcher.llm.watsonx.Credentials')
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_init_with_custom_params(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
"""Test initialization with custom generation parameters."""
from deepsearcher.llm.watsonx import WatsonX
mock_credentials_instance = MagicMock()
mock_model_inference_instance = MagicMock()
mock_credentials_class.return_value = mock_credentials_instance
mock_model_inference_class.return_value = mock_model_inference_instance
# Mock the GenTextParamsMetaNames attributes
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
mock_gen_text_params_class.TEMPERATURE = 'temperature'
mock_gen_text_params_class.TOP_P = 'top_p'
mock_gen_text_params_class.TOP_K = 'top_k'
llm = WatsonX(
max_new_tokens=500,
temperature=0.7,
top_p=0.9,
top_k=40
)
# Check that generation parameters were set correctly
expected_params = {
'max_new_tokens': 500,
'temperature': 0.7,
'top_p': 0.9,
'top_k': 40
}
self.assertEqual(llm.generation_params, expected_params)
@patch('deepsearcher.llm.watsonx.ModelInference')
@patch('deepsearcher.llm.watsonx.Credentials')
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
def test_init_missing_api_key(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
"""Test initialization with missing API key."""
from deepsearcher.llm.watsonx import WatsonX
with patch.dict(os.environ, {}, clear=True):
with self.assertRaises(ValueError) as context:
WatsonX()
self.assertIn("WATSONX_APIKEY", str(context.exception))
@patch('deepsearcher.llm.watsonx.ModelInference')
@patch('deepsearcher.llm.watsonx.Credentials')
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key'
})
def test_init_missing_url(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
"""Test initialization with missing URL."""
from deepsearcher.llm.watsonx import WatsonX
with self.assertRaises(ValueError) as context:
WatsonX()
self.assertIn("WATSONX_URL", str(context.exception))
@patch('deepsearcher.llm.watsonx.ModelInference')
@patch('deepsearcher.llm.watsonx.Credentials')
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com'
})
def test_init_missing_project_and_space_id(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
"""Test initialization with missing both project_id and space_id."""
from deepsearcher.llm.watsonx import WatsonX
with self.assertRaises(ValueError) as context:
WatsonX()
self.assertIn("WATSONX_PROJECT_ID", str(context.exception))
@patch('deepsearcher.llm.watsonx.ModelInference')
@patch('deepsearcher.llm.watsonx.Credentials')
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_chat_simple_message(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
"""Test chat with a simple message."""
from deepsearcher.llm.watsonx import WatsonX
mock_credentials_instance = MagicMock()
mock_model_inference_instance = MagicMock()
mock_model_inference_instance.generate_text.return_value = "This is a test response from WatsonX."
mock_credentials_class.return_value = mock_credentials_instance
mock_model_inference_class.return_value = mock_model_inference_instance
# Mock the GenTextParamsMetaNames attributes
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
mock_gen_text_params_class.TEMPERATURE = 'temperature'
mock_gen_text_params_class.TOP_P = 'top_p'
mock_gen_text_params_class.TOP_K = 'top_k'
llm = WatsonX()
messages = [
{"role": "user", "content": "Hello, how are you?"}
]
response = llm.chat(messages)
# Check that generate_text was called
mock_model_inference_instance.generate_text.assert_called_once()
call_args = mock_model_inference_instance.generate_text.call_args
# Check the prompt format
expected_prompt = "Human: Hello, how are you?\n\nAssistant:"
self.assertEqual(call_args[1]['prompt'], expected_prompt)
# Check response
self.assertEqual(response.content, "This is a test response from WatsonX.")
self.assertIsInstance(response.total_tokens, int)
self.assertGreater(response.total_tokens, 0)
@patch('deepsearcher.llm.watsonx.ModelInference')
@patch('deepsearcher.llm.watsonx.Credentials')
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_chat_with_system_message(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
"""Test chat with system and user messages."""
from deepsearcher.llm.watsonx import WatsonX
mock_credentials_instance = MagicMock()
mock_model_inference_instance = MagicMock()
mock_model_inference_instance.generate_text.return_value = "4"
mock_credentials_class.return_value = mock_credentials_instance
mock_model_inference_class.return_value = mock_model_inference_instance
# Mock the GenTextParamsMetaNames attributes
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
mock_gen_text_params_class.TEMPERATURE = 'temperature'
mock_gen_text_params_class.TOP_P = 'top_p'
mock_gen_text_params_class.TOP_K = 'top_k'
llm = WatsonX()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is 2+2?"}
]
response = llm.chat(messages)
# Check that generate_text was called
mock_model_inference_instance.generate_text.assert_called_once()
call_args = mock_model_inference_instance.generate_text.call_args
# Check the prompt format
expected_prompt = "System: You are a helpful assistant.\n\nHuman: What is 2+2?\n\nAssistant:"
self.assertEqual(call_args[1]['prompt'], expected_prompt)
@patch('deepsearcher.llm.watsonx.ModelInference')
@patch('deepsearcher.llm.watsonx.Credentials')
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_chat_conversation_history(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
"""Test chat with conversation history."""
from deepsearcher.llm.watsonx import WatsonX
mock_credentials_instance = MagicMock()
mock_model_inference_instance = MagicMock()
mock_model_inference_instance.generate_text.return_value = "6"
mock_credentials_class.return_value = mock_credentials_instance
mock_model_inference_class.return_value = mock_model_inference_instance
# Mock the GenTextParamsMetaNames attributes
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
mock_gen_text_params_class.TEMPERATURE = 'temperature'
mock_gen_text_params_class.TOP_P = 'top_p'
mock_gen_text_params_class.TOP_K = 'top_k'
llm = WatsonX()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "2+2 equals 4."},
{"role": "user", "content": "What about 3+3?"}
]
response = llm.chat(messages)
# Check that generate_text was called
mock_model_inference_instance.generate_text.assert_called_once()
call_args = mock_model_inference_instance.generate_text.call_args
# Check the prompt format includes conversation history
expected_prompt = ("System: You are a helpful assistant.\n\n"
"Human: What is 2+2?\n\n"
"Assistant: 2+2 equals 4.\n\n"
"Human: What about 3+3?\n\n"
"Assistant:")
self.assertEqual(call_args[1]['prompt'], expected_prompt)
@patch('deepsearcher.llm.watsonx.ModelInference')
@patch('deepsearcher.llm.watsonx.Credentials')
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_chat_error_handling(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
"""Test error handling in chat method."""
from deepsearcher.llm.watsonx import WatsonX
mock_credentials_instance = MagicMock()
mock_model_inference_instance = MagicMock()
mock_model_inference_instance.generate_text.side_effect = Exception("API Error")
mock_credentials_class.return_value = mock_credentials_instance
mock_model_inference_class.return_value = mock_model_inference_instance
# Mock the GenTextParamsMetaNames attributes
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
mock_gen_text_params_class.TEMPERATURE = 'temperature'
mock_gen_text_params_class.TOP_P = 'top_p'
mock_gen_text_params_class.TOP_K = 'top_k'
llm = WatsonX()
messages = [{"role": "user", "content": "Hello"}]
# Test that the exception is properly wrapped
with self.assertRaises(RuntimeError) as context:
llm.chat(messages)
self.assertIn("Error generating response with WatsonX", str(context.exception))
@patch('deepsearcher.llm.watsonx.ModelInference')
@patch('deepsearcher.llm.watsonx.Credentials')
@patch('deepsearcher.llm.watsonx.GenTextParamsMetaNames')
@patch.dict('os.environ', {
'WATSONX_APIKEY': 'test-api-key',
'WATSONX_URL': 'https://test.watsonx.com',
'WATSONX_PROJECT_ID': 'test-project-id'
})
def test_messages_to_prompt(self, mock_gen_text_params_class, mock_credentials_class, mock_model_inference_class):
"""Test the _messages_to_prompt method."""
from deepsearcher.llm.watsonx import WatsonX
mock_credentials_instance = MagicMock()
mock_model_inference_instance = MagicMock()
mock_credentials_class.return_value = mock_credentials_instance
mock_model_inference_class.return_value = mock_model_inference_instance
# Mock the GenTextParamsMetaNames attributes
mock_gen_text_params_class.MAX_NEW_TOKENS = 'max_new_tokens'
mock_gen_text_params_class.TEMPERATURE = 'temperature'
mock_gen_text_params_class.TOP_P = 'top_p'
mock_gen_text_params_class.TOP_K = 'top_k'
llm = WatsonX()
messages = [
{"role": "system", "content": "System message"},
{"role": "user", "content": "User message"},
{"role": "assistant", "content": "Assistant message"},
{"role": "user", "content": "Another user message"}
]
prompt = llm._messages_to_prompt(messages)
expected_prompt = ("System: System message\n\n"
"Human: User message\n\n"
"Assistant: Assistant message\n\n"
"Human: Another user message\n\n"
"Assistant:")
self.assertEqual(prompt, expected_prompt)
if __name__ == '__main__':
unittest.main()