diff --git a/app/bin/databricksquery.py b/app/bin/databricksquery.py index 8769c50..82b561d 100755 --- a/app/bin/databricksquery.py +++ b/app/bin/databricksquery.py @@ -325,22 +325,58 @@ def handle_cluster_method(): def handle_dbsql_method(row_limit, thread_count): - def fetch_warehouse_status(id_of_warehouse): + def ensure_warehouse_running(id_of_warehouse): + """Ensure warehouse is in RUNNING state, starting it if necessary. + + Handles all warehouse states: + - RUNNING: Returns immediately + - STARTING: Waits for startup to complete + - STOPPING: Waits for stop to complete, then starts + - STOPPED: Starts the warehouse and waits + - Other states (DELETED, DELETING, etc.): Raises an error + + Args: + id_of_warehouse: The warehouse ID to ensure is running + """ + start_was_requested = False + while True: warehouse_resp = client.databricks_api( "get", const.SPECIFIC_WAREHOUSE_STATUS_ENDPOINT.format(id_of_warehouse) ) - if warehouse_resp.get("state").lower() == "starting": - time.sleep(30) - elif warehouse_resp.get("state").lower() == "running": - _LOGGER.info("Warehouse started successfully.") + current_state = warehouse_resp.get("state", "").lower() + + if current_state == "running": + if start_was_requested: + _LOGGER.info("Warehouse started successfully.") + else: + _LOGGER.info("Warehouse is already running.") break + elif current_state == "starting": + _LOGGER.info("Warehouse is in STARTING state, waiting...") + time.sleep(3) + elif current_state == "stopping": + _LOGGER.info("Warehouse is in STOPPING state, waiting for it to stop...") + time.sleep(3) + elif current_state == "stopped": + if start_was_requested: + # After calling start API, warehouse may briefly still show as STOPPED + # before transitioning to STARTING. Wait and retry. + _LOGGER.info("Warehouse still in STOPPED state after start request, " + "waiting for state transition...") + time.sleep(3) + else: + _LOGGER.info("Warehouse is in STOPPED state. Starting the warehouse.") + client.databricks_api( + "post", const.WAREHOUSE_START_ENDPOINT.format(id_of_warehouse) + ) + start_was_requested = True else: - err = "Warehouse is not in RUNNING or STARTING state. Current SQL warehouse state is {}." + err = "Warehouse cannot be started. Current SQL warehouse state is {}." raise Exception(err.format(warehouse_resp.get("state"))) - # Check whether SQL Warehouse exists. If yes, check its status. + # Check whether SQL Warehouse exists. If yes, ensure it's running. warehouse_exist = False list_of_links = [] list_of_chunk_number = [] @@ -350,20 +386,7 @@ def fetch_warehouse_status(id_of_warehouse): if res.get("id") == self.warehouse_id: warehouse_exist = True if res.get("state").lower() != "running": - try: - if res.get("state").lower() == "starting": - _LOGGER.info("Warehouse is not in RUNNING state. It is in STARTING state.") - time.sleep(30) - fetch_warehouse_status(self.warehouse_id) - else: - _LOGGER.info("Warehouse is not in RUNNING or STARTING state. " - "Starting the warehouse.") - client.databricks_api( - "post", const.WAREHOUSE_START_ENDPOINT.format(self.warehouse_id) - ) - fetch_warehouse_status(self.warehouse_id) - except Exception as err: - raise Exception(err) + ensure_warehouse_running(self.warehouse_id) break if not warehouse_exist: raise Exception("No SQL warehouse found with ID: {}. Provide a valid SQL warehouse ID." diff --git a/tests/test_databricksquery.py b/tests/test_databricksquery.py index c1386d4..a329a18 100644 --- a/tests/test_databricksquery.py +++ b/tests/test_databricksquery.py @@ -789,7 +789,9 @@ def test_warehouse_starting_state(self, mock_sleep, mock_utils, mock_com): client.databricks_api.side_effect = [ # Initial warehouse list - STARTING {"warehouses": [{"id": "w123", "state": "STARTING"}]}, - # Status check - now RUNNING + # ensure_warehouse_running: status check - still STARTING + {"state": "STARTING"}, + # ensure_warehouse_running: status check - now RUNNING {"state": "RUNNING"}, # Execute query {"statement_id": "stmt123"}, @@ -837,9 +839,11 @@ def test_warehouse_stopped_then_started(self, mock_sleep, mock_utils, mock_com): client.databricks_api.side_effect = [ # Initial warehouse list - STOPPED {"warehouses": [{"id": "w123", "state": "STOPPED"}]}, - # Start warehouse call + # ensure_warehouse_running: status check - STOPPED + {"state": "STOPPED"}, + # ensure_warehouse_running: Start warehouse call {}, - # Status check - now RUNNING + # ensure_warehouse_running: status check - now RUNNING {"state": "RUNNING"}, # Execute query {"statement_id": "stmt123"}, @@ -1213,12 +1217,10 @@ def test_warehouse_invalid_state(self, mock_utils, mock_com): client = mock_com.return_value = MagicMock() client.databricks_api.side_effect = [ - # Warehouse in ERROR state - {"warehouses": [{"id": "w123", "state": "ERROR"}]}, - # Start warehouse call - {}, - # Status check returns ERROR state (never becomes RUNNING) - {"state": "ERROR"} + # Warehouse in DELETED state (invalid) + {"warehouses": [{"id": "w123", "state": "DELETED"}]}, + # ensure_warehouse_running: status check returns DELETED state + {"state": "DELETED"} ] resp = db_query_obj.generate() @@ -1230,7 +1232,7 @@ def test_warehouse_invalid_state(self, mock_utils, mock_com): # Verify error was written db_query_obj.write_error.assert_called() error_msg = db_query_obj.write_error.call_args[0][0] - self.assertIn("ERROR", error_msg) + self.assertIn("DELETED", error_msg) @patch("databricksquery.com.DatabricksClient", autospec=True) @patch("databricksquery.utils", autospec=True) @@ -1411,3 +1413,182 @@ def test_no_limit_uses_admin_default(self, mock_utils, mock_com): db_query_obj.write_warning.assert_not_called() # Verify limit was not set (remains None, uses admin default) self.assertIsNone(db_query_obj.limit) + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + @patch("databricksquery.time.sleep") + def test_warehouse_stopped_delayed_transition(self, mock_sleep, mock_utils, mock_com): + """Test warehouse in STOPPED state with delayed transition to STARTING after start API call. + + This tests the race condition fix where after calling start API, + the warehouse may briefly still show as STOPPED before transitioning to STARTING. + """ + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.warehouse_id = "w123" + + mock_utils.get_databricks_configs.return_value = { + "auth_type": "PAT", + "databricks_instance": "test.databricks.com", + "databricks_pat": "test_token", + "admin_command_timeout": "300", + "query_result_limit": "1000", + "thread_count": "4" + } + + client = mock_com.return_value = MagicMock() + + client.databricks_api.side_effect = [ + # Initial warehouse list - STOPPED + {"warehouses": [{"id": "w123", "state": "STOPPED"}]}, + # ensure_warehouse_running: status check - STOPPED + {"state": "STOPPED"}, + # ensure_warehouse_running: Start warehouse call + {}, + # ensure_warehouse_running: status check - still STOPPED (race condition) + {"state": "STOPPED"}, + # ensure_warehouse_running: status check - now STARTING + {"state": "STARTING"}, + # ensure_warehouse_running: status check - now RUNNING + {"state": "RUNNING"}, + # Execute query + {"statement_id": "stmt123"}, + # Query status - SUCCEEDED + { + "status": {"state": "SUCCEEDED"}, + "manifest": { + "total_row_count": 0, + "truncated": False, + "schema": {"columns": []} + }, + "result": {"external_links": []} + } + ] + + resp = db_query_obj.generate() + try: + list(resp) + except (StopIteration, SystemExit): + pass + + # Verify warehouse start was called + calls = [call for call in client.databricks_api.call_args_list + if len(call[0]) > 1 and 'start' in str(call[0][1])] + self.assertTrue(len(calls) > 0) + + # Verify sleep was called (for waiting during state transitions) + mock_sleep.assert_called() + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + @patch("databricksquery.time.sleep") + def test_warehouse_stopping_state(self, mock_sleep, mock_utils, mock_com): + """Test warehouse in STOPPING state, waits for STOPPED, then starts and runs query""" + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.warehouse_id = "w123" + + mock_utils.get_databricks_configs.return_value = { + "auth_type": "PAT", + "databricks_instance": "test.databricks.com", + "databricks_pat": "test_token", + "admin_command_timeout": "300", + "query_result_limit": "1000", + "thread_count": "4" + } + + client = mock_com.return_value = MagicMock() + + client.databricks_api.side_effect = [ + # Initial warehouse list - STOPPING + {"warehouses": [{"id": "w123", "state": "STOPPING"}]}, + # Status check - still STOPPING + {"state": "STOPPING"}, + # Status check - now STOPPED + {"state": "STOPPED"}, + # Start warehouse call + {}, + # Status check - now RUNNING + {"state": "RUNNING"}, + # Execute query + {"statement_id": "stmt123"}, + # Query status - SUCCEEDED + { + "status": {"state": "SUCCEEDED"}, + "manifest": { + "total_row_count": 0, + "truncated": False, + "schema": {"columns": []} + }, + "result": {"external_links": []} + } + ] + + resp = db_query_obj.generate() + try: + list(resp) + except (StopIteration, SystemExit): + pass + + # Verify warehouse start was called + calls = [call for call in client.databricks_api.call_args_list + if len(call[0]) > 1 and 'start' in str(call[0][1])] + self.assertTrue(len(calls) > 0) + + # Verify sleep was called while waiting for warehouse to stop + mock_sleep.assert_called() + + @patch("databricksquery.com.DatabricksClient", autospec=True) + @patch("databricksquery.utils", autospec=True) + @patch("databricksquery.time.sleep") + def test_warehouse_stopping_transitions_to_running(self, mock_sleep, mock_utils, mock_com): + """Test warehouse in STOPPING state that transitions directly to RUNNING (edge case) + + This can happen if someone else starts the warehouse while we're waiting for it to stop. + """ + db_query_obj = self.DatabricksQueryCommand() + db_query_obj._metadata = MagicMock() + db_query_obj.warehouse_id = "w123" + + mock_utils.get_databricks_configs.return_value = { + "auth_type": "PAT", + "databricks_instance": "test.databricks.com", + "databricks_pat": "test_token", + "admin_command_timeout": "300", + "query_result_limit": "1000", + "thread_count": "4" + } + + client = mock_com.return_value = MagicMock() + + client.databricks_api.side_effect = [ + # Initial warehouse list - STOPPING + {"warehouses": [{"id": "w123", "state": "STOPPING"}]}, + # Status check - transitions directly to RUNNING (someone else started it) + {"state": "RUNNING"}, + # Start warehouse call (will still be made but warehouse is already running) + {}, + # Status check - RUNNING + {"state": "RUNNING"}, + # Execute query + {"statement_id": "stmt123"}, + # Query status - SUCCEEDED + { + "status": {"state": "SUCCEEDED"}, + "manifest": { + "total_row_count": 0, + "truncated": False, + "schema": {"columns": []} + }, + "result": {"external_links": []} + } + ] + + resp = db_query_obj.generate() + try: + list(resp) + except (StopIteration, SystemExit): + pass + + # Query should complete successfully + # The warehouse start may or may not be called depending on timing