diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e9e7a45..01738a5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -22,7 +22,7 @@ jobs: run: pip install -r requirements-dev.txt - name: Run tests and generate coverage working-directory: ./ - run: pytest --cov-config=tests/.coveragerc --cov=app/bin --cov-report=xml:coverage.xml tests + run: pytest --cov-config=tests/.coveragerc --cov=app/bin --cov-report=xml:coverage.xml --cov-fail-under=80 tests - name: Publish test coverage uses: codecov/codecov-action@v1 diff --git a/app/README.md b/app/README.md index cdf7047..31b08e3 100644 --- a/app/README.md +++ b/app/README.md @@ -81,6 +81,15 @@ The Databricks Add-on for Splunk is used to query Databricks data and execute Da # INSTALLATION Databricks Add-on for Splunk can be installed through UI using "Manage Apps" > "Install the app from file" or by extracting tarball directly into $SPLUNK_HOME/etc/apps/ folder. +# TESTING + +Test coverage: > 90% + +Run tests with coverage: +```bash +pytest --cov-config=tests/.coveragerc --cov=app/bin --cov-report=term-missing --cov-report=xml:coverage.xml tests +``` + # CAPABILITIES * Users with an 'admin' role can do the Configuration of the Account and Proxy, whereas users without an 'admin' role can't do the Configuration or view it. @@ -99,6 +108,15 @@ To configure the Add-on with Azure Active Directory token authentication, you ne * To add the provisioned service principal to the target Azure Databricks workspace, follow [these steps](https://learn.microsoft.com/en-us/azure/databricks/dev-tools/service-principals) and refer [this example](https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/latest/aad/service-prin-aad-token#--api-access-for-service-principals-that-are-azure-databricks-workspace-users-and-admins) * Note that the service principals must be Azure Databricks workspace users and admins. +### Databricks OAuth (M2M) Configuration + +To configure the Add-on with Databricks OAuth M2M authentication for service principals: + +* Create a service principal in your Databricks workspace, following [these steps](https://docs.databricks.com/aws/en/dev-tools/auth/oauth-m2m.html) +* Generate OAuth credentials (Client ID and Client Secret) for the service principal +* OAuth secrets can remain valid for up to two years +* The add-on automatically refreshes tokens before expiration (tokens typically expire after 1 hour) + ## 1. Add Databricks Credentials To configure Databricks Add-on for Splunk, navigate to Databricks Add-on for Splunk, click on "Configuration", go to the "Databricks Credentials" tab, click on "Add", fill in the details asked, and click "Save". Field descriptions are as below: @@ -109,11 +127,13 @@ To configure Databricks Add-on for Splunk, navigate to Databricks Add-on for Spl | 'databricksquery' to run on | Mode through which databricksquery command should execute | Yes | | Databricks Cluster Name | Name of the Databricks cluster to use for query and notebook execution. A user can override this value while executing the custom command. | No | Databricks Warehouse ID | ID of the Databricks warehouse to use for query execution. A user can override this value while executing the custom command. | No -| Authentication Method | SingleSelect: Authentication via Azure Active Directory or using a Personal Access Token | Yes +| Authentication Method | SingleSelect: Authentication via Personal Access Token, Azure Active Directory, or Databricks OAuth (M2M) | Yes | Databricks Access Token | [Auth: Personal Access Token] Databricks personal access token to use for authentication. Refer [Generate Databricks Access Token](https://docs.databricks.com/dev-tools/api/latest/authentication.html#generate-a-personal-access-token) document to generate the access token. | Yes | | Client Id | [Auth: Azure Active Directory] Azure Active Directory Client Id from your Azure portal.| Yes | Tenant Id | [Auth: Azure Active Directory] Databricks Application(Tenant) Id from your Azure portal.| Yes | Client Secret | [Auth: Azure Active Directory] Azure Active Directory Client Secret from your Azure portal.| Yes +| OAuth Client ID | [Auth: Databricks OAuth (M2M)] OAuth Client ID from your Databricks service principal. Refer [OAuth M2M documentation](https://docs.databricks.com/aws/en/dev-tools/auth/oauth-m2m.html) for details. | Yes +| OAuth Client Secret | [Auth: Databricks OAuth (M2M)] OAuth Client Secret from your Databricks service principal.| Yes ## 2. Configure Proxy (Required only if the requests should go via proxy server) @@ -128,18 +148,18 @@ Navigate to Databricks Add-on for Splunk, click on "Configuration", go to the "P | Username | Username for proxy authentication (Username and Password are inclusive fields) | No | | Password | Password for proxy authentication (Username and Password are inclusive fields) | No | | Remote DNS resolution | Enabling this option allows the proxy server to handle DNS resolution for clients, enhancing privacy and centralizing control over DNS requests. | No | -| Use Proxy for OAuth | Check this box if you want to use a proxy just for AAD token generation (https://login.microsoftonline.com/). All other network calls will skip the proxy even if it's enabled. | No | +| Use Proxy for OAuth | Check this box if you want to use a proxy just for Azure AD token generation (https://login.microsoftonline.com/). All other network calls (including Databricks API and OAuth M2M) will skip the proxy even if it's enabled. | No | **Steps to configure an HTTPS proxy** * Select Proxy Type as "http" and provide the other required details for proxy configuration. -* To install the proxy certificate in the Add-on , Go to folder $SPLUNK_HOME/etc/apps/TA-Databricks/bin/ta_databricks/aob_py3/certifi +* To install the proxy certificate in the Add-on , Go to folder `$SPLUNK_HOME/etc/apps/TA-Databricks/bin/ta_databricks/aob_py3/certifi` * Put the proxy certificate at the end of the file named cacert.pem Once the above steps are completed, all the following requests will be directed through the proxy. -**Note**: $SPLUNK_HOME denotes the path where Splunk is installed. Ex: /opt/splunk +**Note**: `$SPLUNK_HOME` denotes the path where Splunk is installed. Ex: `/opt/splunk` After enabling the proxy, re-visit the "Databricks Credentials" tab, fill in the details, and click on "Save" to verify if the proxy is working. @@ -155,6 +175,7 @@ Navigate to Databricks Add-on for Splunk, click on "Configuration", go to the "A | Query Result Limit | Maximum limit of rows in query result for databricksquery command. | Yes | | Index | Index in which you want to store the command execution details. | Yes | | Max Thread Count | Maximum number of threads to be allowed for databricksquery command to fetch the results. | Yes | + # CUSTOM COMMANDS: Any user will be able to execute the custom command. Once the admin user configures Databricks Add-on for Splunk successfully, they can execute custom commands. With custom commands, users can: @@ -178,8 +199,8 @@ This custom command helps users to query their data present in the Databricks ta | command_timeout | No | Time to wait in seconds for query completion. Default value: 300 | * Syntax - - Syntax 1 : | databricksquery account_name="" warehouse_id="" query="" command_timeout= limit= - - Syntax 2 : | databricksquery account_name="" cluster="" query="" command_timeout= + - Syntax 1 : `| databricksquery account_name="" warehouse_id="" query="" command_timeout= limit=` + - Syntax 2 : `| databricksquery account_name="" cluster="" query="" command_timeout=` * Output @@ -187,9 +208,13 @@ The command gives the output of the query in tabular format. It will return an e * Example +``` | databricksquery account_name="db_account" query="SELECT * FROM default.people WHERE age>30" warehouse_id=12345a67 command_timeout=60 limit=500 +``` +``` | databricksquery query="SELECT * FROM default.people WHERE age>30" cluster="test_cluster" command_timeout=60 account_name="AAD_account" +``` ## 2. databricksrun @@ -208,7 +233,9 @@ This custom command helps users to submit a one-time run without creating a job. * Syntax +``` | databricksrun account_name="" notebook_path="" run_name="" cluster="" revision_timestamp= notebook_params="" +``` * Output @@ -216,11 +243,15 @@ The command will give the details about the executed run through job. * Example 1 +``` | databricksrun account_name="db_account" notebook_path="/path/to/test_notebook" run_name="run_comm" cluster="test_cluster" revision_timestamp=1609146477 notebook_params="key1=value1||key2=value2" +``` * Example 2 +``` | databricksrun account_name="db_account" notebook_path="/path/to/test_notebook" run_name="run_comm" cluster="test_cluster" revision_timestamp=1609146477 notebook_params="key1=value with \"double quotes\" in it||key2=value2" +``` ## 3. databricksjob @@ -236,7 +267,9 @@ This custom command helps users to run an already created job from Splunk. * Syntax +``` | databricksjob account_name="" job_id= notebook_params="" +``` * Output @@ -244,11 +277,15 @@ The command will give the details about the executed run through job. * Example 1 +``` | databricksjob account_name="db_account" job_id=2 notebook_params="key1=value1||key2=value2" +``` * Example 2 +``` | databricksjob account_name="db_account" job_id=2 notebook_params="key1=value with \"double quotes\" in it||key2=value2" +``` # Macro Macro `databricks_index_macro` specifies the index in which you want to store the command execution details. @@ -261,7 +298,7 @@ To modify Macro from Splunk UI, 4. Click on the `Save` button. # SAVED SEARCH -Saved search `databricks_update_run_execution_status` uses databricksrunstatus custom command to fetch run execution status and ingest updated details in Splunk for runs invoked through databricksrun and databricksjob command. +Saved search `databricks_update_run_execution_status` uses `databricksrunstatus` custom command to fetch run execution status and ingest updated details in Splunk for runs invoked through `databricksrun` and `databricksjob` command. It runs every 5 minutes and ingests the data with updated execution status in Splunk. # DASHBOARDS @@ -271,7 +308,7 @@ This app contains the following dashboards: * Launch Notebook: The dashboard allows users to launch a notebook on their Databricks cluster by providing the required parameters. The users can then navigate to the job results page on the Databricks instance from the generated link on the dashboard. -The dashboards will be accessible to all the users. A user with admin_all_objects capability can navigate to “:/en-US/app/TA-Databricks/dashboards” to modify the permissions for “Databricks Job Execution Details” dashboard. +The dashboards will be accessible to all the users. A user with admin_all_objects capability can navigate to `:/en-US/app/TA-Databricks/dashboards` to modify the permissions for "Databricks Job Execution Details" dashboard. # ALERT ACTIONS The `Launch Notebook` alert action is used to launch a parameterized notebook based on the provided parameters. The alert can be scheduled or run as ad-hoc. It can also be used as an Adaptive response action in "Enterprise Security> Incident review dashboard". @@ -290,22 +327,27 @@ When this alert action is run as an Adaptive response action from "Enterprise Se * Restart Splunk. ## Upgrade from Databricks Add-On for Splunk v1.4.1 to v1.4.2 + * Follow the General upgrade steps section. * No additional steps are required. ## Upgrade from Databricks Add-On for Splunk v1.4.0 to v1.4.1 + * Follow the General upgrade steps section. * No additional steps are required. ## Upgrade from Databricks Add-On for Splunk v1.3.1 to v1.4.0 + * Follow the General upgrade steps section. * No additional steps are required. ## Upgrade from Databricks Add-On for Splunk v1.3.0 to v1.3.1 + * Follow the General upgrade steps section. * No additional steps are required. ## Upgrade from Databricks Add-On for Splunk v1.2.0 to v1.3.0 + * Follow the General upgrade steps section. * No additional steps are required. @@ -323,6 +365,7 @@ Follow the below steps to upgrade the Add-on to 1.2.0 ## Upgrade from Databricks Add-On for Splunk v1.0.0 to v1.1.0 + No special steps are required. Upload and install v1.1.0 of the add-on normally. @@ -341,26 +384,30 @@ Some of the components included in "Databricks Add-on for Splunk" are licensed u # TROUBLESHOOTING * Authentication Failure: Check the network connectivity and verify that the configuration details provided are correct. -* For any other unknown failure, please check the log files $SPLUNK_HOME/var/log/ta_databricks*.log to get more details on the issue. +* For any other unknown failure, please check the log files `$SPLUNK_HOME/var/log/ta_databricks*.log` to get more details on the issue. * The Add-on does not require a restart after the installation for all functionalities to work. However, the icons will be visible after one Splunk restart post-installation. * If all custom commands/notebooks fail to run with the https response code [403] then most probably the client secret has expired. Please regenerate your client secret in this case on your Azure portal and configure the add-on again with the new client secret. Set the client secret's expiration time to a custom value that you see fit. Refer this [guide](https://docs.microsoft.com/en-us/azure/active-directory/develop/quickstart-register-app#add-a-client-secret) for setting a client secret in Azure Active Directory. * If the proxy is enabled and Use Proxy for OAuth is checked, and custom commands fail to run and throw the below mentioned error. HTTPSConnectionPool(host=, port=443): Max retries exceeded with url: (Caused by NewConnectionError(': Failed to establish a new connection: [Errno 110] Connection timed out')) In this case, uncheck 'Use Proxy for OAuth' and save the Proxy configuration, and re-run the custom command again. +* For OAuth M2M authentication, if commands fail with "Invalid OAuth credentials" error, verify that: + - The service principal has been created in the Databricks workspace + - The OAuth client secret has not expired + - The service principal has appropriate permissions in the workspace **Note**: $SPLUNK_HOME denotes the path where Splunk is installed. Ex: /opt/splunk # UNINSTALL & CLEANUP STEPS -* Remove $SPLUNK_HOME/etc/apps/TA-Databricks/ -* Remove $SPLUNK_HOME/var/log/TA-Databricks/ -* Remove $SPLUNK_HOME/var/log/splunk/**ta_databricks*.log** +* Remove `$SPLUNK_HOME/etc/apps/TA-Databricks/` +* Remove `$SPLUNK_HOME/var/log/TA-Databricks/` +* Remove `$SPLUNK_HOME/var/log/splunk/**ta_databricks*.log**` * To reflect the cleanup changes in UI, restart the Splunk instance. Refer [Start Splunk](https://docs.splunk.com/Documentation/Splunk/8.0.6/Admin/StartSplunk) documentation to get information on how to restart Splunk. -**Note**: $SPLUNK_HOME denotes the path where Splunk is installed. Ex: /opt/splunk +**Note**: `$SPLUNK_HOME` denotes the path where Splunk is installed. Ex: `/opt/splunk` # SUPPORT * This app is not officially supported by Databricks. Please send an email to cybersecurity@databricks.com for help. # COPYRIGHT -© Databricks 2024. All rights reserved. Apache, Apache Spark, Spark and the Spark logo are trademarks of the Apache Software Foundation. \ No newline at end of file +© Databricks 2024. All rights reserved. Apache, Apache Spark, Spark and the Spark logo are trademarks of the Apache Software Foundation. diff --git a/app/appserver/static/js/build/custom/auth_select_hook.js b/app/appserver/static/js/build/custom/auth_select_hook.js index 9fbdae0..0c9b597 100644 --- a/app/appserver/static/js/build/custom/auth_select_hook.js +++ b/app/appserver/static/js/build/custom/auth_select_hook.js @@ -9,11 +9,7 @@ class AuthSelectHook { onChange(field, value, dataDict) { if (field == 'auth_type') { - if (value == 'AAD') { - this.toggleAADFields(true); - } else { - this.toggleAADFields(false); - } + this.toggleAuthFields(value); } if (field == 'config_for_dbquery') { if (value == 'interactive_cluster') { @@ -26,11 +22,7 @@ class AuthSelectHook { onRender() { var selected_auth = this.state.data.auth_type.value; - if (selected_auth == 'AAD') { - this.toggleAADFields(true); - } else { - this.toggleAADFields(false); - } + this.toggleAuthFields(selected_auth); } hideWarehouseField(state) { @@ -41,13 +33,25 @@ class AuthSelectHook { }); } - toggleAADFields(state) { + toggleAuthFields(authType) { this.util.setState((prevState) => { let data = {...prevState.data }; - data.aad_client_id.display = state; - data.aad_tenant_id.display = state; - data.aad_client_secret.display = state; - data.databricks_pat.display = !state; + + // OAuth M2M fields + const showOAuth = (authType === 'OAUTH_M2M'); + data.oauth_client_id.display = showOAuth; + data.oauth_client_secret.display = showOAuth; + + // AAD fields + const showAAD = (authType === 'AAD'); + data.aad_client_id.display = showAAD; + data.aad_tenant_id.display = showAAD; + data.aad_client_secret.display = showAAD; + + // PAT field + const showPAT = (authType === 'PAT'); + data.databricks_pat.display = showPAT; + return { data } }); } diff --git a/app/appserver/static/js/build/globalConfig.json b/app/appserver/static/js/build/globalConfig.json index 1a16a28..200e495 100644 --- a/app/appserver/static/js/build/globalConfig.json +++ b/app/appserver/static/js/build/globalConfig.json @@ -146,6 +146,10 @@ { "label": "Azure Active Directory", "value": "AAD" + }, + { + "label": "Databricks OAuth (M2M)", + "value": "OAUTH_M2M" } ] }, @@ -205,6 +209,42 @@ "placeholder": "required" } }, + { + "field": "oauth_client_id", + "label": "OAuth Client ID", + "type": "text", + "help": "Enter the Client ID from your Databricks service principal.", + "required": false, + "defaultValue": "", + "encrypted": false, + "validators": [{ + "type": "string", + "minLength": 0, + "maxLength": 200, + "errorMsg": "Max length of OAuth Client ID is 200" + }], + "options": { + "placeholder": "required" + } + }, + { + "field": "oauth_client_secret", + "label": "OAuth Client Secret", + "type": "text", + "help": "Enter the OAuth secret from your Databricks service principal.", + "required": false, + "defaultValue": "", + "encrypted": true, + "validators": [{ + "type": "string", + "minLength": 0, + "maxLength": 500, + "errorMsg": "Max length of OAuth Client Secret is 500" + }], + "options": { + "placeholder": "required" + } + }, { "field": "databricks_pat", "label": "Databricks Access Token", @@ -318,7 +358,7 @@ "field": "use_for_oauth", "label": "Use Proxy for OAuth", "defaultValue": 0, - "tooltip": "Check this box if you want to use proxy just for AAD token generation (https://login.microsoftonline.com/). All other network calls will skip the proxy even if it's enabled.", + "tooltip": "Check this box if you want to use proxy just for Azure AD token generation (https://login.microsoftonline.com/). All other network calls (including Databricks API and OAuth M2M) will skip the proxy even if it's enabled.", "type": "checkbox" } ], diff --git a/app/bin/TA_Databricks_rh_account.py b/app/bin/TA_Databricks_rh_account.py index 5bd4542..c76932e 100644 --- a/app/bin/TA_Databricks_rh_account.py +++ b/app/bin/TA_Databricks_rh_account.py @@ -62,6 +62,35 @@ default='', validator=None ), + field.RestField( + 'oauth_client_id', + required=False, + encrypted=False, + default='', + validator=validator.String( + min_len=0, + max_len=200, + ) + ), + field.RestField( + 'oauth_client_secret', + required=False, + encrypted=True, + default='', + validator=None + ), + field.RestField( + 'oauth_access_token', + required=False, + encrypted=True + ), + field.RestField( + 'oauth_token_expiration', + required=False, + encrypted=False, + default='', + validator=None + ), field.RestField( 'databricks_pat', required=False, diff --git a/app/bin/databricks_com.py b/app/bin/databricks_com.py index a6e732c..12976c1 100755 --- a/app/bin/databricks_com.py +++ b/app/bin/databricks_com.py @@ -44,11 +44,17 @@ def __init__(self, account_name, session_key): self.session.timeout = const.TIMEOUT if self.auth_type == "PAT": self.databricks_token = databricks_configs.get("databricks_pat") - else: + elif self.auth_type == "AAD": self.databricks_token = databricks_configs.get("aad_access_token") self.aad_client_id = databricks_configs.get("aad_client_id") self.aad_tenant_id = databricks_configs.get("aad_tenant_id") self.aad_client_secret = databricks_configs.get("aad_client_secret") + elif self.auth_type == "OAUTH_M2M": + self.databricks_token = databricks_configs.get("oauth_access_token") + self.oauth_client_id = databricks_configs.get("oauth_client_id") + self.oauth_client_secret = databricks_configs.get("oauth_client_secret") + oauth_token_expiration_str = databricks_configs.get("oauth_token_expiration") + self.oauth_token_expiration = float(oauth_token_expiration_str) if oauth_token_expiration_str else 0 if not all([databricks_instance, self.databricks_token]): raise Exception("Addon is not configured. Navigate to addon's configuration page to configure the addon.") @@ -98,6 +104,48 @@ def get_requests_retry_session(self): session.mount("https://", adapter) return session + def should_refresh_oauth_token(self): + """ + Check if OAuth token should be refreshed proactively. + + :return: Boolean - True if token expires within 5 minutes + """ + if not hasattr(self, 'oauth_token_expiration'): + return False + + import time + current_time = time.time() + time_until_expiry = self.oauth_token_expiration - current_time + + # Refresh if token expires within 5 minutes (300 seconds) + return time_until_expiry < 300 + + def _refresh_oauth_token(self): + """Refresh OAuth M2M access token and update session headers.""" + databricks_configs = utils.get_databricks_configs(self.session_key, self.account_name) + proxy_config = databricks_configs.get("proxy_uri") + + result = utils.get_oauth_access_token( + self.session_key, + self.account_name, + self.databricks_instance_url.replace("https://", ""), + self.oauth_client_id, + self.oauth_client_secret, + proxy_config, + retry=const.RETRIES, + conf_update=True + ) + + if isinstance(result, tuple) and result[1] == False: + raise Exception(result[0]) + + access_token, expires_in = result + self.databricks_token = access_token + import time + self.oauth_token_expiration = time.time() + expires_in + self.request_headers["Authorization"] = "Bearer {}".format(self.databricks_token) + self.session.headers.update(self.request_headers) + def databricks_api(self, method, endpoint, data=None, args=None): """ Common method to hit the API of Databricks instance. @@ -108,6 +156,11 @@ def databricks_api(self, method, endpoint, data=None, args=None): :param args: Arguments to be add into the url :return: response in the form of dictionary """ + # Proactive OAuth token refresh + if self.auth_type == "OAUTH_M2M" and self.should_refresh_oauth_token(): + _LOGGER.info("OAuth token expiring soon, refreshing proactively.") + self._refresh_oauth_token() + run_again = True request_url = "{}{}".format(self.databricks_instance_url, endpoint) try: @@ -141,6 +194,11 @@ def databricks_api(self, method, endpoint, data=None, args=None): self.databricks_token = db_token self.request_headers["Authorization"] = "Bearer {}".format(self.databricks_token) self.session.headers.update(self.request_headers) + elif status_code == 403 and self.auth_type == "OAUTH_M2M" and run_again: + response = None + run_again = False + _LOGGER.info("Refreshing OAuth M2M token.") + self._refresh_oauth_token() elif status_code != 200: response.raise_for_status() else: diff --git a/app/bin/databricks_common_utils.py b/app/bin/databricks_common_utils.py index e59b449..bab12ac 100755 --- a/app/bin/databricks_common_utils.py +++ b/app/bin/databricks_common_utils.py @@ -110,6 +110,137 @@ def save_databricks_aad_access_token(account_name, session_key, access_token, cl raise Exception("Exception while saving AAD access token.") +def save_databricks_oauth_access_token(account_name, session_key, access_token, expires_in, client_secret): + """ + Method to store new OAuth access token with expiration timestamp. + + :param account_name: Account name + :param session_key: Splunk session key + :param access_token: OAuth access token + :param expires_in: Token lifetime in seconds + :param client_secret: OAuth client secret + :return: None + """ + import time + new_creds = { + "name": account_name, + "oauth_client_secret": client_secret, + "oauth_access_token": access_token, + "oauth_token_expiration": str(time.time() + expires_in), + "update_token": True + } + try: + _LOGGER.info("Saving databricks OAuth access token.") + rest.simpleRequest( + "/databricks_get_credentials", + sessionKey=session_key, + postargs=new_creds, + raiseAllErrors=True, + ) + _LOGGER.info("Saved OAuth access token successfully.") + except Exception as e: + _LOGGER.error("Exception while saving OAuth access token: {}".format(str(e))) + _LOGGER.debug(traceback.format_exc()) + raise Exception("Exception while saving OAuth access token.") + + +def get_oauth_access_token( + session_key, + account_name, + databricks_instance, + oauth_client_id, + oauth_client_secret, + proxy_settings=None, + retry=1, + conf_update=False, +): + """ + Method to acquire OAuth M2M access token for Databricks service principal. + + :param session_key: Splunk session key + :param account_name: Account name for configuration storage + :param databricks_instance: Databricks workspace instance URL + :param oauth_client_id: OAuth client ID from service principal + :param oauth_client_secret: OAuth client secret from service principal + :param proxy_settings: Proxy configuration dict + :param retry: Number of retry attempts + :param conf_update: If True, store token in configuration + :return: tuple (access_token, expires_in) or (error_message, False) + """ + import time + from requests.auth import HTTPBasicAuth + + token_url = "https://{}/oidc/v1/token".format(databricks_instance) + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "User-Agent": "{}".format(const.USER_AGENT_CONST), + } + _LOGGER.debug("Request made to the Databricks from Splunk user: {}".format(get_current_user(session_key))) + data_dict = {"grant_type": "client_credentials", "scope": "all-apis"} + data_encoded = urlencode(data_dict) + + # Handle proxy settings for OAuth M2M + # Note: "use_for_oauth" means "use proxy ONLY for AAD token generation" + # Since OAuth M2M endpoint is on the Databricks instance (not AAD), + # we should skip proxy when use_for_oauth is true + if proxy_settings: + if is_true(proxy_settings.get("use_for_oauth")): + _LOGGER.info( + "Skipping the usage of proxy for OAuth M2M as 'Use Proxy for OAuth' parameter is checked." + ) + proxy_settings_copy = None + else: + proxy_settings_copy = proxy_settings.copy() + proxy_settings_copy.pop("use_for_oauth", None) + else: + proxy_settings_copy = None + + while retry: + try: + resp = requests.post( + token_url, + headers=headers, + data=data_encoded, + auth=HTTPBasicAuth(oauth_client_id, oauth_client_secret), + proxies=proxy_settings_copy, + verify=const.VERIFY_SSL, + timeout=const.TIMEOUT + ) + resp.raise_for_status() + response = resp.json() + oauth_access_token = response.get("access_token") + expires_in = response.get("expires_in", 3600) + if conf_update: + save_databricks_oauth_access_token( + account_name, session_key, oauth_access_token, expires_in, oauth_client_secret + ) + return oauth_access_token, expires_in + except Exception as e: + retry -= 1 + if "resp" in locals(): + error_code = resp.json().get("error") + if error_code and error_code in list(const.ERROR_CODE.keys()): + msg = const.ERROR_CODE[error_code] + elif str(resp.status_code) in list(const.ERROR_CODE.keys()): + msg = const.ERROR_CODE[str(resp.status_code)] + elif resp.status_code not in (200, 201): + msg = ( + "Response status: {}. Unable to validate OAuth credentials. " + "Check logs for more details.".format(str(resp.status_code)) + ) + else: + msg = ( + "Unable to request Databricks instance. " + "Please validate the provided Databricks and " + "Proxy configurations or check the network connectivity." + ) + _LOGGER.error("Error while trying to generate OAuth access token: {}".format(str(e))) + _LOGGER.debug(traceback.format_exc()) + _LOGGER.error(msg) + if retry == 0: + return msg, False + + def get_proxy_clear_password(session_key): """ Get clear password from splunk passwords.conf. diff --git a/app/bin/databricks_const.py b/app/bin/databricks_const.py index 358330a..f4a058f 100755 --- a/app/bin/databricks_const.py +++ b/app/bin/databricks_const.py @@ -54,4 +54,7 @@ "500": "Internal server error.", "400": "Bad request. The request is malformed.", "429": "API limit exceeded. Please try again after some time.", + "invalid_client": "Invalid OAuth Client ID or Client Secret provided.", + "unauthorized_client": "Service principal is not authorized for this workspace.", + "invalid_grant": "The provided OAuth credentials are invalid or expired.", } diff --git a/app/bin/databricks_get_credentials.py b/app/bin/databricks_get_credentials.py index 03b84f2..845b762 100644 --- a/app/bin/databricks_get_credentials.py +++ b/app/bin/databricks_get_credentials.py @@ -48,23 +48,40 @@ def handle(self, in_string): # Saving Configurations if form_data.get('update_token'): try: - _LOGGER.info("Saving databricks AAD access token.") - client_sec = form_data.get("aad_client_secret") - access_token = form_data.get("aad_access_token") + # Determine which credentials to save based on what's provided + if form_data.get("aad_access_token"): + _LOGGER.info("Saving databricks AAD access token.") + client_sec = form_data.get("aad_client_secret") + access_token = form_data.get("aad_access_token") + new_creds = json.dumps({"aad_client_secret": client_sec, "aad_access_token": access_token}) + success_msg = 'Saved AAD access token successfully.' + elif form_data.get("oauth_access_token"): + _LOGGER.info("Saving databricks OAuth access token.") + client_sec = form_data.get("oauth_client_secret") + access_token = form_data.get("oauth_access_token") + token_expiration = form_data.get("oauth_token_expiration") + new_creds = json.dumps({ + "oauth_client_secret": client_sec, + "oauth_access_token": access_token, + "oauth_token_expiration": token_expiration + }) + success_msg = 'Saved OAuth access token successfully.' + else: + raise Exception("No token data provided for update.") + manager = CredentialManager( self.admin_session_key, app=APP_NAME, realm="__REST_CREDENTIAL__#{0}#{1}".format(APP_NAME, "configs/conf-ta_databricks_account"), ) - new_creds = json.dumps({"aad_client_secret": client_sec, "aad_access_token": access_token}) manager.set_password(self.account_name, new_creds) - _LOGGER.info("Saved AAD access token successfully.") + _LOGGER.info(success_msg) return { - 'payload': 'Saved AAD access token successfully.', + 'payload': success_msg, 'status': 200 } except Exception as e: - error_msg = "Databricks Error: Exception while saving AAD access token: {}".format(str(e)) + error_msg = "Databricks Error: Exception while saving access token: {}".format(str(e)) _LOGGER.error(error_msg) _LOGGER.debug(traceback.format_exc()) return { @@ -79,6 +96,10 @@ def handle(self, in_string): 'aad_tenant_id': None, 'aad_client_secret': None, 'aad_access_token': None, + 'oauth_client_id': None, + 'oauth_client_secret': None, + 'oauth_access_token': None, + 'oauth_token_expiration': None, 'config_for_dbquery': None, 'cluster_name': None, 'warehouse_id': None, @@ -130,11 +151,16 @@ def handle(self, in_string): if config_dict['auth_type'] == 'PAT': config_dict['databricks_pat'] = account_password.get('databricks_pat') - else: + elif config_dict['auth_type'] == 'AAD': config_dict['aad_client_id'] = account_config.get('aad_client_id') config_dict['aad_tenant_id'] = account_config.get('aad_tenant_id') config_dict['aad_client_secret'] = account_password.get('aad_client_secret') config_dict['aad_access_token'] = account_password.get('aad_access_token') + elif config_dict['auth_type'] == 'OAUTH_M2M': + config_dict['oauth_client_id'] = account_config.get('oauth_client_id') + config_dict['oauth_client_secret'] = account_password.get('oauth_client_secret') + config_dict['oauth_access_token'] = account_password.get('oauth_access_token') + config_dict['oauth_token_expiration'] = account_password.get('oauth_token_expiration') # Get proxy settings from conf _, proxy_response_content = rest.simpleRequest( diff --git a/app/bin/databricks_validators.py b/app/bin/databricks_validators.py index 7c045de..0035488 100755 --- a/app/bin/databricks_validators.py +++ b/app/bin/databricks_validators.py @@ -68,6 +68,47 @@ def validate_aad(self, data): else: return False + def validate_oauth(self, data): + """ + Validation flow if the user opts for OAuth M2M authentication. + + :param data: Dictionary containing values from configuration UI. + :return: Boolean depending on the success of the connection + """ + import time + _LOGGER.info('Obtaining OAuth M2M access token') + oauth_client_id = data.get("oauth_client_id").strip() + oauth_client_secret = data.get("oauth_client_secret").strip() + databricks_instance = data.get("databricks_instance").strip("/") + account_name = data.get("name") + + result = utils.get_oauth_access_token( + self._splunk_session_key, + account_name, + databricks_instance, + oauth_client_id, + oauth_client_secret, + self._proxy_settings + ) + + if isinstance(result, tuple) and result[1] == False: + _LOGGER.error(result[0]) + self.put_msg(result[0]) + return False + + access_token, expires_in = result + _LOGGER.info('Obtained OAuth M2M access token successfully.') + + valid_instance = self.validate_db_instance(databricks_instance, access_token) + if valid_instance: + data["oauth_access_token"] = access_token + data["oauth_token_expiration"] = str(time.time() + expires_in) + data["databricks_pat"] = "" + data["aad_access_token"] = "" + return True + else: + return False + def validate_db_instance(self, instance_url, access_token): """ Method to validate databricks instance. @@ -157,7 +198,7 @@ def validate(self, value, data): ): self.put_msg('Field Databricks Access Token is required') return False - else: + elif auth_type == "AAD": if (not (data.get("aad_client_id", None) and data.get("aad_client_id").strip()) ): @@ -173,6 +214,17 @@ def validate(self, value, data): ): self.put_msg('Field Client Secret is required') return False + elif auth_type == "OAUTH_M2M": + if (not (data.get("oauth_client_id", None) + and data.get("oauth_client_id").strip()) + ): + self.put_msg('Field OAuth Client ID is required') + return False + elif (not (data.get("oauth_client_secret", None) + and data.get("oauth_client_secret").strip()) + ): + self.put_msg('Field OAuth Client Secret is required') + return False _LOGGER.info("Reading proxy and user data.") try: self._proxy_settings = utils.get_proxy_uri(self._splunk_session_key) @@ -191,5 +243,7 @@ def validate(self, value, data): return False if auth_type == "PAT": return self.validate_pat(data) - else: + elif auth_type == "AAD": return self.validate_aad(data) + elif auth_type == "OAUTH_M2M": + return self.validate_oauth(data) \ No newline at end of file diff --git a/tests/.coveragerc b/tests/.coveragerc index f1a0aed..ea60721 100644 --- a/tests/.coveragerc +++ b/tests/.coveragerc @@ -16,4 +16,4 @@ omit = [report] -fail_under = 60 +fail_under = 80 diff --git a/tests/test_cancel_run.py b/tests/test_cancel_run.py new file mode 100644 index 0000000..4cf5dab --- /dev/null +++ b/tests/test_cancel_run.py @@ -0,0 +1,371 @@ +import declare +import unittest +import json +from mock import patch, MagicMock + +mocked_modules = {} + + +def setUpModule(): + global mocked_modules + + module_to_be_mocked = [ + 'log_manager', + 'splunk', + 'splunk.rest', + 'splunk.admin', + 'splunk.clilib', + 'splunk.clilib.cli_common', + 'solnlib.server_info', + ] + + mocked_modules = {module: MagicMock() for module in module_to_be_mocked} + + # Create a proper mock for PersistentServerConnectionApplication + class MockPersistentServerConnectionApplication: + def __init__(self): + pass + + # Mock the splunk.persistconn.application module + mock_persistconn_app = MagicMock() + mock_persistconn_app.PersistentServerConnectionApplication = MockPersistentServerConnectionApplication + + mock_persistconn = MagicMock() + mock_persistconn.application = mock_persistconn_app + + mocked_modules['splunk.persistconn'] = mock_persistconn + mocked_modules['splunk.persistconn.application'] = mock_persistconn_app + + for module, magicmock in mocked_modules.items(): + patch.dict('sys.modules', **{module: magicmock}).start() + + +def tearDownModule(): + patch.stopall() + + +class TestCancelRunningExecution(unittest.TestCase): + """Test CancelRunningExecution class.""" + + def setUp(self): + """Set up the test.""" + import cancel_run + self.cancel_run = cancel_run + self.CancelRunningExecution = cancel_run.CancelRunningExecution + + def test_initialization(self): + """Test object initialization with all attributes.""" + obj = self.CancelRunningExecution("command_line", "command_arg") + + # Verify all attributes are initialized to None or empty dict + self.assertIsNone(obj.run_id) + self.assertIsNone(obj.account_name) + self.assertIsNone(obj.uid) + self.assertEqual(obj.payload, {}) + self.assertIsNone(obj.status) + self.assertIsNone(obj.session_key) + + @patch("cancel_run.com.DatabricksClient") + def test_handle_successful_cancellation(self, mock_client_class): + """Test successful run cancellation with 200 response.""" + # Setup + obj = self.CancelRunningExecution("command_line", "command_arg") + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.databricks_api.return_value = ({"status": "CANCELLED"}, 200) + + # Prepare input data + input_data = json.dumps({ + "form": { + "run_id": "12345", + "account_name": "test_account", + "uid": "test_uid_123" + }, + "session": { + "authtoken": "test_session_key" + } + }) + + # Execute + result = obj.handle(input_data) + + # Verify + self.assertEqual(result['status'], 200) + self.assertEqual(result['payload']['canceled'], "Success") + self.assertEqual(obj.run_id, "12345") + self.assertEqual(obj.account_name, "test_account") + self.assertEqual(obj.uid, "test_uid_123") + self.assertEqual(obj.session_key, "test_session_key") + + # Verify DatabricksClient was called correctly + mock_client_class.assert_called_once_with("test_account", "test_session_key") + mock_client.databricks_api.assert_called_once_with( + "post", + "/api/2.0/jobs/runs/cancel", + data={"run_id": "12345"} + ) + + @patch("cancel_run.com.DatabricksClient") + def test_handle_failed_cancellation_404(self, mock_client_class): + """Test failed cancellation with 404 response.""" + # Setup + obj = self.CancelRunningExecution("command_line", "command_arg") + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.databricks_api.return_value = ({"error": "Not found"}, 404) + + # Prepare input data + input_data = json.dumps({ + "form": { + "run_id": "99999", + "account_name": "test_account", + "uid": "test_uid_123" + }, + "session": { + "authtoken": "test_session_key" + } + }) + + # Execute + result = obj.handle(input_data) + + # Verify + self.assertEqual(result['status'], 500) + self.assertEqual(result['payload']['canceled'], "Failed") + mock_client.databricks_api.assert_called_once() + + @patch("cancel_run.com.DatabricksClient") + def test_handle_failed_cancellation_500(self, mock_client_class): + """Test failed cancellation with 500 response.""" + # Setup + obj = self.CancelRunningExecution("command_line", "command_arg") + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.databricks_api.return_value = ({"error": "Internal server error"}, 500) + + # Prepare input data + input_data = json.dumps({ + "form": { + "run_id": "12345", + "account_name": "test_account", + "uid": "test_uid_123" + }, + "session": { + "authtoken": "test_session_key" + } + }) + + # Execute + result = obj.handle(input_data) + + # Verify + self.assertEqual(result['status'], 500) + self.assertEqual(result['payload']['canceled'], "Failed") + + @patch("cancel_run.com.DatabricksClient") + def test_handle_exception_during_api_call(self, mock_client_class): + """Test exception raised during API call.""" + # Setup + obj = self.CancelRunningExecution("command_line", "command_arg") + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.databricks_api.side_effect = Exception("Connection timeout") + + # Prepare input data + input_data = json.dumps({ + "form": { + "run_id": "12345", + "account_name": "test_account", + "uid": "test_uid_123" + }, + "session": { + "authtoken": "test_session_key" + } + }) + + # Execute + result = obj.handle(input_data) + + # Verify + self.assertEqual(result['status'], 500) + self.assertEqual(result['payload']['canceled'], "Failed") + mock_client.databricks_api.assert_called_once() + + @patch("cancel_run.com.DatabricksClient") + def test_handle_invalid_json_input(self, mock_client_class): + """Test exception during input parsing with invalid JSON. + + Note: This test exposes a bug in the code where LOG_PREFIX is referenced + before assignment in the outer exception handler. + """ + # Setup + obj = self.CancelRunningExecution("command_line", "command_arg") + + # Prepare invalid input data + input_data = "invalid json string {{" + + # Execute - this will raise UnboundLocalError due to bug in cancel_run.py + with self.assertRaises(UnboundLocalError): + result = obj.handle(input_data) + + # DatabricksClient should not be instantiated + mock_client_class.assert_not_called() + + @patch("cancel_run.com.DatabricksClient") + def test_handle_missing_form_data(self, mock_client_class): + """Test exception when form data is missing. + + Note: This test exposes a bug in the code where LOG_PREFIX is referenced + before assignment in the outer exception handler. + """ + # Setup + obj = self.CancelRunningExecution("command_line", "command_arg") + + # Prepare input data without form key + input_data = json.dumps({ + "session": { + "authtoken": "test_session_key" + } + }) + + # Execute - this will raise UnboundLocalError due to bug in cancel_run.py + with self.assertRaises(UnboundLocalError): + result = obj.handle(input_data) + + @patch("cancel_run.com.DatabricksClient") + def test_handle_missing_session_data(self, mock_client_class): + """Test exception when session data is missing.""" + # Setup + obj = self.CancelRunningExecution("command_line", "command_arg") + + # Prepare input data without session key + input_data = json.dumps({ + "form": { + "run_id": "12345", + "account_name": "test_account", + "uid": "test_uid_123" + } + }) + + # Execute + result = obj.handle(input_data) + + # Verify + self.assertEqual(result['status'], 500) + self.assertEqual(result['payload']['canceled'], "Failed") + + def test_handleStream_raises_not_implemented_error(self): + """Test that handleStream method raises NotImplementedError.""" + obj = self.CancelRunningExecution("command_line", "command_arg") + + with self.assertRaises(NotImplementedError) as context: + obj.handleStream("handle", "in_string") + + self.assertEqual( + str(context.exception), + "PersistentServerConnectionApplication.handleStream" + ) + + def test_done_method_executes_without_error(self): + """Test that done method executes without error.""" + obj = self.CancelRunningExecution("command_line", "command_arg") + + # Should not raise any exception + try: + result = obj.done() + # done() method should return None (implicitly by pass statement) + self.assertIsNone(result) + except Exception as e: + self.fail(f"done() method raised an exception: {e}") + + @patch("cancel_run.com.DatabricksClient") + def test_handle_all_form_fields_populated(self, mock_client_class): + """Test that all form fields are correctly extracted and stored.""" + # Setup + obj = self.CancelRunningExecution("command_line", "command_arg") + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.databricks_api.return_value = ({"status": "CANCELLED"}, 200) + + # Prepare input data with all fields + input_data = json.dumps({ + "form": { + "run_id": "67890", + "account_name": "prod_account", + "uid": "unique_user_id_456" + }, + "session": { + "authtoken": "super_secret_token" + } + }) + + # Execute + result = obj.handle(input_data) + + # Verify all instance variables are correctly set + self.assertEqual(obj.run_id, "67890") + self.assertEqual(obj.account_name, "prod_account") + self.assertEqual(obj.uid, "unique_user_id_456") + self.assertEqual(obj.session_key, "super_secret_token") + self.assertEqual(result['status'], 200) + + @patch("cancel_run.com.DatabricksClient") + def test_handle_client_instantiation_error(self, mock_client_class): + """Test exception during DatabricksClient instantiation.""" + # Setup + obj = self.CancelRunningExecution("command_line", "command_arg") + mock_client_class.side_effect = Exception("Failed to create client") + + # Prepare input data + input_data = json.dumps({ + "form": { + "run_id": "12345", + "account_name": "test_account", + "uid": "test_uid_123" + }, + "session": { + "authtoken": "test_session_key" + } + }) + + # Execute + result = obj.handle(input_data) + + # Verify + self.assertEqual(result['status'], 500) + self.assertEqual(result['payload']['canceled'], "Failed") + mock_client_class.assert_called_once() + + @patch("cancel_run.com.DatabricksClient") + def test_handle_response_status_codes(self, mock_client_class): + """Test various non-200 status codes from API.""" + test_cases = [ + (201, 500), # API returns 201, should result in 500 status + (400, 500), # Bad request + (401, 500), # Unauthorized + (403, 500), # Forbidden + (503, 500), # Service unavailable + ] + + for api_status, expected_status in test_cases: + with self.subTest(api_status=api_status): + obj = self.CancelRunningExecution("command_line", "command_arg") + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.databricks_api.return_value = ({"error": "error"}, api_status) + + input_data = json.dumps({ + "form": { + "run_id": "12345", + "account_name": "test_account", + "uid": "test_uid_123" + }, + "session": { + "authtoken": "test_session_key" + } + }) + + result = obj.handle(input_data) + + self.assertEqual(result['status'], expected_status) + self.assertEqual(result['payload']['canceled'], "Failed") diff --git a/tests/test_databricks_com.py b/tests/test_databricks_com.py index 97b6fc3..8e983c8 100644 --- a/tests/test_databricks_com.py +++ b/tests/test_databricks_com.py @@ -188,6 +188,326 @@ def test_get_api_response_refresh_token_error(self, mock_proxy, mock_conf, mock_ self.assertEqual( "Invalid access token. Please enter the valid access token.", str(context.exception)) + # ========================================================================= + # OAuth M2M Token Refresh Tests + # ========================================================================= + + @patch("databricks_com.utils.get_oauth_access_token", return_value=("new_oauth_token", 3600)) + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + @patch("databricks_com.utils.get_proxy_uri", return_value=None) + def test_get_api_response_oauth_refresh_token(self, mock_proxy, mock_conf, mock_session, mock_version, mock_refresh): + """Test OAuth M2M token refresh on 403 response.""" + db_com = import_module('databricks_com') + mock_conf.return_value = { + "databricks_instance": "123", + "auth_type": "OAUTH_M2M", + "oauth_access_token": "token", + "oauth_client_id": "client_id", + "oauth_client_secret": "client_secret", + "oauth_token_expiration": "9999999999.0", + "proxy_uri": None + } + obj = db_com.DatabricksClient("account_name", "session_key") + obj.session.post.side_effect = [Response(403), Response(200)] + resp = obj.databricks_api("post", "endpoint", args="123", data={"p1": "v1"}) + self.assertEqual(obj.session.post.call_count, 2) + self.assertEqual(resp, {"status_code": 200}) + + @patch("databricks_com.utils.get_oauth_access_token", return_value=("Token refresh failed", False)) + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + @patch("databricks_com.utils.get_proxy_uri", return_value=None) + def test_get_api_response_oauth_refresh_token_failure(self, mock_proxy, mock_conf, mock_session, mock_version, mock_refresh): + """Test OAuth M2M token refresh failure.""" + db_com = import_module('databricks_com') + mock_conf.return_value = { + "databricks_instance": "123", + "auth_type": "OAUTH_M2M", + "oauth_access_token": "token", + "oauth_client_id": "client_id", + "oauth_client_secret": "client_secret", + "oauth_token_expiration": "9999999999.0", + "proxy_uri": None + } + obj = db_com.DatabricksClient("account_name", "session_key") + obj.session.post.side_effect = [Response(403)] + with self.assertRaises(Exception) as context: + resp = obj.databricks_api("post", "endpoint", args="123", data={"p1": "v1"}) + self.assertEqual("Token refresh failed", str(context.exception)) + + @patch("databricks_com.utils.get_oauth_access_token", return_value=("new_oauth_token", 3600)) + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + @patch("databricks_com.utils.get_proxy_uri", return_value=None) + def test_get_api_response_oauth_refresh_still_fails(self, mock_proxy, mock_conf, mock_session, mock_version, mock_refresh): + """Test OAuth M2M token refresh when second request still fails.""" + db_com = import_module('databricks_com') + mock_conf.return_value = { + "databricks_instance": "123", + "auth_type": "OAUTH_M2M", + "oauth_access_token": "token", + "oauth_client_id": "client_id", + "oauth_client_secret": "client_secret", + "oauth_token_expiration": "9999999999.0", + "proxy_uri": None + } + obj = db_com.DatabricksClient("account_name", "session_key") + obj.session.post.side_effect = [Response(403), Response(403)] + with self.assertRaises(Exception) as context: + resp = obj.databricks_api("post", "endpoint", args="123", data={"p1": "v1"}) + self.assertEqual(obj.session.post.call_count, 2) + self.assertEqual("Invalid access token. Please enter the valid access token.", str(context.exception)) + + @patch("databricks_com.utils.get_oauth_access_token", return_value=("new_oauth_token", 300)) + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + @patch("databricks_com.utils.get_proxy_uri", return_value=None) + @patch("time.time", return_value=1000000.0) + def test_proactive_oauth_token_refresh(self, mock_time, mock_proxy, mock_conf, mock_session, mock_version, mock_refresh): + """Test proactive OAuth token refresh when token is about to expire.""" + db_com = import_module('databricks_com') + # Token expires in 4 minutes (240 seconds) - should trigger refresh + mock_conf.return_value = { + "databricks_instance": "123", + "auth_type": "OAUTH_M2M", + "oauth_access_token": "token", + "oauth_client_id": "client_id", + "oauth_client_secret": "client_secret", + "oauth_token_expiration": "1000240.0", # 240 seconds from now + "proxy_uri": None + } + obj = db_com.DatabricksClient("account_name", "session_key") + obj.session.post.return_value = Response(200) + + resp = obj.databricks_api("post", "endpoint", args="123", data={"p1": "v1"}) + + # Token should have been refreshed proactively + mock_refresh.assert_called_once() + self.assertEqual(resp, {"status_code": 200}) + + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + @patch("time.time", return_value=1000000.0) + def test_should_refresh_oauth_token_within_threshold(self, mock_time, mock_conf, mock_session, mock_version): + """Test should_refresh_oauth_token returns True when token expires within 5 minutes.""" + db_com = import_module('databricks_com') + mock_conf.return_value = { + "databricks_instance": "123", + "auth_type": "OAUTH_M2M", + "oauth_access_token": "token", + "oauth_client_id": "client_id", + "oauth_client_secret": "client_secret", + "oauth_token_expiration": "1000200.0", # 200 seconds from now (< 5 min) + "proxy_uri": None + } + obj = db_com.DatabricksClient("account_name", "session_key") + + self.assertTrue(obj.should_refresh_oauth_token()) + + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + @patch("time.time", return_value=1000000.0) + def test_should_refresh_oauth_token_outside_threshold(self, mock_time, mock_conf, mock_session, mock_version): + """Test should_refresh_oauth_token returns False when token has > 5 minutes validity.""" + db_com = import_module('databricks_com') + mock_conf.return_value = { + "databricks_instance": "123", + "auth_type": "OAUTH_M2M", + "oauth_access_token": "token", + "oauth_client_id": "client_id", + "oauth_client_secret": "client_secret", + "oauth_token_expiration": "1003700.0", # 3700 seconds from now (> 5 min) + "proxy_uri": None + } + obj = db_com.DatabricksClient("account_name", "session_key") + + self.assertFalse(obj.should_refresh_oauth_token()) + + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + def test_should_refresh_oauth_token_not_oauth(self, mock_conf, mock_session, mock_version): + """Test should_refresh_oauth_token returns False for non-OAuth auth types.""" + db_com = import_module('databricks_com') + mock_conf.return_value = { + "databricks_instance": "123", + "auth_type": "PAT", + "databricks_pat": "token", + "proxy_uri": None + } + obj = db_com.DatabricksClient("account_name", "session_key") + + # PAT auth doesn't have oauth_token_expiration attribute + self.assertFalse(obj.should_refresh_oauth_token()) + + # ========================================================================= + # Connection Error Handling Tests + # ========================================================================= + + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + def test_get_api_response_connection_error(self, mock_conf, mock_session, mock_version): + """Test handling of connection errors.""" + db_com = import_module('databricks_com') + mock_conf.return_value = {"databricks_instance": "123", "auth_type": "PAT", "databricks_pat": "token", "proxy_uri": None} + obj = db_com.DatabricksClient("account_name", "session_key") + obj.session.get.side_effect = Exception("Connection refused") + + with self.assertRaises(Exception) as context: + obj.databricks_api("get", "endpoint") + + self.assertEqual("Connection refused", str(context.exception)) + + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + def test_get_api_response_400_error(self, mock_conf, mock_session, mock_version): + """Test handling of 400 Bad Request response.""" + db_com = import_module('databricks_com') + mock_conf.return_value = {"databricks_instance": "123", "auth_type": "PAT", "databricks_pat": "token", "proxy_uri": None} + obj = db_com.DatabricksClient("account_name", "session_key") + obj.session.post.return_value = Response(400, {"message": "Invalid parameter"}) + + with self.assertRaises(Exception) as context: + obj.databricks_api("post", "endpoint", data={"p1": "v1"}) + + self.assertEqual("Invalid parameter", str(context.exception)) + + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + def test_get_api_response_500_error(self, mock_conf, mock_session, mock_version): + """Test handling of 500 Internal Server Error response.""" + db_com = import_module('databricks_com') + mock_conf.return_value = {"databricks_instance": "123", "auth_type": "PAT", "databricks_pat": "token", "proxy_uri": None} + obj = db_com.DatabricksClient("account_name", "session_key") + obj.session.get.return_value = Response(500, {"error": "Database connection failed"}) + + with self.assertRaises(Exception) as context: + obj.databricks_api("get", "endpoint") + + self.assertEqual("Database connection failed", str(context.exception)) + + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + def test_get_api_response_404_error(self, mock_conf, mock_session, mock_version): + """Test handling of 404 Not Found response.""" + db_com = import_module('databricks_com') + mock_conf.return_value = {"databricks_instance": "123", "auth_type": "PAT", "databricks_pat": "token", "proxy_uri": None} + obj = db_com.DatabricksClient("account_name", "session_key") + obj.session.get.return_value = Response(404, {}) + + with self.assertRaises(Exception) as context: + obj.databricks_api("get", "endpoint") + + self.assertEqual("Invalid API endpoint.", str(context.exception)) + + # ========================================================================= + # Cancel Endpoint Tests + # ========================================================================= + + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + def test_get_api_response_cancel_endpoint(self, mock_conf, mock_session, mock_version): + """Test cancel endpoint returns tuple with status code.""" + db_com = import_module('databricks_com') + mock_conf.return_value = {"databricks_instance": "123", "auth_type": "PAT", "databricks_pat": "token", "proxy_uri": None} + obj = db_com.DatabricksClient("account_name", "session_key") + obj.session.post.return_value = Response(200, {"cancelled": True}) + + resp, status_code = obj.databricks_api("post", "/api/cancel", data={"run_id": "123"}) + + self.assertEqual(resp, {"cancelled": True}) + self.assertEqual(status_code, 200) + + # ========================================================================= + # OAuth M2M Client Initialization Tests + # ========================================================================= + + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + def test_get_object_oauth_m2m(self, mock_conf, mock_session, mock_version): + """Test DatabricksClient initialization with OAuth M2M auth.""" + db_com = import_module('databricks_com') + db_com._LOGGER = MagicMock() + mock_conf.return_value = { + "databricks_instance": "123", + "auth_type": "OAUTH_M2M", + "oauth_access_token": "oauth_token", + "oauth_client_id": "client_id", + "oauth_client_secret": "client_secret", + "oauth_token_expiration": "9999999999.0", + "proxy_uri": None + } + obj = db_com.DatabricksClient("account_name", "session_key") + + self.assertIsInstance(obj, db_com.DatabricksClient) + self.assertEqual(obj.auth_type, "OAUTH_M2M") + self.assertEqual(obj.databricks_token, "oauth_token") + self.assertEqual(obj.oauth_client_id, "client_id") + self.assertEqual(obj.oauth_client_secret, "client_secret") + + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + def test_get_object_account_not_found(self, mock_conf, mock_session, mock_version): + """Test DatabricksClient initialization when account not found.""" + db_com = import_module('databricks_com') + mock_conf.return_value = None + + with self.assertRaises(Exception) as context: + obj = db_com.DatabricksClient("nonexistent_account", "session_key") + + self.assertEqual( + "Account 'nonexistent_account' not found. Please provide valid Databricks account.", + str(context.exception) + ) + + # ========================================================================= + # External API Tests + # ========================================================================= + + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + def test_external_api_get(self, mock_conf, mock_session, mock_version): + """Test external API GET request.""" + db_com = import_module('databricks_com') + mock_conf.return_value = {"databricks_instance": "123", "auth_type": "PAT", "databricks_pat": "token", "proxy_uri": None} + obj = db_com.DatabricksClient("account_name", "session_key") + obj.external_session.get.return_value = Response(200, {"data": "test"}) + + resp = obj.external_api("get", "https://external.api.com/endpoint") + + self.assertEqual(obj.external_session.get.call_count, 1) + self.assertEqual(resp, {"data": "test"}) + + @patch("solnlib.server_info", return_value=MagicMock()) + @patch("databricks_com.DatabricksClient.get_requests_retry_session", return_value=MagicMock()) + @patch("databricks_com.utils.get_databricks_configs", autospec=True) + def test_external_api_post(self, mock_conf, mock_session, mock_version): + """Test external API POST request.""" + db_com = import_module('databricks_com') + mock_conf.return_value = {"databricks_instance": "123", "auth_type": "PAT", "databricks_pat": "token", "proxy_uri": None} + obj = db_com.DatabricksClient("account_name", "session_key") + obj.external_session.post.return_value = Response(200, {"result": "success"}) + + resp = obj.external_api("post", "https://external.api.com/endpoint", data={"key": "value"}) + + self.assertEqual(obj.external_session.post.call_count, 1) + self.assertEqual(resp, {"result": "success"}) + diff --git a/tests/test_databricks_common_utils.py b/tests/test_databricks_common_utils.py index fac5484..659fc21 100644 --- a/tests/test_databricks_common_utils.py +++ b/tests/test_databricks_common_utils.py @@ -48,8 +48,8 @@ def test_get_user_agent(self, mock_user): def test_get_current_user(self, mock_common, mock_json, mock_jobs, mock_client): db_utils = import_module('databricks_common_utils') mock_common.return_value = 8089 - mock_client.retrun_value = MagicMock() - mock_jobs.retrun_value = '[{"username": "db_admin"}]' + mock_client.return_value = MagicMock() + mock_jobs.return_value = '[{"username": "db_admin"}]' mock_json.return_value = [{"username": "db_admin"}] response = db_utils.get_current_user("session_key") self.assertEqual(response, "db_admin") @@ -139,64 +139,490 @@ def test_format_to_json_parameters_exception(self): self.assertEqual( "Invalid format for parameter notebook_params. Provide the value in 'param1=val1||param2=val2' format.", str(context.exception)) - @patch("databricks_common_utils.get_proxy_uri") - @patch("databricks_common_utils.get_databricks_configs") - @patch("databricks_common_utils.save_databricks_aad_access_token") - @patch("databricks_common_utils.requests.post") - def test_get_aad_access_token(self, mock_post, mock_save, mock_conf, mock_proxy): - db_utils = import_module('databricks_common_utils') - mock_save.side_effect = MagicMock - mock_conf. return_value = MagicMock() - mock_proxy.return_value = MagicMock() - mock_post.return_value.json.return_value = {"access_token": "123"} - mock_post.return_value.status_code = 200 - return_val = db_utils.get_aad_access_token("session_key", "user_agent", "account_name", "aad_client_id", "aad_client_secret") - self.assertEqual (return_val, "123") - - @patch("databricks_common_utils.get_proxy_uri") - @patch("databricks_common_utils.get_databricks_configs") - @patch("databricks_common_utils.save_databricks_aad_access_token") - @patch("databricks_common_utils.requests.post") - def test_get_aad_access_token_200(self, mock_post, mock_save, mock_conf, mock_proxy): - db_utils = import_module('databricks_common_utils') - mock_save.return_value = MagicMock() - mock_conf. return_value = MagicMock() - mock_proxy.return_value = MagicMock() - mock_post.return_value.json.return_value = {"access_token": "123"} - mock_post.return_value.status_code = 200 - return_val = db_utils.get_aad_access_token("session_key", "user_agent", "account_name", "aad_client_id", "aad_client_secret") - self.assertEqual (return_val, "123") - - @patch("databricks_common_utils.get_proxy_uri") - @patch("databricks_common_utils.get_databricks_configs") + @patch("databricks_common_utils.get_proxy_uri") + @patch("databricks_common_utils.get_databricks_configs") @patch("databricks_common_utils.save_databricks_aad_access_token") @patch("databricks_common_utils.requests.post") def test_get_aad_access_token_200(self, mock_post, mock_save, mock_conf, mock_proxy): db_utils = import_module('databricks_common_utils') mock_save.return_value = MagicMock() - mock_conf. return_value = MagicMock() + mock_conf.return_value = MagicMock() mock_proxy.return_value = MagicMock() mock_post.return_value.json.return_value = {"access_token": "123"} - mock_post.return_value.status_code = 200 + mock_post.return_value.status_code = 200 return_val = db_utils.get_aad_access_token("session_key", "user_agent", "account_name", "aad_client_id", "aad_client_secret") - self.assertEqual (return_val, "123") + self.assertEqual(return_val, "123") - @patch("databricks_common_utils.get_proxy_uri") - @patch("databricks_common_utils.get_databricks_configs") + @patch("databricks_common_utils.get_proxy_uri") + @patch("databricks_common_utils.get_databricks_configs") @patch("databricks_common_utils.save_databricks_aad_access_token") @patch("databricks_common_utils.requests.post") def test_get_aad_access_token_403(self, mock_post, mock_save, mock_conf, mock_proxy): db_utils = import_module('databricks_common_utils') mock_save.return_value = MagicMock() - mock_conf. return_value = MagicMock() - mock_proxy.side_effect = MagicMock() + mock_conf.return_value = MagicMock() + mock_proxy.return_value = MagicMock() mock_post.side_effect = [Response(403), Response(403), Response(403)] return_val = db_utils.get_aad_access_token("session_key", "user_agent", "account_name", "aad_client_id", "aad_client_secret", retry=3) - self.assertEqual (return_val, ("Client secret may have expired. Please configure a valid Client secret.", False)) + self.assertEqual(return_val, ("Client secret may have expired. Please configure a valid Client secret.", False)) self.assertEqual(mock_post.call_count, 3) + @patch("databricks_common_utils.rest.simpleRequest") + def test_save_databricks_oauth_access_token(self, mock_request): + """Test saving OAuth access token successfully.""" + db_utils = import_module('databricks_common_utils') + db_utils._LOGGER = MagicMock() + mock_request.return_value = (200, '{}') + + db_utils.save_databricks_oauth_access_token("account_name", "session_key", "access_token", 3600, "client_secret") + + self.assertEqual(db_utils._LOGGER.info.call_count, 2) + db_utils._LOGGER.info.assert_called_with("Saved OAuth access token successfully.") + + @patch("databricks_common_utils.rest.simpleRequest") + def test_save_databricks_oauth_access_token_exception(self, mock_request): + """Test saving OAuth access token with exception.""" + db_utils = import_module('databricks_common_utils') + db_utils._LOGGER = MagicMock() + mock_request.side_effect = Exception("test error") + + with self.assertRaises(Exception) as context: + db_utils.save_databricks_oauth_access_token("account_name", "session_key", "access_token", 3600, "client_secret") + + self.assertEqual(db_utils._LOGGER.error.call_count, 1) + db_utils._LOGGER.error.assert_called_with("Exception while saving OAuth access token: test error") + self.assertEqual("Exception while saving OAuth access token.", str(context.exception)) + + @patch("databricks_common_utils.get_current_user") + @patch("databricks_common_utils.requests.post") + def test_get_oauth_access_token_success(self, mock_post, mock_user): + """Test successful OAuth M2M token acquisition.""" + db_utils = import_module('databricks_common_utils') + db_utils._LOGGER = MagicMock() + mock_user.return_value = "test_user" + + mock_response = MagicMock() + mock_response.json.return_value = {"access_token": "oauth_token_123", "expires_in": 3600} + mock_response.status_code = 200 + mock_post.return_value = mock_response + + token, expires_in = db_utils.get_oauth_access_token( + "session_key", "account_name", "databricks.instance.com", + "client_id", "client_secret" + ) + + self.assertEqual(token, "oauth_token_123") + self.assertEqual(expires_in, 3600) + + @patch("databricks_common_utils.get_current_user") + @patch("databricks_common_utils.save_databricks_oauth_access_token") + @patch("databricks_common_utils.requests.post") + def test_get_oauth_access_token_with_conf_update(self, mock_post, mock_save, mock_user): + """Test OAuth token acquisition with configuration update.""" + db_utils = import_module('databricks_common_utils') + db_utils._LOGGER = MagicMock() + mock_user.return_value = "test_user" + mock_save.return_value = None + + mock_response = MagicMock() + mock_response.json.return_value = {"access_token": "oauth_token_123", "expires_in": 3600} + mock_response.status_code = 200 + mock_post.return_value = mock_response + + token, expires_in = db_utils.get_oauth_access_token( + "session_key", "account_name", "databricks.instance.com", + "client_id", "client_secret", conf_update=True + ) + + self.assertEqual(token, "oauth_token_123") + mock_save.assert_called_once() + + @patch("databricks_common_utils.get_current_user") + @patch("databricks_common_utils.requests.post") + def test_get_oauth_access_token_with_proxy_use_for_oauth_true(self, mock_post, mock_user): + """Test OAuth token acquisition skips proxy when use_for_oauth is true.""" + db_utils = import_module('databricks_common_utils') + db_utils._LOGGER = MagicMock() + mock_user.return_value = "test_user" + + mock_response = MagicMock() + mock_response.json.return_value = {"access_token": "oauth_token_123", "expires_in": 3600} + mock_response.status_code = 200 + mock_post.return_value = mock_response + + proxy_settings = {"http": "http://proxy:8080", "https": "http://proxy:8080", "use_for_oauth": "1"} + + token, expires_in = db_utils.get_oauth_access_token( + "session_key", "account_name", "databricks.instance.com", + "client_id", "client_secret", proxy_settings=proxy_settings + ) + + # Should be called with None proxy when use_for_oauth is true + call_args = mock_post.call_args + self.assertIsNone(call_args[1]['proxies']) + + @patch("databricks_common_utils.get_current_user") + @patch("databricks_common_utils.requests.post") + def test_get_oauth_access_token_with_proxy_use_for_oauth_false(self, mock_post, mock_user): + """Test OAuth token acquisition uses proxy when use_for_oauth is false.""" + db_utils = import_module('databricks_common_utils') + db_utils._LOGGER = MagicMock() + mock_user.return_value = "test_user" + + mock_response = MagicMock() + mock_response.json.return_value = {"access_token": "oauth_token_123", "expires_in": 3600} + mock_response.status_code = 200 + mock_post.return_value = mock_response + + proxy_settings = {"http": "http://proxy:8080", "https": "http://proxy:8080", "use_for_oauth": "0"} + + token, expires_in = db_utils.get_oauth_access_token( + "session_key", "account_name", "databricks.instance.com", + "client_id", "client_secret", proxy_settings=proxy_settings + ) + + # Should be called with proxy settings (without use_for_oauth key) + call_args = mock_post.call_args + self.assertEqual(call_args[1]['proxies'], {"http": "http://proxy:8080", "https": "http://proxy:8080"}) + + @patch("databricks_common_utils.get_current_user") + @patch("databricks_common_utils.requests.post") + def test_get_oauth_access_token_invalid_client(self, mock_post, mock_user): + """Test OAuth token acquisition with invalid_client error.""" + db_utils = import_module('databricks_common_utils') + db_utils._LOGGER = MagicMock() + mock_user.return_value = "test_user" + + mock_response = MagicMock() + mock_response.json.return_value = {"error": "invalid_client"} + mock_response.status_code = 400 + mock_response.raise_for_status.side_effect = Exception("Invalid client") + mock_post.return_value = mock_response + + msg, status = db_utils.get_oauth_access_token( + "session_key", "account_name", "databricks.instance.com", + "client_id", "client_secret" + ) + + self.assertEqual(msg, "Invalid OAuth Client ID or Client Secret provided.") + self.assertFalse(status) + + @patch("databricks_common_utils.get_current_user") + @patch("databricks_common_utils.requests.post") + def test_get_oauth_access_token_unauthorized_client(self, mock_post, mock_user): + """Test OAuth token acquisition with unauthorized_client error.""" + db_utils = import_module('databricks_common_utils') + db_utils._LOGGER = MagicMock() + mock_user.return_value = "test_user" + + mock_response = MagicMock() + mock_response.json.return_value = {"error": "unauthorized_client"} + mock_response.status_code = 401 + mock_response.raise_for_status.side_effect = Exception("Unauthorized") + mock_post.return_value = mock_response + + msg, status = db_utils.get_oauth_access_token( + "session_key", "account_name", "databricks.instance.com", + "client_id", "client_secret" + ) + + self.assertEqual(msg, "Service principal is not authorized for this workspace.") + self.assertFalse(status) + + @patch("databricks_common_utils.get_current_user") + @patch("databricks_common_utils.requests.post") + def test_get_oauth_access_token_status_code_error(self, mock_post, mock_user): + """Test OAuth token acquisition with status code based error.""" + db_utils = import_module('databricks_common_utils') + db_utils._LOGGER = MagicMock() + mock_user.return_value = "test_user" + + mock_response = MagicMock() + mock_response.json.return_value = {} + mock_response.status_code = 429 + mock_response.raise_for_status.side_effect = Exception("Too many requests") + mock_post.return_value = mock_response + + msg, status = db_utils.get_oauth_access_token( + "session_key", "account_name", "databricks.instance.com", + "client_id", "client_secret" + ) + + self.assertEqual(msg, "API limit exceeded. Please try again after some time.") + self.assertFalse(status) + + @patch("databricks_common_utils.get_current_user") + @patch("databricks_common_utils.requests.post") + def test_get_oauth_access_token_generic_error(self, mock_post, mock_user): + """Test OAuth token acquisition with generic HTTP error.""" + db_utils = import_module('databricks_common_utils') + db_utils._LOGGER = MagicMock() + mock_user.return_value = "test_user" + + mock_response = MagicMock() + mock_response.json.return_value = {} + mock_response.status_code = 502 + mock_response.raise_for_status.side_effect = Exception("Bad Gateway") + mock_post.return_value = mock_response + + msg, status = db_utils.get_oauth_access_token( + "session_key", "account_name", "databricks.instance.com", + "client_id", "client_secret" + ) + + self.assertIn("Unable to validate OAuth credentials", msg) + self.assertFalse(status) + + @patch("databricks_common_utils.get_current_user") + @patch("databricks_common_utils.requests.post") + def test_get_oauth_access_token_connection_error(self, mock_post, mock_user): + """Test OAuth token acquisition with connection error (no response).""" + db_utils = import_module('databricks_common_utils') + db_utils._LOGGER = MagicMock() + mock_user.return_value = "test_user" + + mock_post.side_effect = Exception("Connection refused") + + msg, status = db_utils.get_oauth_access_token( + "session_key", "account_name", "databricks.instance.com", + "client_id", "client_secret" + ) + + self.assertIn("Unable to request Databricks instance", msg) + self.assertFalse(status) + + @patch("databricks_common_utils.get_current_user") + @patch("databricks_common_utils.requests.post") + def test_get_oauth_access_token_retry_mechanism(self, mock_post, mock_user): + """Test OAuth token acquisition retry mechanism.""" + db_utils = import_module('databricks_common_utils') + db_utils._LOGGER = MagicMock() + mock_user.return_value = "test_user" + + mock_post.side_effect = Exception("Connection timeout") + + msg, status = db_utils.get_oauth_access_token( + "session_key", "account_name", "databricks.instance.com", + "client_id", "client_secret", retry=3 + ) + + self.assertEqual(mock_post.call_count, 3) + self.assertFalse(status) + + @patch("databricks_common_utils.get_proxy_configuration") + @patch("databricks_common_utils.get_proxy_clear_password") + def test_get_proxy_uri_with_special_chars(self, mock_pwd, mock_conf): + """Test proxy URI formatting with special characters in username/password.""" + db_utils = import_module('databricks_common_utils') + db_utils._LOGGER = MagicMock() + + mock_conf.return_value = { + "proxy_enabled": 1, + "proxy_type": "http", + "proxy_url": "proxy.example.com", + "proxy_port": 8080, + "proxy_username": "user@domain.com", + "use_for_oauth": "0" + } + mock_pwd.return_value = "p@ssw:rd!" + + proxy_uri = db_utils.get_proxy_uri("session_key") + + # Verify that special characters are URL encoded + self.assertIn("user%40domain.com", proxy_uri['http']) + self.assertIn("p%40ssw%3Ard%21", proxy_uri['http']) + self.assertEqual(proxy_uri['use_for_oauth'], "0") + + @patch("databricks_common_utils.rest.simpleRequest") + def test_get_databricks_configs_with_proxy_and_credentials(self, mock_request): + """Test get_databricks_configs with proxy enabled and credentials.""" + db_utils = import_module('databricks_common_utils') + db_utils._LOGGER = MagicMock() + + config_data = { + "databricks_instance": "test.databricks.com", + "auth_type": "OAuth", + "proxy_enabled": "1", + "proxy_url": "proxy.example.com", + "proxy_port": "8080", + "proxy_type": "http", + "proxy_username": "proxyuser", + "proxy_password": "proxypass", + "use_for_oauth": "1" + } + + mock_request.return_value = (200, json.dumps(config_data)) + + response = db_utils.get_databricks_configs("session_key", "account_name") + + self.assertIn("proxy_uri", response) + self.assertIn("proxyuser", response["proxy_uri"]["http"]) + self.assertIn("proxypass", response["proxy_uri"]["http"]) + self.assertEqual(response["proxy_uri"]["use_for_oauth"], "1") + + @patch("databricks_common_utils.rest.simpleRequest") + def test_get_databricks_configs_with_proxy_special_chars(self, mock_request): + """Test get_databricks_configs with special characters in proxy credentials.""" + db_utils = import_module('databricks_common_utils') + db_utils._LOGGER = MagicMock() + + config_data = { + "databricks_instance": "test.databricks.com", + "auth_type": "OAuth", + "proxy_enabled": "1", + "proxy_url": "proxy.example.com", + "proxy_port": "3128", + "proxy_type": "https", + "proxy_username": "admin@corp", + "proxy_password": "p@ss:123", + "use_for_oauth": "0" + } + + mock_request.return_value = (200, json.dumps(config_data)) + + response = db_utils.get_databricks_configs("session_key", "account_name") + + # Verify special characters are URL encoded + self.assertIn("admin%40corp", response["proxy_uri"]["http"]) + self.assertIn("p%40ss%3A123", response["proxy_uri"]["http"]) + + @patch("databricks_common_utils.rest.simpleRequest") + def test_get_databricks_configs_with_proxy_no_credentials(self, mock_request): + """Test get_databricks_configs with proxy but no credentials.""" + db_utils = import_module('databricks_common_utils') + db_utils._LOGGER = MagicMock() + + config_data = { + "databricks_instance": "test.databricks.com", + "auth_type": "PAT", + "proxy_enabled": "1", + "proxy_url": "proxy.example.com", + "proxy_port": "8080", + "proxy_type": "http", + "use_for_oauth": "0" + } + + mock_request.return_value = (200, json.dumps(config_data)) + + response = db_utils.get_databricks_configs("session_key", "account_name") + + self.assertIn("proxy_uri", response) + # Should not contain @ symbol (no credentials) + self.assertNotIn("@", response["proxy_uri"]["http"]) + self.assertIn("proxy.example.com:8080", response["proxy_uri"]["http"]) + + @patch("databricks_common_utils.rest.simpleRequest") + def test_get_databricks_configs_exception(self, mock_request): + """Test get_databricks_configs with exception.""" + db_utils = import_module('databricks_common_utils') + db_utils._LOGGER = MagicMock() + + mock_request.side_effect = Exception("Connection error") + + response = db_utils.get_databricks_configs("session_key", "account_name") + + self.assertIsNone(response) + self.assertEqual(db_utils._LOGGER.error.call_count, 1) + + @patch("databricks_common_utils.get_current_user") + @patch("databricks_common_utils.requests.post") + def test_get_aad_access_token_with_conf_update(self, mock_post, mock_user): + """Test AAD token acquisition with configuration update.""" + db_utils = import_module('databricks_common_utils') + db_utils._LOGGER = MagicMock() + mock_user.return_value = "test_user" + + mock_response = MagicMock() + mock_response.json.return_value = {"access_token": "aad_token_123"} + mock_response.status_code = 200 + mock_post.return_value = mock_response + + with patch("databricks_common_utils.save_databricks_aad_access_token") as mock_save: + token = db_utils.get_aad_access_token( + "session_key", "account_name", "tenant_id", + "client_id", "client_secret", conf_update=True + ) + + self.assertEqual(token, "aad_token_123") + mock_save.assert_called_once() + + @patch("databricks_common_utils.get_current_user") + @patch("databricks_common_utils.requests.post") + def test_get_aad_access_token_with_proxy_settings(self, mock_post, mock_user): + """Test AAD token acquisition with proxy settings removes use_for_oauth key.""" + db_utils = import_module('databricks_common_utils') + db_utils._LOGGER = MagicMock() + mock_user.return_value = "test_user" + + mock_response = MagicMock() + mock_response.json.return_value = {"access_token": "aad_token_123"} + mock_response.status_code = 200 + mock_post.return_value = mock_response + + proxy_settings = {"http": "http://proxy:8080", "https": "http://proxy:8080", "use_for_oauth": "1"} + + token = db_utils.get_aad_access_token( + "session_key", "account_name", "tenant_id", + "client_id", "client_secret", proxy_settings=proxy_settings + ) + + # Verify use_for_oauth key is removed from proxy_settings + call_args = mock_post.call_args + self.assertNotIn("use_for_oauth", call_args[1]['proxies']) + + @patch("databricks_common_utils.get_current_user") + @patch("databricks_common_utils.requests.post") + def test_get_aad_access_token_error_code_handling(self, mock_post, mock_user): + """Test AAD token acquisition with error_codes in response.""" + db_utils = import_module('databricks_common_utils') + db_utils._LOGGER = MagicMock() + mock_user.return_value = "test_user" + + mock_response = MagicMock() + mock_response.json.return_value = {"error_codes": [700016]} + mock_response.status_code = 400 + mock_response.raise_for_status.side_effect = Exception("Invalid client ID") + mock_post.return_value = mock_response + + msg, status = db_utils.get_aad_access_token( + "session_key", "account_name", "tenant_id", + "client_id", "client_secret" + ) + + self.assertEqual(msg, "Invalid Client ID provided.") + self.assertFalse(status) + + @patch("databricks_common_utils.get_current_user") + @patch("databricks_common_utils.requests.post") + def test_get_aad_access_token_connection_error(self, mock_post, mock_user): + """Test AAD token acquisition with connection error (no response object).""" + db_utils = import_module('databricks_common_utils') + db_utils._LOGGER = MagicMock() + mock_user.return_value = "test_user" + + mock_post.side_effect = Exception("Connection refused") + + msg, status = db_utils.get_aad_access_token( + "session_key", "account_name", "tenant_id", + "client_id", "client_secret" + ) + + self.assertIn("Unable to request Databricks instance", msg) + self.assertFalse(status) + + @patch("databricks_common_utils.CredentialManager", autospec=True) + def test_get_proxy_clear_password_not_exist(self, mock_manager): + """Test get_proxy_clear_password when credential doesn't exist.""" + db_utils = import_module('databricks_common_utils') + from solnlib.credentials import CredentialNotExistException + + mock_manager.return_value.get_password.side_effect = CredentialNotExistException("Not found") + + pwd = db_utils.get_proxy_clear_password("session_key") + + self.assertIsNone(pwd) + + - - diff --git a/tests/test_databricks_get_credentials.py b/tests/test_databricks_get_credentials.py index 19f3d74..48d3af2 100644 --- a/tests/test_databricks_get_credentials.py +++ b/tests/test_databricks_get_credentials.py @@ -86,7 +86,7 @@ def test_handle_save_access_token_failure(self): obj1 = db_cm.DatabricksGetCredentials("command_line", "command_args") result = obj1.handle(input_string) - assert result["payload"] == "Databricks Error: Exception while saving AAD access token: Failed to save AAD access token" + assert result["payload"] == "Databricks Error: Exception while saving access token: Failed to save AAD access token" assert result["status"] == 500 @patch("databricks_get_credentials.rest.simpleRequest") @@ -106,3 +106,273 @@ def test_handle_retrieve_config_success(self, mock_request): mock_request.return_value = (200, json.dumps({"entry":[{"content":{"auth_type":"PAT", "databricks_instance":"http", "cluster_name":"test"}},"test"]})) result = obj1.handle(input_string) db_cm._LOGGER.debug.assert_called_with("Account configurations read successfully from account.conf .") + + # ========================================================================= + # OAuth Token Save Tests + # ========================================================================= + + def test_handle_save_oauth_access_token_success(self): + """Test successful saving of OAuth access token.""" + db_cm = import_module("databricks_get_credentials") + input_string = json.dumps({ + "system_authtoken": "dummy_token", + "form": { + "name": "test", + "update_token": "1", + "oauth_client_secret": "oauth_client_secret", + "oauth_access_token": "oauth_access_token", + "oauth_token_expiration": "1234567890.0" + } + }) + + # mock the CredentialManager + credential_manager_mock = MagicMock() + credential_manager_mock.set_password.return_value = "Saved OAuth access token successfully." + db_cm.CredentialManager = MagicMock(return_value=credential_manager_mock) + + obj1 = db_cm.DatabricksGetCredentials("command_line", "command_args") + result = obj1.handle(input_string) + + assert result["payload"] == "Saved OAuth access token successfully." + assert result["status"] == 200 + + # Verify the credential manager was called with correct parameters + credential_manager_mock.set_password.assert_called_once() + + def test_handle_save_oauth_access_token_failure(self): + """Test failure when saving OAuth access token.""" + db_cm = import_module("databricks_get_credentials") + input_string = json.dumps({ + "system_authtoken": "dummy_token", + "form": { + "name": "test", + "update_token": "1", + "oauth_client_secret": "oauth_client_secret", + "oauth_access_token": "oauth_access_token", + "oauth_token_expiration": "1234567890.0" + } + }) + + # mock the CredentialManager to raise an exception + credential_manager_mock = MagicMock() + credential_manager_mock.set_password.side_effect = Exception("Failed to save OAuth access token") + db_cm.CredentialManager = MagicMock(return_value=credential_manager_mock) + + obj1 = db_cm.DatabricksGetCredentials("command_line", "command_args") + result = obj1.handle(input_string) + + assert result["payload"] == "Databricks Error: Exception while saving access token: Failed to save OAuth access token" + assert result["status"] == 500 + + def test_handle_no_token_data_provided(self): + """Test error when update_token is set but no token data provided.""" + db_cm = import_module("databricks_get_credentials") + input_string = json.dumps({ + "system_authtoken": "dummy_token", + "form": { + "name": "test", + "update_token": "1" + # No aad_access_token or oauth_access_token + } + }) + + obj1 = db_cm.DatabricksGetCredentials("command_line", "command_args") + result = obj1.handle(input_string) + + assert "No token data provided for update" in result["payload"] + assert result["status"] == 500 + + # ========================================================================= + # Retrieve Configurations Tests (OAuth M2M) + # ========================================================================= + + @patch("databricks_get_credentials.rest.simpleRequest") + @patch("databricks_get_credentials.CredentialManager") + def test_handle_retrieve_oauth_config(self, mock_cred_manager, mock_request): + """Test retrieving OAuth M2M configuration.""" + db_cm = import_module("databricks_get_credentials") + db_cm._LOGGER = MagicMock() + + # Mock account manager + account_manager_mock = MagicMock() + account_manager_mock.get_password.return_value = json.dumps({ + "oauth_client_secret": "oauth_secret", + "oauth_access_token": "oauth_token", + "oauth_token_expiration": "1234567890.0" + }) + + # Mock proxy manager + proxy_manager_mock = MagicMock() + proxy_manager_mock.get_password.return_value = json.dumps({ + "proxy_password": "proxy_pass" + }) + + mock_cred_manager.side_effect = [account_manager_mock, proxy_manager_mock] + + input_string = json.dumps({ + "system_authtoken": "dummy_token", + "form": { + "name": "test_account" + } + }) + + # Mock REST calls + def mock_request_side_effect(*args, **kwargs): + if "ta_databricks_account" in args[0]: + return (200, json.dumps({ + "entry": [{ + "content": { + "auth_type": "OAUTH_M2M", + "databricks_instance": "test.databricks.azure.net", + "oauth_client_id": "oauth_client_id", + "config_for_dbquery": "warehouse", + "warehouse_id": "warehouse123", + "cluster_name": None + } + }] + })) + elif "proxy" in args[0]: + return (200, json.dumps({ + "entry": [{ + "content": { + "proxy_enabled": "0", + "proxy_password": "" + } + }] + })) + elif "additional_parameters" in args[0]: + return (200, json.dumps({ + "entry": [{ + "content": { + "admin_command_timeout": "300", + "query_result_limit": "1000", + "index": "main", + "thread_count": "4" + } + }] + })) + + mock_request.side_effect = mock_request_side_effect + + obj1 = db_cm.DatabricksGetCredentials("command_line", "command_args") + result = obj1.handle(input_string) + + assert result["status"] == 200 + payload = result["payload"] + assert payload["auth_type"] == "OAUTH_M2M" + assert payload["oauth_client_id"] == "oauth_client_id" + assert payload["oauth_access_token"] == "oauth_token" + assert payload["oauth_token_expiration"] == "1234567890.0" + + @patch("databricks_get_credentials.rest.simpleRequest") + @patch("databricks_get_credentials.CredentialManager") + def test_handle_retrieve_pat_config(self, mock_cred_manager, mock_request): + """Test retrieving PAT configuration.""" + db_cm = import_module("databricks_get_credentials") + db_cm._LOGGER = MagicMock() + + # Mock account manager + account_manager_mock = MagicMock() + account_manager_mock.get_password.return_value = json.dumps({ + "databricks_pat": "pat_token_value" + }) + + # Mock proxy manager + proxy_manager_mock = MagicMock() + proxy_manager_mock.get_password.return_value = json.dumps({}) + + mock_cred_manager.side_effect = [account_manager_mock, proxy_manager_mock] + + input_string = json.dumps({ + "system_authtoken": "dummy_token", + "form": { + "name": "test_account" + } + }) + + # Mock REST calls + def mock_request_side_effect(*args, **kwargs): + if "ta_databricks_account" in args[0]: + return (200, json.dumps({ + "entry": [{ + "content": { + "auth_type": "PAT", + "databricks_instance": "test.databricks.azure.net", + "config_for_dbquery": "cluster", + "cluster_name": "test_cluster", + "warehouse_id": None + } + }] + })) + elif "proxy" in args[0]: + return (200, json.dumps({ + "entry": [{ + "content": { + "proxy_enabled": "0", + "proxy_password": "" + } + }] + })) + elif "additional_parameters" in args[0]: + return (200, json.dumps({ + "entry": [{ + "content": { + "admin_command_timeout": "300", + "query_result_limit": "1000", + "index": "main", + "thread_count": "4" + } + }] + })) + + mock_request.side_effect = mock_request_side_effect + + obj1 = db_cm.DatabricksGetCredentials("command_line", "command_args") + result = obj1.handle(input_string) + + assert result["status"] == 200 + payload = result["payload"] + assert payload["auth_type"] == "PAT" + assert payload["databricks_pat"] == "pat_token_value" + + @patch("databricks_get_credentials.rest.simpleRequest") + def test_handle_retrieve_config_error(self, mock_request): + """Test error handling when retrieving configuration fails.""" + db_cm = import_module("databricks_get_credentials") + db_cm._LOGGER = MagicMock() + + mock_request.side_effect = Exception("Connection error") + + input_string = json.dumps({ + "system_authtoken": "dummy_token", + "form": { + "name": "test_account" + } + }) + + obj1 = db_cm.DatabricksGetCredentials("command_line", "command_args") + result = obj1.handle(input_string) + + assert result["status"] == 500 + assert "Databricks Error: Error occured while retrieving account and proxy configurations" in result["payload"] + + # ========================================================================= + # handleStream and done Tests + # ========================================================================= + + def test_handleStream_not_implemented(self): + """Test that handleStream raises NotImplementedError.""" + db_cm = import_module("databricks_get_credentials") + obj1 = db_cm.DatabricksGetCredentials("command_line", "command_args") + + with self.assertRaises(NotImplementedError): + obj1.handleStream(None, "input_string") + + def test_done_method(self): + """Test that done method executes without error.""" + db_cm = import_module("databricks_get_credentials") + obj1 = db_cm.DatabricksGetCredentials("command_line", "command_args") + + # done() should complete without raising any exception + result = obj1.done() + self.assertIsNone(result) diff --git a/tests/test_databricks_validators.py b/tests/test_databricks_validators.py index 430e696..0073870 100644 --- a/tests/test_databricks_validators.py +++ b/tests/test_databricks_validators.py @@ -171,3 +171,262 @@ def test_validate_instance_false(self, mock_user_agent, mock_get, mock_put, mock ret_val = db_val_obj.validate_db_instance("instance", "token") mock_put.assert_called_once_with("Internal server error. Cannot verify Databricks instance.") self.assertEqual(ret_val, False) + + # ========================================================================= + # OAuth M2M Validation Tests + # ========================================================================= + + @patch("databricks_validators.SessionKeyProvider", return_value=MagicMock()) + @patch("databricks_validators.utils.get_proxy_uri", return_value="{}") + @patch("splunk_aoblib.rest_migration.ConfigMigrationHandler") + @patch("databricks_validators.Validator") + @patch("databricks_validators.ValidateDatabricksInstance.validate_oauth") + def test_validate_oauth_auth_type(self, mock_oauth, mock_validator, mock_conf, mock_proxy, mock_session): + """Test that validate() calls validate_oauth for OAUTH_M2M auth type.""" + db_val = import_module('databricks_validators') + db_val._LOGGER = MagicMock() + db_val_obj = db_val.ValidateDatabricksInstance() + mock_oauth.return_value = True + db_val_obj.validate("OAUTH_M2M", { + "auth_type": "OAUTH_M2M", + "oauth_client_id": "client_id", + "oauth_client_secret": "client_secret", + "databricks_instance": "test.databricks.azure.net" + }) + self.assertEqual(mock_oauth.call_count, 1) + + @patch("databricks_validators.SessionKeyProvider", return_value=MagicMock()) + @patch("databricks_validators.utils.get_proxy_uri", return_value="{}") + @patch("splunk_aoblib.rest_migration.ConfigMigrationHandler") + @patch("databricks_validators.Validator.put_msg", return_value=MagicMock()) + def test_validate_oauth_client_id_error(self, mock_put, mock_conf, mock_proxy, mock_session): + """Test validation error when OAuth client ID is missing.""" + db_val = import_module('databricks_validators') + db_val._LOGGER = MagicMock() + db_val_obj = db_val.ValidateDatabricksInstance() + db_val_obj.validate("OAUTH_M2M", {"auth_type": "OAUTH_M2M"}) + mock_put.assert_called_once_with("Field OAuth Client ID is required") + + @patch("databricks_validators.SessionKeyProvider", return_value=MagicMock()) + @patch("databricks_validators.utils.get_proxy_uri", return_value="{}") + @patch("splunk_aoblib.rest_migration.ConfigMigrationHandler") + @patch("databricks_validators.Validator.put_msg", return_value=MagicMock()) + def test_validate_oauth_client_secret_error(self, mock_put, mock_conf, mock_proxy, mock_session): + """Test validation error when OAuth client secret is missing.""" + db_val = import_module('databricks_validators') + db_val._LOGGER = MagicMock() + db_val_obj = db_val.ValidateDatabricksInstance() + db_val_obj.validate("OAUTH_M2M", {"auth_type": "OAUTH_M2M", "oauth_client_id": "client_id"}) + mock_put.assert_called_once_with("Field OAuth Client Secret is required") + + @patch("databricks_validators.utils.get_oauth_access_token", return_value=("access_token", 3600)) + @patch("databricks_validators.Validator", autospec=True) + @patch("databricks_validators.ValidateDatabricksInstance.validate_db_instance", autospec=True) + def test_validate_oauth_function_success(self, mock_valid_inst, mock_validator, mock_access): + """Test successful OAuth M2M validation flow.""" + db_val = import_module('databricks_validators') + db_val._LOGGER = MagicMock() + db_val_obj = db_val.ValidateDatabricksInstance() + db_val_obj._splunk_session_key = "session_key" + db_val_obj._proxy_settings = {} + mock_valid_inst.return_value = True + + data = { + "auth_type": "OAUTH_M2M", + "oauth_client_id": "cl_id", + "oauth_client_secret": "cl_secret", + "databricks_instance": "db_instance", + "name": "test_account" + } + result = db_val_obj.validate_oauth(data) + + mock_valid_inst.assert_called_once_with(db_val_obj, "db_instance", "access_token") + self.assertTrue(result) + self.assertEqual(data["oauth_access_token"], "access_token") + self.assertEqual(data["databricks_pat"], "") + self.assertEqual(data["aad_access_token"], "") + + @patch("databricks_validators.Validator.put_msg", return_value=MagicMock()) + @patch("databricks_validators.utils.get_oauth_access_token", return_value=("Token retrieval failed", False)) + @patch("databricks_validators.ValidateDatabricksInstance.validate_db_instance") + def test_validate_oauth_function_token_error(self, mock_valid_inst, mock_access, mock_put): + """Test OAuth validation when token retrieval fails.""" + db_val = import_module('databricks_validators') + db_val._LOGGER = MagicMock() + db_val_obj = db_val.ValidateDatabricksInstance() + db_val_obj._splunk_session_key = "session_key" + db_val_obj._proxy_settings = {} + + data = { + "auth_type": "OAUTH_M2M", + "oauth_client_id": "cl_id", + "oauth_client_secret": "cl_secret", + "databricks_instance": "db_instance", + "name": "test_account" + } + result = db_val_obj.validate_oauth(data) + + mock_put.assert_called_once_with("Token retrieval failed") + self.assertEqual(mock_valid_inst.call_count, 0) + self.assertFalse(result) + + @patch("databricks_validators.utils.get_oauth_access_token", return_value=("access_token", 3600)) + @patch("databricks_validators.Validator", autospec=True) + @patch("databricks_validators.ValidateDatabricksInstance.validate_db_instance", autospec=True) + def test_validate_oauth_function_instance_validation_failure(self, mock_valid_inst, mock_validator, mock_access): + """Test OAuth validation when instance validation fails.""" + db_val = import_module('databricks_validators') + db_val._LOGGER = MagicMock() + db_val_obj = db_val.ValidateDatabricksInstance() + db_val_obj._splunk_session_key = "session_key" + db_val_obj._proxy_settings = {} + mock_valid_inst.return_value = False + + data = { + "auth_type": "OAUTH_M2M", + "oauth_client_id": "cl_id", + "oauth_client_secret": "cl_secret", + "databricks_instance": "db_instance", + "name": "test_account" + } + result = db_val_obj.validate_oauth(data) + + mock_valid_inst.assert_called_once() + self.assertFalse(result) + + # ========================================================================= + # Additional AAD Validation Tests + # ========================================================================= + + @patch("databricks_validators.utils.get_aad_access_token", return_value="access_token") + @patch("databricks_validators.Validator", autospec=True) + @patch("databricks_validators.ValidateDatabricksInstance.validate_db_instance", autospec=True) + def test_validate_aad_function_instance_failure(self, mock_valid_inst, mock_validator, mock_access): + """Test AAD validation when instance validation fails.""" + db_val = import_module('databricks_validators') + db_val._LOGGER = MagicMock() + db_val_obj = db_val.ValidateDatabricksInstance() + db_val_obj._splunk_session_key = "session_key" + db_val_obj._splunk_version = "splunk_version" + db_val_obj._proxy_settings = {} + mock_valid_inst.return_value = False + + result = db_val_obj.validate_aad({ + "auth_type": "AAD", + "aad_client_id": "cl_id", + "aad_tenant_id": "tenant_id", + "aad_client_secret": "client_secret", + "databricks_instance": "db_instance" + }) + + mock_valid_inst.assert_called_once() + self.assertFalse(result) + + # ========================================================================= + # Proxy Configuration Tests + # ========================================================================= + + @patch("databricks_validators.utils.get_proxy_uri") + @patch("databricks_validators.utils.get_current_user", return_value="test_user") + @patch("requests.get") + @patch("databricks_validators.Validator.put_msg", return_value=MagicMock()) + def test_validate_db_instance_with_proxy_skip_oauth(self, mock_put, mock_get, mock_user, mock_proxy): + """Test instance validation skips proxy when use_for_oauth is set.""" + db_val = import_module('databricks_validators') + db_val._LOGGER = MagicMock() + db_val_obj = db_val.ValidateDatabricksInstance() + db_val_obj._splunk_session_key = "session_key" + + # Proxy settings with use_for_oauth set to skip proxy + mock_proxy.return_value = {"use_for_oauth": "1", "http": "http://proxy:8080"} + mock_get.return_value = Response(200, {"clusters": []}) + + result = db_val_obj.validate_db_instance("instance", "token") + + # Verify proxy was skipped (set to None) + self.assertTrue(result) + self.assertIsNone(db_val_obj._proxy_settings) + + @patch("databricks_validators.utils.get_proxy_uri") + @patch("databricks_validators.utils.get_current_user", return_value="test_user") + @patch("requests.get") + def test_validate_db_instance_with_proxy_enabled(self, mock_get, mock_user, mock_proxy): + """Test instance validation uses proxy when configured.""" + db_val = import_module('databricks_validators') + db_val._LOGGER = MagicMock() + db_val_obj = db_val.ValidateDatabricksInstance() + db_val_obj._splunk_session_key = "session_key" + + # Proxy settings with use_for_oauth=0 means use proxy + mock_proxy.return_value = {"use_for_oauth": "0", "http": "http://proxy:8080"} + mock_get.return_value = Response(200, {"clusters": []}) + + result = db_val_obj.validate_db_instance("instance", "token") + + self.assertTrue(result) + # Verify proxy is configured (use_for_oauth key removed) + self.assertEqual(db_val_obj._proxy_settings, {"http": "http://proxy:8080"}) + + @patch("databricks_validators.utils.get_proxy_uri", return_value=None) + @patch("databricks_validators.Validator.put_msg", return_value=MagicMock()) + @patch("requests.get", return_value=Response(403)) + @patch("databricks_common_utils.get_user_agent") + def test_validate_instance_invalid_token(self, mock_user_agent, mock_get, mock_put, mock_proxy): + """Test instance validation with invalid access token (403 response).""" + db_val = import_module('databricks_validators') + db_val._LOGGER = MagicMock() + db_val_obj = db_val.ValidateDatabricksInstance() + db_val_obj._splunk_session_key = "session_key" + + ret_val = db_val_obj.validate_db_instance("instance", "invalid_token") + + mock_put.assert_called_once_with("Invalid access token. Please enter the valid access token.") + self.assertFalse(ret_val) + + @patch("databricks_validators.utils.get_proxy_uri", return_value=None) + @patch("databricks_validators.Validator.put_msg", return_value=MagicMock()) + @patch("requests.get", return_value=Response(404)) + @patch("databricks_common_utils.get_user_agent") + def test_validate_instance_not_found(self, mock_user_agent, mock_get, mock_put, mock_proxy): + """Test instance validation with 404 response.""" + db_val = import_module('databricks_validators') + db_val._LOGGER = MagicMock() + db_val_obj = db_val.ValidateDatabricksInstance() + db_val_obj._splunk_session_key = "session_key" + + ret_val = db_val_obj.validate_db_instance("instance", "token") + + mock_put.assert_called_once_with("Please validate the provided details.") + self.assertFalse(ret_val) + + @patch("databricks_validators.utils.get_proxy_uri", return_value=None) + @patch("databricks_validators.Validator.put_msg", return_value=MagicMock()) + @patch("requests.get", return_value=Response(400)) + @patch("databricks_common_utils.get_user_agent") + def test_validate_instance_invalid_instance(self, mock_user_agent, mock_get, mock_put, mock_proxy): + """Test instance validation with 400 response.""" + db_val = import_module('databricks_validators') + db_val._LOGGER = MagicMock() + db_val_obj = db_val.ValidateDatabricksInstance() + db_val_obj._splunk_session_key = "session_key" + + ret_val = db_val_obj.validate_db_instance("instance", "token") + + mock_put.assert_called_once_with("Invalid Databricks instance.") + self.assertFalse(ret_val) + + @patch("databricks_validators.utils.get_proxy_uri", return_value=None) + @patch("databricks_validators.utils.get_current_user", return_value="test_user") + @patch("requests.get") + def test_validate_instance_success(self, mock_get, mock_user, mock_proxy): + """Test successful instance validation.""" + db_val = import_module('databricks_validators') + db_val._LOGGER = MagicMock() + db_val_obj = db_val.ValidateDatabricksInstance() + db_val_obj._splunk_session_key = "session_key" + + mock_get.return_value = Response(200, {"clusters": []}) + + ret_val = db_val_obj.validate_db_instance("instance", "valid_token") + + self.assertTrue(ret_val) diff --git a/tests/test_databricksquery.py b/tests/test_databricksquery.py index 9bb5fc7..c1386d4 100644 --- a/tests/test_databricksquery.py +++ b/tests/test_databricksquery.py @@ -30,11 +30,10 @@ def tearDownModule(): class TestDatabricksQuery(unittest.TestCase): """Test databricksquery.""" - @classmethod - def setUp(cls): + def setUp(self): import databricksquery - cls.databricksquery = databricksquery - cls.DatabricksQueryCommand = databricksquery.DatabricksQueryCommand + self.databricksquery = databricksquery + self.DatabricksQueryCommand = databricksquery.DatabricksQueryCommand @patch("databricksquery.utils.get_databricks_configs") def test_cluster_exception(self, mock_get_config): @@ -218,8 +217,8 @@ def test_fetch_data_status_finished_loop(self,mock_time, mock_utils, mock_com): db_query_obj.cluster = "test_cluster" client = mock_com.return_value = MagicMock() client.get_cluster_id.return_value = "c1" - client.databricks_api.side_effect = [{"contextId": "context1"}, - {"id": "command_id1"}, + client.databricks_api.side_effect = [{"contextId": "context1"}, + {"id": "command_id1"}, {"status":"processing"}, {"status":"Finished", "results": {"data": [["1", "2"],["3", "4"]],"resultType": "table", "truncated": True, "schema":[{"name": "field1"},{"name": "field2"}]}}] db_query_obj.write_warning = MagicMock() @@ -231,3 +230,1184 @@ def test_fetch_data_status_finished_loop(self,mock_time, mock_utils, mock_com): db_query_obj.write_warning.assert_called_once_with("Results are truncated due to Databricks API limitations.") self.assertEqual(row1 , {'field1': '1', 'field2': '2'}) self.assertEqual(row2 , {'field1': '3', 'field2': '4'}) + + @patch("databricksquery.utils.get_databricks_configs") + def test_command_timeout_below_minimum(self, mock_get_config): + """Test that command timeout below minimum value raises error""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.account_name = "test_account" + db_query_obj.command_timeout = 15 # Below minimum of 30 + db_query_obj.write_error = MagicMock() + + mock_get_config.return_value = { + "auth_type": "PAT", + "databricks_instance": "test.databricks.com", + "databricks_pat": "test_token", + "config_for_dbquery": "dbsql", + "warehouse_id": "w123", + "admin_command_timeout": "300" + } + + resp = db_query_obj.generate() + try: + next(resp) + except (StopIteration, SystemExit): + pass + + db_query_obj.write_error.assert_called_with( + "Command Timeout value must be greater than or equal to 30 seconds." + ) + + @patch("databricksquery.get_splunkd_uri") + @patch("databricksquery.rest.simpleRequest") + @patch("databricksquery.time.sleep") + def test_cancel_query_success(self, mock_sleep, mock_simple_request, mock_get_uri): + """Test successful query cancellation when Splunk search is stopped""" + # Mock Splunkd URI + mock_get_uri.return_value = "https://localhost:8089" + + # Create mock XML response for Splunk search status + xml_response = ''' + + FINALIZING + 1 + ''' + + mock_simple_request.return_value = (None, xml_response.encode()) + + client = MagicMock() + client.databricks_api.return_value = ({}, 200) + + db_query_obj = self.DatabricksQueryCommand() + + # Call cancel_query method + db_query_obj.cancel_query( + search_sid="test_sid", + session_key="test_key", + client=client, + cancel_endpoint="/api/cancel", + data_for_cancelation={"statement_id": "stmt123"} + ) + + # Verify API was called to cancel query + client.databricks_api.assert_called_once_with( + "post", "/api/cancel", data={"statement_id": "stmt123"} + ) + + @patch("databricksquery.rest.simpleRequest") + @patch("databricksquery.time.sleep") + def test_cancel_query_unknown_sid(self, mock_sleep, mock_simple_request): + """Test cancel_query handles unknown SID gracefully""" + mock_simple_request.side_effect = Exception("unknown sid") + + client = MagicMock() + db_query_obj = self.DatabricksQueryCommand() + + # Should not raise exception, just log and break + db_query_obj.cancel_query( + search_sid="invalid_sid", + session_key="test_key", + client=client, + cancel_endpoint="/api/cancel", + data_for_cancelation={} + ) + + # Verify no API call was made since exception occurred early + client.databricks_api.assert_not_called() + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + @patch("databricksquery.ThreadPoolExecutor") + def test_warehouse_query_success(self, mock_executor, mock_utils, mock_com): + """Test successful warehouse query execution with result pagination""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.warehouse_id = "w123" + db_query_obj.write_warning = MagicMock() + + client = mock_com.return_value = MagicMock() + + # Mock warehouse status check + client.databricks_api.side_effect = [ + # Warehouse list + {"warehouses": [{"id": "w123", "state": "RUNNING"}]}, + # Execute query + {"statement_id": "stmt123"}, + # Status check - SUCCEEDED + { + "status": {"state": "SUCCEEDED"}, + "manifest": { + "total_row_count": 2, + "truncated": False, + "schema": { + "columns": [{"name": "col1"}, {"name": "col2"}] + } + }, + "result": { + "external_links": [ + { + "external_link": "http://example.com/chunk0", + "chunk_index": 0, + "next_chunk_internal_link": None + } + ] + } + } + ] + + # Mock external API response + client.external_api.return_value = [["val1", "val2"], ["val3", "val4"]] + + # Mock ThreadPoolExecutor + mock_executor_instance = MagicMock() + mock_executor.return_value.__enter__.return_value = mock_executor_instance + mock_executor_instance.map.return_value = [[["val1", "val2"], ["val3", "val4"]]] + + resp = db_query_obj.generate() + results = list(resp) + + self.assertEqual(len(results), 2) + self.assertEqual(results[0], {"col1": "val1", "col2": "val2"}) + self.assertEqual(results[1], {"col1": "val3", "col2": "val4"}) + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + def test_warehouse_query_failed_state(self, mock_utils, mock_com): + """Test warehouse query with FAILED state""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.warehouse_id = "w123" + db_query_obj.write_error = MagicMock() + + client = mock_com.return_value = MagicMock() + + client.databricks_api.side_effect = [ + # Warehouse list + {"warehouses": [{"id": "w123", "state": "RUNNING"}]}, + # Execute query + {"statement_id": "stmt123"}, + # Status check - FAILED + { + "status": { + "state": "FAILED", + "error": {"message": "Syntax error in query"} + } + } + ] + + resp = db_query_obj.generate() + try: + next(resp) + except StopIteration: + pass + + db_query_obj.write_error.assert_called() + error_msg = db_query_obj.write_error.call_args[0][0] + self.assertIn("FAILED", error_msg) + self.assertIn("Syntax error in query", error_msg) + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + @patch("databricksquery.time.sleep") + def test_cluster_query_timeout(self, mock_sleep, mock_utils, mock_com): + """Test cluster query execution timeout scenario""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.cluster = "test_cluster" + db_query_obj.command_timeout = 30 # Short timeout + db_query_obj.write_error = MagicMock() + + mock_utils.get_databricks_configs.return_value = { + "auth_type": "PAT", + "databricks_instance": "test.databricks.com", + "databricks_pat": "test_token", + "admin_command_timeout": "300" + } + + client = mock_com.return_value = MagicMock() + client.get_cluster_id.return_value = "c1" + + # Mock responses - keep returning PENDING state + client.databricks_api.side_effect = [ + {"id": "context1"}, # Create context + {"id": "command_id1"}, # Submit query + {"status": "Running"}, # Status checks - keep pending + {"status": "Running"}, + {"status": "Running"}, + {"status": "Running"}, + {"status": "Running"}, + {"status": "Running"}, + {"status": "Running"}, + {"status": "Running"}, + {"status": "Running"}, + {"status": "Running"}, + {"status": "Running"}, + ({}, 200), # Cancel response + {} # Context destroy + ] + + resp = db_query_obj.generate() + try: + list(resp) + except StopIteration: + pass + + db_query_obj.write_error.assert_called_with( + "Canceled the execution as command execution timed out" + ) + + @patch("databricksquery.utils.get_databricks_configs") + def test_invalid_account_name(self, mock_get_config): + """Test handling of invalid account name""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.account_name = "invalid_account" + db_query_obj.write_error = MagicMock() + + mock_get_config.return_value = None + + resp = db_query_obj.generate() + try: + next(resp) + except (StopIteration, SystemExit): + pass + + db_query_obj.write_error.assert_called_with( + "Account 'invalid_account' not found. Please provide valid Databricks account." + ) + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + @patch("databricksquery.time.sleep") + def test_warehouse_query_timeout(self, mock_sleep, mock_utils, mock_com): + """Test warehouse query execution timeout scenario""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.warehouse_id = "w123" + db_query_obj.command_timeout = 30 # Short timeout + db_query_obj.write_error = MagicMock() + + mock_utils.get_databricks_configs.return_value = { + "auth_type": "PAT", + "databricks_instance": "test.databricks.com", + "databricks_pat": "test_token", + "admin_command_timeout": "300", + "query_result_limit": "1000", + "thread_count": "4" + } + + client = mock_com.return_value = MagicMock() + + client.databricks_api.side_effect = [ + # Warehouse list + {"warehouses": [{"id": "w123", "state": "RUNNING"}]}, + # Execute query + {"statement_id": "stmt123"}, + # Status checks - keep returning PENDING/RUNNING + {"status": {"state": "PENDING"}}, + {"status": {"state": "RUNNING"}}, + {"status": {"state": "RUNNING"}}, + {"status": {"state": "RUNNING"}}, + {"status": {"state": "RUNNING"}}, + {"status": {"state": "RUNNING"}}, + {"status": {"state": "RUNNING"}}, + {"status": {"state": "RUNNING"}}, + {"status": {"state": "RUNNING"}}, + {"status": {"state": "RUNNING"}}, + {"status": {"state": "RUNNING"}}, + # Cancel response + ({}, 200) + ] + + resp = db_query_obj.generate() + try: + list(resp) + except (StopIteration, SystemExit): + pass + + db_query_obj.write_error.assert_called_with( + "Canceled the execution as command execution timed out" + ) + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + def test_command_timeout_exceeds_admin_max(self, mock_utils, mock_com): + """Test command timeout exceeds admin configured maximum""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.warehouse_id = "w123" + db_query_obj.command_timeout = 500 # Exceeds admin max of 300 + db_query_obj.write_warning = MagicMock() + db_query_obj.write_error = MagicMock() + + mock_utils.get_databricks_configs.return_value = { + "auth_type": "PAT", + "databricks_instance": "test.databricks.com", + "databricks_pat": "test_token", + "warehouse_id": "w123", + "admin_command_timeout": "300", + "query_result_limit": "1000", + "thread_count": "4" + } + + client = mock_com.return_value = MagicMock() + client.databricks_api.side_effect = [ + {"warehouses": [{"id": "w123", "state": "RUNNING"}]}, + {"statement_id": "stmt123"}, + { + "status": {"state": "SUCCEEDED"}, + "manifest": { + "total_row_count": 0, + "truncated": False, + "schema": {"columns": []} + }, + "result": {"external_links": []} + } + ] + + resp = db_query_obj.generate() + try: + list(resp) + except (StopIteration, SystemExit): + pass + + # Should have written warning about using max value + db_query_obj.write_warning.assert_called() + warning_msg = db_query_obj.write_warning.call_args[0][0] + self.assertIn("300", warning_msg) + self.assertIn("500", warning_msg) + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + def test_limit_exceeds_admin_max(self, mock_utils, mock_com): + """Test limit exceeds admin configured maximum""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.warehouse_id = "w123" + db_query_obj.limit = 5000 # Exceeds admin max of 1000 + db_query_obj.write_warning = MagicMock() + + mock_utils.get_databricks_configs.return_value = { + "auth_type": "PAT", + "databricks_instance": "test.databricks.com", + "databricks_pat": "test_token", + "warehouse_id": "w123", + "admin_command_timeout": "300", + "query_result_limit": "1000", + "thread_count": "4" + } + + client = mock_com.return_value = MagicMock() + client.databricks_api.side_effect = [ + {"warehouses": [{"id": "w123", "state": "RUNNING"}]}, + {"statement_id": "stmt123"}, + { + "status": {"state": "SUCCEEDED"}, + "manifest": { + "total_row_count": 0, + "truncated": False, + "schema": {"columns": []} + }, + "result": {"external_links": []} + } + ] + + resp = db_query_obj.generate() + try: + list(resp) + except (StopIteration, SystemExit): + pass + + # Should have written warning about using max value + db_query_obj.write_warning.assert_called() + warning_msg = db_query_obj.write_warning.call_args[0][0] + self.assertIn("1000", warning_msg) + self.assertIn("5000", warning_msg) + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + def test_cluster_canceled_status(self, mock_utils, mock_com): + """Test cluster query with Canceled status""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.cluster = "test_cluster" + db_query_obj.write_error = MagicMock() + + client = mock_com.return_value = MagicMock() + client.get_cluster_id.return_value = "c1" + client.databricks_api.side_effect = [ + {"id": "context1"}, + {"id": "command_id1"}, + {"status": "Canceled"} + ] + + resp = db_query_obj.generate() + try: + next(resp) + except StopIteration: + pass + + db_query_obj.write_error.assert_called_with( + "Could not complete the query execution. Status: Canceled." + ) + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + def test_cluster_command_cancellation_exception(self, mock_utils, mock_com): + """Test cluster query with CommandCancelledException""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.cluster = "test_cluster" + db_query_obj.write_error = MagicMock() + + client = mock_com.return_value = MagicMock() + client.get_cluster_id.return_value = "c1" + client.databricks_api.side_effect = [ + {"id": "context1"}, + {"id": "command_id1"}, + { + "status": "Finished", + "results": { + "resultType": "error", + "cause": "CommandCancelledException: Query was canceled" + } + } + ] + + resp = db_query_obj.generate() + try: + next(resp) + except StopIteration: + pass + + db_query_obj.write_error.assert_called_with("Search Canceled!") + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + def test_cluster_finished_with_custom_error_summary(self, mock_utils, mock_com): + """Test cluster query finished with custom error summary""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.cluster = "test_cluster" + db_query_obj.write_error = MagicMock() + + client = mock_com.return_value = MagicMock() + client.get_cluster_id.return_value = "c1" + client.databricks_api.side_effect = [ + {"id": "context1"}, + {"id": "command_id1"}, + { + "status": "Finished", + "results": { + "resultType": "error", + "summary": "Table not found: my_table" + } + } + ] + + resp = db_query_obj.generate() + try: + next(resp) + except StopIteration: + pass + + db_query_obj.write_error.assert_called_with("Table not found: my_table") + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + def test_both_cluster_and_warehouse_provided(self, mock_utils, mock_com): + """Test error when both cluster and warehouse_id are provided""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.cluster = "test_cluster" + db_query_obj.warehouse_id = "w123" + db_query_obj.write_error = MagicMock() + + mock_utils.get_databricks_configs.return_value = { + "auth_type": "PAT", + "databricks_instance": "test.databricks.com", + "databricks_pat": "test_token", + "admin_command_timeout": "300" + } + + resp = db_query_obj.generate() + try: + next(resp) + except StopIteration: + pass + + db_query_obj.write_error.assert_called_with( + "Provide only one of Cluster or Warehouse ID" + ) + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + def test_no_config_for_cluster_or_warehouse(self, mock_utils, mock_com): + """Test error when no cluster or warehouse configuration found""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.write_error = MagicMock() + + mock_utils.get_databricks_configs.return_value = { + "auth_type": "PAT", + "databricks_instance": "test.databricks.com", + "databricks_pat": "test_token", + "admin_command_timeout": "300", + "config_for_dbquery": None + } + + resp = db_query_obj.generate() + try: + next(resp) + except StopIteration: + pass + + error_msg = db_query_obj.write_error.call_args[0][0] + self.assertIn("No configuration found", error_msg) + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + @patch("databricksquery.time.sleep") + def test_warehouse_starting_state(self, mock_sleep, mock_utils, mock_com): + """Test warehouse in STARTING state, then transitions to RUNNING""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.warehouse_id = "w123" + + mock_utils.get_databricks_configs.return_value = { + "auth_type": "PAT", + "databricks_instance": "test.databricks.com", + "databricks_pat": "test_token", + "admin_command_timeout": "300", + "query_result_limit": "1000", + "thread_count": "4" + } + + client = mock_com.return_value = MagicMock() + + client.databricks_api.side_effect = [ + # Initial warehouse list - STARTING + {"warehouses": [{"id": "w123", "state": "STARTING"}]}, + # Status check - now RUNNING + {"state": "RUNNING"}, + # Execute query + {"statement_id": "stmt123"}, + # Query status - SUCCEEDED + { + "status": {"state": "SUCCEEDED"}, + "manifest": { + "total_row_count": 0, + "truncated": False, + "schema": {"columns": []} + }, + "result": {"external_links": []} + } + ] + + resp = db_query_obj.generate() + try: + list(resp) + except (StopIteration, SystemExit): + pass + + # Verify sleep was called while waiting for warehouse + mock_sleep.assert_called() + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + @patch("databricksquery.time.sleep") + def test_warehouse_stopped_then_started(self, mock_sleep, mock_utils, mock_com): + """Test warehouse in STOPPED state, gets started, then runs query""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.warehouse_id = "w123" + + mock_utils.get_databricks_configs.return_value = { + "auth_type": "PAT", + "databricks_instance": "test.databricks.com", + "databricks_pat": "test_token", + "admin_command_timeout": "300", + "query_result_limit": "1000", + "thread_count": "4" + } + + client = mock_com.return_value = MagicMock() + + client.databricks_api.side_effect = [ + # Initial warehouse list - STOPPED + {"warehouses": [{"id": "w123", "state": "STOPPED"}]}, + # Start warehouse call + {}, + # Status check - now RUNNING + {"state": "RUNNING"}, + # Execute query + {"statement_id": "stmt123"}, + # Query status - SUCCEEDED + { + "status": {"state": "SUCCEEDED"}, + "manifest": { + "total_row_count": 0, + "truncated": False, + "schema": {"columns": []} + }, + "result": {"external_links": []} + } + ] + + resp = db_query_obj.generate() + try: + list(resp) + except (StopIteration, SystemExit): + pass + + # Verify warehouse start was called + calls = [call for call in client.databricks_api.call_args_list + if len(call[0]) > 1 and 'start' in str(call[0][1])] + self.assertTrue(len(calls) > 0) + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + def test_warehouse_not_found(self, mock_utils, mock_com): + """Test error when warehouse ID not found""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.warehouse_id = "w999" + db_query_obj.write_error = MagicMock() + + client = mock_com.return_value = MagicMock() + client.databricks_api.return_value = { + "warehouses": [{"id": "w123", "state": "RUNNING"}] + } + + resp = db_query_obj.generate() + try: + next(resp) + except StopIteration: + pass + + error_msg = db_query_obj.write_error.call_args[0][0] + self.assertIn("No SQL warehouse found", error_msg) + self.assertIn("w999", error_msg) + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + @patch("databricksquery.ThreadPoolExecutor") + def test_warehouse_query_multiple_chunks(self, mock_executor, mock_utils, mock_com): + """Test warehouse query with multiple result chunks""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.warehouse_id = "w123" + + client = mock_com.return_value = MagicMock() + + client.databricks_api.side_effect = [ + # Warehouse list + {"warehouses": [{"id": "w123", "state": "RUNNING"}]}, + # Execute query + {"statement_id": "stmt123"}, + # Status check - SUCCEEDED with multiple chunks + { + "status": {"state": "SUCCEEDED"}, + "manifest": { + "total_row_count": 4, + "truncated": False, + "schema": {"columns": [{"name": "col1"}]} + }, + "result": { + "external_links": [{ + "external_link": "http://example.com/chunk0", + "chunk_index": 0, + "next_chunk_internal_link": "/internal/chunk1" + }] + } + }, + # Get next chunk + { + "external_links": [{ + "external_link": "http://example.com/chunk1", + "chunk_index": 1, + "next_chunk_internal_link": None + }] + } + ] + + # Mock ThreadPoolExecutor to return data from both chunks + mock_executor_instance = MagicMock() + mock_executor.return_value.__enter__.return_value = mock_executor_instance + mock_executor_instance.map.return_value = [ + [["val1"], ["val2"]], # Chunk 0 + [["val3"], ["val4"]] # Chunk 1 + ] + + resp = db_query_obj.generate() + results = list(resp) + + self.assertEqual(len(results), 4) + self.assertEqual(results[0], {"col1": "val1"}) + self.assertEqual(results[3], {"col1": "val4"}) + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + def test_warehouse_query_truncated_results(self, mock_utils, mock_com): + """Test warehouse query with truncated results warning""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.warehouse_id = "w123" + db_query_obj.write_warning = MagicMock() + + client = mock_com.return_value = MagicMock() + + client.databricks_api.side_effect = [ + {"warehouses": [{"id": "w123", "state": "RUNNING"}]}, + {"statement_id": "stmt123"}, + { + "status": {"state": "SUCCEEDED"}, + "manifest": { + "total_row_count": 1000, + "truncated": True, # Results are truncated + "schema": {"columns": [{"name": "col1"}]} + }, + "result": { + "external_links": [{ + "external_link": "http://example.com/chunk0", + "chunk_index": 0, + "next_chunk_internal_link": None + }] + } + } + ] + + client.external_api.return_value = [["val1"]] + + resp = db_query_obj.generate() + list(resp) + + db_query_obj.write_warning.assert_called_with( + "Result limit exceeded, hence results are truncated." + ) + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + def test_warehouse_query_no_external_links(self, mock_utils, mock_com): + """Test warehouse query with no external links raises error""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.warehouse_id = "w123" + db_query_obj.write_error = MagicMock() + + client = mock_com.return_value = MagicMock() + + client.databricks_api.side_effect = [ + {"warehouses": [{"id": "w123", "state": "RUNNING"}]}, + {"statement_id": "stmt123"}, + { + "status": {"state": "SUCCEEDED"}, + "manifest": { + "total_row_count": 10, + "truncated": False, + "schema": {"columns": [{"name": "col1"}]} + }, + "result": { + "external_links": None # No links! + } + } + ] + + resp = db_query_obj.generate() + try: + list(resp) + except StopIteration: + pass + + error_msg = db_query_obj.write_error.call_args[0][0] + self.assertIn("No data returned", error_msg) + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + def test_config_uses_warehouse_from_settings(self, mock_utils, mock_com): + """Test using warehouse_id from configuration when not provided in command""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + # No warehouse_id or cluster set on command + + mock_utils.get_databricks_configs.return_value = { + "auth_type": "PAT", + "databricks_instance": "test.databricks.com", + "databricks_pat": "test_token", + "config_for_dbquery": "dbsql", + "warehouse_id": "w123", + "admin_command_timeout": "300", + "query_result_limit": "1000", + "thread_count": "4" + } + + client = mock_com.return_value = MagicMock() + client.databricks_api.side_effect = [ + {"warehouses": [{"id": "w123", "state": "RUNNING"}]}, + {"statement_id": "stmt123"}, + { + "status": {"state": "SUCCEEDED"}, + "manifest": { + "total_row_count": 0, + "truncated": False, + "schema": {"columns": []} + }, + "result": {"external_links": []} + } + ] + + resp = db_query_obj.generate() + try: + list(resp) + except (StopIteration, SystemExit): + pass + + # Verify warehouse_id was set from config + self.assertEqual(db_query_obj.warehouse_id, "w123") + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + def test_config_uses_cluster_from_settings(self, mock_utils, mock_com): + """Test using cluster from configuration when not provided in command""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.write_error = MagicMock() + + mock_utils.get_databricks_configs.return_value = { + "auth_type": "PAT", + "databricks_instance": "test.databricks.com", + "databricks_pat": "test_token", + "config_for_dbquery": "interactive_cluster", + "cluster_name": "test_cluster", + "admin_command_timeout": "300" + } + + client = mock_com.return_value = MagicMock() + client.get_cluster_id.side_effect = Exception("test error") + + resp = db_query_obj.generate() + try: + list(resp) + except StopIteration: + pass + + # Verify cluster was set from config + self.assertEqual(db_query_obj.cluster, "test_cluster") + + @patch("databricksquery.get_splunkd_uri") + @patch("databricksquery.rest.simpleRequest") + @patch("databricksquery.time.sleep") + def test_cancel_query_non_200_response(self, mock_sleep, mock_simple_request, mock_get_uri): + """Test cancel_query with non-200 response from cancel API""" + mock_get_uri.return_value = "https://localhost:8089" + + xml_response = ''' + + FINALIZING + 1 + ''' + + mock_simple_request.return_value = (None, xml_response.encode()) + + client = MagicMock() + client.databricks_api.return_value = ({"error": "some error"}, 400) + + db_query_obj = self.DatabricksQueryCommand() + + db_query_obj.cancel_query( + search_sid="test_sid", + session_key="test_key", + client=client, + cancel_endpoint="/api/cancel", + data_for_cancelation={"statement_id": "stmt123"} + ) + + # Verify API was called + client.databricks_api.assert_called_once() + + @patch("databricksquery.get_splunkd_uri") + @patch("databricksquery.rest.simpleRequest") + @patch("databricksquery.time.sleep") + def test_cancel_query_other_exception(self, mock_sleep, mock_simple_request, mock_get_uri): + """Test cancel_query handles other exceptions gracefully""" + mock_get_uri.return_value = "https://localhost:8089" + mock_simple_request.side_effect = Exception("Network error") + + client = MagicMock() + db_query_obj = self.DatabricksQueryCommand() + + # Should not raise exception, just log and break + db_query_obj.cancel_query( + search_sid="test_sid", + session_key="test_key", + client=client, + cancel_endpoint="/api/cancel", + data_for_cancelation={} + ) + + # Verify no API call was made since exception occurred early + client.databricks_api.assert_not_called() + + @patch("databricksquery.get_splunkd_uri") + @patch("databricksquery.rest.simpleRequest") + @patch("databricksquery.time.sleep") + def test_cancel_query_continues_when_not_finalized(self, mock_sleep, mock_simple_request, mock_get_uri): + """Test cancel_query continues checking when search is not finalized""" + mock_get_uri.return_value = "https://localhost:8089" + + # First call - not finalized, second call - finalized + xml_not_finalized = ''' + + RUNNING + 0 + ''' + + xml_finalized = ''' + + FINALIZING + 1 + ''' + + mock_simple_request.side_effect = [ + (None, xml_not_finalized.encode()), + (None, xml_finalized.encode()) + ] + + client = MagicMock() + client.databricks_api.return_value = ({}, 200) + + db_query_obj = self.DatabricksQueryCommand() + + db_query_obj.cancel_query( + search_sid="test_sid", + session_key="test_key", + client=client, + cancel_endpoint="/api/cancel", + data_for_cancelation={"statement_id": "stmt123"} + ) + + # Verify sleep was called while waiting + mock_sleep.assert_called() + # Verify API was eventually called to cancel + client.databricks_api.assert_called_once() + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + def test_warehouse_invalid_state(self, mock_utils, mock_com): + """Test warehouse in invalid state raises error""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.warehouse_id = "w123" + db_query_obj.write_error = MagicMock() + + mock_utils.get_databricks_configs.return_value = { + "auth_type": "PAT", + "databricks_instance": "test.databricks.com", + "databricks_pat": "test_token", + "admin_command_timeout": "300", + "query_result_limit": "1000", + "thread_count": "4" + } + + client = mock_com.return_value = MagicMock() + + client.databricks_api.side_effect = [ + # Warehouse in ERROR state + {"warehouses": [{"id": "w123", "state": "ERROR"}]}, + # Start warehouse call + {}, + # Status check returns ERROR state (never becomes RUNNING) + {"state": "ERROR"} + ] + + resp = db_query_obj.generate() + try: + list(resp) + except (StopIteration, SystemExit): + pass + + # Verify error was written + db_query_obj.write_error.assert_called() + error_msg = db_query_obj.write_error.call_args[0][0] + self.assertIn("ERROR", error_msg) + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + def test_command_timeout_within_admin_max(self, mock_utils, mock_com): + """Test command timeout within admin max uses provided value""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.warehouse_id = "w123" + db_query_obj.command_timeout = 100 # Within admin max of 300 + db_query_obj.write_warning = MagicMock() + + mock_utils.get_databricks_configs.return_value = { + "auth_type": "PAT", + "databricks_instance": "test.databricks.com", + "databricks_pat": "test_token", + "warehouse_id": "w123", + "admin_command_timeout": "300", + "query_result_limit": "1000", + "thread_count": "4" + } + + client = mock_com.return_value = MagicMock() + client.databricks_api.side_effect = [ + {"warehouses": [{"id": "w123", "state": "RUNNING"}]}, + {"statement_id": "stmt123"}, + { + "status": {"state": "SUCCEEDED"}, + "manifest": { + "total_row_count": 0, + "truncated": False, + "schema": {"columns": []} + }, + "result": {"external_links": []} + } + ] + + resp = db_query_obj.generate() + try: + list(resp) + except (StopIteration, SystemExit): + pass + + # Verify no warning was issued (timeout value was within bounds) + db_query_obj.write_warning.assert_not_called() + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + def test_limit_within_admin_max(self, mock_utils, mock_com): + """Test limit within admin max uses provided value""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.warehouse_id = "w123" + db_query_obj.limit = 100 # Within admin max of 1000 + db_query_obj.write_warning = MagicMock() + + mock_utils.get_databricks_configs.return_value = { + "auth_type": "PAT", + "databricks_instance": "test.databricks.com", + "databricks_pat": "test_token", + "warehouse_id": "w123", + "admin_command_timeout": "300", + "query_result_limit": "1000", + "thread_count": "4" + } + + client = mock_com.return_value = MagicMock() + client.databricks_api.side_effect = [ + {"warehouses": [{"id": "w123", "state": "RUNNING"}]}, + {"statement_id": "stmt123"}, + { + "status": {"state": "SUCCEEDED"}, + "manifest": { + "total_row_count": 0, + "truncated": False, + "schema": {"columns": []} + }, + "result": {"external_links": []} + } + ] + + resp = db_query_obj.generate() + try: + list(resp) + except (StopIteration, SystemExit): + pass + + # Verify no warning was issued (limit value was within bounds) + db_query_obj.write_warning.assert_not_called() + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + def test_no_command_timeout_uses_admin_default(self, mock_utils, mock_com): + """Test no command timeout uses admin default""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.warehouse_id = "w123" + db_query_obj.write_warning = MagicMock() + # No command_timeout set (defaults to None) + + mock_utils.get_databricks_configs.return_value = { + "auth_type": "PAT", + "databricks_instance": "test.databricks.com", + "databricks_pat": "test_token", + "warehouse_id": "w123", + "admin_command_timeout": "300", + "query_result_limit": "1000", + "thread_count": "4" + } + + client = mock_com.return_value = MagicMock() + client.databricks_api.side_effect = [ + {"warehouses": [{"id": "w123", "state": "RUNNING"}]}, + {"statement_id": "stmt123"}, + { + "status": {"state": "SUCCEEDED"}, + "manifest": { + "total_row_count": 0, + "truncated": False, + "schema": {"columns": []} + }, + "result": {"external_links": []} + } + ] + + resp = db_query_obj.generate() + try: + list(resp) + except (StopIteration, SystemExit): + pass + + # Verify no warning was issued (should use admin default without warning) + db_query_obj.write_warning.assert_not_called() + # Verify command_timeout was not set (remains None, uses admin default) + self.assertIsNone(db_query_obj.command_timeout) + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + def test_no_limit_uses_admin_default(self, mock_utils, mock_com): + """Test no limit uses admin default""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.warehouse_id = "w123" + db_query_obj.write_warning = MagicMock() + # No limit set (defaults to None) + + mock_utils.get_databricks_configs.return_value = { + "auth_type": "PAT", + "databricks_instance": "test.databricks.com", + "databricks_pat": "test_token", + "warehouse_id": "w123", + "admin_command_timeout": "300", + "query_result_limit": "1000", + "thread_count": "4" + } + + client = mock_com.return_value = MagicMock() + client.databricks_api.side_effect = [ + {"warehouses": [{"id": "w123", "state": "RUNNING"}]}, + {"statement_id": "stmt123"}, + { + "status": {"state": "SUCCEEDED"}, + "manifest": { + "total_row_count": 0, + "truncated": False, + "schema": {"columns": []} + }, + "result": {"external_links": []} + } + ] + + resp = db_query_obj.generate() + try: + list(resp) + except (StopIteration, SystemExit): + pass + + # Verify no warning was issued (should use admin default without warning) + db_query_obj.write_warning.assert_not_called() + # Verify limit was not set (remains None, uses admin default) + self.assertIsNone(db_query_obj.limit) diff --git a/tests/test_databricksrunstatus.py b/tests/test_databricksrunstatus.py new file mode 100644 index 0000000..73b0e1e --- /dev/null +++ b/tests/test_databricksrunstatus.py @@ -0,0 +1,793 @@ +import declare +import unittest +import sys +import os +from mock import patch, MagicMock, call + +mocked_modules = {} + + +def setUpModule(): + global mocked_modules + + module_to_be_mocked = [ + 'log_manager', + 'splunk', + 'splunk.rest', + 'splunk.admin', + 'splunk.clilib', + 'splunk.clilib.cli_common', + 'solnlib', + 'solnlib.server_info', + 'solnlib.utils', + 'solnlib.credentials', + 'splunk_aoblib', + 'splunk_aoblib.rest_migration', + 'splunklib', + 'splunklib.client', + 'splunklib.results', + 'splunklib.binding', + 'splunk.Intersplunk', + 'splunktaucclib', + 'splunktaucclib.rest_handler', + 'splunktaucclib.rest_handler.endpoint', + 'splunktaucclib.rest_handler.endpoint.validator' + ] + + mocked_modules = {module: MagicMock() for module in module_to_be_mocked} + + for module, magicmock in mocked_modules.items(): + patch.dict('sys.modules', **{module: magicmock}).start() + + +def tearDownModule(): + patch.stopall() + + +class TestDatabricksrunstatus(unittest.TestCase): + """Test databricksrunstatus script.""" + + def _run_script(self): + """Helper method to run the databricksrunstatus script as __main__.""" + # Delete the module from cache to force reimport + if 'databricksrunstatus' in sys.modules: + del sys.modules['databricksrunstatus'] + + try: + # Read and execute the script with __name__ == "__main__" + script_path = os.path.join( + os.path.dirname(__file__), + '..', + 'app', + 'bin', + 'databricksrunstatus.py' + ) + script_path = os.path.abspath(script_path) + + with open(script_path, 'r') as f: + script_code = f.read() + + # Create a globals dict with __name__ set to "__main__" + script_globals = { + '__name__': '__main__', + '__file__': script_path, + } + + # Execute the script code + exec(compile(script_code, script_path, 'exec'), script_globals) + except SystemExit: + # Script calls sys.exit, which we want to capture + pass + + def test_no_results_early_exit(self): + """Test early exit when no results are returned (lines 24-26).""" + with patch('splunk.Intersplunk.getOrganizedResults') as mock_get_results, \ + patch('sys.exit') as mock_exit: + + # Mock getOrganizedResults to return empty results + mock_get_results.return_value = ([], {}, {'sessionKey': 'test_session_key'}) + + # Execute the script + self._run_script() + + # Verify sys.exit(0) was called + mock_exit.assert_called_once_with(0) + + def test_running_state_no_status_change(self): + """Test RUNNING state when status hasn't changed (lines 49-53).""" + with patch('splunk.Intersplunk.getOrganizedResults') as mock_get_results, \ + patch('databricks_com.DatabricksClient') as mock_client_class, \ + patch('databricks_common_utils.ingest_data_to_splunk') as mock_ingest, \ + patch('time.time', return_value=1234567890.0), \ + patch('sys.exit'): + + # Setup test data + test_result = { + 'run_id': '123', + 'account_name': 'test_account', + 'index': 'test_index', + 'sourcetype': 'databricks:databricksjob', + 'run_execution_status': 'Running', # Already Running + 'uid': 'test_uid' + } + mock_get_results.return_value = ([test_result], {}, {'sessionKey': 'test_session_key'}) + + # Mock DatabricksClient + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.databricks_api.return_value = { + 'state': { + 'life_cycle_state': 'RUNNING' + } + } + + # Execute the script + self._run_script() + + # Verify no ingestion occurred since status didn't change + mock_ingest.assert_not_called() + + def test_state_change_to_running(self): + """Test state change to RUNNING triggers ingestion (lines 49-53).""" + with patch('splunk.Intersplunk.getOrganizedResults') as mock_get_results, \ + patch('databricks_com.DatabricksClient') as mock_client_class, \ + patch('databricks_common_utils.ingest_data_to_splunk') as mock_ingest, \ + patch('time.time', return_value=1234567890.0), \ + patch('sys.exit'): + + # Setup test data + test_result = { + 'run_id': '123', + 'account_name': 'test_account', + 'index': 'test_index', + 'sourcetype': 'databricks:databricksjob', + 'run_execution_status': 'Pending', # Different from RUNNING + 'uid': 'test_uid' + } + mock_get_results.return_value = ([test_result], {}, {'sessionKey': 'test_session_key'}) + + # Mock DatabricksClient + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.databricks_api.return_value = { + 'state': { + 'life_cycle_state': 'RUNNING' + } + } + + # Execute the script + self._run_script() + + # Verify ingestion occurred + mock_ingest.assert_called_once() + call_args = mock_ingest.call_args + ingested_data = call_args[0][0] + self.assertEqual(ingested_data['run_execution_status'], 'Running') + self.assertEqual(ingested_data['created_time'], 1234567890.0) + + def test_pending_state_transition(self): + """Test PENDING state transition (lines 55-59).""" + with patch('splunk.Intersplunk.getOrganizedResults') as mock_get_results, \ + patch('databricks_com.DatabricksClient') as mock_client_class, \ + patch('databricks_common_utils.ingest_data_to_splunk') as mock_ingest, \ + patch('time.time', return_value=1234567890.0), \ + patch('sys.exit'): + + # Setup test data + test_result = { + 'run_id': '456', + 'account_name': 'test_account', + 'index': 'test_index', + 'sourcetype': 'databricks:databricksjob', + 'run_execution_status': 'Running', # Different from PENDING + 'uid': 'test_uid_456' + } + mock_get_results.return_value = ([test_result], {}, {'sessionKey': 'test_session_key'}) + + # Mock DatabricksClient + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.databricks_api.return_value = { + 'state': { + 'life_cycle_state': 'PENDING' + } + } + + # Execute the script + self._run_script() + + # Verify ingestion occurred with Pending status + mock_ingest.assert_called_once() + call_args = mock_ingest.call_args + ingested_data = call_args[0][0] + self.assertEqual(ingested_data['run_execution_status'], 'Pending') + self.assertEqual(ingested_data['created_time'], 1234567890.0) + + def test_terminated_success_state(self): + """Test TERMINATED/SUCCESS state (lines 61-65).""" + with patch('splunk.Intersplunk.getOrganizedResults') as mock_get_results, \ + patch('databricks_com.DatabricksClient') as mock_client_class, \ + patch('databricks_common_utils.ingest_data_to_splunk') as mock_ingest, \ + patch('time.time', return_value=1234567890.0), \ + patch('sys.exit'): + + # Setup test data + test_result = { + 'run_id': '789', + 'account_name': 'test_account', + 'index': 'test_index', + 'sourcetype': 'databricks:databricksjob', + 'run_execution_status': 'Running', + 'uid': 'test_uid_789' + } + mock_get_results.return_value = ([test_result], {}, {'sessionKey': 'test_session_key'}) + + # Mock DatabricksClient + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.databricks_api.return_value = { + 'state': { + 'life_cycle_state': 'TERMINATED', + 'result_state': 'SUCCESS' + } + } + + # Execute the script + self._run_script() + + # Verify ingestion occurred with Success status + mock_ingest.assert_called_once() + call_args = mock_ingest.call_args + ingested_data = call_args[0][0] + self.assertEqual(ingested_data['run_execution_status'], 'Success') + self.assertEqual(ingested_data['created_time'], 1234567890.0) + + def test_terminated_failed_state(self): + """Test TERMINATED/FAILED state (lines 66-69).""" + with patch('splunk.Intersplunk.getOrganizedResults') as mock_get_results, \ + patch('databricks_com.DatabricksClient') as mock_client_class, \ + patch('databricks_common_utils.ingest_data_to_splunk') as mock_ingest, \ + patch('time.time', return_value=1234567890.0), \ + patch('sys.exit'): + + # Setup test data + test_result = { + 'run_id': '999', + 'account_name': 'test_account', + 'index': 'test_index', + 'sourcetype': 'databricks:databricksjob', + 'run_execution_status': 'Running', + 'uid': 'test_uid_999' + } + mock_get_results.return_value = ([test_result], {}, {'sessionKey': 'test_session_key'}) + + # Mock DatabricksClient + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.databricks_api.return_value = { + 'state': { + 'life_cycle_state': 'TERMINATED', + 'result_state': 'FAILED' + } + } + + # Execute the script + self._run_script() + + # Verify ingestion occurred with Failed status + mock_ingest.assert_called_once() + call_args = mock_ingest.call_args + ingested_data = call_args[0][0] + self.assertEqual(ingested_data['run_execution_status'], 'Failed') + + def test_terminated_canceled_state(self): + """Test TERMINATED/CANCELED state (lines 70-73).""" + with patch('splunk.Intersplunk.getOrganizedResults') as mock_get_results, \ + patch('databricks_com.DatabricksClient') as mock_client_class, \ + patch('databricks_common_utils.ingest_data_to_splunk') as mock_ingest, \ + patch('time.time', return_value=1234567890.0), \ + patch('sys.exit'): + + # Setup test data + test_result = { + 'run_id': '111', + 'account_name': 'test_account', + 'index': 'test_index', + 'sourcetype': 'databricks:databricksjob', + 'run_execution_status': 'Running', + 'uid': 'test_uid_111' + } + mock_get_results.return_value = ([test_result], {}, {'sessionKey': 'test_session_key'}) + + # Mock DatabricksClient + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.databricks_api.return_value = { + 'state': { + 'life_cycle_state': 'TERMINATED', + 'result_state': 'CANCELED' + } + } + + # Execute the script + self._run_script() + + # Verify ingestion occurred with Canceled status + mock_ingest.assert_called_once() + call_args = mock_ingest.call_args + ingested_data = call_args[0][0] + self.assertEqual(ingested_data['run_execution_status'], 'Canceled') + + def test_other_state_handling(self): + """Test other state handling (lines 74-80).""" + with patch('splunk.Intersplunk.getOrganizedResults') as mock_get_results, \ + patch('databricks_com.DatabricksClient') as mock_client_class, \ + patch('databricks_common_utils.ingest_data_to_splunk') as mock_ingest, \ + patch('time.time', return_value=1234567890.0), \ + patch('sys.exit'): + + # Setup test data + test_result = { + 'run_id': '222', + 'account_name': 'test_account', + 'index': 'test_index', + 'sourcetype': 'databricks:databricksjob', + 'run_execution_status': 'Running', + 'uid': 'test_uid_222' + } + mock_get_results.return_value = ([test_result], {}, {'sessionKey': 'test_session_key'}) + + # Mock DatabricksClient + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.databricks_api.return_value = { + 'state': { + 'life_cycle_state': 'UNKNOWN_STATE', + 'result_state': 'UNKNOWN_RESULT' + } + } + + # Execute the script + self._run_script() + + # Verify ingestion occurred with the new state + mock_ingest.assert_called_once() + call_args = mock_ingest.call_args + ingested_data = call_args[0][0] + self.assertEqual(ingested_data['run_execution_status'], 'UNKNOWN_RESULT') + + def test_invalid_run_id_value_error(self): + """Test invalid run_id handling (lines 44-46, ValueError).""" + with patch('splunk.Intersplunk.getOrganizedResults') as mock_get_results, \ + patch('databricks_com.DatabricksClient') as mock_client_class, \ + patch('databricks_common_utils.ingest_data_to_splunk') as mock_ingest, \ + patch('sys.exit'): + + # Setup test data with invalid run_id + test_result = { + 'run_id': 'invalid_run_id', # String that can't be converted to int + 'account_name': 'test_account', + 'index': 'test_index', + 'sourcetype': 'databricks:databricksjob', + 'run_execution_status': 'Running', + 'uid': 'test_uid_invalid' + } + mock_get_results.return_value = ([test_result], {}, {'sessionKey': 'test_session_key'}) + + # Mock DatabricksClient (should not be called) + mock_client = MagicMock() + mock_client_class.return_value = mock_client + + # Execute the script + self._run_script() + + # Verify no ingestion occurred and databricks_api was not called + mock_ingest.assert_not_called() + mock_client.databricks_api.assert_not_called() + + def test_exception_handling_per_result(self): + """Test exception handling per result continues processing (lines 87-89).""" + with patch('splunk.Intersplunk.getOrganizedResults') as mock_get_results, \ + patch('databricks_com.DatabricksClient') as mock_client_class, \ + patch('databricks_common_utils.ingest_data_to_splunk') as mock_ingest, \ + patch('time.time', return_value=1234567890.0), \ + patch('sys.exit'): + + # Setup test data with two results + test_result1 = { + 'run_id': '333', + 'account_name': 'test_account', + 'index': 'test_index', + 'sourcetype': 'databricks:databricksjob', + 'run_execution_status': 'Running', + 'uid': 'test_uid_333' + } + test_result2 = { + 'run_id': '444', + 'account_name': 'test_account', + 'index': 'test_index', + 'sourcetype': 'databricks:databricksjob', + 'run_execution_status': 'Running', + 'uid': 'test_uid_444' + } + mock_get_results.return_value = ( + [test_result1, test_result2], + {}, + {'sessionKey': 'test_session_key'} + ) + + # Mock DatabricksClient - first call raises exception, second succeeds + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.databricks_api.side_effect = [ + Exception("API error"), # First result fails + { # Second result succeeds + 'state': { + 'life_cycle_state': 'TERMINATED', + 'result_state': 'SUCCESS' + } + } + ] + + # Execute the script + self._run_script() + + # Verify that second result was still processed despite first failing + mock_ingest.assert_called_once() + call_args = mock_ingest.call_args + ingested_data = call_args[0][0] + self.assertEqual(ingested_data['run_id'], '444') + self.assertEqual(ingested_data['run_execution_status'], 'Success') + + def test_outer_exception_handling(self): + """Test outer exception handling (lines 91-94).""" + with patch('splunk.Intersplunk.getOrganizedResults') as mock_get_results, \ + patch('sys.exit') as mock_exit: + + # Mock getOrganizedResults to raise an exception + mock_get_results.side_effect = Exception("Fatal error") + + # Execute the script + self._run_script() + + # Verify sys.exit(0) was called + mock_exit.assert_called_once_with(0) + + def test_param_field_splitting(self): + """Test param field is split on newlines (line 32).""" + with patch('splunk.Intersplunk.getOrganizedResults') as mock_get_results, \ + patch('databricks_com.DatabricksClient') as mock_client_class, \ + patch('databricks_common_utils.ingest_data_to_splunk') as mock_ingest, \ + patch('time.time', return_value=1234567890.0), \ + patch('sys.exit'): + + # Setup test data with param field + test_result = { + 'run_id': '555', + 'account_name': 'test_account', + 'index': 'test_index', + 'sourcetype': 'databricks:databricksjob', + 'run_execution_status': 'Pending', + 'param': 'param1\nparam2\nparam3', + 'uid': 'test_uid_555' + } + mock_get_results.return_value = ([test_result], {}, {'sessionKey': 'test_session_key'}) + + # Mock DatabricksClient + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.databricks_api.return_value = { + 'state': { + 'life_cycle_state': 'RUNNING' + } + } + + # Execute the script + self._run_script() + + # Verify param was split + mock_ingest.assert_called_once() + call_args = mock_ingest.call_args + ingested_data = call_args[0][0] + self.assertEqual(ingested_data['param'], ['param1', 'param2', 'param3']) + + def test_identifier_removal_for_databricksjob(self): + """Test identifier is removed for databricksjob sourcetype (line 38).""" + with patch('splunk.Intersplunk.getOrganizedResults') as mock_get_results, \ + patch('databricks_com.DatabricksClient') as mock_client_class, \ + patch('databricks_common_utils.ingest_data_to_splunk') as mock_ingest, \ + patch('time.time', return_value=1234567890.0), \ + patch('sys.exit'): + + # Setup test data with identifier field + test_result = { + 'run_id': '666', + 'account_name': 'test_account', + 'index': 'test_index', + 'sourcetype': 'databricks:databricksjob', + 'run_execution_status': 'Pending', + 'identifier': 'should_be_removed', + 'uid': 'test_uid_666' + } + mock_get_results.return_value = ([test_result], {}, {'sessionKey': 'test_session_key'}) + + # Mock DatabricksClient + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.databricks_api.return_value = { + 'state': { + 'life_cycle_state': 'RUNNING' + } + } + + # Execute the script + self._run_script() + + # Verify identifier was removed + mock_ingest.assert_called_once() + call_args = mock_ingest.call_args + ingested_data = call_args[0][0] + self.assertNotIn('identifier', ingested_data) + + def test_index_and_sourcetype_removed_from_ingestion(self): + """Test index and sourcetype are removed from ingested data (lines 39-40).""" + with patch('splunk.Intersplunk.getOrganizedResults') as mock_get_results, \ + patch('databricks_com.DatabricksClient') as mock_client_class, \ + patch('databricks_common_utils.ingest_data_to_splunk') as mock_ingest, \ + patch('time.time', return_value=1234567890.0), \ + patch('sys.exit'): + + # Setup test data + test_result = { + 'run_id': '777', + 'account_name': 'test_account', + 'index': 'test_index', + 'sourcetype': 'databricks:databricksjob', + 'run_execution_status': 'Pending', + 'uid': 'test_uid_777' + } + mock_get_results.return_value = ([test_result], {}, {'sessionKey': 'test_session_key'}) + + # Mock DatabricksClient + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.databricks_api.return_value = { + 'state': { + 'life_cycle_state': 'RUNNING' + } + } + + # Execute the script + self._run_script() + + # Verify index and sourcetype were removed from ingested data + mock_ingest.assert_called_once() + call_args = mock_ingest.call_args + ingested_data = call_args[0][0] + self.assertNotIn('index', ingested_data) + self.assertNotIn('sourcetype', ingested_data) + # But they should be passed as parameters to ingest_data_to_splunk + self.assertEqual(call_args[0][2], 'test_index') + self.assertEqual(call_args[0][3], 'databricks:databricksjob') + + def test_multiple_results_all_processed(self): + """Test multiple results are all processed.""" + with patch('splunk.Intersplunk.getOrganizedResults') as mock_get_results, \ + patch('databricks_com.DatabricksClient') as mock_client_class, \ + patch('databricks_common_utils.ingest_data_to_splunk') as mock_ingest, \ + patch('time.time', return_value=1234567890.0), \ + patch('sys.exit'): + + # Setup test data with three results + test_results = [ + { + 'run_id': '100', + 'account_name': 'test_account', + 'index': 'test_index', + 'sourcetype': 'databricks:databricksjob', + 'run_execution_status': 'Pending', + 'uid': 'uid_100' + }, + { + 'run_id': '101', + 'account_name': 'test_account', + 'index': 'test_index', + 'sourcetype': 'databricks:databricksjob', + 'run_execution_status': 'Running', + 'uid': 'uid_101' + }, + { + 'run_id': '102', + 'account_name': 'test_account', + 'index': 'test_index', + 'sourcetype': 'databricks:databricksjob', + 'run_execution_status': 'Running', + 'uid': 'uid_102' + } + ] + mock_get_results.return_value = (test_results, {}, {'sessionKey': 'test_session_key'}) + + # Mock DatabricksClient - all state changes + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.databricks_api.side_effect = [ + {'state': {'life_cycle_state': 'RUNNING'}}, + {'state': {'life_cycle_state': 'TERMINATED', 'result_state': 'SUCCESS'}}, + {'state': {'life_cycle_state': 'TERMINATED', 'result_state': 'FAILED'}} + ] + + # Execute the script + self._run_script() + + # Verify all three results were ingested + self.assertEqual(mock_ingest.call_count, 3) + + def test_no_uid_in_result(self): + """Test handling when uid is not present in result.""" + with patch('splunk.Intersplunk.getOrganizedResults') as mock_get_results, \ + patch('databricks_com.DatabricksClient') as mock_client_class, \ + patch('databricks_common_utils.ingest_data_to_splunk') as mock_ingest, \ + patch('time.time', return_value=1234567890.0), \ + patch('sys.exit'): + + # Setup test data without uid + test_result = { + 'run_id': '888', + 'account_name': 'test_account', + 'index': 'test_index', + 'sourcetype': 'databricks:databricksjob', + 'run_execution_status': 'Pending' + # Note: no 'uid' field + } + mock_get_results.return_value = ([test_result], {}, {'sessionKey': 'test_session_key'}) + + # Mock DatabricksClient + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.databricks_api.return_value = { + 'state': { + 'life_cycle_state': 'RUNNING' + } + } + + # Execute the script - should handle missing uid gracefully + self._run_script() + + # Verify ingestion still occurred + mock_ingest.assert_called_once() + + def test_non_databricksjob_sourcetype(self): + """Test that identifier is not removed for non-databricksjob sourcetype.""" + with patch('splunk.Intersplunk.getOrganizedResults') as mock_get_results, \ + patch('databricks_com.DatabricksClient') as mock_client_class, \ + patch('databricks_common_utils.ingest_data_to_splunk') as mock_ingest, \ + patch('time.time', return_value=1234567890.0), \ + patch('sys.exit'): + + # Setup test data with different sourcetype + test_result = { + 'run_id': '999', + 'account_name': 'test_account', + 'index': 'test_index', + 'sourcetype': 'databricks:other', # Not databricksjob + 'run_execution_status': 'Pending', + 'identifier': 'should_be_kept', + 'uid': 'test_uid_999' + } + mock_get_results.return_value = ([test_result], {}, {'sessionKey': 'test_session_key'}) + + # Mock DatabricksClient + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.databricks_api.return_value = { + 'state': { + 'life_cycle_state': 'RUNNING' + } + } + + # Execute the script + self._run_script() + + # Verify identifier was NOT removed + mock_ingest.assert_called_once() + call_args = mock_ingest.call_args + ingested_data = call_args[0][0] + self.assertEqual(ingested_data['identifier'], 'should_be_kept') + + def test_pending_state_no_change(self): + """Test PENDING state when status is already Pending.""" + with patch('splunk.Intersplunk.getOrganizedResults') as mock_get_results, \ + patch('databricks_com.DatabricksClient') as mock_client_class, \ + patch('databricks_common_utils.ingest_data_to_splunk') as mock_ingest, \ + patch('time.time', return_value=1234567890.0), \ + patch('sys.exit'): + + # Setup test data + test_result = { + 'run_id': '1000', + 'account_name': 'test_account', + 'index': 'test_index', + 'sourcetype': 'databricks:databricksjob', + 'run_execution_status': 'Pending', # Already Pending + 'uid': 'test_uid_1000' + } + mock_get_results.return_value = ([test_result], {}, {'sessionKey': 'test_session_key'}) + + # Mock DatabricksClient + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.databricks_api.return_value = { + 'state': { + 'life_cycle_state': 'PENDING' + } + } + + # Execute the script + self._run_script() + + # Verify no ingestion occurred since status didn't change + mock_ingest.assert_not_called() + + def test_other_state_no_change(self): + """Test other state when status hasn't changed.""" + with patch('splunk.Intersplunk.getOrganizedResults') as mock_get_results, \ + patch('databricks_com.DatabricksClient') as mock_client_class, \ + patch('databricks_common_utils.ingest_data_to_splunk') as mock_ingest, \ + patch('time.time', return_value=1234567890.0), \ + patch('sys.exit'): + + # Setup test data + test_result = { + 'run_id': '1001', + 'account_name': 'test_account', + 'index': 'test_index', + 'sourcetype': 'databricks:databricksjob', + 'run_execution_status': 'CUSTOM_STATE', # Same as result_state + 'uid': 'test_uid_1001' + } + mock_get_results.return_value = ([test_result], {}, {'sessionKey': 'test_session_key'}) + + # Mock DatabricksClient + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.databricks_api.return_value = { + 'state': { + 'life_cycle_state': 'OTHER', + 'result_state': 'CUSTOM_STATE' + } + } + + # Execute the script + self._run_script() + + # Verify no ingestion occurred since status didn't change + mock_ingest.assert_not_called() + + def test_none_response_from_api(self): + """Test handling when API returns None response.""" + with patch('splunk.Intersplunk.getOrganizedResults') as mock_get_results, \ + patch('databricks_com.DatabricksClient') as mock_client_class, \ + patch('databricks_common_utils.ingest_data_to_splunk') as mock_ingest, \ + patch('sys.exit'): + + # Setup test data + test_result = { + 'run_id': '1002', + 'account_name': 'test_account', + 'index': 'test_index', + 'sourcetype': 'databricks:databricksjob', + 'run_execution_status': 'Running', + 'uid': 'test_uid_1002' + } + mock_get_results.return_value = ([test_result], {}, {'sessionKey': 'test_session_key'}) + + # Mock DatabricksClient to return None + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.databricks_api.return_value = None + + # Execute the script + self._run_script() + + # Verify no ingestion occurred + mock_ingest.assert_not_called() diff --git a/tests/utility.py b/tests/utility.py index 0d2292d..ca2414d 100644 --- a/tests/utility.py +++ b/tests/utility.py @@ -1,10 +1,16 @@ class Response: """Sample Response Class.""" - def __init__(self, status_code): + def __init__(self, status_code, json_data=None): """Init Method for Response.""" self.status_code = status_code - + self._json_data = json_data if json_data is not None else {"status_code": self.status_code} + def json(self): """Set json value.""" - return {"status_code": self.status_code} \ No newline at end of file + return self._json_data + + def raise_for_status(self): + """Raise exception for non-2xx status codes.""" + if self.status_code >= 400: + raise Exception(f"HTTP {self.status_code}") \ No newline at end of file