diff --git a/databricks-mcp-server/databricks_mcp_server/tools/lakebase.py b/databricks-mcp-server/databricks_mcp_server/tools/lakebase.py index eb7ec0fc..bdd58ce9 100644 --- a/databricks-mcp-server/databricks_mcp_server/tools/lakebase.py +++ b/databricks-mcp-server/databricks_mcp_server/tools/lakebase.py @@ -21,6 +21,7 @@ create_synced_table as _create_synced_table, get_synced_table as _get_synced_table, delete_synced_table as _delete_synced_table, + execute_lakebase_query as _execute_query, ) # Autoscale core functions @@ -524,3 +525,62 @@ def generate_lakebase_credential( return _generate_autoscale_credential(endpoint=endpoint) else: return {"error": "Provide either instance_names (provisioned) or endpoint (autoscale)."} + + +# ============================================================================ +# Tool 9: query_lakebase +# ============================================================================ + + +@mcp.tool +def query_lakebase( + sql_query: str, + instance_name: Optional[str] = None, + endpoint: Optional[str] = None, + database: str = "databricks_postgres", + timeout: int = 60, +) -> Dict[str, Any]: + """ + Execute a SQL query against a Lakebase PostgreSQL instance. + + Supports both Provisioned and Autoscale Lakebase instances. + Provide either instance_name (for provisioned) or endpoint (for autoscale). + + Args: + sql_query: SQL query to execute (SELECT, INSERT, UPDATE, DELETE, etc.) + instance_name: Name of a Provisioned Lakebase instance (e.g., "my-instance") + endpoint: Autoscale endpoint - either full resource name + (e.g., "projects/xxx/branches/yyy/endpoints/zzz") or just the host + (e.g., "ep-xxx.database.eastus2.azuredatabricks.net") + database: PostgreSQL database name (default: "databricks_postgres") + timeout: Query timeout in seconds (default: 60) + + Returns: + Dictionary with: + - columns: List of column names + - data: List of rows (each row is a list of values) + - row_count: Number of rows returned + - type: "provisioned" or "autoscale" + - target: instance_name or endpoint used + + Examples: + # Provisioned instance + >>> query_lakebase("SELECT * FROM users", instance_name="my-instance") + + # Autoscale endpoint (by host) + >>> query_lakebase("SELECT 1", endpoint="ep-xxx.database.eastus2.azuredatabricks.net") + + # Autoscale endpoint (by resource name) + >>> query_lakebase("SELECT 1", endpoint="projects/abc/branches/main/endpoints/primary") + """ + try: + return _execute_query( + sql_query=sql_query, + instance_name=instance_name, + endpoint=endpoint, + database=database, + timeout=timeout, + ) + except Exception as e: + target = instance_name or endpoint or "unknown" + return {"error": str(e), "target": target} diff --git a/databricks-mcp-server/pyproject.toml b/databricks-mcp-server/pyproject.toml index 941c0f4d..072f49f3 100644 --- a/databricks-mcp-server/pyproject.toml +++ b/databricks-mcp-server/pyproject.toml @@ -22,6 +22,7 @@ classifiers = [ dependencies = [ "databricks-tools-core", "fastmcp>=0.1.0", + "psycopg2-binary>=2.9.0", ] [project.optional-dependencies] diff --git a/databricks-tools-core/databricks_tools_core/lakebase/__init__.py b/databricks-tools-core/databricks_tools_core/lakebase/__init__.py index 2f6e458c..084168d5 100644 --- a/databricks-tools-core/databricks_tools_core/lakebase/__init__.py +++ b/databricks-tools-core/databricks_tools_core/lakebase/__init__.py @@ -23,6 +23,10 @@ get_synced_table, delete_synced_table, ) +from .query import ( + execute_lakebase_query, + LakebaseQueryError, +) __all__ = [ # Instances @@ -40,4 +44,7 @@ "create_synced_table", "get_synced_table", "delete_synced_table", + # Query + "execute_lakebase_query", + "LakebaseQueryError", ] diff --git a/databricks-tools-core/databricks_tools_core/lakebase/query.py b/databricks-tools-core/databricks_tools_core/lakebase/query.py new file mode 100644 index 00000000..77b9c480 --- /dev/null +++ b/databricks-tools-core/databricks_tools_core/lakebase/query.py @@ -0,0 +1,230 @@ +""" +Lakebase Query Operations + +Functions for executing SQL queries against Lakebase PostgreSQL instances. +Supports both Provisioned and Autoscale instance types. +""" + +import logging +from typing import Any, Dict, Optional + +logger = logging.getLogger(__name__) + + +class LakebaseQueryError(Exception): + """Exception raised when Lakebase query execution fails.""" + + + +def execute_lakebase_query( + sql_query: str, + instance_name: Optional[str] = None, + endpoint: Optional[str] = None, + database: str = "databricks_postgres", + timeout: int = 60, +) -> Dict[str, Any]: + """ + Execute a SQL query against a Lakebase PostgreSQL instance. + + Supports both Provisioned and Autoscale Lakebase instances. + Provide either instance_name (for provisioned) or endpoint (for autoscale). + + Args: + sql_query: SQL query to execute + instance_name: Name of a Provisioned Lakebase instance (e.g., "my-instance") + endpoint: Autoscale endpoint - either full resource name + (e.g., "projects/xxx/branches/yyy/endpoints/zzz") or just the host + (e.g., "ep-xxx.database.eastus2.azuredatabricks.net") + database: PostgreSQL database name (default: "databricks_postgres") + timeout: Query timeout in seconds (default: 60) + + Returns: + Dictionary with: + - columns: List of column names + - data: List of rows (each row is a list of values) + - row_count: Number of rows returned + - type: "provisioned" or "autoscale" + - target: instance_name or endpoint used + + Raises: + LakebaseQueryError: If query execution fails + """ + try: + import psycopg2 + except ImportError: + raise LakebaseQueryError( + "psycopg2 is not installed. Install it with: pip install psycopg2-binary" + ) + + if not instance_name and not endpoint: + raise LakebaseQueryError( + "Provide either instance_name (for provisioned) or endpoint (for autoscale)" + ) + + # Determine instance type and get connection details + if instance_name: + host, token, username, instance_type = _get_provisioned_connection(instance_name) + target = instance_name + else: + host, token, username, instance_type = _get_autoscale_connection(endpoint) + target = endpoint + + # Connect and execute query + conn = None + try: + conn = psycopg2.connect( + host=host, + port=5432, + dbname=database, + user=username, + password=token, + sslmode="require", + connect_timeout=timeout, + ) + conn.set_session(readonly=False, autocommit=True) + + with conn.cursor() as cur: + cur.execute(sql_query) + + # Get column names + columns = [] + if cur.description: + columns = [desc[0] for desc in cur.description] + + # Fetch results (if any) + data = [] + if cur.description: # SELECT or RETURNING query + data = [list(row) for row in cur.fetchall()] + + return { + "columns": columns, + "data": data, + "row_count": len(data), + "type": instance_type, + "target": target, + "database": database, + } + + except psycopg2.Error as e: + raise LakebaseQueryError(f"PostgreSQL error: {e}") + except Exception as e: + raise LakebaseQueryError(f"Query execution failed: {e}") + finally: + if conn: + conn.close() + + +def _get_provisioned_connection(instance_name: str) -> tuple: + """Get connection details for a Provisioned Lakebase instance.""" + from .instances import get_lakebase_instance, generate_lakebase_credential + from ..auth import get_workspace_client + + # Get instance details + instance = get_lakebase_instance(instance_name) + if instance.get("state") == "NOT_FOUND": + raise LakebaseQueryError(f"Provisioned instance '{instance_name}' not found") + + host = instance.get("read_write_dns") + if not host: + raise LakebaseQueryError( + f"Instance '{instance_name}' does not have a read_write_dns endpoint. " + f"State: {instance.get('state')}" + ) + + # Check if instance is available + state = str(instance.get("state", "")) + if "STOPPED" in state: + raise LakebaseQueryError( + f"Instance '{instance_name}' is stopped. Start it first with " + f"update_lakebase_instance('{instance_name}', stopped=False)" + ) + + # Generate OAuth credential + cred = generate_lakebase_credential(instance_names=[instance_name]) + token = cred.get("token") + if not token: + raise LakebaseQueryError("Failed to generate OAuth token for provisioned instance") + + # Get username from current user + client = get_workspace_client() + try: + me = client.current_user.me() + username = me.user_name + except Exception: + username = "databricks" + + return host, token, username, "provisioned" + + +def _get_autoscale_connection(endpoint: str) -> tuple: + """Get connection details for an Autoscale Lakebase endpoint.""" + from ..lakebase_autoscale import generate_credential, get_endpoint, list_projects, list_branches, list_endpoints + from ..auth import get_workspace_client + + endpoint_name = None + host = None + + # Determine if endpoint is a host or resource name + if endpoint.startswith("projects/"): + # Full resource name - get endpoint details + ep_info = get_endpoint(endpoint) + host = ep_info.get("host") + endpoint_name = endpoint + if not host: + raise LakebaseQueryError(f"Endpoint '{endpoint}' does not have a host") + elif ".database." in endpoint and ".azuredatabricks.net" in endpoint: + # It's a host - need to find the full endpoint name by searching + host = endpoint + endpoint_name = _find_endpoint_by_host(host) + if not endpoint_name: + raise LakebaseQueryError( + f"Could not find autoscale endpoint with host '{endpoint}'. " + "Try providing the full resource name instead." + ) + else: + raise LakebaseQueryError( + f"Invalid endpoint format: '{endpoint}'. Provide either a host " + "(e.g., 'ep-xxx.database.eastus2.azuredatabricks.net') or full resource name " + "(e.g., 'projects/xxx/branches/yyy/endpoints/zzz')" + ) + + # Generate OAuth credential for autoscale + cred = generate_credential(endpoint=endpoint_name) + token = cred.get("token") + if not token: + raise LakebaseQueryError("Failed to generate OAuth token for autoscale endpoint") + + # Get username from current user + client = get_workspace_client() + try: + me = client.current_user.me() + username = me.user_name + except Exception: + username = "databricks" + + return host, token, username, "autoscale" + + +def _find_endpoint_by_host(target_host: str) -> Optional[str]: + """Find the full endpoint resource name by searching for a matching host.""" + from ..lakebase_autoscale import list_projects, list_branches, list_endpoints + + try: + projects = list_projects() + for project in projects: + project_name = project.get("name", "") + project_id = project_name.split("/")[-1] if "/" in project_name else project_name + + branches = list_branches(project_id) + for branch in branches: + branch_name = branch.get("name", "") + + endpoints = list_endpoints(branch_name) + for ep in endpoints: + ep_host = ep.get("host", "") + if ep_host == target_host: + return ep.get("name") + except Exception as e: + logger.warning(f"Error searching for endpoint by host: {e}") + + return None diff --git a/databricks-tools-core/pyproject.toml b/databricks-tools-core/pyproject.toml index 3daf9b97..c63ad35b 100644 --- a/databricks-tools-core/pyproject.toml +++ b/databricks-tools-core/pyproject.toml @@ -31,6 +31,9 @@ dependencies = [ ] [project.optional-dependencies] +lakebase = [ + "psycopg2-binary>=2.9.0", +] dev = [ "pytest>=7.0.0", "pytest-timeout>=2.0.0",