Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions app/README/alert_actions.conf.spec
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ param._cam = <json> Active response parameters.
param.notebook = <string> Notebook
param.paramOne = <string> Field Name for Parameter One
param.paramTwo = <string> Field Name for Parameter Two
param.account_name = <string> Account Name

[launch_notebook]
python.version = python3
Expand All @@ -12,4 +13,5 @@ param.revision_timestamp = <string> Revision Timestamp.
param.notebook_parameters = <string> Notebook Parameters.
param.cluster_name = <string> Cluster Name.
param.run_name = <string> Run Name.
param.account_name = <string> Account Name.
param._cam = <json> Active response parameters.
5 changes: 5 additions & 0 deletions app/README/ta_databricks_account.conf.spec
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
[<name>]
databricks_instance =
auth_type =
aad_client_id =
aad_tenant_id =
aad_client_secret =
aad_access_token =
oauth_client_id =
oauth_client_secret =
oauth_access_token =
oauth_token_expiration =
config_for_dbquery =
cluster_name =
warehouse_id =
Expand Down
3 changes: 2 additions & 1 deletion app/README/ta_databricks_settings.conf.spec
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ loglevel =
[additional_parameters]
admin_command_timeout =
query_result_limit =
index =
index =
thread_count =
7 changes: 7 additions & 0 deletions app/bin/TA_Databricks_rh_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@
'aad_access_token',
required=False,
encrypted=True
),
field.RestField(
'aad_token_expiration',
required=False,
encrypted=False,
default='',
validator=None
)
]
model_databricks_credentials = RestModel(fields, name=None)
Expand Down
122 changes: 98 additions & 24 deletions app/bin/databricks_com.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def __init__(self, account_name, session_key):
self.aad_client_id = databricks_configs.get("aad_client_id")
self.aad_tenant_id = databricks_configs.get("aad_tenant_id")
self.aad_client_secret = databricks_configs.get("aad_client_secret")
aad_token_expiration_str = databricks_configs.get("aad_token_expiration")
self.aad_token_expiration = float(aad_token_expiration_str) if aad_token_expiration_str else 0
elif self.auth_type == "OAUTH_M2M":
self.databricks_token = databricks_configs.get("oauth_access_token")
self.oauth_client_id = databricks_configs.get("oauth_client_id")
Expand Down Expand Up @@ -104,6 +106,22 @@ def get_requests_retry_session(self):
session.mount("https://", adapter)
return session

def should_refresh_aad_token(self):
"""
Check if AAD token should be refreshed proactively.

:return: Boolean - True if token expires within 5 minutes
"""
if not hasattr(self, 'aad_token_expiration') or self.aad_token_expiration == 0:
return False

import time
current_time = time.time()
time_until_expiry = self.aad_token_expiration - current_time

# Refresh if token expires within 5 minutes (300 seconds)
return time_until_expiry < 300

def should_refresh_oauth_token(self):
"""
Check if OAuth token should be refreshed proactively.
Expand All @@ -120,6 +138,74 @@ def should_refresh_oauth_token(self):
# Refresh if token expires within 5 minutes (300 seconds)
return time_until_expiry < 300

def _is_token_expired_response(self, response):
"""
Check if the API response indicates an expired token.

Databricks API may return different status codes for expired tokens:
- 403 Forbidden (legacy)
- 401 Unauthorized with "Token is expired"
- 400 Bad Request with "ExpiredJwtException"

:param response: requests.Response object
:return: Boolean - True if response indicates expired token
"""
if response is None:
return False

status_code = response.status_code

# 403 Forbidden - legacy expired token response
if status_code == 403:
return True

# 401 Unauthorized - check for token expiry message
if status_code == 401:
try:
response_text = response.text.lower()
if "token is expired" in response_text or "expired" in response_text:
return True
except Exception:
pass
return True # Treat all 401s as potentially expired tokens

# 400 Bad Request - check for ExpiredJwtException
if status_code == 400:
try:
response_text = response.text.lower()
if "expiredjwtexception" in response_text or "expired" in response_text:
return True
except Exception:
pass

return False

def _refresh_aad_token(self):
"""Refresh AAD access token and update session headers."""
databricks_configs = utils.get_databricks_configs(self.session_key, self.account_name)
proxy_config = databricks_configs.get("proxy_uri")

result = utils.get_aad_access_token(
self.session_key,
self.account_name,
self.aad_tenant_id,
self.aad_client_id,
self.aad_client_secret,
proxy_config,
retry=const.RETRIES,
conf_update=True
)

if isinstance(result, tuple) and result[1] == False:
raise Exception(result[0])

access_token, expires_in = result
self.databricks_token = access_token
import time
self.aad_token_expiration = time.time() + expires_in
self.request_headers["Authorization"] = "Bearer {}".format(self.databricks_token)
self.session.headers.update(self.request_headers)

def _refresh_oauth_token(self):
"""Refresh OAuth M2M access token and update session headers."""
databricks_configs = utils.get_databricks_configs(self.session_key, self.account_name)
Expand Down Expand Up @@ -156,8 +242,11 @@ def databricks_api(self, method, endpoint, data=None, args=None):
:param args: Arguments to be add into the url
:return: response in the form of dictionary
"""
# Proactive OAuth token refresh
if self.auth_type == "OAUTH_M2M" and self.should_refresh_oauth_token():
# Proactive token refresh
if self.auth_type == "AAD" and self.should_refresh_aad_token():
_LOGGER.info("AAD token expiring soon, refreshing proactively.")
self._refresh_aad_token()
elif self.auth_type == "OAUTH_M2M" and self.should_refresh_oauth_token():
_LOGGER.info("OAuth token expiring soon, refreshing proactively.")
self._refresh_oauth_token()

