diff --git a/README.md b/README.md index 1d88f37b..6d1e0e9c 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ yarn dev:frontend **Terminal 2 - Backend:** ```bash cd src -yarn dev:backend +hatch -e dev run dev-backend ``` - Frontend: http://localhost:3000 @@ -171,5 +171,4 @@ This project is licensed under the Databricks License - see the [LICENSE.txt](LI --- -**Version**: 0.4.0 **Maintained by**: [Databricks](https://databricks.com) diff --git a/src/backend/.env.example b/src/backend/.env.example index 4ba34745..89234f11 100644 --- a/src/backend/.env.example +++ b/src/backend/.env.example @@ -98,6 +98,15 @@ LLM_ENABLED=False # Security: First-phase injection detection prompt (advanced - usually not changed) # LLM_INJECTION_CHECK_PROMPT="You are a security analyzer. Analyze the following content..." +# --- Graph Explorer Safety Limits --- +# Maximum nodes/edges returned from initial graph load (prevents browser OOM) +# GRAPH_MAX_NODES=5000 +# GRAPH_MAX_EDGES=10000 +# Default max neighbors per expansion +# GRAPH_NEIGHBOR_LIMIT=50 +# SQL statement timeout for graph queries +# GRAPH_QUERY_TIMEOUT=30s + # --- Self-Service Sandbox Policy Settings --- # These settings control which catalogs and schemas users can create objects in # via the self-service dialog. This is a global security boundary separate from diff --git a/src/backend/src/app.py b/src/backend/src/app.py index 5552e3b4..51301771 100644 --- a/src/backend/src/app.py +++ b/src/backend/src/app.py @@ -1,12 +1,12 @@ # Initialize configuration and logging first from src.common.config import get_settings, init_config -from src.common.logging import setup_logging, get_logger +from src.common.logging import get_logger, setup_logging + init_config() settings = get_settings() setup_logging(level=settings.LOG_LEVEL, log_file=settings.LOG_FILE) logger = get_logger(__name__) -import mimetypes import os import time from pathlib import Path @@ -14,21 +14,22 @@ # Server startup timestamp for cache invalidation SERVER_STARTUP_TIME = int(time.time()) -from fastapi import Depends, FastAPI, Request +from fastapi import Depends, FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, HTMLResponse from fastapi.staticfiles import StaticFiles -from starlette.responses import Response -from fastapi import HTTPException, status from src.common.middleware import ErrorHandlingMiddleware, LoggingMiddleware from src.routes import ( access_grants_routes, + audit_routes, catalog_commander_routes, - data_catalog_routes, - compliance_routes, + change_log_routes, comments_routes, + compliance_routes, + costs_routes, data_asset_reviews_routes, + data_catalog_routes, data_contracts_routes, data_domains_routes, data_product_routes, @@ -36,7 +37,6 @@ entitlements_routes, entitlements_sync_routes, estate_manager_routes, - industry_ontology_routes, jobs_routes, llm_search_routes, mcp_routes, @@ -44,40 +44,21 @@ mdm_routes, metadata_routes, notifications_routes, + projects_routes, search_routes, security_features_routes, self_service_routes, - settings_routes, - semantic_models_routes, semantic_links_routes, - user_routes, - audit_routes, - change_log_routes, - workspace_routes, + semantic_models_routes, + settings_routes, tags_routes, teams_routes, - projects_routes, - costs_routes, + user_routes, + graph_explorer_routes, workflows_routes, + workspace_routes, ) - -from src.common.database import init_db, get_session_factory, SQLAlchemySession -from src.controller.data_products_manager import DataProductsManager -from src.controller.data_asset_reviews_manager import DataAssetReviewManager -from src.controller.data_contracts_manager import DataContractsManager -from src.controller.semantic_models_manager import SemanticModelsManager -from src.controller.search_manager import SearchManager -from src.common.workspace_client import get_workspace_client -from src.controller.settings_manager import SettingsManager -from src.controller.users_manager import UsersManager -from src.controller.authorization_manager import AuthorizationManager -from src.utils.startup_tasks import ( - initialize_database, - initialize_managers, - startup_event_handler, - shutdown_event_handler -) - +from src.utils.startup_tasks import initialize_database, initialize_managers logger.info(f"Starting application in {settings.ENV} mode.") logger.info(f"Debug mode: {settings.DEBUG}") @@ -94,28 +75,32 @@ # --- Application Lifecycle Events --- + # Application Startup Event async def startup_event(): import os - + # Skip startup tasks if running tests - if os.getenv('SKIP_STARTUP_TASKS') == 'true': + if os.getenv("SKIP_STARTUP_TASKS") == "true": logger.info("SKIP_STARTUP_TASKS=true detected - skipping startup tasks (test mode)") return - + logger.info("Running application startup event...") settings = get_settings() - + initialize_database(settings=settings) initialize_managers(app) # Handles DB-backed manager init - + # Initialize Git service for indirect delivery mode try: logger.info("Initializing Git service...") from src.common.git import init_git_service + git_service = init_git_service(settings) app.state.git_service = git_service - logger.info(f"Git service initialized (status: {git_service.get_status().clone_status.value})") + logger.info( + f"Git service initialized (status: {git_service.get_status().clone_status.value})" + ) except Exception as e: logger.warning(f"Failed initializing Git service: {e}", exc_info=True) app.state.git_service = None @@ -123,8 +108,9 @@ async def startup_event(): # Initialize Grant Manager for direct delivery mode try: logger.info("Initializing Grant Manager...") - from src.controller.grant_manager import init_grant_manager from src.common.workspace_client import get_workspace_client + from src.controller.grant_manager import init_grant_manager + ws_client = get_workspace_client(settings=settings) grant_manager = init_grant_manager(ws_client=ws_client, settings=settings) app.state.grant_manager = grant_manager @@ -137,30 +123,36 @@ async def startup_event(): try: logger.info("Initializing Delivery Service...") from src.controller.delivery_service import init_delivery_service + delivery_service = init_delivery_service( settings=settings, - git_service=getattr(app.state, 'git_service', None), - grant_manager=getattr(app.state, 'grant_manager', None), - notifications_manager=getattr(app.state, 'notifications_manager', None), + git_service=getattr(app.state, "git_service", None), + grant_manager=getattr(app.state, "grant_manager", None), + notifications_manager=getattr(app.state, "notifications_manager", None), ) app.state.delivery_service = delivery_service - logger.info(f"Delivery Service initialized (active modes: {[m.value for m in delivery_service.get_active_modes()]})") + logger.info( + f"Delivery Service initialized (active modes: {[m.value for m in delivery_service.get_active_modes()]})" + ) except Exception as e: logger.warning(f"Failed initializing Delivery Service: {e}", exc_info=True) app.state.delivery_service = None - + # Demo data is loaded on-demand via POST /api/settings/demo-data/load # See: src/backend/src/data/demo_data.sql - + # Ensure SearchManager is initialized and index built try: from src.common.search_interfaces import SearchableAsset from src.controller.search_manager import SearchManager + logger.info("Initializing SearchManager after data load (app.py)...") searchable_managers_instances = [] - for attr_name, manager_instance in list(getattr(app.state, '_state', {}).items()): + for attr_name, manager_instance in list(getattr(app.state, "_state", {}).items()): try: - if isinstance(manager_instance, SearchableAsset) and hasattr(manager_instance, 'get_search_index_items'): + if isinstance(manager_instance, SearchableAsset) and hasattr( + manager_instance, "get_search_index_items" + ): searchable_managers_instances.append(manager_instance) except Exception: continue @@ -172,11 +164,13 @@ async def startup_event(): logger.info("Application startup complete.") + # Application Shutdown Event async def shutdown_event(): logger.info("Running application shutdown event...") logger.info("Application shutdown complete.") + # --- FastAPI App Instantiation (AFTER defining lifecycle functions) --- # Define paths @@ -202,29 +196,23 @@ async def shutdown_event(): {"name": "Datasets", "description": "Manage datasets and dataset instances"}, {"name": "Data Contracts", "description": "Manage data contracts for data products"}, {"name": "Data Products", "description": "Manage data products and subscriptions"}, - # Governance - Standards and approval workflows {"name": "Compliance", "description": "Manage compliance policies and runs"}, {"name": "Approvals", "description": "Manage approval workflows"}, {"name": "Process Workflows", "description": "Manage process workflows"}, {"name": "Data Asset Reviews", "description": "Manage data asset review workflows"}, - # Business Glossary - Semantic models and ontologies {"name": "Semantic Models", "description": "Manage semantic models and ontologies"}, {"name": "Semantic Links", "description": "Manage semantic links between entities"}, - {"name": "Industry Ontologies", "description": "Industry Ontology Library for importing standard ontologies"}, - # Operations - Monitoring and technical management {"name": "Estates", "description": "Manage data estates"}, {"name": "Master Data Management", "description": "Master data management features"}, {"name": "Catalog Commander", "description": "Dual-pane catalog explorer"}, - # Security - Access control and security features {"name": "Security Features", "description": "Advanced security features"}, {"name": "Entitlements", "description": "Manage entitlements and personas"}, {"name": "Entitlements Sync", "description": "Sync entitlements from external sources"}, {"name": "Access Grants", "description": "Manage time-limited access grants"}, - # System - Utilities, configuration, auxiliary services {"name": "Metadata", "description": "Manage metadata attachments"}, {"name": "Workspace", "description": "Workspace asset operations"}, @@ -250,7 +238,7 @@ async def shutdown_event(): dependencies=[Depends(get_settings)], on_startup=[startup_event], on_shutdown=[shutdown_event], - openapi_tags=openapi_tags + openapi_tags=openapi_tags, ) # Configure CORS @@ -278,7 +266,7 @@ async def shutdown_event(): app.add_middleware(LoggingMiddleware) # Mount static files for the React application (skip in test mode) -if not os.environ.get('TESTING'): +if not os.environ.get("TESTING"): app.mount("/static", StaticFiles(directory=STATIC_ASSETS_PATH, html=True), name="static") # Data Products - Core data lifecycle @@ -291,12 +279,12 @@ async def shutdown_event(): data_contracts_routes.register_routes(app) data_product_routes.register_routes(app) from src.routes import approvals_routes + approvals_routes.register_routes(app) # Governance - Standards and approval workflows semantic_models_routes.register_routes(app) semantic_links_routes.register_routes(app) -industry_ontology_routes.register_routes(app) # Industry Ontology Library data_asset_reviews_routes.register_routes(app) data_catalog_routes.register_routes(app) @@ -326,50 +314,61 @@ async def shutdown_event(): mcp_routes.register_routes(app) mcp_tokens_routes.register_routes(app) self_service_routes.register_routes(app) +graph_explorer_routes.register_routes(app) workflows_routes.register_routes(app) settings_routes.register_routes(app) + # Define other specific API routes BEFORE the catch-all @app.get("/api/time") async def get_current_time(): """Get the current time (for testing purposes mostly)""" - return {'time': time.time()} + return {"time": time.time()} + @app.get("/api/cache-version") async def get_cache_version(): """Get the server cache version for client-side cache invalidation""" - return {'version': SERVER_STARTUP_TIME, 'timestamp': int(time.time())} + return {"version": SERVER_STARTUP_TIME, "timestamp": int(time.time())} + @app.get("/api/version") async def get_app_version(): """Get the application version and server start time""" - return { - 'version': __version__, - 'startTime': SERVER_STARTUP_TIME, - 'timestamp': int(time.time()) - } + return {"version": __version__, "startTime": SERVER_STARTUP_TIME, "timestamp": int(time.time())} + # Define the SPA catch-all route LAST (skip in test mode) -if not os.environ.get('TESTING'): +if not os.environ.get("TESTING"): + @app.get("/{full_path:path}") def serve_spa(full_path: str): # Only catch routes that aren't API routes, static files, or API docs # This check might be redundant now due to ordering, but safe to keep - if not full_path.startswith("api/") and not full_path.startswith("static/") and full_path not in ["docs", "redoc", "openapi.json"]: + if ( + not full_path.startswith("api/") + and not full_path.startswith("static/") + and full_path not in ["docs", "redoc", "openapi.json"] + ): # Ensure the path exists before serving spa_index = STATIC_ASSETS_PATH / "index.html" if spa_index.is_file(): - return FileResponse(spa_index, media_type="text/html") + return FileResponse(spa_index, media_type="text/html") else: - # Optional: Return a 404 or a simple HTML message if index.html is missing - logger.error(f"SPA index.html not found at {spa_index}") - return HTMLResponse(content="Frontend not built or index.html missing.", status_code=404) + # Optional: Return a 404 or a simple HTML message if index.html is missing + logger.error(f"SPA index.html not found at {spa_index}") + return HTMLResponse( + content="Frontend not built or index.html missing.", + status_code=404, + ) # If it starts with api/ or static/ but wasn't handled by a router/StaticFiles, # FastAPI will return its default 404 Not Found, which is correct. # No explicit return needed here for that case. + logger.info("All routes registered.") -if __name__ == '__main__': +if __name__ == "__main__": import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/src/backend/src/common/config.py b/src/backend/src/common/config.py index f8e7ed5f..a1c84b38 100644 --- a/src/backend/src/common/config.py +++ b/src/backend/src/common/config.py @@ -138,6 +138,12 @@ class Settings(BaseSettings): env='LLM_INJECTION_CHECK_PROMPT' ) + # Graph Explorer safety limits + GRAPH_MAX_NODES: int = Field(5000, env='GRAPH_MAX_NODES') # Max nodes returned from initial graph load + GRAPH_MAX_EDGES: int = Field(10000, env='GRAPH_MAX_EDGES') # Max edges returned from initial graph load + GRAPH_QUERY_TIMEOUT: str = Field("30s", env='GRAPH_QUERY_TIMEOUT') # SQL statement timeout + GRAPH_NEIGHBOR_LIMIT: int = Field(50, env='GRAPH_NEIGHBOR_LIMIT') # Default max neighbors per expansion + # Sandbox allowlist settings sandbox_default_schema: str = Field('sandbox', validation_alias=AliasChoices('SANDBOX_DEFAULT_SCHEMA', 'sandbox_default_schema')) sandbox_allowed_catalog_prefixes: List[str] = Field(default_factory=lambda: ['user_'], validation_alias=AliasChoices('SANDBOX_ALLOWED_CATALOG_PREFIXES', 'sandbox_allowed_catalog_prefixes')) diff --git a/src/backend/src/common/database.py b/src/backend/src/common/database.py index d7f05e81..382569dc 100644 --- a/src/backend/src/common/database.py +++ b/src/backend/src/common/database.py @@ -1,38 +1,35 @@ import os -import uuid -import time import threading +import time +import uuid from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime -from typing import Any, Dict, List, Optional, TypeVar - -from sqlalchemy import create_engine, text, event -from sqlalchemy.orm import sessionmaker, Session as SQLAlchemySession -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy import pool -from sqlalchemy.engine import Connection, URL +from typing import Any, TypeVar from alembic.config import Config as AlembicConfig -from alembic.script import ScriptDirectory from alembic.runtime.migration import MigrationContext -from alembic import command as alembic_command +from alembic.script import ScriptDirectory + +# Import SDK components +from sqlalchemy import create_engine, event, pool, text +from sqlalchemy.engine import URL, Connection +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker -from .config import get_settings, Settings -from .logging import get_logger -from src.common.workspace_client import get_workspace_client from src.common.unity_catalog_utils import ( ensure_catalog_exists, ensure_schema_exists, sanitize_postgres_identifier, ) -# Import SDK components -from databricks.sdk.errors import NotFound, DatabricksError -from databricks.sdk.core import Config, oauth_service_principal +from src.common.workspace_client import get_workspace_client + +from .config import Settings, get_settings +from .logging import get_logger logger = get_logger(__name__) -T = TypeVar('T') +T = TypeVar("T") # Define the base class for SQLAlchemy models Base = declarative_base() @@ -41,40 +38,52 @@ # This ensures Base.metadata is populated before init_db needs it. logger.debug("Importing all DB model modules to register with Base...") try: - from src.db_models import settings as settings_db - from src.db_models import audit_log - from src.db_models import data_asset_reviews - from src.db_models import data_products - from src.db_models import notifications - from src.db_models import data_domains - from src.db_models import semantic_links - from src.db_models import metadata as metadata_db - from src.db_models import semantic_models - from src.db_models import comments - from src.db_models import costs - from src.db_models import change_log # Add missing model imports for Alembic - from src.db_models import data_contracts - from src.db_models import workflow_configurations - from src.db_models import workflow_installations - from src.db_models import workflow_job_runs - from src.db_models import compliance - from src.db_models import projects - from src.db_models import teams - from src.db_models import mdm # MDM models for master data management - # from src.db_models.data_products import DataProductDb, InfoDb, InputPortDb, OutputPortDb # Already imported via module import above - from src.db_models.settings import AppRoleDb + from src.db_models import ( + audit_log, + change_log, + comments, + compliance, + costs, + data_asset_reviews, + data_contracts, + data_domains, + data_products, + mdm, # MDM models for master data management + notifications, + projects, + semantic_links, + semantic_models, + teams, + workflow_configurations, + workflow_installations, + workflow_job_runs, + ) + from src.db_models import metadata as metadata_db + from src.db_models import settings as settings_db + # from src.db_models.users import UserActivityDb, UserSearchHistoryDb # Commented out due to missing file from src.db_models.audit_log import AuditLogDb from src.db_models.notifications import NotificationDb + + # from src.db_models.data_products import DataProductDb, InfoDb, InputPortDb, OutputPortDb # Already imported via module import above + from src.db_models.settings import AppRoleDb + # from src.db_models.business_glossary import GlossaryDb, TermDb, CategoryDb, term_category_association, term_related_terms, term_asset_association # Commented out due to missing file # Add new tag models - from src.db_models.tags import TagDb, TagNamespaceDb, TagNamespacePermissionDb, EntityTagAssociationDb + from src.db_models.tags import ( + EntityTagAssociationDb, + TagDb, + TagNamespaceDb, + TagNamespacePermissionDb, + ) + # Add imports for any other future model modules here logger.debug("DB model modules imported successfully.") except ImportError as e: logger.critical( - f"Failed to import a DB model module during initial registration: {e}", exc_info=True) + f"Failed to import a DB model module during initial registration: {e}", exc_info=True + ) # This is likely a fatal error, consider raising or exiting raise # ------------------------------------------------------------------------- # @@ -86,20 +95,20 @@ engine = None # OAuth token state for Lakebase connections -_oauth_token: Optional[str] = None +_oauth_token: str | None = None _token_last_refresh: float = 0 _token_refresh_lock = threading.Lock() -_token_refresh_thread: Optional[threading.Thread] = None +_token_refresh_thread: threading.Thread | None = None _token_refresh_stop_event = threading.Event() -def get_lakebase_instance_name(app_name: str, ws_client) -> Optional[str]: +def get_lakebase_instance_name(app_name: str, ws_client) -> str | None: """Get the Lakebase instance name from the Databricks App resources. - + Args: app_name: Name of the Databricks App ws_client: Workspace client instance - + Returns: The database instance name, or None if not found """ @@ -117,7 +126,8 @@ def get_lakebase_instance_name(app_name: str, ws_client) -> Optional[str]: @dataclass class InMemorySession: """In-memory session for managing transactions.""" - changes: List[Dict[str, Any]] + + changes: list[dict[str, Any]] def __init__(self): self.changes = [] @@ -135,10 +145,10 @@ class InMemoryStore: def __init__(self): """Initialize the in-memory store.""" - self._data: Dict[str, List[Dict[str, Any]]] = {} - self._metadata: Dict[str, Dict[str, Any]] = {} + self._data: dict[str, list[dict[str, Any]]] = {} + self._metadata: dict[str, dict[str, Any]] = {} - def create_table(self, table_name: str, metadata: Dict[str, Any] = None) -> None: + def create_table(self, table_name: str, metadata: dict[str, Any] = None) -> None: """Create a new table in the store. Args: @@ -150,7 +160,7 @@ def create_table(self, table_name: str, metadata: Dict[str, Any] = None) -> None if metadata: self._metadata[table_name] = metadata - def insert(self, table_name: str, data: Dict[str, Any]) -> None: + def insert(self, table_name: str, data: dict[str, Any]) -> None: """Insert a record into a table. Args: @@ -161,16 +171,16 @@ def insert(self, table_name: str, data: Dict[str, Any]) -> None: self.create_table(table_name) # Add timestamp and id if not present - if 'id' not in data: - data['id'] = str(len(self._data[table_name]) + 1) - if 'created_at' not in data: - data['created_at'] = datetime.utcnow().isoformat() - if 'updated_at' not in data: - data['updated_at'] = data['created_at'] + if "id" not in data: + data["id"] = str(len(self._data[table_name]) + 1) + if "created_at" not in data: + data["created_at"] = datetime.utcnow().isoformat() + if "updated_at" not in data: + data["updated_at"] = data["created_at"] self._data[table_name].append(data) - def get(self, table_name: str, id: str) -> Optional[Dict[str, Any]]: + def get(self, table_name: str, id: str) -> dict[str, Any] | None: """Get a record by ID. Args: @@ -182,9 +192,9 @@ def get(self, table_name: str, id: str) -> Optional[Dict[str, Any]]: """ if table_name not in self._data: return None - return next((item for item in self._data[table_name] if item['id'] == id), None) + return next((item for item in self._data[table_name] if item["id"] == id), None) - def get_all(self, table_name: str) -> List[Dict[str, Any]]: + def get_all(self, table_name: str) -> list[dict[str, Any]]: """Get all records from a table. Args: @@ -195,7 +205,7 @@ def get_all(self, table_name: str) -> List[Dict[str, Any]]: """ return self._data.get(table_name, []) - def update(self, table_name: str, id: str, data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + def update(self, table_name: str, id: str, data: dict[str, Any]) -> dict[str, Any] | None: """Update a record. Args: @@ -210,9 +220,9 @@ def update(self, table_name: str, id: str, data: Dict[str, Any]) -> Optional[Dic return None for item in self._data[table_name]: - if item['id'] == id: + if item["id"] == id: item.update(data) - item['updated_at'] = datetime.utcnow().isoformat() + item["updated_at"] = datetime.utcnow().isoformat() return item return None @@ -230,8 +240,7 @@ def delete(self, table_name: str, id: str) -> bool: return False initial_length = len(self._data[table_name]) - self._data[table_name] = [ - item for item in self._data[table_name] if item['id'] != id] + self._data[table_name] = [item for item in self._data[table_name] if item["id"] != id] return len(self._data[table_name]) < initial_length def clear(self, table_name: str) -> None: @@ -276,37 +285,39 @@ def dispose(self) -> None: # Global database manager instance -db_manager: Optional[DatabaseManager] = None +db_manager: DatabaseManager | None = None def refresh_oauth_token(settings: Settings) -> str: """Generate fresh OAuth token from Databricks for Lakebase connection.""" global _oauth_token, _token_last_refresh - + with _token_refresh_lock: ws_client = get_workspace_client(settings) instance_name = get_lakebase_instance_name(settings.DATABRICKS_APP_NAME, ws_client) - + if not instance_name: - raise ValueError(f"Could not determine Lakebase instance name for app '{settings.DATABRICKS_APP_NAME}'") - + raise ValueError( + f"Could not determine Lakebase instance name for app '{settings.DATABRICKS_APP_NAME}'" + ) + logger.info(f"Generating OAuth token for Lakebase instance: {instance_name}") cred = ws_client.database.generate_database_credential( request_id=str(uuid.uuid4()), instance_names=[instance_name], ) - + _oauth_token = cred.token _token_last_refresh = time.time() logger.info("OAuth token refreshed successfully") - + return _oauth_token def start_token_refresh_background(settings: Settings): """Start background thread to refresh OAuth tokens every 50 minutes.""" global _token_refresh_thread, _token_refresh_stop_event - + def refresh_loop(): while not _token_refresh_stop_event.is_set(): _token_refresh_stop_event.wait(50 * 60) # 50 minutes @@ -315,7 +326,7 @@ def refresh_loop(): refresh_oauth_token(settings) except Exception as e: logger.error(f"Background token refresh failed: {e}", exc_info=True) - + _token_refresh_stop_event.clear() _token_refresh_thread = threading.Thread(target=refresh_loop, daemon=True) _token_refresh_thread.start() @@ -333,14 +344,16 @@ def stop_token_refresh_background(): def get_db_url(settings: Settings) -> str: """Construct the PostgreSQL SQLAlchemy URL with appropriate auth method.""" - + # Validate required settings if not all([settings.PGHOST, settings.PGDATABASE]): - raise ValueError("PostgreSQL connection details (PGHOST, PGDATABASE) are missing in settings.") - + raise ValueError( + "PostgreSQL connection details (PGHOST, PGDATABASE) are missing in settings." + ) + # Determine authentication mode based on ENV use_password_auth = settings.ENV.upper().startswith("LOCAL") - + if use_password_auth: logger.info("Database: Using password authentication (LOCAL mode)") if not settings.PGPASSWORD or not settings.PGUSER: @@ -351,20 +364,17 @@ def get_db_url(settings: Settings) -> str: logger.info("Database: Using OAuth authentication (Lakebase mode)") # Dynamically determine username from authenticated principal ws_client = get_workspace_client(settings) - username = ( - os.getenv("DATABRICKS_CLIENT_ID") - or ws_client.current_user.me().user_name - ) + username = os.getenv("DATABRICKS_CLIENT_ID") or ws_client.current_user.me().user_name if not username: raise ValueError("Could not determine database username from authenticated principal") - + logger.info(f"🔑 Detected service principal username: {username}") password = "" # Will be set via event handler - + # Build URL with schema options and statement timeout query_params = {} options_list = [] - + # Add schema to search_path if specified if settings.PGSCHEMA: # Validate schema name for connection options to prevent injection @@ -379,15 +389,15 @@ def get_db_url(settings: Settings) -> str: logger.info(f"PostgreSQL schema will be set via options: {validated_schema}") else: logger.info("No specific PostgreSQL schema configured, using default (public).") - + # Add statement timeout to prevent indefinite locks (30 seconds) # This helps prevent stuck transactions when operations fail options_list.append("-cstatement_timeout=30000") logger.info("PostgreSQL statement timeout set to 30 seconds") - + if options_list: query_params["options"] = " ".join(options_list) - + db_url_obj = URL.create( drivername="postgresql+psycopg2", username=username, @@ -395,7 +405,7 @@ def get_db_url(settings: Settings) -> str: host=settings.PGHOST, port=settings.PGPORT, database=settings.PGDATABASE, - query=query_params if query_params else None + query=query_params if query_params else None, ) url_str = db_url_obj.render_as_string(hide_password=False) logger.debug( @@ -409,52 +419,53 @@ def ensure_database_and_schema_exist(settings: Settings): """ Ensure the target database and schema exist. If not, create them. Works in both LOCAL and OAuth modes (authentication differs but logic is unified). - + If APP_DB_DROP_ON_START=true, drops the schema CASCADE before creating it. Connects to default postgres database to create target database (OAuth mode only), then creates schema within it. - + The app (as service principal) becomes the owner of what it creates in OAuth mode, eliminating permission issues. - + Security: All PostgreSQL identifiers are validated to prevent SQL injection. """ is_local_mode = settings.ENV.upper().startswith("LOCAL") - - logger.info(f"Ensuring database and schema exist ({'LOCAL' if is_local_mode else 'OAuth'} mode)...") - + + logger.info( + f"Ensuring database and schema exist ({'LOCAL' if is_local_mode else 'OAuth'} mode)..." + ) + # Determine username based on mode if is_local_mode: username = settings.PGUSER else: # Get service principal username for OAuth mode ws_client = get_workspace_client(settings) - username = ( - os.getenv("DATABRICKS_CLIENT_ID") - or ws_client.current_user.me().user_name - ) - + username = os.getenv("DATABRICKS_CLIENT_ID") or ws_client.current_user.me().user_name + if not username: raise ValueError("Could not determine username/service principal") - + # Validate all PostgreSQL identifiers to prevent SQL injection try: target_db = sanitize_postgres_identifier(settings.PGDATABASE) - target_schema = sanitize_postgres_identifier(settings.PGSCHEMA) if settings.PGSCHEMA else None + target_schema = ( + sanitize_postgres_identifier(settings.PGSCHEMA) if settings.PGSCHEMA else None + ) username = sanitize_postgres_identifier(username) except ValueError as e: raise ValueError( f"Invalid PostgreSQL identifier in configuration: {e}. " "Please check PGDATABASE, PGSCHEMA, and username." ) from e - + logger.info(f"Username: {username}") logger.debug(f"Target database: {target_db}, schema: {target_schema}") - + # Generate initial OAuth token for OAuth mode if not is_local_mode: refresh_oauth_token(settings) - + # Build connection URL # In OAuth mode, connect directly to the target database (must be pre-created) # In LOCAL mode, connect to the target database (should already exist) @@ -466,24 +477,25 @@ def ensure_database_and_schema_exist(settings: Settings): port=settings.PGPORT, database=target_db, ) - + # Create temporary engine for schema setup temp_engine = create_engine( connection_url.render_as_string(hide_password=False), - isolation_level="AUTOCOMMIT" # Needed for CREATE SCHEMA + isolation_level="AUTOCOMMIT", # Needed for CREATE SCHEMA ) - + # Inject OAuth token for connections in OAuth mode if not is_local_mode: + @event.listens_for(temp_engine, "do_connect") def inject_token_temp(dialect, conn_rec, cargs, cparams): global _oauth_token if _oauth_token: cparams["password"] = _oauth_token - + try: # In OAuth mode, verify we can connect to the target database - # The database must be pre-created by an admin with: + # The database must be pre-created by an admin with: # CREATE DATABASE "app_ontos"; GRANT CREATE ON DATABASE "app_ontos" TO PUBLIC; if not is_local_mode: try: @@ -504,22 +516,26 @@ def inject_token_temp(dialect, conn_rec, cargs, cparams): f"See the README for detailed setup instructions." ) from e raise - + # Now handle schema (works for both LOCAL and OAuth modes) with temp_engine.connect() as conn: if target_schema and target_schema != "public": # Handle APP_DB_DROP_ON_START: Drop schema CASCADE before creating if settings.APP_DB_DROP_ON_START: - logger.warning(f"APP_DB_DROP_ON_START=true: Dropping schema '{target_schema}' CASCADE...") + logger.warning( + f"APP_DB_DROP_ON_START=true: Dropping schema '{target_schema}' CASCADE..." + ) # DROP SCHEMA cannot be parameterized, but identifier is validated conn.execute(text(f'DROP SCHEMA IF EXISTS "{target_schema}" CASCADE')) conn.commit() - logger.warning(f"✓ Schema '{target_schema}' dropped CASCADE. Will be recreated.") - + logger.warning( + f"✓ Schema '{target_schema}' dropped CASCADE. Will be recreated." + ) + # Explicitly drop application enum types from public schema # These are often created in public and not dropped with CASCADE on a specific schema logger.info("Dropping application enum types from public schema...") - enum_types = ['commentstatus', 'commenttype', 'accesslevel'] + enum_types = ["commentstatus", "commenttype", "accesslevel"] for enum_type in enum_types: try: conn.execute(text(f'DROP TYPE IF EXISTS "{enum_type}" CASCADE')) @@ -527,47 +543,53 @@ def inject_token_temp(dialect, conn_rec, cargs, cparams): logger.warning(f"Could not drop enum type '{enum_type}': {e}") conn.commit() logger.info("✓ Enum types cleanup completed.") - + # Check if schema exists (using parameterized query) result = conn.execute( - text("SELECT 1 FROM information_schema.schemata WHERE schema_name = :schemaname"), - {"schemaname": target_schema} + text( + "SELECT 1 FROM information_schema.schemata WHERE schema_name = :schemaname" + ), + {"schemaname": target_schema}, ) schema_exists = result.scalar() is not None - + if not schema_exists: logger.info(f"Creating schema: {target_schema}") # CREATE SCHEMA cannot be parameterized, but identifier is validated conn.execute(text(f'CREATE SCHEMA "{target_schema}"')) logger.info(f"✓ Schema created: {target_schema} (owner: {username})") - + # Set default privileges for future objects (OAuth mode only) if not is_local_mode: # ALTER statements cannot be parameterized, but identifiers are validated logger.info(f"Setting default privileges in schema: {target_schema}") - conn.execute(text( - f'ALTER DEFAULT PRIVILEGES IN SCHEMA "{target_schema}" ' - f'GRANT ALL ON TABLES TO "{username}"' - )) - conn.execute(text( - f'ALTER DEFAULT PRIVILEGES IN SCHEMA "{target_schema}" ' - f'GRANT ALL ON SEQUENCES TO "{username}"' - )) - logger.info(f"✓ Default privileges configured") - + conn.execute( + text( + f'ALTER DEFAULT PRIVILEGES IN SCHEMA "{target_schema}" ' + f'GRANT ALL ON TABLES TO "{username}"' + ) + ) + conn.execute( + text( + f'ALTER DEFAULT PRIVILEGES IN SCHEMA "{target_schema}" ' + f'GRANT ALL ON SEQUENCES TO "{username}"' + ) + ) + logger.info("✓ Default privileges configured") + conn.commit() else: logger.info(f"✓ Schema already exists: {target_schema}") else: - logger.info(f"Using public schema (no custom schema specified)") - + logger.info("Using public schema (no custom schema specified)") + except Exception as e: if "permission denied" in str(e).lower(): logger.error( f"❌ Permission denied - check database/schema privileges for user '{username}'" ) if not is_local_mode: - logger.error(f"To fix this, run as a Lakebase admin:") + logger.error("To fix this, run as a Lakebase admin:") logger.error(f' DROP DATABASE IF EXISTS "{target_db}";') logger.error(f' CREATE DATABASE "{target_db}";') logger.error(f' GRANT CREATE ON DATABASE "{target_db}" TO "{username}";') @@ -575,13 +597,13 @@ def inject_token_temp(dialect, conn_rec, cargs, cparams): raise finally: temp_engine.dispose() - + logger.info("✓ Database and schema are ready") def ensure_catalog_schema_exists(settings: Settings): """Checks if the configured catalog and schema exist, creates them if not. - + Uses shared Unity Catalog utilities for secure, idempotent catalog/schema creation. """ logger.info("Ensuring required catalog and schema exist...") @@ -600,18 +622,15 @@ def ensure_catalog_schema_exists(settings: Settings): ensure_catalog_exists( ws=ws_client, catalog_name=catalog_name, - comment=f"System catalog for {settings.APP_NAME}" + comment=f"System catalog for {settings.APP_NAME}", ) logger.info(f"Catalog '{catalog_name}' is ready.") except Exception as e: # Map HTTPException or other errors to ConnectionError for consistency logger.critical( - f"Failed to ensure catalog '{catalog_name}': {e}. Check permissions.", - exc_info=True + f"Failed to ensure catalog '{catalog_name}': {e}. Check permissions.", exc_info=True ) - raise ConnectionError( - f"Failed to create required catalog '{catalog_name}': {e}" - ) from e + raise ConnectionError(f"Failed to create required catalog '{catalog_name}': {e}") from e try: logger.debug(f"Ensuring schema exists: {full_schema_name}") @@ -619,13 +638,13 @@ def ensure_catalog_schema_exists(settings: Settings): ws=ws_client, catalog_name=catalog_name, schema_name=schema_name, - comment=f"System schema for {settings.APP_NAME}" + comment=f"System schema for {settings.APP_NAME}", ) logger.info(f"Schema '{full_schema_name}' is ready.") except Exception as e: logger.critical( - f"Failed to ensure schema '{full_schema_name}': {e}. Check permissions.", - exc_info=True + f"Failed to ensure schema '{full_schema_name}': {e}. Check permissions.", + exc_info=True, ) raise ConnectionError( f"Failed to create required schema '{full_schema_name}': {e}" @@ -638,15 +657,14 @@ def ensure_catalog_schema_exists(settings: Settings): raise except Exception as e: logger.critical( - f"An unexpected error occurred during catalog/schema check/creation: {e}", - exc_info=True + f"An unexpected error occurred during catalog/schema check/creation: {e}", exc_info=True ) - raise ConnectionError( - f"Failed during catalog/schema setup: {e}" - ) from e + raise ConnectionError(f"Failed during catalog/schema setup: {e}") from e -def get_current_db_revision(engine_connection: Connection, alembic_cfg: AlembicConfig) -> str | None: +def get_current_db_revision( + engine_connection: Connection, alembic_cfg: AlembicConfig +) -> str | None: """Gets the current revision of the database.""" context = MigrationContext.configure(engine_connection) return context.get_current_revision() @@ -662,7 +680,7 @@ def init_db() -> None: return logger.info("Initializing database engine and session factory...") - + # Ensure database and schema exist (creates them if needed in OAuth mode) ensure_database_and_schema_exist(settings) @@ -675,27 +693,31 @@ def init_db() -> None: logger.info("Connecting to database...") logger.info(f"> Database URL: {db_url}") logger.info(f"> Connect args: {connect_args}") - logger.info(f"> Pool settings: size={settings.DB_POOL_SIZE}, max_overflow={settings.DB_MAX_OVERFLOW}, " - f"timeout={settings.DB_POOL_TIMEOUT}s, recycle={settings.DB_POOL_RECYCLE}s") - - _engine = create_engine(db_url, - connect_args=connect_args, - echo=settings.DB_ECHO, - poolclass=pool.QueuePool, - pool_size=settings.DB_POOL_SIZE, - max_overflow=settings.DB_MAX_OVERFLOW, - pool_timeout=settings.DB_POOL_TIMEOUT, - pool_recycle=settings.DB_POOL_RECYCLE, - pool_pre_ping=True) - engine = _engine # Assign to public variable + logger.info( + f"> Pool settings: size={settings.DB_POOL_SIZE}, max_overflow={settings.DB_MAX_OVERFLOW}, " + f"timeout={settings.DB_POOL_TIMEOUT}s, recycle={settings.DB_POOL_RECYCLE}s" + ) + + _engine = create_engine( + db_url, + connect_args=connect_args, + echo=settings.DB_ECHO, + poolclass=pool.QueuePool, + pool_size=settings.DB_POOL_SIZE, + max_overflow=settings.DB_MAX_OVERFLOW, + pool_timeout=settings.DB_POOL_TIMEOUT, + pool_recycle=settings.DB_POOL_RECYCLE, + pool_pre_ping=True, + ) + engine = _engine # Assign to public variable # Add OAuth token injection if not in LOCAL mode if not settings.ENV.upper().startswith("LOCAL"): logger.info("Setting up OAuth token injection for Lakebase...") - + # Generate initial token refresh_oauth_token(settings) - + # Register event handler to inject tokens for new connections # Use 'do_connect' event to inject password at connection creation time @event.listens_for(_engine, "do_connect") @@ -704,7 +726,7 @@ def inject_token_on_connect(dialect, conn_rec, cargs, cparams): if _oauth_token: cparams["password"] = _oauth_token logger.debug("Injected OAuth token into new database connection") - + # Start background refresh thread start_token_refresh_background(settings) logger.info("OAuth authentication configured successfully") @@ -738,13 +760,21 @@ def set_search_path(dbapi_connection, connection_record): logger.info("Database engine and session factory initialized.") # --- Alembic Migration Logic --- # - alembic_cfg_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..' , 'alembic.ini')) - alembic_script_location = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'alembic')) + alembic_cfg_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..", "alembic.ini") + ) + alembic_script_location = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..", "alembic") + ) logger.info(f"Loading Alembic configuration from: {alembic_cfg_path}") logger.info(f"Alembic script location: {alembic_script_location}") alembic_cfg = AlembicConfig(alembic_cfg_path) - alembic_cfg.set_main_option("sqlalchemy.url", db_url.replace("%", "%%")) # Ensure Alembic uses the same URL - alembic_cfg.set_main_option("script_location", alembic_script_location) # Set absolute path to alembic directory + alembic_cfg.set_main_option( + "sqlalchemy.url", db_url.replace("%", "%%") + ) # Ensure Alembic uses the same URL + alembic_cfg.set_main_option( + "script_location", alembic_script_location + ) # Set absolute path to alembic directory script = ScriptDirectory.from_config(alembic_cfg) head_revision = script.get_current_head() logger.info(f"Alembic Head Revision: {head_revision}") @@ -765,23 +795,29 @@ def set_search_path(dbapi_connection, connection_record): # Handle migrations based on database state if db_revision is None: # Fresh database - will use create_all() + stamp later - logger.info("Fresh database detected (no Alembic version table). Will initialize with create_all() and stamp.") + logger.info( + "Fresh database detected (no Alembic version table). Will initialize with create_all() and stamp." + ) elif db_revision != head_revision: # Existing database needs migration - logger.info(f"Database revision '{db_revision}' differs from head revision '{head_revision}'.") + logger.info( + f"Database revision '{db_revision}' differs from head revision '{head_revision}'." + ) logger.info("Attempting Alembic upgrade to head...") try: # CRITICAL: Dispose the main engine's connection pool before running Alembic. # This releases all pooled connections that might hold implicit transaction state. - logger.info("Disposing engine pool to release connections before Alembic migration...") + logger.info( + "Disposing engine pool to release connections before Alembic migration..." + ) _engine.dispose() - + # Run Alembic upgrade via subprocess to avoid hanging issues with # Alembic's internal runpy-based execution when called programmatically. # This ensures proper process isolation and cleanup. + import shutil import subprocess import sys - import shutil # For Lakebase (OAuth mode), we need to pass the token to the subprocess # via environment variable since the subprocess can't access our in-memory token @@ -798,7 +834,11 @@ def set_search_path(dbapi_connection, connection_record): python_executable = None # Try sys.executable first if it exists and is absolute - if sys.executable and os.path.isabs(sys.executable) and os.path.exists(sys.executable): + if ( + sys.executable + and os.path.isabs(sys.executable) + and os.path.exists(sys.executable) + ): python_executable = sys.executable logger.info(f"Using sys.executable: {python_executable}") else: @@ -810,8 +850,12 @@ def set_search_path(dbapi_connection, connection_record): os.path.join(os.getcwd(), "venv", "bin", "python3"), os.path.join(os.getcwd(), "venv", "bin", "python"), # Try relative to the backend src directory - os.path.join(os.path.dirname(__file__), "..", "..", ".venv", "bin", "python3"), - os.path.join(os.path.dirname(__file__), "..", "..", ".venv", "bin", "python"), + os.path.join( + os.path.dirname(__file__), "..", "..", ".venv", "bin", "python3" + ), + os.path.join( + os.path.dirname(__file__), "..", "..", ".venv", "bin", "python" + ), # Try system Python as fallback "/usr/local/bin/python3", "/usr/bin/python3", @@ -828,7 +872,9 @@ def set_search_path(dbapi_connection, connection_record): if not python_executable: alembic_path = shutil.which("alembic") if alembic_path: - logger.warning("Could not find Python executable, will try running alembic command directly") + logger.warning( + "Could not find Python executable, will try running alembic command directly" + ) python_executable = None # Will use alembic directly below else: raise RuntimeError( @@ -851,18 +897,22 @@ def set_search_path(dbapi_connection, connection_record): capture_output=True, text=True, timeout=300, # 5 minute timeout - env=subprocess_env + env=subprocess_env, ) if result.returncode != 0: logger.error(f"Alembic upgrade stderr: {result.stderr}") - raise RuntimeError(f"Alembic upgrade failed with exit code {result.returncode}: {result.stderr}") + raise RuntimeError( + f"Alembic upgrade failed with exit code {result.returncode}: {result.stderr}" + ) logger.info(f"Alembic upgrade output: {result.stdout}") logger.info("✓ Alembic upgrade to head COMPLETED.") except subprocess.TimeoutExpired: logger.critical("Alembic upgrade timed out after 5 minutes!") raise RuntimeError("Alembic upgrade timed out") except Exception as alembic_err: - logger.critical("Alembic upgrade failed! Manual intervention may be required.", exc_info=True) + logger.critical( + "Alembic upgrade failed! Manual intervention may be required.", exc_info=True + ) raise RuntimeError("Failed to upgrade database schema.") from alembic_err else: logger.info("✓ Database schema is up to date according to Alembic.") @@ -882,43 +932,51 @@ def set_search_path(dbapi_connection, connection_record): # with engine.connect() as connection: # connection.execute(sqlalchemy.text(f"CREATE SCHEMA IF NOT EXISTS {schema_to_create_in}")) # connection.commit() - logger.info(f"PostgreSQL: Tables will be targeted for schema '{schema_to_create_in}' via search_path or model definitions.") + logger.info( + f"PostgreSQL: Tables will be targeted for schema '{schema_to_create_in}' via search_path or model definitions." + ) # No Databricks-specific metadata modifications required - - # Note: Schema creation and APP_DB_DROP_ON_START handling is done in + + # Note: Schema creation and APP_DB_DROP_ON_START handling is done in # ensure_database_and_schema_exist() before Alembic migrations run # Only use create_all() for fresh databases without Alembic version table # Once Alembic is tracking the schema, migrations handle all schema changes if db_revision is None: - logger.info("Fresh database detected (no Alembic version). Using create_all() for initial setup...") - target_schema = settings.PGSCHEMA or 'public' - + logger.info( + "Fresh database detected (no Alembic version). Using create_all() for initial setup..." + ) + target_schema = settings.PGSCHEMA or "public" + # Create all tables in the target schema with _engine.begin() as connection: # Explicitly set search_path to ensure tables are created in correct schema connection.execute(text(f'SET search_path TO "{target_schema}"')) Base.metadata.create_all(bind=connection, checkfirst=True) logger.info("✓ Database tables created by create_all.") - + # Stamp the database with the baseline migration - # Using direct INSERT instead of alembic_command.stamp() to avoid + # Using direct INSERT instead of alembic_command.stamp() to avoid # hanging issues with Alembic's runpy-based execution in some environments logger.info("Stamping database with baseline migration...") try: with _engine.begin() as connection: connection.execute(text(f'SET search_path TO "{target_schema}"')) # Create alembic_version table if needed and insert head revision - connection.execute(text(""" + connection.execute( + text(""" CREATE TABLE IF NOT EXISTS alembic_version ( version_num VARCHAR(32) NOT NULL, CONSTRAINT alembic_version_pkc PRIMARY KEY (version_num) ) - """)) + """) + ) connection.execute(text("DELETE FROM alembic_version")) - connection.execute(text("INSERT INTO alembic_version (version_num) VALUES (:rev)"), - {"rev": head_revision}) + connection.execute( + text("INSERT INTO alembic_version (version_num) VALUES (:rev)"), + {"rev": head_revision}, + ) logger.info(f"✓ Database stamped with baseline migration: {head_revision}") except Exception as stamp_err: logger.error(f"Failed to stamp database: {stamp_err}", exc_info=True) @@ -929,21 +987,24 @@ def set_search_path(dbapi_connection, connection_record): logger.critical(f"Database initialization failed: {e}", exc_info=True) _engine = None _SessionLocal = None - engine = None # Reset public engine on failure + engine = None # Reset public engine on failure raise ConnectionError("Failed to initialize database connection or run migrations.") from e + def get_db(): global _SessionLocal if _SessionLocal is None: logger.error("Database not initialized. Cannot get session.") # Consider raising HTTPException for FastAPI to handle gracefully if this occurs at runtime - raise RuntimeError("Database session factory is not available. Database might not have been initialized correctly.") - + raise RuntimeError( + "Database session factory is not available. Database might not have been initialized correctly." + ) + db = _SessionLocal() try: yield db db.commit() # Commit the transaction on successful completion of the request - except Exception as e: # Catch all exceptions to ensure rollback + except Exception as e: # Catch all exceptions to ensure rollback logger.error(f"Error during database session for request, rolling back: {e}", exc_info=True) db.rollback() # Re-raise the exception so FastAPI can handle it appropriately @@ -952,6 +1013,7 @@ def get_db(): finally: db.close() + @contextmanager def get_db_session(): """Context manager that yields a SQLAlchemy session. @@ -966,7 +1028,9 @@ def get_db_session(): init_db() except Exception as e: logger.critical(f"Failed to initialize database session factory: {e}", exc_info=True) - raise RuntimeError("Database session factory not available and initialization failed.") from e + raise RuntimeError( + "Database session factory not available and initialization failed." + ) from e session = _SessionLocal() try: @@ -979,12 +1043,14 @@ def get_db_session(): finally: session.close() + def get_engine(): global _engine if _engine is None: raise RuntimeError("Database engine not initialized.") return _engine + def get_session_factory(): global _SessionLocal if _SessionLocal is None: @@ -995,25 +1061,26 @@ def get_session_factory(): def set_session_factory(factory): """ Set the global session factory. Used by tests to inject a test database session factory. - + Args: factory: A sessionmaker instance or callable that returns database sessions """ global _SessionLocal _SessionLocal = factory + def cleanup_db(): """Cleanup database resources including OAuth token refresh.""" global _engine, _SessionLocal, engine - + # Stop token refresh if running stop_token_refresh_background() - + # Dispose engine if _engine: _engine.dispose() logger.info("Database engine disposed") - + _engine = None _SessionLocal = None engine = None diff --git a/src/backend/src/common/dependencies.py b/src/backend/src/common/dependencies.py index c9d33d76..6c6ccb16 100644 --- a/src/backend/src/common/dependencies.py +++ b/src/backend/src/common/dependencies.py @@ -20,7 +20,8 @@ from src.controller.tags_manager import TagsManager # Import TagsManager from src.controller.workspace_manager import WorkspaceManager # Import WorkspaceManager from src.controller.change_log_manager import ChangeLogManager # Import ChangeLogManager -from src.controller.datasets_manager import DatasetsManager # Import DatasetsManager +from src.controller.datasets_manager import DatasetsManager +from src.controller.graph_explorer_manager import GraphExplorerManager # Import DatasetsManager # Import base dependencies from src.common.database import get_session_factory # Import the factory function @@ -47,6 +48,7 @@ get_workspace_manager, get_change_log_manager, get_datasets_manager, + get_graph_explorer_manager, ) # Import workspace client getter separately as it might be structured differently from src.common.workspace_client import get_workspace_client_dependency # Fixed to use proper wrapper @@ -120,6 +122,7 @@ async def get_current_user(user_details: UserInfo = Depends(get_user_details_fro WorkspaceManagerDep = Annotated[WorkspaceManager, Depends(get_workspace_manager)] ChangeLogManagerDep = Annotated[ChangeLogManager, Depends(get_change_log_manager)] DatasetsManagerDep = Annotated[DatasetsManager, Depends(get_datasets_manager)] +GraphExplorerManagerDep = Annotated[GraphExplorerManager, Depends(get_graph_explorer_manager)] # Permission Checker Dependency PermissionCheckerDep = AuthorizationManagerDep diff --git a/src/backend/src/common/features.py b/src/backend/src/common/features.py index f344e8f6..d6f1e702 100644 --- a/src/backend/src/common/features.py +++ b/src/backend/src/common/features.py @@ -164,6 +164,11 @@ class FeatureAccessLevel(str, Enum): 'name': 'Comments & Ratings', 'allowed_levels': READ_WRITE_ADMIN_LEVELS # READ_WRITE to add, ADMIN to manage all }, + # Graph Explorer + 'graph-explorer': { + 'name': 'Graph Explorer', + 'allowed_levels': READ_WRITE_ADMIN_LEVELS + }, # 'about': { ... } # About page doesn't need explicit permissions here } diff --git a/src/backend/src/common/manager_dependencies.py b/src/backend/src/common/manager_dependencies.py index 9436e192..4a901c4a 100644 --- a/src/backend/src/common/manager_dependencies.py +++ b/src/backend/src/common/manager_dependencies.py @@ -25,6 +25,7 @@ from src.controller.change_log_manager import ChangeLogManager from src.controller.datasets_manager import DatasetsManager from src.controller.delivery_service import DeliveryService +from src.controller.graph_explorer_manager import GraphExplorerManager # Import other dependencies needed by these providers from src.common.database import get_db @@ -181,6 +182,13 @@ def get_datasets_manager(request: Request) -> DatasetsManager: # Add getters for Compliance, Estate, MDM, Security, Entitlements, Catalog Commander managers when they are added +def get_graph_explorer_manager(request: Request) -> GraphExplorerManager: + manager = getattr(request.app.state, 'graph_explorer_manager', None) + if not manager: + logger.critical("GraphExplorerManager not found in application state during request!") + raise HTTPException(status_code=503, detail="Graph Explorer service not configured.") + return manager + def get_delivery_service(request: Request) -> DeliveryService: """Get the DeliveryService for multi-mode delivery of governance changes.""" service = getattr(request.app.state, "delivery_service", None) diff --git a/src/backend/src/controller/data_products_manager.py b/src/backend/src/controller/data_products_manager.py index e59eccde..67d54bd9 100644 --- a/src/backend/src/controller/data_products_manager.py +++ b/src/backend/src/controller/data_products_manager.py @@ -659,11 +659,19 @@ def publish_product(self, product_id: str, current_user: Optional[str] = None) - if not product_db: raise ValueError(f"Data product with ID {product_id} not found") + # Validate that product has at least one output port with a data contract + output_ports = product_db.output_ports or [] + ports_with_contracts = [p for p in output_ports if p.contract_id] + if not ports_with_contracts: + raise ValueError( + "Cannot publish product: Product must have at least one output port with a data contract assigned" + ) + # Validate that all output ports have data contracts - if product_db.output_ports: + if output_ports: ports_without_contracts = [ - port.name for port in product_db.output_ports - if not port.data_contract_id + port.name for port in output_ports + if not port.contract_id ] if ports_without_contracts: raise ValueError( @@ -677,9 +685,9 @@ def publish_product(self, product_id: str, current_user: Optional[str] = None) - valid_contract_statuses = ['approved', 'active', 'certified'] contracts_not_approved = [] - for port in product_db.output_ports: - if port.data_contract_id: - contract = data_contract_repo.get(db=self._db, id=port.data_contract_id) + for port in output_ports: + if port.contract_id: + contract = data_contract_repo.get(db=self._db, id=port.contract_id) if contract: contract_status = (contract.status or '').lower() if contract_status not in valid_contract_statuses: diff --git a/src/backend/src/controller/graph_explorer_manager.py b/src/backend/src/controller/graph_explorer_manager.py new file mode 100644 index 00000000..f33ddc61 --- /dev/null +++ b/src/backend/src/controller/graph_explorer_manager.py @@ -0,0 +1,792 @@ +""" +Graph Explorer Manager. + +Business logic for reading/writing property graph data +from/to Databricks Unity Catalog tables using the Statement Execution API. + +Table schema (edge-centric, same as graph-demo): + node_start_id STRING NOT NULL, + node_start_key STRING NOT NULL, -- node type + relationship STRING NOT NULL, + node_end_id STRING NOT NULL, + node_end_key STRING NOT NULL, -- node type + node_start_properties STRING, -- JSON + node_end_properties STRING -- JSON + +Standalone nodes are stored as self-referencing edges with relationship = 'EXISTS'. +""" + +import json +import os +import re +import time +from typing import Any, Dict, List, Optional, Tuple + +from src.common.logging import get_logger + +logger = get_logger(__name__) + +# Regex for valid Unity Catalog table names +TABLE_NAME_PATTERN = re.compile(r'^[a-zA-Z0-9_`]+(\.[a-zA-Z0-9_`]+){0,2}$') + +DEFAULT_TABLE_NAME = "main.default.property_graph_entity_edges" + + +def _validate_table_name(table_name: str) -> str: + """Validate and sanitize a table name to prevent SQL injection.""" + cleaned = table_name.strip().strip('`') + if not TABLE_NAME_PATTERN.match(cleaned): + raise ValueError(f"Invalid table name: {table_name}") + # Wrap parts in backticks for safety + parts = cleaned.split('.') + return '.'.join(f'`{p.strip("`")}`' for p in parts) + + +def _escape_sql_string(value: str) -> str: + """Escape single quotes for SQL string literals.""" + return value.replace("'", "''") + + +class GraphExplorerManager: + """Manages graph explorer operations via Databricks SQL.""" + + def __init__(self, settings=None): + self.settings = settings + + def _execute_sql(self, ws_client, sql: str, warehouse_id: str) -> Tuple[List[str], List[List[Any]]]: + """Execute a SQL statement and return (columns, rows).""" + logger.debug(f"Executing SQL: {sql[:200]}...") + timeout = getattr(self.settings, "GRAPH_QUERY_TIMEOUT", "30s") if self.settings else "30s" + result = ws_client.statement_execution.execute_statement( + statement=sql, + warehouse_id=warehouse_id, + wait_timeout=timeout, + ) + + if result.status and result.status.state: + state = str(result.status.state) + if "FAILED" in state or "CANCELED" in state: + error_msg = result.status.error.message if result.status.error else "Query failed" + raise RuntimeError(f"SQL execution failed: {error_msg}") + + columns = [] + if result.manifest and result.manifest.schema and result.manifest.schema.columns: + columns = [col.name for col in result.manifest.schema.columns] + + rows = [] + if result.result and result.result.data_array: + rows = result.result.data_array + + return columns, rows + + def ensure_table_exists(self, ws_client, table_name: str, warehouse_id: str) -> None: + """Create the graph table if it doesn't exist.""" + safe_table = _validate_table_name(table_name) + sql = f"""CREATE TABLE IF NOT EXISTS {safe_table} ( + node_start_id STRING NOT NULL, + node_start_key STRING NOT NULL, + relationship STRING NOT NULL, + node_end_id STRING NOT NULL, + node_end_key STRING NOT NULL, + node_start_properties STRING, + node_end_properties STRING + )""" + self._execute_sql(ws_client, sql, warehouse_id) + logger.info(f"Ensured table exists: {safe_table}") + + def get_graph_data(self, ws_client, table_name: str, warehouse_id: str, max_rows: int = 0) -> Dict[str, Any]: + """Read graph data from a Databricks table and return as nodes + edges. + + Args: + max_rows: Maximum rows to fetch. 0 means use the server-side + default derived from GRAPH_MAX_EDGES setting. + """ + safe_table = _validate_table_name(table_name) + + # Determine effective row limit from settings or argument + if max_rows <= 0: + max_edges = getattr(self.settings, "GRAPH_MAX_EDGES", 10000) if self.settings else 10000 + effective_limit = max_edges + else: + effective_limit = max_rows + + # Count total rows for truncation info + count_sql = f"SELECT COUNT(*) FROM {safe_table}" + _, count_rows = self._execute_sql(ws_client, count_sql, warehouse_id) + total_rows = int(count_rows[0][0]) if count_rows and count_rows[0][0] else 0 + + sql = f"SELECT * FROM {safe_table} LIMIT {effective_limit}" + columns, rows = self._execute_sql(ws_client, sql, warehouse_id) + + if not columns: + return {"nodes": [], "edges": [], "truncated": False, "totalAvailable": None} + + # Build column index map + col_idx = {name: i for i, name in enumerate(columns)} + + # Track unique nodes and edges + nodes_map: Dict[str, Dict[str, Any]] = {} + edges: List[Dict[str, Any]] = [] + + for row in rows: + start_id = row[col_idx.get("node_start_id", 0)] or "" + start_key = row[col_idx.get("node_start_key", 1)] or "Node" + relationship = row[col_idx.get("relationship", 2)] or "" + end_id = row[col_idx.get("node_end_id", 3)] or "" + end_key = row[col_idx.get("node_end_key", 4)] or "Node" + start_props_raw = row[col_idx.get("node_start_properties", 5)] + end_props_raw = row[col_idx.get("node_end_properties", 6)] + + # Parse properties (handles VARIANT and STRING columns) + start_props = self._parse_props(start_props_raw) + end_props = self._parse_props(end_props_raw) + + # Extract label from properties or use id + start_label = start_props.pop("_label", None) or start_id + end_label = end_props.pop("_label", None) or end_id + + # Register start node + if start_id and start_id not in nodes_map: + nodes_map[start_id] = { + "id": start_id, + "label": start_label, + "type": start_key, + "properties": start_props, + "status": "existing", + } + + # Register end node + if end_id and end_id not in nodes_map: + nodes_map[end_id] = { + "id": end_id, + "label": end_label, + "type": end_key, + "properties": end_props, + "status": "existing", + } + + # Add edge (skip EXISTS self-references — those are standalone node markers) + if relationship != "EXISTS" and start_id and end_id: + edge_id = f"{start_id}-{relationship}-{end_id}" + edges.append({ + "id": edge_id, + "source": start_id, + "target": end_id, + "relationshipType": relationship, + "properties": {}, + "status": "existing", + }) + + truncated = total_rows > effective_limit + + return { + "nodes": list(nodes_map.values()), + "edges": edges, + "truncated": truncated, + "totalAvailable": total_rows if truncated else None, + } + + def get_neighbors( + self, + ws_client, + table_name: str, + warehouse_id: str, + node_id: str, + direction: str = "both", + edge_types: Optional[List[str]] = None, + limit: int = 25, + offset: int = 0, + ) -> Dict[str, Any]: + """Get the 1-hop neighborhood of a node. + + Args: + node_id: The node to expand from. + direction: 'outgoing', 'incoming', or 'both'. + edge_types: Optional list of relationship types to filter on. + limit: Max edges to return. + offset: Pagination offset. + + Returns: + Dict with nodes, edges, truncated, and totalAvailable. + """ + safe_table = _validate_table_name(table_name) + safe_id = _escape_sql_string(node_id) + + # Build direction filter + if direction == "outgoing": + direction_filter = f"node_start_id = '{safe_id}'" + elif direction == "incoming": + direction_filter = f"node_end_id = '{safe_id}'" + else: # both + direction_filter = f"(node_start_id = '{safe_id}' OR node_end_id = '{safe_id}')" + + # Exclude EXISTS self-references (standalone node markers) + base_where = f"{direction_filter} AND relationship != 'EXISTS'" + + # Optional edge type filter + if edge_types: + escaped_types = [f"'{_escape_sql_string(t)}'" for t in edge_types] + base_where += f" AND relationship IN ({', '.join(escaped_types)})" + + # Count total available (for truncation info) + count_sql = f"SELECT COUNT(*) FROM {safe_table} WHERE {base_where}" + _, count_rows = self._execute_sql(ws_client, count_sql, warehouse_id) + total_available = int(count_rows[0][0]) if count_rows and count_rows[0][0] else 0 + + # Fetch the edges with limit/offset + sql = f"SELECT * FROM {safe_table} WHERE {base_where} LIMIT {limit} OFFSET {offset}" + columns, rows = self._execute_sql(ws_client, sql, warehouse_id) + + if not columns: + return {"nodes": [], "edges": [], "truncated": False, "totalAvailable": 0} + + # Parse rows into nodes + edges (reuse the same logic as get_graph_data) + col_idx = {name: i for i, name in enumerate(columns)} + nodes_map: Dict[str, Dict[str, Any]] = {} + edges: List[Dict[str, Any]] = [] + + for row in rows: + start_id = row[col_idx.get("node_start_id", 0)] or "" + start_key = row[col_idx.get("node_start_key", 1)] or "Node" + relationship = row[col_idx.get("relationship", 2)] or "" + end_id = row[col_idx.get("node_end_id", 3)] or "" + end_key = row[col_idx.get("node_end_key", 4)] or "Node" + start_props_raw = row[col_idx.get("node_start_properties", 5)] + end_props_raw = row[col_idx.get("node_end_properties", 6)] + + start_props = self._parse_props(start_props_raw) + end_props = self._parse_props(end_props_raw) + start_label = start_props.pop("_label", None) or start_id + end_label = end_props.pop("_label", None) or end_id + + if start_id and start_id not in nodes_map: + nodes_map[start_id] = { + "id": start_id, + "label": start_label, + "type": start_key, + "properties": start_props, + "status": "existing", + } + + if end_id and end_id not in nodes_map: + nodes_map[end_id] = { + "id": end_id, + "label": end_label, + "type": end_key, + "properties": end_props, + "status": "existing", + } + + if relationship and start_id and end_id: + edge_id = f"{start_id}-{relationship}-{end_id}" + edges.append({ + "id": edge_id, + "source": start_id, + "target": end_id, + "relationshipType": relationship, + "properties": {}, + "status": "existing", + }) + + truncated = total_available > (offset + limit) + + return { + "nodes": list(nodes_map.values()), + "edges": edges, + "truncated": truncated, + "totalAvailable": total_available, + } + + def write_nodes_and_edges( + self, + ws_client, + table_name: str, + warehouse_id: str, + nodes: List[Dict[str, Any]], + edges: List[Dict[str, Any]], + ) -> Dict[str, Any]: + """Write new nodes and edges to the Databricks table.""" + safe_table = _validate_table_name(table_name) + nodes_written = 0 + edges_written = 0 + + # Build a lookup of all nodes (from provided list) for edge writing + node_map = {n["id"]: n for n in nodes} + + # Write standalone nodes as EXISTS self-referencing edges + for node in nodes: + node_id = _escape_sql_string(node["id"]) + node_type = _escape_sql_string(node.get("type", "Node")) + props = dict(node.get("properties", {})) + props["_label"] = node.get("label", node["id"]) + props_json = _escape_sql_string(json.dumps(props)) + + sql = f"""INSERT INTO {safe_table} + (node_start_id, node_start_key, relationship, node_end_id, node_end_key, node_start_properties, node_end_properties) + VALUES ('{node_id}', '{node_type}', 'EXISTS', '{node_id}', '{node_type}', '{props_json}', '{props_json}')""" + self._execute_sql(ws_client, sql, warehouse_id) + nodes_written += 1 + + # Write edges + for edge in edges: + source_id = _escape_sql_string(edge["source"]) + target_id = _escape_sql_string(edge["target"]) + rel_type = _escape_sql_string(edge.get("relationshipType", "RELATES_TO")) + + # Get source/target node info for properties + source_node = node_map.get(edge["source"], {}) + target_node = node_map.get(edge["target"], {}) + + source_type = _escape_sql_string(source_node.get("type", "Node")) + target_type = _escape_sql_string(target_node.get("type", "Node")) + + source_props = dict(source_node.get("properties", {})) + source_props["_label"] = source_node.get("label", edge["source"]) + target_props = dict(target_node.get("properties", {})) + target_props["_label"] = target_node.get("label", edge["target"]) + + source_props_json = _escape_sql_string(json.dumps(source_props)) + target_props_json = _escape_sql_string(json.dumps(target_props)) + + sql = f"""INSERT INTO {safe_table} + (node_start_id, node_start_key, relationship, node_end_id, node_end_key, node_start_properties, node_end_properties) + VALUES ('{source_id}', '{source_type}', '{rel_type}', '{target_id}', '{target_type}', '{source_props_json}', '{target_props_json}')""" + self._execute_sql(ws_client, sql, warehouse_id) + edges_written += 1 + + return {"nodesWritten": nodes_written, "edgesWritten": edges_written} + + def delete_node(self, ws_client, table_name: str, warehouse_id: str, node_id: str) -> None: + """Delete a node and all its connected edges.""" + safe_table = _validate_table_name(table_name) + safe_id = _escape_sql_string(node_id) + sql = f"DELETE FROM {safe_table} WHERE node_start_id = '{safe_id}' OR node_end_id = '{safe_id}'" + self._execute_sql(ws_client, sql, warehouse_id) + logger.info(f"Deleted node {node_id} and connected edges from {safe_table}") + + def delete_edge(self, ws_client, table_name: str, warehouse_id: str, source_id: str, target_id: str, relationship_type: str) -> None: + """Delete a specific edge.""" + safe_table = _validate_table_name(table_name) + safe_source = _escape_sql_string(source_id) + safe_target = _escape_sql_string(target_id) + safe_rel = _escape_sql_string(relationship_type) + sql = f"""DELETE FROM {safe_table} + WHERE node_start_id = '{safe_source}' + AND node_end_id = '{safe_target}' + AND relationship = '{safe_rel}'""" + self._execute_sql(ws_client, sql, warehouse_id) + logger.info(f"Deleted edge {source_id} -[{relationship_type}]-> {target_id} from {safe_table}") + + def update_node(self, ws_client, table_name: str, warehouse_id: str, node_id: str, label: str, node_type: str, properties: Dict) -> None: + """Update a node's properties, type, and label across all rows.""" + safe_table = _validate_table_name(table_name) + safe_id = _escape_sql_string(node_id) + safe_type = _escape_sql_string(node_type) + props = dict(properties) + props["_label"] = label + props_json = _escape_sql_string(json.dumps(props)) + + # Update as start node + sql = f"""UPDATE {safe_table} + SET node_start_key = '{safe_type}', node_start_properties = '{props_json}' + WHERE node_start_id = '{safe_id}'""" + self._execute_sql(ws_client, sql, warehouse_id) + + # Update as end node + sql = f"""UPDATE {safe_table} + SET node_end_key = '{safe_type}', node_end_properties = '{props_json}' + WHERE node_end_id = '{safe_id}'""" + self._execute_sql(ws_client, sql, warehouse_id) + logger.info(f"Updated node {node_id} in {safe_table}") + + # ----------------------------------------------------------------- + # LLM Configuration + # ----------------------------------------------------------------- + + def get_llm_config(self) -> Dict[str, Any]: + """Return the current LLM configuration state. + + The graph query panel is enabled whenever an LLM endpoint is + configured — it does NOT require the global LLM_ENABLED flag, which + gates the heavier conversational search feature. + """ + if not self.settings: + return {"enabled": False, "defaultModel": "", "maxTokens": 4096, "provider": "databricks"} + + endpoint = getattr(self.settings, "LLM_ENDPOINT", "") or "" + return { + "enabled": bool(endpoint), + "defaultModel": endpoint, + "maxTokens": 4096, + "provider": "databricks", + } + + # ----------------------------------------------------------------- + # Graph Query Translation (Cypher / Gremlin → SQL via LLM) + # ----------------------------------------------------------------- + + _TRANSLATE_SYSTEM_PROMPT = ( + "You are a SQL translation engine for Databricks Unity Catalog. " + "You translate Cypher or Gremlin graph queries into valid Databricks SQL.\n\n" + "The target table has the following schema:\n" + " node_start_id STRING, node_start_key STRING, relationship STRING,\n" + " node_end_id STRING, node_end_key STRING,\n" + " node_start_properties VARIANT, node_end_properties VARIANT\n\n" + "Rules:\n" + "- Each row represents an edge between two nodes.\n" + "- node_start_id/node_start_key/node_start_properties describe the source node.\n" + "- node_end_id/node_end_key/node_end_properties describe the target node.\n" + "- relationship is the edge type (e.g. 'ALLIED_WITH', 'BORN_ON').\n" + "- Node labels (types) are stored in node_start_key and node_end_key.\n" + "- Some tables use relationship = 'EXISTS' for standalone nodes (self-edges).\n" + " Check the graph data context to see if EXISTS is listed as a relationship type.\n" + " If EXISTS is NOT listed, nodes only appear as edge endpoints.\n" + "\n" + "CRITICAL — Subgraph pattern for node-centric queries:\n" + "When the user asks for specific nodes (e.g. 'dark characters'), use a CTE to find\n" + "matching node IDs first, then return only edges BETWEEN those nodes. This avoids\n" + "returning unrelated connected nodes.\n" + "Example:\n" + " WITH matched AS (\n" + " SELECT DISTINCT node_start_id AS id FROM table\n" + " WHERE node_start_key = 'Character'\n" + " AND get_json_object(CAST(node_start_properties AS STRING), '$.alignment') = 'Dark'\n" + " UNION\n" + " SELECT DISTINCT node_end_id AS id FROM table\n" + " WHERE node_end_key = 'Character'\n" + " AND get_json_object(CAST(node_end_properties AS STRING), '$.alignment') = 'Dark'\n" + " )\n" + " SELECT * FROM table\n" + " WHERE node_start_id IN (SELECT id FROM matched)\n" + " AND node_end_id IN (SELECT id FROM matched)\n" + " LIMIT 5000\n" + "\n" + "For relationship queries (e.g. 'characters who fought in battles'), don't use the CTE\n" + "pattern — just filter by relationship type and node types directly.\n" + "\n" + "- node_start_properties and node_end_properties are VARIANT columns.\n" + " IMPORTANT: Always use get_json_object with CAST to STRING for property access:\n" + " Example: get_json_object(CAST(node_start_properties AS STRING), '$.alignment') = 'Dark'\n" + " Example: CAST(get_json_object(CAST(node_start_properties AS STRING), '$.age') AS INT) > 30\n" + "- For natural language queries, use LIKE for fuzzy string matching:\n" + " Example: LOWER(get_json_object(CAST(node_start_properties AS STRING), '$.alignment')) LIKE '%dark%'\n" + "- For Cypher/Gremlin queries, use exact values from the graph data context.\n" + "- IMPORTANT: Always prefer the actual property values shown in the graph data context.\n" + "- IMPORTANT: Always SELECT * (all columns). Never select a subset of columns.\n" + " The downstream parser requires all columns to extract nodes and edges.\n" + "- Always add LIMIT 5000 unless the user query already contains a limit.\n" + "- Return ONLY the SQL query, nothing else — no markdown, no explanation.\n" + ) + + def _get_openai_client(self): + """Get an OpenAI-compatible client for the configured LLM endpoint.""" + from openai import OpenAI + + token = None + + # Explicit token from settings / env (local dev) + token = getattr(self.settings, "DATABRICKS_TOKEN", None) or os.environ.get("DATABRICKS_TOKEN") + if token: + logger.debug("Graph query LLM: using token from settings/environment") + + # Fall back to Databricks SDK OBO + if not token: + try: + from databricks.sdk.core import Config + config = Config() + headers = config.authenticate() + if headers and "Authorization" in headers: + auth_header = headers["Authorization"] + if auth_header.startswith("Bearer "): + token = auth_header[7:] + logger.debug("Graph query LLM: using SDK OBO token") + except Exception as sdk_err: + logger.debug(f"Could not get SDK token: {sdk_err}") + + if not token: + raise RuntimeError("No authentication token available for LLM endpoint.") + + base_url = getattr(self.settings, "LLM_BASE_URL", None) + if not base_url: + host = getattr(self.settings, "DATABRICKS_HOST", "") + if host: + host = host.rstrip("/") + if not host.startswith("http://") and not host.startswith("https://"): + host = f"https://{host}" + base_url = f"{host}/serving-endpoints" + + if not base_url: + raise RuntimeError("LLM_BASE_URL not configured.") + + return OpenAI(api_key=token, base_url=base_url) + + def _get_graph_schema(self, ws_client, table_name: str, warehouse_id: str) -> str: + """Build a concise schema summary from the actual graph data for LLM context. + + Returns a text block describing actual column types, node types, + relationship types, and sample property keys with distinct values. + """ + safe_table = _validate_table_name(table_name) + lines: List[str] = [] + + try: + # Step 0: Get actual column types via DESCRIBE + try: + sql = f"DESCRIBE TABLE {safe_table}" + cols, desc_rows = self._execute_sql(ws_client, sql, warehouse_id) + col_types = {r[0]: r[1] for r in desc_rows if r[0] and r[1]} + if col_types: + lines.append("Actual column types: " + ", ".join( + f"{k} {v}" for k, v in col_types.items() + )) + except Exception: + pass + + # Step 1: All distinct node types (start and end) + sql = ( + f"SELECT DISTINCT node_start_key FROM {safe_table} " + f"UNION SELECT DISTINCT node_end_key FROM {safe_table} LIMIT 50" + ) + _, rows = self._execute_sql(ws_client, sql, warehouse_id) + node_types = sorted(set(r[0] for r in rows if r[0])) + if node_types: + lines.append(f"Node types: {', '.join(node_types)}") + + # Step 2: All distinct relationship types + sql = f"SELECT DISTINCT relationship FROM {safe_table} LIMIT 50" + _, rows = self._execute_sql(ws_client, sql, warehouse_id) + rel_types = sorted(set(r[0] for r in rows if r[0])) + if rel_types: + lines.append(f"Relationship types: {', '.join(rel_types)}") + if "EXISTS" not in rel_types: + lines.append("NOTE: No 'EXISTS' relationship found — nodes only appear as edge endpoints, not as standalone rows.") + + # Step 3: Sample properties per node type — use raw rows (no EXISTS filter) + for ntype in node_types[:5]: + safe_type = _escape_sql_string(ntype) + sql = ( + f"SELECT node_start_properties FROM {safe_table} " + f"WHERE node_start_key = '{safe_type}' " + f"AND node_start_properties IS NOT NULL LIMIT 10" + ) + _, rows = self._execute_sql(ws_client, sql, warehouse_id) + + # Collect all property keys and their distinct values + key_values: Dict[str, set] = {} + for row in rows: + props = self._parse_props(row[0]) + for k, v in props.items(): + if k.startswith("_"): + continue + if k not in key_values: + key_values[k] = set() + if v is not None and len(key_values[k]) < 5: + key_values[k].add(str(v)[:50]) + + if key_values: + prop_parts = [] + for k, vals in list(key_values.items())[:8]: + distinct = sorted(vals)[:4] + prop_parts.append(f"{k} (e.g. {', '.join(repr(v) for v in distinct)})") + lines.append(f" {ntype} properties: {'; '.join(prop_parts)}") + + # Step 4: If no properties were found, dump one raw row for debugging + if not any("properties:" in l for l in lines): + sql = f"SELECT * FROM {safe_table} LIMIT 1" + cols, rows = self._execute_sql(ws_client, sql, warehouse_id) + if rows: + raw_sample = {cols[i]: rows[0][i] for i in range(len(cols))} + lines.append(f"Sample raw row: {json.dumps(raw_sample, default=str)[:500]}") + + except Exception as e: + logger.warning(f"Failed to fetch graph schema for LLM context: {e}", exc_info=True) + + return "\n".join(lines) if lines else "No schema information available." + + def _translate_to_sql(self, query: str, language: str, table_name: str, + graph_schema: str = "") -> str: + """Translate a natural language, Cypher, or Gremlin query to Databricks SQL via LLM.""" + client = self._get_openai_client() + endpoint = getattr(self.settings, "LLM_ENDPOINT", "") or "" + + safe_table = _validate_table_name(table_name) + + schema_block = "" + if graph_schema: + schema_block = f"\nGraph data context:\n{graph_schema}\n" + + if language == "natural": + user_message = ( + f"Table: {safe_table}\n" + f"{schema_block}" + f"User request (natural language): {query}\n\n" + "Write a Databricks SQL SELECT statement that answers this request. " + "Return graph data (nodes and edges) that match the user's intent." + ) + else: + user_message = ( + f"Table: {safe_table}\n" + f"{schema_block}" + f"Language: {language}\n" + f"Query: {query}\n\n" + "Translate this to a Databricks SQL SELECT statement." + ) + + response = client.chat.completions.create( + model=endpoint, + messages=[ + {"role": "system", "content": self._TRANSLATE_SYSTEM_PROMPT}, + {"role": "user", "content": user_message}, + ], + max_tokens=2048, + temperature=0, + ) + + sql = response.choices[0].message.content.strip() + # Strip markdown fencing if the model wraps it + if sql.startswith("```"): + sql = re.sub(r"^```(?:sql)?\s*", "", sql) + sql = re.sub(r"\s*```$", "", sql) + return sql.strip() + + @staticmethod + def _parse_props(raw: Any) -> Dict[str, Any]: + """Parse a properties value that may be a VARIANT JSON string, dict, or None.""" + if raw is None: + return {} + if isinstance(raw, dict): + return raw + if isinstance(raw, str): + try: + parsed = json.loads(raw) + return parsed if isinstance(parsed, dict) else {} + except (json.JSONDecodeError, TypeError): + return {} + return {} + + def _parse_query_results(self, columns: List[str], rows: List[List[Any]]) -> Dict[str, Any]: + """Parse raw SQL result rows into nodes and edges, mirroring get_graph_data.""" + col_idx = {name: i for i, name in enumerate(columns)} + nodes_map: Dict[str, Dict[str, Any]] = {} + edges: List[Dict[str, Any]] = [] + + has_edge_columns = "relationship" in col_idx + + for row in rows: + start_id = row[col_idx["node_start_id"]] if "node_start_id" in col_idx else None + start_key = row[col_idx.get("node_start_key", -1)] if "node_start_key" in col_idx else "Node" + relationship = row[col_idx["relationship"]] if "relationship" in col_idx else None + end_id = row[col_idx["node_end_id"]] if "node_end_id" in col_idx else None + end_key = row[col_idx.get("node_end_key", -1)] if "node_end_key" in col_idx else "Node" + start_props_raw = row[col_idx.get("node_start_properties", -1)] if "node_start_properties" in col_idx else None + end_props_raw = row[col_idx.get("node_end_properties", -1)] if "node_end_properties" in col_idx else None + + # Parse properties — handles VARIANT (returned as JSON string by Statement Execution API) and dicts + start_props = self._parse_props(start_props_raw) + end_props = self._parse_props(end_props_raw) + + start_label = start_props.pop("_label", None) or (start_id or "") + end_label = end_props.pop("_label", None) or (end_id or "") + + if start_id and start_id not in nodes_map: + nodes_map[start_id] = { + "id": start_id, + "label": start_label, + "type": start_key or "Node", + "properties": start_props, + "status": "existing", + } + + if end_id and end_id not in nodes_map: + nodes_map[end_id] = { + "id": end_id, + "label": end_label, + "type": end_key or "Node", + "properties": end_props, + "status": "existing", + } + + if relationship and relationship != "EXISTS" and start_id and end_id: + edge_id = f"{start_id}-{relationship}-{end_id}" + edges.append({ + "id": edge_id, + "source": start_id, + "target": end_id, + "relationshipType": relationship, + "properties": {}, + "status": "existing", + }) + + return { + "nodes": list(nodes_map.values()), + "edges": edges, + "hasEdgeColumns": has_edge_columns, + } + + def execute_graph_query( + self, + ws_client, + warehouse_id: str, + query: str, + language: str, + table_name: str, + override_sql: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Translate a Cypher/Gremlin query to SQL via LLM, execute it, and + return graph nodes/edges. If *override_sql* is provided, skip + translation and execute the SQL directly. + """ + t0 = time.time() + endpoint = getattr(self.settings, "LLM_ENDPOINT", "") or "" + + # Step 1: get SQL + graph_schema = "" + if override_sql: + sql = override_sql + else: + # Fetch graph schema context so the LLM knows what types/properties exist + graph_schema = self._get_graph_schema(ws_client, table_name, warehouse_id) + logger.info(f"Graph schema context:\n{graph_schema}") + sql = self._translate_to_sql(query, language, table_name, graph_schema=graph_schema) + logger.info(f"Graph query SQL: {sql}") + + # Step 2: execute SQL + try: + columns, rows = self._execute_sql(ws_client, sql, warehouse_id) + except RuntimeError as e: + return { + "success": False, + "nodes": [], + "edges": [], + "sql": sql, + "language": language, + "originalQuery": query, + "message": str(e), + } + + # Step 3: parse results + parsed = self._parse_query_results(columns, rows) + duration = f"{time.time() - t0:.2f}s" + + # Detect vertex-only queries: CTE pattern means the user asked for specific nodes + is_vertex_only = bool(re.search(r"(?i)\bWITH\s+matched\s+AS\b", sql)) + + return { + "success": True, + "nodes": parsed["nodes"], + "edges": parsed["edges"] if not is_vertex_only else [], + "sql": sql, + "language": language, + "originalQuery": query, + "rawRowCount": len(rows), + "hasEdgeColumns": parsed.get("hasEdgeColumns", False), + "vertexOnly": is_vertex_only, + "metadata": { + "source": "databricks", + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + "duration": duration, + "translationModel": endpoint, + "graphSchema": graph_schema if graph_schema else None, + }, + } diff --git a/src/backend/src/controller/notifications_manager.py b/src/backend/src/controller/notifications_manager.py index 89e00256..7813989c 100644 --- a/src/backend/src/controller/notifications_manager.py +++ b/src/backend/src/controller/notifications_manager.py @@ -63,7 +63,9 @@ def get_notifications(self, db: Session, user_info: Optional[UserInfo] = None) - # --- Filtering logic (similar to before, but uses API models and SettingsManager) --- if not user_info: # Return only broadcast notifications if no user info - return [n for n in all_notifications_api if not n.recipient] + broadcast = [n for n in all_notifications_api if not n.recipient] + broadcast.sort(key=lambda x: x.created_at, reverse=True) + return broadcast user_groups = set(user_info.groups or []) user_email = user_info.email @@ -276,12 +278,31 @@ async def create_notification( def delete_notification(self, db: Session, notification_id: str) -> bool: """Delete a notification by ID using the repository.""" try: + db_obj = self._repo.get(db=db, id=notification_id) + if db_obj is None: + return False + if not db_obj.can_delete: + return False deleted_obj = self._repo.remove(db=db, id=notification_id) return deleted_obj is not None except Exception as e: logger.error(f"Error deleting notification {notification_id}: {e}", exc_info=True) raise + def mark_all_as_read(self, db: Session, user_info: UserInfo) -> int: + """Mark all notifications for a user as read. Returns count of updated notifications.""" + if not user_info: + return 0 + all_notifications = self._repo.get_multi(db=db, limit=1000) + count = 0 + user_email = user_info.email + for db_obj in all_notifications: + if db_obj.recipient == user_email and not db_obj.read: + self._repo.update(db=db, db_obj=db_obj, obj_in={"read": True}) + count += 1 + db.commit() + return count + def mark_notification_read(self, db: Session, notification_id: str) -> Optional[Notification]: """Mark a notification as read using the repository.""" try: diff --git a/src/backend/src/controller/security_features_manager.py b/src/backend/src/controller/security_features_manager.py index 1ed0d36b..b7ce8e14 100644 --- a/src/backend/src/controller/security_features_manager.py +++ b/src/backend/src/controller/security_features_manager.py @@ -43,14 +43,14 @@ def list_features(self) -> List[SecurityFeature]: def update_feature(self, feature_id: str, feature: SecurityFeature) -> Optional[SecurityFeature]: if feature_id not in self.features: - logging.warning(f"Security feature not found: {feature_id}") + logger.warning(f"Security feature not found: {feature_id}") return None self.features[feature_id] = feature return feature def delete_feature(self, feature_id: str) -> bool: if feature_id not in self.features: - logging.warning(f"Security feature not found: {feature_id}") + logger.warning(f"Security feature not found: {feature_id}") return False del self.features[feature_id] return True @@ -119,5 +119,5 @@ def save_to_yaml(self, yaml_path: Path) -> None: with open(yaml_path, 'w') as f: yaml.dump(data, f) except Exception as e: - logging.exception(f"Error saving security features to YAML: {e!s}") + logger.exception(f"Error saving security features to YAML: {e!s}") raise diff --git a/src/backend/src/models/graph_explorer.py b/src/backend/src/models/graph_explorer.py new file mode 100644 index 00000000..3aae1420 --- /dev/null +++ b/src/backend/src/models/graph_explorer.py @@ -0,0 +1,134 @@ +""" +Pydantic models for Graph Explorer API. +""" + +from enum import Enum +from typing import Any, Dict, List, Optional +from pydantic import BaseModel, Field + + +class GraphNodeRequest(BaseModel): + """Request model for creating/updating a node.""" + id: str + label: str + type: str = "Node" + properties: Dict[str, Any] = Field(default_factory=dict) + + +class GraphEdgeRequest(BaseModel): + """Request model for creating/updating an edge.""" + id: Optional[str] = None + source: str + target: str + relationshipType: str + properties: Dict[str, Any] = Field(default_factory=dict) + + +class SaveGraphRequest(BaseModel): + """Request model for saving new/modified graph data.""" + tableName: str + nodes: List[GraphNodeRequest] = Field(default_factory=list) + edges: List[GraphEdgeRequest] = Field(default_factory=list) + + +class DeleteNodeRequest(BaseModel): + """Request model for deleting a node.""" + tableName: str + nodeId: str + + +class DeleteEdgeRequest(BaseModel): + """Request model for deleting an edge.""" + tableName: str + sourceId: str + targetId: str + relationshipType: str + + +class UpdateNodeRequest(BaseModel): + """Request model for updating a node.""" + tableName: str + nodeId: str + label: str + type: str = "Node" + properties: Dict[str, Any] = Field(default_factory=dict) + + +class NeighborDirection(str, Enum): + """Direction for neighbor expansion.""" + OUTGOING = "outgoing" + INCOMING = "incoming" + BOTH = "both" + + +class GraphDataResponse(BaseModel): + """Response model for graph data.""" + nodes: List[Dict[str, Any]] + edges: List[Dict[str, Any]] + truncated: bool = False + totalAvailable: Optional[int] = None + + +class SaveGraphResponse(BaseModel): + """Response model for save operation.""" + nodesWritten: int + edgesWritten: int + + +class EnsureTableResponse(BaseModel): + """Response model for table creation.""" + tableName: str + status: str = "ok" + + +# --------------------------------------------------------------------------- +# Graph Query (Cypher / Gremlin → SQL via LLM) +# --------------------------------------------------------------------------- + +class GraphQueryRequest(BaseModel): + """Request model for executing a Cypher/Gremlin graph query.""" + query: str + language: str = Field(default="cypher", description="Query language: 'cypher' or 'gremlin'") + tableName: Optional[str] = None + modelEndpoint: Optional[str] = None + sql: Optional[str] = Field(default=None, description="Override SQL — skip LLM translation and execute directly") + + +class GraphQueryResponseMetadata(BaseModel): + """Metadata about the query execution.""" + source: str = "databricks" + timestamp: Optional[str] = None + duration: Optional[str] = None + translationModel: Optional[str] = None + graphSchema: Optional[str] = None + + +class GraphQueryResponse(BaseModel): + """Response model for a graph query execution.""" + success: bool + nodes: List[Dict[str, Any]] = Field(default_factory=list) + edges: List[Dict[str, Any]] = Field(default_factory=list) + sql: str = "" + language: str = "" + originalQuery: str = "" + rawRowCount: Optional[int] = None + hasEdgeColumns: Optional[bool] = None + vertexOnly: Optional[bool] = None + message: Optional[str] = None + metadata: Optional[GraphQueryResponseMetadata] = None + + +class GraphLimitsResponse(BaseModel): + """Response model for graph safety limits.""" + maxNodes: int = 5000 + maxEdges: int = 10000 + neighborLimit: int = 50 + queryTimeout: str = "30s" + + +class LlmConfigResponse(BaseModel): + """Response model for LLM configuration status.""" + enabled: bool + defaultModel: str = "" + maxTokens: int = 4096 + provider: str = "databricks" diff --git a/src/backend/src/routes/graph_explorer_routes.py b/src/backend/src/routes/graph_explorer_routes.py new file mode 100644 index 00000000..984fe5c7 --- /dev/null +++ b/src/backend/src/routes/graph_explorer_routes.py @@ -0,0 +1,270 @@ +""" +FastAPI routes for Graph Explorer. + +All operations go through Databricks Statement Execution API. +The table name is passed as a query parameter or in the request body. +""" + +from typing import List, Optional +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Query, Request + +from src.common.authorization import PermissionChecker +from src.common.config import Settings, get_settings +from src.common.dependencies import GraphExplorerManagerDep +from src.common.features import FeatureAccessLevel +from src.common.logging import get_logger +from src.common.workspace_client import get_workspace_client +from src.controller.graph_explorer_manager import GraphExplorerManager, DEFAULT_TABLE_NAME +from src.models.graph_explorer import ( + DeleteEdgeRequest, + DeleteNodeRequest, + EnsureTableResponse, + GraphDataResponse, + GraphLimitsResponse, + GraphQueryRequest, + GraphQueryResponse, + LlmConfigResponse, + NeighborDirection, + SaveGraphRequest, + SaveGraphResponse, + UpdateNodeRequest, +) + +logger = get_logger(__name__) + +GRAPH_EXPLORER_FEATURE_ID = 'graph-explorer' + +router = APIRouter(prefix="/api/graph-explorer", tags=["graph-explorer"]) + + +def _get_ws_and_warehouse(settings: Settings): + """Get workspace client and warehouse_id.""" + ws_client = get_workspace_client() + warehouse_id = settings.DATABRICKS_WAREHOUSE_ID + if not warehouse_id: + raise HTTPException(status_code=500, detail="DATABRICKS_WAREHOUSE_ID not configured") + return ws_client, warehouse_id + + +@router.get("", response_model=GraphDataResponse) +async def get_graph_data( + table_name: str = Query(default=DEFAULT_TABLE_NAME, alias="tableName"), + manager: GraphExplorerManagerDep = None, + settings: Settings = Depends(get_settings), + _: bool = Depends(PermissionChecker(GRAPH_EXPLORER_FEATURE_ID, FeatureAccessLevel.READ_ONLY)), +): + """Read graph data from a Databricks table.""" + try: + ws_client, warehouse_id = _get_ws_and_warehouse(settings) + # Ensure the table exists first + manager.ensure_table_exists(ws_client, table_name, warehouse_id) + data = manager.get_graph_data(ws_client, table_name, warehouse_id) + return data + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Error reading graph data: {e}") + raise HTTPException(status_code=500, detail=f"Error reading graph data: {str(e)}") + + +@router.get("/neighbors", response_model=GraphDataResponse) +async def get_neighbors( + node_id: str = Query(..., alias="nodeId"), + table_name: str = Query(default=DEFAULT_TABLE_NAME, alias="tableName"), + direction: NeighborDirection = Query(default=NeighborDirection.BOTH), + edge_types: Optional[List[str]] = Query(default=None, alias="edgeTypes"), + limit: int = Query(default=25, ge=1, le=500), + offset: int = Query(default=0, ge=0), + manager: GraphExplorerManagerDep = None, + settings: Settings = Depends(get_settings), + _: bool = Depends(PermissionChecker(GRAPH_EXPLORER_FEATURE_ID, FeatureAccessLevel.READ_ONLY)), +): + """Get the 1-hop neighborhood of a node for incremental expansion.""" + try: + ws_client, warehouse_id = _get_ws_and_warehouse(settings) + data = manager.get_neighbors( + ws_client, table_name, warehouse_id, + node_id=node_id, + direction=direction.value, + edge_types=edge_types, + limit=limit, + offset=offset, + ) + return data + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Error getting neighbors: {e}") + raise HTTPException(status_code=500, detail=f"Error getting neighbors: {str(e)}") + + +@router.post("/save", response_model=SaveGraphResponse) +async def save_graph_data( + request: SaveGraphRequest, + manager: GraphExplorerManagerDep = None, + settings: Settings = Depends(get_settings), + _: bool = Depends(PermissionChecker(GRAPH_EXPLORER_FEATURE_ID, FeatureAccessLevel.READ_WRITE)), +): + """Write new nodes and edges to a Databricks table.""" + try: + ws_client, warehouse_id = _get_ws_and_warehouse(settings) + manager.ensure_table_exists(ws_client, request.tableName, warehouse_id) + + nodes_dicts = [n.model_dump() for n in request.nodes] + edges_dicts = [e.model_dump() for e in request.edges] + + result = manager.write_nodes_and_edges( + ws_client, request.tableName, warehouse_id, nodes_dicts, edges_dicts, + ) + return result + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Error saving graph data: {e}") + raise HTTPException(status_code=500, detail=f"Error saving graph data: {str(e)}") + + +@router.post("/ensure-table", response_model=EnsureTableResponse) +async def ensure_table( + table_name: str = Query(default=DEFAULT_TABLE_NAME, alias="tableName"), + manager: GraphExplorerManagerDep = None, + settings: Settings = Depends(get_settings), + _: bool = Depends(PermissionChecker(GRAPH_EXPLORER_FEATURE_ID, FeatureAccessLevel.READ_WRITE)), +): + """Ensure the graph table exists, creating it if necessary.""" + try: + ws_client, warehouse_id = _get_ws_and_warehouse(settings) + manager.ensure_table_exists(ws_client, table_name, warehouse_id) + return {"tableName": table_name, "status": "ok"} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Error ensuring table: {e}") + raise HTTPException(status_code=500, detail=f"Error ensuring table: {str(e)}") + + +@router.delete("/node") +async def delete_node( + request: DeleteNodeRequest, + manager: GraphExplorerManagerDep = None, + settings: Settings = Depends(get_settings), + _: bool = Depends(PermissionChecker(GRAPH_EXPLORER_FEATURE_ID, FeatureAccessLevel.READ_WRITE)), +): + """Delete a node and its connected edges from the Databricks table.""" + try: + ws_client, warehouse_id = _get_ws_and_warehouse(settings) + manager.delete_node(ws_client, request.tableName, warehouse_id, request.nodeId) + return {"status": "deleted", "nodeId": request.nodeId} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Error deleting node: {e}") + raise HTTPException(status_code=500, detail=f"Error deleting node: {str(e)}") + + +@router.delete("/edge") +async def delete_edge( + request: DeleteEdgeRequest, + manager: GraphExplorerManagerDep = None, + settings: Settings = Depends(get_settings), + _: bool = Depends(PermissionChecker(GRAPH_EXPLORER_FEATURE_ID, FeatureAccessLevel.READ_WRITE)), +): + """Delete an edge from the Databricks table.""" + try: + ws_client, warehouse_id = _get_ws_and_warehouse(settings) + manager.delete_edge( + ws_client, request.tableName, warehouse_id, + request.sourceId, request.targetId, request.relationshipType, + ) + return {"status": "deleted"} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Error deleting edge: {e}") + raise HTTPException(status_code=500, detail=f"Error deleting edge: {str(e)}") + + +@router.put("/node") +async def update_node( + request: UpdateNodeRequest, + manager: GraphExplorerManagerDep = None, + settings: Settings = Depends(get_settings), + _: bool = Depends(PermissionChecker(GRAPH_EXPLORER_FEATURE_ID, FeatureAccessLevel.READ_WRITE)), +): + """Update a node in the Databricks table.""" + try: + ws_client, warehouse_id = _get_ws_and_warehouse(settings) + manager.update_node( + ws_client, request.tableName, warehouse_id, + request.nodeId, request.label, request.type, request.properties, + ) + return {"status": "updated", "nodeId": request.nodeId} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Error updating node: {e}") + raise HTTPException(status_code=500, detail=f"Error updating node: {str(e)}") + + +@router.get("/limits", response_model=GraphLimitsResponse) +async def get_graph_limits( + settings: Settings = Depends(get_settings), + _: bool = Depends(PermissionChecker(GRAPH_EXPLORER_FEATURE_ID, FeatureAccessLevel.READ_ONLY)), +): + """Return the current server-side safety limits for graph operations.""" + return GraphLimitsResponse( + maxNodes=settings.GRAPH_MAX_NODES, + maxEdges=settings.GRAPH_MAX_EDGES, + neighborLimit=settings.GRAPH_NEIGHBOR_LIMIT, + queryTimeout=settings.GRAPH_QUERY_TIMEOUT, + ) + + +@router.get("/llm-config", response_model=LlmConfigResponse) +async def get_llm_config( + manager: GraphExplorerManagerDep = None, + _: bool = Depends(PermissionChecker(GRAPH_EXPLORER_FEATURE_ID, FeatureAccessLevel.READ_ONLY)), +): + """Return the current LLM configuration status for the query panel.""" + return manager.get_llm_config() + + +@router.post("/query", response_model=GraphQueryResponse) +async def execute_graph_query( + request: GraphQueryRequest, + manager: GraphExplorerManagerDep = None, + settings: Settings = Depends(get_settings), + _: bool = Depends(PermissionChecker(GRAPH_EXPLORER_FEATURE_ID, FeatureAccessLevel.READ_WRITE)), +): + """Translate a Cypher/Gremlin query to SQL via LLM and execute it.""" + try: + ws_client, warehouse_id = _get_ws_and_warehouse(settings) + table_name = request.tableName or DEFAULT_TABLE_NAME + result = manager.execute_graph_query( + ws_client=ws_client, + warehouse_id=warehouse_id, + query=request.query, + language=request.language, + table_name=table_name, + override_sql=request.sql, + ) + return result + except RuntimeError as e: + # LLM or SQL execution error — return as a structured error, not 500 + return GraphQueryResponse( + success=False, + sql="", + language=request.language, + originalQuery=request.query, + message=str(e), + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Error executing graph query: {e}") + raise HTTPException(status_code=500, detail=f"Error executing graph query: {str(e)}") + + +def register_routes(app: FastAPI): + """Register graph explorer routes with the FastAPI app.""" + app.include_router(router) diff --git a/src/backend/src/tests/unit/test_audit_manager.py b/src/backend/src/tests/unit/test_audit_manager.py index 3435b633..2f524396 100644 --- a/src/backend/src/tests/unit/test_audit_manager.py +++ b/src/backend/src/tests/unit/test_audit_manager.py @@ -1,6 +1,8 @@ """ Unit tests for AuditManager """ +import uuid + import pytest from datetime import datetime from unittest.mock import Mock, MagicMock, patch @@ -41,7 +43,7 @@ def manager(self, mock_repo, mock_settings, mock_db_session): def sample_audit_log_db(self): """Sample audit log database object.""" log = Mock(spec=AuditLogDb) - log.id = 1 + log.id = uuid.uuid4() log.username = "user@example.com" log.ip_address = "192.168.1.1" log.feature = "data-products" diff --git a/src/backend/src/tests/unit/test_comments_manager.py b/src/backend/src/tests/unit/test_comments_manager.py index 507cb1e1..658cb866 100644 --- a/src/backend/src/tests/unit/test_comments_manager.py +++ b/src/backend/src/tests/unit/test_comments_manager.py @@ -3,11 +3,12 @@ """ import pytest import json +import uuid from datetime import datetime from unittest.mock import Mock, MagicMock from src.controller.comments_manager import CommentsManager from src.models.comments import CommentCreate, CommentUpdate, Comment -from src.db_models.comments import CommentDb, CommentStatus +from src.db_models.comments import CommentDb, CommentStatus, CommentType as DbCommentType class TestCommentsManager: @@ -27,13 +28,16 @@ def manager(self, mock_repository): def sample_comment_db(self): """Sample comment database object.""" comment = Mock(spec=CommentDb) - comment.id = "comment-123" + comment.id = uuid.UUID("11111111-2222-3333-4444-555555555555") comment.entity_id = "entity-456" comment.entity_type = "data_product" comment.title = "Test Comment" comment.comment = "This is a test comment" comment.audience = None + comment.project_id = None comment.status = CommentStatus.ACTIVE + comment.comment_type = DbCommentType.COMMENT + comment.rating = None comment.created_by = "user@example.com" comment.updated_by = "user@example.com" comment.created_at = datetime(2024, 1, 1, 0, 0, 0) @@ -49,6 +53,7 @@ def sample_comment_create(self): title="Test Comment", comment="This is a test comment", audience=None, + project_id="project-1", # Project-scoped; global (project_id=None) requires admin ) # Initialization Tests @@ -82,7 +87,7 @@ def test_db_to_api_model(self, manager, sample_comment_db): """Test converting database model to API model.""" result = manager._db_to_api_model(sample_comment_db) assert isinstance(result, Comment) - assert result.id == "comment-123" + assert result.id == uuid.UUID("11111111-2222-3333-4444-555555555555") assert result.title == "Test Comment" # Create Comment Tests @@ -95,7 +100,7 @@ def test_create_comment_success(self, manager, mock_repository, sample_comment_c result = manager.create_comment(db_session, data=sample_comment_create, user_email="user@example.com") assert isinstance(result, Comment) - assert result.id == "comment-123" + assert result.id == uuid.UUID("11111111-2222-3333-4444-555555555555") mock_repository.create_with_audience.assert_called_once() db_session.commit.assert_called_once() @@ -107,6 +112,7 @@ def test_create_comment_with_audience(self, manager, mock_repository, sample_com title="Restricted Comment", comment="For admins only", audience=["admins", "data_stewards"], + project_id="project-1", # Project-scoped; global (project_id=None) requires admin ) sample_comment_db.audience = '["admins", "data_stewards"]' mock_repository.create_with_audience.return_value = sample_comment_db @@ -114,7 +120,7 @@ def test_create_comment_with_audience(self, manager, mock_repository, sample_com db_session = Mock() result = manager.create_comment(db_session, data=comment_data, user_email="admin@example.com") - assert result.id == "comment-123" + assert result.id == uuid.UUID("11111111-2222-3333-4444-555555555555") mock_repository.create_with_audience.assert_called_once_with( db_session, obj_in=comment_data, created_by="admin@example.com" ) @@ -146,18 +152,21 @@ def test_list_comments_with_results(self, manager, mock_repository, sample_comme assert result.total_count == 1 assert result.visible_count == 1 assert len(result.comments) == 1 - assert result.comments[0].id == "comment-123" + assert result.comments[0].id == uuid.UUID("11111111-2222-3333-4444-555555555555") def test_list_comments_filtered_by_groups(self, manager, mock_repository, sample_comment_db): """Test listing comments filtered by user groups.""" comment2 = Mock(spec=CommentDb) - comment2.id = "comment-456" + comment2.id = uuid.UUID("66666666-7777-8888-9999-000000000000") comment2.entity_id = "entity-456" comment2.entity_type = "data_product" comment2.title = "Another Comment" comment2.comment = "Another test" comment2.audience = '["admins"]' + comment2.project_id = None comment2.status = CommentStatus.ACTIVE + comment2.comment_type = DbCommentType.COMMENT + comment2.rating = None comment2.created_by = "admin@example.com" comment2.updated_by = "admin@example.com" comment2.created_at = datetime(2024, 1, 1, 0, 0, 0) @@ -201,7 +210,10 @@ def test_list_comments_include_deleted(self, manager, mock_repository, sample_co def test_update_comment_success(self, manager, mock_repository, sample_comment_db): """Test updating a comment.""" updated_comment = Mock(spec=CommentDb) - updated_comment.__dict__.update(sample_comment_db.__dict__) + for attr in ("id", "entity_id", "entity_type", "comment", "audience", "project_id", + "status", "comment_type", "rating", "created_by", "updated_by", + "created_at", "updated_at"): + setattr(updated_comment, attr, getattr(sample_comment_db, attr)) updated_comment.title = "Updated Title" mock_repository.get.return_value = sample_comment_db @@ -212,7 +224,7 @@ def test_update_comment_success(self, manager, mock_repository, sample_comment_d update_data = CommentUpdate(title="Updated Title") result = manager.update_comment( db_session, - comment_id="comment-123", + comment_id="11111111-2222-3333-4444-555555555555", data=update_data, user_email="user@example.com", ) @@ -284,7 +296,7 @@ def test_delete_comment_soft_success(self, manager, mock_repository, sample_comm db_session = Mock() result = manager.delete_comment( - db_session, comment_id="comment-123", user_email="user@example.com" + db_session, comment_id="11111111-2222-3333-4444-555555555555", user_email="user@example.com" ) assert result is True @@ -324,7 +336,7 @@ def test_delete_comment_permission_denied(self, manager, mock_repository, sample db_session = Mock() result = manager.delete_comment( - db_session, comment_id="comment-123", user_email="other@example.com" + db_session, comment_id="11111111-2222-3333-4444-555555555555", user_email="other@example.com" ) assert result is False @@ -336,10 +348,10 @@ def test_get_comment_success(self, manager, mock_repository, sample_comment_db): mock_repository.get.return_value = sample_comment_db db_session = Mock() - result = manager.get_comment(db_session, comment_id="comment-123") + result = manager.get_comment(db_session, comment_id="11111111-2222-3333-4444-555555555555") assert result is not None - assert result.id == "comment-123" + assert result.id == uuid.UUID("11111111-2222-3333-4444-555555555555") def test_get_comment_not_found(self, manager, mock_repository): """Test getting non-existent comment.""" @@ -359,7 +371,7 @@ def test_can_user_modify_comment_true(self, manager, mock_repository, sample_com db_session = Mock() result = manager.can_user_modify_comment( - db_session, comment_id="comment-123", user_email="user@example.com" + db_session, comment_id="11111111-2222-3333-4444-555555555555", user_email="user@example.com" ) assert result is True @@ -371,7 +383,7 @@ def test_can_user_modify_comment_false(self, manager, mock_repository, sample_co db_session = Mock() result = manager.can_user_modify_comment( - db_session, comment_id="comment-123", user_email="other@example.com" + db_session, comment_id="11111111-2222-3333-4444-555555555555", user_email="other@example.com" ) assert result is False diff --git a/src/backend/src/tests/unit/test_costs_manager.py b/src/backend/src/tests/unit/test_costs_manager.py index 3bf0903f..766f2896 100644 --- a/src/backend/src/tests/unit/test_costs_manager.py +++ b/src/backend/src/tests/unit/test_costs_manager.py @@ -2,12 +2,41 @@ Unit tests for CostsManager """ import pytest -from datetime import date +from datetime import date, datetime, timezone +from uuid import uuid4, UUID from unittest.mock import Mock from src.controller.costs_manager import CostsManager from src.models.costs import CostItemCreate, CostItemUpdate, CostItem from src.db_models.costs import CostItemDb +# Fixed UUID for deterministic test assertions +SAMPLE_COST_UUID = UUID("a1b2c3d4-e5f6-7890-abcd-ef1234567890") + + +def _make_cost_item_db( + id_val=None, + amount_cents=150000, + cost_center="INFRASTRUCTURE", +): + """Build a mock CostItemDb with properly typed values for Pydantic validation.""" + cost_item = Mock(spec=CostItemDb) + cost_item.id = id_val or uuid4() + cost_item.entity_type = "data_product" + cost_item.entity_id = "product-456" + cost_item.title = "Storage Costs" + cost_item.description = "Monthly storage costs" + cost_item.cost_center = cost_center + cost_item.custom_center_name = None + cost_item.amount_cents = amount_cents + cost_item.currency = "USD" + cost_item.start_month = date(2024, 1, 1) + cost_item.end_month = None + cost_item.created_by = None + cost_item.updated_by = None + cost_item.created_at = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + cost_item.updated_at = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + return cost_item + class TestCostsManager: """Test suite for CostsManager.""" @@ -25,18 +54,7 @@ def manager(self, mock_repository): @pytest.fixture def sample_cost_item_db(self): """Sample cost item database object.""" - cost_item = Mock(spec=CostItemDb) - cost_item.id = "cost-123" - cost_item.entity_type = "data_product" - cost_item.entity_id = "product-456" - cost_item.title = "Storage Costs" - cost_item.description = "Monthly storage costs" - cost_item.cost_center = "infrastructure" - cost_item.amount_cents = 150000 - cost_item.currency = "USD" - cost_item.start_month = date(2024, 1, 1) - cost_item.end_month = None - return cost_item + return _make_cost_item_db(id_val=SAMPLE_COST_UUID, amount_cents=150000) @pytest.fixture def sample_cost_create(self): @@ -69,7 +87,7 @@ def test_create_cost_item_success(self, manager, mock_repository, sample_cost_cr result = manager.create(mock_db, data=sample_cost_create, user_email="user@example.com") assert isinstance(result, CostItem) - assert result.id == "cost-123" + assert result.id == SAMPLE_COST_UUID assert result.amount_cents == 150000 mock_repository.create.assert_called_once() assert mock_db.commit.called @@ -107,7 +125,7 @@ def test_list_cost_items_with_results(self, manager, mock_repository, sample_cos result = manager.list(mock_db, entity_type="data_product", entity_id="product-456") assert len(result) == 1 - assert result[0].id == "cost-123" + assert result[0].id == SAMPLE_COST_UUID assert result[0].amount_cents == 150000 def test_list_cost_items_filtered_by_month(self, manager, mock_repository, sample_cost_item_db): @@ -130,13 +148,9 @@ def test_list_cost_items_filtered_by_month(self, manager, mock_repository, sampl def test_update_cost_item_success(self, manager, mock_repository, sample_cost_item_db): """Test updating a cost item.""" # Create updated item with same attributes but changed amount - updated_item = Mock(spec=CostItemDb) - for key, value in sample_cost_item_db.__dict__.items(): - if not key.startswith('_'): - setattr(updated_item, key, value) - updated_item.amount_cents = 200000 - updated_item.entity_type = sample_cost_item_db.entity_type - updated_item.entity_id = sample_cost_item_db.entity_id + updated_item = _make_cost_item_db( + id_val=SAMPLE_COST_UUID, amount_cents=200000 + ) mock_repository.get.return_value = sample_cost_item_db mock_repository.update.return_value = updated_item @@ -161,12 +175,9 @@ def test_update_cost_item_not_found(self, manager, mock_repository): def test_update_cost_item_logs_change(self, manager, mock_repository, sample_cost_item_db): """Test that updating a cost item logs a change.""" - updated_item = Mock(spec=CostItemDb) - for key, value in sample_cost_item_db.__dict__.items(): - if not key.startswith('_'): - setattr(updated_item, key, value) - updated_item.entity_type = sample_cost_item_db.entity_type - updated_item.entity_id = sample_cost_item_db.entity_id + updated_item = _make_cost_item_db( + id_val=SAMPLE_COST_UUID, amount_cents=200000 + ) mock_repository.get.return_value = sample_cost_item_db mock_repository.update.return_value = updated_item diff --git a/src/backend/src/tests/unit/test_data_products_manager.py b/src/backend/src/tests/unit/test_data_products_manager.py index a8c29339..97be055b 100644 --- a/src/backend/src/tests/unit/test_data_products_manager.py +++ b/src/backend/src/tests/unit/test_data_products_manager.py @@ -134,10 +134,11 @@ def test_create_product_with_tags( def test_create_product_validation_error(self, manager, db_session): """Test that invalid product data raises ValueError.""" - # Arrange + # Arrange - Use data that fails Pydantic validation (id must be str, not int) invalid_data = { - "name": "", # Empty name should fail - "version": "invalid", # Invalid version format might fail + "id": 12345, # Invalid type - Pydantic expects str + "name": "Test", + "version": "1.0.0", } # Act & Assert @@ -290,9 +291,10 @@ def test_update_product_validation_error( """Test that invalid update data raises ValueError.""" # Arrange created = manager.create_product(sample_product_data, db=db_session) + # Use data that fails Pydantic validation (status must be str, not int) invalid_update = { "id": created.id, - "status": "invalid-status", # Invalid status + "status": 123, # Invalid type - Pydantic expects str } # Act & Assert @@ -489,13 +491,17 @@ def test_get_distinct_product_types(self, manager): def test_get_distinct_owners(self, manager, db_session): """Test retrieving distinct owners.""" - # Arrange - Create products with different owners + # Arrange - Create products with team members (ODPS stores owners in team.members) for i, owner in enumerate(["owner1@test.com", "owner2@test.com"]): product_data = { "name": f"Product {i}", "version": "1.0.0", "productType": "sourceAligned", - "owner": owner, + "team": { + "members": [ + {"username": owner, "role": "owner", "name": owner} + ] + }, } manager.create_product(product_data, db=db_session) @@ -562,12 +568,13 @@ def test_create_product_db_error_handling( self, manager, db_session, sample_product_data ): """Test graceful handling of database errors during creation.""" - # Arrange - Force DB error by closing session - db_session.close() + # Arrange - Patch repository to simulate DB error + from sqlalchemy.exc import SQLAlchemyError - # Act & Assert - with pytest.raises(Exception): # SQLAlchemy error - manager.create_product(sample_product_data, db=db_session) + with patch.object(manager._repo, 'create', side_effect=SQLAlchemyError("DB connection failed")): + # Act & Assert + with pytest.raises(SQLAlchemyError): + manager.create_product(sample_product_data, db=db_session) def test_manager_without_workspace_client(self, db_session): """Test manager initialization without WorkspaceClient.""" diff --git a/src/backend/src/tests/unit/test_graph_explorer_manager.py b/src/backend/src/tests/unit/test_graph_explorer_manager.py new file mode 100644 index 00000000..394aa212 --- /dev/null +++ b/src/backend/src/tests/unit/test_graph_explorer_manager.py @@ -0,0 +1,489 @@ +""" +Unit tests for GraphExplorerManager. + +Tests the get_neighbors() method with mocked Databricks Statement Execution API. +""" + +import json +import pytest +from unittest.mock import MagicMock, patch, call + +from src.controller.graph_explorer_manager import GraphExplorerManager, _validate_table_name, _escape_sql_string + + +class TestValidateTableName: + """Test table name validation and sanitization.""" + + def test_valid_three_part_name(self): + result = _validate_table_name("main.default.my_table") + assert result == "`main`.`default`.`my_table`" + + def test_valid_single_part(self): + result = _validate_table_name("my_table") + assert result == "`my_table`" + + def test_rejects_sql_injection(self): + with pytest.raises(ValueError, match="Invalid table name"): + _validate_table_name("main; DROP TABLE users --") + + def test_strips_whitespace(self): + result = _validate_table_name(" main.default.t ") + assert result == "`main`.`default`.`t`" + + +class TestEscapeSqlString: + """Test SQL string escaping.""" + + def test_escapes_single_quotes(self): + assert _escape_sql_string("O'Brien") == "O''Brien" + + def test_no_change_for_safe_strings(self): + assert _escape_sql_string("hello") == "hello" + + +class TestGraphExplorerManager: + """Unit tests for GraphExplorerManager business logic.""" + + @pytest.fixture + def manager(self): + settings = MagicMock() + settings.DATABRICKS_WAREHOUSE_ID = "test-warehouse" + settings.LLM_ENDPOINT = "" + settings.GRAPH_MAX_EDGES = 10000 + settings.GRAPH_MAX_NODES = 5000 + settings.GRAPH_QUERY_TIMEOUT = "30s" + settings.GRAPH_NEIGHBOR_LIMIT = 50 + return GraphExplorerManager(settings=settings) + + @pytest.fixture + def mock_ws_client(self): + return MagicMock() + + def _make_sql_result(self, columns, rows): + """Helper to create a mock Statement Execution API result.""" + mock_result = MagicMock() + mock_result.status.state = "SUCCEEDED" + + # Build column schema + col_objects = [] + for col_name in columns: + col_obj = MagicMock() + col_obj.name = col_name + col_objects.append(col_obj) + + mock_result.manifest.schema.columns = col_objects + mock_result.result.data_array = rows + + return mock_result + + # --------------------------------------------------------------- + # get_neighbors tests + # --------------------------------------------------------------- + + def test_get_neighbors_both_directions(self, manager, mock_ws_client): + """Test expanding all neighbors of a node.""" + columns = [ + "node_start_id", "node_start_key", "relationship", + "node_end_id", "node_end_key", + "node_start_properties", "node_end_properties", + ] + + # First call: COUNT query + count_result = self._make_sql_result(["count"], [["2"]]) + # Second call: SELECT query + data_result = self._make_sql_result(columns, [ + ["alice", "Person", "KNOWS", "bob", "Person", + json.dumps({"_label": "Alice"}), json.dumps({"_label": "Bob"})], + ["alice", "Person", "WORKS_AT", "acme", "Company", + json.dumps({"_label": "Alice"}), json.dumps({"_label": "Acme"})], + ]) + + mock_ws_client.statement_execution.execute_statement.side_effect = [ + count_result, data_result + ] + + result = manager.get_neighbors( + mock_ws_client, "main.default.test_graph", "wh-1", + node_id="alice", direction="both", limit=25, + ) + + assert len(result["nodes"]) == 3 # alice, bob, acme + assert len(result["edges"]) == 2 # KNOWS, WORKS_AT + assert result["truncated"] is False + assert result["totalAvailable"] == 2 + + # Verify SQL contains the right direction filter + calls = mock_ws_client.statement_execution.execute_statement.call_args_list + count_sql = calls[0].kwargs["statement"] + assert "node_start_id = 'alice' OR node_end_id = 'alice'" in count_sql + assert "relationship != 'EXISTS'" in count_sql + + def test_get_neighbors_outgoing_only(self, manager, mock_ws_client): + """Test expanding only outgoing edges.""" + columns = [ + "node_start_id", "node_start_key", "relationship", + "node_end_id", "node_end_key", + "node_start_properties", "node_end_properties", + ] + + count_result = self._make_sql_result(["count"], [["1"]]) + data_result = self._make_sql_result(columns, [ + ["alice", "Person", "KNOWS", "bob", "Person", + json.dumps({"_label": "Alice"}), json.dumps({"_label": "Bob"})], + ]) + + mock_ws_client.statement_execution.execute_statement.side_effect = [ + count_result, data_result + ] + + result = manager.get_neighbors( + mock_ws_client, "main.default.test_graph", "wh-1", + node_id="alice", direction="outgoing", limit=25, + ) + + assert len(result["nodes"]) == 2 # alice, bob + assert len(result["edges"]) == 1 + + # Verify direction filter + calls = mock_ws_client.statement_execution.execute_statement.call_args_list + count_sql = calls[0].kwargs["statement"] + assert "node_start_id = 'alice'" in count_sql + assert "OR" not in count_sql + + def test_get_neighbors_incoming_only(self, manager, mock_ws_client): + """Test expanding only incoming edges.""" + columns = [ + "node_start_id", "node_start_key", "relationship", + "node_end_id", "node_end_key", + "node_start_properties", "node_end_properties", + ] + + count_result = self._make_sql_result(["count"], [["1"]]) + data_result = self._make_sql_result(columns, [ + ["bob", "Person", "KNOWS", "alice", "Person", + json.dumps({"_label": "Bob"}), json.dumps({"_label": "Alice"})], + ]) + + mock_ws_client.statement_execution.execute_statement.side_effect = [ + count_result, data_result + ] + + result = manager.get_neighbors( + mock_ws_client, "main.default.test_graph", "wh-1", + node_id="alice", direction="incoming", limit=25, + ) + + assert len(result["nodes"]) == 2 + calls = mock_ws_client.statement_execution.execute_statement.call_args_list + count_sql = calls[0].kwargs["statement"] + assert "node_end_id = 'alice'" in count_sql + assert "node_start_id" not in count_sql + + def test_get_neighbors_with_edge_type_filter(self, manager, mock_ws_client): + """Test filtering by specific edge types.""" + columns = [ + "node_start_id", "node_start_key", "relationship", + "node_end_id", "node_end_key", + "node_start_properties", "node_end_properties", + ] + + count_result = self._make_sql_result(["count"], [["1"]]) + data_result = self._make_sql_result(columns, [ + ["alice", "Person", "KNOWS", "bob", "Person", + json.dumps({"_label": "Alice"}), json.dumps({"_label": "Bob"})], + ]) + + mock_ws_client.statement_execution.execute_statement.side_effect = [ + count_result, data_result + ] + + result = manager.get_neighbors( + mock_ws_client, "main.default.test_graph", "wh-1", + node_id="alice", direction="both", + edge_types=["KNOWS"], limit=25, + ) + + assert len(result["edges"]) == 1 + assert result["edges"][0]["relationshipType"] == "KNOWS" + + # Verify edge type filter in SQL + calls = mock_ws_client.statement_execution.execute_statement.call_args_list + count_sql = calls[0].kwargs["statement"] + assert "relationship IN ('KNOWS')" in count_sql + + def test_get_neighbors_truncated(self, manager, mock_ws_client): + """Test truncation when more neighbors exist than the limit.""" + columns = [ + "node_start_id", "node_start_key", "relationship", + "node_end_id", "node_end_key", + "node_start_properties", "node_end_properties", + ] + + # COUNT returns 50, but we're requesting limit=2 + count_result = self._make_sql_result(["count"], [["50"]]) + data_result = self._make_sql_result(columns, [ + ["alice", "Person", "KNOWS", "bob", "Person", None, None], + ["alice", "Person", "KNOWS", "carol", "Person", None, None], + ]) + + mock_ws_client.statement_execution.execute_statement.side_effect = [ + count_result, data_result + ] + + result = manager.get_neighbors( + mock_ws_client, "main.default.test_graph", "wh-1", + node_id="alice", direction="both", limit=2, + ) + + assert result["truncated"] is True + assert result["totalAvailable"] == 50 + assert len(result["edges"]) == 2 + + def test_get_neighbors_empty_result(self, manager, mock_ws_client): + """Test expanding a node with no neighbors.""" + count_result = self._make_sql_result(["count"], [["0"]]) + data_result = self._make_sql_result([], []) + # Override to handle empty columns + data_result.manifest.schema.columns = [] + data_result.result.data_array = [] + + mock_ws_client.statement_execution.execute_statement.side_effect = [ + count_result, data_result + ] + + result = manager.get_neighbors( + mock_ws_client, "main.default.test_graph", "wh-1", + node_id="lonely_node", direction="both", limit=25, + ) + + assert len(result["nodes"]) == 0 + assert len(result["edges"]) == 0 + assert result["truncated"] is False + assert result["totalAvailable"] == 0 + + def test_get_neighbors_deduplicates_nodes(self, manager, mock_ws_client): + """Test that nodes appearing in multiple edges are not duplicated.""" + columns = [ + "node_start_id", "node_start_key", "relationship", + "node_end_id", "node_end_key", + "node_start_properties", "node_end_properties", + ] + + count_result = self._make_sql_result(["count"], [["2"]]) + # alice appears as start node in both rows + data_result = self._make_sql_result(columns, [ + ["alice", "Person", "KNOWS", "bob", "Person", None, None], + ["alice", "Person", "LIKES", "bob", "Person", None, None], + ]) + + mock_ws_client.statement_execution.execute_statement.side_effect = [ + count_result, data_result + ] + + result = manager.get_neighbors( + mock_ws_client, "main.default.test_graph", "wh-1", + node_id="alice", direction="both", limit=25, + ) + + # Only 2 unique nodes despite appearing in multiple rows + assert len(result["nodes"]) == 2 + node_ids = {n["id"] for n in result["nodes"]} + assert node_ids == {"alice", "bob"} + # But both edges are present + assert len(result["edges"]) == 2 + + def test_get_neighbors_sql_injection_protection(self, manager, mock_ws_client): + """Test that node IDs with special characters are escaped.""" + columns = [ + "node_start_id", "node_start_key", "relationship", + "node_end_id", "node_end_key", + "node_start_properties", "node_end_properties", + ] + + count_result = self._make_sql_result(["count"], [["0"]]) + data_result = self._make_sql_result(columns, []) + data_result.result.data_array = [] + + mock_ws_client.statement_execution.execute_statement.side_effect = [ + count_result, data_result + ] + + # Node ID with a single quote (SQL injection attempt) + result = manager.get_neighbors( + mock_ws_client, "main.default.test_graph", "wh-1", + node_id="alice'; DROP TABLE users; --", direction="both", limit=25, + ) + + # Verify the escaped string is in the SQL + calls = mock_ws_client.statement_execution.execute_statement.call_args_list + count_sql = calls[0].kwargs["statement"] + assert "alice''; DROP TABLE users; --" in count_sql # Double-escaped + + # --------------------------------------------------------------- + # get_graph_data tests + # --------------------------------------------------------------- + + def test_get_graph_data_basic(self, manager, mock_ws_client): + """Test basic graph data loading.""" + columns = [ + "node_start_id", "node_start_key", "relationship", + "node_end_id", "node_end_key", + "node_start_properties", "node_end_properties", + ] + + count_result = self._make_sql_result(["count"], [["2"]]) + data_result = self._make_sql_result(columns, [ + ["alice", "Person", "KNOWS", "bob", "Person", + json.dumps({"_label": "Alice", "age": "30"}), + json.dumps({"_label": "Bob"})], + ["alice", "Person", "EXISTS", "alice", "Person", + json.dumps({"_label": "Alice"}), + json.dumps({"_label": "Alice"})], + ]) + + mock_ws_client.statement_execution.execute_statement.side_effect = [ + count_result, data_result + ] + + data = manager.get_graph_data(mock_ws_client, "main.default.test", "wh-1") + + assert len(data["nodes"]) == 2 # alice, bob + assert len(data["edges"]) == 1 # KNOWS only, EXISTS filtered out + assert data["truncated"] is False + + def test_get_graph_data_empty_table(self, manager, mock_ws_client): + """Test loading from an empty table.""" + count_result = self._make_sql_result(["count"], [["0"]]) + data_result = self._make_sql_result([], []) + data_result.manifest.schema.columns = [] + data_result.result.data_array = [] + + mock_ws_client.statement_execution.execute_statement.side_effect = [ + count_result, data_result + ] + + data = manager.get_graph_data(mock_ws_client, "main.default.test", "wh-1") + + assert data["nodes"] == [] + assert data["edges"] == [] + assert data["truncated"] is False + + def test_get_graph_data_truncated(self, mock_ws_client): + """Test that get_graph_data returns truncation info when data exceeds limit.""" + settings = MagicMock() + settings.DATABRICKS_WAREHOUSE_ID = "test-warehouse" + settings.LLM_ENDPOINT = "" + settings.GRAPH_MAX_EDGES = 2 # Very low limit + settings.GRAPH_QUERY_TIMEOUT = "30s" + mgr = GraphExplorerManager(settings=settings) + + columns = [ + "node_start_id", "node_start_key", "relationship", + "node_end_id", "node_end_key", + "node_start_properties", "node_end_properties", + ] + + # Total rows = 100, but limit is 2 + count_result = self._make_sql_result(["count"], [["100"]]) + data_result = self._make_sql_result(columns, [ + ["a", "N", "R", "b", "N", None, None], + ["c", "N", "R", "d", "N", None, None], + ]) + + mock_ws_client.statement_execution.execute_statement.side_effect = [ + count_result, data_result + ] + + data = mgr.get_graph_data(mock_ws_client, "main.default.test", "wh-1") + + assert data["truncated"] is True + assert data["totalAvailable"] == 100 + + # Verify LIMIT in the SQL + calls = mock_ws_client.statement_execution.execute_statement.call_args_list + select_sql = calls[1].kwargs["statement"] + assert "LIMIT 2" in select_sql + + def test_get_graph_data_respects_max_rows_argument(self, manager, mock_ws_client): + """Test that explicit max_rows overrides settings.""" + columns = [ + "node_start_id", "node_start_key", "relationship", + "node_end_id", "node_end_key", + "node_start_properties", "node_end_properties", + ] + + count_result = self._make_sql_result(["count"], [["5"]]) + data_result = self._make_sql_result(columns, [ + ["a", "N", "R", "b", "N", None, None], + ]) + + mock_ws_client.statement_execution.execute_statement.side_effect = [ + count_result, data_result + ] + + data = manager.get_graph_data(mock_ws_client, "main.default.test", "wh-1", max_rows=1) + + # Verify LIMIT in the SQL + calls = mock_ws_client.statement_execution.execute_statement.call_args_list + select_sql = calls[1].kwargs["statement"] + assert "LIMIT 1" in select_sql + + def test_get_graph_data_uses_timeout_from_settings(self, mock_ws_client): + """Test that _execute_sql uses GRAPH_QUERY_TIMEOUT from settings.""" + settings = MagicMock() + settings.DATABRICKS_WAREHOUSE_ID = "test-warehouse" + settings.LLM_ENDPOINT = "" + settings.GRAPH_MAX_EDGES = 10000 + settings.GRAPH_QUERY_TIMEOUT = "60s" + mgr = GraphExplorerManager(settings=settings) + + count_result = self._make_sql_result(["count"], [["0"]]) + data_result = self._make_sql_result([], []) + data_result.manifest.schema.columns = [] + data_result.result.data_array = [] + + mock_ws_client.statement_execution.execute_statement.side_effect = [ + count_result, data_result + ] + + mgr.get_graph_data(mock_ws_client, "main.default.test", "wh-1") + + # Verify timeout was passed + calls = mock_ws_client.statement_execution.execute_statement.call_args_list + for c in calls: + assert c.kwargs["wait_timeout"] == "60s" + + # --------------------------------------------------------------- + # _parse_props tests + # --------------------------------------------------------------- + + def test_parse_props_json_string(self, manager): + result = manager._parse_props('{"name": "Alice", "age": 30}') + assert result == {"name": "Alice", "age": 30} + + def test_parse_props_dict(self, manager): + result = manager._parse_props({"name": "Alice"}) + assert result == {"name": "Alice"} + + def test_parse_props_none(self, manager): + assert manager._parse_props(None) == {} + + def test_parse_props_invalid_json(self, manager): + assert manager._parse_props("not json") == {} + + # --------------------------------------------------------------- + # LLM config tests + # --------------------------------------------------------------- + + def test_get_llm_config_disabled(self, manager): + config = manager.get_llm_config() + assert config["enabled"] is False + + def test_get_llm_config_enabled(self): + settings = MagicMock() + settings.LLM_ENDPOINT = "databricks-claude-sonnet" + mgr = GraphExplorerManager(settings=settings) + + config = mgr.get_llm_config() + assert config["enabled"] is True + assert config["defaultModel"] == "databricks-claude-sonnet" diff --git a/src/backend/src/tests/unit/test_metadata_manager.py b/src/backend/src/tests/unit/test_metadata_manager.py index fcf0c055..eb7f4d6c 100644 --- a/src/backend/src/tests/unit/test_metadata_manager.py +++ b/src/backend/src/tests/unit/test_metadata_manager.py @@ -1,6 +1,8 @@ """ Unit tests for MetadataManager """ +import uuid +from datetime import datetime import pytest from unittest.mock import Mock, MagicMock from databricks.sdk.service.catalog import VolumeType @@ -40,31 +42,45 @@ def manager(self, mock_rich_text_repo, mock_link_repo, mock_document_repo): def sample_rich_text_db(self): """Sample rich text database object.""" rt = Mock(spec=RichTextMetadataDb) - rt.id = "rt-123" + rt.id = uuid.UUID("12345678-1234-1234-1234-123456789012") rt.entity_type = "data_product" rt.entity_id = "product-456" rt.title = "Documentation" rt.short_description = "Product docs" rt.content_markdown = "# Test\nContent" + rt.is_shared = False + rt.level = 50 + rt.inheritable = True + rt.created_by = "user@example.com" + rt.updated_by = "user@example.com" + rt.created_at = datetime(2024, 1, 1, 12, 0, 0) + rt.updated_at = datetime(2024, 1, 1, 12, 0, 0) return rt @pytest.fixture def sample_link_db(self): """Sample link database object.""" link = Mock(spec=LinkMetadataDb) - link.id = "link-123" + link.id = uuid.UUID("22345678-1234-1234-1234-123456789012") link.entity_type = "data_contract" link.entity_id = "contract-456" link.title = "Related Link" link.short_description = "External resource" link.url = "https://example.com" + link.is_shared = False + link.level = 50 + link.inheritable = True + link.created_by = "user@example.com" + link.updated_by = "user@example.com" + link.created_at = datetime(2024, 1, 1, 12, 0, 0) + link.updated_at = datetime(2024, 1, 1, 12, 0, 0) return link @pytest.fixture def sample_document_db(self): """Sample document database object.""" doc = Mock(spec=DocumentMetadataDb) - doc.id = "doc-123" + doc.id = uuid.UUID("32345678-1234-1234-1234-123456789012") doc.entity_type = "data_product" doc.entity_id = "product-789" doc.title = "User Guide" @@ -73,6 +89,13 @@ def sample_document_db(self): doc.content_type = "application/pdf" doc.size_bytes = 1024000 doc.storage_path = "/volumes/catalog/schema/volume/guide.pdf" + doc.is_shared = False + doc.level = 50 + doc.inheritable = True + doc.created_by = "user@example.com" + doc.updated_by = "user@example.com" + doc.created_at = datetime(2024, 1, 1, 12, 0, 0) + doc.updated_at = datetime(2024, 1, 1, 12, 0, 0) return doc # Initialization Tests @@ -105,7 +128,7 @@ def test_create_rich_text_success(self, manager, mock_rich_text_repo, sample_ric result = manager.create_rich_text(mock_db, data=data, user_email="user@example.com") - assert result.id == "rt-123" + assert result.id == uuid.UUID("12345678-1234-1234-1234-123456789012") mock_rich_text_repo.create.assert_called_once() assert mock_db.commit.called @@ -126,14 +149,17 @@ def test_list_rich_texts_with_results(self, manager, mock_rich_text_repo, sample result = manager.list_rich_texts(mock_db, entity_type="data_product", entity_id="product-456") assert len(result) == 1 - assert result[0].id == "rt-123" + assert result[0].id == uuid.UUID("12345678-1234-1234-1234-123456789012") def test_update_rich_text_success(self, manager, mock_rich_text_repo, sample_rich_text_db): """Test updating rich text metadata.""" updated_rt = Mock(spec=RichTextMetadataDb) - for key, value in sample_rich_text_db.__dict__.items(): - if not key.startswith('_'): - setattr(updated_rt, key, value) + for attr in ( + "id", "entity_type", "entity_id", "title", "short_description", "content_markdown", + "is_shared", "level", "inheritable", "created_by", "updated_by", + "created_at", "updated_at", + ): + setattr(updated_rt, attr, getattr(sample_rich_text_db, attr)) updated_rt.title = "Updated Title" mock_rich_text_repo.get.return_value = sample_rich_text_db @@ -141,7 +167,9 @@ def test_update_rich_text_success(self, manager, mock_rich_text_repo, sample_ric mock_db = Mock() update_data = RichTextUpdate(title="Updated Title") - result = manager.update_rich_text(mock_db, id="rt-123", data=update_data, user_email="user@example.com") + result = manager.update_rich_text( + mock_db, id="12345678-1234-1234-1234-123456789012", data=update_data, user_email="user@example.com" + ) assert result is not None assert result.title == "Updated Title" @@ -162,7 +190,9 @@ def test_delete_rich_text_success(self, manager, mock_rich_text_repo, sample_ric mock_rich_text_repo.remove.return_value = sample_rich_text_db mock_db = Mock() - result = manager.delete_rich_text(mock_db, id="rt-123", user_email="user@example.com") + result = manager.delete_rich_text( + mock_db, id="12345678-1234-1234-1234-123456789012", user_email="user@example.com" + ) assert result is True mock_rich_text_repo.remove.assert_called_once() @@ -193,7 +223,7 @@ def test_create_link_success(self, manager, mock_link_repo, sample_link_db): result = manager.create_link(mock_db, data=data, user_email="user@example.com") - assert result.id == "link-123" + assert result.id == uuid.UUID("22345678-1234-1234-1234-123456789012") mock_link_repo.create.assert_called_once() def test_list_links_empty(self, manager, mock_link_repo): @@ -218,9 +248,12 @@ def test_list_links_with_results(self, manager, mock_link_repo, sample_link_db): def test_update_link_success(self, manager, mock_link_repo, sample_link_db): """Test updating link metadata.""" updated_link = Mock(spec=LinkMetadataDb) - for key, value in sample_link_db.__dict__.items(): - if not key.startswith('_'): - setattr(updated_link, key, value) + for attr in ( + "id", "entity_type", "entity_id", "title", "short_description", "url", + "is_shared", "level", "inheritable", "created_by", "updated_by", + "created_at", "updated_at", + ): + setattr(updated_link, attr, getattr(sample_link_db, attr)) updated_link.url = "https://updated.com" mock_link_repo.get.return_value = sample_link_db @@ -228,7 +261,9 @@ def test_update_link_success(self, manager, mock_link_repo, sample_link_db): mock_db = Mock() update_data = LinkUpdate(url="https://updated.com") - result = manager.update_link(mock_db, id="link-123", data=update_data, user_email="user@example.com") + result = manager.update_link( + mock_db, id="22345678-1234-1234-1234-123456789012", data=update_data, user_email="user@example.com" + ) assert result is not None assert result.url == "https://updated.com" @@ -249,7 +284,9 @@ def test_delete_link_success(self, manager, mock_link_repo, sample_link_db): mock_link_repo.remove.return_value = sample_link_db mock_db = Mock() - result = manager.delete_link(mock_db, id="link-123", user_email="user@example.com") + result = manager.delete_link( + mock_db, id="22345678-1234-1234-1234-123456789012", user_email="user@example.com" + ) assert result is True @@ -268,6 +305,17 @@ def test_create_document_record_success(self, manager): """Test creating document record.""" mock_db = Mock() + # Simulate db.refresh populating DB-generated fields (id, timestamps, defaults) + def _mock_refresh(db_obj): + setattr(db_obj, "id", uuid.UUID("32345678-1234-1234-1234-123456789012")) + setattr(db_obj, "created_at", datetime(2024, 1, 1, 12, 0, 0)) + setattr(db_obj, "updated_at", datetime(2024, 1, 1, 12, 0, 0)) + setattr(db_obj, "is_shared", False) + setattr(db_obj, "level", 50) + setattr(db_obj, "inheritable", True) + + mock_db.refresh = Mock(side_effect=_mock_refresh) + data = DocumentCreate( entity_type="data_product", entity_id="product-789", @@ -313,10 +361,10 @@ def test_get_document_success(self, manager, mock_document_repo, sample_document mock_document_repo.get.return_value = sample_document_db mock_db = Mock() - result = manager.get_document(mock_db, id="doc-123") + result = manager.get_document(mock_db, id="32345678-1234-1234-1234-123456789012") assert result is not None - assert result.id == "doc-123" + assert result.id == uuid.UUID("32345678-1234-1234-1234-123456789012") def test_get_document_not_found(self, manager, mock_document_repo): """Test getting non-existent document.""" @@ -333,7 +381,9 @@ def test_delete_document_success(self, manager, mock_document_repo, sample_docum mock_document_repo.remove.return_value = sample_document_db mock_db = Mock() - result = manager.delete_document(mock_db, id="doc-123", user_email="user@example.com") + result = manager.delete_document( + mock_db, id="32345678-1234-1234-1234-123456789012", user_email="user@example.com" + ) assert result is True diff --git a/src/backend/src/tests/unit/test_notifications_manager.py b/src/backend/src/tests/unit/test_notifications_manager.py index d4b04a31..c476d51b 100644 --- a/src/backend/src/tests/unit/test_notifications_manager.py +++ b/src/backend/src/tests/unit/test_notifications_manager.py @@ -3,13 +3,40 @@ """ import pytest from datetime import datetime -from unittest.mock import Mock, MagicMock, patch, mock_open -from src.controller.notifications_manager import NotificationsManager, NotificationNotFoundError +from types import SimpleNamespace +from unittest.mock import Mock +from src.controller.notifications_manager import NotificationsManager from src.models.notifications import Notification, NotificationType from src.models.users import UserInfo from src.db_models.notifications import NotificationDb +def _make_notif_db(**kwargs): + """Create a mock NotificationDb with all attributes needed for Notification.model_validate.""" + defaults = { + "id": "notif-1", + "recipient": None, + "title": "Test", + "type": "info", + "created_at": datetime(2024, 1, 1), + "read": False, + "can_delete": True, + "subtitle": None, + "description": None, + "message": None, + "link": None, + "action_type": None, + "action_payload": None, + "data": None, + "target_roles": None, + "updated_at": None, + "recipient_role_id": None, + "recipient_role_name": None, + } + defaults.update(kwargs) + return SimpleNamespace(**defaults) + + class TestNotificationsManager: """Test suite for NotificationsManager.""" @@ -26,9 +53,11 @@ def mock_repository(self): return Mock() @pytest.fixture - def manager(self, mock_settings_manager): + def manager(self, mock_settings_manager, mock_repository): """Create NotificationsManager instance for testing.""" - return NotificationsManager(settings_manager=mock_settings_manager) + manager = NotificationsManager(settings_manager=mock_settings_manager) + manager._repo = mock_repository # Replace real repo with mock for unit testing + return manager @pytest.fixture def sample_notification(self): @@ -80,21 +109,7 @@ def test_get_notifications_broadcast_only(self, manager): mock_db = Mock() # Create notification DB objects - broadcast_notif = Mock(spec=NotificationDb) - broadcast_notif.id = "notif-1" - broadcast_notif.recipient = None - broadcast_notif.title = "Broadcast" - broadcast_notif.type = "info" - broadcast_notif.created_at = datetime(2024, 1, 1) - broadcast_notif.read = False - broadcast_notif.can_delete = True - broadcast_notif.subtitle = None - broadcast_notif.description = None - broadcast_notif.link = None - broadcast_notif.action_type = None - broadcast_notif.action_payload = None - broadcast_notif.data = None - broadcast_notif.target_roles = None + broadcast_notif = _make_notif_db(id="notif-1", recipient=None, title="Broadcast") manager._repo.get_multi.return_value = [broadcast_notif] @@ -108,37 +123,13 @@ def test_get_notifications_filtered_by_email(self, manager, sample_user_info): mock_db = Mock() # Create notifications - user_notif = Mock(spec=NotificationDb) - user_notif.id = "notif-1" - user_notif.recipient = "user@example.com" - user_notif.title = "User Notification" - user_notif.type = "info" - user_notif.created_at = datetime(2024, 1, 1) - user_notif.read = False - user_notif.can_delete = True - user_notif.subtitle = None - user_notif.description = None - user_notif.link = None - user_notif.action_type = None - user_notif.action_payload = None - user_notif.data = None - user_notif.target_roles = None - - other_notif = Mock(spec=NotificationDb) - other_notif.id = "notif-2" - other_notif.recipient = "other@example.com" - other_notif.title = "Other Notification" - other_notif.type = "info" - other_notif.created_at = datetime(2024, 1, 2) - other_notif.read = False - other_notif.can_delete = True - other_notif.subtitle = None - other_notif.description = None - other_notif.link = None - other_notif.action_type = None - other_notif.action_payload = None - other_notif.data = None - other_notif.target_roles = None + user_notif = _make_notif_db(id="notif-1", recipient="user@example.com", title="User Notification") + other_notif = _make_notif_db( + id="notif-2", + recipient="other@example.com", + title="Other Notification", + created_at=datetime(2024, 1, 2), + ) manager._repo.get_multi.return_value = [user_notif, other_notif] @@ -159,21 +150,9 @@ def test_get_notifications_filtered_by_role(self, manager, sample_user_info, moc mock_settings_manager.list_app_roles.return_value = [mock_role] # Create role-targeted notification - role_notif = Mock(spec=NotificationDb) - role_notif.id = "notif-1" - role_notif.recipient = "Data Consumer" - role_notif.title = "Role Notification" - role_notif.type = "info" - role_notif.created_at = datetime(2024, 1, 1) - role_notif.read = False - role_notif.can_delete = True - role_notif.subtitle = None - role_notif.description = None - role_notif.link = None - role_notif.action_type = None - role_notif.action_payload = None - role_notif.data = None - role_notif.target_roles = None + role_notif = _make_notif_db( + id="notif-1", recipient="Data Consumer", title="Role Notification" + ) manager._repo.get_multi.return_value = [role_notif] @@ -187,53 +166,9 @@ def test_get_notifications_sorts_by_created_at(self, manager): mock_db = Mock() # Create notifications with different timestamps - notif1 = Mock(spec=NotificationDb) - notif1.id = "notif-1" - notif1.recipient = None - notif1.title = "First" - notif1.type = "info" - notif1.created_at = datetime(2024, 1, 1) - notif1.read = False - notif1.can_delete = True - notif1.subtitle = None - notif1.description = None - notif1.link = None - notif1.action_type = None - notif1.action_payload = None - notif1.data = None - notif1.target_roles = None - - notif2 = Mock(spec=NotificationDb) - notif2.id = "notif-2" - notif2.recipient = None - notif2.title = "Second" - notif2.type = "info" - notif2.created_at = datetime(2024, 1, 3) - notif2.read = False - notif2.can_delete = True - notif2.subtitle = None - notif2.description = None - notif2.link = None - notif2.action_type = None - notif2.action_payload = None - notif2.data = None - notif2.target_roles = None - - notif3 = Mock(spec=NotificationDb) - notif3.id = "notif-3" - notif3.recipient = None - notif3.title = "Third" - notif3.type = "info" - notif3.created_at = datetime(2024, 1, 2) - notif3.read = False - notif3.can_delete = True - notif3.subtitle = None - notif3.description = None - notif3.link = None - notif3.action_type = None - notif3.action_payload = None - notif3.data = None - notif3.target_roles = None + notif1 = _make_notif_db(id="notif-1", title="First", created_at=datetime(2024, 1, 1)) + notif2 = _make_notif_db(id="notif-2", title="Second", created_at=datetime(2024, 1, 3)) + notif3 = _make_notif_db(id="notif-3", title="Third", created_at=datetime(2024, 1, 2)) manager._repo.get_multi.return_value = [notif1, notif2, notif3] @@ -250,15 +185,19 @@ def test_mark_as_read_success(self, manager): """Test marking notification as read.""" mock_db = Mock() - notif_db = Mock(spec=NotificationDb) - notif_db.id = "notif-123" - notif_db.read = False + notif_db = _make_notif_db(id="notif-123", read=False, title="Test") + + def mock_update(db, db_obj, obj_in): + for k, v in obj_in.items(): + setattr(db_obj, k, v) + return db_obj manager._repo.get.return_value = notif_db + manager._repo.update.side_effect = mock_update - result = manager.mark_as_read(mock_db, notification_id="notif-123") + result = manager.mark_notification_read(db=mock_db, notification_id="notif-123") - assert result is True + assert result is not None assert notif_db.read is True mock_db.commit.assert_called_once() @@ -267,8 +206,9 @@ def test_mark_as_read_not_found(self, manager): mock_db = Mock() manager._repo.get.return_value = None - with pytest.raises(NotificationNotFoundError): - manager.mark_as_read(mock_db, notification_id="nonexistent") + result = manager.mark_notification_read(db=mock_db, notification_id="nonexistent") + + assert result is None # Mark All as Read Tests @@ -276,39 +216,22 @@ def test_mark_all_as_read_success(self, manager, sample_user_info): """Test marking all notifications as read for user.""" mock_db = Mock() - notif1 = Mock(spec=NotificationDb) - notif1.id = "notif-1" - notif1.recipient = "user@example.com" - notif1.read = False - notif1.title = "Test 1" - notif1.type = "info" - notif1.created_at = datetime(2024, 1, 1) - notif1.can_delete = True - notif1.subtitle = None - notif1.description = None - notif1.link = None - notif1.action_type = None - notif1.action_payload = None - notif1.data = None - notif1.target_roles = None - - notif2 = Mock(spec=NotificationDb) - notif2.id = "notif-2" - notif2.recipient = "user@example.com" - notif2.read = False - notif2.title = "Test 2" - notif2.type = "info" - notif2.created_at = datetime(2024, 1, 2) - notif2.can_delete = True - notif2.subtitle = None - notif2.description = None - notif2.link = None - notif2.action_type = None - notif2.action_payload = None - notif2.data = None - notif2.target_roles = None + notif1 = _make_notif_db(id="notif-1", recipient="user@example.com", read=False, title="Test 1") + notif2 = _make_notif_db( + id="notif-2", + recipient="user@example.com", + read=False, + title="Test 2", + created_at=datetime(2024, 1, 2), + ) + + def mock_update(db, db_obj, obj_in): + for k, v in obj_in.items(): + setattr(db_obj, k, v) + return db_obj manager._repo.get_multi.return_value = [notif1, notif2] + manager._repo.update.side_effect = mock_update result = manager.mark_all_as_read(mock_db, user_info=sample_user_info) @@ -322,9 +245,7 @@ def test_delete_notification_success(self, manager): """Test deleting a notification.""" mock_db = Mock() - notif_db = Mock(spec=NotificationDb) - notif_db.id = "notif-123" - notif_db.can_delete = True + notif_db = _make_notif_db(id="notif-123", can_delete=True) manager._repo.get.return_value = notif_db manager._repo.remove.return_value = notif_db @@ -339,16 +260,15 @@ def test_delete_notification_not_found(self, manager): mock_db = Mock() manager._repo.get.return_value = None - with pytest.raises(NotificationNotFoundError): - manager.delete_notification(mock_db, notification_id="nonexistent") + result = manager.delete_notification(mock_db, notification_id="nonexistent") + + assert result is False def test_delete_notification_cannot_delete(self, manager): """Test trying to delete a notification marked as non-deletable.""" mock_db = Mock() - notif_db = Mock(spec=NotificationDb) - notif_db.id = "notif-123" - notif_db.can_delete = False + notif_db = _make_notif_db(id="notif-123", can_delete=False) manager._repo.get.return_value = notif_db diff --git a/src/backend/src/tests/unit/test_search_manager.py b/src/backend/src/tests/unit/test_search_manager.py index 756f84f0..8c17bf95 100644 --- a/src/backend/src/tests/unit/test_search_manager.py +++ b/src/backend/src/tests/unit/test_search_manager.py @@ -64,8 +64,8 @@ def sample_user(self): return UserInfo( username="testuser", email="test@example.com", - display_name="Test User", - active=True, + user="Test User", + ip="127.0.0.1", groups=["users", "data_consumers"], ) @@ -279,8 +279,8 @@ def test_search_user_without_groups(self, mock_searchable_manager, mock_auth_man user_no_groups = UserInfo( username="nogroups", email="nogroups@example.com", - display_name="No Groups User", - active=True, + user="No Groups User", + ip="127.0.0.1", groups=[], ) diff --git a/src/backend/src/tests/unit/test_security_features_manager.py b/src/backend/src/tests/unit/test_security_features_manager.py index 5f589cfa..1c679c91 100644 --- a/src/backend/src/tests/unit/test_security_features_manager.py +++ b/src/backend/src/tests/unit/test_security_features_manager.py @@ -188,7 +188,7 @@ def test_load_from_yaml_empty_file(self, mock_exists, mock_file, manager): assert result is True assert len(manager.features) == 0 - @patch('builtins.open', new_callable=mock_open, read_data="features:") + @patch('builtins.open', new_callable=mock_open, read_data="features: []") @patch('pathlib.Path.exists') def test_load_from_yaml_empty_features_list(self, mock_exists, mock_file, manager): """Test loading from YAML with empty features list.""" diff --git a/src/backend/src/tests/unit/test_settings_manager.py b/src/backend/src/tests/unit/test_settings_manager.py index 0f4815bc..0faad61b 100644 --- a/src/backend/src/tests/unit/test_settings_manager.py +++ b/src/backend/src/tests/unit/test_settings_manager.py @@ -23,9 +23,29 @@ class TestSettingsManager: @pytest.fixture def mock_settings(self): - """Create mock settings.""" + """Create mock settings with attributes required by get_settings().""" mock = MagicMock(spec=Settings) mock.job_cluster_id = "test-cluster" + mock.WORKSPACE_DEPLOYMENT_PATH = None + mock.DATABRICKS_CATALOG = None + mock.DATABRICKS_SCHEMA = None + mock.DATABRICKS_VOLUME = None + mock.APP_AUDIT_LOG_DIR = None + mock.LLM_ENABLED = False + mock.LLM_ENDPOINT = None + mock.LLM_SYSTEM_PROMPT = None + mock.LLM_DISCLAIMER_TEXT = None + mock.DELIVERY_MODE_DIRECT = False + mock.DELIVERY_MODE_INDIRECT = False + mock.DELIVERY_MODE_MANUAL = True + mock.DELIVERY_DIRECT_DRY_RUN = False + mock.GIT_REPO_URL = None + mock.GIT_BRANCH = None + mock.GIT_USERNAME = None + mock.UI_I18N_ENABLED = True + mock.UI_CUSTOM_LOGO_URL = None + mock.UI_ABOUT_CONTENT = None + mock.UI_CUSTOM_CSS = None mock.to_dict.return_value = {"job_cluster_id": "test-cluster"} return mock diff --git a/src/backend/src/utils/startup_tasks.py b/src/backend/src/utils/startup_tasks.py index a39a787b..563fb512 100644 --- a/src/backend/src/utils/startup_tasks.py +++ b/src/backend/src/utils/startup_tasks.py @@ -294,6 +294,11 @@ def initialize_managers(app: FastAPI): # SEARCHABLE_ASSET_MANAGERS.append(metadata_manager) # If it's searchable # logger.info("MetadataManager initialized.") + # Graph Explorer Manager (no DB needed - uses Databricks SQL) + from src.controller.graph_explorer_manager import GraphExplorerManager + app.state.graph_explorer_manager = GraphExplorerManager(settings=settings) + logger.info("GraphExplorerManager initialized.") + logger.info("All managers instantiated and stored in app.state.") # Defer SearchManager initialization until after initial data loading completes diff --git a/src/frontend/package.json b/src/frontend/package.json index 870e64f4..3cba73db 100644 --- a/src/frontend/package.json +++ b/src/frontend/package.json @@ -35,6 +35,7 @@ "@radix-ui/react-toast": "^1.2.11", "@radix-ui/react-tooltip": "^1.1.8", "@tanstack/react-table": "^8.21.2", + "@types/d3-force": "^3.0.10", "@types/node": "^20.11.16", "@types/react": "^18.2.55", "@types/react-dom": "^18.2.19", @@ -47,6 +48,7 @@ "clsx": "^2.1.1", "cmdk": "^1.1.1", "cytoscape": "^3.33.1", + "d3-force": "^3.0.0", "date-fns": "^4.1.0", "framer-motion": "^12.6.3", "i18next": "^25.5.3", @@ -56,6 +58,7 @@ "react-cytoscapejs": "^2.0.0", "react-dom": "^18.2.0", "react-dropzone": "^14.3.8", + "react-force-graph-2d": "^1.29.1", "react-hook-form": "^7.56.1", "react-i18next": "^16.0.0", "react-markdown": "^10.1.0", diff --git a/src/frontend/postcss.config.js b/src/frontend/postcss.config.cjs similarity index 100% rename from src/frontend/postcss.config.js rename to src/frontend/postcss.config.cjs diff --git a/src/frontend/src/app.tsx b/src/frontend/src/app.tsx index 2f86da0c..5ac34a25 100644 --- a/src/frontend/src/app.tsx +++ b/src/frontend/src/app.tsx @@ -47,6 +47,7 @@ import ProjectsView from './views/projects'; import AuditTrail from './views/audit-trail'; import WorkflowDesignerView from './views/workflow-designer'; import Workflows from './views/workflows'; +import GraphExplorer from './views/graph-explorer'; export default function App() { const fetchUserInfo = useUserStore((state: any) => state.fetchUserInfo); @@ -118,6 +119,7 @@ export default function App() { } /> } /> } /> + } /> } /> } /> diff --git a/src/frontend/src/components/graph-explorer/diagram-manager.test.tsx b/src/frontend/src/components/graph-explorer/diagram-manager.test.tsx new file mode 100644 index 00000000..b4bdd74e --- /dev/null +++ b/src/frontend/src/components/graph-explorer/diagram-manager.test.tsx @@ -0,0 +1,186 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { render, screen, within } from '@testing-library/react'; +import userEvent from '@testing-library/user-event'; +import { DiagramManager, getDiagrams, saveDiagrams, type SavedDiagram } from './diagram-manager'; +import type { GraphData } from '@/types/graph-explorer'; + +// Mock react-i18next +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string, opts?: Record) => { + // Minimal interpolation for description text + if (opts && key === 'diagrams.saveDescription') { + return `Save current view (${opts.nodeCount} nodes, ${opts.edgeCount} edges)`; + } + return key; + }, + i18n: { language: 'en' }, + }), +})); + +const sampleData: GraphData = { + nodes: [ + { id: 'n1', label: 'Alice', type: 'Person', properties: {}, status: 'existing' as const }, + { id: 'n2', label: 'Bob', type: 'Person', properties: {}, status: 'existing' as const }, + ], + edges: [ + { id: 'e1', source: 'n1', target: 'n2', relationshipType: 'KNOWS', properties: {}, status: 'existing' as const }, + ], +}; + +const emptyData: GraphData = { nodes: [], edges: [] }; + +describe('diagram localStorage helpers', () => { + beforeEach(() => { + localStorage.clear(); + }); + + it('getDiagrams returns empty array for unknown table', () => { + expect(getDiagrams('nonexistent')).toEqual([]); + }); + + it('saveDiagrams + getDiagrams round-trips data', () => { + const diagrams: SavedDiagram[] = [ + { + id: 'diag-1', + name: 'Test', + savedAt: '2026-01-01T00:00:00Z', + nodeCount: 2, + edgeCount: 1, + data: sampleData, + }, + ]; + saveDiagrams('test-table', diagrams); + const loaded = getDiagrams('test-table'); + expect(loaded).toEqual(diagrams); + }); + + it('getDiagrams handles corrupt data gracefully', () => { + localStorage.setItem('graph-explorer-diagrams:bad', '{invalid json'); + expect(getDiagrams('bad')).toEqual([]); + }); +}); + +describe('DiagramManager component', () => { + beforeEach(() => { + localStorage.clear(); + }); + + it('shows empty state when no diagrams saved', () => { + render( + , + ); + expect(screen.getByText('diagrams.empty')).toBeInTheDocument(); + }); + + it('disables save button when data is empty', () => { + render( + , + ); + const saveBtn = screen.getByText('diagrams.save').closest('button'); + expect(saveBtn).toBeDisabled(); + }); + + it('opens save dialog and saves a diagram', async () => { + const user = userEvent.setup(); + render( + , + ); + + // Click save button to open dialog + await user.click(screen.getByText('diagrams.save')); + // Type a name + const input = screen.getByPlaceholderText('diagrams.namePlaceholder'); + await user.type(input, 'My Diagram'); + // Click save in dialog + const dialogSaveButtons = screen.getAllByText('diagrams.save'); + // The second one is inside the dialog + await user.click(dialogSaveButtons[dialogSaveButtons.length - 1]); + + // Diagram should now appear in the list + expect(screen.getByText('My Diagram')).toBeInTheDocument(); + // Should be saved in localStorage + const saved = getDiagrams('test'); + expect(saved).toHaveLength(1); + expect(saved[0].name).toBe('My Diagram'); + }); + + it('restores a diagram when restore button is clicked', async () => { + const user = userEvent.setup(); + const onRestore = vi.fn(); + + // Pre-save a diagram + saveDiagrams('test', [ + { + id: 'diag-1', + name: 'Saved View', + savedAt: '2026-01-01T00:00:00Z', + nodeCount: 2, + edgeCount: 1, + data: sampleData, + }, + ]); + + render( + , + ); + + // Find the diagram entry and hover to reveal buttons + const entry = screen.getByText('Saved View').closest('div[class*="group"]')!; + // Click the restore button (FolderOpen icon button) + const restoreBtn = within(entry as HTMLElement).getAllByRole('button')[0]; + await user.click(restoreBtn); + + expect(onRestore).toHaveBeenCalledWith(sampleData); + }); + + it('deletes a diagram when delete button is clicked', async () => { + const user = userEvent.setup(); + + saveDiagrams('test', [ + { + id: 'diag-1', + name: 'To Delete', + savedAt: '2026-01-01T00:00:00Z', + nodeCount: 2, + edgeCount: 1, + data: sampleData, + }, + ]); + + render( + , + ); + + expect(screen.getByText('To Delete')).toBeInTheDocument(); + + // Click delete button + const entry = screen.getByText('To Delete').closest('div[class*="group"]')!; + const deleteBtn = within(entry as HTMLElement).getAllByRole('button')[1]; + await user.click(deleteBtn); + + // Should be removed from view and localStorage + expect(screen.queryByText('To Delete')).not.toBeInTheDocument(); + expect(getDiagrams('test')).toHaveLength(0); + }); +}); diff --git a/src/frontend/src/components/graph-explorer/diagram-manager.tsx b/src/frontend/src/components/graph-explorer/diagram-manager.tsx new file mode 100644 index 00000000..81d26b0b --- /dev/null +++ b/src/frontend/src/components/graph-explorer/diagram-manager.tsx @@ -0,0 +1,226 @@ +/** + * Diagram Manager for Graph Explorer. + * + * Allows users to save, name, and restore curated subgraph snapshots. + * Diagrams are persisted in localStorage, keyed by table name so each + * dataset has its own diagram collection. + */ + +import React, { useState, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'; +import { Button } from '@/components/ui/button'; +import { Input } from '@/components/ui/input'; +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, + DialogTrigger, +} from '@/components/ui/dialog'; +import { Save, FolderOpen, Trash2, Plus } from 'lucide-react'; +import { cn } from '@/lib/utils'; +import type { GraphData } from '@/types/graph-explorer'; + +// --------------------------------------------------------------------------- +// Types +// --------------------------------------------------------------------------- + +export interface SavedDiagram { + id: string; + name: string; + /** ISO timestamp */ + savedAt: string; + nodeCount: number; + edgeCount: number; + data: GraphData; +} + +// --------------------------------------------------------------------------- +// localStorage helpers +// --------------------------------------------------------------------------- + +const STORAGE_PREFIX = 'graph-explorer-diagrams:'; + +export function getDiagrams(tableName: string): SavedDiagram[] { + try { + const raw = localStorage.getItem(`${STORAGE_PREFIX}${tableName}`); + if (!raw) return []; + return JSON.parse(raw) as SavedDiagram[]; + } catch { + return []; + } +} + +export function saveDiagrams(tableName: string, diagrams: SavedDiagram[]): void { + localStorage.setItem(`${STORAGE_PREFIX}${tableName}`, JSON.stringify(diagrams)); +} + +// --------------------------------------------------------------------------- +// Component +// --------------------------------------------------------------------------- + +export interface DiagramManagerProps { + tableName: string; + currentData: GraphData; + onRestoreDiagram: (data: GraphData) => void; + disabled?: boolean; + className?: string; +} + +export function DiagramManager({ + tableName, + currentData, + onRestoreDiagram, + disabled = false, + className, +}: DiagramManagerProps) { + const { t } = useTranslation('graph-explorer'); + + const [diagrams, setDiagrams] = useState(() => getDiagrams(tableName)); + const [saveDialogOpen, setSaveDialogOpen] = useState(false); + const [diagramName, setDiagramName] = useState(''); + + // Reload diagrams when tableName changes + React.useEffect(() => { + setDiagrams(getDiagrams(tableName)); + }, [tableName]); + + const handleSave = useCallback(() => { + const name = diagramName.trim() || `Diagram ${diagrams.length + 1}`; + const newDiagram: SavedDiagram = { + id: `diag-${Date.now()}`, + name, + savedAt: new Date().toISOString(), + nodeCount: currentData.nodes.length, + edgeCount: currentData.edges.length, + data: currentData, + }; + const updated = [newDiagram, ...diagrams]; + setDiagrams(updated); + saveDiagrams(tableName, updated); + setDiagramName(''); + setSaveDialogOpen(false); + }, [diagramName, diagrams, currentData, tableName]); + + const handleDelete = useCallback( + (diagramId: string) => { + const updated = diagrams.filter((d) => d.id !== diagramId); + setDiagrams(updated); + saveDiagrams(tableName, updated); + }, + [diagrams, tableName], + ); + + const handleRestore = useCallback( + (diagram: SavedDiagram) => { + onRestoreDiagram(diagram.data); + }, + [onRestoreDiagram], + ); + + const formatDate = (iso: string) => { + try { + return new Date(iso).toLocaleString(undefined, { + month: 'short', + day: 'numeric', + hour: '2-digit', + minute: '2-digit', + }); + } catch { + return iso; + } + }; + + return ( + + +
+ {t('diagrams.title')} + + + + + + + {t('diagrams.saveTitle')} + + {t('diagrams.saveDescription', { + nodeCount: currentData.nodes.length, + edgeCount: currentData.edges.length, + })} + + + setDiagramName(e.target.value)} + onKeyDown={(e) => { + if (e.key === 'Enter') handleSave(); + }} + autoFocus + /> + + + + + + +
+
+ + + {diagrams.length === 0 ? ( +

{t('diagrams.empty')}

+ ) : ( + diagrams.map((d) => ( +
+
+
{d.name}
+
+ {d.nodeCount}n / {d.edgeCount}e · {formatDate(d.savedAt)} +
+
+ + +
+ )) + )} +
+
+ ); +} diff --git a/src/frontend/src/components/graph-explorer/graph-context-menu.test.tsx b/src/frontend/src/components/graph-explorer/graph-context-menu.test.tsx new file mode 100644 index 00000000..ca2542f6 --- /dev/null +++ b/src/frontend/src/components/graph-explorer/graph-context-menu.test.tsx @@ -0,0 +1,273 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { render, screen, fireEvent } from '@testing-library/react'; +import userEvent from '@testing-library/user-event'; +import { + GraphContextMenu, + type GraphContextMenuProps, + type ContextMenuTarget, + type ContextMenuPosition, +} from './graph-context-menu'; + +// Mock react-i18next — return the key as the display text +vi.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + i18n: { language: 'en' }, + }), +})); + +// Mock navigator.clipboard +const mockWriteText = vi.fn().mockResolvedValue(undefined); + +describe('GraphContextMenu', () => { + const defaultPosition: ContextMenuPosition = { x: 100, y: 200 }; + + const defaultCallbacks = { + onClose: vi.fn(), + onExpandNeighbors: vi.fn(), + onExpandByType: vi.fn(), + onCollapseNode: vi.fn(), + onEditNode: vi.fn(), + onDeleteNode: vi.fn(), + onCenterOnNode: vi.fn(), + onEditEdge: vi.fn(), + onDeleteEdge: vi.fn(), + onCreateNode: vi.fn(), + onResetView: vi.fn(), + onFitToScreen: vi.fn(), + }; + + beforeEach(() => { + vi.clearAllMocks(); + mockWriteText.mockResolvedValue(undefined); + // Ensure navigator.clipboard.writeText is our mock + if (!navigator.clipboard || navigator.clipboard.writeText !== mockWriteText) { + Object.defineProperty(global.navigator, 'clipboard', { + value: { writeText: mockWriteText }, + writable: true, + configurable: true, + }); + } + }); + + function renderMenu(overrides: Partial = {}) { + return render( + , + ); + } + + // ------------------------------------------------------------------- + // Visibility + // ------------------------------------------------------------------- + it('renders nothing when position is null', () => { + const { container } = renderMenu({ position: null }); + expect(container.innerHTML).toBe(''); + }); + + it('renders nothing when target is null', () => { + const { container } = renderMenu({ target: null }); + expect(container.innerHTML).toBe(''); + }); + + // ------------------------------------------------------------------- + // Node context menu + // ------------------------------------------------------------------- + describe('node context menu', () => { + const nodeTarget: ContextMenuTarget = { + type: 'node', + id: 'node-1', + label: 'Alice', + nodeType: 'Person', + isExpanded: false, + connectedEdgeTypes: ['KNOWS', 'WORKS_AT'], + }; + + it('renders node menu items', () => { + renderMenu({ target: nodeTarget }); + // Header + expect(screen.getByText(/Alice/)).toBeInTheDocument(); + // Expand all neighbors + expect(screen.getByText('contextMenu.expandAll')).toBeInTheDocument(); + // Edit, Delete, Center, Copy + expect(screen.getByText('contextMenu.editNode')).toBeInTheDocument(); + expect(screen.getByText('contextMenu.deleteNode')).toBeInTheDocument(); + expect(screen.getByText('contextMenu.centerOnNode')).toBeInTheDocument(); + expect(screen.getByText('contextMenu.copyNodeId')).toBeInTheDocument(); + }); + + it('calls onExpandNeighbors with "both" when expand all is clicked', async () => { + const user = userEvent.setup(); + renderMenu({ target: nodeTarget }); + await user.click(screen.getByText('contextMenu.expandAll')); + expect(defaultCallbacks.onExpandNeighbors).toHaveBeenCalledWith('node-1', 'both'); + expect(defaultCallbacks.onClose).toHaveBeenCalled(); + }); + + it('calls onEditNode when edit is clicked', async () => { + const user = userEvent.setup(); + renderMenu({ target: nodeTarget }); + await user.click(screen.getByText('contextMenu.editNode')); + expect(defaultCallbacks.onEditNode).toHaveBeenCalledWith('node-1'); + expect(defaultCallbacks.onClose).toHaveBeenCalled(); + }); + + it('calls onDeleteNode when delete is clicked', async () => { + const user = userEvent.setup(); + renderMenu({ target: nodeTarget }); + await user.click(screen.getByText('contextMenu.deleteNode')); + expect(defaultCallbacks.onDeleteNode).toHaveBeenCalledWith('node-1'); + expect(defaultCallbacks.onClose).toHaveBeenCalled(); + }); + + it('calls onCenterOnNode when center is clicked', async () => { + const user = userEvent.setup(); + renderMenu({ target: nodeTarget }); + await user.click(screen.getByText('contextMenu.centerOnNode')); + expect(defaultCallbacks.onCenterOnNode).toHaveBeenCalledWith('node-1'); + expect(defaultCallbacks.onClose).toHaveBeenCalled(); + }); + + it('copies node ID to clipboard when copy is clicked', async () => { + renderMenu({ target: nodeTarget }); + fireEvent.click(screen.getByText('contextMenu.copyNodeId')); + // navigator.clipboard.writeText returns a Promise, so wait a tick + await vi.waitFor(() => { + expect(mockWriteText).toHaveBeenCalledWith('node-1'); + }); + expect(defaultCallbacks.onClose).toHaveBeenCalled(); + }); + + it('shows collapse instead of expand sub-menu when node is expanded', () => { + const expandedTarget: ContextMenuTarget = { ...nodeTarget, isExpanded: true }; + renderMenu({ target: expandedTarget }); + expect(screen.getByText('contextMenu.collapse')).toBeInTheDocument(); + // Expand all should still be present (re-expand) + expect(screen.getByText('contextMenu.expandAll')).toBeInTheDocument(); + }); + + it('calls onCollapseNode when collapse is clicked', async () => { + const user = userEvent.setup(); + const expandedTarget: ContextMenuTarget = { ...nodeTarget, isExpanded: true }; + renderMenu({ target: expandedTarget }); + await user.click(screen.getByText('contextMenu.collapse')); + expect(defaultCallbacks.onCollapseNode).toHaveBeenCalledWith('node-1'); + expect(defaultCallbacks.onClose).toHaveBeenCalled(); + }); + }); + + // ------------------------------------------------------------------- + // Edge context menu + // ------------------------------------------------------------------- + describe('edge context menu', () => { + const edgeTarget: ContextMenuTarget = { + type: 'edge', + id: 'edge-1', + relationshipType: 'KNOWS', + }; + + it('renders edge menu items', () => { + renderMenu({ target: edgeTarget }); + expect(screen.getByText('KNOWS')).toBeInTheDocument(); + expect(screen.getByText('contextMenu.editEdge')).toBeInTheDocument(); + expect(screen.getByText('contextMenu.deleteEdge')).toBeInTheDocument(); + expect(screen.getByText('contextMenu.copyEdgeDetails')).toBeInTheDocument(); + }); + + it('calls onEditEdge when edit is clicked', async () => { + const user = userEvent.setup(); + renderMenu({ target: edgeTarget }); + await user.click(screen.getByText('contextMenu.editEdge')); + expect(defaultCallbacks.onEditEdge).toHaveBeenCalledWith('edge-1'); + expect(defaultCallbacks.onClose).toHaveBeenCalled(); + }); + + it('calls onDeleteEdge when delete is clicked', async () => { + const user = userEvent.setup(); + renderMenu({ target: edgeTarget }); + await user.click(screen.getByText('contextMenu.deleteEdge')); + expect(defaultCallbacks.onDeleteEdge).toHaveBeenCalledWith('edge-1'); + expect(defaultCallbacks.onClose).toHaveBeenCalled(); + }); + }); + + // ------------------------------------------------------------------- + // Canvas context menu + // ------------------------------------------------------------------- + describe('canvas context menu', () => { + const canvasTarget: ContextMenuTarget = { type: 'canvas' }; + + it('renders canvas menu items', () => { + renderMenu({ target: canvasTarget }); + expect(screen.getByText('contextMenu.createNode')).toBeInTheDocument(); + expect(screen.getByText('contextMenu.fitToScreen')).toBeInTheDocument(); + expect(screen.getByText('contextMenu.resetView')).toBeInTheDocument(); + }); + + it('calls onCreateNode when create node is clicked', async () => { + const user = userEvent.setup(); + renderMenu({ target: canvasTarget }); + await user.click(screen.getByText('contextMenu.createNode')); + expect(defaultCallbacks.onCreateNode).toHaveBeenCalled(); + expect(defaultCallbacks.onClose).toHaveBeenCalled(); + }); + + it('calls onFitToScreen when fit to screen is clicked', async () => { + const user = userEvent.setup(); + renderMenu({ target: canvasTarget }); + await user.click(screen.getByText('contextMenu.fitToScreen')); + expect(defaultCallbacks.onFitToScreen).toHaveBeenCalled(); + expect(defaultCallbacks.onClose).toHaveBeenCalled(); + }); + + it('calls onResetView when reset view is clicked', async () => { + const user = userEvent.setup(); + renderMenu({ target: canvasTarget }); + await user.click(screen.getByText('contextMenu.resetView')); + expect(defaultCallbacks.onResetView).toHaveBeenCalled(); + expect(defaultCallbacks.onClose).toHaveBeenCalled(); + }); + }); + + // ------------------------------------------------------------------- + // Dismiss behavior + // ------------------------------------------------------------------- + describe('dismiss behavior', () => { + it('calls onClose when Escape is pressed', async () => { + renderMenu({ target: { type: 'canvas' } }); + // Wait for the event listener to be attached (setTimeout(0) in the component) + await vi.waitFor(() => { + fireEvent.keyDown(document, { key: 'Escape' }); + expect(defaultCallbacks.onClose).toHaveBeenCalled(); + }); + }); + + it('calls onClose when clicking outside the menu', async () => { + renderMenu({ target: { type: 'canvas' } }); + // Wait for the event listener to be attached + await vi.waitFor(() => { + fireEvent.mouseDown(document.body); + expect(defaultCallbacks.onClose).toHaveBeenCalled(); + }); + }); + }); + + // ------------------------------------------------------------------- + // Position adjustment + // ------------------------------------------------------------------- + describe('position adjustment', () => { + it('renders at the provided position', () => { + const { container } = renderMenu({ + target: { type: 'canvas' }, + position: { x: 150, y: 250 }, + }); + const menuEl = container.firstElementChild as HTMLElement; + expect(menuEl.style.left).toBe('150px'); + expect(menuEl.style.top).toBe('250px'); + }); + }); +}); diff --git a/src/frontend/src/components/graph-explorer/graph-context-menu.tsx b/src/frontend/src/components/graph-explorer/graph-context-menu.tsx new file mode 100644 index 00000000..2b5c354f --- /dev/null +++ b/src/frontend/src/components/graph-explorer/graph-context-menu.tsx @@ -0,0 +1,409 @@ +/** + * Context menu for Graph Explorer. + * + * Renders a positioned menu overlay triggered by right-click on nodes, edges, + * or the canvas background. Since the graph is rendered on a , Radix + * context-menu (DOM trigger-based) can't be used directly, so we build a custom + * positioned menu with Shadcn-consistent styling. + */ + +import React, { useEffect, useRef, useCallback } from 'react'; +import { cn } from '@/lib/utils'; +import { + Expand, + ArrowUpRight, + ArrowDownLeft, + Shrink, + Pencil, + Trash2, + Copy, + Crosshair, + Maximize, + Plus, + ChevronRight, +} from 'lucide-react'; +import { useTranslation } from 'react-i18next'; + +// --------------------------------------------------------------------------- +// Types +// --------------------------------------------------------------------------- + +export type ContextMenuTargetType = 'node' | 'edge' | 'canvas'; + +export interface ContextMenuTarget { + type: ContextMenuTargetType; + /** Node or edge ID */ + id?: string; + /** Node label for display */ + label?: string; + /** Node type for display */ + nodeType?: string; + /** Edge relationship type for display */ + relationshipType?: string; + /** Whether this node is currently expanded (has neighbor data loaded) */ + isExpanded?: boolean; + /** Edge types connected to this node (for "Expand by Type" submenu) */ + connectedEdgeTypes?: string[]; +} + +export interface ContextMenuPosition { + x: number; + y: number; +} + +interface MenuItemProps { + icon?: React.ReactNode; + label: string; + onClick: () => void; + variant?: 'default' | 'destructive'; + disabled?: boolean; +} + +interface SubMenuProps { + icon?: React.ReactNode; + label: string; + children: React.ReactNode; +} + +// --------------------------------------------------------------------------- +// Menu item and sub-menu components +// --------------------------------------------------------------------------- + +function MenuItem({ icon, label, onClick, variant = 'default', disabled = false }: MenuItemProps) { + return ( + + ); +} + +function SubMenu({ icon, label, children }: SubMenuProps) { + const [isOpen, setIsOpen] = React.useState(false); + const subMenuRef = useRef(null); + + return ( +
setIsOpen(true)} + onMouseLeave={() => setIsOpen(false)} + > +
+ {icon && {icon}} + {label} + +
+ + {isOpen && ( +
+ {children} +
+ )} +
+ ); +} + +function MenuSeparator() { + return
; +} + +// --------------------------------------------------------------------------- +// Main context menu +// --------------------------------------------------------------------------- + +export interface GraphContextMenuProps { + position: ContextMenuPosition | null; + target: ContextMenuTarget | null; + onClose: () => void; + /** Expand neighbors of a node in a given direction */ + onExpandNeighbors?: (nodeId: string, direction: 'outgoing' | 'incoming' | 'both') => void; + /** Expand neighbors filtered by a specific edge type */ + onExpandByType?: (nodeId: string, edgeType: string) => void; + /** Collapse a previously expanded node */ + onCollapseNode?: (nodeId: string) => void; + /** Edit a node */ + onEditNode?: (nodeId: string) => void; + /** Delete a node */ + onDeleteNode?: (nodeId: string) => void; + /** Center the view on a node */ + onCenterOnNode?: (nodeId: string) => void; + /** Edit an edge */ + onEditEdge?: (edgeId: string) => void; + /** Delete an edge */ + onDeleteEdge?: (edgeId: string) => void; + /** Create a new node (canvas action) */ + onCreateNode?: () => void; + /** Reset view / fit to screen */ + onResetView?: () => void; + /** Fit graph to screen */ + onFitToScreen?: () => void; +} + +export function GraphContextMenu({ + position, + target, + onClose, + onExpandNeighbors, + onExpandByType, + onCollapseNode, + onEditNode, + onDeleteNode, + onCenterOnNode, + onEditEdge, + onDeleteEdge, + onCreateNode, + onResetView, + onFitToScreen, +}: GraphContextMenuProps) { + const menuRef = useRef(null); + const { t } = useTranslation('graph-explorer'); + + // Close on click outside + useEffect(() => { + if (!position) return; + + const handleClickOutside = (e: MouseEvent) => { + if (menuRef.current && !menuRef.current.contains(e.target as Node)) { + onClose(); + } + }; + + const handleKeyDown = (e: KeyboardEvent) => { + if (e.key === 'Escape') onClose(); + }; + + // Delay attaching to prevent the right-click from immediately closing + const timer = setTimeout(() => { + document.addEventListener('mousedown', handleClickOutside); + document.addEventListener('keydown', handleKeyDown); + }, 0); + + return () => { + clearTimeout(timer); + document.removeEventListener('mousedown', handleClickOutside); + document.removeEventListener('keydown', handleKeyDown); + }; + }, [position, onClose]); + + // Adjust position to stay on-screen + const adjustedPosition = React.useMemo(() => { + if (!position) return null; + const menuWidth = 220; + const menuHeight = 300; + const padding = 8; + + let { x, y } = position; + if (typeof window !== 'undefined') { + if (x + menuWidth + padding > window.innerWidth) { + x = window.innerWidth - menuWidth - padding; + } + if (y + menuHeight + padding > window.innerHeight) { + y = window.innerHeight - menuHeight - padding; + } + } + return { x: Math.max(padding, x), y: Math.max(padding, y) }; + }, [position]); + + const handleAction = useCallback( + (action: () => void) => { + action(); + onClose(); + }, + [onClose], + ); + + if (!position || !target || !adjustedPosition) return null; + + return ( +
e.preventDefault()} + > + {/* --------- Node context menu --------- */} + {target.type === 'node' && target.id && ( + <> + {/* Header */} +
+ {target.label || target.id} + {target.nodeType && ( + ({target.nodeType}) + )} +
+ + + {/* Expand / Collapse */} + {!target.isExpanded ? ( + <> + } + label={t('contextMenu.expandAll')} + onClick={() => handleAction(() => onExpandNeighbors?.(target.id!, 'both'))} + /> + } + label={t('contextMenu.expandByType')} + > + } + label={t('contextMenu.expandOutgoing')} + onClick={() => handleAction(() => onExpandNeighbors?.(target.id!, 'outgoing'))} + /> + } + label={t('contextMenu.expandIncoming')} + onClick={() => handleAction(() => onExpandNeighbors?.(target.id!, 'incoming'))} + /> + {target.connectedEdgeTypes && target.connectedEdgeTypes.length > 0 && ( + <> + + {target.connectedEdgeTypes.map((edgeType) => ( + handleAction(() => onExpandByType?.(target.id!, edgeType))} + /> + ))} + + )} + + + ) : ( + <> + } + label={t('contextMenu.expandAll')} + onClick={() => handleAction(() => onExpandNeighbors?.(target.id!, 'both'))} + /> + } + label={t('contextMenu.collapse')} + onClick={() => handleAction(() => onCollapseNode?.(target.id!))} + /> + + )} + + + + {/* Actions */} + } + label={t('contextMenu.centerOnNode')} + onClick={() => handleAction(() => onCenterOnNode?.(target.id!))} + /> + } + label={t('contextMenu.editNode')} + onClick={() => handleAction(() => onEditNode?.(target.id!))} + /> + } + label={t('contextMenu.copyNodeId')} + onClick={() => + handleAction(() => { + navigator.clipboard.writeText(target.id!); + }) + } + /> + + + + } + label={t('contextMenu.deleteNode')} + onClick={() => handleAction(() => onDeleteNode?.(target.id!))} + variant="destructive" + /> + + )} + + {/* --------- Edge context menu --------- */} + {target.type === 'edge' && target.id && ( + <> +
+ {target.relationshipType || target.id} +
+ + + } + label={t('contextMenu.editEdge')} + onClick={() => handleAction(() => onEditEdge?.(target.id!))} + /> + } + label={t('contextMenu.copyEdgeDetails')} + onClick={() => + handleAction(() => { + const text = `${target.relationshipType || ''} (${target.id})`; + navigator.clipboard.writeText(text); + }) + } + /> + + + + } + label={t('contextMenu.deleteEdge')} + onClick={() => handleAction(() => onDeleteEdge?.(target.id!))} + variant="destructive" + /> + + )} + + {/* --------- Canvas context menu --------- */} + {target.type === 'canvas' && ( + <> + } + label={t('contextMenu.createNode')} + onClick={() => handleAction(() => onCreateNode?.())} + /> + + + + } + label={t('contextMenu.fitToScreen')} + onClick={() => handleAction(() => onFitToScreen?.())} + /> + } + label={t('contextMenu.resetView')} + onClick={() => handleAction(() => onResetView?.())} + /> + + )} +
+ ); +} diff --git a/src/frontend/src/components/graph-explorer/graph-controls.tsx b/src/frontend/src/components/graph-explorer/graph-controls.tsx new file mode 100644 index 00000000..1616529e --- /dev/null +++ b/src/frontend/src/components/graph-explorer/graph-controls.tsx @@ -0,0 +1,392 @@ +import { useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'; +import { Switch } from '@/components/ui/switch'; +import { Label } from '@/components/ui/label'; +import { Badge } from '@/components/ui/badge'; +import { Button } from '@/components/ui/button'; +import { Separator } from '@/components/ui/separator'; +import { ScrollArea } from '@/components/ui/scroll-area'; +import { RefreshCw, Eye, Tag } from 'lucide-react'; +import { + getUniqueNodeTypes, + getUniqueRelationshipTypes, + getColorForType, + type GraphData, +} from '@/types/graph-explorer'; +import { cn } from '@/lib/utils'; + +interface GraphControlsProps { + showProposed: boolean; + onToggleProposed: (show: boolean) => void; + selectedNodeTypes: string[]; + onNodeTypeChange: (types: string[]) => void; + selectedRelationshipTypes: string[]; + onRelationshipTypeChange: (types: string[]) => void; + showNodeLabels: boolean; + onToggleNodeLabels: (show: boolean) => void; + showEdgeLabels: boolean; + onToggleEdgeLabels: (show: boolean) => void; + edgeLength: number; + onEdgeLengthChange: (length: number) => void; + edgeOpacity?: number; + onEdgeOpacityChange?: (opacity: number) => void; + nodeSize: number; + onNodeSizeChange: (size: number) => void; + onResetView: () => void; + graphData: GraphData; + stats: { totalNodes: number; totalEdges: number; newNodes: number; newEdges: number }; +} + +export default function GraphControls({ + showProposed, + onToggleProposed, + selectedNodeTypes, + onNodeTypeChange, + selectedRelationshipTypes, + onRelationshipTypeChange, + showNodeLabels, + onToggleNodeLabels, + showEdgeLabels, + onToggleEdgeLabels, + edgeLength, + onEdgeLengthChange, + edgeOpacity, + onEdgeOpacityChange, + nodeSize, + onNodeSizeChange, + onResetView, + graphData, + stats, +}: GraphControlsProps) { + const { t } = useTranslation('graph-explorer'); + const nodeTypes = useMemo(() => getUniqueNodeTypes(graphData), [graphData]); + const relationshipTypes = useMemo(() => getUniqueRelationshipTypes(graphData), [graphData]); + + // Color map for node type legend + const isDarkMode = typeof document !== 'undefined' && document.documentElement.classList.contains('dark'); + const nodeTypeColors = useMemo(() => { + const colorMap = new Map(); + nodeTypes.forEach((type) => { + colorMap.set(type, getColorForType(type, isDarkMode)); + }); + return colorMap; + }, [nodeTypes, isDarkMode]); + + // "Show all" = empty array. Clicking a type when showing all deselects it + // (shows all except it). Select All resets to empty array. + const handleNodeTypeToggle = (type: string) => { + if (selectedNodeTypes.length === 0) { + // "Show all" mode — deselect this one (show all except it) + onNodeTypeChange(nodeTypes.filter((t) => t !== type)); + } else if (selectedNodeTypes.includes(type)) { + onNodeTypeChange(selectedNodeTypes.filter((t) => t !== type)); + } else { + onNodeTypeChange([...selectedNodeTypes, type]); + } + }; + + const handleRelationshipTypeToggle = (type: string) => { + if (selectedRelationshipTypes.length === 0) { + onRelationshipTypeChange(relationshipTypes.filter((t) => t !== type)); + } else if (selectedRelationshipTypes.includes(type)) { + onRelationshipTypeChange(selectedRelationshipTypes.filter((t) => t !== type)); + } else { + onRelationshipTypeChange([...selectedRelationshipTypes, type]); + } + }; + + const allNodesSelected = selectedNodeTypes.length === 0; + const allRelsSelected = selectedRelationshipTypes.length === 0; + + return ( + + + {t('controls.title')} + + + {/* Graph Statistics */} +
+
+
{t('controls.totalNodes')}
+
+ {stats.totalNodes} +
+
+
+
{t('controls.totalEdges')}
+
+ {stats.totalEdges} +
+
+
+
{t('controls.newNodes')}
+
+ {stats.newNodes} +
+
+
+
{t('controls.newEdges')}
+
+ {stats.newEdges} +
+
+
+ + + + {/* Visibility Toggles */} +
+
+ + +
+
+ + +
+
+ + +
+
+ + + + {/* Layout Sliders */} +
+
+
+ + {edgeLength} +
+ onEdgeLengthChange(Number(e.target.value))} + className="w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer dark:bg-gray-700" + style={{ + background: `linear-gradient(to right, hsl(var(--primary)) 0%, hsl(var(--primary)) ${ + ((edgeLength - 30) / (1000 - 30)) * 100 + }%, rgb(229 231 235) ${((edgeLength - 30) / (1000 - 30)) * 100}%, rgb(229 231 235) 100%)`, + }} + /> +
+ + {/* Edge Opacity Slider */} + {edgeOpacity != null && onEdgeOpacityChange && ( +
+
+ + {Math.round(edgeOpacity * 100)}% +
+ onEdgeOpacityChange(Number(e.target.value))} + className="w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer dark:bg-gray-700" + style={{ + background: `linear-gradient(to right, hsl(var(--primary)) 0%, hsl(var(--primary)) ${ + ((edgeOpacity - 0.05) / (1 - 0.05)) * 100 + }%, rgb(229 231 235) ${((edgeOpacity - 0.05) / (1 - 0.05)) * 100}%, rgb(229 231 235) 100%)`, + }} + /> +
+ )} + +
+
+ + {nodeSize} +
+ onNodeSizeChange(Number(e.target.value))} + className="w-full h-2 bg-gray-200 rounded-lg appearance-none cursor-pointer dark:bg-gray-700" + style={{ + background: `linear-gradient(to right, hsl(var(--primary)) 0%, hsl(var(--primary)) ${ + ((nodeSize - 3) / (15 - 3)) * 100 + }%, rgb(229 231 235) ${((nodeSize - 3) / (15 - 3)) * 100}%, rgb(229 231 235) 100%)`, + }} + /> +
+
+ + + + {/* Node Type Filter */} +
+
+ + {!allNodesSelected && ( + + )} +
+ +
+ {nodeTypes.map((type) => { + const isSelected = selectedNodeTypes.length === 0 || selectedNodeTypes.includes(type); + const color = nodeTypeColors.get(type) || getColorForType(type, isDarkMode); + return ( + handleNodeTypeToggle(type)} + > + {type} + + ); + })} +
+
+
+ + + + {/* Relationship Type Filter */} +
+
+ + {!allRelsSelected && ( + + )} +
+ +
+ {relationshipTypes.map((type) => { + const isSelected = + selectedRelationshipTypes.length === 0 || selectedRelationshipTypes.includes(type); + const color = getColorForType(type, isDarkMode); + return ( + handleRelationshipTypeToggle(type)} + > + {type} + + ); + })} +
+
+
+ + + + {/* Legend */} +
+ + +
+ {/* Show only selected types (or all if none selected) */} + {(selectedNodeTypes.length === 0 ? nodeTypes : selectedNodeTypes).map((type) => { + const color = nodeTypeColors.get(type) || getColorForType(type, isDarkMode); + return ( +
+
+ {type} +
+ ); + })} + {nodeTypes.length > 10 && selectedNodeTypes.length === 0 && ( +
+ {t('controls.moreTypes', { count: nodeTypes.length - 10 })} +
+ )} + + {/* Proposed indicator */} + {showProposed && ( + <> + +
+
+ Proposed New +
+
+
+ Modified +
+ + )} +
+ +
+ + + + {/* Reset View Button */} + + + + ); +} diff --git a/src/frontend/src/components/graph-explorer/graph-query-panel.tsx b/src/frontend/src/components/graph-explorer/graph-query-panel.tsx new file mode 100644 index 00000000..8ddddbd2 --- /dev/null +++ b/src/frontend/src/components/graph-explorer/graph-query-panel.tsx @@ -0,0 +1,669 @@ +/** + * Graph Query Panel + * + * Collapsible panel for running Cypher/Gremlin queries against the graph. + * Queries are translated to SQL via a backend LLM endpoint, executed on + * Databricks, and the resulting nodes/edges are applied to the graph. + */ + +import React, { useState, useCallback, useMemo, useEffect } from 'react'; +import { useTranslation } from 'react-i18next'; +import { + Collapsible, + CollapsibleTrigger, + CollapsibleContent, +} from '@/components/ui/collapsible'; +import { Button } from '@/components/ui/button'; +import { Badge } from '@/components/ui/badge'; +import { Textarea } from '@/components/ui/textarea'; +import { Alert, AlertDescription, AlertTitle } from '@/components/ui/alert'; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from '@/components/ui/select'; +import { + Tooltip, + TooltipContent, + TooltipProvider, + TooltipTrigger, +} from '@/components/ui/tooltip'; +import { + Sparkles, + Play, + Code2, + ChevronDown, + ChevronUp, + Copy, + X, + Loader2, + AlertCircle, + CheckCircle2, + Info, + AlertTriangle, + Database, +} from 'lucide-react'; +import type { GraphNode, GraphEdge, GraphData } from '@/types/graph-explorer'; + +// --------------------------------------------------------------------------- +// Types +// --------------------------------------------------------------------------- + +type QueryLanguage = 'natural' | 'cypher' | 'gremlin'; + +interface GraphQueryResult { + success: boolean; + nodes: GraphNode[]; + edges: GraphEdge[]; + sql: string; + language: QueryLanguage; + originalQuery: string; + rawRowCount?: number; + hasEdgeColumns?: boolean; + message?: string; + metadata?: { + source: string; + timestamp: string; + duration: string; + translationModel: string; + graphSchema?: string; + }; +} + +interface LlmConfig { + enabled: boolean; + defaultModel: string; + maxTokens: number; + provider: string; +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +function buildExampleQueries( + nodes: GraphNode[], + edges: GraphEdge[], +): Record { + const nodeTypes = [...new Set(nodes.map((n) => n.type).filter(Boolean))]; + const relTypes = [...new Set(edges.map((e) => e.relationshipType).filter(Boolean))]; + + const t1 = nodeTypes[0]; + const t2 = nodeTypes[1]; + const t3 = nodeTypes[2]; + const r1 = relTypes[0]; + const r2 = relTypes[1]; + + const unhelpfulValues = new Set(['unknown', 'null', 'undefined', 'n/a', 'none', '']); + let samplePropKey: string | null = null; + let samplePropVal: string | null = null; + + if (t1) { + const candidates = nodes.filter( + (n) => n.type === t1 && n.properties && Object.keys(n.properties).length > 0, + ); + for (const node of candidates) { + for (const [key, val] of Object.entries(node.properties)) { + if (val != null && !unhelpfulValues.has(String(val).toLowerCase().trim())) { + samplePropKey = key; + samplePropVal = String(val); + break; + } + } + if (samplePropKey) break; + } + } + + // Cypher + const cypher: { label: string; query: string }[] = []; + if (t1) cypher.push({ label: `All ${t1} nodes`, query: `MATCH (n:${t1}) RETURN n` }); + if (t1 && r1 && t2) cypher.push({ label: `${t1} -[${r1}]-> ${t2}`, query: `MATCH (a:${t1})-[:${r1}]->(b:${t2}) RETURN a, b` }); + else if (t1 && r1) cypher.push({ label: `${t1} via ${r1}`, query: `MATCH (a:${t1})-[:${r1}]->(b) RETURN a, b` }); + if (t1 && r1 && t2 && r2 && t3) cypher.push({ label: `2-hop: ${t1} → ${t2} → ${t3}`, query: `MATCH (a:${t1})-[:${r1}]->(b:${t2})-[:${r2}]->(c:${t3}) RETURN a, b, c` }); + else if (t1 && r1 && r2) cypher.push({ label: '2-hop path', query: `MATCH (a:${t1})-[:${r1}]->(b)-[:${r2}]->(c) RETURN a, b, c` }); + if (t1 && samplePropKey && samplePropVal) cypher.push({ label: `Filter by ${samplePropKey}`, query: `MATCH (n:${t1}) WHERE n.${samplePropKey} = '${samplePropVal}' RETURN n` }); + cypher.push({ label: 'All relationships', query: 'MATCH (a)-[r]->(b) RETURN a, r, b LIMIT 100' }); + + // Gremlin + const gremlin: { label: string; query: string }[] = []; + if (t1) gremlin.push({ label: `All ${t1} nodes`, query: `g.V().hasLabel('${t1}')` }); + if (t1 && r1 && t2) gremlin.push({ label: `${t1} -[${r1}]-> ${t2}`, query: `g.V().hasLabel('${t1}').out('${r1}').hasLabel('${t2}')` }); + else if (t1 && r1) gremlin.push({ label: `${t1} via ${r1}`, query: `g.V().hasLabel('${t1}').out('${r1}')` }); + if (t1 && r1 && t2 && r2 && t3) gremlin.push({ label: `2-hop: ${t1} → ${t2} → ${t3}`, query: `g.V().hasLabel('${t1}').out('${r1}').hasLabel('${t2}').out('${r2}').hasLabel('${t3}')` }); + else if (t1 && r1 && r2) gremlin.push({ label: '2-hop path', query: `g.V().hasLabel('${t1}').out('${r1}').out('${r2}')` }); + if (t1 && samplePropKey && samplePropVal) gremlin.push({ label: `Filter by ${samplePropKey}`, query: `g.V().hasLabel('${t1}').has('${samplePropKey}', '${samplePropVal}')` }); + gremlin.push({ label: 'All edges', query: 'g.E().limit(100)' }); + + return { natural: [], cypher, gremlin }; +} + +// --------------------------------------------------------------------------- +// Component +// --------------------------------------------------------------------------- + +interface GraphQueryPanelProps { + /** Called when query results should be applied to the graph */ + onApplyResults: (nodes: GraphNode[], edges: GraphEdge[]) => void; + /** Called when the user clears an active query to restore the full dataset */ + onClearQuery: () => void; + /** Full graph data — used to generate contextual example queries */ + graphData?: GraphData; + /** Current table name for context */ + tableName?: string; + /** Whether the panel is initially expanded */ + defaultExpanded?: boolean; +} + +const GraphQueryPanel: React.FC = ({ + onApplyResults, + onClearQuery, + graphData, + tableName, + defaultExpanded = false, +}) => { + const { t } = useTranslation(['graph-explorer']); + + // Example queries derived from current graph data + const exampleQueries = useMemo( + () => buildExampleQueries(graphData?.nodes ?? [], graphData?.edges ?? []), + [graphData], + ); + + // Panel state + const [expanded, setExpanded] = useState(defaultExpanded); + + // Query state + const [language, setLanguage] = useState('natural'); + const [query, setQuery] = useState(''); + const [isLoading, setIsLoading] = useState(false); + + // LLM availability + const [llmEnabled, setLlmEnabled] = useState(null); + const [llmModel, setLlmModel] = useState(''); + + // Fetch LLM config on mount + useEffect(() => { + fetch('/api/graph-explorer/llm-config') + .then((res) => (res.ok ? res.json() : null)) + .then((cfg: LlmConfig | null) => { + if (cfg) { + setLlmEnabled(cfg.enabled); + setLlmModel(cfg.defaultModel); + } else { + setLlmEnabled(false); + } + }) + .catch(() => setLlmEnabled(false)); + }, []); + + // Result state + const [result, setResult] = useState(null); + const [showSql, setShowSql] = useState(false); + const [error, setError] = useState(null); + const [queryApplied, setQueryApplied] = useState(false); + + // --- Handlers --- + + const handleLanguageChange = useCallback((value: string) => { + setLanguage(value as QueryLanguage); + setResult(null); + setError(null); + }, []); + + const handleExampleSelect = useCallback((value: string) => { + if (value) { + setQuery(value); + setResult(null); + setError(null); + } + }, []); + + const handleRunQuery = useCallback(async () => { + if (!query.trim()) return; + + setIsLoading(true); + setError(null); + setResult(null); + + try { + const response = await fetch('/api/graph-explorer/query', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ query, language, tableName }), + }); + + const data: GraphQueryResult = await response.json(); + setResult(data); + + if (!data.success) { + setError(data.message || 'Query failed'); + } else if (data.nodes.length > 0 || data.edges.length > 0) { + let nodes = data.nodes; + let edges = data.edges; + + // Vertex-only / edge-only heuristics only apply to structured query languages + if (language !== 'natural') { + const q = query.trim().toLowerCase(); + const isVertexOnly = + language === 'gremlin' + ? /^g\.v\(\)/.test(q) && !/(\.out\(|\.in\(|\.both\(|\.oute\(|\.ine\(|\.bothe\(|g\.e\()/.test(q) + : language === 'cypher' + ? !/-\[/.test(query) && !/\]-/.test(query) + : false; + const isEdgeOnly = language === 'gremlin' ? /^g\.e\(\)/.test(q) : false; + + if (isEdgeOnly) nodes = []; + if (isVertexOnly) edges = []; + + // Filter nodes for vertex-only queries + if (isVertexOnly && nodes.length > 0) { + const labelSet = new Set(); + if (language === 'cypher') { + const matches = query.matchAll(/\(\w*:(\w+)/g); + for (const m of matches) labelSet.add(m[1]); + } else if (language === 'gremlin') { + const matches = q.matchAll(/\.haslabel\('([^']+)'\)/g); + for (const m of matches) labelSet.add(m[1]); + } + if (labelSet.size > 0) { + const lowerSet = new Set([...labelSet].map((l) => l.toLowerCase())); + nodes = nodes.filter((n) => lowerSet.has(n.type.toLowerCase())); + } + } + } + + onApplyResults(nodes, edges); + setQueryApplied(true); + } else if (data.success) { + onApplyResults([], []); + setQueryApplied(true); + } + } catch (err) { + setError(err instanceof Error ? err.message : 'An unexpected error occurred'); + } finally { + setIsLoading(false); + } + }, [query, language, tableName, onApplyResults]); + + const handleRemoveLimit = useCallback(async () => { + if (!result?.sql) return; + const sqlWithoutLimit = result.sql.replace(/\s+LIMIT\s+\d+\s*$/i, ''); + if (sqlWithoutLimit === result.sql) return; + + setIsLoading(true); + setError(null); + + try { + const response = await fetch('/api/graph-explorer/query', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ query, language, tableName, sql: sqlWithoutLimit }), + }); + const data: GraphQueryResult = await response.json(); + setResult(data); + + if (!data.success) { + setError(data.message || 'Query failed'); + } else if (data.nodes.length > 0 || data.edges.length > 0) { + onApplyResults(data.nodes, data.edges); + setQueryApplied(true); + } else { + onApplyResults([], []); + setQueryApplied(true); + } + } catch (err) { + setError(err instanceof Error ? err.message : 'An unexpected error occurred'); + } finally { + setIsLoading(false); + } + }, [result, query, language, tableName, onApplyResults]); + + const handleCopySql = useCallback(() => { + if (result?.sql) { + navigator.clipboard.writeText(result.sql); + } + }, [result]); + + const handleClearQuery = useCallback(() => { + setQuery(''); + setResult(null); + setError(null); + setShowSql(false); + setQueryApplied(false); + onClearQuery(); + }, [onClearQuery]); + + const handleKeyDown = useCallback( + (e: React.KeyboardEvent) => { + if ((e.ctrlKey || e.metaKey) && e.key === 'Enter') { + e.preventDefault(); + handleRunQuery(); + } + }, + [handleRunQuery], + ); + + const placeholderText = useMemo(() => { + if (language === 'natural') { + return t('queryPanel.naturalPlaceholder'); + } + const examples = exampleQueries[language]; + const hop = examples.length > 1 ? examples[1] : examples[0]; + if (hop) return hop.query; + return language === 'cypher' + ? 'MATCH (a)-[r]->(b) RETURN a, r, b LIMIT 100' + : 'g.E().limit(100)'; + }, [exampleQueries, language, t]); + + // SQL LIMIT detection for warning + const limitMatch = result?.sql?.match(/LIMIT\s+(\d+)/i); + const sqlLimit = limitMatch ? parseInt(limitMatch[1], 10) : null; + const isLimitHit = + sqlLimit != null && result?.rawRowCount != null && result.rawRowCount >= sqlLimit; + + return ( + +
+ {/* Header */} +
+ + + + {queryApplied && ( + + )} +
+ + {/* Collapsible body */} + +
+ {/* LLM not configured */} + {llmEnabled === false && ( + + + {t('queryPanel.llmNotConfiguredTitle')} + + {t('queryPanel.llmNotConfiguredDescription')} + + + )} + + {/* Loading LLM config */} + {llmEnabled === null && ( +
+ + {t('queryPanel.checkingConfig')} +
+ )} + + {/* Query controls */} + {llmEnabled && ( + <> + {/* Language toggle + examples + run */} +
+ {/* Language selector */} +
+ + + +
+ + {/* Example queries dropdown — only for Cypher/Gremlin */} + {language !== 'natural' && ( + + )} + + {/* Run button */} + + + + {isLoading ? t('queryPanel.translatingViaLlm') : '⌘+Enter'} + +
+ + {/* Query input */} +