Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 44 additions & 21 deletions app/bin/databricksquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,22 +325,58 @@ def handle_cluster_method():

def handle_dbsql_method(row_limit, thread_count):

def fetch_warehouse_status(id_of_warehouse):
def ensure_warehouse_running(id_of_warehouse):
"""Ensure warehouse is in RUNNING state, starting it if necessary.

Handles all warehouse states:
- RUNNING: Returns immediately
- STARTING: Waits for startup to complete
- STOPPING: Waits for stop to complete, then starts
- STOPPED: Starts the warehouse and waits
- Other states (DELETED, DELETING, etc.): Raises an error

Args:
id_of_warehouse: The warehouse ID to ensure is running
"""
start_was_requested = False

while True:
warehouse_resp = client.databricks_api(
"get",
const.SPECIFIC_WAREHOUSE_STATUS_ENDPOINT.format(id_of_warehouse)
)
if warehouse_resp.get("state").lower() == "starting":
time.sleep(30)
elif warehouse_resp.get("state").lower() == "running":
_LOGGER.info("Warehouse started successfully.")
current_state = warehouse_resp.get("state", "").lower()

if current_state == "running":
if start_was_requested:
_LOGGER.info("Warehouse started successfully.")
else:
_LOGGER.info("Warehouse is already running.")
break
elif current_state == "starting":
_LOGGER.info("Warehouse is in STARTING state, waiting...")
time.sleep(3)
elif current_state == "stopping":
_LOGGER.info("Warehouse is in STOPPING state, waiting for it to stop...")
time.sleep(3)
elif current_state == "stopped":
if start_was_requested:
# After calling start API, warehouse may briefly still show as STOPPED
# before transitioning to STARTING. Wait and retry.
_LOGGER.info("Warehouse still in STOPPED state after start request, "
"waiting for state transition...")
time.sleep(3)
else:
_LOGGER.info("Warehouse is in STOPPED state. Starting the warehouse.")
client.databricks_api(
"post", const.WAREHOUSE_START_ENDPOINT.format(id_of_warehouse)
)
start_was_requested = True
else:
err = "Warehouse is not in RUNNING or STARTING state. Current SQL warehouse state is {}."
err = "Warehouse cannot be started. Current SQL warehouse state is {}."
raise Exception(err.format(warehouse_resp.get("state")))

