Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
run: pip install -r requirements-dev.txt
- name: Run tests and generate coverage
working-directory: ./
run: pytest --cov-config=tests/.coveragerc --cov=app/bin --cov-report=xml:coverage.xml tests
run: pytest --cov-config=tests/.coveragerc --cov=app/bin --cov-report=xml:coverage.xml --cov-fail-under=80 tests
- name: Publish test coverage
uses: codecov/codecov-action@v1

75 changes: 61 additions & 14 deletions app/README.md

Large diffs are not rendered by default.

34 changes: 19 additions & 15 deletions app/appserver/static/js/build/custom/auth_select_hook.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@ class AuthSelectHook {

onChange(field, value, dataDict) {
if (field == 'auth_type') {
if (value == 'AAD') {
this.toggleAADFields(true);
} else {
this.toggleAADFields(false);
}
this.toggleAuthFields(value);
}
if (field == 'config_for_dbquery') {
if (value == 'interactive_cluster') {
Expand All @@ -26,11 +22,7 @@ class AuthSelectHook {

onRender() {
var selected_auth = this.state.data.auth_type.value;
if (selected_auth == 'AAD') {
this.toggleAADFields(true);
} else {
this.toggleAADFields(false);
}
this.toggleAuthFields(selected_auth);
}

hideWarehouseField(state) {
Expand All @@ -41,13 +33,25 @@ class AuthSelectHook {
});
}

toggleAADFields(state) {
toggleAuthFields(authType) {
this.util.setState((prevState) => {
let data = {...prevState.data };
data.aad_client_id.display = state;
data.aad_tenant_id.display = state;
data.aad_client_secret.display = state;
data.databricks_pat.display = !state;

// OAuth M2M fields
const showOAuth = (authType === 'OAUTH_M2M');
data.oauth_client_id.display = showOAuth;
data.oauth_client_secret.display = showOAuth;

// AAD fields
const showAAD = (authType === 'AAD');
data.aad_client_id.display = showAAD;
data.aad_tenant_id.display = showAAD;
data.aad_client_secret.display = showAAD;

// PAT field
const showPAT = (authType === 'PAT');
data.databricks_pat.display = showPAT;

return { data }
});
}
Expand Down
42 changes: 41 additions & 1 deletion app/appserver/static/js/build/globalConfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@
{
"label": "Azure Active Directory",
"value": "AAD"
},
{
"label": "Databricks OAuth (M2M)",
"value": "OAUTH_M2M"
}
]
},
Expand Down Expand Up @@ -205,6 +209,42 @@
"placeholder": "required"
}
},
{
"field": "oauth_client_id",
"label": "OAuth Client ID",
"type": "text",
"help": "Enter the Client ID from your Databricks service principal.",
"required": false,
"defaultValue": "",
"encrypted": false,
"validators": [{
"type": "string",
"minLength": 0,
"maxLength": 200,
"errorMsg": "Max length of OAuth Client ID is 200"
}],
"options": {
"placeholder": "required"
}
},
{
"field": "oauth_client_secret",
"label": "OAuth Client Secret",
"type": "text",
"help": "Enter the OAuth secret from your Databricks service principal.",
"required": false,
"defaultValue": "",
"encrypted": true,
"validators": [{
"type": "string",
"minLength": 0,
"maxLength": 500,
"errorMsg": "Max length of OAuth Client Secret is 500"
}],
"options": {
"placeholder": "required"
}
},
{
"field": "databricks_pat",
"label": "Databricks Access Token",
Expand Down Expand Up @@ -318,7 +358,7 @@
"field": "use_for_oauth",
"label": "Use Proxy for OAuth",
"defaultValue": 0,
"tooltip": "Check this box if you want to use proxy just for AAD token generation (https://login.microsoftonline.com/). All other network calls will skip the proxy even if it's enabled.",
"tooltip": "Check this box if you want to use proxy just for Azure AD token generation (https://login.microsoftonline.com/). All other network calls (including Databricks API and OAuth M2M) will skip the proxy even if it's enabled.",
"type": "checkbox"
}
],
Expand Down
29 changes: 29 additions & 0 deletions app/bin/TA_Databricks_rh_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,35 @@
default='',
validator=None
),
field.RestField(
'oauth_client_id',
required=False,
encrypted=False,
default='',
validator=validator.String(
min_len=0,
max_len=200,
)
),
field.RestField(
'oauth_client_secret',
required=False,
encrypted=True,
default='',
validator=None
),
field.RestField(
'oauth_access_token',
required=False,
encrypted=True
),
field.RestField(
'oauth_token_expiration',
required=False,
encrypted=False,
default='',
validator=None
),
field.RestField(
'databricks_pat',
required=False,
Expand Down
60 changes: 59 additions & 1 deletion app/bin/databricks_com.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,17 @@ def __init__(self, account_name, session_key):
self.session.timeout = const.TIMEOUT
if self.auth_type == "PAT":
self.databricks_token = databricks_configs.get("databricks_pat")
else:
elif self.auth_type == "AAD":
self.databricks_token = databricks_configs.get("aad_access_token")
self.aad_client_id = databricks_configs.get("aad_client_id")
self.aad_tenant_id = databricks_configs.get("aad_tenant_id")
self.aad_client_secret = databricks_configs.get("aad_client_secret")
elif self.auth_type == "OAUTH_M2M":
self.databricks_token = databricks_configs.get("oauth_access_token")
self.oauth_client_id = databricks_configs.get("oauth_client_id")
self.oauth_client_secret = databricks_configs.get("oauth_client_secret")
oauth_token_expiration_str = databricks_configs.get("oauth_token_expiration")
self.oauth_token_expiration = float(oauth_token_expiration_str) if oauth_token_expiration_str else 0

if not all([databricks_instance, self.databricks_token]):
raise Exception("Addon is not configured. Navigate to addon's configuration page to configure the addon.")
Expand Down Expand Up @@ -98,6 +104,48 @@ def get_requests_retry_session(self):
session.mount("https://", adapter)
return session

def should_refresh_oauth_token(self):
"""
Check if OAuth token should be refreshed proactively.

:return: Boolean - True if token expires within 5 minutes
"""
if not hasattr(self, 'oauth_token_expiration'):
return False

import time
current_time = time.time()
time_until_expiry = self.oauth_token_expiration - current_time

# Refresh if token expires within 5 minutes (300 seconds)
return time_until_expiry < 300

def _refresh_oauth_token(self):
"""Refresh OAuth M2M access token and update session headers."""
databricks_configs = utils.get_databricks_configs(self.session_key, self.account_name)
proxy_config = databricks_configs.get("proxy_uri")

result = utils.get_oauth_access_token(
self.session_key,
self.account_name,
self.databricks_instance_url.replace("https://", ""),
self.oauth_client_id,
self.oauth_client_secret,
proxy_config,
retry=const.RETRIES,
conf_update=True
)

if isinstance(result, tuple) and result[1] == False:
raise Exception(result[0])

access_token, expires_in = result
self.databricks_token = access_token
import time
self.oauth_token_expiration = time.time() + expires_in
self.request_headers["Authorization"] = "Bearer {}".format(self.databricks_token)
self.session.headers.update(self.request_headers)

def databricks_api(self, method, endpoint, data=None, args=None):
"""
Common method to hit the API of Databricks instance.
Expand All @@ -108,6 +156,11 @@ def databricks_api(self, method, endpoint, data=None, args=None):
:param args: Arguments to be add into the url
:return: response in the form of dictionary
"""
# Proactive OAuth token refresh
if self.auth_type == "OAUTH_M2M" and self.should_refresh_oauth_token():
_LOGGER.info("OAuth token expiring soon, refreshing proactively.")
self._refresh_oauth_token()

run_again = True
request_url = "{}{}".format(self.databricks_instance_url, endpoint)
try:
Expand Down Expand Up @@ -141,6 +194,11 @@ def databricks_api(self, method, endpoint, data=None, args=None):
self.databricks_token = db_token
self.request_headers["Authorization"] = "Bearer {}".format(self.databricks_token)
self.session.headers.update(self.request_headers)
elif status_code == 403 and self.auth_type == "OAUTH_M2M" and run_again:
response = None
run_again = False
_LOGGER.info("Refreshing OAuth M2M token.")
self._refresh_oauth_token()
elif status_code != 200:
response.raise_for_status()
else:
Expand Down
131 changes: 131 additions & 0 deletions app/bin/databricks_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,137 @@ def save_databricks_aad_access_token(account_name, session_key, access_token, cl
raise Exception("Exception while saving AAD access token.")


def save_databricks_oauth_access_token(account_name, session_key, access_token, expires_in, client_secret):
"""
Method to store new OAuth access token with expiration timestamp.

:param account_name: Account name
:param session_key: Splunk session key
:param access_token: OAuth access token
:param expires_in: Token lifetime in seconds
:param client_secret: OAuth client secret
:return: None
"""
import time
new_creds = {
"name": account_name,
"oauth_client_secret": client_secret,
"oauth_access_token": access_token,
"oauth_token_expiration": str(time.time() + expires_in),
"update_token": True
}
try:
_LOGGER.info("Saving databricks OAuth access token.")
rest.simpleRequest(
"/databricks_get_credentials",
sessionKey=session_key,
postargs=new_creds,
raiseAllErrors=True,
)
_LOGGER.info("Saved OAuth access token successfully.")
except Exception as e:
_LOGGER.error("Exception while saving OAuth access token: {}".format(str(e)))
_LOGGER.debug(traceback.format_exc())
raise Exception("Exception while saving OAuth access token.")


def get_oauth_access_token(
session_key,
account_name,
databricks_instance,
oauth_client_id,
oauth_client_secret,
proxy_settings=None,
retry=1,
conf_update=False,
):
"""
Method to acquire OAuth M2M access token for Databricks service principal.

:param session_key: Splunk session key
:param account_name: Account name for configuration storage
:param databricks_instance: Databricks workspace instance URL
:param oauth_client_id: OAuth client ID from service principal
:param oauth_client_secret: OAuth client secret from service principal
:param proxy_settings: Proxy configuration dict
:param retry: Number of retry attempts
:param conf_update: If True, store token in configuration
:return: tuple (access_token, expires_in) or (error_message, False)
"""
import time
from requests.auth import HTTPBasicAuth

token_url = "https://{}/oidc/v1/token".format(databricks_instance)
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": "{}".format(const.USER_AGENT_CONST),
}
_LOGGER.debug("Request made to the Databricks from Splunk user: {}".format(get_current_user(session_key)))
data_dict = {"grant_type": "client_credentials", "scope": "all-apis"}
data_encoded = urlencode(data_dict)

# Handle proxy settings for OAuth M2M
# Note: "use_for_oauth" means "use proxy ONLY for AAD token generation"
# Since OAuth M2M endpoint is on the Databricks instance (not AAD),
# we should skip proxy when use_for_oauth is true
if proxy_settings:
if is_true(proxy_settings.get("use_for_oauth")):
_LOGGER.info(
"Skipping the usage of proxy for OAuth M2M as 'Use Proxy for OAuth' parameter is checked."
)
proxy_settings_copy = None
else:
proxy_settings_copy = proxy_settings.copy()
proxy_settings_copy.pop("use_for_oauth", None)
else:
proxy_settings_copy = None

while retry:
try:
resp = requests.post(
token_url,
headers=headers,
data=data_encoded,
auth=HTTPBasicAuth(oauth_client_id, oauth_client_secret),
proxies=proxy_settings_copy,
verify=const.VERIFY_SSL,
timeout=const.TIMEOUT
)
resp.raise_for_status()
response = resp.json()
oauth_access_token = response.get("access_token")
expires_in = response.get("expires_in", 3600)
if conf_update:
save_databricks_oauth_access_token(
account_name, session_key, oauth_access_token, expires_in, oauth_client_secret
)
return oauth_access_token, expires_in
except Exception as e:
retry -= 1
if "resp" in locals():
error_code = resp.json().get("error")
if error_code and error_code in list(const.ERROR_CODE.keys()):
msg = const.ERROR_CODE[error_code]
elif str(resp.status_code) in list(const.ERROR_CODE.keys()):
msg = const.ERROR_CODE[str(resp.status_code)]
elif resp.status_code not in (200, 201):
msg = (
"Response status: {}. Unable to validate OAuth credentials. "
"Check logs for more details.".format(str(resp.status_code))
)
else:
msg = (
"Unable to request Databricks instance. "
"Please validate the provided Databricks and "
"Proxy configurations or check the network connectivity."
)
_LOGGER.error("Error while trying to generate OAuth access token: {}".format(str(e)))
_LOGGER.debug(traceback.format_exc())
_LOGGER.error(msg)
if retry == 0:
return msg, False


def get_proxy_clear_password(session_key):
"""
Get clear password from splunk passwords.conf.
Expand Down
Loading