diff --git a/databricks-mcp-server/databricks_mcp_server/tools/compute.py b/databricks-mcp-server/databricks_mcp_server/tools/compute.py index 06e52eff..d5ec99ea 100644 --- a/databricks-mcp-server/databricks_mcp_server/tools/compute.py +++ b/databricks-mcp-server/databricks_mcp_server/tools/compute.py @@ -1,6 +1,6 @@ -"""Compute tools - Execute code on Databricks clusters.""" +"""Compute tools - Execute code on Databricks clusters and serverless compute, and manage compute resources.""" -from typing import Dict, Any, List +from typing import Dict, Any, List, Optional from databricks_tools_core.compute import ( list_clusters as _list_clusters, @@ -8,8 +8,18 @@ start_cluster as _start_cluster, get_cluster_status as _get_cluster_status, execute_databricks_command as _execute_databricks_command, - run_python_file_on_databricks as _run_python_file_on_databricks, + run_file_on_databricks as _run_file_on_databricks, + run_code_on_serverless as _run_code_on_serverless, NoRunningClusterError, + create_cluster as _create_cluster, + modify_cluster as _modify_cluster, + terminate_cluster as _terminate_cluster, + delete_cluster as _delete_cluster, + list_node_types as _list_node_types, + list_spark_versions as _list_spark_versions, + create_sql_warehouse as _create_sql_warehouse, + modify_sql_warehouse as _modify_sql_warehouse, + delete_sql_warehouse as _delete_sql_warehouse, ) from ..server import mcp @@ -182,17 +192,26 @@ def execute_databricks_command( @mcp.tool -def run_python_file_on_databricks( +def run_file_on_databricks( file_path: str, cluster_id: str = None, context_id: str = None, + language: str = None, timeout: int = 600, destroy_context_on_completion: bool = False, + workspace_path: str = None, ) -> Dict[str, Any]: """ - Read a local Python file and execute it on a Databricks cluster. + Read a local file and execute it on a Databricks cluster. - Useful for running data generation scripts or other Python code. + Supports Python (.py), Scala (.scala), SQL (.sql), and R (.r) files. + Language is auto-detected from the file extension if not specified. + + Two modes: + - Ephemeral (default): Sends code via Command Execution API. No workspace artifact. + - Persistent: If workspace_path is provided, also uploads the file as a notebook + to that workspace path so it's visible and re-runnable in the Databricks UI. + Use persistent mode for project work (model training, ETL scripts, etc.). If context_id is provided, reuses the existing context (faster, maintains state). If not provided, creates a new context. @@ -201,15 +220,17 @@ def run_python_file_on_databricks( returns an error with actionable suggestions (startable clusters, alternatives). Args: - file_path: Local path to the Python file + file_path: Local path to the file to execute. cluster_id: ID of the cluster to run on. If not provided, auto-selects a running cluster accessible to the current user. - Single-user clusters owned by other users are automatically skipped. - context_id: Optional existing execution context ID. If provided, reuses it - for faster execution and state preservation. - timeout: Maximum wait time in seconds (default: 600) + context_id: Optional existing execution context ID for reuse. + language: Programming language ("python", "scala", "sql", "r"). + If omitted, auto-detected from file extension. + timeout: Maximum wait time in seconds (default: 600). destroy_context_on_completion: If True, destroys the context after execution. - Default is False to allow reuse. + workspace_path: Optional workspace path to persist the file as a notebook + (e.g. "/Workspace/Users/user@company.com/my-project/train"). + If omitted, no workspace artifact is created. Returns: Dictionary with: @@ -226,14 +247,20 @@ def run_python_file_on_databricks( cluster_id = None if context_id == "": context_id = None + if language == "": + language = None + if workspace_path == "": + workspace_path = None try: - result = _run_python_file_on_databricks( + result = _run_file_on_databricks( file_path=file_path, cluster_id=cluster_id, context_id=context_id, + language=language, timeout=timeout, destroy_context_on_completion=destroy_context_on_completion, + workspace_path=workspace_path, ) return result.to_dict() except NoRunningClusterError as e: @@ -250,3 +277,393 @@ def run_python_file_on_databricks( "skipped_clusters": e.skipped_clusters, "available_clusters": e.available_clusters, } + + +@mcp.tool +def run_code_on_serverless( + code: str, + language: str = "python", + timeout: int = 1800, + run_name: Optional[str] = None, + workspace_path: Optional[str] = None, +) -> Dict[str, Any]: + """ + Execute code on serverless compute (no cluster required). + + This is the primary tool for running Python when no interactive cluster is + available. Uses the Jobs API (runs/submit) with serverless compute — the code + is uploaded as a notebook, executed, and (by default) cleaned up automatically. + + Two modes: + - Ephemeral (default): Uploads to a temp path and cleans up after execution. + Good for testing, exploration, one-off scripts. + - Persistent: If workspace_path is provided, saves the notebook at that path + and keeps it after execution. Good for project work the user wants saved + (model training, ETL, data pipelines). + + Also supports Jupyter notebooks (.ipynb): if the code content is valid .ipynb + JSON (i.e. contains a "cells" key), it is automatically uploaded using + Databricks' native Jupyter import. The language parameter is ignored for .ipynb. + To run a .ipynb file, read its contents and pass the raw JSON string as code. + + Use this tool when: + - The user needs to run Python and no cluster is running + - Running one-off Python scripts that don't need an interactive session + - Running longer-running Python code (up to 30 min default timeout) + - Running a Jupyter notebook (.ipynb) on Databricks serverless + + Do NOT use this tool for: + - Interactive, iterative Python with state (use execute_databricks_command) + - SQL queries that need result rows (use execute_sql — works with serverless + SQL warehouses) + + SQL is supported (language="sql") but only for DDL/DML (CREATE TABLE, INSERT, + MERGE). SQL SELECT results are NOT captured — use execute_sql() instead. + + Args: + code: Code to execute (Python or SQL), or raw .ipynb JSON content (auto-detected). + language: Programming language ("python" or "sql"). Default: "python". Ignored for .ipynb. + timeout: Maximum wait time in seconds (default: 1800 = 30 minutes). + run_name: Optional human-readable name for the run. Auto-generated if omitted. + workspace_path: Optional workspace path to persist the notebook + (e.g. "/Workspace/Users/user@company.com/my-project/train"). + If provided, the notebook is saved at this path and kept after execution. + If omitted, uses a temp path and cleans up after. + + Returns: + Dictionary with: + - success: Whether execution succeeded + - output: The output from execution (notebook result or logs) + - error: Error message if failed + - run_id: Databricks Jobs run ID + - run_url: URL to view the run in Databricks UI + - duration_seconds: How long the execution took + - state: Final state (SUCCESS, FAILED, TIMEDOUT, etc.) + - message: Human-readable summary + - workspace_path: (persistent mode only) Where the notebook was saved + """ + result = _run_code_on_serverless( + code=code, + language=language, + timeout=timeout, + run_name=run_name if run_name else None, + cleanup=workspace_path is None, + workspace_path=workspace_path if workspace_path else None, + ) + return result.to_dict() + + +# --- Compute Management Tools --- + + +@mcp.tool +def create_cluster( + name: str, + num_workers: int = 1, + spark_version: str = None, + node_type_id: str = None, + autotermination_minutes: int = 120, + data_security_mode: str = None, + spark_conf: str = None, + autoscale_min_workers: int = None, + autoscale_max_workers: int = None, +) -> Dict[str, Any]: + """ + Create a new Databricks cluster with sensible defaults. + + Just provide a name and num_workers — the tool auto-picks the latest LTS + Databricks Runtime, a reasonable node type for the cloud, SINGLE_USER + security mode, and 120-minute auto-termination. + + Power users can override any parameter for full control. + + Args: + name: Human-readable cluster name. + num_workers: Fixed number of workers (ignored if autoscale is set). Default 1. + spark_version: DBR version key (e.g. "15.4.x-scala2.12"). Auto-picks latest LTS if omitted. + node_type_id: Worker node type (e.g. "i3.xlarge"). Auto-picked if omitted. + autotermination_minutes: Minutes of inactivity before auto-termination. Default 120. + data_security_mode: Security mode ("SINGLE_USER", "USER_ISOLATION", etc.). Default SINGLE_USER. + spark_conf: JSON string of Spark config overrides (e.g. '{"spark.sql.shuffle.partitions": "8"}'). + autoscale_min_workers: If set with autoscale_max_workers, enables autoscaling. + autoscale_max_workers: Maximum workers for autoscaling. + + Returns: + Dictionary with cluster_id, cluster_name, state, spark_version, node_type_id, and message. + """ + # Convert empty strings to None + if spark_version == "": + spark_version = None + if node_type_id == "": + node_type_id = None + if data_security_mode == "": + data_security_mode = None + + # Parse spark_conf from JSON string + parsed_spark_conf = None + if spark_conf and spark_conf.strip(): + import json + parsed_spark_conf = json.loads(spark_conf) + + kwargs = {} + if spark_version: + kwargs["spark_version"] = spark_version + if node_type_id: + kwargs["node_type_id"] = node_type_id + if data_security_mode: + kwargs["data_security_mode"] = data_security_mode + if parsed_spark_conf: + kwargs["spark_conf"] = parsed_spark_conf + if autoscale_min_workers is not None: + kwargs["autoscale_min_workers"] = autoscale_min_workers + if autoscale_max_workers is not None: + kwargs["autoscale_max_workers"] = autoscale_max_workers + + return _create_cluster( + name=name, + num_workers=num_workers, + autotermination_minutes=autotermination_minutes, + **kwargs, + ) + + +@mcp.tool +def modify_cluster( + cluster_id: str, + name: str = None, + num_workers: int = None, + spark_version: str = None, + node_type_id: str = None, + autotermination_minutes: int = None, + spark_conf: str = None, + autoscale_min_workers: int = None, + autoscale_max_workers: int = None, +) -> Dict[str, Any]: + """ + Modify an existing Databricks cluster configuration. + + Only the specified parameters are changed; others remain as-is. + If the cluster is running, it will restart to apply changes. + + Args: + cluster_id: ID of the cluster to modify. + name: New cluster name (optional). + num_workers: New fixed worker count (optional). + spark_version: New DBR version (optional). + node_type_id: New worker node type (optional). + autotermination_minutes: New auto-termination timeout (optional). + spark_conf: JSON string of Spark config overrides (optional). + autoscale_min_workers: Set to enable/modify autoscaling (optional). + autoscale_max_workers: Set to enable/modify autoscaling (optional). + + Returns: + Dictionary with cluster_id, cluster_name, state, and message. + """ + # Convert empty strings to None + if name == "": + name = None + if spark_version == "": + spark_version = None + if node_type_id == "": + node_type_id = None + + kwargs = {} + if name: + kwargs["name"] = name + if num_workers is not None: + kwargs["num_workers"] = num_workers + if spark_version: + kwargs["spark_version"] = spark_version + if node_type_id: + kwargs["node_type_id"] = node_type_id + if autotermination_minutes is not None: + kwargs["autotermination_minutes"] = autotermination_minutes + if autoscale_min_workers is not None: + kwargs["autoscale_min_workers"] = autoscale_min_workers + if autoscale_max_workers is not None: + kwargs["autoscale_max_workers"] = autoscale_max_workers + + # Parse spark_conf from JSON string + if spark_conf and spark_conf.strip(): + import json + kwargs["spark_conf"] = json.loads(spark_conf) + + return _modify_cluster(cluster_id=cluster_id, **kwargs) + + +@mcp.tool +def terminate_cluster(cluster_id: str) -> Dict[str, Any]: + """ + Stop a running Databricks cluster (reversible). + + The cluster is terminated but NOT deleted. It can be restarted later + with start_cluster(). This is safe and reversible. + + Args: + cluster_id: ID of the cluster to terminate. + + Returns: + Dictionary with cluster_id, cluster_name, state, and message. + """ + return _terminate_cluster(cluster_id) + + +@mcp.tool +def delete_cluster(cluster_id: str) -> Dict[str, Any]: + """ + PERMANENTLY delete a Databricks cluster. + + WARNING: This is a DESTRUCTIVE, IRREVERSIBLE action. The cluster and its + configuration will be permanently removed. This cannot be undone. + + IMPORTANT: Always confirm with the user before calling this tool. + Ask: "Are you sure you want to permanently delete cluster ''? + This cannot be undone." + + Args: + cluster_id: ID of the cluster to permanently delete. + + Returns: + Dictionary with cluster_id, cluster_name, state, and warning message. + """ + return _delete_cluster(cluster_id) + + +@mcp.tool +def list_node_types() -> List[Dict[str, Any]]: + """ + List available VM/node types for cluster creation. + + Returns node type IDs, memory, cores, and GPU info. Useful when the user + wants to choose a specific node type for create_cluster(). + + Returns: + List of node type dicts with node_type_id, memory_mb, num_cores, num_gpus, description. + """ + return _list_node_types() + + +@mcp.tool +def list_spark_versions() -> List[Dict[str, Any]]: + """ + List available Databricks Runtime (Spark) versions. + + Returns version keys and names. Filter for "LTS" in the name to find + long-term support versions for create_cluster(). + + Returns: + List of dicts with key and name for each version. + """ + return _list_spark_versions() + + +@mcp.tool +def create_sql_warehouse( + name: str, + size: str = "Small", + min_num_clusters: int = 1, + max_num_clusters: int = 1, + auto_stop_mins: int = 120, + warehouse_type: str = "PRO", + enable_serverless: bool = True, +) -> Dict[str, Any]: + """ + Create a new SQL warehouse with sensible defaults. + + By default creates a serverless Pro warehouse with auto-stop at 120 minutes. + + Args: + name: Human-readable warehouse name. + size: T-shirt size ("2X-Small", "X-Small", "Small", "Medium", "Large", + "X-Large", "2X-Large", "3X-Large", "4X-Large"). Default "Small". + min_num_clusters: Minimum cluster count. Default 1. + max_num_clusters: Maximum cluster count for scaling. Default 1. + auto_stop_mins: Minutes of inactivity before auto-stop. Default 120. + warehouse_type: "PRO" or "CLASSIC". Default "PRO". + enable_serverless: Enable serverless compute. Default True. + + Returns: + Dictionary with warehouse_id, name, size, state, and message. + """ + # Convert empty strings to None + if size == "": + size = "Small" + if warehouse_type == "": + warehouse_type = "PRO" + + return _create_sql_warehouse( + name=name, + size=size, + min_num_clusters=min_num_clusters, + max_num_clusters=max_num_clusters, + auto_stop_mins=auto_stop_mins, + warehouse_type=warehouse_type, + enable_serverless=enable_serverless, + ) + + +@mcp.tool +def modify_sql_warehouse( + warehouse_id: str, + name: str = None, + size: str = None, + min_num_clusters: int = None, + max_num_clusters: int = None, + auto_stop_mins: int = None, +) -> Dict[str, Any]: + """ + Modify an existing SQL warehouse configuration. + + Only the specified parameters are changed; others remain as-is. + + Args: + warehouse_id: ID of the warehouse to modify. + name: New warehouse name (optional). + size: New T-shirt size (optional). + min_num_clusters: New minimum clusters (optional). + max_num_clusters: New maximum clusters (optional). + auto_stop_mins: New auto-stop timeout in minutes (optional). + + Returns: + Dictionary with warehouse_id, name, state, and message. + """ + # Convert empty strings to None + if name == "": + name = None + if size == "": + size = None + + kwargs = {} + if name: + kwargs["name"] = name + if size: + kwargs["size"] = size + if min_num_clusters is not None: + kwargs["min_num_clusters"] = min_num_clusters + if max_num_clusters is not None: + kwargs["max_num_clusters"] = max_num_clusters + if auto_stop_mins is not None: + kwargs["auto_stop_mins"] = auto_stop_mins + + return _modify_sql_warehouse(warehouse_id=warehouse_id, **kwargs) + + +@mcp.tool +def delete_sql_warehouse(warehouse_id: str) -> Dict[str, Any]: + """ + PERMANENTLY delete a SQL warehouse. + + WARNING: This is a DESTRUCTIVE, IRREVERSIBLE action. The warehouse and its + configuration will be permanently removed. This cannot be undone. + + IMPORTANT: Always confirm with the user before calling this tool. + Ask: "Are you sure you want to permanently delete warehouse ''? + This cannot be undone." + + Args: + warehouse_id: ID of the warehouse to permanently delete. + + Returns: + Dictionary with warehouse_id, name, state, and warning message. + """ + return _delete_sql_warehouse(warehouse_id) diff --git a/databricks-skills/databricks-execution-compute/SKILL.md b/databricks-skills/databricks-execution-compute/SKILL.md new file mode 100644 index 00000000..bb3d9ffc --- /dev/null +++ b/databricks-skills/databricks-execution-compute/SKILL.md @@ -0,0 +1,166 @@ +--- +name: databricks-execution-compute +description: >- + Execute code on Databricks compute — serverless or classic clusters. Use this + skill when the user mentions: "run code", "execute", "run on databricks", + "serverless", "no cluster", "run python", "run scala", "run sql", "run R", + "run file", "push and run", "notebook run", "batch script", "model training", + "run script on cluster". Also use when the user wants to run local files on + Databricks or needs to choose between serverless and cluster compute. +--- + +# Databricks Execution Compute + +Run code on Databricks — either on serverless compute (no cluster required) or on classic clusters (interactive, multi-language). Supports pushing local files to the Databricks workspace and executing them. + +## Choosing the Right Tool + +| Scenario | Tool | Why | +|----------|------|-----| +| **Run Python, no cluster available** | `run_code_on_serverless` | No cluster needed; serverless spins up automatically | +| **Run local file on a cluster** | `run_file_on_databricks` | Auto-detects language from extension; supports Python, Scala, SQL, R | +| **Interactive iteration (preserve variables)** | `execute_databricks_command` | Keeps execution context alive across calls | +| **SQL queries that need result rows** | `execute_sql` | Works with serverless SQL warehouses; returns data | +| **Batch/ETL Python, no interactivity needed** | `run_code_on_serverless` | Dedicated serverless resources, up to 30 min timeout | +| **Long-running production pipelines** | Databricks Jobs | Full scheduling, retries, monitoring | + +## Ephemeral vs Persistent Mode + +All execution tools support two modes: + +**Ephemeral (default):** Code is executed and no artifact is saved in the workspace. Good for testing, exploration, quick checks. + +**Persistent:** Pass `workspace_path` to save the code as a notebook in the Databricks workspace. The notebook stays after execution — visible in the UI, re-runnable, and versionable. Good for: +- Model training scripts +- ETL/data pipeline notebooks +- Any project work the user wants to keep + +When the user is working on a project, ask where they want files saved and suggest a path like: +`/Workspace/Users/{username}/{project-name}/` + +## MCP Tools + +### run_code_on_serverless + +Execute code on serverless compute via Jobs API. No cluster required. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `code` | string | *(required)* | Python or SQL code to execute | +| `language` | string | `"python"` | `"python"` or `"sql"` | +| `timeout` | int | `1800` | Max wait time in seconds (30 min) | +| `run_name` | string | auto-generated | Optional human-readable run name | +| `workspace_path` | string | None | Workspace path to persist the notebook. If omitted, uses temp path and cleans up | + +**Returns:** `success`, `output`, `error`, `run_id`, `run_url`, `duration_seconds`, `state`, `message`, `workspace_path` (persistent mode). + +**Output capture:** Use `dbutils.notebook.exit(value)` to return structured output. `print()` output may not be reliably captured. SQL SELECT results are NOT captured — use `execute_sql()` instead. + +### run_file_on_databricks + +Read a local file and execute it on a Databricks cluster. Auto-detects language from file extension. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `file_path` | string | *(required)* | Local path to the file (.py, .scala, .sql, .r) | +| `cluster_id` | string | auto-selected | Cluster to run on; auto-selects if omitted | +| `context_id` | string | None | Reuse an existing execution context | +| `language` | string | auto-detected | Override language detection | +| `timeout` | int | `600` | Max wait time in seconds | +| `destroy_context_on_completion` | bool | `false` | Destroy context after execution | +| `workspace_path` | string | None | Workspace path to also persist the file as a notebook | + +**Returns:** `success`, `output`, `error`, `cluster_id`, `context_id`, `context_destroyed`, `message`. + +### execute_databricks_command + +Execute code interactively on a cluster. Best for iterative work — contexts persist variables and imports across calls. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `code` | string | *(required)* | Code to execute | +| `cluster_id` | string | auto-selected | Cluster to run on | +| `context_id` | string | None | Reuse existing context for speed + state | +| `language` | string | `"python"` | `"python"`, `"scala"`, `"sql"`, or `"r"` | +| `timeout` | int | `120` | Max wait time in seconds | +| `destroy_context_on_completion` | bool | `false` | Destroy context after execution | + +**Returns:** `success`, `output`, `error`, `cluster_id`, `context_id`, `context_destroyed`, `message`. + +## Cluster Management Helpers + +| Tool | Description | +|------|-------------| +| `list_clusters` | List all user-created clusters in the workspace | +| `get_best_cluster` | Auto-select the best running cluster (prefers "shared" > "demo") | +| `start_cluster` | Start a terminated cluster (**always ask user first**) | +| `get_cluster_status` | Poll cluster state after starting | + +### When No Cluster Is Available + +If `execute_databricks_command` or `run_file_on_databricks` finds no running cluster: +1. The error response includes `startable_clusters` and `suggestions` +2. Ask the user if they want to start a terminated cluster (3-8 min startup) +3. Or suggest `run_code_on_serverless` for Python (no cluster needed) +4. Or suggest `execute_sql` for SQL workloads (uses SQL warehouses) + +## Limitations + +| Limitation | Applies To | Details | +|-----------|------------|---------| +| Cold start ~25-50s | Serverless | Serverless compute spin-up time | +| No interactive state | Serverless | Each invocation is fresh; no variables persist | +| Python and SQL only | Serverless | No R, Scala, or Java on serverless | +| SQL SELECT not captured | Serverless | Use `execute_sql()` for SELECT queries | +| Cluster must be running | Classic | Use `start_cluster` or switch to serverless | +| print() output unreliable | Serverless | Use `dbutils.notebook.exit()` instead | + +## Quick Start Examples + +### Run Python on serverless (ephemeral) + +```python +run_code_on_serverless( + code="dbutils.notebook.exit('hello from serverless')" +) +``` + +### Run Python on serverless (persistent — save to project) + +```python +run_code_on_serverless( + code=training_code, + workspace_path="/Workspace/Users/user@company.com/ml-project/train", + run_name="model-training-v1" +) +``` + +### Run a local file on a cluster + +```python +run_file_on_databricks(file_path="/local/path/to/etl.py") +``` + +### Run a local file and persist it to workspace + +```python +run_file_on_databricks( + file_path="/local/path/to/train.py", + workspace_path="/Workspace/Users/user@company.com/ml-project/train" +) +``` + +### Interactive iteration on a cluster + +```python +# First call — creates context +result = execute_databricks_command(code="import pandas as pd\ndf = pd.DataFrame({'a': [1,2,3]})") +# Follow-up — reuses context (faster, state preserved) +execute_databricks_command(code="print(df.shape)", context_id=result["context_id"], cluster_id=result["cluster_id"]) +``` + +## Related Skills + +- **[databricks-jobs](../databricks-jobs/SKILL.md)** — Production job orchestration with scheduling, retries, and multi-task DAGs +- **[databricks-dbsql](../databricks-dbsql/SKILL.md)** — SQL warehouse capabilities and AI functions +- **[databricks-python-sdk](../databricks-python-sdk/SKILL.md)** — Direct SDK usage for workspace automation diff --git a/databricks-skills/databricks-manage-compute/SKILL.md b/databricks-skills/databricks-manage-compute/SKILL.md new file mode 100644 index 00000000..9cae11f7 --- /dev/null +++ b/databricks-skills/databricks-manage-compute/SKILL.md @@ -0,0 +1,194 @@ +--- +name: databricks-manage-compute +description: >- + Create, modify, and delete Databricks compute resources (clusters and SQL + warehouses). Use this skill when the user mentions: "create cluster", "new + cluster", "resize cluster", "modify cluster", "delete cluster", "terminate + cluster", "create warehouse", "new warehouse", "resize warehouse", "delete + warehouse", "node types", "runtime versions", "DBR versions", "spin up + compute", "provision cluster". +--- + +# Databricks Manage Compute + +Create, modify, and delete Databricks compute resources — classic clusters and SQL warehouses. Provides opinionated defaults so simple operations just work, with full override for power users. + +## Decision Matrix + +| User Intent | Tool | Notes | +|-------------|------|-------| +| **Create a new cluster** | `create_cluster` | Just needs name + num_workers; defaults handle the rest | +| **Resize or reconfigure a cluster** | `modify_cluster` | Change workers, DBR, node type, spark conf | +| **Stop a cluster (save costs)** | `terminate_cluster` | Reversible — can restart with `start_cluster` | +| **Permanently remove a cluster** | `delete_cluster` | DESTRUCTIVE — always confirm with user first | +| **Choose a node type** | `list_node_types` | Browse available VM types before creating | +| **Choose a DBR version** | `list_spark_versions` | Browse runtimes; filter for "LTS" | +| **Create a SQL warehouse** | `create_sql_warehouse` | Serverless Pro by default | +| **Resize a SQL warehouse** | `modify_sql_warehouse` | Change size, scaling, auto-stop | +| **Permanently remove a warehouse** | `delete_sql_warehouse` | DESTRUCTIVE — always confirm with user first | + +## MCP Tools — Clusters + +### create_cluster + +Create a new Databricks cluster with sensible defaults. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `name` | string | *(required)* | Human-readable cluster name | +| `num_workers` | int | `1` | Fixed worker count (ignored if autoscale is set) | +| `spark_version` | string | latest LTS | DBR version key (e.g. "15.4.x-scala2.12") | +| `node_type_id` | string | auto-picked | Worker node type (e.g. "i3.xlarge") | +| `autotermination_minutes` | int | `120` | Minutes of inactivity before auto-stop | +| `data_security_mode` | string | `"SINGLE_USER"` | Security mode | +| `spark_conf` | string (JSON) | None | Spark config overrides as JSON | +| `autoscale_min_workers` | int | None | Min workers for autoscaling | +| `autoscale_max_workers` | int | None | Max workers for autoscaling | + +**Returns:** `cluster_id`, `cluster_name`, `state`, `spark_version`, `node_type_id`, `message`. + +### modify_cluster + +Modify an existing cluster. Only specified parameters change; the rest stay as-is. Running clusters will restart. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `cluster_id` | string | *(required)* | Cluster to modify | +| `name` | string | unchanged | New cluster name | +| `num_workers` | int | unchanged | New worker count | +| `spark_version` | string | unchanged | New DBR version | +| `node_type_id` | string | unchanged | New node type | +| `autotermination_minutes` | int | unchanged | New auto-termination | +| `spark_conf` | string (JSON) | unchanged | New Spark config | +| `autoscale_min_workers` | int | unchanged | Enable/modify autoscaling | +| `autoscale_max_workers` | int | unchanged | Enable/modify autoscaling | + +**Returns:** `cluster_id`, `cluster_name`, `state`, `message`. + +### terminate_cluster + +Stop a running cluster (reversible). The cluster can be restarted later with `start_cluster`. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `cluster_id` | string | *(required)* | Cluster to terminate | + +**Returns:** `cluster_id`, `cluster_name`, `state`, `message`. + +### delete_cluster + +**DESTRUCTIVE** — Permanently delete a cluster. Always confirm with the user before calling. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `cluster_id` | string | *(required)* | Cluster to permanently delete | + +**Returns:** `cluster_id`, `cluster_name`, `state`, `message` (includes warning). + +### list_node_types + +List available VM/node types for the workspace. Use this to help users choose a `node_type_id` for `create_cluster`. + +**Returns:** List of `node_type_id`, `memory_mb`, `num_cores`, `num_gpus`, `description`, `is_deprecated`. + +### list_spark_versions + +List available Databricks Runtime versions. Filter for "LTS" in the name for long-term support versions. + +**Returns:** List of `key`, `name`. + +## MCP Tools — SQL Warehouses + +### create_sql_warehouse + +Create a new SQL warehouse. Defaults to serverless Pro with 120-minute auto-stop. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `name` | string | *(required)* | Warehouse name | +| `size` | string | `"Small"` | T-shirt size (2X-Small through 4X-Large) | +| `min_num_clusters` | int | `1` | Minimum clusters | +| `max_num_clusters` | int | `1` | Maximum clusters for scaling | +| `auto_stop_mins` | int | `120` | Auto-stop after inactivity | +| `warehouse_type` | string | `"PRO"` | PRO or CLASSIC | +| `enable_serverless` | bool | `true` | Enable serverless compute | + +**Returns:** `warehouse_id`, `name`, `size`, `state`, `message`. + +### modify_sql_warehouse + +Modify an existing SQL warehouse. Only specified parameters change. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `warehouse_id` | string | *(required)* | Warehouse to modify | +| `name` | string | unchanged | New warehouse name | +| `size` | string | unchanged | New T-shirt size | +| `min_num_clusters` | int | unchanged | New min clusters | +| `max_num_clusters` | int | unchanged | New max clusters | +| `auto_stop_mins` | int | unchanged | New auto-stop timeout | + +**Returns:** `warehouse_id`, `name`, `state`, `message`. + +### delete_sql_warehouse + +**DESTRUCTIVE** — Permanently delete a SQL warehouse. Always confirm with the user before calling. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `warehouse_id` | string | *(required)* | Warehouse to permanently delete | + +**Returns:** `warehouse_id`, `name`, `state`, `message` (includes warning). + +## Destructive Actions + +`delete_cluster` and `delete_sql_warehouse` are permanent and irreversible. Before calling either: + +1. Tell the user the action is permanent +2. Ask for explicit confirmation +3. Only proceed if the user confirms + +`terminate_cluster` is safe and reversible — the cluster can be restarted. + +## Quick Start Examples + +### Create a simple cluster (all defaults) + +```python +create_cluster(name="my-dev-cluster", num_workers=1) +``` + +### Create an autoscaling cluster + +```python +create_cluster( + name="my-scaling-cluster", + autoscale_min_workers=1, + autoscale_max_workers=8, + autotermination_minutes=60 +) +``` + +### Resize a cluster + +```python +modify_cluster(cluster_id="1234-567890-abcdef", num_workers=4) +``` + +### Create a SQL warehouse + +```python +create_sql_warehouse(name="analytics-warehouse", size="Medium") +``` + +### Stop a cluster to save costs + +```python +terminate_cluster(cluster_id="1234-567890-abcdef") +``` + +## Related Skills + +- **[databricks-execution-compute](../databricks-execution-compute/SKILL.md)** — Execute code on clusters and serverless compute +- **[databricks-dbsql](../databricks-dbsql/SKILL.md)** — SQL warehouse query capabilities +- **[databricks-python-sdk](../databricks-python-sdk/SKILL.md)** — Direct SDK usage for workspace automation diff --git a/databricks-skills/install_skills.sh b/databricks-skills/install_skills.sh index 6220195e..49bc2916 100755 --- a/databricks-skills/install_skills.sh +++ b/databricks-skills/install_skills.sh @@ -42,7 +42,7 @@ MLFLOW_REPO_RAW_URL="https://raw.githubusercontent.com/mlflow/skills" MLFLOW_REPO_REF="main" # Databricks skills (hosted in this repo) -DATABRICKS_SKILLS="databricks-agent-bricks databricks-aibi-dashboards databricks-asset-bundles databricks-app-python databricks-config databricks-dbsql databricks-docs databricks-genie databricks-iceberg databricks-jobs databricks-lakebase-autoscale databricks-lakebase-provisioned databricks-metric-views databricks-mlflow-evaluation databricks-model-serving databricks-parsing databricks-python-sdk databricks-spark-declarative-pipelines databricks-spark-structured-streaming databricks-synthetic-data-gen databricks-unity-catalog databricks-unstructured-pdf-generation databricks-vector-search databricks-zerobus-ingest spark-python-data-source" +DATABRICKS_SKILLS="databricks-agent-bricks databricks-aibi-dashboards databricks-asset-bundles databricks-app-python databricks-config databricks-dbsql databricks-docs databricks-genie databricks-iceberg databricks-jobs databricks-lakebase-autoscale databricks-lakebase-provisioned databricks-manage-compute databricks-metric-views databricks-mlflow-evaluation databricks-model-serving databricks-parsing databricks-python-sdk databricks-execution-compute databricks-spark-declarative-pipelines databricks-spark-structured-streaming databricks-synthetic-data-gen databricks-unity-catalog databricks-unstructured-pdf-generation databricks-vector-search databricks-zerobus-ingest spark-python-data-source" # MLflow skills (fetched from mlflow/skills repo) MLFLOW_SKILLS="agent-evaluation analyze-mlflow-chat-session analyze-mlflow-trace instrumenting-with-mlflow-tracing mlflow-onboarding querying-mlflow-metrics retrieving-mlflow-traces searching-mlflow-docs" @@ -73,6 +73,8 @@ get_skill_description() { "databricks-iceberg") echo "Apache Iceberg - managed tables, UniForm, IRC, Snowflake interop, migration" ;; "databricks-jobs") echo "Databricks Lakeflow Jobs - workflow orchestration" ;; "databricks-python-sdk") echo "Databricks Python SDK, Connect, and REST API" ;; + "databricks-execution-compute") echo "Execute code on Databricks - serverless and classic cluster compute" ;; + "databricks-manage-compute") echo "Create, modify, and delete Databricks clusters and SQL warehouses" ;; "databricks-unity-catalog") echo "System tables for lineage, audit, billing" ;; "databricks-lakebase-autoscale") echo "Lakebase Autoscale - managed PostgreSQL with autoscaling" ;; "databricks-lakebase-provisioned") echo "Lakebase Provisioned - data connections and reverse ETL" ;; diff --git a/databricks-tools-core/databricks_tools_core/compute/__init__.py b/databricks-tools-core/databricks_tools_core/compute/__init__.py index 55a7eed2..958952e2 100644 --- a/databricks-tools-core/databricks_tools_core/compute/__init__.py +++ b/databricks-tools-core/databricks_tools_core/compute/__init__.py @@ -1,7 +1,8 @@ """ -Compute - Execution Context Operations +Compute - Code Execution and Compute Management Operations -Functions for executing code on Databricks clusters. +Functions for executing code on Databricks clusters and serverless compute, +and for creating, modifying, and deleting compute resources. """ from .execution import ( @@ -14,9 +15,27 @@ create_context, destroy_context, execute_databricks_command, + run_file_on_databricks, run_python_file_on_databricks, ) +from .serverless import ( + ServerlessRunResult, + run_code_on_serverless, +) + +from .manage import ( + create_cluster, + modify_cluster, + terminate_cluster, + delete_cluster, + list_node_types, + list_spark_versions, + create_sql_warehouse, + modify_sql_warehouse, + delete_sql_warehouse, +) + __all__ = [ "ExecutionResult", "NoRunningClusterError", @@ -27,5 +46,17 @@ "create_context", "destroy_context", "execute_databricks_command", + "run_file_on_databricks", "run_python_file_on_databricks", + "ServerlessRunResult", + "run_code_on_serverless", + "create_cluster", + "modify_cluster", + "terminate_cluster", + "delete_cluster", + "list_node_types", + "list_spark_versions", + "create_sql_warehouse", + "modify_sql_warehouse", + "delete_sql_warehouse", ] diff --git a/databricks-tools-core/databricks_tools_core/compute/execution.py b/databricks-tools-core/databricks_tools_core/compute/execution.py index a4c04c5f..805e30db 100644 --- a/databricks-tools-core/databricks_tools_core/compute/execution.py +++ b/databricks-tools-core/databricks_tools_core/compute/execution.py @@ -670,52 +670,53 @@ def execute_databricks_command( raise -def run_python_file_on_databricks( +_FILE_EXT_LANGUAGE = { + ".py": "python", + ".scala": "scala", + ".sql": "sql", + ".r": "r", +} + + +def run_file_on_databricks( file_path: str, cluster_id: Optional[str] = None, context_id: Optional[str] = None, + language: Optional[str] = None, timeout: int = 600, destroy_context_on_completion: bool = False, + workspace_path: Optional[str] = None, ) -> ExecutionResult: """ - Read a local Python file and execute it on a Databricks cluster. + Read a local file and execute it on a Databricks cluster. - This is useful for running data generation scripts or other Python code - that has been written locally and needs to be executed on Databricks. + Supports Python, Scala, SQL, and R files. If ``language`` is not specified, + it is auto-detected from the file extension (.py, .scala, .sql, .r). - If context_id is provided, reuses the existing context (faster, maintains state). - If not provided, creates a new context. + Two modes: + - **Ephemeral** (default): Sends code directly via Command Execution API. + No artifact is saved in the workspace. + - **Persistent**: If ``workspace_path`` is provided, also uploads the file + as a notebook to that workspace path so it is visible in the Databricks UI. Args: - file_path: Local path to the Python file to execute - cluster_id: ID of the cluster to run the code on. If not provided, - auto-selects a running cluster (prefers "shared" or "demo"). - context_id: Optional existing execution context ID. If provided, reuses it - for faster execution and state preservation. - timeout: Maximum time to wait for execution (seconds, default 600) + file_path: Local path to the file to execute. + cluster_id: ID of the cluster to run on. If not provided, auto-selects + a running cluster (prefers "shared" or "demo"). + context_id: Optional existing execution context ID for reuse. + language: Programming language ("python", "scala", "sql", "r"). + If omitted, auto-detected from file extension. + timeout: Maximum time to wait for execution (seconds, default 600). destroy_context_on_completion: If True, destroys the context after execution. - Default is False to allow reuse. + workspace_path: Optional workspace path to persist the file as a notebook + (e.g. "/Workspace/Users/user@company.com/my-project/train"). + If omitted, no workspace artifact is created. Returns: ExecutionResult with output, error, and context info for reuse. - - Raises: - FileNotFoundError: If the file doesn't exist - NoRunningClusterError: If no cluster_id provided and no running cluster found - DatabricksError: If API request fails - - Example: - >>> # First execution - creates context - >>> result = run_python_file_on_databricks(file_path="/path/to/script.py") - >>> print(result.context_id) # Save this for follow-up - >>> - >>> # Follow-up execution - reuses context (faster) - >>> result2 = run_python_file_on_databricks( - ... file_path="/path/to/another_script.py", - ... cluster_id=result.cluster_id, - ... context_id=result.context_id - ... ) """ + import os + # Read the file contents try: with open(file_path, "r", encoding="utf-8") as f: @@ -728,12 +729,61 @@ def run_python_file_on_databricks( if not code.strip(): return ExecutionResult(success=False, error=f"File is empty: {file_path}") + # Auto-detect language from file extension if not specified + if language is None: + ext = os.path.splitext(file_path)[1].lower() + language = _FILE_EXT_LANGUAGE.get(ext, "python") + + # Persist to workspace if requested + if workspace_path: + try: + _upload_to_workspace(code, language, workspace_path) + except Exception as e: + return ExecutionResult(success=False, error=f"Failed to upload to workspace: {e}") + # Execute the code on Databricks return execute_databricks_command( code=code, cluster_id=cluster_id, context_id=context_id, - language="python", + language=language, timeout=timeout, destroy_context_on_completion=destroy_context_on_completion, ) + + +def _upload_to_workspace(code: str, language: str, workspace_path: str) -> None: + """Upload code as a notebook to the Databricks workspace for persistence.""" + import base64 + + from databricks.sdk.service.workspace import ImportFormat, Language + + lang_map = { + "python": Language.PYTHON, + "scala": Language.SCALA, + "sql": Language.SQL, + "r": Language.R, + } + + w = get_workspace_client() + lang_enum = lang_map.get(language.lower(), Language.PYTHON) + content_b64 = base64.b64encode(code.encode("utf-8")).decode("utf-8") + + # Ensure parent directory exists + parent = workspace_path.rsplit("/", 1)[0] + try: + w.workspace.mkdirs(parent) + except Exception: + pass # Directory may already exist + + w.workspace.import_( + path=workspace_path, + content=content_b64, + language=lang_enum, + format=ImportFormat.SOURCE, + overwrite=True, + ) + + +# Keep old name as alias for backwards compatibility +run_python_file_on_databricks = run_file_on_databricks diff --git a/databricks-tools-core/databricks_tools_core/compute/manage.py b/databricks-tools-core/databricks_tools_core/compute/manage.py new file mode 100644 index 00000000..f0fd8ce7 --- /dev/null +++ b/databricks-tools-core/databricks_tools_core/compute/manage.py @@ -0,0 +1,561 @@ +""" +Compute - Manage Compute Resources + +Functions for creating, modifying, and deleting Databricks clusters and SQL warehouses. +Uses Databricks SDK for all operations. +""" + +import logging +from typing import Optional, List, Dict, Any + +from databricks.sdk.service.compute import ( + AutoScale, + DataSecurityMode, +) + +from ..auth import get_workspace_client + +logger = logging.getLogger(__name__) + + +# --- Clusters --- + + +def _get_latest_lts_spark_version(w) -> str: + """Pick the latest LTS Databricks Runtime version. + + Falls back to the latest non-ML, non-GPU, non-Photon LTS version, + or the first available version if no LTS is found. + """ + versions = w.clusters.spark_versions() + lts_versions = [] + for v in versions.versions: + key = v.key or "" + name = (v.name or "").lower() + # Skip ML, GPU, Photon, and aarch64 runtimes + if any(tag in key for tag in ("-ml-", "-gpu-", "-photon-", "-aarch64-")): + continue + if "lts" in name: + lts_versions.append(v) + + if lts_versions: + # Sort by key descending to get latest + lts_versions.sort(key=lambda v: v.key, reverse=True) + return lts_versions[0].key + + # Fallback: first available version + if versions.versions: + return versions.versions[0].key + + raise RuntimeError("No Spark versions available in this workspace") + + +def _get_default_node_type(w) -> str: + """Pick a reasonable default node type for the current cloud. + + Prefers memory-optimized, mid-size instances. Falls back to the + smallest available node type. + """ + node_types = w.clusters.list_node_types() + + # Common sensible defaults by cloud + preferred = [ + "i3.xlarge", # AWS + "Standard_DS3_v2", # Azure + "n1-highmem-4", # GCP + "Standard_D4ds_v5", # Azure newer + "m5d.xlarge", # AWS newer + ] + + available_ids = {nt.node_type_id for nt in node_types.node_types} + + for pref in preferred: + if pref in available_ids: + return pref + + # Fallback: pick smallest available node type by memory + if node_types.node_types: + sorted_types = sorted( + node_types.node_types, + key=lambda nt: getattr(nt, "memory_mb", 0) or 0, + ) + # Skip types with 0 memory (metadata-only entries) + for nt in sorted_types: + if (getattr(nt, "memory_mb", 0) or 0) > 0: + return nt.node_type_id + return sorted_types[0].node_type_id + + raise RuntimeError("No node types available in this workspace") + + +def create_cluster( + name: str, + num_workers: int = 1, + spark_version: Optional[str] = None, + node_type_id: Optional[str] = None, + driver_node_type_id: Optional[str] = None, + autotermination_minutes: int = 120, + data_security_mode: Optional[str] = None, + single_user_name: Optional[str] = None, + spark_conf: Optional[Dict[str, str]] = None, + autoscale_min_workers: Optional[int] = None, + autoscale_max_workers: Optional[int] = None, + **kwargs, +) -> Dict[str, Any]: + """Create a new Databricks cluster with sensible defaults. + + Provides opinionated defaults so ``create_cluster(name="my-cluster", num_workers=1)`` + just works — auto-picks the latest LTS DBR, a reasonable node type, single-user + security mode, and 120-minute auto-termination. + + Power users can override any parameter or pass additional SDK parameters via kwargs. + + Args: + name: Human-readable cluster name. + num_workers: Fixed number of workers (ignored if autoscale is set). Default 1. + spark_version: DBR version key (e.g. "15.4.x-scala2.12"). Auto-picks latest LTS if omitted. + node_type_id: Worker node type (e.g. "i3.xlarge"). Auto-picked if omitted. + driver_node_type_id: Driver node type. Defaults to same as worker. + autotermination_minutes: Minutes of inactivity before auto-termination. Default 120. + data_security_mode: Security mode string ("SINGLE_USER", "USER_ISOLATION", etc.). + Defaults to SINGLE_USER. + single_user_name: User for SINGLE_USER mode. Auto-detected if omitted. + spark_conf: Spark configuration overrides. + autoscale_min_workers: If set (with autoscale_max_workers), enables autoscaling + instead of fixed num_workers. + autoscale_max_workers: Maximum workers for autoscaling. + **kwargs: Additional parameters passed directly to the SDK clusters.create() call. + + Returns: + Dict with cluster_id, cluster_name, state, and message. + """ + w = get_workspace_client() + + # Auto-pick defaults + if spark_version is None: + spark_version = _get_latest_lts_spark_version(w) + if node_type_id is None: + node_type_id = _get_default_node_type(w) + if driver_node_type_id is None: + driver_node_type_id = node_type_id + + # Security mode defaults + if data_security_mode is None: + dsm = DataSecurityMode.SINGLE_USER + else: + dsm = DataSecurityMode(data_security_mode) + + if dsm == DataSecurityMode.SINGLE_USER and single_user_name is None: + from ..auth import get_current_username + single_user_name = get_current_username() + + # Build create kwargs + create_kwargs = { + "cluster_name": name, + "spark_version": spark_version, + "node_type_id": node_type_id, + "driver_node_type_id": driver_node_type_id, + "autotermination_minutes": autotermination_minutes, + "data_security_mode": dsm, + } + + if single_user_name: + create_kwargs["single_user_name"] = single_user_name + if spark_conf: + create_kwargs["spark_conf"] = spark_conf + + # Autoscale vs fixed workers + if autoscale_min_workers is not None and autoscale_max_workers is not None: + create_kwargs["autoscale"] = AutoScale( + min_workers=autoscale_min_workers, + max_workers=autoscale_max_workers, + ) + else: + create_kwargs["num_workers"] = num_workers + + # Merge any extra SDK parameters + create_kwargs.update(kwargs) + + # Create the cluster (non-blocking — returns immediately) + wait = w.clusters.create(**create_kwargs) + cluster_id = wait.cluster_id + + return { + "cluster_id": cluster_id, + "cluster_name": name, + "state": "PENDING", + "spark_version": spark_version, + "node_type_id": node_type_id, + "message": ( + f"Cluster '{name}' is being created (cluster_id='{cluster_id}'). " + f"It typically takes 3-8 minutes to start. " + f"Use get_cluster_status(cluster_id='{cluster_id}') to check progress." + ), + } + + +def modify_cluster( + cluster_id: str, + name: Optional[str] = None, + num_workers: Optional[int] = None, + spark_version: Optional[str] = None, + node_type_id: Optional[str] = None, + driver_node_type_id: Optional[str] = None, + autotermination_minutes: Optional[int] = None, + spark_conf: Optional[Dict[str, str]] = None, + autoscale_min_workers: Optional[int] = None, + autoscale_max_workers: Optional[int] = None, + **kwargs, +) -> Dict[str, Any]: + """Modify an existing Databricks cluster configuration. + + Fetches the current config, applies the requested changes, and calls the + edit API. The cluster will restart if it is running. + + Args: + cluster_id: ID of the cluster to modify. + name: New cluster name (optional). + num_workers: New fixed worker count (optional). + spark_version: New DBR version (optional). + node_type_id: New worker node type (optional). + driver_node_type_id: New driver node type (optional). + autotermination_minutes: New auto-termination timeout (optional). + spark_conf: Spark configuration overrides (optional). + autoscale_min_workers: Set to enable/modify autoscaling (optional). + autoscale_max_workers: Set to enable/modify autoscaling (optional). + **kwargs: Additional SDK parameters. + + Returns: + Dict with cluster_id, cluster_name, state, and message. + """ + w = get_workspace_client() + + # Get current cluster config + cluster = w.clusters.get(cluster_id) + + # Build edit kwargs from current config + edit_kwargs = { + "cluster_id": cluster_id, + "cluster_name": name or cluster.cluster_name, + "spark_version": spark_version or cluster.spark_version, + "node_type_id": node_type_id or cluster.node_type_id, + "driver_node_type_id": driver_node_type_id or cluster.driver_node_type_id or cluster.node_type_id, + } + + if autotermination_minutes is not None: + edit_kwargs["autotermination_minutes"] = autotermination_minutes + elif cluster.autotermination_minutes: + edit_kwargs["autotermination_minutes"] = cluster.autotermination_minutes + + if spark_conf is not None: + edit_kwargs["spark_conf"] = spark_conf + elif cluster.spark_conf: + edit_kwargs["spark_conf"] = cluster.spark_conf + + # Handle data_security_mode and single_user_name from existing config + if cluster.data_security_mode: + edit_kwargs["data_security_mode"] = cluster.data_security_mode + if cluster.single_user_name: + edit_kwargs["single_user_name"] = cluster.single_user_name + + # Autoscale vs fixed workers + if autoscale_min_workers is not None and autoscale_max_workers is not None: + edit_kwargs["autoscale"] = AutoScale( + min_workers=autoscale_min_workers, + max_workers=autoscale_max_workers, + ) + elif num_workers is not None: + edit_kwargs["num_workers"] = num_workers + elif cluster.autoscale: + edit_kwargs["autoscale"] = cluster.autoscale + else: + edit_kwargs["num_workers"] = cluster.num_workers or 0 + + # Merge extra SDK params + edit_kwargs.update(kwargs) + + w.clusters.edit(**edit_kwargs) + + current_state = cluster.state.value if cluster.state else "UNKNOWN" + cluster_name = edit_kwargs["cluster_name"] + + return { + "cluster_id": cluster_id, + "cluster_name": cluster_name, + "state": current_state, + "message": ( + f"Cluster '{cluster_name}' configuration updated. " + + ( + "The cluster will restart to apply changes." + if current_state == "RUNNING" + else "Changes will take effect when the cluster starts." + ) + ), + } + + +def terminate_cluster(cluster_id: str) -> Dict[str, Any]: + """Stop a running Databricks cluster (reversible). + + The cluster is terminated but not deleted. It can be restarted later + with start_cluster(). This is a safe, reversible operation. + + Args: + cluster_id: ID of the cluster to terminate. + + Returns: + Dict with cluster_id, cluster_name, state, and message. + """ + w = get_workspace_client() + cluster = w.clusters.get(cluster_id) + cluster_name = cluster.cluster_name or cluster_id + current_state = cluster.state.value if cluster.state else "UNKNOWN" + + if current_state == "TERMINATED": + return { + "cluster_id": cluster_id, + "cluster_name": cluster_name, + "state": "TERMINATED", + "message": f"Cluster '{cluster_name}' is already terminated.", + } + + w.clusters.delete(cluster_id) # SDK's delete = terminate (confusing but correct) + + return { + "cluster_id": cluster_id, + "cluster_name": cluster_name, + "previous_state": current_state, + "state": "TERMINATING", + "message": f"Cluster '{cluster_name}' is being terminated. This is reversible — use start_cluster() to restart.", + } + + +def delete_cluster(cluster_id: str) -> Dict[str, Any]: + """Permanently delete a Databricks cluster. + + WARNING: This action is PERMANENT and cannot be undone. The cluster + and its configuration will be permanently removed. + + Args: + cluster_id: ID of the cluster to permanently delete. + + Returns: + Dict with cluster_id, cluster_name, and warning message. + """ + w = get_workspace_client() + cluster = w.clusters.get(cluster_id) + cluster_name = cluster.cluster_name or cluster_id + + w.clusters.permanent_delete(cluster_id) + + return { + "cluster_id": cluster_id, + "cluster_name": cluster_name, + "state": "DELETED", + "message": ( + f"WARNING: Cluster '{cluster_name}' has been PERMANENTLY deleted. " + f"This action cannot be undone. The cluster configuration is gone." + ), + } + + +def list_node_types() -> List[Dict[str, Any]]: + """List available VM/node types for cluster creation. + + Returns a summary of each node type including ID, memory, cores, + and GPU info. Useful for choosing node_type_id when creating clusters. + + Returns: + List of node type info dicts. + """ + w = get_workspace_client() + result = w.clusters.list_node_types() + + node_types = [] + for nt in result.node_types: + node_types.append({ + "node_type_id": nt.node_type_id, + "memory_mb": nt.memory_mb, + "num_cores": getattr(nt, "num_cores", None), + "num_gpus": getattr(nt, "num_gpus", None) or 0, + "description": getattr(nt, "description", None) or nt.node_type_id, + "is_deprecated": getattr(nt, "is_deprecated", False), + }) + return node_types + + +def list_spark_versions() -> List[Dict[str, Any]]: + """List available Databricks Runtime (Spark) versions. + + Returns version key and name. Filter for "LTS" in the name to find + long-term support versions. + + Returns: + List of dicts with key and name for each version. + """ + w = get_workspace_client() + result = w.clusters.spark_versions() + + versions = [] + for v in result.versions: + versions.append({ + "key": v.key, + "name": v.name, + }) + return versions + + +# --- SQL Warehouses --- + + +def create_sql_warehouse( + name: str, + size: str = "Small", + min_num_clusters: int = 1, + max_num_clusters: int = 1, + auto_stop_mins: int = 120, + warehouse_type: str = "PRO", + enable_serverless: bool = True, + **kwargs, +) -> Dict[str, Any]: + """Create a new SQL warehouse with sensible defaults. + + By default creates a serverless Pro warehouse with auto-stop at 120 minutes. + + Args: + name: Human-readable warehouse name. + size: T-shirt size ("2X-Small", "X-Small", "Small", "Medium", "Large", + "X-Large", "2X-Large", "3X-Large", "4X-Large"). Default "Small". + min_num_clusters: Minimum number of clusters. Default 1. + max_num_clusters: Maximum number of clusters for scaling. Default 1. + auto_stop_mins: Minutes of inactivity before auto-stop. Default 120. + warehouse_type: "PRO", "CLASSIC", or "TYPE_UNSPECIFIED". Default "PRO". + enable_serverless: Enable serverless compute. Default True. + **kwargs: Additional SDK parameters. + + Returns: + Dict with warehouse_id, name, state, and message. + """ + w = get_workspace_client() + + from databricks.sdk.service.sql import ( + CreateWarehouseRequestWarehouseType, + EndpointInfoWarehouseType, + ) + + # Map warehouse type string to enum + type_map = { + "PRO": CreateWarehouseRequestWarehouseType.PRO, + "CLASSIC": CreateWarehouseRequestWarehouseType.CLASSIC, + "TYPE_UNSPECIFIED": CreateWarehouseRequestWarehouseType.TYPE_UNSPECIFIED, + } + wh_type = type_map.get(warehouse_type.upper(), CreateWarehouseRequestWarehouseType.PRO) + + create_kwargs = { + "name": name, + "cluster_size": size, + "min_num_clusters": min_num_clusters, + "max_num_clusters": max_num_clusters, + "auto_stop_mins": auto_stop_mins, + "warehouse_type": wh_type, + "enable_serverless_compute": enable_serverless, + } + create_kwargs.update(kwargs) + + wait = w.warehouses.create(**create_kwargs) + warehouse_id = wait.id + + return { + "warehouse_id": warehouse_id, + "name": name, + "size": size, + "state": "STARTING", + "message": ( + f"SQL warehouse '{name}' is being created (warehouse_id='{warehouse_id}'). " + f"It typically takes 1-3 minutes to start." + ), + } + + +def modify_sql_warehouse( + warehouse_id: str, + name: Optional[str] = None, + size: Optional[str] = None, + min_num_clusters: Optional[int] = None, + max_num_clusters: Optional[int] = None, + auto_stop_mins: Optional[int] = None, + **kwargs, +) -> Dict[str, Any]: + """Modify an existing SQL warehouse configuration. + + Only the specified parameters are changed; others remain as-is. + + Args: + warehouse_id: ID of the warehouse to modify. + name: New warehouse name (optional). + size: New T-shirt size (optional). + min_num_clusters: New minimum clusters (optional). + max_num_clusters: New maximum clusters (optional). + auto_stop_mins: New auto-stop timeout in minutes (optional). + **kwargs: Additional SDK parameters. + + Returns: + Dict with warehouse_id, name, state, and message. + """ + w = get_workspace_client() + + # Get current config + wh = w.warehouses.get(warehouse_id) + + edit_kwargs = { + "id": warehouse_id, + "name": name or wh.name, + "cluster_size": size or wh.cluster_size, + "min_num_clusters": min_num_clusters if min_num_clusters is not None else wh.min_num_clusters, + "max_num_clusters": max_num_clusters if max_num_clusters is not None else wh.max_num_clusters, + "auto_stop_mins": auto_stop_mins if auto_stop_mins is not None else wh.auto_stop_mins, + } + edit_kwargs.update(kwargs) + + w.warehouses.edit(**edit_kwargs) + + current_state = wh.state.value if wh.state else "UNKNOWN" + wh_name = edit_kwargs["name"] + + return { + "warehouse_id": warehouse_id, + "name": wh_name, + "state": current_state, + "message": f"SQL warehouse '{wh_name}' configuration updated.", + } + + +def delete_sql_warehouse(warehouse_id: str) -> Dict[str, Any]: + """Permanently delete a SQL warehouse. + + WARNING: This action is PERMANENT and cannot be undone. The warehouse + and its configuration will be permanently removed. + + Args: + warehouse_id: ID of the warehouse to permanently delete. + + Returns: + Dict with warehouse_id, name, and warning message. + """ + w = get_workspace_client() + + # Get warehouse info before deleting + wh = w.warehouses.get(warehouse_id) + wh_name = wh.name or warehouse_id + + w.warehouses.delete(warehouse_id) + + return { + "warehouse_id": warehouse_id, + "name": wh_name, + "state": "DELETED", + "message": ( + f"WARNING: SQL warehouse '{wh_name}' has been PERMANENTLY deleted. " + f"This action cannot be undone." + ), + } diff --git a/databricks-tools-core/databricks_tools_core/compute/serverless.py b/databricks-tools-core/databricks_tools_core/compute/serverless.py new file mode 100644 index 00000000..433846cd --- /dev/null +++ b/databricks-tools-core/databricks_tools_core/compute/serverless.py @@ -0,0 +1,427 @@ +""" +Compute - Serverless Code Execution + +Execute Python or SQL code on Databricks serverless compute via the Jobs API +(runs/submit). No interactive cluster required. + +Usage: + from databricks_tools_core.compute.serverless import run_code_on_serverless + + result = run_code_on_serverless("print('hello')", language="python") + result = run_code_on_serverless("SELECT 1", language="sql") +""" + +import base64 +import datetime +import json +import logging +import time +import uuid +from dataclasses import dataclass +from typing import Dict, Any, Optional + +from databricks.sdk.service.compute import Environment +from databricks.sdk.service.jobs import ( + JobEnvironment, + NotebookTask, + RunResultState, + SubmitTask, +) +from databricks.sdk.service.workspace import ImportFormat, Language + +from ..auth import get_workspace_client, get_current_username + +logger = logging.getLogger(__name__) + +# Language string to workspace Language enum +_LANGUAGE_MAP = { + "python": Language.PYTHON, + "sql": Language.SQL, +} + + +@dataclass +class ServerlessRunResult: + """Result from serverless code execution via Jobs API. + + Attributes: + success: Whether the execution completed successfully. + output: The output from the execution (notebook result or logs). + error: Error message if execution failed. + run_id: Databricks Jobs run ID. + run_url: URL to the run in the Databricks UI. + duration_seconds: Wall-clock duration of the execution. + state: Final state string (SUCCESS, FAILED, TIMEDOUT, CANCELED, etc.). + message: Human-readable summary of the result. + """ + + success: bool + output: Optional[str] = None + error: Optional[str] = None + run_id: Optional[int] = None + run_url: Optional[str] = None + duration_seconds: Optional[float] = None + state: Optional[str] = None + message: Optional[str] = None + workspace_path: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + d = { + "success": self.success, + "output": self.output, + "error": self.error, + "run_id": self.run_id, + "run_url": self.run_url, + "duration_seconds": self.duration_seconds, + "state": self.state, + "message": self.message, + } + if self.workspace_path: + d["workspace_path"] = self.workspace_path + return d + + +def _get_temp_notebook_path(run_label: str) -> str: + """Build a workspace path for a temporary serverless notebook. + + Args: + run_label: Unique label for this run. + + Returns: + Workspace path string under the current user's home directory. + """ + username = get_current_username() + base = f"/Workspace/Users/{username}" if username else "/Workspace" + return f"{base}/.ai_dev_kit_tmp/{run_label}" + + +def _is_ipynb(content: str) -> bool: + """Check if content is a Jupyter notebook (.ipynb) JSON structure.""" + try: + data = json.loads(content) + return isinstance(data, dict) and "cells" in data + except (json.JSONDecodeError, ValueError): + return False + + +def _upload_temp_notebook( + code: str, language: str, workspace_path: str, is_jupyter: bool = False +) -> None: + """Upload code as a temporary notebook to the Databricks workspace. + + Args: + code: Source code or .ipynb JSON content to upload. + language: Language string ("python" or "sql"). Ignored for Jupyter uploads. + workspace_path: Target workspace path (without file extension). + is_jupyter: If True, upload as Jupyter format (ImportFormat.JUPYTER). + + Raises: + Exception: If the upload fails. + """ + w = get_workspace_client() + content_b64 = base64.b64encode(code.encode("utf-8")).decode("utf-8") + + # Ensure parent directory exists + parent = workspace_path.rsplit("/", 1)[0] + try: + w.workspace.mkdirs(parent) + except Exception: + pass # Directory may already exist + + if is_jupyter: + w.workspace.import_( + path=workspace_path, + content=content_b64, + format=ImportFormat.JUPYTER, + overwrite=True, + ) + else: + lang_enum = _LANGUAGE_MAP[language] + w.workspace.import_( + path=workspace_path, + content=content_b64, + language=lang_enum, + format=ImportFormat.SOURCE, + overwrite=True, + ) + + +def _cleanup_temp_notebook(workspace_path: str) -> None: + """Delete a temporary notebook from the workspace (best-effort).""" + try: + w = get_workspace_client() + w.workspace.delete(path=workspace_path, recursive=False) + except Exception as e: + logger.debug(f"Cleanup of {workspace_path} failed (non-fatal): {e}") + + +def _get_run_output(task_run_id: int) -> Dict[str, Optional[str]]: + """Retrieve output and error text from a completed task run. + + Args: + task_run_id: The run ID of the specific task (not the parent run). + + Returns: + Dict with ``output`` and ``error`` keys (both may be None). + """ + w = get_workspace_client() + result: Dict[str, Optional[str]] = {"output": None, "error": None} + + try: + run_output = w.jobs.get_run_output(run_id=task_run_id) + + # Notebook output (from dbutils.notebook.exit() or last cell) + if run_output.notebook_output and run_output.notebook_output.result: + result["output"] = run_output.notebook_output.result + + # Logs (stdout/stderr, typically for spark_python_task) + if run_output.logs: + if result["output"]: + result["output"] += f"\n\n--- Logs ---\n{run_output.logs}" + else: + result["output"] = run_output.logs + + # Error details + if run_output.error: + error_parts = [run_output.error] + if run_output.error_trace: + error_parts.append(run_output.error_trace) + result["error"] = "\n\n".join(error_parts) + + except Exception as e: + logger.debug(f"Failed to get output for task run {task_run_id}: {e}") + result["error"] = str(e) + + return result + + +def run_code_on_serverless( + code: str, + language: str = "python", + timeout: int = 1800, + run_name: Optional[str] = None, + cleanup: bool = True, + workspace_path: Optional[str] = None, +) -> ServerlessRunResult: + """Execute code on serverless compute via Jobs API runs/submit. + + Uploads the code as a notebook, submits it as a one-time run on serverless + compute (no cluster required), waits for completion, and retrieves output. + + Two modes: + - **Ephemeral** (default): Uploads to a temp path and cleans up after. + - **Persistent**: If ``workspace_path`` is provided, uploads to that path + and keeps it after execution. Useful for project notebooks (model training, + ETL) the user wants saved in their workspace. + + Jupyter notebooks (.ipynb) are also supported. If the code content is + detected as .ipynb JSON (contains "cells" key), it is uploaded using + Databricks' native Jupyter import (ImportFormat.JUPYTER). The language + parameter is ignored in this case since the notebook carries its own + kernel metadata. + + SQL is supported but SELECT query results are NOT captured in the output. + SQL via this tool is only useful for DDL/DML (CREATE TABLE, INSERT, MERGE). + For SQL that needs result rows, use execute_sql() instead. + + Args: + code: Code to execute, or raw .ipynb JSON content (auto-detected). + language: Programming language ("python" or "sql"). Ignored for .ipynb. + timeout: Maximum wait time in seconds (default: 1800 = 30 minutes). + run_name: Optional human-readable run name. Auto-generated if omitted. + cleanup: Delete the notebook after execution (default: True). + Ignored when ``workspace_path`` is provided (persistent mode never cleans up). + workspace_path: Optional workspace path to save the notebook to + (e.g. "/Workspace/Users/user@company.com/my-project/train"). + If provided, the notebook is persisted at this path. If omitted, + a temporary path is used and cleaned up after execution. + + Returns: + ServerlessRunResult with output, error, run_id, run_url, and timing info. + """ + if not code or not code.strip(): + return ServerlessRunResult( + success=False, + error="Code cannot be empty.", + state="INVALID_INPUT", + message="No code provided to execute.", + ) + + # Auto-detect .ipynb content + is_jupyter = _is_ipynb(code) + + language = language.lower() + if not is_jupyter and language not in _LANGUAGE_MAP: + return ServerlessRunResult( + success=False, + error=f"Unsupported language: {language!r}. Must be 'python' or 'sql'.", + state="INVALID_INPUT", + message=f"Unsupported language {language!r}. Use 'python' or 'sql'.", + ) + + unique_id = uuid.uuid4().hex[:12] + if not run_name: + run_name = f"ai_dev_kit_serverless_{unique_id}" + + # Persistent mode: user-specified path, never cleanup + if workspace_path: + notebook_path = workspace_path + cleanup = False + else: + notebook_path = _get_temp_notebook_path(f"serverless_{unique_id}") + + start_time = time.time() + w = get_workspace_client() + + # --- Step 1: Upload code as a notebook --- + try: + _upload_temp_notebook(code, language, notebook_path, is_jupyter=is_jupyter) + except Exception as e: + return ServerlessRunResult( + success=False, + error=f"Failed to upload code to workspace: {e}", + state="UPLOAD_FAILED", + message=f"Could not upload temporary notebook: {e}", + ) + + run_id = None + run_url = None + + try: + # --- Step 2: Submit serverless run --- + try: + wait = w.jobs.submit( + run_name=run_name, + tasks=[ + SubmitTask( + task_key="main", + notebook_task=NotebookTask(notebook_path=notebook_path), + environment_key="Default", + ) + ], + environments=[ + JobEnvironment( + environment_key="Default", + spec=Environment(client="1"), + ) + ], + ) + # Extract run_id from the Wait object + run_id = getattr(wait, "run_id", None) + if run_id is None and hasattr(wait, "response"): + run_id = getattr(wait.response, "run_id", None) + + # Get the canonical run URL immediately via get_run so the user + # can monitor progress even before the run completes. + if run_id: + try: + initial_run = w.jobs.get_run(run_id=run_id) + run_url = initial_run.run_page_url + except Exception: + pass # Fall back to no URL rather than a guessed one + + except Exception as e: + return ServerlessRunResult( + success=False, + error=f"Failed to submit serverless run: {e}", + state="SUBMIT_FAILED", + message=f"Jobs API runs/submit call failed: {e}", + ) + + # --- Step 3: Wait for completion --- + try: + run = wait.result(timeout=datetime.timedelta(seconds=timeout)) + except TimeoutError: + elapsed = time.time() - start_time + return ServerlessRunResult( + success=False, + error=f"Run timed out after {timeout}s.", + run_id=run_id, + run_url=run_url, + duration_seconds=round(elapsed, 2), + state="TIMEDOUT", + message=(f"Serverless run {run_id} did not complete within {timeout}s. Check status at {run_url}"), + ) + except Exception as e: + elapsed = time.time() - start_time + error_text = str(e) + + # Best-effort: retrieve the actual error traceback from run output + if run_id: + try: + failed_run = w.jobs.get_run(run_id=run_id) + if failed_run.tasks: + task_run_id = failed_run.tasks[0].run_id + output_data = _get_run_output(task_run_id) + if output_data.get("error"): + error_text = output_data["error"] + except Exception: + pass # Fall back to the original exception message + + return ServerlessRunResult( + success=False, + error=error_text, + run_id=run_id, + run_url=run_url, + duration_seconds=round(elapsed, 2), + state="FAILED", + message=f"Run {run_id} failed: {e}", + ) + + elapsed = time.time() - start_time + + # --- Step 4: Determine result state --- + result_state = None + state_message = None + if run.state: + result_state = run.state.result_state + state_message = run.state.state_message + + # Prefer the canonical URL from the Run object + if run.run_page_url: + run_url = run.run_page_url + + is_success = result_state == RunResultState.SUCCESS + state_str = result_state.value if result_state else "UNKNOWN" + + # --- Step 5: Retrieve output --- + task_run_id = None + if run.tasks: + task_run_id = run.tasks[0].run_id + + output_text = None + error_text = None + + if task_run_id: + output_data = _get_run_output(task_run_id) + output_text = output_data["output"] + error_text = output_data["error"] + + # Fallback error from state message + if not is_success and not error_text: + error_text = state_message or f"Run ended with state: {state_str}" + + if is_success: + if not output_text: + output_text = "Success (no output)" + message = f"Code executed successfully on serverless compute in {round(elapsed, 1)}s." + else: + message = f"Serverless run failed with state {state_str}. Check {run_url} for details." + + return ServerlessRunResult( + success=is_success, + output=output_text if is_success else None, + error=error_text if not is_success else None, + run_id=run_id, + run_url=run_url, + duration_seconds=round(elapsed, 2), + state=state_str, + message=message, + workspace_path=notebook_path if workspace_path else None, + ) + + finally: + # --- Step 6: Cleanup temporary notebook --- + if cleanup: + _cleanup_temp_notebook(notebook_path) diff --git a/databricks-tools-core/tests/integration/compute/test_execution.py b/databricks-tools-core/tests/integration/compute/test_execution.py index 073faa54..ec459770 100644 --- a/databricks-tools-core/tests/integration/compute/test_execution.py +++ b/databricks-tools-core/tests/integration/compute/test_execution.py @@ -1,21 +1,28 @@ """ Integration tests for compute execution functions. -Tests execute_databricks_command and run_python_file_on_databricks with a real cluster. +Tests execute_databricks_command, run_file_on_databricks (with language detection, +workspace_path persistence), and backwards-compatible run_python_file_on_databricks. """ +import logging import tempfile import pytest from pathlib import Path from databricks_tools_core.compute import ( execute_databricks_command, + run_file_on_databricks, run_python_file_on_databricks, list_clusters, get_best_cluster, destroy_context, NoRunningClusterError, + ExecutionResult, ) +from databricks_tools_core.auth import get_workspace_client, get_current_username + +logger = logging.getLogger(__name__) @pytest.fixture(scope="module") @@ -286,3 +293,180 @@ def test_file_not_found(self): assert not result.success assert "not found" in result.error.lower() + + +@pytest.mark.integration +class TestRunFileOnDatabricks: + """Tests for run_file_on_databricks (renamed from run_python_file_on_databricks). + + Covers: language auto-detection, multi-language support, workspace_path persistence, + backwards compatibility alias. + """ + + def test_backwards_compat_alias(self): + """run_python_file_on_databricks should be an alias for run_file_on_databricks.""" + assert run_python_file_on_databricks is run_file_on_databricks + + def test_python_auto_detect(self, shared_context): + """Should auto-detect Python from .py extension.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write('print("auto-detected python")') + temp_path = f.name + + try: + result = run_file_on_databricks( + file_path=temp_path, + cluster_id=shared_context["cluster_id"], + context_id=shared_context["context_id"], + timeout=120, + ) + assert result.success, f"Execution failed: {result.error}" + assert "auto-detected python" in result.output + finally: + Path(temp_path).unlink(missing_ok=True) + + def test_sql_auto_detect(self, shared_context): + """Should auto-detect SQL from .sql extension.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".sql", delete=False) as f: + f.write("SELECT 42 as answer") + temp_path = f.name + + try: + result = run_file_on_databricks( + file_path=temp_path, + cluster_id=shared_context["cluster_id"], + language=None, # should auto-detect + timeout=120, + ) + + logger.info(f"SQL auto-detect: success={result.success}, output={result.output}") + + assert result.success, f"SQL execution failed: {result.error}" + finally: + Path(temp_path).unlink(missing_ok=True) + + def test_explicit_language_override(self, shared_context): + """Should use explicit language even if extension differs.""" + # File has .txt extension but we specify python + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f: + f.write('print("explicit python")') + temp_path = f.name + + try: + result = run_file_on_databricks( + file_path=temp_path, + cluster_id=shared_context["cluster_id"], + context_id=shared_context["context_id"], + language="python", + timeout=120, + ) + assert result.success, f"Execution failed: {result.error}" + assert "explicit python" in result.output + finally: + Path(temp_path).unlink(missing_ok=True) + + def test_empty_file(self): + """Should reject empty files gracefully.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write("") + temp_path = f.name + + try: + result = run_file_on_databricks(file_path=temp_path, timeout=120) + assert not result.success + assert "empty" in result.error.lower() + finally: + Path(temp_path).unlink(missing_ok=True) + + def test_returns_execution_result(self, shared_context): + """Should return ExecutionResult type.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write('print("type check")') + temp_path = f.name + + try: + result = run_file_on_databricks( + file_path=temp_path, + cluster_id=shared_context["cluster_id"], + context_id=shared_context["context_id"], + timeout=120, + ) + assert isinstance(result, ExecutionResult) + assert result.success + d = result.to_dict() + assert isinstance(d, dict) + assert "success" in d + assert "cluster_id" in d + assert "context_id" in d + finally: + Path(temp_path).unlink(missing_ok=True) + + +@pytest.mark.integration +class TestRunFileWorkspacePath: + """Tests for run_file_on_databricks with workspace_path (persistent mode).""" + + @pytest.fixture(autouse=True) + def _setup_cleanup(self): + """Track workspace paths for cleanup.""" + self._paths_to_cleanup = [] + yield + try: + w = get_workspace_client() + for path in self._paths_to_cleanup: + try: + w.workspace.delete(path=path, recursive=False) + logger.info(f"Cleaned up: {path}") + except Exception: + pass + except Exception: + pass + + def test_workspace_path_uploads_notebook(self, shared_context): + """Should upload file as notebook when workspace_path is provided.""" + username = get_current_username() + ws_path = f"/Workspace/Users/{username}/.ai_dev_kit_test/file_persist_test" + self._paths_to_cleanup.append(ws_path) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write('print("persisted via run_file")') + temp_path = f.name + + try: + result = run_file_on_databricks( + file_path=temp_path, + cluster_id=shared_context["cluster_id"], + context_id=shared_context["context_id"], + workspace_path=ws_path, + timeout=120, + ) + + logger.info(f"Workspace path result: success={result.success}") + + assert result.success, f"Execution failed: {result.error}" + assert "persisted via run_file" in result.output + + # Verify notebook exists in workspace + w = get_workspace_client() + status = w.workspace.get_status(ws_path) + assert status is not None + finally: + Path(temp_path).unlink(missing_ok=True) + + def test_workspace_path_none_no_upload(self, shared_context): + """Without workspace_path, no notebook should be uploaded (ephemeral).""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write('print("ephemeral file")') + temp_path = f.name + + try: + result = run_file_on_databricks( + file_path=temp_path, + cluster_id=shared_context["cluster_id"], + context_id=shared_context["context_id"], + timeout=120, + ) + assert result.success, f"Execution failed: {result.error}" + # No workspace_path on ExecutionResult — just verify execution worked + finally: + Path(temp_path).unlink(missing_ok=True) diff --git a/databricks-tools-core/tests/integration/compute/test_manage.py b/databricks-tools-core/tests/integration/compute/test_manage.py new file mode 100644 index 00000000..f5c7e1a7 --- /dev/null +++ b/databricks-tools-core/tests/integration/compute/test_manage.py @@ -0,0 +1,326 @@ +""" +Integration tests for compute management functions. + +Tests create_cluster, modify_cluster, terminate_cluster, delete_cluster, +list_node_types, list_spark_versions, create_sql_warehouse, modify_sql_warehouse, +and delete_sql_warehouse. + +Requires a valid Databricks connection (e.g. DATABRICKS_CONFIG_PROFILE=E2-Demo). +""" + +import logging +import time +import pytest + +from databricks_tools_core.compute import ( + create_cluster, + modify_cluster, + terminate_cluster, + delete_cluster, + list_node_types, + list_spark_versions, + create_sql_warehouse, + modify_sql_warehouse, + delete_sql_warehouse, + get_cluster_status, +) +from databricks_tools_core.auth import get_workspace_client + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="module") +def managed_cluster(): + """Create a test cluster and clean it up after all tests. + + Yields the cluster_id. The cluster is permanently deleted after the + test module completes. + """ + result = create_cluster( + name="ai-dev-kit-test-manage", + num_workers=0, + autotermination_minutes=10, + ) + + assert result["cluster_id"] is not None + cluster_id = result["cluster_id"] + + logger.info(f"Created test cluster: {cluster_id}") + + yield cluster_id + + # Cleanup: permanently delete + try: + delete_cluster(cluster_id) + logger.info(f"Cleaned up test cluster: {cluster_id}") + except Exception as e: + logger.warning(f"Failed to cleanup test cluster {cluster_id}: {e}") + + +@pytest.fixture(scope="module") +def managed_warehouse(): + """Create a test SQL warehouse and clean it up after all tests. + + Yields the warehouse_id. The warehouse is permanently deleted after the + test module completes. + """ + result = create_sql_warehouse( + name="ai-dev-kit-test-manage", + size="2X-Small", + auto_stop_mins=10, + enable_serverless=True, + ) + + assert result["warehouse_id"] is not None + warehouse_id = result["warehouse_id"] + + logger.info(f"Created test warehouse: {warehouse_id}") + + yield warehouse_id + + # Cleanup: permanently delete + try: + delete_sql_warehouse(warehouse_id) + logger.info(f"Cleaned up test warehouse: {warehouse_id}") + except Exception as e: + logger.warning(f"Failed to cleanup test warehouse {warehouse_id}: {e}") + + +@pytest.mark.integration +class TestListNodeTypes: + """Tests for list_node_types function.""" + + def test_list_node_types(self): + """Should return a non-empty list of node types.""" + node_types = list_node_types() + + print(f"\n=== List Node Types ===") + print(f"Found {len(node_types)} node types") + for nt in node_types[:5]: + print(f" - {nt['node_type_id']} ({nt['memory_mb']}MB, {nt['num_cores']} cores)") + + assert isinstance(node_types, list) + assert len(node_types) > 0 + assert "node_type_id" in node_types[0] + assert "memory_mb" in node_types[0] + + def test_node_type_has_expected_fields(self): + """Each node type should have expected fields.""" + node_types = list_node_types() + nt = node_types[0] + + assert "node_type_id" in nt + assert "memory_mb" in nt + assert "num_gpus" in nt + assert "description" in nt + + +@pytest.mark.integration +class TestListSparkVersions: + """Tests for list_spark_versions function.""" + + def test_list_spark_versions(self): + """Should return a non-empty list of spark versions.""" + versions = list_spark_versions() + + print(f"\n=== List Spark Versions ===") + print(f"Found {len(versions)} versions") + for v in versions[:5]: + print(f" - {v['key']}: {v['name']}") + + assert isinstance(versions, list) + assert len(versions) > 0 + assert "key" in versions[0] + assert "name" in versions[0] + + def test_has_lts_versions(self): + """Should include at least one LTS version.""" + versions = list_spark_versions() + lts = [v for v in versions if "LTS" in (v["name"] or "")] + assert len(lts) > 0, "No LTS versions found" + + +@pytest.mark.integration +class TestCreateCluster: + """Tests for create_cluster function.""" + + def test_create_cluster_returns_expected_fields(self, managed_cluster): + """managed_cluster fixture validates create_cluster returns cluster_id. + + This test just verifies the cluster exists. + """ + status = get_cluster_status(managed_cluster) + + print(f"\n=== Created Cluster Status ===") + print(f"Cluster ID: {status['cluster_id']}") + print(f"Name: {status['cluster_name']}") + print(f"State: {status['state']}") + + assert status["cluster_id"] == managed_cluster + assert status["cluster_name"] == "ai-dev-kit-test-manage" + + +@pytest.mark.integration +class TestTerminateCluster: + """Tests for terminate_cluster function.""" + + def test_terminate_cluster(self, managed_cluster): + """Should terminate the cluster (reversible).""" + result = terminate_cluster(managed_cluster) + + print(f"\n=== Terminate Cluster ===") + print(f"Result: {result}") + + assert result["cluster_id"] == managed_cluster + assert result["state"] in ("TERMINATING", "TERMINATED") + assert "reversible" in result["message"].lower() or "terminated" in result["message"].lower() + + +@pytest.mark.integration +class TestModifyCluster: + """Tests for modify_cluster function. + + Runs after TestTerminateCluster so the cluster is in a stable (TERMINATED/TERMINATING) + state — the edit API rejects edits on PENDING clusters. + """ + + def _wait_for_terminated(self, cluster_id, timeout=120): + """Wait until cluster reaches TERMINATED state.""" + import time + start = time.time() + while time.time() - start < timeout: + status = get_cluster_status(cluster_id) + if status["state"] == "TERMINATED": + return + time.sleep(5) + pytest.fail(f"Cluster did not terminate within {timeout}s") + + def test_modify_cluster_name(self, managed_cluster): + """Should modify cluster name.""" + self._wait_for_terminated(managed_cluster) + + result = modify_cluster( + cluster_id=managed_cluster, + name="ai-dev-kit-test-manage-renamed", + ) + + print(f"\n=== Modify Cluster ===") + print(f"Result: {result}") + + assert result["cluster_id"] == managed_cluster + assert result["cluster_name"] == "ai-dev-kit-test-manage-renamed" + assert "updated" in result["message"].lower() + + # Rename back for other tests + modify_cluster( + cluster_id=managed_cluster, + name="ai-dev-kit-test-manage", + ) + + +@pytest.mark.integration +class TestCreateSqlWarehouse: + """Tests for create_sql_warehouse function.""" + + def test_create_warehouse_returns_expected_fields(self, managed_warehouse): + """managed_warehouse fixture validates create returns warehouse_id.""" + w = get_workspace_client() + wh = w.warehouses.get(managed_warehouse) + + print(f"\n=== Created Warehouse ===") + print(f"Warehouse ID: {wh.id}") + print(f"Name: {wh.name}") + print(f"State: {wh.state}") + + assert wh.id == managed_warehouse + assert wh.name == "ai-dev-kit-test-manage" + + +@pytest.mark.integration +class TestModifySqlWarehouse: + """Tests for modify_sql_warehouse function.""" + + def test_modify_warehouse_name(self, managed_warehouse): + """Should modify warehouse name.""" + result = modify_sql_warehouse( + warehouse_id=managed_warehouse, + name="ai-dev-kit-test-manage-renamed", + ) + + print(f"\n=== Modify Warehouse ===") + print(f"Result: {result}") + + assert result["warehouse_id"] == managed_warehouse + assert result["name"] == "ai-dev-kit-test-manage-renamed" + assert "updated" in result["message"].lower() + + # Rename back + modify_sql_warehouse( + warehouse_id=managed_warehouse, + name="ai-dev-kit-test-manage", + ) + + +@pytest.mark.integration +class TestDeleteCluster: + """Tests for delete_cluster function.""" + + def test_delete_cluster_warning_message(self): + """Should include a permanent deletion warning in the response.""" + # Create a throwaway cluster for deletion test + result = create_cluster( + name="ai-dev-kit-test-delete", + num_workers=0, + autotermination_minutes=10, + ) + cluster_id = result["cluster_id"] + + try: + delete_result = delete_cluster(cluster_id) + + print(f"\n=== Delete Cluster ===") + print(f"Result: {delete_result}") + + assert delete_result["cluster_id"] == cluster_id + assert delete_result["state"] == "DELETED" + assert "permanent" in delete_result["message"].lower() + assert "warning" in delete_result["message"].lower() + except Exception: + # Best-effort cleanup if delete fails + try: + delete_cluster(cluster_id) + except Exception: + pass + raise + + +@pytest.mark.integration +class TestDeleteSqlWarehouse: + """Tests for delete_sql_warehouse function.""" + + def test_delete_warehouse_warning_message(self): + """Should include a permanent deletion warning in the response.""" + # Create a throwaway warehouse for deletion test + result = create_sql_warehouse( + name="ai-dev-kit-test-delete", + size="2X-Small", + auto_stop_mins=10, + ) + warehouse_id = result["warehouse_id"] + + try: + delete_result = delete_sql_warehouse(warehouse_id) + + print(f"\n=== Delete Warehouse ===") + print(f"Result: {delete_result}") + + assert delete_result["warehouse_id"] == warehouse_id + assert delete_result["state"] == "DELETED" + assert "permanent" in delete_result["message"].lower() + assert "warning" in delete_result["message"].lower() + except Exception: + try: + delete_sql_warehouse(warehouse_id) + except Exception: + pass + raise diff --git a/databricks-tools-core/tests/integration/compute/test_serverless.py b/databricks-tools-core/tests/integration/compute/test_serverless.py new file mode 100644 index 00000000..7f68bb42 --- /dev/null +++ b/databricks-tools-core/tests/integration/compute/test_serverless.py @@ -0,0 +1,226 @@ +""" +Integration tests for serverless compute execution (run_code_on_serverless). + +Tests serverless Python/SQL execution, ephemeral vs persistent modes, +workspace_path, error handling, and input validation. +""" + +import logging +import pytest + +from databricks_tools_core.compute import ( + run_code_on_serverless, + ServerlessRunResult, +) +from databricks_tools_core.auth import get_workspace_client, get_current_username + +logger = logging.getLogger(__name__) + + +@pytest.mark.integration +class TestServerlessInputValidation: + """Tests for input validation (no cluster/serverless needed).""" + + def test_empty_code(self): + """Should reject empty code without submitting a run.""" + result = run_code_on_serverless(code="") + assert not result.success + assert result.state == "INVALID_INPUT" + assert "empty" in result.error.lower() + + def test_whitespace_only_code(self): + """Should reject whitespace-only code.""" + result = run_code_on_serverless(code=" \n\n ") + assert not result.success + assert result.state == "INVALID_INPUT" + + def test_unsupported_language(self): + """Should reject unsupported languages.""" + result = run_code_on_serverless(code="println('hi')", language="scala") + assert not result.success + assert result.state == "INVALID_INPUT" + assert "scala" in result.error.lower() + + def test_result_is_serverless_run_result(self): + """Should return ServerlessRunResult type even on validation errors.""" + result = run_code_on_serverless(code="", language="python") + assert isinstance(result, ServerlessRunResult) + + def test_to_dict(self): + """Should serialize to dict properly.""" + result = run_code_on_serverless(code="") + d = result.to_dict() + assert isinstance(d, dict) + assert "success" in d + assert "output" in d + assert "error" in d + assert "run_id" in d + assert "state" in d + + +@pytest.mark.integration +class TestServerlessPythonExecution: + """Tests for Python code execution on serverless compute.""" + + def test_simple_python_dbutils_exit(self): + """Should capture output from dbutils.notebook.exit().""" + code = 'dbutils.notebook.exit("hello from serverless")' + result = run_code_on_serverless(code=code, language="python") + + logger.info(f"Result: success={result.success}, output={result.output}, " + f"duration={result.duration_seconds}s") + + assert result.success, f"Execution failed: {result.error}" + assert "hello from serverless" in result.output + assert result.run_id is not None + assert result.run_url is not None + assert result.duration_seconds is not None + assert result.state == "SUCCESS" + + def test_python_computation(self): + """Should execute computation and return result via dbutils.notebook.exit().""" + code = """ +import math +result = sum(math.factorial(i) for i in range(10)) +dbutils.notebook.exit(str(result)) +""" + result = run_code_on_serverless(code=code, language="python") + + assert result.success, f"Execution failed: {result.error}" + assert "409114" in result.output # sum of 0! through 9! + + def test_python_error_handling(self): + """Should capture Python errors with traceback.""" + code = """ +x = 1 / 0 +""" + result = run_code_on_serverless(code=code, language="python") + + logger.info(f"Error result: success={result.success}, error={result.error[:200] if result.error else None}") + + assert not result.success + assert result.error is not None + assert "ZeroDivisionError" in result.error + assert result.state == "FAILED" + + def test_python_with_spark(self): + """Should have access to Spark on serverless.""" + code = """ +df = spark.range(10) +count = df.count() +dbutils.notebook.exit(str(count)) +""" + result = run_code_on_serverless(code=code, language="python") + + assert result.success, f"Execution failed: {result.error}" + assert "10" in result.output + + def test_custom_run_name(self): + """Should accept custom run name.""" + result = run_code_on_serverless( + code='dbutils.notebook.exit("named run")', + run_name="test_custom_name_integration", + ) + + assert result.success, f"Execution failed: {result.error}" + assert "named run" in result.output + + +@pytest.mark.integration +class TestServerlessSQLExecution: + """Tests for SQL execution on serverless compute.""" + + def test_sql_ddl(self): + """Should execute SQL DDL statements.""" + code = """ +CREATE DATABASE IF NOT EXISTS ai_dev_kit_serverless_test; +""" + result = run_code_on_serverless(code=code, language="sql") + + logger.info(f"SQL DDL result: success={result.success}, state={result.state}") + + assert result.success, f"SQL DDL failed: {result.error}" + + +@pytest.mark.integration +class TestServerlessEphemeralMode: + """Tests for ephemeral mode (default - temp notebook cleaned up).""" + + def test_ephemeral_no_workspace_path_in_result(self): + """Ephemeral mode should not include workspace_path in result.""" + result = run_code_on_serverless( + code='dbutils.notebook.exit("ephemeral")', + ) + + assert result.success, f"Execution failed: {result.error}" + assert result.workspace_path is None + + def test_ephemeral_to_dict_no_workspace_path(self): + """Ephemeral mode should not include workspace_path in dict.""" + result = run_code_on_serverless( + code='dbutils.notebook.exit("ephemeral dict")', + ) + + assert result.success, f"Execution failed: {result.error}" + d = result.to_dict() + assert "workspace_path" not in d + + +@pytest.mark.integration +class TestServerlessPersistentMode: + """Tests for persistent mode (workspace_path provided, notebook saved).""" + + @pytest.fixture(autouse=True) + def _setup_cleanup(self): + """Track workspace paths for cleanup after each test.""" + self._paths_to_cleanup = [] + yield + # Cleanup persisted notebooks + try: + w = get_workspace_client() + for path in self._paths_to_cleanup: + try: + w.workspace.delete(path=path, recursive=False) + logger.info(f"Cleaned up: {path}") + except Exception: + pass + except Exception: + pass + + def test_persistent_saves_notebook(self): + """Persistent mode should save notebook at workspace_path.""" + username = get_current_username() + ws_path = f"/Workspace/Users/{username}/.ai_dev_kit_test/persistent_test" + self._paths_to_cleanup.append(ws_path) + + result = run_code_on_serverless( + code='dbutils.notebook.exit("persisted!")', + workspace_path=ws_path, + ) + + logger.info(f"Persistent result: success={result.success}, " + f"workspace_path={result.workspace_path}") + + assert result.success, f"Execution failed: {result.error}" + assert result.workspace_path == ws_path + assert "persisted!" in result.output + + # Verify notebook exists in workspace + w = get_workspace_client() + status = w.workspace.get_status(ws_path) + assert status is not None + + def test_persistent_to_dict_includes_workspace_path(self): + """Persistent mode should include workspace_path in dict.""" + username = get_current_username() + ws_path = f"/Workspace/Users/{username}/.ai_dev_kit_test/persistent_dict_test" + self._paths_to_cleanup.append(ws_path) + + result = run_code_on_serverless( + code='dbutils.notebook.exit("dict test")', + workspace_path=ws_path, + ) + + assert result.success, f"Execution failed: {result.error}" + d = result.to_dict() + assert d["workspace_path"] == ws_path