diff --git a/app/README/alert_actions.conf.spec b/app/README/alert_actions.conf.spec index f9f0f55..0d63a86 100644 --- a/app/README/alert_actions.conf.spec +++ b/app/README/alert_actions.conf.spec @@ -4,6 +4,7 @@ param._cam = Active response parameters. param.notebook = Notebook param.paramOne = Field Name for Parameter One param.paramTwo = Field Name for Parameter Two +param.account_name = Account Name [launch_notebook] python.version = python3 @@ -12,4 +13,5 @@ param.revision_timestamp = Revision Timestamp. param.notebook_parameters = Notebook Parameters. param.cluster_name = Cluster Name. param.run_name = Run Name. +param.account_name = Account Name. param._cam = Active response parameters. \ No newline at end of file diff --git a/app/README/ta_databricks_account.conf.spec b/app/README/ta_databricks_account.conf.spec index 2eec475..a8ea35e 100644 --- a/app/README/ta_databricks_account.conf.spec +++ b/app/README/ta_databricks_account.conf.spec @@ -1,9 +1,14 @@ [] databricks_instance = +auth_type = aad_client_id = aad_tenant_id = aad_client_secret = aad_access_token = +oauth_client_id = +oauth_client_secret = +oauth_access_token = +oauth_token_expiration = config_for_dbquery = cluster_name = warehouse_id = diff --git a/app/README/ta_databricks_settings.conf.spec b/app/README/ta_databricks_settings.conf.spec index 7d3d6fd..0634a42 100644 --- a/app/README/ta_databricks_settings.conf.spec +++ b/app/README/ta_databricks_settings.conf.spec @@ -14,4 +14,5 @@ loglevel = [additional_parameters] admin_command_timeout = query_result_limit = -index = \ No newline at end of file +index = +thread_count = \ No newline at end of file diff --git a/app/bin/TA_Databricks_rh_account.py b/app/bin/TA_Databricks_rh_account.py index c76932e..b9047ce 100644 --- a/app/bin/TA_Databricks_rh_account.py +++ b/app/bin/TA_Databricks_rh_account.py @@ -129,6 +129,13 @@ 'aad_access_token', required=False, encrypted=True + ), + field.RestField( + 'aad_token_expiration', + required=False, + encrypted=False, + default='', + validator=None ) ] model_databricks_credentials = RestModel(fields, name=None) diff --git a/app/bin/databricks_com.py b/app/bin/databricks_com.py index 12976c1..418111e 100755 --- a/app/bin/databricks_com.py +++ b/app/bin/databricks_com.py @@ -49,6 +49,8 @@ def __init__(self, account_name, session_key): self.aad_client_id = databricks_configs.get("aad_client_id") self.aad_tenant_id = databricks_configs.get("aad_tenant_id") self.aad_client_secret = databricks_configs.get("aad_client_secret") + aad_token_expiration_str = databricks_configs.get("aad_token_expiration") + self.aad_token_expiration = float(aad_token_expiration_str) if aad_token_expiration_str else 0 elif self.auth_type == "OAUTH_M2M": self.databricks_token = databricks_configs.get("oauth_access_token") self.oauth_client_id = databricks_configs.get("oauth_client_id") @@ -104,6 +106,22 @@ def get_requests_retry_session(self): session.mount("https://", adapter) return session + def should_refresh_aad_token(self): + """ + Check if AAD token should be refreshed proactively. + + :return: Boolean - True if token expires within 5 minutes + """ + if not hasattr(self, 'aad_token_expiration') or self.aad_token_expiration == 0: + return False + + import time + current_time = time.time() + time_until_expiry = self.aad_token_expiration - current_time + + # Refresh if token expires within 5 minutes (300 seconds) + return time_until_expiry < 300 + def should_refresh_oauth_token(self): """ Check if OAuth token should be refreshed proactively. @@ -120,6 +138,74 @@ def should_refresh_oauth_token(self): # Refresh if token expires within 5 minutes (300 seconds) return time_until_expiry < 300 + def _is_token_expired_response(self, response): + """ + Check if the API response indicates an expired token. + + Databricks API may return different status codes for expired tokens: + - 403 Forbidden (legacy) + - 401 Unauthorized with "Token is expired" + - 400 Bad Request with "ExpiredJwtException" + + :param response: requests.Response object + :return: Boolean - True if response indicates expired token + """ + if response is None: + return False + + status_code = response.status_code + + # 403 Forbidden - legacy expired token response + if status_code == 403: + return True + + # 401 Unauthorized - check for token expiry message + if status_code == 401: + try: + response_text = response.text.lower() + if "token is expired" in response_text or "expired" in response_text: + return True + except Exception: + pass + return True # Treat all 401s as potentially expired tokens + + # 400 Bad Request - check for ExpiredJwtException + if status_code == 400: + try: + response_text = response.text.lower() + if "expiredjwtexception" in response_text or "expired" in response_text: + return True + except Exception: + pass + + return False + + def _refresh_aad_token(self): + """Refresh AAD access token and update session headers.""" + databricks_configs = utils.get_databricks_configs(self.session_key, self.account_name) + proxy_config = databricks_configs.get("proxy_uri") + + result = utils.get_aad_access_token( + self.session_key, + self.account_name, + self.aad_tenant_id, + self.aad_client_id, + self.aad_client_secret, + proxy_config, + retry=const.RETRIES, + conf_update=True + ) + + if isinstance(result, tuple) and result[1] == False: + raise Exception(result[0]) + + access_token, expires_in = result + self.databricks_token = access_token + import time + self.aad_token_expiration = time.time() + expires_in + self.request_headers["Authorization"] = "Bearer {}".format(self.databricks_token) + self.session.headers.update(self.request_headers) + def _refresh_oauth_token(self): """Refresh OAuth M2M access token and update session headers.""" databricks_configs = utils.get_databricks_configs(self.session_key, self.account_name) @@ -156,8 +242,11 @@ def databricks_api(self, method, endpoint, data=None, args=None): :param args: Arguments to be add into the url :return: response in the form of dictionary """ - # Proactive OAuth token refresh - if self.auth_type == "OAUTH_M2M" and self.should_refresh_oauth_token(): + # Proactive token refresh + if self.auth_type == "AAD" and self.should_refresh_aad_token(): + _LOGGER.info("AAD token expiring soon, refreshing proactively.") + self._refresh_aad_token() + elif self.auth_type == "OAUTH_M2M" and self.should_refresh_oauth_token(): _LOGGER.info("OAuth token expiring soon, refreshing proactively.") self._refresh_oauth_token() @@ -172,32 +261,16 @@ def databricks_api(self, method, endpoint, data=None, args=None): _LOGGER.info("Executing REST call: {} Payload: {}.".format(endpoint, str(data))) response = self.session.post(request_url, params=args, json=data, timeout=self.session.timeout) status_code = response.status_code - if status_code == 403 and self.auth_type == "AAD" and run_again: + # Check if token is expired and needs refresh (handles 403, 401, 400 with expiry messages) + if self._is_token_expired_response(response) and self.auth_type == "AAD" and run_again: + _LOGGER.info("Token expired (status: {}). Refreshing AAD token.".format(status_code)) response = None run_again = False - _LOGGER.info("Refreshing AAD token.") - databricks_configs = utils.get_databricks_configs(self.session_key, self.account_name) - proxy_config = databricks_configs.get("proxy_uri") - db_token = utils.get_aad_access_token( - self.session_key, - self.account_name, - self.aad_tenant_id, - self.aad_client_id, - self.aad_client_secret, - proxy_config, # Using the reinit proxy. As proxy is getting updated on Line no: 43, 45 - retry=const.RETRIES, # based on the condition and for this call we will always need proxy. - conf_update=True, # By passing True, the AAD access token will be updated in conf - ) - if isinstance(db_token, tuple): - raise Exception(db_token[0]) - else: - self.databricks_token = db_token - self.request_headers["Authorization"] = "Bearer {}".format(self.databricks_token) - self.session.headers.update(self.request_headers) - elif status_code == 403 and self.auth_type == "OAUTH_M2M" and run_again: + self._refresh_aad_token() + elif self._is_token_expired_response(response) and self.auth_type == "OAUTH_M2M" and run_again: + _LOGGER.info("Token expired (status: {}). Refreshing OAuth M2M token.".format(status_code)) response = None run_again = False - _LOGGER.info("Refreshing OAuth M2M token.") self._refresh_oauth_token() elif status_code != 200: response.raise_for_status() @@ -215,6 +288,7 @@ def databricks_api(self, method, endpoint, data=None, args=None): if "response" in locals() and response is not None: status_code_messages = { 400: response.json().get("message", "Bad request. The request is malformed."), + 401: "Unauthorized. Access token may be invalid or expired.", 403: "Invalid access token. Please enter the valid access token.", 404: "Invalid API endpoint.", 429: "API limit exceeded. Please try again after some time.", diff --git a/app/bin/databricks_common_utils.py b/app/bin/databricks_common_utils.py index bab12ac..55a4a83 100755 --- a/app/bin/databricks_common_utils.py +++ b/app/bin/databricks_common_utils.py @@ -83,16 +83,25 @@ def get_databricks_configs(session_key, account_name): return configs_dict -def save_databricks_aad_access_token(account_name, session_key, access_token, client_sec): +def save_databricks_aad_access_token(account_name, session_key, access_token, client_sec, expires_in=None): """ - Method to store new AAD access token. + Method to store new AAD access token with expiration timestamp. + :param account_name: Account name + :param session_key: Splunk session key + :param access_token: AAD access token + :param client_sec: AAD client secret + :param expires_in: Token lifetime in seconds (optional, defaults to 3600) :return: None """ + import time + if expires_in is None: + expires_in = 3600 # Default to 1 hour if not provided new_creds = { "name": account_name, "aad_client_secret": client_sec, "aad_access_token": access_token, + "aad_token_expiration": str(time.time() + expires_in), "update_token": True } try: @@ -434,7 +443,14 @@ def get_aad_access_token( Method to acquire a new AAD access token. :param session_key: Splunk session key - :return: access token + :param account_name: Account name for configuration storage + :param aad_tenant_id: Azure AD tenant ID + :param aad_client_id: Azure AD client ID + :param aad_client_secret: Azure AD client secret + :param proxy_settings: Proxy configuration dict + :param retry: Number of retry attempts + :param conf_update: If True, store token in configuration + :return: tuple (access_token, expires_in) or (error_message, False) """ token_url = const.AAD_TOKEN_ENDPOINT.format(aad_tenant_id) headers = { @@ -464,11 +480,12 @@ def get_aad_access_token( resp.raise_for_status() response = resp.json() aad_access_token = response.get("access_token") + expires_in = response.get("expires_in", 3600) # Default to 1 hour if conf_update: save_databricks_aad_access_token( - account_name, session_key, aad_access_token, aad_client_secret + account_name, session_key, aad_access_token, aad_client_secret, expires_in ) - return aad_access_token + return aad_access_token, expires_in except Exception as e: retry -= 1 if "resp" in locals(): diff --git a/app/bin/databricks_const.py b/app/bin/databricks_const.py index f4a058f..d3f740b 100755 --- a/app/bin/databricks_const.py +++ b/app/bin/databricks_const.py @@ -49,11 +49,12 @@ "700016": "Invalid Client ID provided.", "900023": "Invalid Tenant ID provided.", "7000215": "Invalid Client Secret provided.", + "400": "Bad request. The request is malformed.", + "401": "Unauthorized. Access token may be invalid or expired.", "403": "Client secret may have expired. Please configure a valid Client secret.", "404": "Invalid API endpoint.", - "500": "Internal server error.", - "400": "Bad request. The request is malformed.", "429": "API limit exceeded. Please try again after some time.", + "500": "Internal server error.", "invalid_client": "Invalid OAuth Client ID or Client Secret provided.", "unauthorized_client": "Service principal is not authorized for this workspace.", "invalid_grant": "The provided OAuth credentials are invalid or expired.", diff --git a/app/bin/databricks_get_credentials.py b/app/bin/databricks_get_credentials.py index 845b762..b3562c9 100644 --- a/app/bin/databricks_get_credentials.py +++ b/app/bin/databricks_get_credentials.py @@ -53,7 +53,12 @@ def handle(self, in_string): _LOGGER.info("Saving databricks AAD access token.") client_sec = form_data.get("aad_client_secret") access_token = form_data.get("aad_access_token") - new_creds = json.dumps({"aad_client_secret": client_sec, "aad_access_token": access_token}) + token_expiration = form_data.get("aad_token_expiration") + new_creds = json.dumps({ + "aad_client_secret": client_sec, + "aad_access_token": access_token, + "aad_token_expiration": token_expiration + }) success_msg = 'Saved AAD access token successfully.' elif form_data.get("oauth_access_token"): _LOGGER.info("Saving databricks OAuth access token.") @@ -96,6 +101,7 @@ def handle(self, in_string): 'aad_tenant_id': None, 'aad_client_secret': None, 'aad_access_token': None, + 'aad_token_expiration': None, 'oauth_client_id': None, 'oauth_client_secret': None, 'oauth_access_token': None, @@ -156,6 +162,7 @@ def handle(self, in_string): config_dict['aad_tenant_id'] = account_config.get('aad_tenant_id') config_dict['aad_client_secret'] = account_password.get('aad_client_secret') config_dict['aad_access_token'] = account_password.get('aad_access_token') + config_dict['aad_token_expiration'] = account_password.get('aad_token_expiration') elif config_dict['auth_type'] == 'OAUTH_M2M': config_dict['oauth_client_id'] = account_config.get('oauth_client_id') config_dict['oauth_client_secret'] = account_password.get('oauth_client_secret') diff --git a/app/bin/databricks_validators.py b/app/bin/databricks_validators.py index 0035488..c89c99e 100755 --- a/app/bin/databricks_validators.py +++ b/app/bin/databricks_validators.py @@ -46,23 +46,28 @@ def validate_aad(self, data): :param data: Dictionary containing values from configuration UI. :return: Boolean depending on the sucess of the connection """ + import time _LOGGER.info('Obtaining Azure Active Directory access token') aad_client_id = data.get("aad_client_id").strip() client_sec = data.get("aad_client_secret").strip() aad_tenant_id = data.get("aad_tenant_id").strip() account_name = data.get("name") - aad_access_token = utils.get_aad_access_token( + result = utils.get_aad_access_token( self._splunk_session_key, account_name, aad_tenant_id, aad_client_id, client_sec, self._proxy_settings) - if isinstance(aad_access_token, tuple): - _LOGGER.error(aad_access_token[0]) - self.put_msg(aad_access_token[0]) + + if isinstance(result, tuple) and result[1] == False: + _LOGGER.error(result[0]) + self.put_msg(result[0]) return False + + access_token, expires_in = result _LOGGER.info('Obtained Azure Active Directory access token Successfully.') databricks_instance = data.get("databricks_instance").strip("/") - valid_instance = self.validate_db_instance(databricks_instance, aad_access_token) + valid_instance = self.validate_db_instance(databricks_instance, access_token) if valid_instance: - data["aad_access_token"] = aad_access_token + data["aad_access_token"] = access_token + data["aad_token_expiration"] = str(time.time() + expires_in) data["databricks_pat"] = "" return True else: diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..2b83836 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,225 @@ +""" +Shared pytest fixtures and test configuration for Databricks Add-on tests. + +This module provides common mocking utilities and test fixtures to reduce +code duplication across test files. +""" +import pytest +from mock import patch, MagicMock + + +# ============================================================================= +# Module Mock Lists - Predefined sets of modules to mock for different tests +# ============================================================================= + +# Core Splunk modules needed by most tests +CORE_SPLUNK_MODULES = [ + 'log_manager', + 'splunk', + 'splunk.rest', + 'splunk.admin', + 'splunk.clilib', + 'splunk.clilib.cli_common', +] + +# Modules for DatabricksClient and API tests +DATABRICKS_COM_MODULES = CORE_SPLUNK_MODULES + [ + 'solnlib.server_info', +] + +# Modules for validator tests +VALIDATOR_MODULES = CORE_SPLUNK_MODULES + [ + 'solnlib.server_info', + 'splunk_aoblib', + 'splunk_aoblib.rest_migration', +] + +# Modules for command/query tests +COMMAND_MODULES = CORE_SPLUNK_MODULES + [ + 'solnlib.server_info', + 'splunk_aoblib', + 'splunk_aoblib.rest_migration', +] + +# Modules for credentials handler tests +CREDENTIALS_MODULES = [ + 'log_manager', + 'splunk', + 'splunk.persistconn.application', + 'splunk.rest', +] + +# Modules for common utils tests +COMMON_UTILS_MODULES = CORE_SPLUNK_MODULES + [ + 'splunklib.client', + 'splunklib.results', +] + +# Modules for run status tests +RUN_STATUS_MODULES = CORE_SPLUNK_MODULES + [ + 'solnlib', + 'solnlib.server_info', + 'solnlib.utils', + 'solnlib.credentials', +] + +# Modules for alert tests +ALERT_MODULES = [ + 'splunk', + 'splunk.rest', + 'splunk.clilib', + 'solnlib.server_info', + 'splunk_aoblib', + 'splunk_aoblib.rest_migration', + 'solnlib.splunkenv', + 'splunklib', +] + + +# ============================================================================= +# Module Mocking Functions +# ============================================================================= + +def create_module_mocks(modules_to_mock, special_handlers=None): + """ + Create MagicMock objects for a list of modules and patch sys.modules. + + Args: + modules_to_mock: List of module names to mock + special_handlers: Dict mapping module names to special setup functions + e.g., {'splunk.persistconn.application': setup_persistconn} + + Returns: + Dict of mocked modules + """ + mocked_modules = {module: MagicMock() for module in modules_to_mock} + + # Apply special handlers if provided + if special_handlers: + for module_name, handler in special_handlers.items(): + if module_name in mocked_modules: + handler(mocked_modules[module_name]) + + # Patch sys.modules + for module, magicmock in mocked_modules.items(): + patch.dict('sys.modules', **{module: magicmock}).start() + + return mocked_modules + + +def setup_persistconn_mock(mock_module): + """Setup special mock for splunk.persistconn.application.""" + mock_module.PersistentServerConnectionApplication = object + + +def teardown_module_mocks(): + """Stop all patches - call in tearDownModule.""" + patch.stopall() + + +# ============================================================================= +# Common Test Configurations +# ============================================================================= + +def get_pat_config(instance="123", token="token", proxy_uri=None): + """Get a PAT authentication configuration dict.""" + return { + "databricks_instance": instance, + "auth_type": "PAT", + "databricks_pat": token, + "proxy_uri": proxy_uri, + } + + +def get_aad_config( + instance="123", + token="token", + client_id="client_id", + tenant_id="tenant_id", + client_secret="client_secret", + token_expiration="9999999999.0", + proxy_uri=None +): + """Get an AAD authentication configuration dict.""" + return { + "databricks_instance": instance, + "auth_type": "AAD", + "aad_access_token": token, + "aad_client_id": client_id, + "aad_tenant_id": tenant_id, + "aad_client_secret": client_secret, + "aad_token_expiration": token_expiration, + "proxy_uri": proxy_uri, + } + + +def get_oauth_config( + instance="123", + token="token", + client_id="client_id", + client_secret="client_secret", + token_expiration="9999999999.0", + proxy_uri=None +): + """Get an OAuth M2M authentication configuration dict.""" + return { + "databricks_instance": instance, + "auth_type": "OAUTH_M2M", + "oauth_access_token": token, + "oauth_client_id": client_id, + "oauth_client_secret": client_secret, + "oauth_token_expiration": token_expiration, + "proxy_uri": proxy_uri, + } + + +def get_proxy_config( + http="http://proxy:8080", + https=None, + use_for_oauth="0" +): + """Get a proxy configuration dict.""" + config = {"http": http, "use_for_oauth": use_for_oauth} + if https: + config["https"] = https + return config + + +# ============================================================================= +# Common Test Data +# ============================================================================= + +CLUSTER_LIST = { + "clusters": [ + {"cluster_name": "test1", "cluster_id": "123", "state": "running"}, + {"cluster_name": "test2", "cluster_id": "345", "state": "pending"}, + ] +} + + +# ============================================================================= +# Pytest Fixtures (optional - for pytest-native tests) +# ============================================================================= + +@pytest.fixture +def mock_logger(): + """Fixture to provide a mocked logger.""" + return MagicMock() + + +@pytest.fixture +def pat_config(): + """Fixture for PAT configuration.""" + return get_pat_config() + + +@pytest.fixture +def aad_config(): + """Fixture for AAD configuration.""" + return get_aad_config() + + +@pytest.fixture +def oauth_config(): + """Fixture for OAuth M2M configuration.""" + return get_oauth_config() diff --git a/tests/test_databricks_com.py b/tests/test_databricks_com.py index 8e983c8..140a1d0 100644 --- a/tests/test_databricks_com.py +++ b/tests/test_databricks_com.py @@ -1,60 +1,56 @@ import declare -import os -import sys import unittest -import json from utility import Response from importlib import import_module from mock import patch, MagicMock -CLUSTER_LIST = {"clusters": [{"cluster_name": "test1", "cluster_id": "123","state":"running"}, {"cluster_name": "test2", "cluster_id": "345","state":"pending"}]} +# Import shared test utilities +from conftest import ( + create_module_mocks, + teardown_module_mocks, + DATABRICKS_COM_MODULES, + CLUSTER_LIST, + get_pat_config, + get_aad_config, + get_oauth_config, +) -mocked_modules = {} -def setUpModule(): - global mocked_modules - module_to_be_mocked = [ - 'log_manager', - 'splunk', - 'splunk.rest', - 'splunk.admin', - 'splunk.clilib', - 'splunk.clilib.cli_common', - 'solnlib.server_info', - ] +mocked_modules = {} - mocked_modules = {module: MagicMock() for module in module_to_be_mocked} - for module, magicmock in mocked_modules.items(): - patch.dict('sys.modules', **{module: magicmock}).start() +def setUpModule(): + global mocked_modules + mocked_modules = create_module_mocks(DATABRICKS_COM_MODULES) def tearDownModule(): - patch.stopall() + teardown_module_mocks() class TestDatabricksUtils(unittest.TestCase): """Test Databricks utils.""" + @patch("solnlib.server_info", return_value=MagicMock()) @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) @patch("databricks_com.utils.get_databricks_configs", autospec=True) def test_get_object(self, mock_conf, mock_session, mock_version): db_com = import_module('databricks_com') db_com._LOGGER = MagicMock() - mock_conf.return_value = {"databricks_instance" : "123", "auth_type" : "PAT", "databricks_pat" : "token", "proxy_uri" : {"use_for_oauth": '0', 'http':'uri'}} + mock_conf.return_value = get_pat_config(proxy_uri={"use_for_oauth": '0', 'http': 'uri'}) obj = db_com.DatabricksClient("account_name", "session_key") - self.assertIsInstance(obj,db_com.DatabricksClient) + self.assertIsInstance(obj, db_com.DatabricksClient) db_com._LOGGER.info.assert_called_with("Proxy is configured. Using proxy to execute the request.") - + @patch("solnlib.server_info", return_value=MagicMock()) @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) @patch("databricks_com.utils.get_databricks_configs", autospec=True) def test_skipping_proxy(self, mock_conf, mock_session, mock_version): db_com = import_module('databricks_com') db_com._LOGGER = MagicMock() - mock_conf.return_value = {"databricks_instance" : "123", "auth_type" : "PAT", "databricks_pat" : "token", "proxy_uri" : {"use_for_oauth": '1', 'http':'uri'}} + mock_conf.return_value = get_pat_config(proxy_uri={"use_for_oauth": '1', 'http': 'uri'}) obj = db_com.DatabricksClient("account_name", "session_key") - self.assertIsInstance(obj,db_com.DatabricksClient) + self.assertIsInstance(obj, db_com.DatabricksClient) db_com._LOGGER.info.assert_called_with("Skipping the usage of proxy for running query as 'Use Proxy for OAuth' parameter is checked.") @patch("solnlib.server_info", return_value=MagicMock()) @@ -62,7 +58,7 @@ def test_skipping_proxy(self, mock_conf, mock_session, mock_version): @patch("databricks_com.utils.get_databricks_configs", autospec=True) def test_get_object_error(self, mock_conf, mock_session, mock_version): db_com = import_module('databricks_com') - mock_conf.return_value = {"databricks_instance" : "123", "auth_type" : "PAT", "databricks_pat" : None, "proxy_uri" : None} + mock_conf.return_value = get_pat_config(token=None) with self.assertRaises(Exception) as context: obj = db_com.DatabricksClient("account_name", "session_key") self.assertEqual( @@ -157,14 +153,24 @@ def test_get_api_response_429(self, mock_conf, mock_session, mock_version): "API limit exceeded. Please try again after some time.", str(context.exception)) - @patch("databricks_com.utils.get_aad_access_token", return_value="new_access_token") + @patch("databricks_com.utils.get_aad_access_token", return_value=("new_access_token", 3600)) @patch("solnlib.server_info", return_value=MagicMock()) @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) @patch("databricks_com.utils.get_databricks_configs", autospec=True) @patch("databricks_com.utils.get_proxy_uri", return_value=None) def test_get_api_response_refresh_token(self, mock_proxy, mock_conf, mock_session, mock_version, mock_refresh): + """Test AAD token refresh on 403 response.""" db_com = import_module('databricks_com') - mock_conf.return_value = {"databricks_instance" : "123", "auth_type" : "AAD", "aad_access_token" : "token", "proxy_uri" : None} + mock_conf.return_value = { + "databricks_instance": "123", + "auth_type": "AAD", + "aad_access_token": "token", + "aad_client_id": "client_id", + "aad_tenant_id": "tenant_id", + "aad_client_secret": "client_secret", + "aad_token_expiration": "9999999999.0", + "proxy_uri": None + } obj = db_com.DatabricksClient("account_name", "session_key") obj.session.post.side_effect = [Response(403), Response(200)] resp = obj.databricks_api("post", "endpoint", args="123", data={"p1": "v1"}) @@ -172,14 +178,24 @@ def test_get_api_response_refresh_token(self, mock_proxy, mock_conf, mock_sessio self.assertEqual(resp, {"status_code": 200}) - @patch("databricks_com.utils.get_aad_access_token", return_value="new_token") + @patch("databricks_com.utils.get_aad_access_token", return_value=("new_token", 3600)) @patch("solnlib.server_info", return_value=MagicMock()) @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) @patch("databricks_com.utils.get_databricks_configs", autospec=True) @patch("databricks_com.utils.get_proxy_uri", return_value=None) def test_get_api_response_refresh_token_error(self, mock_proxy, mock_conf, mock_session, mock_version, mock_refresh): + """Test AAD token refresh when second request still fails.""" db_com = import_module('databricks_com') - mock_conf.return_value = {"databricks_instance" : "123", "auth_type" : "AAD" , "aad_access_token" : "token", "proxy_uri" : None} + mock_conf.return_value = { + "databricks_instance": "123", + "auth_type": "AAD", + "aad_access_token": "token", + "aad_client_id": "client_id", + "aad_tenant_id": "tenant_id", + "aad_client_secret": "client_secret", + "aad_token_expiration": "9999999999.0", + "proxy_uri": None + } obj = db_com.DatabricksClient("account_name", "session_key") obj.session.post.side_effect = [Response(403), Response(403)] with self.assertRaises(Exception) as context: @@ -508,6 +524,251 @@ def test_external_api_post(self, mock_conf, mock_session, mock_version): self.assertEqual(obj.external_session.post.call_count, 1) self.assertEqual(resp, {"result": "success"}) + # ========================================================================= + # AAD Token Expiration Tests + # ========================================================================= + + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + def test_get_object_aad_with_expiration(self, mock_conf, mock_session, mock_version): + """Test DatabricksClient initialization with AAD auth and token expiration.""" + db_com = import_module('databricks_com') + db_com._LOGGER = MagicMock() + mock_conf.return_value = { + "databricks_instance": "123", + "auth_type": "AAD", + "aad_access_token": "aad_token", + "aad_client_id": "client_id", + "aad_tenant_id": "tenant_id", + "aad_client_secret": "client_secret", + "aad_token_expiration": "9999999999.0", + "proxy_uri": None + } + obj = db_com.DatabricksClient("account_name", "session_key") + + self.assertIsInstance(obj, db_com.DatabricksClient) + self.assertEqual(obj.auth_type, "AAD") + self.assertEqual(obj.databricks_token, "aad_token") + self.assertEqual(obj.aad_client_id, "client_id") + self.assertEqual(obj.aad_tenant_id, "tenant_id") + self.assertEqual(obj.aad_token_expiration, 9999999999.0) + + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + @patch("time.time", return_value=1000000.0) + def test_should_refresh_aad_token_within_threshold(self, mock_time, mock_conf, mock_session, mock_version): + """Test should_refresh_aad_token returns True when token expires within 5 minutes.""" + db_com = import_module('databricks_com') + mock_conf.return_value = { + "databricks_instance": "123", + "auth_type": "AAD", + "aad_access_token": "token", + "aad_client_id": "client_id", + "aad_tenant_id": "tenant_id", + "aad_client_secret": "client_secret", + "aad_token_expiration": "1000200.0", # 200 seconds from now (< 5 min) + "proxy_uri": None + } + obj = db_com.DatabricksClient("account_name", "session_key") + + self.assertTrue(obj.should_refresh_aad_token()) + + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + @patch("time.time", return_value=1000000.0) + def test_should_refresh_aad_token_outside_threshold(self, mock_time, mock_conf, mock_session, mock_version): + """Test should_refresh_aad_token returns False when token has > 5 minutes validity.""" + db_com = import_module('databricks_com') + mock_conf.return_value = { + "databricks_instance": "123", + "auth_type": "AAD", + "aad_access_token": "token", + "aad_client_id": "client_id", + "aad_tenant_id": "tenant_id", + "aad_client_secret": "client_secret", + "aad_token_expiration": "1003700.0", # 3700 seconds from now (> 5 min) + "proxy_uri": None + } + obj = db_com.DatabricksClient("account_name", "session_key") + + self.assertFalse(obj.should_refresh_aad_token()) + + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + def test_should_refresh_aad_token_no_expiration(self, mock_conf, mock_session, mock_version): + """Test should_refresh_aad_token returns False when no expiration is set.""" + db_com = import_module('databricks_com') + mock_conf.return_value = { + "databricks_instance": "123", + "auth_type": "AAD", + "aad_access_token": "token", + "aad_client_id": "client_id", + "aad_tenant_id": "tenant_id", + "aad_client_secret": "client_secret", + "aad_token_expiration": None, # No expiration set + "proxy_uri": None + } + obj = db_com.DatabricksClient("account_name", "session_key") + + # Should return False when expiration is 0 (not set) + self.assertFalse(obj.should_refresh_aad_token()) + + @patch("databricks_com.utils.get_aad_access_token", return_value=("new_aad_token", 3600)) + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + @patch("databricks_com.utils.get_proxy_uri", return_value=None) + @patch("time.time", return_value=1000000.0) + def test_proactive_aad_token_refresh(self, mock_time, mock_proxy, mock_conf, mock_session, mock_version, mock_refresh): + """Test proactive AAD token refresh when token is about to expire.""" + db_com = import_module('databricks_com') + # Token expires in 4 minutes (240 seconds) - should trigger refresh + mock_conf.return_value = { + "databricks_instance": "123", + "auth_type": "AAD", + "aad_access_token": "token", + "aad_client_id": "client_id", + "aad_tenant_id": "tenant_id", + "aad_client_secret": "client_secret", + "aad_token_expiration": "1000240.0", # 240 seconds from now + "proxy_uri": None + } + obj = db_com.DatabricksClient("account_name", "session_key") + obj.session.post.return_value = Response(200) + + resp = obj.databricks_api("post", "endpoint", args="123", data={"p1": "v1"}) + + # Token should have been refreshed proactively + mock_refresh.assert_called_once() + self.assertEqual(resp, {"status_code": 200}) + + # ========================================================================= + # Token Expired Response Detection Tests (_is_token_expired_response) + # ========================================================================= + + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + def test_is_token_expired_response_403(self, mock_conf, mock_session, mock_version): + """Test _is_token_expired_response returns True for 403 status.""" + db_com = import_module('databricks_com') + mock_conf.return_value = {"databricks_instance": "123", "auth_type": "PAT", "databricks_pat": "token", "proxy_uri": None} + obj = db_com.DatabricksClient("account_name", "session_key") + + response = Response(403) + self.assertTrue(obj._is_token_expired_response(response)) + + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + def test_is_token_expired_response_401(self, mock_conf, mock_session, mock_version): + """Test _is_token_expired_response returns True for 401 status.""" + db_com = import_module('databricks_com') + mock_conf.return_value = {"databricks_instance": "123", "auth_type": "PAT", "databricks_pat": "token", "proxy_uri": None} + obj = db_com.DatabricksClient("account_name", "session_key") + + response = Response(401) + self.assertTrue(obj._is_token_expired_response(response)) + + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + def test_is_token_expired_response_200(self, mock_conf, mock_session, mock_version): + """Test _is_token_expired_response returns False for 200 status.""" + db_com = import_module('databricks_com') + mock_conf.return_value = {"databricks_instance": "123", "auth_type": "PAT", "databricks_pat": "token", "proxy_uri": None} + obj = db_com.DatabricksClient("account_name", "session_key") + + response = Response(200) + self.assertFalse(obj._is_token_expired_response(response)) + + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + def test_is_token_expired_response_none(self, mock_conf, mock_session, mock_version): + """Test _is_token_expired_response returns False for None response.""" + db_com = import_module('databricks_com') + mock_conf.return_value = {"databricks_instance": "123", "auth_type": "PAT", "databricks_pat": "token", "proxy_uri": None} + obj = db_com.DatabricksClient("account_name", "session_key") + + self.assertFalse(obj._is_token_expired_response(None)) + + # ========================================================================= + # 401 Status Code Handling Tests + # ========================================================================= + + @patch("databricks_com.utils.get_aad_access_token", return_value=("new_token", 3600)) + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + @patch("databricks_com.utils.get_proxy_uri", return_value=None) + def test_get_api_response_401_triggers_refresh(self, mock_proxy, mock_conf, mock_session, mock_version, mock_refresh): + """Test 401 response triggers AAD token refresh.""" + db_com = import_module('databricks_com') + mock_conf.return_value = { + "databricks_instance": "123", + "auth_type": "AAD", + "aad_access_token": "token", + "aad_client_id": "client_id", + "aad_tenant_id": "tenant_id", + "aad_client_secret": "client_secret", + "aad_token_expiration": "9999999999.0", + "proxy_uri": None + } + obj = db_com.DatabricksClient("account_name", "session_key") + obj.session.post.side_effect = [Response(401), Response(200)] + + resp = obj.databricks_api("post", "endpoint", args="123", data={"p1": "v1"}) + + self.assertEqual(obj.session.post.call_count, 2) + self.assertEqual(resp, {"status_code": 200}) + mock_refresh.assert_called_once() + + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + def test_get_api_response_401_error_message(self, mock_conf, mock_session, mock_version): + """Test 401 response returns correct error message for PAT auth.""" + db_com = import_module('databricks_com') + mock_conf.return_value = {"databricks_instance": "123", "auth_type": "PAT", "databricks_pat": "token", "proxy_uri": None} + obj = db_com.DatabricksClient("account_name", "session_key") + obj.session.get.return_value = Response(401, {}) + + with self.assertRaises(Exception) as context: + obj.databricks_api("get", "endpoint") + + self.assertEqual("Unauthorized. Access token may be invalid or expired.", str(context.exception)) + + @patch("databricks_com.utils.get_aad_access_token", return_value=("Token refresh failed", False)) + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + @patch("databricks_com.utils.get_proxy_uri", return_value=None) + def test_get_api_response_aad_refresh_failure(self, mock_proxy, mock_conf, mock_session, mock_version, mock_refresh): + """Test AAD token refresh failure returns proper error.""" + db_com = import_module('databricks_com') + mock_conf.return_value = { + "databricks_instance": "123", + "auth_type": "AAD", + "aad_access_token": "token", + "aad_client_id": "client_id", + "aad_tenant_id": "tenant_id", + "aad_client_secret": "client_secret", + "aad_token_expiration": "9999999999.0", + "proxy_uri": None + } + obj = db_com.DatabricksClient("account_name", "session_key") + obj.session.post.side_effect = [Response(403)] + + with self.assertRaises(Exception) as context: + obj.databricks_api("post", "endpoint", args="123", data={"p1": "v1"}) + + self.assertEqual("Token refresh failed", str(context.exception)) + diff --git a/tests/test_databricks_common_utils.py b/tests/test_databricks_common_utils.py index 659fc21..869469f 100644 --- a/tests/test_databricks_common_utils.py +++ b/tests/test_databricks_common_utils.py @@ -1,36 +1,28 @@ import declare -import os -import sys import unittest import json from utility import Response - from importlib import import_module from mock import patch, MagicMock -mocked_modules = {} -def setUpModule(): - global mocked_modules +# Import shared test utilities +from conftest import ( + create_module_mocks, + teardown_module_mocks, + COMMON_UTILS_MODULES, +) - module_to_be_mocked = [ - 'log_manager', - 'splunk', - 'splunk.rest', - 'splunk.admin', - 'splunk.clilib', - 'splunk.clilib.cli_common', - 'splunklib.client', - 'splunklib.results' - ] - mocked_modules = {module: MagicMock() for module in module_to_be_mocked} +mocked_modules = {} + - for module, magicmock in mocked_modules.items(): - patch.dict('sys.modules', **{module: magicmock}).start() +def setUpModule(): + global mocked_modules + mocked_modules = create_module_mocks(COMMON_UTILS_MODULES) def tearDownModule(): - patch.stopall() + teardown_module_mocks() class TestDatabricksUtils(unittest.TestCase): """Test Databricks utils.""" @@ -144,14 +136,16 @@ def test_format_to_json_parameters_exception(self): @patch("databricks_common_utils.save_databricks_aad_access_token") @patch("databricks_common_utils.requests.post") def test_get_aad_access_token_200(self, mock_post, mock_save, mock_conf, mock_proxy): + """Test successful AAD token acquisition returns tuple (token, expires_in).""" db_utils = import_module('databricks_common_utils') mock_save.return_value = MagicMock() mock_conf.return_value = MagicMock() mock_proxy.return_value = MagicMock() - mock_post.return_value.json.return_value = {"access_token": "123"} + mock_post.return_value.json.return_value = {"access_token": "123", "expires_in": 3600} mock_post.return_value.status_code = 200 return_val = db_utils.get_aad_access_token("session_key", "user_agent", "account_name", "aad_client_id", "aad_client_secret") - self.assertEqual(return_val, "123") + # Now returns tuple (access_token, expires_in) + self.assertEqual(return_val, ("123", 3600)) @patch("databricks_common_utils.get_proxy_uri") @@ -528,23 +522,24 @@ def test_get_databricks_configs_exception(self, mock_request): @patch("databricks_common_utils.get_current_user") @patch("databricks_common_utils.requests.post") def test_get_aad_access_token_with_conf_update(self, mock_post, mock_user): - """Test AAD token acquisition with configuration update.""" + """Test AAD token acquisition with configuration update returns tuple.""" db_utils = import_module('databricks_common_utils') db_utils._LOGGER = MagicMock() mock_user.return_value = "test_user" mock_response = MagicMock() - mock_response.json.return_value = {"access_token": "aad_token_123"} + mock_response.json.return_value = {"access_token": "aad_token_123", "expires_in": 3600} mock_response.status_code = 200 mock_post.return_value = mock_response with patch("databricks_common_utils.save_databricks_aad_access_token") as mock_save: - token = db_utils.get_aad_access_token( + result = db_utils.get_aad_access_token( "session_key", "account_name", "tenant_id", "client_id", "client_secret", conf_update=True ) - self.assertEqual(token, "aad_token_123") + # Now returns tuple (access_token, expires_in) + self.assertEqual(result, ("aad_token_123", 3600)) mock_save.assert_called_once() @patch("databricks_common_utils.get_current_user") @@ -556,13 +551,13 @@ def test_get_aad_access_token_with_proxy_settings(self, mock_post, mock_user): mock_user.return_value = "test_user" mock_response = MagicMock() - mock_response.json.return_value = {"access_token": "aad_token_123"} + mock_response.json.return_value = {"access_token": "aad_token_123", "expires_in": 3600} mock_response.status_code = 200 mock_post.return_value = mock_response proxy_settings = {"http": "http://proxy:8080", "https": "http://proxy:8080", "use_for_oauth": "1"} - token = db_utils.get_aad_access_token( + result = db_utils.get_aad_access_token( "session_key", "account_name", "tenant_id", "client_id", "client_secret", proxy_settings=proxy_settings ) @@ -570,6 +565,8 @@ def test_get_aad_access_token_with_proxy_settings(self, mock_post, mock_user): # Verify use_for_oauth key is removed from proxy_settings call_args = mock_post.call_args self.assertNotIn("use_for_oauth", call_args[1]['proxies']) + # Now returns tuple (access_token, expires_in) + self.assertEqual(result, ("aad_token_123", 3600)) @patch("databricks_common_utils.get_current_user") @patch("databricks_common_utils.requests.post") diff --git a/tests/test_databricks_get_credentials.py b/tests/test_databricks_get_credentials.py index 48d3af2..6edb37f 100644 --- a/tests/test_databricks_get_credentials.py +++ b/tests/test_databricks_get_credentials.py @@ -1,38 +1,32 @@ import declare -import os -import sys import unittest -from utility import Response import json -import traceback -import base64 +from utility import Response from importlib import import_module from mock import patch, MagicMock +# Import shared test utilities +from conftest import ( + create_module_mocks, + teardown_module_mocks, + setup_persistconn_mock, + CREDENTIALS_MODULES, +) + mocked_modules = {} def setUpModule(): global mocked_modules - - module_to_be_mocked = [ - "log_manager", - "splunk", - "splunk.persistconn.application", - "splunk.rest", - ] - - mocked_modules = {module: MagicMock() for module in module_to_be_mocked} - - for module, magicmock in mocked_modules.items(): - if module == "splunk.persistconn.application": - magicmock.PersistentServerConnectionApplication = object - patch.dict("sys.modules", **{module: magicmock}).start() + special_handlers = { + "splunk.persistconn.application": setup_persistconn_mock, + } + mocked_modules = create_module_mocks(CREDENTIALS_MODULES, special_handlers) def tearDownModule(): - patch.stopall() + teardown_module_mocks() class TestDatabricksGetCredentials(unittest.TestCase): @@ -44,6 +38,7 @@ def test_get_credentials_object(self): self.assertIsInstance(obj1, db_cm.DatabricksGetCredentials) def test_handle_save_access_token_success(self): + """Test successful saving of AAD access token with expiration.""" db_cm = import_module("databricks_get_credentials") input_string = json.dumps({ "system_authtoken": "dummy_token", @@ -51,11 +46,12 @@ def test_handle_save_access_token_success(self): "name": "test", "update_token": "1", "aad_client_secret": "client_secret", - "aad_access_token": "access_token" + "aad_access_token": "access_token", + "aad_token_expiration": "1234567890.0" } }) - # mock the CredentialManager to raise an exception when set_password is called + # mock the CredentialManager credential_manager_mock = MagicMock() credential_manager_mock.set_password.return_value = "Saved AAD access token successfully." db_cm.CredentialManager = MagicMock(return_value=credential_manager_mock) @@ -65,8 +61,12 @@ def test_handle_save_access_token_success(self): assert result["payload"] == "Saved AAD access token successfully." assert result["status"] == 200 + + # Verify the credential manager was called with correct parameters including expiration + credential_manager_mock.set_password.assert_called_once() def test_handle_save_access_token_failure(self): + """Test failure when saving AAD access token.""" db_cm = import_module("databricks_get_credentials") input_string = json.dumps({ "system_authtoken": "dummy_token", @@ -74,7 +74,8 @@ def test_handle_save_access_token_failure(self): "name": "test", "update_token": "1", "aad_client_secret": "client_secret", - "aad_access_token": "access_token" + "aad_access_token": "access_token", + "aad_token_expiration": "1234567890.0" } }) @@ -335,6 +336,85 @@ def mock_request_side_effect(*args, **kwargs): assert payload["auth_type"] == "PAT" assert payload["databricks_pat"] == "pat_token_value" + @patch("databricks_get_credentials.rest.simpleRequest") + @patch("databricks_get_credentials.CredentialManager") + def test_handle_retrieve_aad_config(self, mock_cred_manager, mock_request): + """Test retrieving AAD configuration with token expiration.""" + db_cm = import_module("databricks_get_credentials") + db_cm._LOGGER = MagicMock() + + # Mock account manager + account_manager_mock = MagicMock() + account_manager_mock.get_password.return_value = json.dumps({ + "aad_client_secret": "aad_secret", + "aad_access_token": "aad_token", + "aad_token_expiration": "1234567890.0" + }) + + # Mock proxy manager + proxy_manager_mock = MagicMock() + proxy_manager_mock.get_password.return_value = json.dumps({}) + + mock_cred_manager.side_effect = [account_manager_mock, proxy_manager_mock] + + input_string = json.dumps({ + "system_authtoken": "dummy_token", + "form": { + "name": "test_account" + } + }) + + # Mock REST calls + def mock_request_side_effect(*args, **kwargs): + if "ta_databricks_account" in args[0]: + return (200, json.dumps({ + "entry": [{ + "content": { + "auth_type": "AAD", + "databricks_instance": "test.databricks.azure.net", + "aad_client_id": "aad_client_id", + "aad_tenant_id": "aad_tenant_id", + "config_for_dbquery": "warehouse", + "warehouse_id": "warehouse123", + "cluster_name": None + } + }] + })) + elif "proxy" in args[0]: + return (200, json.dumps({ + "entry": [{ + "content": { + "proxy_enabled": "0", + "proxy_password": "" + } + }] + })) + elif "additional_parameters" in args[0]: + return (200, json.dumps({ + "entry": [{ + "content": { + "admin_command_timeout": "300", + "query_result_limit": "1000", + "index": "main", + "thread_count": "4" + } + }] + })) + + mock_request.side_effect = mock_request_side_effect + + obj1 = db_cm.DatabricksGetCredentials("command_line", "command_args") + result = obj1.handle(input_string) + + assert result["status"] == 200 + payload = result["payload"] + assert payload["auth_type"] == "AAD" + assert payload["aad_client_id"] == "aad_client_id" + assert payload["aad_tenant_id"] == "aad_tenant_id" + assert payload["aad_access_token"] == "aad_token" + assert payload["aad_token_expiration"] == "1234567890.0" + assert payload["aad_client_secret"] == "aad_secret" + @patch("databricks_get_credentials.rest.simpleRequest") def test_handle_retrieve_config_error(self, mock_request): """Test error handling when retrieving configuration fails.""" diff --git a/tests/test_databricks_validators.py b/tests/test_databricks_validators.py index 0073870..53e84a7 100644 --- a/tests/test_databricks_validators.py +++ b/tests/test_databricks_validators.py @@ -1,39 +1,28 @@ import declare -import os -import sys import unittest -import json from utility import Response from importlib import import_module from mock import patch, MagicMock +# Import shared test utilities +from conftest import ( + create_module_mocks, + teardown_module_mocks, + VALIDATOR_MODULES, +) mocked_modules = {} -def setUpModule(): - global mocked_modules - - module_to_be_mocked = [ - 'log_manager', - 'splunk', - 'splunk.rest', - 'splunk.admin', - 'splunk.clilib', - 'splunk.clilib.cli_common', - 'solnlib.server_info', - 'splunk_aoblib', - 'splunk_aoblib.rest_migration' - ] - mocked_modules = {module: MagicMock() for module in module_to_be_mocked} - for module, magicmock in mocked_modules.items(): - patch.dict('sys.modules', **{module: magicmock}).start() +def setUpModule(): + global mocked_modules + mocked_modules = create_module_mocks(VALIDATOR_MODULES) def tearDownModule(): - patch.stopall() + teardown_module_mocks() class TestDatabricksUtils(unittest.TestCase): """Test Databricks Validators.""" @@ -128,10 +117,11 @@ def test_validate_pat_function(self, mock_valid_inst, mock_validator): db_val_obj.validate_pat({"auth_type": "PAT", "databricks_pat": "pat_token", "databricks_instance": "db_instance"}) mock_valid_inst.assert_called_once_with(db_val_obj, "db_instance", "pat_token") - @patch("databricks_validators.utils.get_aad_access_token", return_value="access_token") + @patch("databricks_validators.utils.get_aad_access_token", return_value=("access_token", 3600)) @patch("databricks_validators.Validator", autospec=True) @patch("databricks_validators.ValidateDatabricksInstance.validate_db_instance", autospec=True) def test_validate_aad_function(self, mock_valid_inst, mock_validator, mock_access): + """Test successful AAD validation flow with token expiration.""" db_val = import_module('databricks_validators') db_val._LOGGER = MagicMock() db_val_obj = db_val.ValidateDatabricksInstance() @@ -139,8 +129,24 @@ def test_validate_aad_function(self, mock_valid_inst, mock_validator, mock_acces db_val_obj._splunk_version = "splunk_version" db_val_obj._proxy_settings = {} mock_valid_inst.return_value = True - db_val_obj.validate_aad({"auth_type": "AAD", "aad_client_id": "cl_id", "aad_tenant_id": "tenant_id", "aad_client_secret": "client_secret", "databricks_instance": "db_instance"}) + + data = { + "auth_type": "AAD", + "aad_client_id": "cl_id", + "aad_tenant_id": "tenant_id", + "aad_client_secret": "client_secret", + "databricks_instance": "db_instance", + "name": "test_account" + } + result = db_val_obj.validate_aad(data) + mock_valid_inst.assert_called_once_with(db_val_obj, "db_instance", "access_token") + self.assertTrue(result) + self.assertEqual(data["aad_access_token"], "access_token") + self.assertEqual(data["databricks_pat"], "") + # Verify aad_token_expiration is set + self.assertIn("aad_token_expiration", data) + self.assertTrue(float(data["aad_token_expiration"]) > 0) @patch("databricks_validators.Validator.put_msg", return_value=MagicMock()) @patch("databricks_validators.utils.get_aad_access_token", return_value=("test", False)) @@ -298,7 +304,7 @@ def test_validate_oauth_function_instance_validation_failure(self, mock_valid_in # Additional AAD Validation Tests # ========================================================================= - @patch("databricks_validators.utils.get_aad_access_token", return_value="access_token") + @patch("databricks_validators.utils.get_aad_access_token", return_value=("access_token", 3600)) @patch("databricks_validators.Validator", autospec=True) @patch("databricks_validators.ValidateDatabricksInstance.validate_db_instance", autospec=True) def test_validate_aad_function_instance_failure(self, mock_valid_inst, mock_validator, mock_access): @@ -316,7 +322,8 @@ def test_validate_aad_function_instance_failure(self, mock_valid_inst, mock_vali "aad_client_id": "cl_id", "aad_tenant_id": "tenant_id", "aad_client_secret": "client_secret", - "databricks_instance": "db_instance" + "databricks_instance": "db_instance", + "name": "test_account" }) mock_valid_inst.assert_called_once() diff --git a/tests/utility.py b/tests/utility.py index ca2414d..2d7b688 100644 --- a/tests/utility.py +++ b/tests/utility.py @@ -1,10 +1,19 @@ +import json + + class Response: """Sample Response Class.""" - def __init__(self, status_code, json_data=None): + def __init__(self, status_code, json_data=None, text=None): """Init Method for Response.""" self.status_code = status_code self._json_data = json_data if json_data is not None else {"status_code": self.status_code} + self._text = text if text is not None else json.dumps(self._json_data) + + @property + def text(self): + """Return response text.""" + return self._text def json(self): """Set json value."""