Expand All @@ -172,32 +261,16 @@ def databricks_api(self, method, endpoint, data=None, args=None):
_LOGGER.info("Executing REST call: {} Payload: {}.".format(endpoint, str(data)))
response = self.session.post(request_url, params=args, json=data, timeout=self.session.timeout)
status_code = response.status_code
if status_code == 403 and self.auth_type == "AAD" and run_again:
# Check if token is expired and needs refresh (handles 403, 401, 400 with expiry messages)
if self._is_token_expired_response(response) and self.auth_type == "AAD" and run_again:
_LOGGER.info("Token expired (status: {}). Refreshing AAD token.".format(status_code))
response = None
run_again = False
_LOGGER.info("Refreshing AAD token.")
databricks_configs = utils.get_databricks_configs(self.session_key, self.account_name)
proxy_config = databricks_configs.get("proxy_uri")
db_token = utils.get_aad_access_token(
self.session_key,
self.account_name,
self.aad_tenant_id,
self.aad_client_id,
self.aad_client_secret,
proxy_config, # Using the reinit proxy. As proxy is getting updated on Line no: 43, 45
retry=const.RETRIES, # based on the condition and for this call we will always need proxy.
conf_update=True, # By passing True, the AAD access token will be updated in conf
)
if isinstance(db_token, tuple):
raise Exception(db_token[0])
else:
self.databricks_token = db_token
self.request_headers["Authorization"] = "Bearer {}".format(self.databricks_token)
self.session.headers.update(self.request_headers)
elif status_code == 403 and self.auth_type == "OAUTH_M2M" and run_again:
self._refresh_aad_token()
elif self._is_token_expired_response(response) and self.auth_type == "OAUTH_M2M" and run_again:
_LOGGER.info("Token expired (status: {}). Refreshing OAuth M2M token.".format(status_code))
response = None
run_again = False
_LOGGER.info("Refreshing OAuth M2M token.")
self._refresh_oauth_token()
elif status_code != 200:
response.raise_for_status()
Expand All @@ -215,6 +288,7 @@ def databricks_api(self, method, endpoint, data=None, args=None):
if "response" in locals() and response is not None:
status_code_messages = {
400: response.json().get("message", "Bad request. The request is malformed."),
401: "Unauthorized. Access token may be invalid or expired.",
403: "Invalid access token. Please enter the valid access token.",
404: "Invalid API endpoint.",
429: "API limit exceeded. Please try again after some time.",
Expand Down
27 changes: 22 additions & 5 deletions app/bin/databricks_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,25 @@ def get_databricks_configs(session_key, account_name):
return configs_dict


def save_databricks_aad_access_token(account_name, session_key, access_token, client_sec):
def save_databricks_aad_access_token(account_name, session_key, access_token, client_sec, expires_in=None):
"""
Method to store new AAD access token.
Method to store new AAD access token with expiration timestamp.

:param account_name: Account name
:param session_key: Splunk session key
:param access_token: AAD access token
:param client_sec: AAD client secret
:param expires_in: Token lifetime in seconds (optional, defaults to 3600)
:return: None
"""
import time
if expires_in is None:
expires_in = 3600 # Default to 1 hour if not provided
new_creds = {
"name": account_name,
"aad_client_secret": client_sec,
"aad_access_token": access_token,
"aad_token_expiration": str(time.time() + expires_in),
"update_token": True
}
try:
Expand Down Expand Up @@ -434,7 +443,14 @@ def get_aad_access_token(
Method to acquire a new AAD access token.

:param session_key: Splunk session key
:return: access token
:param account_name: Account name for configuration storage
:param aad_tenant_id: Azure AD tenant ID
:param aad_client_id: Azure AD client ID
:param aad_client_secret: Azure AD client secret
:param proxy_settings: Proxy configuration dict
:param retry: Number of retry attempts
:param conf_update: If True, store token in configuration
:return: tuple (access_token, expires_in) or (error_message, False)
"""
token_url = const.AAD_TOKEN_ENDPOINT.format(aad_tenant_id)
headers = {
Expand Down Expand Up @@ -464,11 +480,12 @@ def get_aad_access_token(
resp.raise_for_status()
response = resp.json()
aad_access_token = response.get("access_token")
expires_in = response.get("expires_in", 3600) # Default to 1 hour
if conf_update:
save_databricks_aad_access_token(
account_name, session_key, aad_access_token, aad_client_secret
account_name, session_key, aad_access_token, aad_client_secret, expires_in
)
return aad_access_token
return aad_access_token, expires_in
except Exception as e:
retry -= 1
if "resp" in locals():
Expand Down
5 changes: 3 additions & 2 deletions app/bin/databricks_const.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,12 @@
"700016": "Invalid Client ID provided.",
"900023": "Invalid Tenant ID provided.",
"7000215": "Invalid Client Secret provided.",
"400": "Bad request. The request is malformed.",
"401": "Unauthorized. Access token may be invalid or expired.",
"403": "Client secret may have expired. Please configure a valid Client secret.",
"404": "Invalid API endpoint.",
"500": "Internal server error.",
"400": "Bad request. The request is malformed.",
"429": "API limit exceeded. Please try again after some time.",
"500": "Internal server error.",
"invalid_client": "Invalid OAuth Client ID or Client Secret provided.",
"unauthorized_client": "Service principal is not authorized for this workspace.",
"invalid_grant": "The provided OAuth credentials are invalid or expired.",
Expand Down
9 changes: 8 additions & 1 deletion app/bin/databricks_get_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@ def handle(self, in_string):
_LOGGER.info("Saving databricks AAD access token.")
client_sec = form_data.get("aad_client_secret")
access_token = form_data.get("aad_access_token")
new_creds = json.dumps({"aad_client_secret": client_sec, "aad_access_token": access_token})
token_expiration = form_data.get("aad_token_expiration")
new_creds = json.dumps({
"aad_client_secret": client_sec,
"aad_access_token": access_token,
"aad_token_expiration": token_expiration
})
success_msg = 'Saved AAD access token successfully.'
elif form_data.get("oauth_access_token"):
_LOGGER.info("Saving databricks OAuth access token.")
Expand Down Expand Up @@ -96,6 +101,7 @@ def handle(self, in_string):
'aad_tenant_id': None,
'aad_client_secret': None,
'aad_access_token': None,
'aad_token_expiration': None,
'oauth_client_id': None,
'oauth_client_secret': None,
'oauth_access_token': None,
Expand Down Expand Up @@ -156,6 +162,7 @@ def handle(self, in_string):
config_dict['aad_tenant_id'] = account_config.get('aad_tenant_id')
config_dict['aad_client_secret'] = account_password.get('aad_client_secret')
config_dict['aad_access_token'] = account_password.get('aad_access_token')
config_dict['aad_token_expiration'] = account_password.get('aad_token_expiration')
elif config_dict['auth_type'] == 'OAUTH_M2M':
config_dict['oauth_client_id'] = account_config.get('oauth_client_id')
config_dict['oauth_client_secret'] = account_password.get('oauth_client_secret')
Expand Down
17 changes: 11 additions & 6 deletions app/bin/databricks_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,28 @@ def validate_aad(self, data):
:param data: Dictionary containing values from configuration UI.
:return: Boolean depending on the sucess of the connection
"""
import time
_LOGGER.info('Obtaining Azure Active Directory access token')
aad_client_id = data.get("aad_client_id").strip()
client_sec = data.get("aad_client_secret").strip()
aad_tenant_id = data.get("aad_tenant_id").strip()
account_name = data.get("name")
aad_access_token = utils.get_aad_access_token(
result = utils.get_aad_access_token(
self._splunk_session_key, account_name,
aad_tenant_id, aad_client_id, client_sec, self._proxy_settings)
if isinstance(aad_access_token, tuple):
_LOGGER.error(aad_access_token[0])
self.put_msg(aad_access_token[0])

if isinstance(result, tuple) and result[1] == False:
_LOGGER.error(result[0])
self.put_msg(result[0])
return False

access_token, expires_in = result
_LOGGER.info('Obtained Azure Active Directory access token Successfully.')
databricks_instance = data.get("databricks_instance").strip("/")
valid_instance = self.validate_db_instance(databricks_instance, aad_access_token)
valid_instance = self.validate_db_instance(databricks_instance, access_token)
if valid_instance:
data["aad_access_token"] = aad_access_token
data["aad_access_token"] = access_token
data["aad_token_expiration"] = str(time.time() + expires_in)
data["databricks_pat"] = ""
return True
else:
Expand Down
Loading