# Check whether SQL Warehouse exists. If yes, check its status.
# Check whether SQL Warehouse exists. If yes, ensure it's running.
warehouse_exist = False
list_of_links = []
list_of_chunk_number = []
Expand All @@ -350,20 +386,7 @@ def fetch_warehouse_status(id_of_warehouse):
if res.get("id") == self.warehouse_id:
warehouse_exist = True
if res.get("state").lower() != "running":
try:
if res.get("state").lower() == "starting":
_LOGGER.info("Warehouse is not in RUNNING state. It is in STARTING state.")
time.sleep(30)
fetch_warehouse_status(self.warehouse_id)
else:
_LOGGER.info("Warehouse is not in RUNNING or STARTING state. "
"Starting the warehouse.")
client.databricks_api(
"post", const.WAREHOUSE_START_ENDPOINT.format(self.warehouse_id)
)
fetch_warehouse_status(self.warehouse_id)
except Exception as err:
raise Exception(err)
ensure_warehouse_running(self.warehouse_id)
break
if not warehouse_exist:
raise Exception("No SQL warehouse found with ID: {}. Provide a valid SQL warehouse ID."
Expand Down
201 changes: 191 additions & 10 deletions tests/test_databricksquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,9 @@ def test_warehouse_starting_state(self, mock_sleep, mock_utils, mock_com):
client.databricks_api.side_effect = [
# Initial warehouse list - STARTING
{"warehouses": [{"id": "w123", "state": "STARTING"}]},
# Status check - now RUNNING
# ensure_warehouse_running: status check - still STARTING
{"state": "STARTING"},
# ensure_warehouse_running: status check - now RUNNING
{"state": "RUNNING"},
# Execute query
{"statement_id": "stmt123"},
Expand Down Expand Up @@ -837,9 +839,11 @@ def test_warehouse_stopped_then_started(self, mock_sleep, mock_utils, mock_com):
client.databricks_api.side_effect = [
# Initial warehouse list - STOPPED
{"warehouses": [{"id": "w123", "state": "STOPPED"}]},
# Start warehouse call
# ensure_warehouse_running: status check - STOPPED
{"state": "STOPPED"},
# ensure_warehouse_running: Start warehouse call
{},
# Status check - now RUNNING
# ensure_warehouse_running: status check - now RUNNING
{"state": "RUNNING"},
# Execute query
{"statement_id": "stmt123"},
Expand Down Expand Up @@ -1213,12 +1217,10 @@ def test_warehouse_invalid_state(self, mock_utils, mock_com):
client = mock_com.return_value = MagicMock()

client.databricks_api.side_effect = [
# Warehouse in ERROR state
{"warehouses": [{"id": "w123", "state": "ERROR"}]},
# Start warehouse call
{},
# Status check returns ERROR state (never becomes RUNNING)
{"state": "ERROR"}
# Warehouse in DELETED state (invalid)
{"warehouses": [{"id": "w123", "state": "DELETED"}]},
# ensure_warehouse_running: status check returns DELETED state
{"state": "DELETED"}
]

resp = db_query_obj.generate()
Expand All @@ -1230,7 +1232,7 @@ def test_warehouse_invalid_state(self, mock_utils, mock_com):
# Verify error was written
db_query_obj.write_error.assert_called()
error_msg = db_query_obj.write_error.call_args[0][0]
self.assertIn("ERROR", error_msg)
self.assertIn("DELETED", error_msg)

@patch("databricksquery.com.DatabricksClient", autospec=True)
@patch("databricksquery.utils", autospec=True)
Expand Down Expand Up @@ -1411,3 +1413,182 @@ def test_no_limit_uses_admin_default(self, mock_utils, mock_com):
db_query_obj.write_warning.assert_not_called()
# Verify limit was not set (remains None, uses admin default)
self.assertIsNone(db_query_obj.limit)

@patch("databricksquery.com.DatabricksClient", autospec=True)
@patch("databricksquery.utils", autospec=True)
@patch("databricksquery.time.sleep")
def test_warehouse_stopped_delayed_transition(self, mock_sleep, mock_utils, mock_com):
"""Test warehouse in STOPPED state with delayed transition to STARTING after start API call.

This tests the race condition fix where after calling start API,
the warehouse may briefly still show as STOPPED before transitioning to STARTING.
"""
db_query_obj = self.DatabricksQueryCommand()
db_query_obj._metadata = MagicMock()
db_query_obj.warehouse_id = "w123"

mock_utils.get_databricks_configs.return_value = {
"auth_type": "PAT",
"databricks_instance": "test.databricks.com",
"databricks_pat": "test_token",
"admin_command_timeout": "300",
"query_result_limit": "1000",
"thread_count": "4"
}

client = mock_com.return_value = MagicMock()

client.databricks_api.side_effect = [
# Initial warehouse list - STOPPED
{"warehouses": [{"id": "w123", "state": "STOPPED"}]},
# ensure_warehouse_running: status check - STOPPED
{"state": "STOPPED"},
# ensure_warehouse_running: Start warehouse call
{},
# ensure_warehouse_running: status check - still STOPPED (race condition)
{"state": "STOPPED"},
# ensure_warehouse_running: status check - now STARTING
{"state": "STARTING"},
# ensure_warehouse_running: status check - now RUNNING
{"state": "RUNNING"},
# Execute query
{"statement_id": "stmt123"},
# Query status - SUCCEEDED
{
"status": {"state": "SUCCEEDED"},
"manifest": {
"total_row_count": 0,
"truncated": False,
"schema": {"columns": []}
},
"result": {"external_links": []}
}
]

resp = db_query_obj.generate()
try:
list(resp)
except (StopIteration, SystemExit):
pass

# Verify warehouse start was called
calls = [call for call in client.databricks_api.call_args_list
if len(call[0]) > 1 and 'start' in str(call[0][1])]
self.assertTrue(len(calls) > 0)

# Verify sleep was called (for waiting during state transitions)
mock_sleep.assert_called()

@patch("databricksquery.com.DatabricksClient", autospec=True)
@patch("databricksquery.utils", autospec=True)
@patch("databricksquery.time.sleep")
def test_warehouse_stopping_state(self, mock_sleep, mock_utils, mock_com):
"""Test warehouse in STOPPING state, waits for STOPPED, then starts and runs query"""
db_query_obj = self.DatabricksQueryCommand()
db_query_obj._metadata = MagicMock()
db_query_obj.warehouse_id = "w123"

mock_utils.get_databricks_configs.return_value = {
"auth_type": "PAT",
"databricks_instance": "test.databricks.com",
"databricks_pat": "test_token",
"admin_command_timeout": "300",
"query_result_limit": "1000",
"thread_count": "4"
}

client = mock_com.return_value = MagicMock()

client.databricks_api.side_effect = [
# Initial warehouse list - STOPPING
{"warehouses": [{"id": "w123", "state": "STOPPING"}]},
# Status check - still STOPPING
{"state": "STOPPING"},
# Status check - now STOPPED
{"state": "STOPPED"},
# Start warehouse call
{},
# Status check - now RUNNING
{"state": "RUNNING"},
# Execute query
{"statement_id": "stmt123"},
# Query status - SUCCEEDED
{
"status": {"state": "SUCCEEDED"},
"manifest": {
"total_row_count": 0,
"truncated": False,
"schema": {"columns": []}
},
"result": {"external_links": []}
}
]

resp = db_query_obj.generate()
try:
list(resp)
except (StopIteration, SystemExit):
pass

# Verify warehouse start was called
calls = [call for call in client.databricks_api.call_args_list
if len(call[0]) > 1 and 'start' in str(call[0][1])]
self.assertTrue(len(calls) > 0)

# Verify sleep was called while waiting for warehouse to stop
mock_sleep.assert_called()

@patch("databricksquery.com.DatabricksClient", autospec=True)
@patch("databricksquery.utils", autospec=True)
@patch("databricksquery.time.sleep")
def test_warehouse_stopping_transitions_to_running(self, mock_sleep, mock_utils, mock_com):
"""Test warehouse in STOPPING state that transitions directly to RUNNING (edge case)

This can happen if someone else starts the warehouse while we're waiting for it to stop.
"""
db_query_obj = self.DatabricksQueryCommand()
db_query_obj._metadata = MagicMock()
db_query_obj.warehouse_id = "w123"

mock_utils.get_databricks_configs.return_value = {
"auth_type": "PAT",
"databricks_instance": "test.databricks.com",
"databricks_pat": "test_token",
"admin_command_timeout": "300",
"query_result_limit": "1000",
"thread_count": "4"
}

client = mock_com.return_value = MagicMock()

client.databricks_api.side_effect = [
# Initial warehouse list - STOPPING
{"warehouses": [{"id": "w123", "state": "STOPPING"}]},
# Status check - transitions directly to RUNNING (someone else started it)
{"state": "RUNNING"},
# Start warehouse call (will still be made but warehouse is already running)
{},
# Status check - RUNNING
{"state": "RUNNING"},
# Execute query
{"statement_id": "stmt123"},
# Query status - SUCCEEDED
{
"status": {"state": "SUCCEEDED"},
"manifest": {
"total_row_count": 0,
"truncated": False,
"schema": {"columns": []}
},
"result": {"external_links": []}
}
]

resp = db_query_obj.generate()
try:
list(resp)
except (StopIteration, SystemExit):
pass

# Query should complete successfully
# The warehouse start may or may not be called depending on timing