diff --git a/README.md b/README.md
index 1d88f37b..6d1e0e9c 100644
--- a/README.md
+++ b/README.md
@@ -81,7 +81,7 @@ yarn dev:frontend
**Terminal 2 - Backend:**
```bash
cd src
-yarn dev:backend
+hatch -e dev run dev-backend
```
- Frontend: http://localhost:3000
@@ -171,5 +171,4 @@ This project is licensed under the Databricks License - see the [LICENSE.txt](LI
---
-**Version**: 0.4.0
**Maintained by**: [Databricks](https://databricks.com)
diff --git a/src/backend/.env.example b/src/backend/.env.example
index 4ba34745..89234f11 100644
--- a/src/backend/.env.example
+++ b/src/backend/.env.example
@@ -98,6 +98,15 @@ LLM_ENABLED=False
# Security: First-phase injection detection prompt (advanced - usually not changed)
# LLM_INJECTION_CHECK_PROMPT="You are a security analyzer. Analyze the following content..."
+# --- Graph Explorer Safety Limits ---
+# Maximum nodes/edges returned from initial graph load (prevents browser OOM)
+# GRAPH_MAX_NODES=5000
+# GRAPH_MAX_EDGES=10000
+# Default max neighbors per expansion
+# GRAPH_NEIGHBOR_LIMIT=50
+# SQL statement timeout for graph queries
+# GRAPH_QUERY_TIMEOUT=30s
+
# --- Self-Service Sandbox Policy Settings ---
# These settings control which catalogs and schemas users can create objects in
# via the self-service dialog. This is a global security boundary separate from
diff --git a/src/backend/src/app.py b/src/backend/src/app.py
index 5552e3b4..51301771 100644
--- a/src/backend/src/app.py
+++ b/src/backend/src/app.py
@@ -1,12 +1,12 @@
# Initialize configuration and logging first
from src.common.config import get_settings, init_config
-from src.common.logging import setup_logging, get_logger
+from src.common.logging import get_logger, setup_logging
+
init_config()
settings = get_settings()
setup_logging(level=settings.LOG_LEVEL, log_file=settings.LOG_FILE)
logger = get_logger(__name__)
-import mimetypes
import os
import time
from pathlib import Path
@@ -14,21 +14,22 @@
# Server startup timestamp for cache invalidation
SERVER_STARTUP_TIME = int(time.time())
-from fastapi import Depends, FastAPI, Request
+from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
-from starlette.responses import Response
-from fastapi import HTTPException, status
from src.common.middleware import ErrorHandlingMiddleware, LoggingMiddleware
from src.routes import (
access_grants_routes,
+ audit_routes,
catalog_commander_routes,
- data_catalog_routes,
- compliance_routes,
+ change_log_routes,
comments_routes,
+ compliance_routes,
+ costs_routes,
data_asset_reviews_routes,
+ data_catalog_routes,
data_contracts_routes,
data_domains_routes,
data_product_routes,
@@ -36,7 +37,6 @@
entitlements_routes,
entitlements_sync_routes,
estate_manager_routes,
- industry_ontology_routes,
jobs_routes,
llm_search_routes,
mcp_routes,
@@ -44,40 +44,21 @@
mdm_routes,
metadata_routes,
notifications_routes,
+ projects_routes,
search_routes,
security_features_routes,
self_service_routes,
- settings_routes,
- semantic_models_routes,
semantic_links_routes,
- user_routes,
- audit_routes,
- change_log_routes,
- workspace_routes,
+ semantic_models_routes,
+ settings_routes,
tags_routes,
teams_routes,
- projects_routes,
- costs_routes,
+ user_routes,
+ graph_explorer_routes,
workflows_routes,
+ workspace_routes,
)
-
-from src.common.database import init_db, get_session_factory, SQLAlchemySession
-from src.controller.data_products_manager import DataProductsManager
-from src.controller.data_asset_reviews_manager import DataAssetReviewManager
-from src.controller.data_contracts_manager import DataContractsManager
-from src.controller.semantic_models_manager import SemanticModelsManager
-from src.controller.search_manager import SearchManager
-from src.common.workspace_client import get_workspace_client
-from src.controller.settings_manager import SettingsManager
-from src.controller.users_manager import UsersManager
-from src.controller.authorization_manager import AuthorizationManager
-from src.utils.startup_tasks import (
- initialize_database,
- initialize_managers,
- startup_event_handler,
- shutdown_event_handler
-)
-
+from src.utils.startup_tasks import initialize_database, initialize_managers
logger.info(f"Starting application in {settings.ENV} mode.")
logger.info(f"Debug mode: {settings.DEBUG}")
@@ -94,28 +75,32 @@
# --- Application Lifecycle Events ---
+
# Application Startup Event
async def startup_event():
import os
-
+
# Skip startup tasks if running tests
- if os.getenv('SKIP_STARTUP_TASKS') == 'true':
+ if os.getenv("SKIP_STARTUP_TASKS") == "true":
logger.info("SKIP_STARTUP_TASKS=true detected - skipping startup tasks (test mode)")
return
-
+
logger.info("Running application startup event...")
settings = get_settings()
-
+
initialize_database(settings=settings)
initialize_managers(app) # Handles DB-backed manager init
-
+
# Initialize Git service for indirect delivery mode
try:
logger.info("Initializing Git service...")
from src.common.git import init_git_service
+
git_service = init_git_service(settings)
app.state.git_service = git_service
- logger.info(f"Git service initialized (status: {git_service.get_status().clone_status.value})")
+ logger.info(
+ f"Git service initialized (status: {git_service.get_status().clone_status.value})"
+ )
except Exception as e:
logger.warning(f"Failed initializing Git service: {e}", exc_info=True)
app.state.git_service = None
@@ -123,8 +108,9 @@ async def startup_event():
# Initialize Grant Manager for direct delivery mode
try:
logger.info("Initializing Grant Manager...")
- from src.controller.grant_manager import init_grant_manager
from src.common.workspace_client import get_workspace_client
+ from src.controller.grant_manager import init_grant_manager
+
ws_client = get_workspace_client(settings=settings)
grant_manager = init_grant_manager(ws_client=ws_client, settings=settings)
app.state.grant_manager = grant_manager
@@ -137,30 +123,36 @@ async def startup_event():
try:
logger.info("Initializing Delivery Service...")
from src.controller.delivery_service import init_delivery_service
+
delivery_service = init_delivery_service(
settings=settings,
- git_service=getattr(app.state, 'git_service', None),
- grant_manager=getattr(app.state, 'grant_manager', None),
- notifications_manager=getattr(app.state, 'notifications_manager', None),
+ git_service=getattr(app.state, "git_service", None),
+ grant_manager=getattr(app.state, "grant_manager", None),
+ notifications_manager=getattr(app.state, "notifications_manager", None),
)
app.state.delivery_service = delivery_service
- logger.info(f"Delivery Service initialized (active modes: {[m.value for m in delivery_service.get_active_modes()]})")
+ logger.info(
+ f"Delivery Service initialized (active modes: {[m.value for m in delivery_service.get_active_modes()]})"
+ )
except Exception as e:
logger.warning(f"Failed initializing Delivery Service: {e}", exc_info=True)
app.state.delivery_service = None
-
+
# Demo data is loaded on-demand via POST /api/settings/demo-data/load
# See: src/backend/src/data/demo_data.sql
-
+
# Ensure SearchManager is initialized and index built
try:
from src.common.search_interfaces import SearchableAsset
from src.controller.search_manager import SearchManager
+
logger.info("Initializing SearchManager after data load (app.py)...")
searchable_managers_instances = []
- for attr_name, manager_instance in list(getattr(app.state, '_state', {}).items()):
+ for attr_name, manager_instance in list(getattr(app.state, "_state", {}).items()):
try:
- if isinstance(manager_instance, SearchableAsset) and hasattr(manager_instance, 'get_search_index_items'):
+ if isinstance(manager_instance, SearchableAsset) and hasattr(
+ manager_instance, "get_search_index_items"
+ ):
searchable_managers_instances.append(manager_instance)
except Exception:
continue
@@ -172,11 +164,13 @@ async def startup_event():
logger.info("Application startup complete.")
+
# Application Shutdown Event
async def shutdown_event():
logger.info("Running application shutdown event...")
logger.info("Application shutdown complete.")
+
# --- FastAPI App Instantiation (AFTER defining lifecycle functions) ---
# Define paths
@@ -202,29 +196,23 @@ async def shutdown_event():
{"name": "Datasets", "description": "Manage datasets and dataset instances"},
{"name": "Data Contracts", "description": "Manage data contracts for data products"},
{"name": "Data Products", "description": "Manage data products and subscriptions"},
-
# Governance - Standards and approval workflows
{"name": "Compliance", "description": "Manage compliance policies and runs"},
{"name": "Approvals", "description": "Manage approval workflows"},
{"name": "Process Workflows", "description": "Manage process workflows"},
{"name": "Data Asset Reviews", "description": "Manage data asset review workflows"},
-
# Business Glossary - Semantic models and ontologies
{"name": "Semantic Models", "description": "Manage semantic models and ontologies"},
{"name": "Semantic Links", "description": "Manage semantic links between entities"},
- {"name": "Industry Ontologies", "description": "Industry Ontology Library for importing standard ontologies"},
-
# Operations - Monitoring and technical management
{"name": "Estates", "description": "Manage data estates"},
{"name": "Master Data Management", "description": "Master data management features"},
{"name": "Catalog Commander", "description": "Dual-pane catalog explorer"},
-
# Security - Access control and security features
{"name": "Security Features", "description": "Advanced security features"},
{"name": "Entitlements", "description": "Manage entitlements and personas"},
{"name": "Entitlements Sync", "description": "Sync entitlements from external sources"},
{"name": "Access Grants", "description": "Manage time-limited access grants"},
-
# System - Utilities, configuration, auxiliary services
{"name": "Metadata", "description": "Manage metadata attachments"},
{"name": "Workspace", "description": "Workspace asset operations"},
@@ -250,7 +238,7 @@ async def shutdown_event():
dependencies=[Depends(get_settings)],
on_startup=[startup_event],
on_shutdown=[shutdown_event],
- openapi_tags=openapi_tags
+ openapi_tags=openapi_tags,
)
# Configure CORS
@@ -278,7 +266,7 @@ async def shutdown_event():
app.add_middleware(LoggingMiddleware)
# Mount static files for the React application (skip in test mode)
-if not os.environ.get('TESTING'):
+if not os.environ.get("TESTING"):
app.mount("/static", StaticFiles(directory=STATIC_ASSETS_PATH, html=True), name="static")
# Data Products - Core data lifecycle
@@ -291,12 +279,12 @@ async def shutdown_event():
data_contracts_routes.register_routes(app)
data_product_routes.register_routes(app)
from src.routes import approvals_routes
+
approvals_routes.register_routes(app)
# Governance - Standards and approval workflows
semantic_models_routes.register_routes(app)
semantic_links_routes.register_routes(app)
-industry_ontology_routes.register_routes(app) # Industry Ontology Library
data_asset_reviews_routes.register_routes(app)
data_catalog_routes.register_routes(app)
@@ -326,50 +314,61 @@ async def shutdown_event():
mcp_routes.register_routes(app)
mcp_tokens_routes.register_routes(app)
self_service_routes.register_routes(app)
+graph_explorer_routes.register_routes(app)
workflows_routes.register_routes(app)
settings_routes.register_routes(app)
+
# Define other specific API routes BEFORE the catch-all
@app.get("/api/time")
async def get_current_time():
"""Get the current time (for testing purposes mostly)"""
- return {'time': time.time()}
+ return {"time": time.time()}
+
@app.get("/api/cache-version")
async def get_cache_version():
"""Get the server cache version for client-side cache invalidation"""
- return {'version': SERVER_STARTUP_TIME, 'timestamp': int(time.time())}
+ return {"version": SERVER_STARTUP_TIME, "timestamp": int(time.time())}
+
@app.get("/api/version")
async def get_app_version():
"""Get the application version and server start time"""
- return {
- 'version': __version__,
- 'startTime': SERVER_STARTUP_TIME,
- 'timestamp': int(time.time())
- }
+ return {"version": __version__, "startTime": SERVER_STARTUP_TIME, "timestamp": int(time.time())}
+
# Define the SPA catch-all route LAST (skip in test mode)
-if not os.environ.get('TESTING'):
+if not os.environ.get("TESTING"):
+
@app.get("/{full_path:path}")
def serve_spa(full_path: str):
# Only catch routes that aren't API routes, static files, or API docs
# This check might be redundant now due to ordering, but safe to keep
- if not full_path.startswith("api/") and not full_path.startswith("static/") and full_path not in ["docs", "redoc", "openapi.json"]:
+ if (
+ not full_path.startswith("api/")
+ and not full_path.startswith("static/")
+ and full_path not in ["docs", "redoc", "openapi.json"]
+ ):
# Ensure the path exists before serving
spa_index = STATIC_ASSETS_PATH / "index.html"
if spa_index.is_file():
- return FileResponse(spa_index, media_type="text/html")
+ return FileResponse(spa_index, media_type="text/html")
else:
- # Optional: Return a 404 or a simple HTML message if index.html is missing
- logger.error(f"SPA index.html not found at {spa_index}")
- return HTMLResponse(content="
Frontend not built or index.html missing.", status_code=404)
+ # Optional: Return a 404 or a simple HTML message if index.html is missing
+ logger.error(f"SPA index.html not found at {spa_index}")
+ return HTMLResponse(
+ content="Frontend not built or index.html missing.",
+ status_code=404,
+ )
# If it starts with api/ or static/ but wasn't handled by a router/StaticFiles,
# FastAPI will return its default 404 Not Found, which is correct.
# No explicit return needed here for that case.
+
logger.info("All routes registered.")
-if __name__ == '__main__':
+if __name__ == "__main__":
import uvicorn
+
uvicorn.run(app, host="0.0.0.0", port=8000)
diff --git a/src/backend/src/common/config.py b/src/backend/src/common/config.py
index f8e7ed5f..a1c84b38 100644
--- a/src/backend/src/common/config.py
+++ b/src/backend/src/common/config.py
@@ -138,6 +138,12 @@ class Settings(BaseSettings):
env='LLM_INJECTION_CHECK_PROMPT'
)
+ # Graph Explorer safety limits
+ GRAPH_MAX_NODES: int = Field(5000, env='GRAPH_MAX_NODES') # Max nodes returned from initial graph load
+ GRAPH_MAX_EDGES: int = Field(10000, env='GRAPH_MAX_EDGES') # Max edges returned from initial graph load
+ GRAPH_QUERY_TIMEOUT: str = Field("30s", env='GRAPH_QUERY_TIMEOUT') # SQL statement timeout
+ GRAPH_NEIGHBOR_LIMIT: int = Field(50, env='GRAPH_NEIGHBOR_LIMIT') # Default max neighbors per expansion
+
# Sandbox allowlist settings
sandbox_default_schema: str = Field('sandbox', validation_alias=AliasChoices('SANDBOX_DEFAULT_SCHEMA', 'sandbox_default_schema'))
sandbox_allowed_catalog_prefixes: List[str] = Field(default_factory=lambda: ['user_'], validation_alias=AliasChoices('SANDBOX_ALLOWED_CATALOG_PREFIXES', 'sandbox_allowed_catalog_prefixes'))
diff --git a/src/backend/src/common/database.py b/src/backend/src/common/database.py
index d7f05e81..382569dc 100644
--- a/src/backend/src/common/database.py
+++ b/src/backend/src/common/database.py
@@ -1,38 +1,35 @@
import os
-import uuid
-import time
import threading
+import time
+import uuid
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
-from typing import Any, Dict, List, Optional, TypeVar
-
-from sqlalchemy import create_engine, text, event
-from sqlalchemy.orm import sessionmaker, Session as SQLAlchemySession
-from sqlalchemy.ext.declarative import declarative_base
-from sqlalchemy import pool
-from sqlalchemy.engine import Connection, URL
+from typing import Any, TypeVar
from alembic.config import Config as AlembicConfig
-from alembic.script import ScriptDirectory
from alembic.runtime.migration import MigrationContext
-from alembic import command as alembic_command
+from alembic.script import ScriptDirectory
+
+# Import SDK components
+from sqlalchemy import create_engine, event, pool, text
+from sqlalchemy.engine import URL, Connection
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import sessionmaker
-from .config import get_settings, Settings
-from .logging import get_logger
-from src.common.workspace_client import get_workspace_client
from src.common.unity_catalog_utils import (
ensure_catalog_exists,
ensure_schema_exists,
sanitize_postgres_identifier,
)
-# Import SDK components
-from databricks.sdk.errors import NotFound, DatabricksError
-from databricks.sdk.core import Config, oauth_service_principal
+from src.common.workspace_client import get_workspace_client
+
+from .config import Settings, get_settings
+from .logging import get_logger
logger = get_logger(__name__)
-T = TypeVar('T')
+T = TypeVar("T")
# Define the base class for SQLAlchemy models
Base = declarative_base()
@@ -41,40 +38,52 @@
# This ensures Base.metadata is populated before init_db needs it.
logger.debug("Importing all DB model modules to register with Base...")
try:
- from src.db_models import settings as settings_db
- from src.db_models import audit_log
- from src.db_models import data_asset_reviews
- from src.db_models import data_products
- from src.db_models import notifications
- from src.db_models import data_domains
- from src.db_models import semantic_links
- from src.db_models import metadata as metadata_db
- from src.db_models import semantic_models
- from src.db_models import comments
- from src.db_models import costs
- from src.db_models import change_log
# Add missing model imports for Alembic
- from src.db_models import data_contracts
- from src.db_models import workflow_configurations
- from src.db_models import workflow_installations
- from src.db_models import workflow_job_runs
- from src.db_models import compliance
- from src.db_models import projects
- from src.db_models import teams
- from src.db_models import mdm # MDM models for master data management
- # from src.db_models.data_products import DataProductDb, InfoDb, InputPortDb, OutputPortDb # Already imported via module import above
- from src.db_models.settings import AppRoleDb
+ from src.db_models import (
+ audit_log,
+ change_log,
+ comments,
+ compliance,
+ costs,
+ data_asset_reviews,
+ data_contracts,
+ data_domains,
+ data_products,
+ mdm, # MDM models for master data management
+ notifications,
+ projects,
+ semantic_links,
+ semantic_models,
+ teams,
+ workflow_configurations,
+ workflow_installations,
+ workflow_job_runs,
+ )
+ from src.db_models import metadata as metadata_db
+ from src.db_models import settings as settings_db
+
# from src.db_models.users import UserActivityDb, UserSearchHistoryDb # Commented out due to missing file
from src.db_models.audit_log import AuditLogDb
from src.db_models.notifications import NotificationDb
+
+ # from src.db_models.data_products import DataProductDb, InfoDb, InputPortDb, OutputPortDb # Already imported via module import above
+ from src.db_models.settings import AppRoleDb
+
# from src.db_models.business_glossary import GlossaryDb, TermDb, CategoryDb, term_category_association, term_related_terms, term_asset_association # Commented out due to missing file
# Add new tag models
- from src.db_models.tags import TagDb, TagNamespaceDb, TagNamespacePermissionDb, EntityTagAssociationDb
+ from src.db_models.tags import (
+ EntityTagAssociationDb,
+ TagDb,
+ TagNamespaceDb,
+ TagNamespacePermissionDb,
+ )
+
# Add imports for any other future model modules here
logger.debug("DB model modules imported successfully.")
except ImportError as e:
logger.critical(
- f"Failed to import a DB model module during initial registration: {e}", exc_info=True)
+ f"Failed to import a DB model module during initial registration: {e}", exc_info=True
+ )
# This is likely a fatal error, consider raising or exiting
raise
# ------------------------------------------------------------------------- #
@@ -86,20 +95,20 @@
engine = None
# OAuth token state for Lakebase connections
-_oauth_token: Optional[str] = None
+_oauth_token: str | None = None
_token_last_refresh: float = 0
_token_refresh_lock = threading.Lock()
-_token_refresh_thread: Optional[threading.Thread] = None
+_token_refresh_thread: threading.Thread | None = None
_token_refresh_stop_event = threading.Event()
-def get_lakebase_instance_name(app_name: str, ws_client) -> Optional[str]:
+def get_lakebase_instance_name(app_name: str, ws_client) -> str | None:
"""Get the Lakebase instance name from the Databricks App resources.
-
+
Args:
app_name: Name of the Databricks App
ws_client: Workspace client instance
-
+
Returns:
The database instance name, or None if not found
"""
@@ -117,7 +126,8 @@ def get_lakebase_instance_name(app_name: str, ws_client) -> Optional[str]:
@dataclass
class InMemorySession:
"""In-memory session for managing transactions."""
- changes: List[Dict[str, Any]]
+
+ changes: list[dict[str, Any]]
def __init__(self):
self.changes = []
@@ -135,10 +145,10 @@ class InMemoryStore:
def __init__(self):
"""Initialize the in-memory store."""
- self._data: Dict[str, List[Dict[str, Any]]] = {}
- self._metadata: Dict[str, Dict[str, Any]] = {}
+ self._data: dict[str, list[dict[str, Any]]] = {}
+ self._metadata: dict[str, dict[str, Any]] = {}
- def create_table(self, table_name: str, metadata: Dict[str, Any] = None) -> None:
+ def create_table(self, table_name: str, metadata: dict[str, Any] = None) -> None:
"""Create a new table in the store.
Args:
@@ -150,7 +160,7 @@ def create_table(self, table_name: str, metadata: Dict[str, Any] = None) -> None
if metadata:
self._metadata[table_name] = metadata
- def insert(self, table_name: str, data: Dict[str, Any]) -> None:
+ def insert(self, table_name: str, data: dict[str, Any]) -> None:
"""Insert a record into a table.
Args:
@@ -161,16 +171,16 @@ def insert(self, table_name: str, data: Dict[str, Any]) -> None:
self.create_table(table_name)
# Add timestamp and id if not present
- if 'id' not in data:
- data['id'] = str(len(self._data[table_name]) + 1)
- if 'created_at' not in data:
- data['created_at'] = datetime.utcnow().isoformat()
- if 'updated_at' not in data:
- data['updated_at'] = data['created_at']
+ if "id" not in data:
+ data["id"] = str(len(self._data[table_name]) + 1)
+ if "created_at" not in data:
+ data["created_at"] = datetime.utcnow().isoformat()
+ if "updated_at" not in data:
+ data["updated_at"] = data["created_at"]
self._data[table_name].append(data)
- def get(self, table_name: str, id: str) -> Optional[Dict[str, Any]]:
+ def get(self, table_name: str, id: str) -> dict[str, Any] | None:
"""Get a record by ID.
Args:
@@ -182,9 +192,9 @@ def get(self, table_name: str, id: str) -> Optional[Dict[str, Any]]:
"""
if table_name not in self._data:
return None
- return next((item for item in self._data[table_name] if item['id'] == id), None)
+ return next((item for item in self._data[table_name] if item["id"] == id), None)
- def get_all(self, table_name: str) -> List[Dict[str, Any]]:
+ def get_all(self, table_name: str) -> list[dict[str, Any]]:
"""Get all records from a table.
Args:
@@ -195,7 +205,7 @@ def get_all(self, table_name: str) -> List[Dict[str, Any]]:
"""
return self._data.get(table_name, [])
- def update(self, table_name: str, id: str, data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
+ def update(self, table_name: str, id: str, data: dict[str, Any]) -> dict[str, Any] | None:
"""Update a record.
Args:
@@ -210,9 +220,9 @@ def update(self, table_name: str, id: str, data: Dict[str, Any]) -> Optional[Dic
return None
for item in self._data[table_name]:
- if item['id'] == id:
+ if item["id"] == id:
item.update(data)
- item['updated_at'] = datetime.utcnow().isoformat()
+ item["updated_at"] = datetime.utcnow().isoformat()
return item
return None
@@ -230,8 +240,7 @@ def delete(self, table_name: str, id: str) -> bool:
return False
initial_length = len(self._data[table_name])
- self._data[table_name] = [
- item for item in self._data[table_name] if item['id'] != id]
+ self._data[table_name] = [item for item in self._data[table_name] if item["id"] != id]
return len(self._data[table_name]) < initial_length
def clear(self, table_name: str) -> None:
@@ -276,37 +285,39 @@ def dispose(self) -> None:
# Global database manager instance
-db_manager: Optional[DatabaseManager] = None
+db_manager: DatabaseManager | None = None
def refresh_oauth_token(settings: Settings) -> str:
"""Generate fresh OAuth token from Databricks for Lakebase connection."""
global _oauth_token, _token_last_refresh
-
+
with _token_refresh_lock:
ws_client = get_workspace_client(settings)
instance_name = get_lakebase_instance_name(settings.DATABRICKS_APP_NAME, ws_client)
-
+
if not instance_name:
- raise ValueError(f"Could not determine Lakebase instance name for app '{settings.DATABRICKS_APP_NAME}'")
-
+ raise ValueError(
+ f"Could not determine Lakebase instance name for app '{settings.DATABRICKS_APP_NAME}'"
+ )
+
logger.info(f"Generating OAuth token for Lakebase instance: {instance_name}")
cred = ws_client.database.generate_database_credential(
request_id=str(uuid.uuid4()),
instance_names=[instance_name],
)
-
+
_oauth_token = cred.token
_token_last_refresh = time.time()
logger.info("OAuth token refreshed successfully")
-
+
return _oauth_token
def start_token_refresh_background(settings: Settings):
"""Start background thread to refresh OAuth tokens every 50 minutes."""
global _token_refresh_thread, _token_refresh_stop_event
-
+
def refresh_loop():
while not _token_refresh_stop_event.is_set():
_token_refresh_stop_event.wait(50 * 60) # 50 minutes
@@ -315,7 +326,7 @@ def refresh_loop():
refresh_oauth_token(settings)
except Exception as e:
logger.error(f"Background token refresh failed: {e}", exc_info=True)
-
+
_token_refresh_stop_event.clear()
_token_refresh_thread = threading.Thread(target=refresh_loop, daemon=True)
_token_refresh_thread.start()
@@ -333,14 +344,16 @@ def stop_token_refresh_background():
def get_db_url(settings: Settings) -> str:
"""Construct the PostgreSQL SQLAlchemy URL with appropriate auth method."""
-
+
# Validate required settings
if not all([settings.PGHOST, settings.PGDATABASE]):
- raise ValueError("PostgreSQL connection details (PGHOST, PGDATABASE) are missing in settings.")
-
+ raise ValueError(
+ "PostgreSQL connection details (PGHOST, PGDATABASE) are missing in settings."
+ )
+
# Determine authentication mode based on ENV
use_password_auth = settings.ENV.upper().startswith("LOCAL")
-
+
if use_password_auth:
logger.info("Database: Using password authentication (LOCAL mode)")
if not settings.PGPASSWORD or not settings.PGUSER:
@@ -351,20 +364,17 @@ def get_db_url(settings: Settings) -> str:
logger.info("Database: Using OAuth authentication (Lakebase mode)")
# Dynamically determine username from authenticated principal
ws_client = get_workspace_client(settings)
- username = (
- os.getenv("DATABRICKS_CLIENT_ID")
- or ws_client.current_user.me().user_name
- )
+ username = os.getenv("DATABRICKS_CLIENT_ID") or ws_client.current_user.me().user_name
if not username:
raise ValueError("Could not determine database username from authenticated principal")
-
+
logger.info(f"🔑 Detected service principal username: {username}")
password = "" # Will be set via event handler
-
+
# Build URL with schema options and statement timeout
query_params = {}
options_list = []
-
+
# Add schema to search_path if specified
if settings.PGSCHEMA:
# Validate schema name for connection options to prevent injection
@@ -379,15 +389,15 @@ def get_db_url(settings: Settings) -> str:
logger.info(f"PostgreSQL schema will be set via options: {validated_schema}")
else:
logger.info("No specific PostgreSQL schema configured, using default (public).")
-
+
# Add statement timeout to prevent indefinite locks (30 seconds)
# This helps prevent stuck transactions when operations fail
options_list.append("-cstatement_timeout=30000")
logger.info("PostgreSQL statement timeout set to 30 seconds")
-
+
if options_list:
query_params["options"] = " ".join(options_list)
-
+
db_url_obj = URL.create(
drivername="postgresql+psycopg2",
username=username,
@@ -395,7 +405,7 @@ def get_db_url(settings: Settings) -> str:
host=settings.PGHOST,
port=settings.PGPORT,
database=settings.PGDATABASE,
- query=query_params if query_params else None
+ query=query_params if query_params else None,
)
url_str = db_url_obj.render_as_string(hide_password=False)
logger.debug(
@@ -409,52 +419,53 @@ def ensure_database_and_schema_exist(settings: Settings):
"""
Ensure the target database and schema exist. If not, create them.
Works in both LOCAL and OAuth modes (authentication differs but logic is unified).
-
+
If APP_DB_DROP_ON_START=true, drops the schema CASCADE before creating it.
Connects to default postgres database to create target database (OAuth mode only),
then creates schema within it.
-
+
The app (as service principal) becomes the owner of what it creates in OAuth mode,
eliminating permission issues.
-
+
Security: All PostgreSQL identifiers are validated to prevent SQL injection.
"""
is_local_mode = settings.ENV.upper().startswith("LOCAL")
-
- logger.info(f"Ensuring database and schema exist ({'LOCAL' if is_local_mode else 'OAuth'} mode)...")
-
+
+ logger.info(
+ f"Ensuring database and schema exist ({'LOCAL' if is_local_mode else 'OAuth'} mode)..."
+ )
+
# Determine username based on mode
if is_local_mode:
username = settings.PGUSER
else:
# Get service principal username for OAuth mode
ws_client = get_workspace_client(settings)
- username = (
- os.getenv("DATABRICKS_CLIENT_ID")
- or ws_client.current_user.me().user_name
- )
-
+ username = os.getenv("DATABRICKS_CLIENT_ID") or ws_client.current_user.me().user_name
+
if not username:
raise ValueError("Could not determine username/service principal")
-
+
# Validate all PostgreSQL identifiers to prevent SQL injection
try:
target_db = sanitize_postgres_identifier(settings.PGDATABASE)
- target_schema = sanitize_postgres_identifier(settings.PGSCHEMA) if settings.PGSCHEMA else None
+ target_schema = (
+ sanitize_postgres_identifier(settings.PGSCHEMA) if settings.PGSCHEMA else None
+ )
username = sanitize_postgres_identifier(username)
except ValueError as e:
raise ValueError(
f"Invalid PostgreSQL identifier in configuration: {e}. "
"Please check PGDATABASE, PGSCHEMA, and username."
) from e
-
+
logger.info(f"Username: {username}")
logger.debug(f"Target database: {target_db}, schema: {target_schema}")
-
+
# Generate initial OAuth token for OAuth mode
if not is_local_mode:
refresh_oauth_token(settings)
-
+
# Build connection URL
# In OAuth mode, connect directly to the target database (must be pre-created)
# In LOCAL mode, connect to the target database (should already exist)
@@ -466,24 +477,25 @@ def ensure_database_and_schema_exist(settings: Settings):
port=settings.PGPORT,
database=target_db,
)
-
+
# Create temporary engine for schema setup
temp_engine = create_engine(
connection_url.render_as_string(hide_password=False),
- isolation_level="AUTOCOMMIT" # Needed for CREATE SCHEMA
+ isolation_level="AUTOCOMMIT", # Needed for CREATE SCHEMA
)
-
+
# Inject OAuth token for connections in OAuth mode
if not is_local_mode:
+
@event.listens_for(temp_engine, "do_connect")
def inject_token_temp(dialect, conn_rec, cargs, cparams):
global _oauth_token
if _oauth_token:
cparams["password"] = _oauth_token
-
+
try:
# In OAuth mode, verify we can connect to the target database
- # The database must be pre-created by an admin with:
+ # The database must be pre-created by an admin with:
# CREATE DATABASE "app_ontos"; GRANT CREATE ON DATABASE "app_ontos" TO PUBLIC;
if not is_local_mode:
try:
@@ -504,22 +516,26 @@ def inject_token_temp(dialect, conn_rec, cargs, cparams):
f"See the README for detailed setup instructions."
) from e
raise
-
+
# Now handle schema (works for both LOCAL and OAuth modes)
with temp_engine.connect() as conn:
if target_schema and target_schema != "public":
# Handle APP_DB_DROP_ON_START: Drop schema CASCADE before creating
if settings.APP_DB_DROP_ON_START:
- logger.warning(f"APP_DB_DROP_ON_START=true: Dropping schema '{target_schema}' CASCADE...")
+ logger.warning(
+ f"APP_DB_DROP_ON_START=true: Dropping schema '{target_schema}' CASCADE..."
+ )
# DROP SCHEMA cannot be parameterized, but identifier is validated
conn.execute(text(f'DROP SCHEMA IF EXISTS "{target_schema}" CASCADE'))
conn.commit()
- logger.warning(f"✓ Schema '{target_schema}' dropped CASCADE. Will be recreated.")
-
+ logger.warning(
+ f"✓ Schema '{target_schema}' dropped CASCADE. Will be recreated."
+ )
+
# Explicitly drop application enum types from public schema
# These are often created in public and not dropped with CASCADE on a specific schema
logger.info("Dropping application enum types from public schema...")
- enum_types = ['commentstatus', 'commenttype', 'accesslevel']
+ enum_types = ["commentstatus", "commenttype", "accesslevel"]
for enum_type in enum_types:
try:
conn.execute(text(f'DROP TYPE IF EXISTS "{enum_type}" CASCADE'))
@@ -527,47 +543,53 @@ def inject_token_temp(dialect, conn_rec, cargs, cparams):
logger.warning(f"Could not drop enum type '{enum_type}': {e}")
conn.commit()
logger.info("✓ Enum types cleanup completed.")
-
+
# Check if schema exists (using parameterized query)
result = conn.execute(
- text("SELECT 1 FROM information_schema.schemata WHERE schema_name = :schemaname"),
- {"schemaname": target_schema}
+ text(
+ "SELECT 1 FROM information_schema.schemata WHERE schema_name = :schemaname"
+ ),
+ {"schemaname": target_schema},
)
schema_exists = result.scalar() is not None
-
+
if not schema_exists:
logger.info(f"Creating schema: {target_schema}")
# CREATE SCHEMA cannot be parameterized, but identifier is validated
conn.execute(text(f'CREATE SCHEMA "{target_schema}"'))
logger.info(f"✓ Schema created: {target_schema} (owner: {username})")
-
+
# Set default privileges for future objects (OAuth mode only)
if not is_local_mode:
# ALTER statements cannot be parameterized, but identifiers are validated
logger.info(f"Setting default privileges in schema: {target_schema}")
- conn.execute(text(
- f'ALTER DEFAULT PRIVILEGES IN SCHEMA "{target_schema}" '
- f'GRANT ALL ON TABLES TO "{username}"'
- ))
- conn.execute(text(
- f'ALTER DEFAULT PRIVILEGES IN SCHEMA "{target_schema}" '
- f'GRANT ALL ON SEQUENCES TO "{username}"'
- ))
- logger.info(f"✓ Default privileges configured")
-
+ conn.execute(
+ text(
+ f'ALTER DEFAULT PRIVILEGES IN SCHEMA "{target_schema}" '
+ f'GRANT ALL ON TABLES TO "{username}"'
+ )
+ )
+ conn.execute(
+ text(
+ f'ALTER DEFAULT PRIVILEGES IN SCHEMA "{target_schema}" '
+ f'GRANT ALL ON SEQUENCES TO "{username}"'
+ )
+ )
+ logger.info("✓ Default privileges configured")
+
conn.commit()
else:
logger.info(f"✓ Schema already exists: {target_schema}")
else:
- logger.info(f"Using public schema (no custom schema specified)")
-
+ logger.info("Using public schema (no custom schema specified)")
+
except Exception as e:
if "permission denied" in str(e).lower():
logger.error(
f"❌ Permission denied - check database/schema privileges for user '{username}'"
)
if not is_local_mode:
- logger.error(f"To fix this, run as a Lakebase admin:")
+ logger.error("To fix this, run as a Lakebase admin:")
logger.error(f' DROP DATABASE IF EXISTS "{target_db}";')
logger.error(f' CREATE DATABASE "{target_db}";')
logger.error(f' GRANT CREATE ON DATABASE "{target_db}" TO "{username}";')
@@ -575,13 +597,13 @@ def inject_token_temp(dialect, conn_rec, cargs, cparams):
raise
finally:
temp_engine.dispose()
-
+
logger.info("✓ Database and schema are ready")
def ensure_catalog_schema_exists(settings: Settings):
"""Checks if the configured catalog and schema exist, creates them if not.
-
+
Uses shared Unity Catalog utilities for secure, idempotent catalog/schema creation.
"""
logger.info("Ensuring required catalog and schema exist...")
@@ -600,18 +622,15 @@ def ensure_catalog_schema_exists(settings: Settings):
ensure_catalog_exists(
ws=ws_client,
catalog_name=catalog_name,
- comment=f"System catalog for {settings.APP_NAME}"
+ comment=f"System catalog for {settings.APP_NAME}",
)
logger.info(f"Catalog '{catalog_name}' is ready.")
except Exception as e:
# Map HTTPException or other errors to ConnectionError for consistency
logger.critical(
- f"Failed to ensure catalog '{catalog_name}': {e}. Check permissions.",
- exc_info=True
+ f"Failed to ensure catalog '{catalog_name}': {e}. Check permissions.", exc_info=True
)
- raise ConnectionError(
- f"Failed to create required catalog '{catalog_name}': {e}"
- ) from e
+ raise ConnectionError(f"Failed to create required catalog '{catalog_name}': {e}") from e
try:
logger.debug(f"Ensuring schema exists: {full_schema_name}")
@@ -619,13 +638,13 @@ def ensure_catalog_schema_exists(settings: Settings):
ws=ws_client,
catalog_name=catalog_name,
schema_name=schema_name,
- comment=f"System schema for {settings.APP_NAME}"
+ comment=f"System schema for {settings.APP_NAME}",
)
logger.info(f"Schema '{full_schema_name}' is ready.")
except Exception as e:
logger.critical(
- f"Failed to ensure schema '{full_schema_name}': {e}. Check permissions.",
- exc_info=True
+ f"Failed to ensure schema '{full_schema_name}': {e}. Check permissions.",
+ exc_info=True,
)
raise ConnectionError(
f"Failed to create required schema '{full_schema_name}': {e}"
@@ -638,15 +657,14 @@ def ensure_catalog_schema_exists(settings: Settings):
raise
except Exception as e:
logger.critical(
- f"An unexpected error occurred during catalog/schema check/creation: {e}",
- exc_info=True
+ f"An unexpected error occurred during catalog/schema check/creation: {e}", exc_info=True
)
- raise ConnectionError(
- f"Failed during catalog/schema setup: {e}"
- ) from e
+ raise ConnectionError(f"Failed during catalog/schema setup: {e}") from e
-def get_current_db_revision(engine_connection: Connection, alembic_cfg: AlembicConfig) -> str | None:
+def get_current_db_revision(
+ engine_connection: Connection, alembic_cfg: AlembicConfig
+) -> str | None:
"""Gets the current revision of the database."""
context = MigrationContext.configure(engine_connection)
return context.get_current_revision()
@@ -662,7 +680,7 @@ def init_db() -> None:
return
logger.info("Initializing database engine and session factory...")
-
+
# Ensure database and schema exist (creates them if needed in OAuth mode)
ensure_database_and_schema_exist(settings)
@@ -675,27 +693,31 @@ def init_db() -> None:
logger.info("Connecting to database...")
logger.info(f"> Database URL: {db_url}")
logger.info(f"> Connect args: {connect_args}")
- logger.info(f"> Pool settings: size={settings.DB_POOL_SIZE}, max_overflow={settings.DB_MAX_OVERFLOW}, "
- f"timeout={settings.DB_POOL_TIMEOUT}s, recycle={settings.DB_POOL_RECYCLE}s")
-
- _engine = create_engine(db_url,
- connect_args=connect_args,
- echo=settings.DB_ECHO,
- poolclass=pool.QueuePool,
- pool_size=settings.DB_POOL_SIZE,
- max_overflow=settings.DB_MAX_OVERFLOW,
- pool_timeout=settings.DB_POOL_TIMEOUT,
- pool_recycle=settings.DB_POOL_RECYCLE,
- pool_pre_ping=True)
- engine = _engine # Assign to public variable
+ logger.info(
+ f"> Pool settings: size={settings.DB_POOL_SIZE}, max_overflow={settings.DB_MAX_OVERFLOW}, "
+ f"timeout={settings.DB_POOL_TIMEOUT}s, recycle={settings.DB_POOL_RECYCLE}s"
+ )
+
+ _engine = create_engine(
+ db_url,
+ connect_args=connect_args,
+ echo=settings.DB_ECHO,
+ poolclass=pool.QueuePool,
+ pool_size=settings.DB_POOL_SIZE,
+ max_overflow=settings.DB_MAX_OVERFLOW,
+ pool_timeout=settings.DB_POOL_TIMEOUT,
+ pool_recycle=settings.DB_POOL_RECYCLE,
+ pool_pre_ping=True,
+ )
+ engine = _engine # Assign to public variable
# Add OAuth token injection if not in LOCAL mode
if not settings.ENV.upper().startswith("LOCAL"):
logger.info("Setting up OAuth token injection for Lakebase...")
-
+
# Generate initial token
refresh_oauth_token(settings)
-
+
# Register event handler to inject tokens for new connections
# Use 'do_connect' event to inject password at connection creation time
@event.listens_for(_engine, "do_connect")
@@ -704,7 +726,7 @@ def inject_token_on_connect(dialect, conn_rec, cargs, cparams):
if _oauth_token:
cparams["password"] = _oauth_token
logger.debug("Injected OAuth token into new database connection")
-
+
# Start background refresh thread
start_token_refresh_background(settings)
logger.info("OAuth authentication configured successfully")
@@ -738,13 +760,21 @@ def set_search_path(dbapi_connection, connection_record):
logger.info("Database engine and session factory initialized.")
# --- Alembic Migration Logic --- #
- alembic_cfg_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..' , 'alembic.ini'))
- alembic_script_location = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'alembic'))
+ alembic_cfg_path = os.path.abspath(
+ os.path.join(os.path.dirname(__file__), "..", "..", "alembic.ini")
+ )
+ alembic_script_location = os.path.abspath(
+ os.path.join(os.path.dirname(__file__), "..", "..", "alembic")
+ )
logger.info(f"Loading Alembic configuration from: {alembic_cfg_path}")
logger.info(f"Alembic script location: {alembic_script_location}")
alembic_cfg = AlembicConfig(alembic_cfg_path)
- alembic_cfg.set_main_option("sqlalchemy.url", db_url.replace("%", "%%")) # Ensure Alembic uses the same URL
- alembic_cfg.set_main_option("script_location", alembic_script_location) # Set absolute path to alembic directory
+ alembic_cfg.set_main_option(
+ "sqlalchemy.url", db_url.replace("%", "%%")
+ ) # Ensure Alembic uses the same URL
+ alembic_cfg.set_main_option(
+ "script_location", alembic_script_location
+ ) # Set absolute path to alembic directory
script = ScriptDirectory.from_config(alembic_cfg)
head_revision = script.get_current_head()
logger.info(f"Alembic Head Revision: {head_revision}")
@@ -765,23 +795,29 @@ def set_search_path(dbapi_connection, connection_record):
# Handle migrations based on database state
if db_revision is None:
# Fresh database - will use create_all() + stamp later
- logger.info("Fresh database detected (no Alembic version table). Will initialize with create_all() and stamp.")
+ logger.info(
+ "Fresh database detected (no Alembic version table). Will initialize with create_all() and stamp."
+ )
elif db_revision != head_revision:
# Existing database needs migration
- logger.info(f"Database revision '{db_revision}' differs from head revision '{head_revision}'.")
+ logger.info(
+ f"Database revision '{db_revision}' differs from head revision '{head_revision}'."
+ )
logger.info("Attempting Alembic upgrade to head...")
try:
# CRITICAL: Dispose the main engine's connection pool before running Alembic.
# This releases all pooled connections that might hold implicit transaction state.
- logger.info("Disposing engine pool to release connections before Alembic migration...")
+ logger.info(
+ "Disposing engine pool to release connections before Alembic migration..."
+ )
_engine.dispose()
-
+
# Run Alembic upgrade via subprocess to avoid hanging issues with
# Alembic's internal runpy-based execution when called programmatically.
# This ensures proper process isolation and cleanup.
+ import shutil
import subprocess
import sys
- import shutil
# For Lakebase (OAuth mode), we need to pass the token to the subprocess
# via environment variable since the subprocess can't access our in-memory token
@@ -798,7 +834,11 @@ def set_search_path(dbapi_connection, connection_record):
python_executable = None
# Try sys.executable first if it exists and is absolute
- if sys.executable and os.path.isabs(sys.executable) and os.path.exists(sys.executable):
+ if (
+ sys.executable
+ and os.path.isabs(sys.executable)
+ and os.path.exists(sys.executable)
+ ):
python_executable = sys.executable
logger.info(f"Using sys.executable: {python_executable}")
else:
@@ -810,8 +850,12 @@ def set_search_path(dbapi_connection, connection_record):
os.path.join(os.getcwd(), "venv", "bin", "python3"),
os.path.join(os.getcwd(), "venv", "bin", "python"),
# Try relative to the backend src directory
- os.path.join(os.path.dirname(__file__), "..", "..", ".venv", "bin", "python3"),
- os.path.join(os.path.dirname(__file__), "..", "..", ".venv", "bin", "python"),
+ os.path.join(
+ os.path.dirname(__file__), "..", "..", ".venv", "bin", "python3"
+ ),
+ os.path.join(
+ os.path.dirname(__file__), "..", "..", ".venv", "bin", "python"
+ ),
# Try system Python as fallback
"/usr/local/bin/python3",
"/usr/bin/python3",
@@ -828,7 +872,9 @@ def set_search_path(dbapi_connection, connection_record):
if not python_executable:
alembic_path = shutil.which("alembic")
if alembic_path:
- logger.warning("Could not find Python executable, will try running alembic command directly")
+ logger.warning(
+ "Could not find Python executable, will try running alembic command directly"
+ )
python_executable = None # Will use alembic directly below
else:
raise RuntimeError(
@@ -851,18 +897,22 @@ def set_search_path(dbapi_connection, connection_record):
capture_output=True,
text=True,
timeout=300, # 5 minute timeout
- env=subprocess_env
+ env=subprocess_env,
)
if result.returncode != 0:
logger.error(f"Alembic upgrade stderr: {result.stderr}")
- raise RuntimeError(f"Alembic upgrade failed with exit code {result.returncode}: {result.stderr}")
+ raise RuntimeError(
+ f"Alembic upgrade failed with exit code {result.returncode}: {result.stderr}"
+ )
logger.info(f"Alembic upgrade output: {result.stdout}")
logger.info("✓ Alembic upgrade to head COMPLETED.")
except subprocess.TimeoutExpired:
logger.critical("Alembic upgrade timed out after 5 minutes!")
raise RuntimeError("Alembic upgrade timed out")
except Exception as alembic_err:
- logger.critical("Alembic upgrade failed! Manual intervention may be required.", exc_info=True)
+ logger.critical(
+ "Alembic upgrade failed! Manual intervention may be required.", exc_info=True
+ )
raise RuntimeError("Failed to upgrade database schema.") from alembic_err
else:
logger.info("✓ Database schema is up to date according to Alembic.")
@@ -882,43 +932,51 @@ def set_search_path(dbapi_connection, connection_record):
# with engine.connect() as connection:
# connection.execute(sqlalchemy.text(f"CREATE SCHEMA IF NOT EXISTS {schema_to_create_in}"))
# connection.commit()
- logger.info(f"PostgreSQL: Tables will be targeted for schema '{schema_to_create_in}' via search_path or model definitions.")
+ logger.info(
+ f"PostgreSQL: Tables will be targeted for schema '{schema_to_create_in}' via search_path or model definitions."
+ )
# No Databricks-specific metadata modifications required
-
- # Note: Schema creation and APP_DB_DROP_ON_START handling is done in
+
+ # Note: Schema creation and APP_DB_DROP_ON_START handling is done in
# ensure_database_and_schema_exist() before Alembic migrations run
# Only use create_all() for fresh databases without Alembic version table
# Once Alembic is tracking the schema, migrations handle all schema changes
if db_revision is None:
- logger.info("Fresh database detected (no Alembic version). Using create_all() for initial setup...")
- target_schema = settings.PGSCHEMA or 'public'
-
+ logger.info(
+ "Fresh database detected (no Alembic version). Using create_all() for initial setup..."
+ )
+ target_schema = settings.PGSCHEMA or "public"
+
# Create all tables in the target schema
with _engine.begin() as connection:
# Explicitly set search_path to ensure tables are created in correct schema
connection.execute(text(f'SET search_path TO "{target_schema}"'))
Base.metadata.create_all(bind=connection, checkfirst=True)
logger.info("✓ Database tables created by create_all.")
-
+
# Stamp the database with the baseline migration
- # Using direct INSERT instead of alembic_command.stamp() to avoid
+ # Using direct INSERT instead of alembic_command.stamp() to avoid
# hanging issues with Alembic's runpy-based execution in some environments
logger.info("Stamping database with baseline migration...")
try:
with _engine.begin() as connection:
connection.execute(text(f'SET search_path TO "{target_schema}"'))
# Create alembic_version table if needed and insert head revision
- connection.execute(text("""
+ connection.execute(
+ text("""
CREATE TABLE IF NOT EXISTS alembic_version (
version_num VARCHAR(32) NOT NULL,
CONSTRAINT alembic_version_pkc PRIMARY KEY (version_num)
)
- """))
+ """)
+ )
connection.execute(text("DELETE FROM alembic_version"))
- connection.execute(text("INSERT INTO alembic_version (version_num) VALUES (:rev)"),
- {"rev": head_revision})
+ connection.execute(
+ text("INSERT INTO alembic_version (version_num) VALUES (:rev)"),
+ {"rev": head_revision},
+ )
logger.info(f"✓ Database stamped with baseline migration: {head_revision}")
except Exception as stamp_err:
logger.error(f"Failed to stamp database: {stamp_err}", exc_info=True)
@@ -929,21 +987,24 @@ def set_search_path(dbapi_connection, connection_record):
logger.critical(f"Database initialization failed: {e}", exc_info=True)
_engine = None
_SessionLocal = None
- engine = None # Reset public engine on failure
+ engine = None # Reset public engine on failure
raise ConnectionError("Failed to initialize database connection or run migrations.") from e
+
def get_db():
global _SessionLocal
if _SessionLocal is None:
logger.error("Database not initialized. Cannot get session.")
# Consider raising HTTPException for FastAPI to handle gracefully if this occurs at runtime
- raise RuntimeError("Database session factory is not available. Database might not have been initialized correctly.")
-
+ raise RuntimeError(
+ "Database session factory is not available. Database might not have been initialized correctly."
+ )
+
db = _SessionLocal()
try:
yield db
db.commit() # Commit the transaction on successful completion of the request
- except Exception as e: # Catch all exceptions to ensure rollback
+ except Exception as e: # Catch all exceptions to ensure rollback
logger.error(f"Error during database session for request, rolling back: {e}", exc_info=True)
db.rollback()
# Re-raise the exception so FastAPI can handle it appropriately
@@ -952,6 +1013,7 @@ def get_db():
finally:
db.close()
+
@contextmanager
def get_db_session():
"""Context manager that yields a SQLAlchemy session.
@@ -966,7 +1028,9 @@ def get_db_session():
init_db()
except Exception as e:
logger.critical(f"Failed to initialize database session factory: {e}", exc_info=True)
- raise RuntimeError("Database session factory not available and initialization failed.") from e
+ raise RuntimeError(
+ "Database session factory not available and initialization failed."
+ ) from e
session = _SessionLocal()
try:
@@ -979,12 +1043,14 @@ def get_db_session():
finally:
session.close()
+
def get_engine():
global _engine
if _engine is None:
raise RuntimeError("Database engine not initialized.")
return _engine
+
def get_session_factory():
global _SessionLocal
if _SessionLocal is None:
@@ -995,25 +1061,26 @@ def get_session_factory():
def set_session_factory(factory):
"""
Set the global session factory. Used by tests to inject a test database session factory.
-
+
Args:
factory: A sessionmaker instance or callable that returns database sessions
"""
global _SessionLocal
_SessionLocal = factory
+
def cleanup_db():
"""Cleanup database resources including OAuth token refresh."""
global _engine, _SessionLocal, engine
-
+
# Stop token refresh if running
stop_token_refresh_background()
-
+
# Dispose engine
if _engine:
_engine.dispose()
logger.info("Database engine disposed")
-
+
_engine = None
_SessionLocal = None
engine = None
diff --git a/src/backend/src/common/dependencies.py b/src/backend/src/common/dependencies.py
index c9d33d76..6c6ccb16 100644
--- a/src/backend/src/common/dependencies.py
+++ b/src/backend/src/common/dependencies.py
@@ -20,7 +20,8 @@
from src.controller.tags_manager import TagsManager # Import TagsManager
from src.controller.workspace_manager import WorkspaceManager # Import WorkspaceManager
from src.controller.change_log_manager import ChangeLogManager # Import ChangeLogManager
-from src.controller.datasets_manager import DatasetsManager # Import DatasetsManager
+from src.controller.datasets_manager import DatasetsManager
+from src.controller.graph_explorer_manager import GraphExplorerManager # Import DatasetsManager
# Import base dependencies
from src.common.database import get_session_factory # Import the factory function
@@ -47,6 +48,7 @@
get_workspace_manager,
get_change_log_manager,
get_datasets_manager,
+ get_graph_explorer_manager,
)
# Import workspace client getter separately as it might be structured differently
from src.common.workspace_client import get_workspace_client_dependency # Fixed to use proper wrapper
@@ -120,6 +122,7 @@ async def get_current_user(user_details: UserInfo = Depends(get_user_details_fro
WorkspaceManagerDep = Annotated[WorkspaceManager, Depends(get_workspace_manager)]
ChangeLogManagerDep = Annotated[ChangeLogManager, Depends(get_change_log_manager)]
DatasetsManagerDep = Annotated[DatasetsManager, Depends(get_datasets_manager)]
+GraphExplorerManagerDep = Annotated[GraphExplorerManager, Depends(get_graph_explorer_manager)]
# Permission Checker Dependency
PermissionCheckerDep = AuthorizationManagerDep
diff --git a/src/backend/src/common/features.py b/src/backend/src/common/features.py
index f344e8f6..d6f1e702 100644
--- a/src/backend/src/common/features.py
+++ b/src/backend/src/common/features.py
@@ -164,6 +164,11 @@ class FeatureAccessLevel(str, Enum):
'name': 'Comments & Ratings',
'allowed_levels': READ_WRITE_ADMIN_LEVELS # READ_WRITE to add, ADMIN to manage all
},
+ # Graph Explorer
+ 'graph-explorer': {
+ 'name': 'Graph Explorer',
+ 'allowed_levels': READ_WRITE_ADMIN_LEVELS
+ },
# 'about': { ... } # About page doesn't need explicit permissions here
}
diff --git a/src/backend/src/common/manager_dependencies.py b/src/backend/src/common/manager_dependencies.py
index 9436e192..4a901c4a 100644
--- a/src/backend/src/common/manager_dependencies.py
+++ b/src/backend/src/common/manager_dependencies.py
@@ -25,6 +25,7 @@
from src.controller.change_log_manager import ChangeLogManager
from src.controller.datasets_manager import DatasetsManager
from src.controller.delivery_service import DeliveryService
+from src.controller.graph_explorer_manager import GraphExplorerManager
# Import other dependencies needed by these providers
from src.common.database import get_db
@@ -181,6 +182,13 @@ def get_datasets_manager(request: Request) -> DatasetsManager:
# Add getters for Compliance, Estate, MDM, Security, Entitlements, Catalog Commander managers when they are added
+def get_graph_explorer_manager(request: Request) -> GraphExplorerManager:
+ manager = getattr(request.app.state, 'graph_explorer_manager', None)
+ if not manager:
+ logger.critical("GraphExplorerManager not found in application state during request!")
+ raise HTTPException(status_code=503, detail="Graph Explorer service not configured.")
+ return manager
+
def get_delivery_service(request: Request) -> DeliveryService:
"""Get the DeliveryService for multi-mode delivery of governance changes."""
service = getattr(request.app.state, "delivery_service", None)
diff --git a/src/backend/src/controller/data_products_manager.py b/src/backend/src/controller/data_products_manager.py
index e59eccde..67d54bd9 100644
--- a/src/backend/src/controller/data_products_manager.py
+++ b/src/backend/src/controller/data_products_manager.py
@@ -659,11 +659,19 @@ def publish_product(self, product_id: str, current_user: Optional[str] = None) -
if not product_db:
raise ValueError(f"Data product with ID {product_id} not found")
+ # Validate that product has at least one output port with a data contract
+ output_ports = product_db.output_ports or []
+ ports_with_contracts = [p for p in output_ports if p.contract_id]
+ if not ports_with_contracts:
+ raise ValueError(
+ "Cannot publish product: Product must have at least one output port with a data contract assigned"
+ )
+
# Validate that all output ports have data contracts
- if product_db.output_ports:
+ if output_ports:
ports_without_contracts = [
- port.name for port in product_db.output_ports
- if not port.data_contract_id
+ port.name for port in output_ports
+ if not port.contract_id
]
if ports_without_contracts:
raise ValueError(
@@ -677,9 +685,9 @@ def publish_product(self, product_id: str, current_user: Optional[str] = None) -
valid_contract_statuses = ['approved', 'active', 'certified']
contracts_not_approved = []
- for port in product_db.output_ports:
- if port.data_contract_id:
- contract = data_contract_repo.get(db=self._db, id=port.data_contract_id)
+ for port in output_ports:
+ if port.contract_id:
+ contract = data_contract_repo.get(db=self._db, id=port.contract_id)
if contract:
contract_status = (contract.status or '').lower()
if contract_status not in valid_contract_statuses:
diff --git a/src/backend/src/controller/graph_explorer_manager.py b/src/backend/src/controller/graph_explorer_manager.py
new file mode 100644
index 00000000..f33ddc61
--- /dev/null
+++ b/src/backend/src/controller/graph_explorer_manager.py
@@ -0,0 +1,792 @@
+"""
+Graph Explorer Manager.
+
+Business logic for reading/writing property graph data
+from/to Databricks Unity Catalog tables using the Statement Execution API.
+
+Table schema (edge-centric, same as graph-demo):
+ node_start_id STRING NOT NULL,
+ node_start_key STRING NOT NULL, -- node type
+ relationship STRING NOT NULL,
+ node_end_id STRING NOT NULL,
+ node_end_key STRING NOT NULL, -- node type
+ node_start_properties STRING, -- JSON
+ node_end_properties STRING -- JSON
+
+Standalone nodes are stored as self-referencing edges with relationship = 'EXISTS'.
+"""
+
+import json
+import os
+import re
+import time
+from typing import Any, Dict, List, Optional, Tuple
+
+from src.common.logging import get_logger
+
+logger = get_logger(__name__)
+
+# Regex for valid Unity Catalog table names
+TABLE_NAME_PATTERN = re.compile(r'^[a-zA-Z0-9_`]+(\.[a-zA-Z0-9_`]+){0,2}$')
+
+DEFAULT_TABLE_NAME = "main.default.property_graph_entity_edges"
+
+
+def _validate_table_name(table_name: str) -> str:
+ """Validate and sanitize a table name to prevent SQL injection."""
+ cleaned = table_name.strip().strip('`')
+ if not TABLE_NAME_PATTERN.match(cleaned):
+ raise ValueError(f"Invalid table name: {table_name}")
+ # Wrap parts in backticks for safety
+ parts = cleaned.split('.')
+ return '.'.join(f'`{p.strip("`")}`' for p in parts)
+
+
+def _escape_sql_string(value: str) -> str:
+ """Escape single quotes for SQL string literals."""
+ return value.replace("'", "''")
+
+
+class GraphExplorerManager:
+ """Manages graph explorer operations via Databricks SQL."""
+
+ def __init__(self, settings=None):
+ self.settings = settings
+
+ def _execute_sql(self, ws_client, sql: str, warehouse_id: str) -> Tuple[List[str], List[List[Any]]]:
+ """Execute a SQL statement and return (columns, rows)."""
+ logger.debug(f"Executing SQL: {sql[:200]}...")
+ timeout = getattr(self.settings, "GRAPH_QUERY_TIMEOUT", "30s") if self.settings else "30s"
+ result = ws_client.statement_execution.execute_statement(
+ statement=sql,
+ warehouse_id=warehouse_id,
+ wait_timeout=timeout,
+ )
+
+ if result.status and result.status.state:
+ state = str(result.status.state)
+ if "FAILED" in state or "CANCELED" in state:
+ error_msg = result.status.error.message if result.status.error else "Query failed"
+ raise RuntimeError(f"SQL execution failed: {error_msg}")
+
+ columns = []
+ if result.manifest and result.manifest.schema and result.manifest.schema.columns:
+ columns = [col.name for col in result.manifest.schema.columns]
+
+ rows = []
+ if result.result and result.result.data_array:
+ rows = result.result.data_array
+
+ return columns, rows
+
+ def ensure_table_exists(self, ws_client, table_name: str, warehouse_id: str) -> None:
+ """Create the graph table if it doesn't exist."""
+ safe_table = _validate_table_name(table_name)
+ sql = f"""CREATE TABLE IF NOT EXISTS {safe_table} (
+ node_start_id STRING NOT NULL,
+ node_start_key STRING NOT NULL,
+ relationship STRING NOT NULL,
+ node_end_id STRING NOT NULL,
+ node_end_key STRING NOT NULL,
+ node_start_properties STRING,
+ node_end_properties STRING
+ )"""
+ self._execute_sql(ws_client, sql, warehouse_id)
+ logger.info(f"Ensured table exists: {safe_table}")
+
+ def get_graph_data(self, ws_client, table_name: str, warehouse_id: str, max_rows: int = 0) -> Dict[str, Any]:
+ """Read graph data from a Databricks table and return as nodes + edges.
+
+ Args:
+ max_rows: Maximum rows to fetch. 0 means use the server-side
+ default derived from GRAPH_MAX_EDGES setting.
+ """
+ safe_table = _validate_table_name(table_name)
+
+ # Determine effective row limit from settings or argument
+ if max_rows <= 0:
+ max_edges = getattr(self.settings, "GRAPH_MAX_EDGES", 10000) if self.settings else 10000
+ effective_limit = max_edges
+ else:
+ effective_limit = max_rows
+
+ # Count total rows for truncation info
+ count_sql = f"SELECT COUNT(*) FROM {safe_table}"
+ _, count_rows = self._execute_sql(ws_client, count_sql, warehouse_id)
+ total_rows = int(count_rows[0][0]) if count_rows and count_rows[0][0] else 0
+
+ sql = f"SELECT * FROM {safe_table} LIMIT {effective_limit}"
+ columns, rows = self._execute_sql(ws_client, sql, warehouse_id)
+
+ if not columns:
+ return {"nodes": [], "edges": [], "truncated": False, "totalAvailable": None}
+
+ # Build column index map
+ col_idx = {name: i for i, name in enumerate(columns)}
+
+ # Track unique nodes and edges
+ nodes_map: Dict[str, Dict[str, Any]] = {}
+ edges: List[Dict[str, Any]] = []
+
+ for row in rows:
+ start_id = row[col_idx.get("node_start_id", 0)] or ""
+ start_key = row[col_idx.get("node_start_key", 1)] or "Node"
+ relationship = row[col_idx.get("relationship", 2)] or ""
+ end_id = row[col_idx.get("node_end_id", 3)] or ""
+ end_key = row[col_idx.get("node_end_key", 4)] or "Node"
+ start_props_raw = row[col_idx.get("node_start_properties", 5)]
+ end_props_raw = row[col_idx.get("node_end_properties", 6)]
+
+ # Parse properties (handles VARIANT and STRING columns)
+ start_props = self._parse_props(start_props_raw)
+ end_props = self._parse_props(end_props_raw)
+
+ # Extract label from properties or use id
+ start_label = start_props.pop("_label", None) or start_id
+ end_label = end_props.pop("_label", None) or end_id
+
+ # Register start node
+ if start_id and start_id not in nodes_map:
+ nodes_map[start_id] = {
+ "id": start_id,
+ "label": start_label,
+ "type": start_key,
+ "properties": start_props,
+ "status": "existing",
+ }
+
+ # Register end node
+ if end_id and end_id not in nodes_map:
+ nodes_map[end_id] = {
+ "id": end_id,
+ "label": end_label,
+ "type": end_key,
+ "properties": end_props,
+ "status": "existing",
+ }
+
+ # Add edge (skip EXISTS self-references — those are standalone node markers)
+ if relationship != "EXISTS" and start_id and end_id:
+ edge_id = f"{start_id}-{relationship}-{end_id}"
+ edges.append({
+ "id": edge_id,
+ "source": start_id,
+ "target": end_id,
+ "relationshipType": relationship,
+ "properties": {},
+ "status": "existing",
+ })
+
+ truncated = total_rows > effective_limit
+
+ return {
+ "nodes": list(nodes_map.values()),
+ "edges": edges,
+ "truncated": truncated,
+ "totalAvailable": total_rows if truncated else None,
+ }
+
+ def get_neighbors(
+ self,
+ ws_client,
+ table_name: str,
+ warehouse_id: str,
+ node_id: str,
+ direction: str = "both",
+ edge_types: Optional[List[str]] = None,
+ limit: int = 25,
+ offset: int = 0,
+ ) -> Dict[str, Any]:
+ """Get the 1-hop neighborhood of a node.
+
+ Args:
+ node_id: The node to expand from.
+ direction: 'outgoing', 'incoming', or 'both'.
+ edge_types: Optional list of relationship types to filter on.
+ limit: Max edges to return.
+ offset: Pagination offset.
+
+ Returns:
+ Dict with nodes, edges, truncated, and totalAvailable.
+ """
+ safe_table = _validate_table_name(table_name)
+ safe_id = _escape_sql_string(node_id)
+
+ # Build direction filter
+ if direction == "outgoing":
+ direction_filter = f"node_start_id = '{safe_id}'"
+ elif direction == "incoming":
+ direction_filter = f"node_end_id = '{safe_id}'"
+ else: # both
+ direction_filter = f"(node_start_id = '{safe_id}' OR node_end_id = '{safe_id}')"
+
+ # Exclude EXISTS self-references (standalone node markers)
+ base_where = f"{direction_filter} AND relationship != 'EXISTS'"
+
+ # Optional edge type filter
+ if edge_types:
+ escaped_types = [f"'{_escape_sql_string(t)}'" for t in edge_types]
+ base_where += f" AND relationship IN ({', '.join(escaped_types)})"
+
+ # Count total available (for truncation info)
+ count_sql = f"SELECT COUNT(*) FROM {safe_table} WHERE {base_where}"
+ _, count_rows = self._execute_sql(ws_client, count_sql, warehouse_id)
+ total_available = int(count_rows[0][0]) if count_rows and count_rows[0][0] else 0
+
+ # Fetch the edges with limit/offset
+ sql = f"SELECT * FROM {safe_table} WHERE {base_where} LIMIT {limit} OFFSET {offset}"
+ columns, rows = self._execute_sql(ws_client, sql, warehouse_id)
+
+ if not columns:
+ return {"nodes": [], "edges": [], "truncated": False, "totalAvailable": 0}
+
+ # Parse rows into nodes + edges (reuse the same logic as get_graph_data)
+ col_idx = {name: i for i, name in enumerate(columns)}
+ nodes_map: Dict[str, Dict[str, Any]] = {}
+ edges: List[Dict[str, Any]] = []
+
+ for row in rows:
+ start_id = row[col_idx.get("node_start_id", 0)] or ""
+ start_key = row[col_idx.get("node_start_key", 1)] or "Node"
+ relationship = row[col_idx.get("relationship", 2)] or ""
+ end_id = row[col_idx.get("node_end_id", 3)] or ""
+ end_key = row[col_idx.get("node_end_key", 4)] or "Node"
+ start_props_raw = row[col_idx.get("node_start_properties", 5)]
+ end_props_raw = row[col_idx.get("node_end_properties", 6)]
+
+ start_props = self._parse_props(start_props_raw)
+ end_props = self._parse_props(end_props_raw)
+ start_label = start_props.pop("_label", None) or start_id
+ end_label = end_props.pop("_label", None) or end_id
+
+ if start_id and start_id not in nodes_map:
+ nodes_map[start_id] = {
+ "id": start_id,
+ "label": start_label,
+ "type": start_key,
+ "properties": start_props,
+ "status": "existing",
+ }
+
+ if end_id and end_id not in nodes_map:
+ nodes_map[end_id] = {
+ "id": end_id,
+ "label": end_label,
+ "type": end_key,
+ "properties": end_props,
+ "status": "existing",
+ }
+
+ if relationship and start_id and end_id:
+ edge_id = f"{start_id}-{relationship}-{end_id}"
+ edges.append({
+ "id": edge_id,
+ "source": start_id,
+ "target": end_id,
+ "relationshipType": relationship,
+ "properties": {},
+ "status": "existing",
+ })
+
+ truncated = total_available > (offset + limit)
+
+ return {
+ "nodes": list(nodes_map.values()),
+ "edges": edges,
+ "truncated": truncated,
+ "totalAvailable": total_available,
+ }
+
+ def write_nodes_and_edges(
+ self,
+ ws_client,
+ table_name: str,
+ warehouse_id: str,
+ nodes: List[Dict[str, Any]],
+ edges: List[Dict[str, Any]],
+ ) -> Dict[str, Any]:
+ """Write new nodes and edges to the Databricks table."""
+ safe_table = _validate_table_name(table_name)
+ nodes_written = 0
+ edges_written = 0
+
+ # Build a lookup of all nodes (from provided list) for edge writing
+ node_map = {n["id"]: n for n in nodes}
+
+ # Write standalone nodes as EXISTS self-referencing edges
+ for node in nodes:
+ node_id = _escape_sql_string(node["id"])
+ node_type = _escape_sql_string(node.get("type", "Node"))
+ props = dict(node.get("properties", {}))
+ props["_label"] = node.get("label", node["id"])
+ props_json = _escape_sql_string(json.dumps(props))
+
+ sql = f"""INSERT INTO {safe_table}
+ (node_start_id, node_start_key, relationship, node_end_id, node_end_key, node_start_properties, node_end_properties)
+ VALUES ('{node_id}', '{node_type}', 'EXISTS', '{node_id}', '{node_type}', '{props_json}', '{props_json}')"""
+ self._execute_sql(ws_client, sql, warehouse_id)
+ nodes_written += 1
+
+ # Write edges
+ for edge in edges:
+ source_id = _escape_sql_string(edge["source"])
+ target_id = _escape_sql_string(edge["target"])
+ rel_type = _escape_sql_string(edge.get("relationshipType", "RELATES_TO"))
+
+ # Get source/target node info for properties
+ source_node = node_map.get(edge["source"], {})
+ target_node = node_map.get(edge["target"], {})
+
+ source_type = _escape_sql_string(source_node.get("type", "Node"))
+ target_type = _escape_sql_string(target_node.get("type", "Node"))
+
+ source_props = dict(source_node.get("properties", {}))
+ source_props["_label"] = source_node.get("label", edge["source"])
+ target_props = dict(target_node.get("properties", {}))
+ target_props["_label"] = target_node.get("label", edge["target"])
+
+ source_props_json = _escape_sql_string(json.dumps(source_props))
+ target_props_json = _escape_sql_string(json.dumps(target_props))
+
+ sql = f"""INSERT INTO {safe_table}
+ (node_start_id, node_start_key, relationship, node_end_id, node_end_key, node_start_properties, node_end_properties)
+ VALUES ('{source_id}', '{source_type}', '{rel_type}', '{target_id}', '{target_type}', '{source_props_json}', '{target_props_json}')"""
+ self._execute_sql(ws_client, sql, warehouse_id)
+ edges_written += 1
+
+ return {"nodesWritten": nodes_written, "edgesWritten": edges_written}
+
+ def delete_node(self, ws_client, table_name: str, warehouse_id: str, node_id: str) -> None:
+ """Delete a node and all its connected edges."""
+ safe_table = _validate_table_name(table_name)
+ safe_id = _escape_sql_string(node_id)
+ sql = f"DELETE FROM {safe_table} WHERE node_start_id = '{safe_id}' OR node_end_id = '{safe_id}'"
+ self._execute_sql(ws_client, sql, warehouse_id)
+ logger.info(f"Deleted node {node_id} and connected edges from {safe_table}")
+
+ def delete_edge(self, ws_client, table_name: str, warehouse_id: str, source_id: str, target_id: str, relationship_type: str) -> None:
+ """Delete a specific edge."""
+ safe_table = _validate_table_name(table_name)
+ safe_source = _escape_sql_string(source_id)
+ safe_target = _escape_sql_string(target_id)
+ safe_rel = _escape_sql_string(relationship_type)
+ sql = f"""DELETE FROM {safe_table}
+ WHERE node_start_id = '{safe_source}'
+ AND node_end_id = '{safe_target}'
+ AND relationship = '{safe_rel}'"""
+ self._execute_sql(ws_client, sql, warehouse_id)
+ logger.info(f"Deleted edge {source_id} -[{relationship_type}]-> {target_id} from {safe_table}")
+
+ def update_node(self, ws_client, table_name: str, warehouse_id: str, node_id: str, label: str, node_type: str, properties: Dict) -> None:
+ """Update a node's properties, type, and label across all rows."""
+ safe_table = _validate_table_name(table_name)
+ safe_id = _escape_sql_string(node_id)
+ safe_type = _escape_sql_string(node_type)
+ props = dict(properties)
+ props["_label"] = label
+ props_json = _escape_sql_string(json.dumps(props))
+
+ # Update as start node
+ sql = f"""UPDATE {safe_table}
+ SET node_start_key = '{safe_type}', node_start_properties = '{props_json}'
+ WHERE node_start_id = '{safe_id}'"""
+ self._execute_sql(ws_client, sql, warehouse_id)
+
+ # Update as end node
+ sql = f"""UPDATE {safe_table}
+ SET node_end_key = '{safe_type}', node_end_properties = '{props_json}'
+ WHERE node_end_id = '{safe_id}'"""
+ self._execute_sql(ws_client, sql, warehouse_id)
+ logger.info(f"Updated node {node_id} in {safe_table}")
+
+ # -----------------------------------------------------------------
+ # LLM Configuration
+ # -----------------------------------------------------------------
+
+ def get_llm_config(self) -> Dict[str, Any]:
+ """Return the current LLM configuration state.
+
+ The graph query panel is enabled whenever an LLM endpoint is
+ configured — it does NOT require the global LLM_ENABLED flag, which
+ gates the heavier conversational search feature.
+ """
+ if not self.settings:
+ return {"enabled": False, "defaultModel": "", "maxTokens": 4096, "provider": "databricks"}
+
+ endpoint = getattr(self.settings, "LLM_ENDPOINT", "") or ""
+ return {
+ "enabled": bool(endpoint),
+ "defaultModel": endpoint,
+ "maxTokens": 4096,
+ "provider": "databricks",
+ }
+
+ # -----------------------------------------------------------------
+ # Graph Query Translation (Cypher / Gremlin → SQL via LLM)
+ # -----------------------------------------------------------------
+
+ _TRANSLATE_SYSTEM_PROMPT = (
+ "You are a SQL translation engine for Databricks Unity Catalog. "
+ "You translate Cypher or Gremlin graph queries into valid Databricks SQL.\n\n"
+ "The target table has the following schema:\n"
+ " node_start_id STRING, node_start_key STRING, relationship STRING,\n"
+ " node_end_id STRING, node_end_key STRING,\n"
+ " node_start_properties VARIANT, node_end_properties VARIANT\n\n"
+ "Rules:\n"
+ "- Each row represents an edge between two nodes.\n"
+ "- node_start_id/node_start_key/node_start_properties describe the source node.\n"
+ "- node_end_id/node_end_key/node_end_properties describe the target node.\n"
+ "- relationship is the edge type (e.g. 'ALLIED_WITH', 'BORN_ON').\n"
+ "- Node labels (types) are stored in node_start_key and node_end_key.\n"
+ "- Some tables use relationship = 'EXISTS' for standalone nodes (self-edges).\n"
+ " Check the graph data context to see if EXISTS is listed as a relationship type.\n"
+ " If EXISTS is NOT listed, nodes only appear as edge endpoints.\n"
+ "\n"
+ "CRITICAL — Subgraph pattern for node-centric queries:\n"
+ "When the user asks for specific nodes (e.g. 'dark characters'), use a CTE to find\n"
+ "matching node IDs first, then return only edges BETWEEN those nodes. This avoids\n"
+ "returning unrelated connected nodes.\n"
+ "Example:\n"
+ " WITH matched AS (\n"
+ " SELECT DISTINCT node_start_id AS id FROM table\n"
+ " WHERE node_start_key = 'Character'\n"
+ " AND get_json_object(CAST(node_start_properties AS STRING), '$.alignment') = 'Dark'\n"
+ " UNION\n"
+ " SELECT DISTINCT node_end_id AS id FROM table\n"
+ " WHERE node_end_key = 'Character'\n"
+ " AND get_json_object(CAST(node_end_properties AS STRING), '$.alignment') = 'Dark'\n"
+ " )\n"
+ " SELECT * FROM table\n"
+ " WHERE node_start_id IN (SELECT id FROM matched)\n"
+ " AND node_end_id IN (SELECT id FROM matched)\n"
+ " LIMIT 5000\n"
+ "\n"
+ "For relationship queries (e.g. 'characters who fought in battles'), don't use the CTE\n"
+ "pattern — just filter by relationship type and node types directly.\n"
+ "\n"
+ "- node_start_properties and node_end_properties are VARIANT columns.\n"
+ " IMPORTANT: Always use get_json_object with CAST to STRING for property access:\n"
+ " Example: get_json_object(CAST(node_start_properties AS STRING), '$.alignment') = 'Dark'\n"
+ " Example: CAST(get_json_object(CAST(node_start_properties AS STRING), '$.age') AS INT) > 30\n"
+ "- For natural language queries, use LIKE for fuzzy string matching:\n"
+ " Example: LOWER(get_json_object(CAST(node_start_properties AS STRING), '$.alignment')) LIKE '%dark%'\n"
+ "- For Cypher/Gremlin queries, use exact values from the graph data context.\n"
+ "- IMPORTANT: Always prefer the actual property values shown in the graph data context.\n"
+ "- IMPORTANT: Always SELECT * (all columns). Never select a subset of columns.\n"
+ " The downstream parser requires all columns to extract nodes and edges.\n"
+ "- Always add LIMIT 5000 unless the user query already contains a limit.\n"
+ "- Return ONLY the SQL query, nothing else — no markdown, no explanation.\n"
+ )
+
+ def _get_openai_client(self):
+ """Get an OpenAI-compatible client for the configured LLM endpoint."""
+ from openai import OpenAI
+
+ token = None
+
+ # Explicit token from settings / env (local dev)
+ token = getattr(self.settings, "DATABRICKS_TOKEN", None) or os.environ.get("DATABRICKS_TOKEN")
+ if token:
+ logger.debug("Graph query LLM: using token from settings/environment")
+
+ # Fall back to Databricks SDK OBO
+ if not token:
+ try:
+ from databricks.sdk.core import Config
+ config = Config()
+ headers = config.authenticate()
+ if headers and "Authorization" in headers:
+ auth_header = headers["Authorization"]
+ if auth_header.startswith("Bearer "):
+ token = auth_header[7:]
+ logger.debug("Graph query LLM: using SDK OBO token")
+ except Exception as sdk_err:
+ logger.debug(f"Could not get SDK token: {sdk_err}")
+
+ if not token:
+ raise RuntimeError("No authentication token available for LLM endpoint.")
+
+ base_url = getattr(self.settings, "LLM_BASE_URL", None)
+ if not base_url:
+ host = getattr(self.settings, "DATABRICKS_HOST", "")
+ if host:
+ host = host.rstrip("/")
+ if not host.startswith("http://") and not host.startswith("https://"):
+ host = f"https://{host}"
+ base_url = f"{host}/serving-endpoints"
+
+ if not base_url:
+ raise RuntimeError("LLM_BASE_URL not configured.")
+
+ return OpenAI(api_key=token, base_url=base_url)
+
+ def _get_graph_schema(self, ws_client, table_name: str, warehouse_id: str) -> str:
+ """Build a concise schema summary from the actual graph data for LLM context.
+
+ Returns a text block describing actual column types, node types,
+ relationship types, and sample property keys with distinct values.
+ """
+ safe_table = _validate_table_name(table_name)
+ lines: List[str] = []
+
+ try:
+ # Step 0: Get actual column types via DESCRIBE
+ try:
+ sql = f"DESCRIBE TABLE {safe_table}"
+ cols, desc_rows = self._execute_sql(ws_client, sql, warehouse_id)
+ col_types = {r[0]: r[1] for r in desc_rows if r[0] and r[1]}
+ if col_types:
+ lines.append("Actual column types: " + ", ".join(
+ f"{k} {v}" for k, v in col_types.items()
+ ))
+ except Exception:
+ pass
+
+ # Step 1: All distinct node types (start and end)
+ sql = (
+ f"SELECT DISTINCT node_start_key FROM {safe_table} "
+ f"UNION SELECT DISTINCT node_end_key FROM {safe_table} LIMIT 50"
+ )
+ _, rows = self._execute_sql(ws_client, sql, warehouse_id)
+ node_types = sorted(set(r[0] for r in rows if r[0]))
+ if node_types:
+ lines.append(f"Node types: {', '.join(node_types)}")
+
+ # Step 2: All distinct relationship types
+ sql = f"SELECT DISTINCT relationship FROM {safe_table} LIMIT 50"
+ _, rows = self._execute_sql(ws_client, sql, warehouse_id)
+ rel_types = sorted(set(r[0] for r in rows if r[0]))
+ if rel_types:
+ lines.append(f"Relationship types: {', '.join(rel_types)}")
+ if "EXISTS" not in rel_types:
+ lines.append("NOTE: No 'EXISTS' relationship found — nodes only appear as edge endpoints, not as standalone rows.")
+
+ # Step 3: Sample properties per node type — use raw rows (no EXISTS filter)
+ for ntype in node_types[:5]:
+ safe_type = _escape_sql_string(ntype)
+ sql = (
+ f"SELECT node_start_properties FROM {safe_table} "
+ f"WHERE node_start_key = '{safe_type}' "
+ f"AND node_start_properties IS NOT NULL LIMIT 10"
+ )
+ _, rows = self._execute_sql(ws_client, sql, warehouse_id)
+
+ # Collect all property keys and their distinct values
+ key_values: Dict[str, set] = {}
+ for row in rows:
+ props = self._parse_props(row[0])
+ for k, v in props.items():
+ if k.startswith("_"):
+ continue
+ if k not in key_values:
+ key_values[k] = set()
+ if v is not None and len(key_values[k]) < 5:
+ key_values[k].add(str(v)[:50])
+
+ if key_values:
+ prop_parts = []
+ for k, vals in list(key_values.items())[:8]:
+ distinct = sorted(vals)[:4]
+ prop_parts.append(f"{k} (e.g. {', '.join(repr(v) for v in distinct)})")
+ lines.append(f" {ntype} properties: {'; '.join(prop_parts)}")
+
+ # Step 4: If no properties were found, dump one raw row for debugging
+ if not any("properties:" in l for l in lines):
+ sql = f"SELECT * FROM {safe_table} LIMIT 1"
+ cols, rows = self._execute_sql(ws_client, sql, warehouse_id)
+ if rows:
+ raw_sample = {cols[i]: rows[0][i] for i in range(len(cols))}
+ lines.append(f"Sample raw row: {json.dumps(raw_sample, default=str)[:500]}")
+
+ except Exception as e:
+ logger.warning(f"Failed to fetch graph schema for LLM context: {e}", exc_info=True)
+
+ return "\n".join(lines) if lines else "No schema information available."
+
+ def _translate_to_sql(self, query: str, language: str, table_name: str,
+ graph_schema: str = "") -> str:
+ """Translate a natural language, Cypher, or Gremlin query to Databricks SQL via LLM."""
+ client = self._get_openai_client()
+ endpoint = getattr(self.settings, "LLM_ENDPOINT", "") or ""
+
+ safe_table = _validate_table_name(table_name)
+
+ schema_block = ""
+ if graph_schema:
+ schema_block = f"\nGraph data context:\n{graph_schema}\n"
+
+ if language == "natural":
+ user_message = (
+ f"Table: {safe_table}\n"
+ f"{schema_block}"
+ f"User request (natural language): {query}\n\n"
+ "Write a Databricks SQL SELECT statement that answers this request. "
+ "Return graph data (nodes and edges) that match the user's intent."
+ )
+ else:
+ user_message = (
+ f"Table: {safe_table}\n"
+ f"{schema_block}"
+ f"Language: {language}\n"
+ f"Query: {query}\n\n"
+ "Translate this to a Databricks SQL SELECT statement."
+ )
+
+ response = client.chat.completions.create(
+ model=endpoint,
+ messages=[
+ {"role": "system", "content": self._TRANSLATE_SYSTEM_PROMPT},
+ {"role": "user", "content": user_message},
+ ],
+ max_tokens=2048,
+ temperature=0,
+ )
+
+ sql = response.choices[0].message.content.strip()
+ # Strip markdown fencing if the model wraps it
+ if sql.startswith("```"):
+ sql = re.sub(r"^```(?:sql)?\s*", "", sql)
+ sql = re.sub(r"\s*```$", "", sql)
+ return sql.strip()
+
+ @staticmethod
+ def _parse_props(raw: Any) -> Dict[str, Any]:
+ """Parse a properties value that may be a VARIANT JSON string, dict, or None."""
+ if raw is None:
+ return {}
+ if isinstance(raw, dict):
+ return raw
+ if isinstance(raw, str):
+ try:
+ parsed = json.loads(raw)
+ return parsed if isinstance(parsed, dict) else {}
+ except (json.JSONDecodeError, TypeError):
+ return {}
+ return {}
+
+ def _parse_query_results(self, columns: List[str], rows: List[List[Any]]) -> Dict[str, Any]:
+ """Parse raw SQL result rows into nodes and edges, mirroring get_graph_data."""
+ col_idx = {name: i for i, name in enumerate(columns)}
+ nodes_map: Dict[str, Dict[str, Any]] = {}
+ edges: List[Dict[str, Any]] = []
+
+ has_edge_columns = "relationship" in col_idx
+
+ for row in rows:
+ start_id = row[col_idx["node_start_id"]] if "node_start_id" in col_idx else None
+ start_key = row[col_idx.get("node_start_key", -1)] if "node_start_key" in col_idx else "Node"
+ relationship = row[col_idx["relationship"]] if "relationship" in col_idx else None
+ end_id = row[col_idx["node_end_id"]] if "node_end_id" in col_idx else None
+ end_key = row[col_idx.get("node_end_key", -1)] if "node_end_key" in col_idx else "Node"
+ start_props_raw = row[col_idx.get("node_start_properties", -1)] if "node_start_properties" in col_idx else None
+ end_props_raw = row[col_idx.get("node_end_properties", -1)] if "node_end_properties" in col_idx else None
+
+ # Parse properties — handles VARIANT (returned as JSON string by Statement Execution API) and dicts
+ start_props = self._parse_props(start_props_raw)
+ end_props = self._parse_props(end_props_raw)
+
+ start_label = start_props.pop("_label", None) or (start_id or "")
+ end_label = end_props.pop("_label", None) or (end_id or "")
+
+ if start_id and start_id not in nodes_map:
+ nodes_map[start_id] = {
+ "id": start_id,
+ "label": start_label,
+ "type": start_key or "Node",
+ "properties": start_props,
+ "status": "existing",
+ }
+
+ if end_id and end_id not in nodes_map:
+ nodes_map[end_id] = {
+ "id": end_id,
+ "label": end_label,
+ "type": end_key or "Node",
+ "properties": end_props,
+ "status": "existing",
+ }
+
+ if relationship and relationship != "EXISTS" and start_id and end_id:
+ edge_id = f"{start_id}-{relationship}-{end_id}"
+ edges.append({
+ "id": edge_id,
+ "source": start_id,
+ "target": end_id,
+ "relationshipType": relationship,
+ "properties": {},
+ "status": "existing",
+ })
+
+ return {
+ "nodes": list(nodes_map.values()),
+ "edges": edges,
+ "hasEdgeColumns": has_edge_columns,
+ }
+
+ def execute_graph_query(
+ self,
+ ws_client,
+ warehouse_id: str,
+ query: str,
+ language: str,
+ table_name: str,
+ override_sql: Optional[str] = None,
+ ) -> Dict[str, Any]:
+ """
+ Translate a Cypher/Gremlin query to SQL via LLM, execute it, and
+ return graph nodes/edges. If *override_sql* is provided, skip
+ translation and execute the SQL directly.
+ """
+ t0 = time.time()
+ endpoint = getattr(self.settings, "LLM_ENDPOINT", "") or ""
+
+ # Step 1: get SQL
+ graph_schema = ""
+ if override_sql:
+ sql = override_sql
+ else:
+ # Fetch graph schema context so the LLM knows what types/properties exist
+ graph_schema = self._get_graph_schema(ws_client, table_name, warehouse_id)
+ logger.info(f"Graph schema context:\n{graph_schema}")
+ sql = self._translate_to_sql(query, language, table_name, graph_schema=graph_schema)
+ logger.info(f"Graph query SQL: {sql}")
+
+ # Step 2: execute SQL
+ try:
+ columns, rows = self._execute_sql(ws_client, sql, warehouse_id)
+ except RuntimeError as e:
+ return {
+ "success": False,
+ "nodes": [],
+ "edges": [],
+ "sql": sql,
+ "language": language,
+ "originalQuery": query,
+ "message": str(e),
+ }
+
+ # Step 3: parse results
+ parsed = self._parse_query_results(columns, rows)
+ duration = f"{time.time() - t0:.2f}s"
+
+ # Detect vertex-only queries: CTE pattern means the user asked for specific nodes
+ is_vertex_only = bool(re.search(r"(?i)\bWITH\s+matched\s+AS\b", sql))
+
+ return {
+ "success": True,
+ "nodes": parsed["nodes"],
+ "edges": parsed["edges"] if not is_vertex_only else [],
+ "sql": sql,
+ "language": language,
+ "originalQuery": query,
+ "rawRowCount": len(rows),
+ "hasEdgeColumns": parsed.get("hasEdgeColumns", False),
+ "vertexOnly": is_vertex_only,
+ "metadata": {
+ "source": "databricks",
+ "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
+ "duration": duration,
+ "translationModel": endpoint,
+ "graphSchema": graph_schema if graph_schema else None,
+ },
+ }
diff --git a/src/backend/src/controller/notifications_manager.py b/src/backend/src/controller/notifications_manager.py
index 89e00256..7813989c 100644
--- a/src/backend/src/controller/notifications_manager.py
+++ b/src/backend/src/controller/notifications_manager.py
@@ -63,7 +63,9 @@ def get_notifications(self, db: Session, user_info: Optional[UserInfo] = None) -
# --- Filtering logic (similar to before, but uses API models and SettingsManager) ---
if not user_info:
# Return only broadcast notifications if no user info
- return [n for n in all_notifications_api if not n.recipient]
+ broadcast = [n for n in all_notifications_api if not n.recipient]
+ broadcast.sort(key=lambda x: x.created_at, reverse=True)
+ return broadcast
user_groups = set(user_info.groups or [])
user_email = user_info.email
@@ -276,12 +278,31 @@ async def create_notification(
def delete_notification(self, db: Session, notification_id: str) -> bool:
"""Delete a notification by ID using the repository."""
try:
+ db_obj = self._repo.get(db=db, id=notification_id)
+ if db_obj is None:
+ return False
+ if not db_obj.can_delete:
+ return False
deleted_obj = self._repo.remove(db=db, id=notification_id)
return deleted_obj is not None
except Exception as e:
logger.error(f"Error deleting notification {notification_id}: {e}", exc_info=True)
raise
+ def mark_all_as_read(self, db: Session, user_info: UserInfo) -> int:
+ """Mark all notifications for a user as read. Returns count of updated notifications."""
+ if not user_info:
+ return 0
+ all_notifications = self._repo.get_multi(db=db, limit=1000)
+ count = 0
+ user_email = user_info.email
+ for db_obj in all_notifications:
+ if db_obj.recipient == user_email and not db_obj.read:
+ self._repo.update(db=db, db_obj=db_obj, obj_in={"read": True})
+ count += 1
+ db.commit()
+ return count
+
def mark_notification_read(self, db: Session, notification_id: str) -> Optional[Notification]:
"""Mark a notification as read using the repository."""
try:
diff --git a/src/backend/src/controller/security_features_manager.py b/src/backend/src/controller/security_features_manager.py
index 1ed0d36b..b7ce8e14 100644
--- a/src/backend/src/controller/security_features_manager.py
+++ b/src/backend/src/controller/security_features_manager.py
@@ -43,14 +43,14 @@ def list_features(self) -> List[SecurityFeature]:
def update_feature(self, feature_id: str, feature: SecurityFeature) -> Optional[SecurityFeature]:
if feature_id not in self.features:
- logging.warning(f"Security feature not found: {feature_id}")
+ logger.warning(f"Security feature not found: {feature_id}")
return None
self.features[feature_id] = feature
return feature
def delete_feature(self, feature_id: str) -> bool:
if feature_id not in self.features:
- logging.warning(f"Security feature not found: {feature_id}")
+ logger.warning(f"Security feature not found: {feature_id}")
return False
del self.features[feature_id]
return True
@@ -119,5 +119,5 @@ def save_to_yaml(self, yaml_path: Path) -> None:
with open(yaml_path, 'w') as f:
yaml.dump(data, f)
except Exception as e:
- logging.exception(f"Error saving security features to YAML: {e!s}")
+ logger.exception(f"Error saving security features to YAML: {e!s}")
raise
diff --git a/src/backend/src/models/graph_explorer.py b/src/backend/src/models/graph_explorer.py
new file mode 100644
index 00000000..3aae1420
--- /dev/null
+++ b/src/backend/src/models/graph_explorer.py
@@ -0,0 +1,134 @@
+"""
+Pydantic models for Graph Explorer API.
+"""
+
+from enum import Enum
+from typing import Any, Dict, List, Optional
+from pydantic import BaseModel, Field
+
+
+class GraphNodeRequest(BaseModel):
+ """Request model for creating/updating a node."""
+ id: str
+ label: str
+ type: str = "Node"
+ properties: Dict[str, Any] = Field(default_factory=dict)
+
+
+class GraphEdgeRequest(BaseModel):
+ """Request model for creating/updating an edge."""
+ id: Optional[str] = None
+ source: str
+ target: str
+ relationshipType: str
+ properties: Dict[str, Any] = Field(default_factory=dict)
+
+
+class SaveGraphRequest(BaseModel):
+ """Request model for saving new/modified graph data."""
+ tableName: str
+ nodes: List[GraphNodeRequest] = Field(default_factory=list)
+ edges: List[GraphEdgeRequest] = Field(default_factory=list)
+
+
+class DeleteNodeRequest(BaseModel):
+ """Request model for deleting a node."""
+ tableName: str
+ nodeId: str
+
+
+class DeleteEdgeRequest(BaseModel):
+ """Request model for deleting an edge."""
+ tableName: str
+ sourceId: str
+ targetId: str
+ relationshipType: str
+
+
+class UpdateNodeRequest(BaseModel):
+ """Request model for updating a node."""
+ tableName: str
+ nodeId: str
+ label: str
+ type: str = "Node"
+ properties: Dict[str, Any] = Field(default_factory=dict)
+
+
+class NeighborDirection(str, Enum):
+ """Direction for neighbor expansion."""
+ OUTGOING = "outgoing"
+ INCOMING = "incoming"
+ BOTH = "both"
+
+
+class GraphDataResponse(BaseModel):
+ """Response model for graph data."""
+ nodes: List[Dict[str, Any]]
+ edges: List[Dict[str, Any]]
+ truncated: bool = False
+ totalAvailable: Optional[int] = None
+
+
+class SaveGraphResponse(BaseModel):
+ """Response model for save operation."""
+ nodesWritten: int
+ edgesWritten: int
+
+
+class EnsureTableResponse(BaseModel):
+ """Response model for table creation."""
+ tableName: str
+ status: str = "ok"
+
+
+# ---------------------------------------------------------------------------
+# Graph Query (Cypher / Gremlin → SQL via LLM)
+# ---------------------------------------------------------------------------
+
+class GraphQueryRequest(BaseModel):
+ """Request model for executing a Cypher/Gremlin graph query."""
+ query: str
+ language: str = Field(default="cypher", description="Query language: 'cypher' or 'gremlin'")
+ tableName: Optional[str] = None
+ modelEndpoint: Optional[str] = None
+ sql: Optional[str] = Field(default=None, description="Override SQL — skip LLM translation and execute directly")
+
+
+class GraphQueryResponseMetadata(BaseModel):
+ """Metadata about the query execution."""
+ source: str = "databricks"
+ timestamp: Optional[str] = None
+ duration: Optional[str] = None
+ translationModel: Optional[str] = None
+ graphSchema: Optional[str] = None
+
+
+class GraphQueryResponse(BaseModel):
+ """Response model for a graph query execution."""
+ success: bool
+ nodes: List[Dict[str, Any]] = Field(default_factory=list)
+ edges: List[Dict[str, Any]] = Field(default_factory=list)
+ sql: str = ""
+ language: str = ""
+ originalQuery: str = ""
+ rawRowCount: Optional[int] = None
+ hasEdgeColumns: Optional[bool] = None
+ vertexOnly: Optional[bool] = None
+ message: Optional[str] = None
+ metadata: Optional[GraphQueryResponseMetadata] = None
+
+
+class GraphLimitsResponse(BaseModel):
+ """Response model for graph safety limits."""
+ maxNodes: int = 5000
+ maxEdges: int = 10000
+ neighborLimit: int = 50
+ queryTimeout: str = "30s"
+
+
+class LlmConfigResponse(BaseModel):
+ """Response model for LLM configuration status."""
+ enabled: bool
+ defaultModel: str = ""
+ maxTokens: int = 4096
+ provider: str = "databricks"
diff --git a/src/backend/src/routes/graph_explorer_routes.py b/src/backend/src/routes/graph_explorer_routes.py
new file mode 100644
index 00000000..984fe5c7
--- /dev/null
+++ b/src/backend/src/routes/graph_explorer_routes.py
@@ -0,0 +1,270 @@
+"""
+FastAPI routes for Graph Explorer.
+
+All operations go through Databricks Statement Execution API.
+The table name is passed as a query parameter or in the request body.
+"""
+
+from typing import List, Optional
+from fastapi import APIRouter, Depends, FastAPI, HTTPException, Query, Request
+
+from src.common.authorization import PermissionChecker
+from src.common.config import Settings, get_settings
+from src.common.dependencies import GraphExplorerManagerDep
+from src.common.features import FeatureAccessLevel
+from src.common.logging import get_logger
+from src.common.workspace_client import get_workspace_client
+from src.controller.graph_explorer_manager import GraphExplorerManager, DEFAULT_TABLE_NAME
+from src.models.graph_explorer import (
+ DeleteEdgeRequest,
+ DeleteNodeRequest,
+ EnsureTableResponse,
+ GraphDataResponse,
+ GraphLimitsResponse,
+ GraphQueryRequest,
+ GraphQueryResponse,
+ LlmConfigResponse,
+ NeighborDirection,
+ SaveGraphRequest,
+ SaveGraphResponse,
+ UpdateNodeRequest,
+)
+
+logger = get_logger(__name__)
+
+GRAPH_EXPLORER_FEATURE_ID = 'graph-explorer'
+
+router = APIRouter(prefix="/api/graph-explorer", tags=["graph-explorer"])
+
+
+def _get_ws_and_warehouse(settings: Settings):
+ """Get workspace client and warehouse_id."""
+ ws_client = get_workspace_client()
+ warehouse_id = settings.DATABRICKS_WAREHOUSE_ID
+ if not warehouse_id:
+ raise HTTPException(status_code=500, detail="DATABRICKS_WAREHOUSE_ID not configured")
+ return ws_client, warehouse_id
+
+
+@router.get("", response_model=GraphDataResponse)
+async def get_graph_data(
+ table_name: str = Query(default=DEFAULT_TABLE_NAME, alias="tableName"),
+ manager: GraphExplorerManagerDep = None,
+ settings: Settings = Depends(get_settings),
+ _: bool = Depends(PermissionChecker(GRAPH_EXPLORER_FEATURE_ID, FeatureAccessLevel.READ_ONLY)),
+):
+ """Read graph data from a Databricks table."""
+ try:
+ ws_client, warehouse_id = _get_ws_and_warehouse(settings)
+ # Ensure the table exists first
+ manager.ensure_table_exists(ws_client, table_name, warehouse_id)
+ data = manager.get_graph_data(ws_client, table_name, warehouse_id)
+ return data
+ except ValueError as e:
+ raise HTTPException(status_code=400, detail=str(e))
+ except Exception as e:
+ logger.error(f"Error reading graph data: {e}")
+ raise HTTPException(status_code=500, detail=f"Error reading graph data: {str(e)}")
+
+
+@router.get("/neighbors", response_model=GraphDataResponse)
+async def get_neighbors(
+ node_id: str = Query(..., alias="nodeId"),
+ table_name: str = Query(default=DEFAULT_TABLE_NAME, alias="tableName"),
+ direction: NeighborDirection = Query(default=NeighborDirection.BOTH),
+ edge_types: Optional[List[str]] = Query(default=None, alias="edgeTypes"),
+ limit: int = Query(default=25, ge=1, le=500),
+ offset: int = Query(default=0, ge=0),
+ manager: GraphExplorerManagerDep = None,
+ settings: Settings = Depends(get_settings),
+ _: bool = Depends(PermissionChecker(GRAPH_EXPLORER_FEATURE_ID, FeatureAccessLevel.READ_ONLY)),
+):
+ """Get the 1-hop neighborhood of a node for incremental expansion."""
+ try:
+ ws_client, warehouse_id = _get_ws_and_warehouse(settings)
+ data = manager.get_neighbors(
+ ws_client, table_name, warehouse_id,
+ node_id=node_id,
+ direction=direction.value,
+ edge_types=edge_types,
+ limit=limit,
+ offset=offset,
+ )
+ return data
+ except ValueError as e:
+ raise HTTPException(status_code=400, detail=str(e))
+ except Exception as e:
+ logger.error(f"Error getting neighbors: {e}")
+ raise HTTPException(status_code=500, detail=f"Error getting neighbors: {str(e)}")
+
+
+@router.post("/save", response_model=SaveGraphResponse)
+async def save_graph_data(
+ request: SaveGraphRequest,
+ manager: GraphExplorerManagerDep = None,
+ settings: Settings = Depends(get_settings),
+ _: bool = Depends(PermissionChecker(GRAPH_EXPLORER_FEATURE_ID, FeatureAccessLevel.READ_WRITE)),
+):
+ """Write new nodes and edges to a Databricks table."""
+ try:
+ ws_client, warehouse_id = _get_ws_and_warehouse(settings)
+ manager.ensure_table_exists(ws_client, request.tableName, warehouse_id)
+
+ nodes_dicts = [n.model_dump() for n in request.nodes]
+ edges_dicts = [e.model_dump() for e in request.edges]
+
+ result = manager.write_nodes_and_edges(
+ ws_client, request.tableName, warehouse_id, nodes_dicts, edges_dicts,
+ )
+ return result
+ except ValueError as e:
+ raise HTTPException(status_code=400, detail=str(e))
+ except Exception as e:
+ logger.error(f"Error saving graph data: {e}")
+ raise HTTPException(status_code=500, detail=f"Error saving graph data: {str(e)}")
+
+
+@router.post("/ensure-table", response_model=EnsureTableResponse)
+async def ensure_table(
+ table_name: str = Query(default=DEFAULT_TABLE_NAME, alias="tableName"),
+ manager: GraphExplorerManagerDep = None,
+ settings: Settings = Depends(get_settings),
+ _: bool = Depends(PermissionChecker(GRAPH_EXPLORER_FEATURE_ID, FeatureAccessLevel.READ_WRITE)),
+):
+ """Ensure the graph table exists, creating it if necessary."""
+ try:
+ ws_client, warehouse_id = _get_ws_and_warehouse(settings)
+ manager.ensure_table_exists(ws_client, table_name, warehouse_id)
+ return {"tableName": table_name, "status": "ok"}
+ except ValueError as e:
+ raise HTTPException(status_code=400, detail=str(e))
+ except Exception as e:
+ logger.error(f"Error ensuring table: {e}")
+ raise HTTPException(status_code=500, detail=f"Error ensuring table: {str(e)}")
+
+
+@router.delete("/node")
+async def delete_node(
+ request: DeleteNodeRequest,
+ manager: GraphExplorerManagerDep = None,
+ settings: Settings = Depends(get_settings),
+ _: bool = Depends(PermissionChecker(GRAPH_EXPLORER_FEATURE_ID, FeatureAccessLevel.READ_WRITE)),
+):
+ """Delete a node and its connected edges from the Databricks table."""
+ try:
+ ws_client, warehouse_id = _get_ws_and_warehouse(settings)
+ manager.delete_node(ws_client, request.tableName, warehouse_id, request.nodeId)
+ return {"status": "deleted", "nodeId": request.nodeId}
+ except ValueError as e:
+ raise HTTPException(status_code=400, detail=str(e))
+ except Exception as e:
+ logger.error(f"Error deleting node: {e}")
+ raise HTTPException(status_code=500, detail=f"Error deleting node: {str(e)}")
+
+
+@router.delete("/edge")
+async def delete_edge(
+ request: DeleteEdgeRequest,
+ manager: GraphExplorerManagerDep = None,
+ settings: Settings = Depends(get_settings),
+ _: bool = Depends(PermissionChecker(GRAPH_EXPLORER_FEATURE_ID, FeatureAccessLevel.READ_WRITE)),
+):
+ """Delete an edge from the Databricks table."""
+ try:
+ ws_client, warehouse_id = _get_ws_and_warehouse(settings)
+ manager.delete_edge(
+ ws_client, request.tableName, warehouse_id,
+ request.sourceId, request.targetId, request.relationshipType,
+ )
+ return {"status": "deleted"}
+ except ValueError as e:
+ raise HTTPException(status_code=400, detail=str(e))
+ except Exception as e:
+ logger.error(f"Error deleting edge: {e}")
+ raise HTTPException(status_code=500, detail=f"Error deleting edge: {str(e)}")
+
+
+@router.put("/node")
+async def update_node(
+ request: UpdateNodeRequest,
+ manager: GraphExplorerManagerDep = None,
+ settings: Settings = Depends(get_settings),
+ _: bool = Depends(PermissionChecker(GRAPH_EXPLORER_FEATURE_ID, FeatureAccessLevel.READ_WRITE)),
+):
+ """Update a node in the Databricks table."""
+ try:
+ ws_client, warehouse_id = _get_ws_and_warehouse(settings)
+ manager.update_node(
+ ws_client, request.tableName, warehouse_id,
+ request.nodeId, request.label, request.type, request.properties,
+ )
+ return {"status": "updated", "nodeId": request.nodeId}
+ except ValueError as e:
+ raise HTTPException(status_code=400, detail=str(e))
+ except Exception as e:
+ logger.error(f"Error updating node: {e}")
+ raise HTTPException(status_code=500, detail=f"Error updating node: {str(e)}")
+
+
+@router.get("/limits", response_model=GraphLimitsResponse)
+async def get_graph_limits(
+ settings: Settings = Depends(get_settings),
+ _: bool = Depends(PermissionChecker(GRAPH_EXPLORER_FEATURE_ID, FeatureAccessLevel.READ_ONLY)),
+):
+ """Return the current server-side safety limits for graph operations."""
+ return GraphLimitsResponse(
+ maxNodes=settings.GRAPH_MAX_NODES,
+ maxEdges=settings.GRAPH_MAX_EDGES,
+ neighborLimit=settings.GRAPH_NEIGHBOR_LIMIT,
+ queryTimeout=settings.GRAPH_QUERY_TIMEOUT,
+ )
+
+
+@router.get("/llm-config", response_model=LlmConfigResponse)
+async def get_llm_config(
+ manager: GraphExplorerManagerDep = None,
+ _: bool = Depends(PermissionChecker(GRAPH_EXPLORER_FEATURE_ID, FeatureAccessLevel.READ_ONLY)),
+):
+ """Return the current LLM configuration status for the query panel."""
+ return manager.get_llm_config()
+
+
+@router.post("/query", response_model=GraphQueryResponse)
+async def execute_graph_query(
+ request: GraphQueryRequest,
+ manager: GraphExplorerManagerDep = None,
+ settings: Settings = Depends(get_settings),
+ _: bool = Depends(PermissionChecker(GRAPH_EXPLORER_FEATURE_ID, FeatureAccessLevel.READ_WRITE)),
+):
+ """Translate a Cypher/Gremlin query to SQL via LLM and execute it."""
+ try:
+ ws_client, warehouse_id = _get_ws_and_warehouse(settings)
+ table_name = request.tableName or DEFAULT_TABLE_NAME
+ result = manager.execute_graph_query(
+ ws_client=ws_client,
+ warehouse_id=warehouse_id,
+ query=request.query,
+ language=request.language,
+ table_name=table_name,
+ override_sql=request.sql,
+ )
+ return result
+ except RuntimeError as e:
+ # LLM or SQL execution error — return as a structured error, not 500
+ return GraphQueryResponse(
+ success=False,
+ sql="",
+ language=request.language,
+ originalQuery=request.query,
+ message=str(e),
+ )
+ except ValueError as e:
+ raise HTTPException(status_code=400, detail=str(e))
+ except Exception as e:
+ logger.error(f"Error executing graph query: {e}")
+ raise HTTPException(status_code=500, detail=f"Error executing graph query: {str(e)}")
+
+
+def register_routes(app: FastAPI):
+ """Register graph explorer routes with the FastAPI app."""
+ app.include_router(router)
diff --git a/src/backend/src/tests/unit/test_audit_manager.py b/src/backend/src/tests/unit/test_audit_manager.py
index 3435b633..2f524396 100644
--- a/src/backend/src/tests/unit/test_audit_manager.py
+++ b/src/backend/src/tests/unit/test_audit_manager.py
@@ -1,6 +1,8 @@
"""
Unit tests for AuditManager
"""
+import uuid
+
import pytest
from datetime import datetime
from unittest.mock import Mock, MagicMock, patch
@@ -41,7 +43,7 @@ def manager(self, mock_repo, mock_settings, mock_db_session):
def sample_audit_log_db(self):
"""Sample audit log database object."""
log = Mock(spec=AuditLogDb)
- log.id = 1
+ log.id = uuid.uuid4()
log.username = "user@example.com"
log.ip_address = "192.168.1.1"
log.feature = "data-products"
diff --git a/src/backend/src/tests/unit/test_comments_manager.py b/src/backend/src/tests/unit/test_comments_manager.py
index 507cb1e1..658cb866 100644
--- a/src/backend/src/tests/unit/test_comments_manager.py
+++ b/src/backend/src/tests/unit/test_comments_manager.py
@@ -3,11 +3,12 @@
"""
import pytest
import json
+import uuid
from datetime import datetime
from unittest.mock import Mock, MagicMock
from src.controller.comments_manager import CommentsManager
from src.models.comments import CommentCreate, CommentUpdate, Comment
-from src.db_models.comments import CommentDb, CommentStatus
+from src.db_models.comments import CommentDb, CommentStatus, CommentType as DbCommentType
class TestCommentsManager:
@@ -27,13 +28,16 @@ def manager(self, mock_repository):
def sample_comment_db(self):
"""Sample comment database object."""
comment = Mock(spec=CommentDb)
- comment.id = "comment-123"
+ comment.id = uuid.UUID("11111111-2222-3333-4444-555555555555")
comment.entity_id = "entity-456"
comment.entity_type = "data_product"
comment.title = "Test Comment"
comment.comment = "This is a test comment"
comment.audience = None
+ comment.project_id = None
comment.status = CommentStatus.ACTIVE
+ comment.comment_type = DbCommentType.COMMENT
+ comment.rating = None
comment.created_by = "user@example.com"
comment.updated_by = "user@example.com"
comment.created_at = datetime(2024, 1, 1, 0, 0, 0)
@@ -49,6 +53,7 @@ def sample_comment_create(self):
title="Test Comment",
comment="This is a test comment",
audience=None,
+ project_id="project-1", # Project-scoped; global (project_id=None) requires admin
)
# Initialization Tests
@@ -82,7 +87,7 @@ def test_db_to_api_model(self, manager, sample_comment_db):
"""Test converting database model to API model."""
result = manager._db_to_api_model(sample_comment_db)
assert isinstance(result, Comment)
- assert result.id == "comment-123"
+ assert result.id == uuid.UUID("11111111-2222-3333-4444-555555555555")
assert result.title == "Test Comment"
# Create Comment Tests
@@ -95,7 +100,7 @@ def test_create_comment_success(self, manager, mock_repository, sample_comment_c
result = manager.create_comment(db_session, data=sample_comment_create, user_email="user@example.com")
assert isinstance(result, Comment)
- assert result.id == "comment-123"
+ assert result.id == uuid.UUID("11111111-2222-3333-4444-555555555555")
mock_repository.create_with_audience.assert_called_once()
db_session.commit.assert_called_once()
@@ -107,6 +112,7 @@ def test_create_comment_with_audience(self, manager, mock_repository, sample_com
title="Restricted Comment",
comment="For admins only",
audience=["admins", "data_stewards"],
+ project_id="project-1", # Project-scoped; global (project_id=None) requires admin
)
sample_comment_db.audience = '["admins", "data_stewards"]'
mock_repository.create_with_audience.return_value = sample_comment_db
@@ -114,7 +120,7 @@ def test_create_comment_with_audience(self, manager, mock_repository, sample_com
db_session = Mock()
result = manager.create_comment(db_session, data=comment_data, user_email="admin@example.com")
- assert result.id == "comment-123"
+ assert result.id == uuid.UUID("11111111-2222-3333-4444-555555555555")
mock_repository.create_with_audience.assert_called_once_with(
db_session, obj_in=comment_data, created_by="admin@example.com"
)
@@ -146,18 +152,21 @@ def test_list_comments_with_results(self, manager, mock_repository, sample_comme
assert result.total_count == 1
assert result.visible_count == 1
assert len(result.comments) == 1
- assert result.comments[0].id == "comment-123"
+ assert result.comments[0].id == uuid.UUID("11111111-2222-3333-4444-555555555555")
def test_list_comments_filtered_by_groups(self, manager, mock_repository, sample_comment_db):
"""Test listing comments filtered by user groups."""
comment2 = Mock(spec=CommentDb)
- comment2.id = "comment-456"
+ comment2.id = uuid.UUID("66666666-7777-8888-9999-000000000000")
comment2.entity_id = "entity-456"
comment2.entity_type = "data_product"
comment2.title = "Another Comment"
comment2.comment = "Another test"
comment2.audience = '["admins"]'
+ comment2.project_id = None
comment2.status = CommentStatus.ACTIVE
+ comment2.comment_type = DbCommentType.COMMENT
+ comment2.rating = None
comment2.created_by = "admin@example.com"
comment2.updated_by = "admin@example.com"
comment2.created_at = datetime(2024, 1, 1, 0, 0, 0)
@@ -201,7 +210,10 @@ def test_list_comments_include_deleted(self, manager, mock_repository, sample_co
def test_update_comment_success(self, manager, mock_repository, sample_comment_db):
"""Test updating a comment."""
updated_comment = Mock(spec=CommentDb)
- updated_comment.__dict__.update(sample_comment_db.__dict__)
+ for attr in ("id", "entity_id", "entity_type", "comment", "audience", "project_id",
+ "status", "comment_type", "rating", "created_by", "updated_by",
+ "created_at", "updated_at"):
+ setattr(updated_comment, attr, getattr(sample_comment_db, attr))
updated_comment.title = "Updated Title"
mock_repository.get.return_value = sample_comment_db
@@ -212,7 +224,7 @@ def test_update_comment_success(self, manager, mock_repository, sample_comment_d
update_data = CommentUpdate(title="Updated Title")
result = manager.update_comment(
db_session,
- comment_id="comment-123",
+ comment_id="11111111-2222-3333-4444-555555555555",
data=update_data,
user_email="user@example.com",
)
@@ -284,7 +296,7 @@ def test_delete_comment_soft_success(self, manager, mock_repository, sample_comm
db_session = Mock()
result = manager.delete_comment(
- db_session, comment_id="comment-123", user_email="user@example.com"
+ db_session, comment_id="11111111-2222-3333-4444-555555555555", user_email="user@example.com"
)
assert result is True
@@ -324,7 +336,7 @@ def test_delete_comment_permission_denied(self, manager, mock_repository, sample
db_session = Mock()
result = manager.delete_comment(
- db_session, comment_id="comment-123", user_email="other@example.com"
+ db_session, comment_id="11111111-2222-3333-4444-555555555555", user_email="other@example.com"
)
assert result is False
@@ -336,10 +348,10 @@ def test_get_comment_success(self, manager, mock_repository, sample_comment_db):
mock_repository.get.return_value = sample_comment_db
db_session = Mock()
- result = manager.get_comment(db_session, comment_id="comment-123")
+ result = manager.get_comment(db_session, comment_id="11111111-2222-3333-4444-555555555555")
assert result is not None
- assert result.id == "comment-123"
+ assert result.id == uuid.UUID("11111111-2222-3333-4444-555555555555")
def test_get_comment_not_found(self, manager, mock_repository):
"""Test getting non-existent comment."""
@@ -359,7 +371,7 @@ def test_can_user_modify_comment_true(self, manager, mock_repository, sample_com
db_session = Mock()
result = manager.can_user_modify_comment(
- db_session, comment_id="comment-123", user_email="user@example.com"
+ db_session, comment_id="11111111-2222-3333-4444-555555555555", user_email="user@example.com"
)
assert result is True
@@ -371,7 +383,7 @@ def test_can_user_modify_comment_false(self, manager, mock_repository, sample_co
db_session = Mock()
result = manager.can_user_modify_comment(
- db_session, comment_id="comment-123", user_email="other@example.com"
+ db_session, comment_id="11111111-2222-3333-4444-555555555555", user_email="other@example.com"
)
assert result is False
diff --git a/src/backend/src/tests/unit/test_costs_manager.py b/src/backend/src/tests/unit/test_costs_manager.py
index 3bf0903f..766f2896 100644
--- a/src/backend/src/tests/unit/test_costs_manager.py
+++ b/src/backend/src/tests/unit/test_costs_manager.py
@@ -2,12 +2,41 @@
Unit tests for CostsManager
"""
import pytest
-from datetime import date
+from datetime import date, datetime, timezone
+from uuid import uuid4, UUID
from unittest.mock import Mock
from src.controller.costs_manager import CostsManager
from src.models.costs import CostItemCreate, CostItemUpdate, CostItem
from src.db_models.costs import CostItemDb
+# Fixed UUID for deterministic test assertions
+SAMPLE_COST_UUID = UUID("a1b2c3d4-e5f6-7890-abcd-ef1234567890")
+
+
+def _make_cost_item_db(
+ id_val=None,
+ amount_cents=150000,
+ cost_center="INFRASTRUCTURE",
+):
+ """Build a mock CostItemDb with properly typed values for Pydantic validation."""
+ cost_item = Mock(spec=CostItemDb)
+ cost_item.id = id_val or uuid4()
+ cost_item.entity_type = "data_product"
+ cost_item.entity_id = "product-456"
+ cost_item.title = "Storage Costs"
+ cost_item.description = "Monthly storage costs"
+ cost_item.cost_center = cost_center
+ cost_item.custom_center_name = None
+ cost_item.amount_cents = amount_cents
+ cost_item.currency = "USD"
+ cost_item.start_month = date(2024, 1, 1)
+ cost_item.end_month = None
+ cost_item.created_by = None
+ cost_item.updated_by = None
+ cost_item.created_at = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc)
+ cost_item.updated_at = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc)
+ return cost_item
+
class TestCostsManager:
"""Test suite for CostsManager."""
@@ -25,18 +54,7 @@ def manager(self, mock_repository):
@pytest.fixture
def sample_cost_item_db(self):
"""Sample cost item database object."""
- cost_item = Mock(spec=CostItemDb)
- cost_item.id = "cost-123"
- cost_item.entity_type = "data_product"
- cost_item.entity_id = "product-456"
- cost_item.title = "Storage Costs"
- cost_item.description = "Monthly storage costs"
- cost_item.cost_center = "infrastructure"
- cost_item.amount_cents = 150000
- cost_item.currency = "USD"
- cost_item.start_month = date(2024, 1, 1)
- cost_item.end_month = None
- return cost_item
+ return _make_cost_item_db(id_val=SAMPLE_COST_UUID, amount_cents=150000)
@pytest.fixture
def sample_cost_create(self):
@@ -69,7 +87,7 @@ def test_create_cost_item_success(self, manager, mock_repository, sample_cost_cr
result = manager.create(mock_db, data=sample_cost_create, user_email="user@example.com")
assert isinstance(result, CostItem)
- assert result.id == "cost-123"
+ assert result.id == SAMPLE_COST_UUID
assert result.amount_cents == 150000
mock_repository.create.assert_called_once()
assert mock_db.commit.called
@@ -107,7 +125,7 @@ def test_list_cost_items_with_results(self, manager, mock_repository, sample_cos
result = manager.list(mock_db, entity_type="data_product", entity_id="product-456")
assert len(result) == 1
- assert result[0].id == "cost-123"
+ assert result[0].id == SAMPLE_COST_UUID
assert result[0].amount_cents == 150000
def test_list_cost_items_filtered_by_month(self, manager, mock_repository, sample_cost_item_db):
@@ -130,13 +148,9 @@ def test_list_cost_items_filtered_by_month(self, manager, mock_repository, sampl
def test_update_cost_item_success(self, manager, mock_repository, sample_cost_item_db):
"""Test updating a cost item."""
# Create updated item with same attributes but changed amount
- updated_item = Mock(spec=CostItemDb)
- for key, value in sample_cost_item_db.__dict__.items():
- if not key.startswith('_'):
- setattr(updated_item, key, value)
- updated_item.amount_cents = 200000
- updated_item.entity_type = sample_cost_item_db.entity_type
- updated_item.entity_id = sample_cost_item_db.entity_id
+ updated_item = _make_cost_item_db(
+ id_val=SAMPLE_COST_UUID, amount_cents=200000
+ )
mock_repository.get.return_value = sample_cost_item_db
mock_repository.update.return_value = updated_item
@@ -161,12 +175,9 @@ def test_update_cost_item_not_found(self, manager, mock_repository):
def test_update_cost_item_logs_change(self, manager, mock_repository, sample_cost_item_db):
"""Test that updating a cost item logs a change."""
- updated_item = Mock(spec=CostItemDb)
- for key, value in sample_cost_item_db.__dict__.items():
- if not key.startswith('_'):
- setattr(updated_item, key, value)
- updated_item.entity_type = sample_cost_item_db.entity_type
- updated_item.entity_id = sample_cost_item_db.entity_id
+ updated_item = _make_cost_item_db(
+ id_val=SAMPLE_COST_UUID, amount_cents=200000
+ )
mock_repository.get.return_value = sample_cost_item_db
mock_repository.update.return_value = updated_item
diff --git a/src/backend/src/tests/unit/test_data_products_manager.py b/src/backend/src/tests/unit/test_data_products_manager.py
index a8c29339..97be055b 100644
--- a/src/backend/src/tests/unit/test_data_products_manager.py
+++ b/src/backend/src/tests/unit/test_data_products_manager.py
@@ -134,10 +134,11 @@ def test_create_product_with_tags(
def test_create_product_validation_error(self, manager, db_session):
"""Test that invalid product data raises ValueError."""
- # Arrange
+ # Arrange - Use data that fails Pydantic validation (id must be str, not int)
invalid_data = {
- "name": "", # Empty name should fail
- "version": "invalid", # Invalid version format might fail
+ "id": 12345, # Invalid type - Pydantic expects str
+ "name": "Test",
+ "version": "1.0.0",
}
# Act & Assert
@@ -290,9 +291,10 @@ def test_update_product_validation_error(
"""Test that invalid update data raises ValueError."""
# Arrange
created = manager.create_product(sample_product_data, db=db_session)
+ # Use data that fails Pydantic validation (status must be str, not int)
invalid_update = {
"id": created.id,
- "status": "invalid-status", # Invalid status
+ "status": 123, # Invalid type - Pydantic expects str
}
# Act & Assert
@@ -489,13 +491,17 @@ def test_get_distinct_product_types(self, manager):
def test_get_distinct_owners(self, manager, db_session):
"""Test retrieving distinct owners."""
- # Arrange - Create products with different owners
+ # Arrange - Create products with team members (ODPS stores owners in team.members)
for i, owner in enumerate(["owner1@test.com", "owner2@test.com"]):
product_data = {
"name": f"Product {i}",
"version": "1.0.0",
"productType": "sourceAligned",
- "owner": owner,
+ "team": {
+ "members": [
+ {"username": owner, "role": "owner", "name": owner}
+ ]
+ },
}
manager.create_product(product_data, db=db_session)
@@ -562,12 +568,13 @@ def test_create_product_db_error_handling(
self, manager, db_session, sample_product_data
):
"""Test graceful handling of database errors during creation."""
- # Arrange - Force DB error by closing session
- db_session.close()
+ # Arrange - Patch repository to simulate DB error
+ from sqlalchemy.exc import SQLAlchemyError
- # Act & Assert
- with pytest.raises(Exception): # SQLAlchemy error
- manager.create_product(sample_product_data, db=db_session)
+ with patch.object(manager._repo, 'create', side_effect=SQLAlchemyError("DB connection failed")):
+ # Act & Assert
+ with pytest.raises(SQLAlchemyError):
+ manager.create_product(sample_product_data, db=db_session)
def test_manager_without_workspace_client(self, db_session):
"""Test manager initialization without WorkspaceClient."""
diff --git a/src/backend/src/tests/unit/test_graph_explorer_manager.py b/src/backend/src/tests/unit/test_graph_explorer_manager.py
new file mode 100644
index 00000000..394aa212
--- /dev/null
+++ b/src/backend/src/tests/unit/test_graph_explorer_manager.py
@@ -0,0 +1,489 @@
+"""
+Unit tests for GraphExplorerManager.
+
+Tests the get_neighbors() method with mocked Databricks Statement Execution API.
+"""
+
+import json
+import pytest
+from unittest.mock import MagicMock, patch, call
+
+from src.controller.graph_explorer_manager import GraphExplorerManager, _validate_table_name, _escape_sql_string
+
+
+class TestValidateTableName:
+ """Test table name validation and sanitization."""
+
+ def test_valid_three_part_name(self):
+ result = _validate_table_name("main.default.my_table")
+ assert result == "`main`.`default`.`my_table`"
+
+ def test_valid_single_part(self):
+ result = _validate_table_name("my_table")
+ assert result == "`my_table`"
+
+ def test_rejects_sql_injection(self):
+ with pytest.raises(ValueError, match="Invalid table name"):
+ _validate_table_name("main; DROP TABLE users --")
+
+ def test_strips_whitespace(self):
+ result = _validate_table_name(" main.default.t ")
+ assert result == "`main`.`default`.`t`"
+
+
+class TestEscapeSqlString:
+ """Test SQL string escaping."""
+
+ def test_escapes_single_quotes(self):
+ assert _escape_sql_string("O'Brien") == "O''Brien"
+
+ def test_no_change_for_safe_strings(self):
+ assert _escape_sql_string("hello") == "hello"
+
+
+class TestGraphExplorerManager:
+ """Unit tests for GraphExplorerManager business logic."""
+
+ @pytest.fixture
+ def manager(self):
+ settings = MagicMock()
+ settings.DATABRICKS_WAREHOUSE_ID = "test-warehouse"
+ settings.LLM_ENDPOINT = ""
+ settings.GRAPH_MAX_EDGES = 10000
+ settings.GRAPH_MAX_NODES = 5000
+ settings.GRAPH_QUERY_TIMEOUT = "30s"
+ settings.GRAPH_NEIGHBOR_LIMIT = 50
+ return GraphExplorerManager(settings=settings)
+
+ @pytest.fixture
+ def mock_ws_client(self):
+ return MagicMock()
+
+ def _make_sql_result(self, columns, rows):
+ """Helper to create a mock Statement Execution API result."""
+ mock_result = MagicMock()
+ mock_result.status.state = "SUCCEEDED"
+
+ # Build column schema
+ col_objects = []
+ for col_name in columns:
+ col_obj = MagicMock()
+ col_obj.name = col_name
+ col_objects.append(col_obj)
+
+ mock_result.manifest.schema.columns = col_objects
+ mock_result.result.data_array = rows
+
+ return mock_result
+
+ # ---------------------------------------------------------------
+ # get_neighbors tests
+ # ---------------------------------------------------------------
+
+ def test_get_neighbors_both_directions(self, manager, mock_ws_client):
+ """Test expanding all neighbors of a node."""
+ columns = [
+ "node_start_id", "node_start_key", "relationship",
+ "node_end_id", "node_end_key",
+ "node_start_properties", "node_end_properties",
+ ]
+
+ # First call: COUNT query
+ count_result = self._make_sql_result(["count"], [["2"]])
+ # Second call: SELECT query
+ data_result = self._make_sql_result(columns, [
+ ["alice", "Person", "KNOWS", "bob", "Person",
+ json.dumps({"_label": "Alice"}), json.dumps({"_label": "Bob"})],
+ ["alice", "Person", "WORKS_AT", "acme", "Company",
+ json.dumps({"_label": "Alice"}), json.dumps({"_label": "Acme"})],
+ ])
+
+ mock_ws_client.statement_execution.execute_statement.side_effect = [
+ count_result, data_result
+ ]
+
+ result = manager.get_neighbors(
+ mock_ws_client, "main.default.test_graph", "wh-1",
+ node_id="alice", direction="both", limit=25,
+ )
+
+ assert len(result["nodes"]) == 3 # alice, bob, acme
+ assert len(result["edges"]) == 2 # KNOWS, WORKS_AT
+ assert result["truncated"] is False
+ assert result["totalAvailable"] == 2
+
+ # Verify SQL contains the right direction filter
+ calls = mock_ws_client.statement_execution.execute_statement.call_args_list
+ count_sql = calls[0].kwargs["statement"]
+ assert "node_start_id = 'alice' OR node_end_id = 'alice'" in count_sql
+ assert "relationship != 'EXISTS'" in count_sql
+
+ def test_get_neighbors_outgoing_only(self, manager, mock_ws_client):
+ """Test expanding only outgoing edges."""
+ columns = [
+ "node_start_id", "node_start_key", "relationship",
+ "node_end_id", "node_end_key",
+ "node_start_properties", "node_end_properties",
+ ]
+
+ count_result = self._make_sql_result(["count"], [["1"]])
+ data_result = self._make_sql_result(columns, [
+ ["alice", "Person", "KNOWS", "bob", "Person",
+ json.dumps({"_label": "Alice"}), json.dumps({"_label": "Bob"})],
+ ])
+
+ mock_ws_client.statement_execution.execute_statement.side_effect = [
+ count_result, data_result
+ ]
+
+ result = manager.get_neighbors(
+ mock_ws_client, "main.default.test_graph", "wh-1",
+ node_id="alice", direction="outgoing", limit=25,
+ )
+
+ assert len(result["nodes"]) == 2 # alice, bob
+ assert len(result["edges"]) == 1
+
+ # Verify direction filter
+ calls = mock_ws_client.statement_execution.execute_statement.call_args_list
+ count_sql = calls[0].kwargs["statement"]
+ assert "node_start_id = 'alice'" in count_sql
+ assert "OR" not in count_sql
+
+ def test_get_neighbors_incoming_only(self, manager, mock_ws_client):
+ """Test expanding only incoming edges."""
+ columns = [
+ "node_start_id", "node_start_key", "relationship",
+ "node_end_id", "node_end_key",
+ "node_start_properties", "node_end_properties",
+ ]
+
+ count_result = self._make_sql_result(["count"], [["1"]])
+ data_result = self._make_sql_result(columns, [
+ ["bob", "Person", "KNOWS", "alice", "Person",
+ json.dumps({"_label": "Bob"}), json.dumps({"_label": "Alice"})],
+ ])
+
+ mock_ws_client.statement_execution.execute_statement.side_effect = [
+ count_result, data_result
+ ]
+
+ result = manager.get_neighbors(
+ mock_ws_client, "main.default.test_graph", "wh-1",
+ node_id="alice", direction="incoming", limit=25,
+ )
+
+ assert len(result["nodes"]) == 2
+ calls = mock_ws_client.statement_execution.execute_statement.call_args_list
+ count_sql = calls[0].kwargs["statement"]
+ assert "node_end_id = 'alice'" in count_sql
+ assert "node_start_id" not in count_sql
+
+ def test_get_neighbors_with_edge_type_filter(self, manager, mock_ws_client):
+ """Test filtering by specific edge types."""
+ columns = [
+ "node_start_id", "node_start_key", "relationship",
+ "node_end_id", "node_end_key",
+ "node_start_properties", "node_end_properties",
+ ]
+
+ count_result = self._make_sql_result(["count"], [["1"]])
+ data_result = self._make_sql_result(columns, [
+ ["alice", "Person", "KNOWS", "bob", "Person",
+ json.dumps({"_label": "Alice"}), json.dumps({"_label": "Bob"})],
+ ])
+
+ mock_ws_client.statement_execution.execute_statement.side_effect = [
+ count_result, data_result
+ ]
+
+ result = manager.get_neighbors(
+ mock_ws_client, "main.default.test_graph", "wh-1",
+ node_id="alice", direction="both",
+ edge_types=["KNOWS"], limit=25,
+ )
+
+ assert len(result["edges"]) == 1
+ assert result["edges"][0]["relationshipType"] == "KNOWS"
+
+ # Verify edge type filter in SQL
+ calls = mock_ws_client.statement_execution.execute_statement.call_args_list
+ count_sql = calls[0].kwargs["statement"]
+ assert "relationship IN ('KNOWS')" in count_sql
+
+ def test_get_neighbors_truncated(self, manager, mock_ws_client):
+ """Test truncation when more neighbors exist than the limit."""
+ columns = [
+ "node_start_id", "node_start_key", "relationship",
+ "node_end_id", "node_end_key",
+ "node_start_properties", "node_end_properties",
+ ]
+
+ # COUNT returns 50, but we're requesting limit=2
+ count_result = self._make_sql_result(["count"], [["50"]])
+ data_result = self._make_sql_result(columns, [
+ ["alice", "Person", "KNOWS", "bob", "Person", None, None],
+ ["alice", "Person", "KNOWS", "carol", "Person", None, None],
+ ])
+
+ mock_ws_client.statement_execution.execute_statement.side_effect = [
+ count_result, data_result
+ ]
+
+ result = manager.get_neighbors(
+ mock_ws_client, "main.default.test_graph", "wh-1",
+ node_id="alice", direction="both", limit=2,
+ )
+
+ assert result["truncated"] is True
+ assert result["totalAvailable"] == 50
+ assert len(result["edges"]) == 2
+
+ def test_get_neighbors_empty_result(self, manager, mock_ws_client):
+ """Test expanding a node with no neighbors."""
+ count_result = self._make_sql_result(["count"], [["0"]])
+ data_result = self._make_sql_result([], [])
+ # Override to handle empty columns
+ data_result.manifest.schema.columns = []
+ data_result.result.data_array = []
+
+ mock_ws_client.statement_execution.execute_statement.side_effect = [
+ count_result, data_result
+ ]
+
+ result = manager.get_neighbors(
+ mock_ws_client, "main.default.test_graph", "wh-1",
+ node_id="lonely_node", direction="both", limit=25,
+ )
+
+ assert len(result["nodes"]) == 0
+ assert len(result["edges"]) == 0
+ assert result["truncated"] is False
+ assert result["totalAvailable"] == 0
+
+ def test_get_neighbors_deduplicates_nodes(self, manager, mock_ws_client):
+ """Test that nodes appearing in multiple edges are not duplicated."""
+ columns = [
+ "node_start_id", "node_start_key", "relationship",
+ "node_end_id", "node_end_key",
+ "node_start_properties", "node_end_properties",
+ ]
+
+ count_result = self._make_sql_result(["count"], [["2"]])
+ # alice appears as start node in both rows
+ data_result = self._make_sql_result(columns, [
+ ["alice", "Person", "KNOWS", "bob", "Person", None, None],
+ ["alice", "Person", "LIKES", "bob", "Person", None, None],
+ ])
+
+ mock_ws_client.statement_execution.execute_statement.side_effect = [
+ count_result, data_result
+ ]
+
+ result = manager.get_neighbors(
+ mock_ws_client, "main.default.test_graph", "wh-1",
+ node_id="alice", direction="both", limit=25,
+ )
+
+ # Only 2 unique nodes despite appearing in multiple rows
+ assert len(result["nodes"]) == 2
+ node_ids = {n["id"] for n in result["nodes"]}
+ assert node_ids == {"alice", "bob"}
+ # But both edges are present
+ assert len(result["edges"]) == 2
+
+ def test_get_neighbors_sql_injection_protection(self, manager, mock_ws_client):
+ """Test that node IDs with special characters are escaped."""
+ columns = [
+ "node_start_id", "node_start_key", "relationship",
+ "node_end_id", "node_end_key",
+ "node_start_properties", "node_end_properties",
+ ]
+
+ count_result = self._make_sql_result(["count"], [["0"]])
+ data_result = self._make_sql_result(columns, [])
+ data_result.result.data_array = []
+
+ mock_ws_client.statement_execution.execute_statement.side_effect = [
+ count_result, data_result
+ ]
+
+ # Node ID with a single quote (SQL injection attempt)
+ result = manager.get_neighbors(
+ mock_ws_client, "main.default.test_graph", "wh-1",
+ node_id="alice'; DROP TABLE users; --", direction="both", limit=25,
+ )
+
+ # Verify the escaped string is in the SQL
+ calls = mock_ws_client.statement_execution.execute_statement.call_args_list
+ count_sql = calls[0].kwargs["statement"]
+ assert "alice''; DROP TABLE users; --" in count_sql # Double-escaped
+
+ # ---------------------------------------------------------------
+ # get_graph_data tests
+ # ---------------------------------------------------------------
+
+ def test_get_graph_data_basic(self, manager, mock_ws_client):
+ """Test basic graph data loading."""
+ columns = [
+ "node_start_id", "node_start_key", "relationship",
+ "node_end_id", "node_end_key",
+ "node_start_properties", "node_end_properties",
+ ]
+
+ count_result = self._make_sql_result(["count"], [["2"]])
+ data_result = self._make_sql_result(columns, [
+ ["alice", "Person", "KNOWS", "bob", "Person",
+ json.dumps({"_label": "Alice", "age": "30"}),
+ json.dumps({"_label": "Bob"})],
+ ["alice", "Person", "EXISTS", "alice", "Person",
+ json.dumps({"_label": "Alice"}),
+ json.dumps({"_label": "Alice"})],
+ ])
+
+ mock_ws_client.statement_execution.execute_statement.side_effect = [
+ count_result, data_result
+ ]
+
+ data = manager.get_graph_data(mock_ws_client, "main.default.test", "wh-1")
+
+ assert len(data["nodes"]) == 2 # alice, bob
+ assert len(data["edges"]) == 1 # KNOWS only, EXISTS filtered out
+ assert data["truncated"] is False
+
+ def test_get_graph_data_empty_table(self, manager, mock_ws_client):
+ """Test loading from an empty table."""
+ count_result = self._make_sql_result(["count"], [["0"]])
+ data_result = self._make_sql_result([], [])
+ data_result.manifest.schema.columns = []
+ data_result.result.data_array = []
+
+ mock_ws_client.statement_execution.execute_statement.side_effect = [
+ count_result, data_result
+ ]
+
+ data = manager.get_graph_data(mock_ws_client, "main.default.test", "wh-1")
+
+ assert data["nodes"] == []
+ assert data["edges"] == []
+ assert data["truncated"] is False
+
+ def test_get_graph_data_truncated(self, mock_ws_client):
+ """Test that get_graph_data returns truncation info when data exceeds limit."""
+ settings = MagicMock()
+ settings.DATABRICKS_WAREHOUSE_ID = "test-warehouse"
+ settings.LLM_ENDPOINT = ""
+ settings.GRAPH_MAX_EDGES = 2 # Very low limit
+ settings.GRAPH_QUERY_TIMEOUT = "30s"
+ mgr = GraphExplorerManager(settings=settings)
+
+ columns = [
+ "node_start_id", "node_start_key", "relationship",
+ "node_end_id", "node_end_key",
+ "node_start_properties", "node_end_properties",
+ ]
+
+ # Total rows = 100, but limit is 2
+ count_result = self._make_sql_result(["count"], [["100"]])
+ data_result = self._make_sql_result(columns, [
+ ["a", "N", "R", "b", "N", None, None],
+ ["c", "N", "R", "d", "N", None, None],
+ ])
+
+ mock_ws_client.statement_execution.execute_statement.side_effect = [
+ count_result, data_result
+ ]
+
+ data = mgr.get_graph_data(mock_ws_client, "main.default.test", "wh-1")
+
+ assert data["truncated"] is True
+ assert data["totalAvailable"] == 100
+
+ # Verify LIMIT in the SQL
+ calls = mock_ws_client.statement_execution.execute_statement.call_args_list
+ select_sql = calls[1].kwargs["statement"]
+ assert "LIMIT 2" in select_sql
+
+ def test_get_graph_data_respects_max_rows_argument(self, manager, mock_ws_client):
+ """Test that explicit max_rows overrides settings."""
+ columns = [
+ "node_start_id", "node_start_key", "relationship",
+ "node_end_id", "node_end_key",
+ "node_start_properties", "node_end_properties",
+ ]
+
+ count_result = self._make_sql_result(["count"], [["5"]])
+ data_result = self._make_sql_result(columns, [
+ ["a", "N", "R", "b", "N", None, None],
+ ])
+
+ mock_ws_client.statement_execution.execute_statement.side_effect = [
+ count_result, data_result
+ ]
+
+ data = manager.get_graph_data(mock_ws_client, "main.default.test", "wh-1", max_rows=1)
+
+ # Verify LIMIT in the SQL
+ calls = mock_ws_client.statement_execution.execute_statement.call_args_list
+ select_sql = calls[1].kwargs["statement"]
+ assert "LIMIT 1" in select_sql
+
+ def test_get_graph_data_uses_timeout_from_settings(self, mock_ws_client):
+ """Test that _execute_sql uses GRAPH_QUERY_TIMEOUT from settings."""
+ settings = MagicMock()
+ settings.DATABRICKS_WAREHOUSE_ID = "test-warehouse"
+ settings.LLM_ENDPOINT = ""
+ settings.GRAPH_MAX_EDGES = 10000
+ settings.GRAPH_QUERY_TIMEOUT = "60s"
+ mgr = GraphExplorerManager(settings=settings)
+
+ count_result = self._make_sql_result(["count"], [["0"]])
+ data_result = self._make_sql_result([], [])
+ data_result.manifest.schema.columns = []
+ data_result.result.data_array = []
+
+ mock_ws_client.statement_execution.execute_statement.side_effect = [
+ count_result, data_result
+ ]
+
+ mgr.get_graph_data(mock_ws_client, "main.default.test", "wh-1")
+
+ # Verify timeout was passed
+ calls = mock_ws_client.statement_execution.execute_statement.call_args_list
+ for c in calls:
+ assert c.kwargs["wait_timeout"] == "60s"
+
+ # ---------------------------------------------------------------
+ # _parse_props tests
+ # ---------------------------------------------------------------
+
+ def test_parse_props_json_string(self, manager):
+ result = manager._parse_props('{"name": "Alice", "age": 30}')
+ assert result == {"name": "Alice", "age": 30}
+
+ def test_parse_props_dict(self, manager):
+ result = manager._parse_props({"name": "Alice"})
+ assert result == {"name": "Alice"}
+
+ def test_parse_props_none(self, manager):
+ assert manager._parse_props(None) == {}
+
+ def test_parse_props_invalid_json(self, manager):
+ assert manager._parse_props("not json") == {}
+
+ # ---------------------------------------------------------------
+ # LLM config tests
+ # ---------------------------------------------------------------
+
+ def test_get_llm_config_disabled(self, manager):
+ config = manager.get_llm_config()
+ assert config["enabled"] is False
+
+ def test_get_llm_config_enabled(self):
+ settings = MagicMock()
+ settings.LLM_ENDPOINT = "databricks-claude-sonnet"
+ mgr = GraphExplorerManager(settings=settings)
+
+ config = mgr.get_llm_config()
+ assert config["enabled"] is True
+ assert config["defaultModel"] == "databricks-claude-sonnet"
diff --git a/src/backend/src/tests/unit/test_metadata_manager.py b/src/backend/src/tests/unit/test_metadata_manager.py
index fcf0c055..eb7f4d6c 100644
--- a/src/backend/src/tests/unit/test_metadata_manager.py
+++ b/src/backend/src/tests/unit/test_metadata_manager.py
@@ -1,6 +1,8 @@
"""
Unit tests for MetadataManager
"""
+import uuid
+from datetime import datetime
import pytest
from unittest.mock import Mock, MagicMock
from databricks.sdk.service.catalog import VolumeType
@@ -40,31 +42,45 @@ def manager(self, mock_rich_text_repo, mock_link_repo, mock_document_repo):
def sample_rich_text_db(self):
"""Sample rich text database object."""
rt = Mock(spec=RichTextMetadataDb)
- rt.id = "rt-123"
+ rt.id = uuid.UUID("12345678-1234-1234-1234-123456789012")
rt.entity_type = "data_product"
rt.entity_id = "product-456"
rt.title = "Documentation"
rt.short_description = "Product docs"
rt.content_markdown = "# Test\nContent"
+ rt.is_shared = False
+ rt.level = 50
+ rt.inheritable = True
+ rt.created_by = "user@example.com"
+ rt.updated_by = "user@example.com"
+ rt.created_at = datetime(2024, 1, 1, 12, 0, 0)
+ rt.updated_at = datetime(2024, 1, 1, 12, 0, 0)
return rt
@pytest.fixture
def sample_link_db(self):
"""Sample link database object."""
link = Mock(spec=LinkMetadataDb)
- link.id = "link-123"
+ link.id = uuid.UUID("22345678-1234-1234-1234-123456789012")
link.entity_type = "data_contract"
link.entity_id = "contract-456"
link.title = "Related Link"
link.short_description = "External resource"
link.url = "https://example.com"
+ link.is_shared = False
+ link.level = 50
+ link.inheritable = True
+ link.created_by = "user@example.com"
+ link.updated_by = "user@example.com"
+ link.created_at = datetime(2024, 1, 1, 12, 0, 0)
+ link.updated_at = datetime(2024, 1, 1, 12, 0, 0)
return link
@pytest.fixture
def sample_document_db(self):
"""Sample document database object."""
doc = Mock(spec=DocumentMetadataDb)
- doc.id = "doc-123"
+ doc.id = uuid.UUID("32345678-1234-1234-1234-123456789012")
doc.entity_type = "data_product"
doc.entity_id = "product-789"
doc.title = "User Guide"
@@ -73,6 +89,13 @@ def sample_document_db(self):
doc.content_type = "application/pdf"
doc.size_bytes = 1024000
doc.storage_path = "/volumes/catalog/schema/volume/guide.pdf"
+ doc.is_shared = False
+ doc.level = 50
+ doc.inheritable = True
+ doc.created_by = "user@example.com"
+ doc.updated_by = "user@example.com"
+ doc.created_at = datetime(2024, 1, 1, 12, 0, 0)
+ doc.updated_at = datetime(2024, 1, 1, 12, 0, 0)
return doc
# Initialization Tests
@@ -105,7 +128,7 @@ def test_create_rich_text_success(self, manager, mock_rich_text_repo, sample_ric
result = manager.create_rich_text(mock_db, data=data, user_email="user@example.com")
- assert result.id == "rt-123"
+ assert result.id == uuid.UUID("12345678-1234-1234-1234-123456789012")
mock_rich_text_repo.create.assert_called_once()
assert mock_db.commit.called
@@ -126,14 +149,17 @@ def test_list_rich_texts_with_results(self, manager, mock_rich_text_repo, sample
result = manager.list_rich_texts(mock_db, entity_type="data_product", entity_id="product-456")
assert len(result) == 1
- assert result[0].id == "rt-123"
+ assert result[0].id == uuid.UUID("12345678-1234-1234-1234-123456789012")
def test_update_rich_text_success(self, manager, mock_rich_text_repo, sample_rich_text_db):
"""Test updating rich text metadata."""
updated_rt = Mock(spec=RichTextMetadataDb)
- for key, value in sample_rich_text_db.__dict__.items():
- if not key.startswith('_'):
- setattr(updated_rt, key, value)
+ for attr in (
+ "id", "entity_type", "entity_id", "title", "short_description", "content_markdown",
+ "is_shared", "level", "inheritable", "created_by", "updated_by",
+ "created_at", "updated_at",
+ ):
+ setattr(updated_rt, attr, getattr(sample_rich_text_db, attr))
updated_rt.title = "Updated Title"
mock_rich_text_repo.get.return_value = sample_rich_text_db
@@ -141,7 +167,9 @@ def test_update_rich_text_success(self, manager, mock_rich_text_repo, sample_ric
mock_db = Mock()
update_data = RichTextUpdate(title="Updated Title")
- result = manager.update_rich_text(mock_db, id="rt-123", data=update_data, user_email="user@example.com")
+ result = manager.update_rich_text(
+ mock_db, id="12345678-1234-1234-1234-123456789012", data=update_data, user_email="user@example.com"
+ )
assert result is not None
assert result.title == "Updated Title"
@@ -162,7 +190,9 @@ def test_delete_rich_text_success(self, manager, mock_rich_text_repo, sample_ric
mock_rich_text_repo.remove.return_value = sample_rich_text_db
mock_db = Mock()
- result = manager.delete_rich_text(mock_db, id="rt-123", user_email="user@example.com")
+ result = manager.delete_rich_text(
+ mock_db, id="12345678-1234-1234-1234-123456789012", user_email="user@example.com"
+ )
assert result is True
mock_rich_text_repo.remove.assert_called_once()
@@ -193,7 +223,7 @@ def test_create_link_success(self, manager, mock_link_repo, sample_link_db):
result = manager.create_link(mock_db, data=data, user_email="user@example.com")
- assert result.id == "link-123"
+ assert result.id == uuid.UUID("22345678-1234-1234-1234-123456789012")
mock_link_repo.create.assert_called_once()
def test_list_links_empty(self, manager, mock_link_repo):
@@ -218,9 +248,12 @@ def test_list_links_with_results(self, manager, mock_link_repo, sample_link_db):
def test_update_link_success(self, manager, mock_link_repo, sample_link_db):
"""Test updating link metadata."""
updated_link = Mock(spec=LinkMetadataDb)
- for key, value in sample_link_db.__dict__.items():
- if not key.startswith('_'):
- setattr(updated_link, key, value)
+ for attr in (
+ "id", "entity_type", "entity_id", "title", "short_description", "url",
+ "is_shared", "level", "inheritable", "created_by", "updated_by",
+ "created_at", "updated_at",
+ ):
+ setattr(updated_link, attr, getattr(sample_link_db, attr))
updated_link.url = "https://updated.com"
mock_link_repo.get.return_value = sample_link_db
@@ -228,7 +261,9 @@ def test_update_link_success(self, manager, mock_link_repo, sample_link_db):
mock_db = Mock()
update_data = LinkUpdate(url="https://updated.com")
- result = manager.update_link(mock_db, id="link-123", data=update_data, user_email="user@example.com")
+ result = manager.update_link(
+ mock_db, id="22345678-1234-1234-1234-123456789012", data=update_data, user_email="user@example.com"
+ )
assert result is not None
assert result.url == "https://updated.com"
@@ -249,7 +284,9 @@ def test_delete_link_success(self, manager, mock_link_repo, sample_link_db):
mock_link_repo.remove.return_value = sample_link_db
mock_db = Mock()
- result = manager.delete_link(mock_db, id="link-123", user_email="user@example.com")
+ result = manager.delete_link(
+ mock_db, id="22345678-1234-1234-1234-123456789012", user_email="user@example.com"
+ )
assert result is True
@@ -268,6 +305,17 @@ def test_create_document_record_success(self, manager):
"""Test creating document record."""
mock_db = Mock()
+ # Simulate db.refresh populating DB-generated fields (id, timestamps, defaults)
+ def _mock_refresh(db_obj):
+ setattr(db_obj, "id", uuid.UUID("32345678-1234-1234-1234-123456789012"))
+ setattr(db_obj, "created_at", datetime(2024, 1, 1, 12, 0, 0))
+ setattr(db_obj, "updated_at", datetime(2024, 1, 1, 12, 0, 0))
+ setattr(db_obj, "is_shared", False)
+ setattr(db_obj, "level", 50)
+ setattr(db_obj, "inheritable", True)
+
+ mock_db.refresh = Mock(side_effect=_mock_refresh)
+
data = DocumentCreate(
entity_type="data_product",
entity_id="product-789",
@@ -313,10 +361,10 @@ def test_get_document_success(self, manager, mock_document_repo, sample_document
mock_document_repo.get.return_value = sample_document_db
mock_db = Mock()
- result = manager.get_document(mock_db, id="doc-123")
+ result = manager.get_document(mock_db, id="32345678-1234-1234-1234-123456789012")
assert result is not None
- assert result.id == "doc-123"
+ assert result.id == uuid.UUID("32345678-1234-1234-1234-123456789012")
def test_get_document_not_found(self, manager, mock_document_repo):
"""Test getting non-existent document."""
@@ -333,7 +381,9 @@ def test_delete_document_success(self, manager, mock_document_repo, sample_docum
mock_document_repo.remove.return_value = sample_document_db
mock_db = Mock()
- result = manager.delete_document(mock_db, id="doc-123", user_email="user@example.com")
+ result = manager.delete_document(
+ mock_db, id="32345678-1234-1234-1234-123456789012", user_email="user@example.com"
+ )
assert result is True
diff --git a/src/backend/src/tests/unit/test_notifications_manager.py b/src/backend/src/tests/unit/test_notifications_manager.py
index d4b04a31..c476d51b 100644
--- a/src/backend/src/tests/unit/test_notifications_manager.py
+++ b/src/backend/src/tests/unit/test_notifications_manager.py
@@ -3,13 +3,40 @@
"""
import pytest
from datetime import datetime
-from unittest.mock import Mock, MagicMock, patch, mock_open
-from src.controller.notifications_manager import NotificationsManager, NotificationNotFoundError
+from types import SimpleNamespace
+from unittest.mock import Mock
+from src.controller.notifications_manager import NotificationsManager
from src.models.notifications import Notification, NotificationType
from src.models.users import UserInfo
from src.db_models.notifications import NotificationDb
+def _make_notif_db(**kwargs):
+ """Create a mock NotificationDb with all attributes needed for Notification.model_validate."""
+ defaults = {
+ "id": "notif-1",
+ "recipient": None,
+ "title": "Test",
+ "type": "info",
+ "created_at": datetime(2024, 1, 1),
+ "read": False,
+ "can_delete": True,
+ "subtitle": None,
+ "description": None,
+ "message": None,
+ "link": None,
+ "action_type": None,
+ "action_payload": None,
+ "data": None,
+ "target_roles": None,
+ "updated_at": None,
+ "recipient_role_id": None,
+ "recipient_role_name": None,
+ }
+ defaults.update(kwargs)
+ return SimpleNamespace(**defaults)
+
+
class TestNotificationsManager:
"""Test suite for NotificationsManager."""
@@ -26,9 +53,11 @@ def mock_repository(self):
return Mock()
@pytest.fixture
- def manager(self, mock_settings_manager):
+ def manager(self, mock_settings_manager, mock_repository):
"""Create NotificationsManager instance for testing."""
- return NotificationsManager(settings_manager=mock_settings_manager)
+ manager = NotificationsManager(settings_manager=mock_settings_manager)
+ manager._repo = mock_repository # Replace real repo with mock for unit testing
+ return manager
@pytest.fixture
def sample_notification(self):
@@ -80,21 +109,7 @@ def test_get_notifications_broadcast_only(self, manager):
mock_db = Mock()
# Create notification DB objects
- broadcast_notif = Mock(spec=NotificationDb)
- broadcast_notif.id = "notif-1"
- broadcast_notif.recipient = None
- broadcast_notif.title = "Broadcast"
- broadcast_notif.type = "info"
- broadcast_notif.created_at = datetime(2024, 1, 1)
- broadcast_notif.read = False
- broadcast_notif.can_delete = True
- broadcast_notif.subtitle = None
- broadcast_notif.description = None
- broadcast_notif.link = None
- broadcast_notif.action_type = None
- broadcast_notif.action_payload = None
- broadcast_notif.data = None
- broadcast_notif.target_roles = None
+ broadcast_notif = _make_notif_db(id="notif-1", recipient=None, title="Broadcast")
manager._repo.get_multi.return_value = [broadcast_notif]
@@ -108,37 +123,13 @@ def test_get_notifications_filtered_by_email(self, manager, sample_user_info):
mock_db = Mock()
# Create notifications
- user_notif = Mock(spec=NotificationDb)
- user_notif.id = "notif-1"
- user_notif.recipient = "user@example.com"
- user_notif.title = "User Notification"
- user_notif.type = "info"
- user_notif.created_at = datetime(2024, 1, 1)
- user_notif.read = False
- user_notif.can_delete = True
- user_notif.subtitle = None
- user_notif.description = None
- user_notif.link = None
- user_notif.action_type = None
- user_notif.action_payload = None
- user_notif.data = None
- user_notif.target_roles = None
-
- other_notif = Mock(spec=NotificationDb)
- other_notif.id = "notif-2"
- other_notif.recipient = "other@example.com"
- other_notif.title = "Other Notification"
- other_notif.type = "info"
- other_notif.created_at = datetime(2024, 1, 2)
- other_notif.read = False
- other_notif.can_delete = True
- other_notif.subtitle = None
- other_notif.description = None
- other_notif.link = None
- other_notif.action_type = None
- other_notif.action_payload = None
- other_notif.data = None
- other_notif.target_roles = None
+ user_notif = _make_notif_db(id="notif-1", recipient="user@example.com", title="User Notification")
+ other_notif = _make_notif_db(
+ id="notif-2",
+ recipient="other@example.com",
+ title="Other Notification",
+ created_at=datetime(2024, 1, 2),
+ )
manager._repo.get_multi.return_value = [user_notif, other_notif]
@@ -159,21 +150,9 @@ def test_get_notifications_filtered_by_role(self, manager, sample_user_info, moc
mock_settings_manager.list_app_roles.return_value = [mock_role]
# Create role-targeted notification
- role_notif = Mock(spec=NotificationDb)
- role_notif.id = "notif-1"
- role_notif.recipient = "Data Consumer"
- role_notif.title = "Role Notification"
- role_notif.type = "info"
- role_notif.created_at = datetime(2024, 1, 1)
- role_notif.read = False
- role_notif.can_delete = True
- role_notif.subtitle = None
- role_notif.description = None
- role_notif.link = None
- role_notif.action_type = None
- role_notif.action_payload = None
- role_notif.data = None
- role_notif.target_roles = None
+ role_notif = _make_notif_db(
+ id="notif-1", recipient="Data Consumer", title="Role Notification"
+ )
manager._repo.get_multi.return_value = [role_notif]
@@ -187,53 +166,9 @@ def test_get_notifications_sorts_by_created_at(self, manager):
mock_db = Mock()
# Create notifications with different timestamps
- notif1 = Mock(spec=NotificationDb)
- notif1.id = "notif-1"
- notif1.recipient = None
- notif1.title = "First"
- notif1.type = "info"
- notif1.created_at = datetime(2024, 1, 1)
- notif1.read = False
- notif1.can_delete = True
- notif1.subtitle = None
- notif1.description = None
- notif1.link = None
- notif1.action_type = None
- notif1.action_payload = None
- notif1.data = None
- notif1.target_roles = None
-
- notif2 = Mock(spec=NotificationDb)
- notif2.id = "notif-2"
- notif2.recipient = None
- notif2.title = "Second"
- notif2.type = "info"
- notif2.created_at = datetime(2024, 1, 3)
- notif2.read = False
- notif2.can_delete = True
- notif2.subtitle = None
- notif2.description = None
- notif2.link = None
- notif2.action_type = None
- notif2.action_payload = None
- notif2.data = None
- notif2.target_roles = None
-
- notif3 = Mock(spec=NotificationDb)
- notif3.id = "notif-3"
- notif3.recipient = None
- notif3.title = "Third"
- notif3.type = "info"
- notif3.created_at = datetime(2024, 1, 2)
- notif3.read = False
- notif3.can_delete = True
- notif3.subtitle = None
- notif3.description = None
- notif3.link = None
- notif3.action_type = None
- notif3.action_payload = None
- notif3.data = None
- notif3.target_roles = None
+ notif1 = _make_notif_db(id="notif-1", title="First", created_at=datetime(2024, 1, 1))
+ notif2 = _make_notif_db(id="notif-2", title="Second", created_at=datetime(2024, 1, 3))
+ notif3 = _make_notif_db(id="notif-3", title="Third", created_at=datetime(2024, 1, 2))
manager._repo.get_multi.return_value = [notif1, notif2, notif3]
@@ -250,15 +185,19 @@ def test_mark_as_read_success(self, manager):
"""Test marking notification as read."""
mock_db = Mock()
- notif_db = Mock(spec=NotificationDb)
- notif_db.id = "notif-123"
- notif_db.read = False
+ notif_db = _make_notif_db(id="notif-123", read=False, title="Test")
+
+ def mock_update(db, db_obj, obj_in):
+ for k, v in obj_in.items():
+ setattr(db_obj, k, v)
+ return db_obj
manager._repo.get.return_value = notif_db
+ manager._repo.update.side_effect = mock_update
- result = manager.mark_as_read(mock_db, notification_id="notif-123")
+ result = manager.mark_notification_read(db=mock_db, notification_id="notif-123")
- assert result is True
+ assert result is not None
assert notif_db.read is True
mock_db.commit.assert_called_once()
@@ -267,8 +206,9 @@ def test_mark_as_read_not_found(self, manager):
mock_db = Mock()
manager._repo.get.return_value = None
- with pytest.raises(NotificationNotFoundError):
- manager.mark_as_read(mock_db, notification_id="nonexistent")
+ result = manager.mark_notification_read(db=mock_db, notification_id="nonexistent")
+
+ assert result is None
# Mark All as Read Tests
@@ -276,39 +216,22 @@ def test_mark_all_as_read_success(self, manager, sample_user_info):
"""Test marking all notifications as read for user."""
mock_db = Mock()
- notif1 = Mock(spec=NotificationDb)
- notif1.id = "notif-1"
- notif1.recipient = "user@example.com"
- notif1.read = False
- notif1.title = "Test 1"
- notif1.type = "info"
- notif1.created_at = datetime(2024, 1, 1)
- notif1.can_delete = True
- notif1.subtitle = None
- notif1.description = None
- notif1.link = None
- notif1.action_type = None
- notif1.action_payload = None
- notif1.data = None
- notif1.target_roles = None
-
- notif2 = Mock(spec=NotificationDb)
- notif2.id = "notif-2"
- notif2.recipient = "user@example.com"
- notif2.read = False
- notif2.title = "Test 2"
- notif2.type = "info"
- notif2.created_at = datetime(2024, 1, 2)
- notif2.can_delete = True
- notif2.subtitle = None
- notif2.description = None
- notif2.link = None
- notif2.action_type = None
- notif2.action_payload = None
- notif2.data = None
- notif2.target_roles = None
+ notif1 = _make_notif_db(id="notif-1", recipient="user@example.com", read=False, title="Test 1")
+ notif2 = _make_notif_db(
+ id="notif-2",
+ recipient="user@example.com",
+ read=False,
+ title="Test 2",
+ created_at=datetime(2024, 1, 2),
+ )
+
+ def mock_update(db, db_obj, obj_in):
+ for k, v in obj_in.items():
+ setattr(db_obj, k, v)
+ return db_obj
manager._repo.get_multi.return_value = [notif1, notif2]
+ manager._repo.update.side_effect = mock_update
result = manager.mark_all_as_read(mock_db, user_info=sample_user_info)
@@ -322,9 +245,7 @@ def test_delete_notification_success(self, manager):
"""Test deleting a notification."""
mock_db = Mock()
- notif_db = Mock(spec=NotificationDb)
- notif_db.id = "notif-123"
- notif_db.can_delete = True
+ notif_db = _make_notif_db(id="notif-123", can_delete=True)
manager._repo.get.return_value = notif_db
manager._repo.remove.return_value = notif_db
@@ -339,16 +260,15 @@ def test_delete_notification_not_found(self, manager):
mock_db = Mock()
manager._repo.get.return_value = None
- with pytest.raises(NotificationNotFoundError):
- manager.delete_notification(mock_db, notification_id="nonexistent")
+ result = manager.delete_notification(mock_db, notification_id="nonexistent")
+
+ assert result is False
def test_delete_notification_cannot_delete(self, manager):
"""Test trying to delete a notification marked as non-deletable."""
mock_db = Mock()
- notif_db = Mock(spec=NotificationDb)
- notif_db.id = "notif-123"
- notif_db.can_delete = False
+ notif_db = _make_notif_db(id="notif-123", can_delete=False)
manager._repo.get.return_value = notif_db
diff --git a/src/backend/src/tests/unit/test_search_manager.py b/src/backend/src/tests/unit/test_search_manager.py
index 756f84f0..8c17bf95 100644
--- a/src/backend/src/tests/unit/test_search_manager.py
+++ b/src/backend/src/tests/unit/test_search_manager.py
@@ -64,8 +64,8 @@ def sample_user(self):
return UserInfo(
username="testuser",
email="test@example.com",
- display_name="Test User",
- active=True,
+ user="Test User",
+ ip="127.0.0.1",
groups=["users", "data_consumers"],
)
@@ -279,8 +279,8 @@ def test_search_user_without_groups(self, mock_searchable_manager, mock_auth_man
user_no_groups = UserInfo(
username="nogroups",
email="nogroups@example.com",
- display_name="No Groups User",
- active=True,
+ user="No Groups User",
+ ip="127.0.0.1",
groups=[],
)
diff --git a/src/backend/src/tests/unit/test_security_features_manager.py b/src/backend/src/tests/unit/test_security_features_manager.py
index 5f589cfa..1c679c91 100644
--- a/src/backend/src/tests/unit/test_security_features_manager.py
+++ b/src/backend/src/tests/unit/test_security_features_manager.py
@@ -188,7 +188,7 @@ def test_load_from_yaml_empty_file(self, mock_exists, mock_file, manager):
assert result is True
assert len(manager.features) == 0
- @patch('builtins.open', new_callable=mock_open, read_data="features:")
+ @patch('builtins.open', new_callable=mock_open, read_data="features: []")
@patch('pathlib.Path.exists')
def test_load_from_yaml_empty_features_list(self, mock_exists, mock_file, manager):
"""Test loading from YAML with empty features list."""
diff --git a/src/backend/src/tests/unit/test_settings_manager.py b/src/backend/src/tests/unit/test_settings_manager.py
index 0f4815bc..0faad61b 100644
--- a/src/backend/src/tests/unit/test_settings_manager.py
+++ b/src/backend/src/tests/unit/test_settings_manager.py
@@ -23,9 +23,29 @@ class TestSettingsManager:
@pytest.fixture
def mock_settings(self):
- """Create mock settings."""
+ """Create mock settings with attributes required by get_settings()."""
mock = MagicMock(spec=Settings)
mock.job_cluster_id = "test-cluster"
+ mock.WORKSPACE_DEPLOYMENT_PATH = None
+ mock.DATABRICKS_CATALOG = None
+ mock.DATABRICKS_SCHEMA = None
+ mock.DATABRICKS_VOLUME = None
+ mock.APP_AUDIT_LOG_DIR = None
+ mock.LLM_ENABLED = False
+ mock.LLM_ENDPOINT = None
+ mock.LLM_SYSTEM_PROMPT = None
+ mock.LLM_DISCLAIMER_TEXT = None
+ mock.DELIVERY_MODE_DIRECT = False
+ mock.DELIVERY_MODE_INDIRECT = False
+ mock.DELIVERY_MODE_MANUAL = True
+ mock.DELIVERY_DIRECT_DRY_RUN = False
+ mock.GIT_REPO_URL = None
+ mock.GIT_BRANCH = None
+ mock.GIT_USERNAME = None
+ mock.UI_I18N_ENABLED = True
+ mock.UI_CUSTOM_LOGO_URL = None
+ mock.UI_ABOUT_CONTENT = None
+ mock.UI_CUSTOM_CSS = None
mock.to_dict.return_value = {"job_cluster_id": "test-cluster"}
return mock
diff --git a/src/backend/src/utils/startup_tasks.py b/src/backend/src/utils/startup_tasks.py
index a39a787b..563fb512 100644
--- a/src/backend/src/utils/startup_tasks.py
+++ b/src/backend/src/utils/startup_tasks.py
@@ -294,6 +294,11 @@ def initialize_managers(app: FastAPI):
# SEARCHABLE_ASSET_MANAGERS.append(metadata_manager) # If it's searchable
# logger.info("MetadataManager initialized.")
+ # Graph Explorer Manager (no DB needed - uses Databricks SQL)
+ from src.controller.graph_explorer_manager import GraphExplorerManager
+ app.state.graph_explorer_manager = GraphExplorerManager(settings=settings)
+ logger.info("GraphExplorerManager initialized.")
+
logger.info("All managers instantiated and stored in app.state.")
# Defer SearchManager initialization until after initial data loading completes
diff --git a/src/frontend/package.json b/src/frontend/package.json
index 870e64f4..3cba73db 100644
--- a/src/frontend/package.json
+++ b/src/frontend/package.json
@@ -35,6 +35,7 @@
"@radix-ui/react-toast": "^1.2.11",
"@radix-ui/react-tooltip": "^1.1.8",
"@tanstack/react-table": "^8.21.2",
+ "@types/d3-force": "^3.0.10",
"@types/node": "^20.11.16",
"@types/react": "^18.2.55",
"@types/react-dom": "^18.2.19",
@@ -47,6 +48,7 @@
"clsx": "^2.1.1",
"cmdk": "^1.1.1",
"cytoscape": "^3.33.1",
+ "d3-force": "^3.0.0",
"date-fns": "^4.1.0",
"framer-motion": "^12.6.3",
"i18next": "^25.5.3",
@@ -56,6 +58,7 @@
"react-cytoscapejs": "^2.0.0",
"react-dom": "^18.2.0",
"react-dropzone": "^14.3.8",
+ "react-force-graph-2d": "^1.29.1",
"react-hook-form": "^7.56.1",
"react-i18next": "^16.0.0",
"react-markdown": "^10.1.0",
diff --git a/src/frontend/postcss.config.js b/src/frontend/postcss.config.cjs
similarity index 100%
rename from src/frontend/postcss.config.js
rename to src/frontend/postcss.config.cjs
diff --git a/src/frontend/src/app.tsx b/src/frontend/src/app.tsx
index 2f86da0c..5ac34a25 100644
--- a/src/frontend/src/app.tsx
+++ b/src/frontend/src/app.tsx
@@ -47,6 +47,7 @@ import ProjectsView from './views/projects';
import AuditTrail from './views/audit-trail';
import WorkflowDesignerView from './views/workflow-designer';
import Workflows from './views/workflows';
+import GraphExplorer from './views/graph-explorer';
export default function App() {
const fetchUserInfo = useUserStore((state: any) => state.fetchUserInfo);
@@ -118,6 +119,7 @@ export default function App() {
} />
} />
} />
+ } />
} />
} />
diff --git a/src/frontend/src/components/graph-explorer/diagram-manager.test.tsx b/src/frontend/src/components/graph-explorer/diagram-manager.test.tsx
new file mode 100644
index 00000000..b4bdd74e
--- /dev/null
+++ b/src/frontend/src/components/graph-explorer/diagram-manager.test.tsx
@@ -0,0 +1,186 @@
+import { describe, it, expect, vi, beforeEach } from 'vitest';
+import { render, screen, within } from '@testing-library/react';
+import userEvent from '@testing-library/user-event';
+import { DiagramManager, getDiagrams, saveDiagrams, type SavedDiagram } from './diagram-manager';
+import type { GraphData } from '@/types/graph-explorer';
+
+// Mock react-i18next
+vi.mock('react-i18next', () => ({
+ useTranslation: () => ({
+ t: (key: string, opts?: Record) => {
+ // Minimal interpolation for description text
+ if (opts && key === 'diagrams.saveDescription') {
+ return `Save current view (${opts.nodeCount} nodes, ${opts.edgeCount} edges)`;
+ }
+ return key;
+ },
+ i18n: { language: 'en' },
+ }),
+}));
+
+const sampleData: GraphData = {
+ nodes: [
+ { id: 'n1', label: 'Alice', type: 'Person', properties: {}, status: 'existing' as const },
+ { id: 'n2', label: 'Bob', type: 'Person', properties: {}, status: 'existing' as const },
+ ],
+ edges: [
+ { id: 'e1', source: 'n1', target: 'n2', relationshipType: 'KNOWS', properties: {}, status: 'existing' as const },
+ ],
+};
+
+const emptyData: GraphData = { nodes: [], edges: [] };
+
+describe('diagram localStorage helpers', () => {
+ beforeEach(() => {
+ localStorage.clear();
+ });
+
+ it('getDiagrams returns empty array for unknown table', () => {
+ expect(getDiagrams('nonexistent')).toEqual([]);
+ });
+
+ it('saveDiagrams + getDiagrams round-trips data', () => {
+ const diagrams: SavedDiagram[] = [
+ {
+ id: 'diag-1',
+ name: 'Test',
+ savedAt: '2026-01-01T00:00:00Z',
+ nodeCount: 2,
+ edgeCount: 1,
+ data: sampleData,
+ },
+ ];
+ saveDiagrams('test-table', diagrams);
+ const loaded = getDiagrams('test-table');
+ expect(loaded).toEqual(diagrams);
+ });
+
+ it('getDiagrams handles corrupt data gracefully', () => {
+ localStorage.setItem('graph-explorer-diagrams:bad', '{invalid json');
+ expect(getDiagrams('bad')).toEqual([]);
+ });
+});
+
+describe('DiagramManager component', () => {
+ beforeEach(() => {
+ localStorage.clear();
+ });
+
+ it('shows empty state when no diagrams saved', () => {
+ render(
+ ,
+ );
+ expect(screen.getByText('diagrams.empty')).toBeInTheDocument();
+ });
+
+ it('disables save button when data is empty', () => {
+ render(
+ ,
+ );
+ const saveBtn = screen.getByText('diagrams.save').closest('button');
+ expect(saveBtn).toBeDisabled();
+ });
+
+ it('opens save dialog and saves a diagram', async () => {
+ const user = userEvent.setup();
+ render(
+ ,
+ );
+
+ // Click save button to open dialog
+ await user.click(screen.getByText('diagrams.save'));
+ // Type a name
+ const input = screen.getByPlaceholderText('diagrams.namePlaceholder');
+ await user.type(input, 'My Diagram');
+ // Click save in dialog
+ const dialogSaveButtons = screen.getAllByText('diagrams.save');
+ // The second one is inside the dialog
+ await user.click(dialogSaveButtons[dialogSaveButtons.length - 1]);
+
+ // Diagram should now appear in the list
+ expect(screen.getByText('My Diagram')).toBeInTheDocument();
+ // Should be saved in localStorage
+ const saved = getDiagrams('test');
+ expect(saved).toHaveLength(1);
+ expect(saved[0].name).toBe('My Diagram');
+ });
+
+ it('restores a diagram when restore button is clicked', async () => {
+ const user = userEvent.setup();
+ const onRestore = vi.fn();
+
+ // Pre-save a diagram
+ saveDiagrams('test', [
+ {
+ id: 'diag-1',
+ name: 'Saved View',
+ savedAt: '2026-01-01T00:00:00Z',
+ nodeCount: 2,
+ edgeCount: 1,
+ data: sampleData,
+ },
+ ]);
+
+ render(
+ ,
+ );
+
+ // Find the diagram entry and hover to reveal buttons
+ const entry = screen.getByText('Saved View').closest('div[class*="group"]')!;
+ // Click the restore button (FolderOpen icon button)
+ const restoreBtn = within(entry as HTMLElement).getAllByRole('button')[0];
+ await user.click(restoreBtn);
+
+ expect(onRestore).toHaveBeenCalledWith(sampleData);
+ });
+
+ it('deletes a diagram when delete button is clicked', async () => {
+ const user = userEvent.setup();
+
+ saveDiagrams('test', [
+ {
+ id: 'diag-1',
+ name: 'To Delete',
+ savedAt: '2026-01-01T00:00:00Z',
+ nodeCount: 2,
+ edgeCount: 1,
+ data: sampleData,
+ },
+ ]);
+
+ render(
+ ,
+ );
+
+ expect(screen.getByText('To Delete')).toBeInTheDocument();
+
+ // Click delete button
+ const entry = screen.getByText('To Delete').closest('div[class*="group"]')!;
+ const deleteBtn = within(entry as HTMLElement).getAllByRole('button')[1];
+ await user.click(deleteBtn);
+
+ // Should be removed from view and localStorage
+ expect(screen.queryByText('To Delete')).not.toBeInTheDocument();
+ expect(getDiagrams('test')).toHaveLength(0);
+ });
+});
diff --git a/src/frontend/src/components/graph-explorer/diagram-manager.tsx b/src/frontend/src/components/graph-explorer/diagram-manager.tsx
new file mode 100644
index 00000000..81d26b0b
--- /dev/null
+++ b/src/frontend/src/components/graph-explorer/diagram-manager.tsx
@@ -0,0 +1,226 @@
+/**
+ * Diagram Manager for Graph Explorer.
+ *
+ * Allows users to save, name, and restore curated subgraph snapshots.
+ * Diagrams are persisted in localStorage, keyed by table name so each
+ * dataset has its own diagram collection.
+ */
+
+import React, { useState, useCallback } from 'react';
+import { useTranslation } from 'react-i18next';
+import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
+import { Button } from '@/components/ui/button';
+import { Input } from '@/components/ui/input';
+import {
+ Dialog,
+ DialogContent,
+ DialogDescription,
+ DialogFooter,
+ DialogHeader,
+ DialogTitle,
+ DialogTrigger,
+} from '@/components/ui/dialog';
+import { Save, FolderOpen, Trash2, Plus } from 'lucide-react';
+import { cn } from '@/lib/utils';
+import type { GraphData } from '@/types/graph-explorer';
+
+// ---------------------------------------------------------------------------
+// Types
+// ---------------------------------------------------------------------------
+
+export interface SavedDiagram {
+ id: string;
+ name: string;
+ /** ISO timestamp */
+ savedAt: string;
+ nodeCount: number;
+ edgeCount: number;
+ data: GraphData;
+}
+
+// ---------------------------------------------------------------------------
+// localStorage helpers
+// ---------------------------------------------------------------------------
+
+const STORAGE_PREFIX = 'graph-explorer-diagrams:';
+
+export function getDiagrams(tableName: string): SavedDiagram[] {
+ try {
+ const raw = localStorage.getItem(`${STORAGE_PREFIX}${tableName}`);
+ if (!raw) return [];
+ return JSON.parse(raw) as SavedDiagram[];
+ } catch {
+ return [];
+ }
+}
+
+export function saveDiagrams(tableName: string, diagrams: SavedDiagram[]): void {
+ localStorage.setItem(`${STORAGE_PREFIX}${tableName}`, JSON.stringify(diagrams));
+}
+
+// ---------------------------------------------------------------------------
+// Component
+// ---------------------------------------------------------------------------
+
+export interface DiagramManagerProps {
+ tableName: string;
+ currentData: GraphData;
+ onRestoreDiagram: (data: GraphData) => void;
+ disabled?: boolean;
+ className?: string;
+}
+
+export function DiagramManager({
+ tableName,
+ currentData,
+ onRestoreDiagram,
+ disabled = false,
+ className,
+}: DiagramManagerProps) {
+ const { t } = useTranslation('graph-explorer');
+
+ const [diagrams, setDiagrams] = useState(() => getDiagrams(tableName));
+ const [saveDialogOpen, setSaveDialogOpen] = useState(false);
+ const [diagramName, setDiagramName] = useState('');
+
+ // Reload diagrams when tableName changes
+ React.useEffect(() => {
+ setDiagrams(getDiagrams(tableName));
+ }, [tableName]);
+
+ const handleSave = useCallback(() => {
+ const name = diagramName.trim() || `Diagram ${diagrams.length + 1}`;
+ const newDiagram: SavedDiagram = {
+ id: `diag-${Date.now()}`,
+ name,
+ savedAt: new Date().toISOString(),
+ nodeCount: currentData.nodes.length,
+ edgeCount: currentData.edges.length,
+ data: currentData,
+ };
+ const updated = [newDiagram, ...diagrams];
+ setDiagrams(updated);
+ saveDiagrams(tableName, updated);
+ setDiagramName('');
+ setSaveDialogOpen(false);
+ }, [diagramName, diagrams, currentData, tableName]);
+
+ const handleDelete = useCallback(
+ (diagramId: string) => {
+ const updated = diagrams.filter((d) => d.id !== diagramId);
+ setDiagrams(updated);
+ saveDiagrams(tableName, updated);
+ },
+ [diagrams, tableName],
+ );
+
+ const handleRestore = useCallback(
+ (diagram: SavedDiagram) => {
+ onRestoreDiagram(diagram.data);
+ },
+ [onRestoreDiagram],
+ );
+
+ const formatDate = (iso: string) => {
+ try {
+ return new Date(iso).toLocaleString(undefined, {
+ month: 'short',
+ day: 'numeric',
+ hour: '2-digit',
+ minute: '2-digit',
+ });
+ } catch {
+ return iso;
+ }
+ };
+
+ return (
+
+
+
+
{t('diagrams.title')}
+
+
+
+
+
+ {diagrams.length === 0 ? (
+ {t('diagrams.empty')}
+ ) : (
+ diagrams.map((d) => (
+
+
+
{d.name}
+
+ {d.nodeCount}n / {d.edgeCount}e · {formatDate(d.savedAt)}
+
+
+
+
+
+ ))
+ )}
+
+
+ );
+}
diff --git a/src/frontend/src/components/graph-explorer/graph-context-menu.test.tsx b/src/frontend/src/components/graph-explorer/graph-context-menu.test.tsx
new file mode 100644
index 00000000..ca2542f6
--- /dev/null
+++ b/src/frontend/src/components/graph-explorer/graph-context-menu.test.tsx
@@ -0,0 +1,273 @@
+import { describe, it, expect, vi, beforeEach } from 'vitest';
+import { render, screen, fireEvent } from '@testing-library/react';
+import userEvent from '@testing-library/user-event';
+import {
+ GraphContextMenu,
+ type GraphContextMenuProps,
+ type ContextMenuTarget,
+ type ContextMenuPosition,
+} from './graph-context-menu';
+
+// Mock react-i18next — return the key as the display text
+vi.mock('react-i18next', () => ({
+ useTranslation: () => ({
+ t: (key: string) => key,
+ i18n: { language: 'en' },
+ }),
+}));
+
+// Mock navigator.clipboard
+const mockWriteText = vi.fn().mockResolvedValue(undefined);
+
+describe('GraphContextMenu', () => {
+ const defaultPosition: ContextMenuPosition = { x: 100, y: 200 };
+
+ const defaultCallbacks = {
+ onClose: vi.fn(),
+ onExpandNeighbors: vi.fn(),
+ onExpandByType: vi.fn(),
+ onCollapseNode: vi.fn(),
+ onEditNode: vi.fn(),
+ onDeleteNode: vi.fn(),
+ onCenterOnNode: vi.fn(),
+ onEditEdge: vi.fn(),
+ onDeleteEdge: vi.fn(),
+ onCreateNode: vi.fn(),
+ onResetView: vi.fn(),
+ onFitToScreen: vi.fn(),
+ };
+
+ beforeEach(() => {
+ vi.clearAllMocks();
+ mockWriteText.mockResolvedValue(undefined);
+ // Ensure navigator.clipboard.writeText is our mock
+ if (!navigator.clipboard || navigator.clipboard.writeText !== mockWriteText) {
+ Object.defineProperty(global.navigator, 'clipboard', {
+ value: { writeText: mockWriteText },
+ writable: true,
+ configurable: true,
+ });
+ }
+ });
+
+ function renderMenu(overrides: Partial = {}) {
+ return render(
+ ,
+ );
+ }
+
+ // -------------------------------------------------------------------
+ // Visibility
+ // -------------------------------------------------------------------
+ it('renders nothing when position is null', () => {
+ const { container } = renderMenu({ position: null });
+ expect(container.innerHTML).toBe('');
+ });
+
+ it('renders nothing when target is null', () => {
+ const { container } = renderMenu({ target: null });
+ expect(container.innerHTML).toBe('');
+ });
+
+ // -------------------------------------------------------------------
+ // Node context menu
+ // -------------------------------------------------------------------
+ describe('node context menu', () => {
+ const nodeTarget: ContextMenuTarget = {
+ type: 'node',
+ id: 'node-1',
+ label: 'Alice',
+ nodeType: 'Person',
+ isExpanded: false,
+ connectedEdgeTypes: ['KNOWS', 'WORKS_AT'],
+ };
+
+ it('renders node menu items', () => {
+ renderMenu({ target: nodeTarget });
+ // Header
+ expect(screen.getByText(/Alice/)).toBeInTheDocument();
+ // Expand all neighbors
+ expect(screen.getByText('contextMenu.expandAll')).toBeInTheDocument();
+ // Edit, Delete, Center, Copy
+ expect(screen.getByText('contextMenu.editNode')).toBeInTheDocument();
+ expect(screen.getByText('contextMenu.deleteNode')).toBeInTheDocument();
+ expect(screen.getByText('contextMenu.centerOnNode')).toBeInTheDocument();
+ expect(screen.getByText('contextMenu.copyNodeId')).toBeInTheDocument();
+ });
+
+ it('calls onExpandNeighbors with "both" when expand all is clicked', async () => {
+ const user = userEvent.setup();
+ renderMenu({ target: nodeTarget });
+ await user.click(screen.getByText('contextMenu.expandAll'));
+ expect(defaultCallbacks.onExpandNeighbors).toHaveBeenCalledWith('node-1', 'both');
+ expect(defaultCallbacks.onClose).toHaveBeenCalled();
+ });
+
+ it('calls onEditNode when edit is clicked', async () => {
+ const user = userEvent.setup();
+ renderMenu({ target: nodeTarget });
+ await user.click(screen.getByText('contextMenu.editNode'));
+ expect(defaultCallbacks.onEditNode).toHaveBeenCalledWith('node-1');
+ expect(defaultCallbacks.onClose).toHaveBeenCalled();
+ });
+
+ it('calls onDeleteNode when delete is clicked', async () => {
+ const user = userEvent.setup();
+ renderMenu({ target: nodeTarget });
+ await user.click(screen.getByText('contextMenu.deleteNode'));
+ expect(defaultCallbacks.onDeleteNode).toHaveBeenCalledWith('node-1');
+ expect(defaultCallbacks.onClose).toHaveBeenCalled();
+ });
+
+ it('calls onCenterOnNode when center is clicked', async () => {
+ const user = userEvent.setup();
+ renderMenu({ target: nodeTarget });
+ await user.click(screen.getByText('contextMenu.centerOnNode'));
+ expect(defaultCallbacks.onCenterOnNode).toHaveBeenCalledWith('node-1');
+ expect(defaultCallbacks.onClose).toHaveBeenCalled();
+ });
+
+ it('copies node ID to clipboard when copy is clicked', async () => {
+ renderMenu({ target: nodeTarget });
+ fireEvent.click(screen.getByText('contextMenu.copyNodeId'));
+ // navigator.clipboard.writeText returns a Promise, so wait a tick
+ await vi.waitFor(() => {
+ expect(mockWriteText).toHaveBeenCalledWith('node-1');
+ });
+ expect(defaultCallbacks.onClose).toHaveBeenCalled();
+ });
+
+ it('shows collapse instead of expand sub-menu when node is expanded', () => {
+ const expandedTarget: ContextMenuTarget = { ...nodeTarget, isExpanded: true };
+ renderMenu({ target: expandedTarget });
+ expect(screen.getByText('contextMenu.collapse')).toBeInTheDocument();
+ // Expand all should still be present (re-expand)
+ expect(screen.getByText('contextMenu.expandAll')).toBeInTheDocument();
+ });
+
+ it('calls onCollapseNode when collapse is clicked', async () => {
+ const user = userEvent.setup();
+ const expandedTarget: ContextMenuTarget = { ...nodeTarget, isExpanded: true };
+ renderMenu({ target: expandedTarget });
+ await user.click(screen.getByText('contextMenu.collapse'));
+ expect(defaultCallbacks.onCollapseNode).toHaveBeenCalledWith('node-1');
+ expect(defaultCallbacks.onClose).toHaveBeenCalled();
+ });
+ });
+
+ // -------------------------------------------------------------------
+ // Edge context menu
+ // -------------------------------------------------------------------
+ describe('edge context menu', () => {
+ const edgeTarget: ContextMenuTarget = {
+ type: 'edge',
+ id: 'edge-1',
+ relationshipType: 'KNOWS',
+ };
+
+ it('renders edge menu items', () => {
+ renderMenu({ target: edgeTarget });
+ expect(screen.getByText('KNOWS')).toBeInTheDocument();
+ expect(screen.getByText('contextMenu.editEdge')).toBeInTheDocument();
+ expect(screen.getByText('contextMenu.deleteEdge')).toBeInTheDocument();
+ expect(screen.getByText('contextMenu.copyEdgeDetails')).toBeInTheDocument();
+ });
+
+ it('calls onEditEdge when edit is clicked', async () => {
+ const user = userEvent.setup();
+ renderMenu({ target: edgeTarget });
+ await user.click(screen.getByText('contextMenu.editEdge'));
+ expect(defaultCallbacks.onEditEdge).toHaveBeenCalledWith('edge-1');
+ expect(defaultCallbacks.onClose).toHaveBeenCalled();
+ });
+
+ it('calls onDeleteEdge when delete is clicked', async () => {
+ const user = userEvent.setup();
+ renderMenu({ target: edgeTarget });
+ await user.click(screen.getByText('contextMenu.deleteEdge'));
+ expect(defaultCallbacks.onDeleteEdge).toHaveBeenCalledWith('edge-1');
+ expect(defaultCallbacks.onClose).toHaveBeenCalled();
+ });
+ });
+
+ // -------------------------------------------------------------------
+ // Canvas context menu
+ // -------------------------------------------------------------------
+ describe('canvas context menu', () => {
+ const canvasTarget: ContextMenuTarget = { type: 'canvas' };
+
+ it('renders canvas menu items', () => {
+ renderMenu({ target: canvasTarget });
+ expect(screen.getByText('contextMenu.createNode')).toBeInTheDocument();
+ expect(screen.getByText('contextMenu.fitToScreen')).toBeInTheDocument();
+ expect(screen.getByText('contextMenu.resetView')).toBeInTheDocument();
+ });
+
+ it('calls onCreateNode when create node is clicked', async () => {
+ const user = userEvent.setup();
+ renderMenu({ target: canvasTarget });
+ await user.click(screen.getByText('contextMenu.createNode'));
+ expect(defaultCallbacks.onCreateNode).toHaveBeenCalled();
+ expect(defaultCallbacks.onClose).toHaveBeenCalled();
+ });
+
+ it('calls onFitToScreen when fit to screen is clicked', async () => {
+ const user = userEvent.setup();
+ renderMenu({ target: canvasTarget });
+ await user.click(screen.getByText('contextMenu.fitToScreen'));
+ expect(defaultCallbacks.onFitToScreen).toHaveBeenCalled();
+ expect(defaultCallbacks.onClose).toHaveBeenCalled();
+ });
+
+ it('calls onResetView when reset view is clicked', async () => {
+ const user = userEvent.setup();
+ renderMenu({ target: canvasTarget });
+ await user.click(screen.getByText('contextMenu.resetView'));
+ expect(defaultCallbacks.onResetView).toHaveBeenCalled();
+ expect(defaultCallbacks.onClose).toHaveBeenCalled();
+ });
+ });
+
+ // -------------------------------------------------------------------
+ // Dismiss behavior
+ // -------------------------------------------------------------------
+ describe('dismiss behavior', () => {
+ it('calls onClose when Escape is pressed', async () => {
+ renderMenu({ target: { type: 'canvas' } });
+ // Wait for the event listener to be attached (setTimeout(0) in the component)
+ await vi.waitFor(() => {
+ fireEvent.keyDown(document, { key: 'Escape' });
+ expect(defaultCallbacks.onClose).toHaveBeenCalled();
+ });
+ });
+
+ it('calls onClose when clicking outside the menu', async () => {
+ renderMenu({ target: { type: 'canvas' } });
+ // Wait for the event listener to be attached
+ await vi.waitFor(() => {
+ fireEvent.mouseDown(document.body);
+ expect(defaultCallbacks.onClose).toHaveBeenCalled();
+ });
+ });
+ });
+
+ // -------------------------------------------------------------------
+ // Position adjustment
+ // -------------------------------------------------------------------
+ describe('position adjustment', () => {
+ it('renders at the provided position', () => {
+ const { container } = renderMenu({
+ target: { type: 'canvas' },
+ position: { x: 150, y: 250 },
+ });
+ const menuEl = container.firstElementChild as HTMLElement;
+ expect(menuEl.style.left).toBe('150px');
+ expect(menuEl.style.top).toBe('250px');
+ });
+ });
+});
diff --git a/src/frontend/src/components/graph-explorer/graph-context-menu.tsx b/src/frontend/src/components/graph-explorer/graph-context-menu.tsx
new file mode 100644
index 00000000..2b5c354f
--- /dev/null
+++ b/src/frontend/src/components/graph-explorer/graph-context-menu.tsx
@@ -0,0 +1,409 @@
+/**
+ * Context menu for Graph Explorer.
+ *
+ * Renders a positioned menu overlay triggered by right-click on nodes, edges,
+ * or the canvas background. Since the graph is rendered on a