diff --git a/packages/jumpstarter-cli/jumpstarter_cli/get.py b/packages/jumpstarter-cli/jumpstarter_cli/get.py index f7d1a041..869dfd06 100644 --- a/packages/jumpstarter-cli/jumpstarter_cli/get.py +++ b/packages/jumpstarter-cli/jumpstarter_cli/get.py @@ -21,8 +21,8 @@ def get(): @opt_output_all @opt_comma_separated( "with", - {"leases", "online"}, - help_text="Include fields: leases, online (comma-separated or repeated)" + {"leases", "online", "status"}, + help_text="Include fields: leases, online, status (comma-separated or repeated)", ) @handle_exceptions_with_reauthentication(relogin_client) def get_exporters(config, selector: str | None, output: OutputType, with_options: list[str]): @@ -32,7 +32,10 @@ def get_exporters(config, selector: str | None, output: OutputType, with_options include_leases = "leases" in with_options include_online = "online" in with_options - exporters = config.list_exporters(filter=selector, include_leases=include_leases, include_online=include_online) + include_status = "status" in with_options + exporters = config.list_exporters( + filter=selector, include_leases=include_leases, include_online=include_online, include_status=include_status + ) model_print(exporters, output) diff --git a/packages/jumpstarter/jumpstarter/client/core.py b/packages/jumpstarter/jumpstarter/client/core.py index 3befe92e..2f6491db 100644 --- a/packages/jumpstarter/jumpstarter/client/core.py +++ b/packages/jumpstarter/jumpstarter/client/core.py @@ -14,7 +14,7 @@ from jumpstarter_protocol import jumpstarter_pb2, jumpstarter_pb2_grpc, router_pb2_grpc from rich.logging import RichHandler -from jumpstarter.common import Metadata +from jumpstarter.common import ExporterStatus, Metadata from jumpstarter.common.exceptions import JumpstarterException from jumpstarter.common.resources import ResourceMetadata from jumpstarter.common.serde import decode_value, encode_value @@ -48,6 +48,12 @@ class DriverInvalidArgument(DriverError, ValueError): """ +class ExporterNotReady(DriverError): + """ + Raised when the exporter is not ready to accept driver calls + """ + + @dataclass(kw_only=True) class AsyncDriverClient( Metadata, @@ -76,9 +82,28 @@ def __post_init__(self): handler = RichHandler() self.logger.addHandler(handler) + async def check_exporter_status(self): + """Check if the exporter is ready to accept driver calls""" + try: + response = await self.stub.GetStatus(jumpstarter_pb2.GetStatusRequest()) + status = ExporterStatus.from_proto(response.status) + + if status != ExporterStatus.LEASE_READY: + raise ExporterNotReady(f"Exporter status is {status}: {response.status_message}") + + except AioRpcError as e: + # If GetStatus is not implemented, assume ready for backward compatibility + if e.code() == StatusCode.UNIMPLEMENTED: + self.logger.debug("GetStatus not implemented, assuming exporter is ready") + return + raise DriverError(f"Failed to check exporter status: {e.details()}") from e + async def call_async(self, method, *args): """Make DriverCall by method name and arguments""" + # Check exporter status before making the call + await self.check_exporter_status() + request = jumpstarter_pb2.DriverCallRequest( uuid=str(self.uuid), method=method, @@ -105,6 +130,9 @@ async def call_async(self, method, *args): async def streamingcall_async(self, method, *args): """Make StreamingDriverCall by method name and arguments""" + # Check exporter status before making the call + await self.check_exporter_status() + request = jumpstarter_pb2.StreamingDriverCallRequest( uuid=str(self.uuid), method=method, diff --git a/packages/jumpstarter/jumpstarter/client/grpc.py b/packages/jumpstarter/jumpstarter/client/grpc.py index fc6f5252..295e4b97 100644 --- a/packages/jumpstarter/jumpstarter/client/grpc.py +++ b/packages/jumpstarter/jumpstarter/client/grpc.py @@ -13,6 +13,7 @@ from jumpstarter_protocol import client_pb2, client_pb2_grpc, jumpstarter_pb2_grpc, kubernetes_pb2, router_pb2_grpc from pydantic import BaseModel, ConfigDict, Field, field_serializer +from jumpstarter.common import ExporterStatus from jumpstarter.common.grpc import translate_grpc_exceptions @@ -20,6 +21,7 @@ class WithOptions: show_online: bool = False show_leases: bool = False + show_status: bool = False def add_display_columns(table, options: WithOptions = None): @@ -28,6 +30,8 @@ def add_display_columns(table, options: WithOptions = None): table.add_column("NAME") if options.show_online: table.add_column("ONLINE") + if options.show_status: + table.add_column("STATUS") table.add_column("LABELS") if options.show_leases: table.add_column("LEASED BY") @@ -42,6 +46,9 @@ def add_exporter_row(table, exporter, options: WithOptions = None, lease_info: t row_data.append(exporter.name) if options.show_online: row_data.append("yes" if exporter.online else "no") + if options.show_status: + status_str = str(exporter.status) if exporter.status else "UNKNOWN" + row_data.append(status_str) row_data.append(",".join(("{}={}".format(k, v) for k, v in sorted(exporter.labels.items())))) if options.show_leases: if lease_info: @@ -81,12 +88,16 @@ class Exporter(BaseModel): name: str labels: dict[str, str] online: bool = False + status: ExporterStatus | None = None lease: Lease | None = None @classmethod def from_protobuf(cls, data: client_pb2.Exporter) -> Exporter: namespace, name = parse_exporter_identifier(data.name) - return cls(namespace=namespace, name=name, labels=data.labels, online=data.online) + status = None + if hasattr(data, "status") and data.status: + status = ExporterStatus.from_proto(data.status) + return cls(namespace=namespace, name=name, labels=data.labels, online=data.online, status=status) @classmethod def rich_add_columns(cls, table, options: WithOptions = None): @@ -244,6 +255,7 @@ class ExporterList(BaseModel): next_page_token: str | None = Field(exclude=True) include_online: bool = Field(default=False, exclude=True) include_leases: bool = Field(default=False, exclude=True) + include_status: bool = Field(default=False, exclude=True) @classmethod def from_protobuf(cls, data: client_pb2.ListExportersResponse) -> ExporterList: @@ -253,11 +265,15 @@ def from_protobuf(cls, data: client_pb2.ListExportersResponse) -> ExporterList: ) def rich_add_columns(self, table): - options = WithOptions(show_online=self.include_online, show_leases=self.include_leases) + options = WithOptions( + show_online=self.include_online, show_leases=self.include_leases, show_status=self.include_status + ) Exporter.rich_add_columns(table, options) def rich_add_rows(self, table): - options = WithOptions(show_online=self.include_online, show_leases=self.include_leases) + options = WithOptions( + show_online=self.include_online, show_leases=self.include_leases, show_status=self.include_status + ) for exporter in self.exporters: exporter.rich_add_rows(table, options) @@ -274,12 +290,10 @@ def model_dump_json(self, **kwargs): exclude_fields.add("lease") if not self.include_online: exclude_fields.add("online") + if not self.include_status: + exclude_fields.add("status") - data = { - "exporters": [ - exporter.model_dump(mode="json", exclude=exclude_fields) for exporter in self.exporters - ] - } + data = {"exporters": [exporter.model_dump(mode="json", exclude=exclude_fields) for exporter in self.exporters]} return json.dumps(data, **json_kwargs) def model_dump(self, **kwargs): @@ -288,12 +302,11 @@ def model_dump(self, **kwargs): exclude_fields.add("lease") if not self.include_online: exclude_fields.add("online") + if not self.include_status: + exclude_fields.add("status") + + return {"exporters": [exporter.model_dump(mode="json", exclude=exclude_fields) for exporter in self.exporters]} - return { - "exporters": [ - exporter.model_dump(mode="json", exclude=exclude_fields) for exporter in self.exporters - ] - } class LeaseList(BaseModel): leases: list[Lease] diff --git a/packages/jumpstarter/jumpstarter/common/__init__.py b/packages/jumpstarter/jumpstarter/common/__init__.py index 13058cb0..08645b47 100644 --- a/packages/jumpstarter/jumpstarter/common/__init__.py +++ b/packages/jumpstarter/jumpstarter/common/__init__.py @@ -1,4 +1,12 @@ +from .enums import ExporterStatus, LogSource from .metadata import Metadata from .tempfile import TemporarySocket, TemporaryTcpListener, TemporaryUnixListener -__all__ = ["Metadata", "TemporarySocket", "TemporaryUnixListener", "TemporaryTcpListener"] +__all__ = [ + "ExporterStatus", + "LogSource", + "Metadata", + "TemporarySocket", + "TemporaryUnixListener", + "TemporaryTcpListener", +] diff --git a/packages/jumpstarter/jumpstarter/common/enums.py b/packages/jumpstarter/jumpstarter/common/enums.py new file mode 100644 index 00000000..ce6a79c2 --- /dev/null +++ b/packages/jumpstarter/jumpstarter/common/enums.py @@ -0,0 +1,76 @@ +"""Human-readable enum wrappers for protobuf-generated constants.""" + +from enum import IntEnum + +from jumpstarter_protocol.jumpstarter.v1 import common_pb2 + + +class ExporterStatus(IntEnum): + """Exporter status states.""" + + UNSPECIFIED = common_pb2.EXPORTER_STATUS_UNSPECIFIED + """Unknown/unspecified exporter status""" + + OFFLINE = common_pb2.EXPORTER_STATUS_OFFLINE + """The exporter is currently offline""" + + AVAILABLE = common_pb2.EXPORTER_STATUS_AVAILABLE + """Exporter is available to be leased""" + + BEFORE_LEASE_HOOK = common_pb2.EXPORTER_STATUS_BEFORE_LEASE_HOOK + """Exporter is leased, but currently executing before lease hook""" + + LEASE_READY = common_pb2.EXPORTER_STATUS_LEASE_READY + """Exporter is leased and ready to accept commands""" + + AFTER_LEASE_HOOK = common_pb2.EXPORTER_STATUS_AFTER_LEASE_HOOK + """Lease was releaseed, but exporter is executing after lease hook""" + + BEFORE_LEASE_HOOK_FAILED = common_pb2.EXPORTER_STATUS_BEFORE_LEASE_HOOK_FAILED + """The before lease hook failed and the exporter is no longer available""" + + AFTER_LEASE_HOOK_FAILED = common_pb2.EXPORTER_STATUS_AFTER_LEASE_HOOK_FAILED + """The after lease hook failed and the exporter is no longer available""" + + def __str__(self): + return self.name + + @classmethod + def from_proto(cls, value: int) -> "ExporterStatus": + """Convert from protobuf integer to enum.""" + return cls(value) + + def to_proto(self) -> int: + """Convert to protobuf integer.""" + return self.value + + +class LogSource(IntEnum): + """Log source types.""" + + UNSPECIFIED = common_pb2.LOG_SOURCE_UNSPECIFIED + """Unspecified/unknown log source""" + + DRIVER = common_pb2.LOG_SOURCE_DRIVER + """Logs produced by a Jumpstarter driver""" + + BEFORE_LEASE_HOOK = common_pb2.LOG_SOURCE_BEFORE_LEASE_HOOK + """Logs produced by a before lease hook""" + + AFTER_LEASE_HOOK = common_pb2.LOG_SOURCE_AFTER_LEASE_HOOK + """Logs produced by an after lease hook""" + + SYSTEM = common_pb2.LOG_SOURCE_SYSTEM + """System/exporter logs""" + + def __str__(self): + return self.name + + @classmethod + def from_proto(cls, value: int) -> "LogSource": + """Convert from protobuf integer to enum.""" + return cls(value) + + def to_proto(self) -> int: + """Convert to protobuf integer.""" + return self.value diff --git a/packages/jumpstarter/jumpstarter/config/client.py b/packages/jumpstarter/jumpstarter/config/client.py index 8cb24c50..fb71e396 100644 --- a/packages/jumpstarter/jumpstarter/config/client.py +++ b/packages/jumpstarter/jumpstarter/config/client.py @@ -160,12 +160,14 @@ async def list_exporters( filter: str | None = None, include_leases: bool = False, include_online: bool = False, + include_status: bool = False, ): svc = ClientService(channel=await self.channel(), namespace=self.metadata.namespace) exporters_response = await svc.ListExporters(page_size=page_size, page_token=page_token, filter=filter) - # Set the include_online flag for display purposes + # Set the include flags for display purposes exporters_response.include_online = include_online + exporters_response.include_status = include_status if not include_leases: return exporters_response diff --git a/packages/jumpstarter/jumpstarter/config/exporter.py b/packages/jumpstarter/jumpstarter/config/exporter.py index efd4724b..3e7b88b2 100644 --- a/packages/jumpstarter/jumpstarter/config/exporter.py +++ b/packages/jumpstarter/jumpstarter/config/exporter.py @@ -18,6 +18,33 @@ from jumpstarter.driver import Driver +class HookInstanceConfigV1Alpha1(BaseModel): + """Configuration for a specific lifecycle hook.""" + + model_config = ConfigDict(populate_by_name=True) + + script: str = Field(alias="script", description="The j script to execute for this hook") + timeout: int = Field(default=120, description="The hook execution timeout in seconds (default: 120s)") + exit_code: int = Field(alias="exitCode", default=0, description="The expected exit code (default: 0)") + on_failure: Literal["pass", "block", "warn"] = Field( + default="pass", + alias="onFailure", + description=( + "Action to take when the expected exit code is not returned: 'pass' continues normally, " + "'block' takes the exporter offline and blocks leases, 'warn' continues and prints a warning" + ), + ) + + +class HookConfigV1Alpha1(BaseModel): + """Configuration for lifecycle hooks.""" + + model_config = ConfigDict(populate_by_name=True) + + before_lease: HookInstanceConfigV1Alpha1 | None = Field(default=None, alias="beforeLease") + after_lease: HookInstanceConfigV1Alpha1 | None = Field(default=None, alias="afterLease") + + class ExporterConfigV1Alpha1DriverInstanceProxy(BaseModel): ref: str @@ -52,7 +79,7 @@ def instantiate(self) -> Driver: description=self.root.description, methods_description=self.root.methods_description, children=children, - **self.root.config + **self.root.config, ) case ExporterConfigV1Alpha1DriverInstanceComposite(): @@ -93,6 +120,7 @@ class ExporterConfigV1Alpha1(BaseModel): description: str | None = None export: dict[str, ExporterConfigV1Alpha1DriverInstance] = Field(default_factory=dict) + hooks: HookConfigV1Alpha1 = Field(default_factory=HookConfigV1Alpha1) path: Path | None = Field(default=None) @@ -127,7 +155,7 @@ def list(cls) -> ExporterConfigListV1Alpha1: @classmethod def dump_yaml(self, config: Self) -> str: - return yaml.safe_dump(config.model_dump(mode="json", exclude={"alias", "path"}), sort_keys=False) + return yaml.safe_dump(config.model_dump(mode="json", by_alias=True, exclude={"alias", "path"}), sort_keys=False) @classmethod def save(cls, config: Self, path: Optional[str] = None) -> Path: @@ -138,7 +166,7 @@ def save(cls, config: Self, path: Optional[str] = None) -> Path: else: config.path = Path(path) with config.path.open(mode="w") as f: - yaml.safe_dump(config.model_dump(mode="json", exclude={"alias", "path"}), f, sort_keys=False) + yaml.safe_dump(config.model_dump(mode="json", by_alias=True, exclude={"alias", "path"}), f, sort_keys=False) return config.path @classmethod @@ -185,6 +213,16 @@ async def channel_factory(): ) return aio_secure_channel(self.endpoint, credentials, self.grpcOptions) + # Create hook executor if hooks are configured + hook_executor = None + if self.hooks.before_lease or self.hooks.after_lease: + from jumpstarter.exporter.hooks import HookExecutor + + hook_executor = HookExecutor( + config=self.hooks, + device_factory=ExporterConfigV1Alpha1DriverInstance(children=self.export).instantiate, + ) + exporter = None entered = False try: @@ -197,6 +235,7 @@ async def channel_factory(): ).instantiate, tls=self.tls, grpc_options=self.grpcOptions, + hook_executor=hook_executor, ) # Initialize the exporter (registration, etc.) await exporter.__aenter__() diff --git a/packages/jumpstarter/jumpstarter/config/exporter_test.py b/packages/jumpstarter/jumpstarter/config/exporter_test.py index e9fb4863..68d0e3f4 100644 --- a/packages/jumpstarter/jumpstarter/config/exporter_test.py +++ b/packages/jumpstarter/jumpstarter/config/exporter_test.py @@ -101,3 +101,56 @@ def test_exporter_config(monkeypatch: pytest.MonkeyPatch, tmp_path: Path): ExporterConfigV1Alpha1.save(config) assert config == ExporterConfigV1Alpha1.load("test") + + +def test_exporter_config_with_hooks(monkeypatch: pytest.MonkeyPatch, tmp_path: Path): + monkeypatch.setattr(ExporterConfigV1Alpha1, "BASE_PATH", tmp_path) + + path = tmp_path / "test-hooks.yaml" + + text = """apiVersion: jumpstarter.dev/v1alpha1 +kind: ExporterConfig +metadata: + namespace: default + name: test-hooks +endpoint: "jumpstarter.my-lab.com:1443" +token: "test-token" +hooks: + beforeLease: + script: | + echo "Pre-lease hook for $LEASE_NAME" + j power on + timeout: 600 + afterLease: + script: | + echo "Post-lease hook for $LEASE_NAME" + j power off + timeout: 600 +export: + power: + type: "jumpstarter_driver_power.driver.PduPower" +""" + path.write_text( + text, + encoding="utf-8", + ) + + config = ExporterConfigV1Alpha1.load("test-hooks") + + assert config.hooks.before_lease.script == 'echo "Pre-lease hook for $LEASE_NAME"\nj power on\n' + assert config.hooks.after_lease.script == 'echo "Post-lease hook for $LEASE_NAME"\nj power off\n' + + # Test that it round-trips correctly + path.unlink() + ExporterConfigV1Alpha1.save(config) + reloaded_config = ExporterConfigV1Alpha1.load("test-hooks") + + assert reloaded_config.hooks.before_lease.script == config.hooks.before_lease.script + assert reloaded_config.hooks.after_lease.script == config.hooks.after_lease.script + + # Test that the YAML uses camelCase + yaml_output = ExporterConfigV1Alpha1.dump_yaml(config) + assert "beforeLease:" in yaml_output + assert "afterLease:" in yaml_output + assert "before_lease:" not in yaml_output + assert "after_lease:" not in yaml_output diff --git a/packages/jumpstarter/jumpstarter/driver/base.py b/packages/jumpstarter/jumpstarter/driver/base.py index 78273ef4..b735c32b 100644 --- a/packages/jumpstarter/jumpstarter/driver/base.py +++ b/packages/jumpstarter/jumpstarter/driver/base.py @@ -27,8 +27,9 @@ MARKER_STREAMCALL, MARKER_STREAMING_DRIVERCALL, ) -from jumpstarter.common import Metadata +from jumpstarter.common import LogSource, Metadata from jumpstarter.common.resources import ClientStreamResource, PresignedRequestResource, Resource, ResourceMetadata +from jumpstarter.exporter.logging import get_logger from jumpstarter.common.serde import decode_value, encode_value from jumpstarter.common.streams import ( DriverStreamRequest, @@ -85,7 +86,7 @@ def __post_init__(self): if hasattr(super(), "__post_init__"): super().__post_init__() - self.logger = logging.getLogger(self.__class__.__name__) + self.logger = get_logger(f"driver.{self.__class__.__name__}", LogSource.DRIVER) self.logger.setLevel(self.log_level) def close(self): diff --git a/packages/jumpstarter/jumpstarter/exporter/exporter.py b/packages/jumpstarter/jumpstarter/exporter/exporter.py index a33a6a9b..301c57bb 100644 --- a/packages/jumpstarter/jumpstarter/exporter/exporter.py +++ b/packages/jumpstarter/jumpstarter/exporter/exporter.py @@ -8,6 +8,7 @@ from anyio import ( AsyncContextManagerMixin, CancelScope, + Event, connect_unix, create_memory_object_stream, create_task_group, @@ -21,10 +22,11 @@ jumpstarter_pb2_grpc, ) -from jumpstarter.common import Metadata +from jumpstarter.common import ExporterStatus, Metadata from jumpstarter.common.streams import connect_router_stream from jumpstarter.config.tls import TLSConfigV1Alpha1 from jumpstarter.driver import Driver +from jumpstarter.exporter.hooks import HookContext, HookExecutionError, HookExecutor from jumpstarter.exporter.session import Session logger = logging.getLogger(__name__) @@ -37,11 +39,17 @@ class Exporter(AsyncContextManagerMixin, Metadata): lease_name: str = field(init=False, default="") tls: TLSConfigV1Alpha1 = field(default_factory=TLSConfigV1Alpha1) grpc_options: dict[str, str] = field(default_factory=dict) + hook_executor: HookExecutor | None = field(default=None) registered: bool = field(init=False, default=False) _unregister: bool = field(init=False, default=False) _stop_requested: bool = field(init=False, default=False) _started: bool = field(init=False, default=False) _tg: TaskGroup | None = field(init=False, default=None) + _current_client_name: str = field(init=False, default="") + _pre_lease_ready: Event | None = field(init=False, default=None) + _current_status: ExporterStatus = field(init=False, default=ExporterStatus.OFFLINE) + _current_session: Session | None = field(init=False, default=None) + _session_socket_path: str | None = field(init=False, default=None) def stop(self, wait_for_lease_exit=False, should_unregister=False): """Signal the exporter to stop. @@ -60,6 +68,26 @@ def stop(self, wait_for_lease_exit=False, should_unregister=False): self._stop_requested = True logger.info("Exporter marked for stop upon lease exit") + async def _update_status(self, status: ExporterStatus, message: str = ""): + """Update exporter status with the controller and session.""" + self._current_status = status + + # Update session status if available + if self._current_session: + self._current_session.update_status(status, message) + + try: + controller = jumpstarter_pb2_grpc.ControllerServiceStub(await self.channel_factory()) + await controller.UpdateStatus( + jumpstarter_pb2.UpdateStatusRequest( + status=status.to_proto(), + status_message=message, + ) + ) + logger.info(f"Updated status to {status}: {message}") + except Exception as e: + logger.error(f"Failed to update status: {e}") + @asynccontextmanager async def __asynccontextmanager__(self) -> AsyncGenerator[Self]: try: @@ -73,6 +101,7 @@ async def __asynccontextmanager__(self) -> AsyncGenerator[Self]: channel = await self.channel_factory() try: controller = jumpstarter_pb2_grpc.ControllerServiceStub(channel) + await self._update_status(ExporterStatus.OFFLINE, "Exporter shutting down") await controller.Unregister( jumpstarter_pb2.UnregisterRequest( reason="Exporter shutdown", @@ -105,20 +134,27 @@ async def session(self): labels=self.labels, root_device=self.device_factory(), ) as session: - async with session.serve_unix_async() as path: - async with grpc.aio.secure_channel( - f"unix://{path}", grpc.local_channel_credentials(grpc.LocalConnectionType.UDS) - ) as channel: - response = await jumpstarter_pb2_grpc.ExporterServiceStub(channel).GetReport(empty_pb2.Empty()) - logger.info("Registering exporter with controller") - await controller.Register( - jumpstarter_pb2.RegisterRequest( - labels=self.labels, - reports=response.reports, + # Store session reference for status updates + self._current_session = session + try: + async with session.serve_unix_async() as path: + async with grpc.aio.secure_channel( + f"unix://{path}", grpc.local_channel_credentials(grpc.LocalConnectionType.UDS) + ) as channel: + response = await jumpstarter_pb2_grpc.ExporterServiceStub(channel).GetReport(empty_pb2.Empty()) + logger.info("Registering exporter with controller") + await controller.Register( + jumpstarter_pb2.RegisterRequest( + labels=self.labels, + reports=response.reports, + ) ) - ) - self.registered = True - yield path + self.registered = True + await self._update_status(ExporterStatus.AVAILABLE, "Exporter registered and available") + yield path + finally: + # Clear session reference + self._current_session = None async def handle(self, lease_name, tg): logger.info("Listening for incoming connection requests on lease %s", lease_name) @@ -148,7 +184,18 @@ async def listen(retries=5, backoff=3): tg.start_soon(listen) + # Create session before hooks run async with self.session() as path: + # Store socket path for hook execution + self._session_socket_path = path + + # Wait for before-lease hook to complete before processing connections + if self._pre_lease_ready is not None: + logger.info("Waiting for before-lease hook to complete before accepting connections") + await self._pre_lease_ready.wait() + logger.info("before-lease hook completed, now accepting connections") + + # Process client connections async for request in listen_rx: logger.info("Handling new connection request on lease %s", lease_name) tg.start_soon( @@ -190,19 +237,144 @@ async def status(retries=5, backoff=3): tg.start_soon(status) async for status in status_rx: if self.lease_name != "" and self.lease_name != status.lease_name: + # After-lease hook for the previous lease + if self.hook_executor and self._current_client_name: + hook_context = HookContext( + lease_name=self.lease_name, + client_name=self._current_client_name, + ) + # Shield the after-lease hook from cancellation and await it + with CancelScope(shield=True): + await self.run_after_lease_hook(hook_context) + self.lease_name = status.lease_name logger.info("Lease status changed, killing existing connections") + # Reset event for next lease + self._pre_lease_ready = None self.stop() break + + # Check for lease state transitions + previous_leased = hasattr(self, "_previous_leased") and self._previous_leased + current_leased = status.leased + self.lease_name = status.lease_name if not self._started and self.lease_name != "": self._started = True + # Create event for pre-lease synchronization + self._pre_lease_ready = Event() tg.start_soon(self.handle, self.lease_name, tg) - if status.leased: + + if current_leased: logger.info("Currently leased by %s under %s", status.client_name, status.lease_name) + self._current_client_name = status.client_name + + # Before-lease hook when transitioning from unleased to leased + if not previous_leased: + if self.hook_executor: + hook_context = HookContext( + lease_name=status.lease_name, + client_name=status.client_name, + ) + tg.start_soon(self.run_before_lease_hook, self, hook_context) + else: + # No hook configured, set event immediately + await self._update_status(ExporterStatus.LEASE_READY, "Ready for commands") + if self._pre_lease_ready: + self._pre_lease_ready.set() else: logger.info("Currently not leased") + + # After-lease hook when transitioning from leased to unleased + if previous_leased and self.hook_executor and self._current_client_name: + hook_context = HookContext( + lease_name=self.lease_name, + client_name=self._current_client_name, + ) + # Shield the after-lease hook from cancellation and await it + with CancelScope(shield=True): + await self._update_status(ExporterStatus.AFTER_LEASE_HOOK, "Running afterLease hooks") + # Pass the current session to hook executor for logging + self.hook_executor.main_session = self._current_session + # Use session socket if available, otherwise create new session + await self.hook_executor.execute_after_lease_hook( + hook_context, socket_path=self._session_socket_path + ) + await self._update_status(ExporterStatus.AVAILABLE, "Available for new lease") + + self._current_client_name = "" + # Reset event for next lease + self._pre_lease_ready = None + if self._stop_requested: self.stop(should_unregister=True) break + + self._previous_leased = current_leased self._tg = None + + async def run_before_lease_hook(self, hook_ctx: HookContext): + """ + Execute the before-lease hook for the current exporter session. + + Args: + hook_ctx (HookContext): The current hook execution context + """ + try: + await self._update_status(ExporterStatus.BEFORE_LEASE_HOOK, "Running beforeLease hooks") + # Pass the current session to hook executor for logging + self.hook_executor.main_session = self._current_session + + # Wait for socket path to be available + while self._session_socket_path is None: + await sleep(0.1) + + # Execute hook with main session socket + await self.hook_executor.execute_before_lease_hook(hook_ctx, socket_path=self._session_socket_path) + await self._update_status(ExporterStatus.LEASE_READY, "Ready for commands") + logger.info("beforeLease hook completed successfully") + except HookExecutionError as e: + # Hook failed with on_failure='block' - end lease and set failed status + logger.error("beforeLease hook failed (on_failure=block): %s", e) + await self._update_status( + ExporterStatus.BEFORE_LEASE_HOOK_FAILED, f"beforeLease hook failed (on_failure=block): {e}" + ) + # Note: We don't take the exporter offline for before_lease hook failures + # The lease is simply not ready, and the exporter remains available for future leases + except Exception as e: + # Unexpected error during hook execution + logger.error("beforeLease hook failed with unexpected error: %s", e, exc_info=True) + await self._update_status(ExporterStatus.BEFORE_LEASE_HOOK_FAILED, f"beforeLease hook failed: {e}") + finally: + # Always set the event to unblock connections + if self._pre_lease_ready: + self._pre_lease_ready.set() + + async def run_after_lease_hook(self, hook_ctx: HookContext): + """ + Execute the after-lease hook for the current exporter session. + + Args: + hook_ctx (HookContext): The current hook execution context + """ + try: + await self._update_status(ExporterStatus.AFTER_LEASE_HOOK, "Running afterLease hooks") + # Pass the current session to hook executor for logging + self.hook_executor.main_session = self._current_session + # Use session socket if available, otherwise create new session + await self.hook_executor.execute_after_lease_hook(hook_ctx, socket_path=self._session_socket_path) + await self._update_status(ExporterStatus.AVAILABLE, "Available for new lease") + logger.info("afterLease hook completed successfully") + except HookExecutionError as e: + # Hook failed with on_failure='block' - set failed status and shut down exporter + logger.error("afterLease hook failed (on_failure=block): %s", e) + await self._update_status( + ExporterStatus.AFTER_LEASE_HOOK_FAILED, f"afterLease hook failed (on_failure=block): {e}" + ) + # Shut down the exporter after after_lease hook failure with on_failure='block' + logger.error("Shutting down exporter due to afterLease hook failure") + self.stop() + except Exception as e: + # Unexpected error during hook execution + logger.error("afterLease hook failed with unexpected error: %s", e, exc_info=True) + await self._update_status(ExporterStatus.AFTER_LEASE_HOOK_FAILED, f"afterLease hook failed: {e}") diff --git a/packages/jumpstarter/jumpstarter/exporter/hooks.py b/packages/jumpstarter/jumpstarter/exporter/hooks.py new file mode 100644 index 00000000..77803b28 --- /dev/null +++ b/packages/jumpstarter/jumpstarter/exporter/hooks.py @@ -0,0 +1,265 @@ +"""Lifecycle hooks for Jumpstarter exporters.""" + +import asyncio +import logging +import os +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from typing import Callable + +from jumpstarter.common import LogSource +from jumpstarter.config.env import JMP_DRIVERS_ALLOW, JUMPSTARTER_HOST +from jumpstarter.config.exporter import HookConfigV1Alpha1, HookInstanceConfigV1Alpha1 +from jumpstarter.driver import Driver +from jumpstarter.exporter.logging import get_logger +from jumpstarter.exporter.session import Session + +logger = logging.getLogger(__name__) + + +class HookExecutionError(Exception): + """Raised when a hook fails and on_failure is set to 'block'.""" + + pass + + +@dataclass(kw_only=True) +class HookContext: + """Context information passed to hooks.""" + + lease_name: str + client_name: str = "" + lease_duration: str = "" + exporter_name: str = "" + exporter_namespace: str = "" + + +@dataclass(kw_only=True) +class HookExecutor: + """Executes lifecycle hooks with access to the j CLI.""" + + config: HookConfigV1Alpha1 + device_factory: Callable[[], Driver] + main_session: Session | None = field(default=None) + + @asynccontextmanager + async def _create_hook_environment(self, context: HookContext): + """Create a local session and Unix socket for j CLI access.""" + with Session( + root_device=self.device_factory(), + # Use hook context for metadata + labels={ + "jumpstarter.dev/hook-context": "true", + "jumpstarter.dev/lease": context.lease_name, + }, + ) as session: + async with session.serve_unix_async() as unix_path: + # Create environment variables for the hook + hook_env = os.environ.copy() + hook_env.update( + { + JUMPSTARTER_HOST: str(unix_path), + JMP_DRIVERS_ALLOW: "UNSAFE", # Allow all drivers for local access + "LEASE_NAME": context.lease_name, + "CLIENT_NAME": context.client_name, + "LEASE_DURATION": context.lease_duration, + "EXPORTER_NAME": context.exporter_name, + "EXPORTER_NAMESPACE": context.exporter_namespace, + } + ) + + yield session, hook_env + + async def _execute_hook( + self, + hook_config: HookInstanceConfigV1Alpha1, + context: HookContext, + log_source: LogSource, + socket_path: str | None = None, + ): + """Execute a single hook command. + + Args: + hook_config: Hook configuration including script, timeout, exit_code, and on_failure + context: Hook context information + log_source: Log source for hook output + socket_path: Optional Unix socket path to reuse existing session. + If provided, hooks will access the main session instead of creating their own. + """ + command = hook_config.script + if not command or not command.strip(): + logger.debug("Hook command is empty, skipping") + return + + logger.info("Executing hook: %s", command.strip().split("\n")[0][:100]) + + # If socket_path provided, use existing session; otherwise create new one + if socket_path is not None: + # Reuse existing session - create environment without session creation + hook_env = os.environ.copy() + hook_env.update( + { + JUMPSTARTER_HOST: str(socket_path), + JMP_DRIVERS_ALLOW: "UNSAFE", + "LEASE_NAME": context.lease_name, + "CLIENT_NAME": context.client_name, + "LEASE_DURATION": context.lease_duration, + "EXPORTER_NAME": context.exporter_name, + "EXPORTER_NAMESPACE": context.exporter_namespace, + } + ) + + # Use main session for logging (must be available when socket_path is provided) + logging_session = self.main_session + if logging_session is None: + raise ValueError("main_session must be set when reusing socket_path") + + return await self._execute_hook_process(hook_config, context, log_source, hook_env, logging_session) + else: + # Create new session for hook execution (fallback/standalone mode) + async with self._create_hook_environment(context) as (session, hook_env): + # Determine which session to use for logging + logging_session = self.main_session if self.main_session is not None else session + return await self._execute_hook_process(hook_config, context, log_source, hook_env, logging_session) + + async def _execute_hook_process( + self, + hook_config: HookInstanceConfigV1Alpha1, + context: HookContext, + log_source: LogSource, + hook_env: dict, + logging_session: Session, + ): + """Execute the hook process with the given environment and logging session.""" + command = hook_config.script + timeout = hook_config.timeout + expected_exit_code = hook_config.exit_code + on_failure = hook_config.on_failure + + try: + # Execute the hook command using shell + process = await asyncio.create_subprocess_shell( + command, + env=hook_env, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + + try: + # Create a logger with automatic source registration + hook_logger = get_logger(f"hook.{context.lease_name}", log_source, logging_session) + + # Stream output line-by-line for real-time logging + output_lines = [] + + async def read_output(): + while True: + line = await process.stdout.readline() + if not line: + break + line_decoded = line.decode().rstrip() + output_lines.append(line_decoded) + # Route hook output through the logging system + hook_logger.info(line_decoded) + + # Run output reading and process waiting concurrently with timeout + await asyncio.wait_for(asyncio.gather(read_output(), process.wait()), timeout=timeout) + + # Check if exit code matches expected + if process.returncode == expected_exit_code: + logger.info("Hook executed successfully with exit code %d", process.returncode) + return + else: + # Exit code mismatch - handle according to on_failure setting + error_msg = f"Hook failed: expected exit code {expected_exit_code}, got {process.returncode}" + + if on_failure == "pass": + logger.info("%s (on_failure=pass, continuing)", error_msg) + return + elif on_failure == "warn": + logger.warning("%s (on_failure=warn, continuing)", error_msg) + return + else: # on_failure == "block" + logger.error("%s (on_failure=block, raising exception)", error_msg) + raise HookExecutionError(error_msg) + + except asyncio.TimeoutError as e: + error_msg = f"Hook timed out after {timeout} seconds" + logger.error(error_msg) + try: + process.terminate() + await asyncio.wait_for(process.wait(), timeout=5) + except asyncio.TimeoutError: + process.kill() + await process.wait() + + # Handle timeout according to on_failure setting + if on_failure == "pass": + logger.info("%s (on_failure=pass, continuing)", error_msg) + return + elif on_failure == "warn": + logger.warning("%s (on_failure=warn, continuing)", error_msg) + return + else: # on_failure == "block" + raise HookExecutionError(error_msg) from e + + except HookExecutionError: + # Re-raise HookExecutionError to propagate to exporter + raise + except Exception as e: + error_msg = f"Error executing hook: {e}" + logger.error(error_msg, exc_info=True) + + # Handle exception according to on_failure setting + if on_failure == "pass": + logger.info("%s (on_failure=pass, continuing)", error_msg) + return + elif on_failure == "warn": + logger.warning("%s (on_failure=warn, continuing)", error_msg) + return + else: # on_failure == "block" + raise HookExecutionError(error_msg) from e + + async def execute_before_lease_hook(self, context: HookContext, socket_path: str | None = None): + """Execute the before-lease hook. + + Args: + context: Hook context information + socket_path: Optional Unix socket path to reuse existing session + + Raises: + HookExecutionError: If hook fails and on_failure is set to 'block' + """ + if not self.config.before_lease: + logger.debug("No before-lease hook configured") + return + + logger.info("Executing before-lease hook for lease %s", context.lease_name) + await self._execute_hook( + self.config.before_lease, + context, + LogSource.BEFORE_LEASE_HOOK, + socket_path, + ) + + async def execute_after_lease_hook(self, context: HookContext, socket_path: str | None = None): + """Execute the after-lease hook. + + Args: + context: Hook context information + socket_path: Optional Unix socket path to reuse existing session + + Raises: + HookExecutionError: If hook fails and on_failure is set to 'block' + """ + if not self.config.after_lease: + logger.debug("No after-lease hook configured") + return + + logger.info("Executing after-lease hook for lease %s", context.lease_name) + await self._execute_hook( + self.config.after_lease, + context, + LogSource.AFTER_LEASE_HOOK, + socket_path, + ) diff --git a/packages/jumpstarter/jumpstarter/exporter/hooks_test.py b/packages/jumpstarter/jumpstarter/exporter/hooks_test.py new file mode 100644 index 00000000..0e18d332 --- /dev/null +++ b/packages/jumpstarter/jumpstarter/exporter/hooks_test.py @@ -0,0 +1,495 @@ +import asyncio +from unittest.mock import AsyncMock, Mock, call, patch + +import pytest + +from jumpstarter.config.env import JMP_DRIVERS_ALLOW, JUMPSTARTER_HOST +from jumpstarter.config.exporter import HookConfigV1Alpha1, HookInstanceConfigV1Alpha1 +from jumpstarter.driver import Driver +from jumpstarter.exporter.hooks import HookContext, HookExecutionError, HookExecutor + +pytestmark = pytest.mark.anyio + + +class MockDriver(Driver): + @classmethod + def client(cls) -> str: + return "test.MockClient" + + def close(self): + pass + + def reset(self): + pass + + +@pytest.fixture +def mock_device_factory(): + def factory(): + return MockDriver() + + return factory + + +@pytest.fixture +def hook_config(): + return HookConfigV1Alpha1( + before_lease=HookInstanceConfigV1Alpha1(script="echo 'Pre-lease hook executed'", timeout=10), + after_lease=HookInstanceConfigV1Alpha1(script="echo 'Post-lease hook executed'", timeout=10), + ) + + +@pytest.fixture +def hook_context(): + return HookContext( + lease_name="test-lease-123", + client_name="test-client", + lease_duration="30m", + exporter_name="test-exporter", + exporter_namespace="default", + ) + + +class TestHookExecutor: + async def test_hook_executor_creation(self, hook_config, mock_device_factory): + executor = HookExecutor( + config=hook_config, + device_factory=mock_device_factory, + ) + + assert executor.config == hook_config + assert executor.device_factory == mock_device_factory + + async def test_empty_hook_execution(self, mock_device_factory, hook_context): + empty_config = HookConfigV1Alpha1() + executor = HookExecutor( + config=empty_config, + device_factory=mock_device_factory, + ) + + # Both hooks should return True for empty/None commands + assert await executor.execute_before_lease_hook(hook_context) is True + assert await executor.execute_after_lease_hook(hook_context) is True + + async def test_successful_hook_execution(self, mock_device_factory, hook_context): + hook_config = HookConfigV1Alpha1( + before_lease=HookInstanceConfigV1Alpha1(script="echo 'Pre-lease hook executed'", timeout=10), + ) + # Mock the Session and serve_unix_async + with patch("jumpstarter.exporter.hooks.Session") as mock_session_class: + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + + # Mock the async context manager for serve_unix_async + mock_session.serve_unix_async.return_value.__aenter__ = AsyncMock(return_value="/tmp/test_socket") + mock_session.serve_unix_async.return_value.__aexit__ = AsyncMock(return_value=None) + + # Mock asyncio.create_subprocess_shell to simulate successful execution + mock_process = AsyncMock() + mock_process.returncode = 0 + # Mock stdout.readline to simulate line-by-line output + mock_process.stdout.readline.side_effect = [ + b"Pre-lease hook executed\n", + b"", # EOF + ] + mock_process.wait = AsyncMock(return_value=None) + + with patch("asyncio.create_subprocess_shell", return_value=mock_process) as mock_subprocess: + executor = HookExecutor( + config=hook_config, + device_factory=mock_device_factory, + ) + + result = await executor.execute_before_lease_hook(hook_context) + + assert result is True + + # Verify subprocess was called with correct environment + mock_subprocess.assert_called_once() + call_args = mock_subprocess.call_args + command = call_args[0][0] + env = call_args[1]["env"] + + assert command == "echo 'Pre-lease hook executed'" + assert JUMPSTARTER_HOST in env + assert env[JUMPSTARTER_HOST] == "/tmp/test_socket" + assert env[JMP_DRIVERS_ALLOW] == "UNSAFE" + assert env["LEASE_NAME"] == "test-lease-123" + assert env["CLIENT_NAME"] == "test-client" + + async def test_failed_hook_execution(self, mock_device_factory, hook_context): + failed_config = HookConfigV1Alpha1( + before_lease=HookInstanceConfigV1Alpha1( + script="exit 1", timeout=10, on_failure="block" + ), # Command that will fail with on_failure="block" + ) + + with patch("jumpstarter.exporter.hooks.Session") as mock_session_class: + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + + mock_session.serve_unix_async.return_value.__aenter__ = AsyncMock(return_value="/tmp/test_socket") + mock_session.serve_unix_async.return_value.__aexit__ = AsyncMock(return_value=None) + + # Mock failed process + mock_process = AsyncMock() + mock_process.returncode = 1 + # Mock stdout.readline for failed process + mock_process.stdout.readline.side_effect = [ + b"Command failed\n", + b"", # EOF + ] + mock_process.wait = AsyncMock(return_value=None) + + with patch("asyncio.create_subprocess_shell", return_value=mock_process): + executor = HookExecutor( + config=failed_config, + device_factory=mock_device_factory, + ) + + # Should raise HookExecutionError since on_failure="block" + with pytest.raises(HookExecutionError, match="expected exit code 0, got 1"): + await executor.execute_before_lease_hook(hook_context) + + async def test_hook_timeout(self, mock_device_factory, hook_context): + timeout_config = HookConfigV1Alpha1( + before_lease=HookInstanceConfigV1Alpha1( + script="sleep 60", timeout=1, on_failure="block" + ), # Command that will timeout with on_failure="block" + ) + + with patch("jumpstarter.exporter.hooks.Session") as mock_session_class: + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + + mock_session.serve_unix_async.return_value.__aenter__ = AsyncMock(return_value="/tmp/test_socket") + mock_session.serve_unix_async.return_value.__aexit__ = AsyncMock(return_value=None) + + # Mock process that times out + mock_process = AsyncMock() + mock_process.terminate.return_value = None + mock_process.wait.return_value = None + + with ( + patch("asyncio.create_subprocess_shell", return_value=mock_process), + patch("asyncio.wait_for", side_effect=asyncio.TimeoutError()), + ): + executor = HookExecutor( + config=timeout_config, + device_factory=mock_device_factory, + ) + + # Should raise HookExecutionError since on_failure="block" + with pytest.raises(HookExecutionError, match="timed out after 1 seconds"): + await executor.execute_before_lease_hook(hook_context) + + mock_process.terminate.assert_called_once() + + async def test_hook_environment_variables(self, mock_device_factory, hook_context): + hook_config = HookConfigV1Alpha1( + before_lease=HookInstanceConfigV1Alpha1(script="echo 'Pre-lease hook executed'", timeout=10), + ) + with patch("jumpstarter.exporter.hooks.Session") as mock_session_class: + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + + mock_session.serve_unix_async.return_value.__aenter__ = AsyncMock(return_value="/tmp/test_socket") + mock_session.serve_unix_async.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_process = AsyncMock() + mock_process.returncode = 0 + # Mock stdout.readline for environment test + mock_process.stdout.readline.side_effect = [ + b"", # EOF (no output) + ] + mock_process.wait = AsyncMock(return_value=None) + + with patch("asyncio.create_subprocess_shell", return_value=mock_process) as mock_subprocess: + executor = HookExecutor( + config=hook_config, + device_factory=mock_device_factory, + ) + + await executor.execute_before_lease_hook(hook_context) + + # Check that all expected environment variables are set + call_args = mock_subprocess.call_args + env = call_args[1]["env"] + + assert env["LEASE_NAME"] == "test-lease-123" + assert env["CLIENT_NAME"] == "test-client" + assert env["LEASE_DURATION"] == "30m" + assert env["EXPORTER_NAME"] == "test-exporter" + assert env["EXPORTER_NAMESPACE"] == "default" + assert env[JUMPSTARTER_HOST] == "/tmp/test_socket" + assert env[JMP_DRIVERS_ALLOW] == "UNSAFE" + + async def test_real_time_output_logging(self, mock_device_factory, hook_context): + """Test that hook output is logged in real-time at INFO level.""" + hook_config = HookConfigV1Alpha1( + before_lease=HookInstanceConfigV1Alpha1( + script="echo 'Line 1'; echo 'Line 2'; echo 'Line 3'", timeout=10 + ), + ) + + with patch("jumpstarter.exporter.hooks.Session") as mock_session_class: + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + + mock_session.serve_unix_async.return_value.__aenter__ = AsyncMock(return_value="/tmp/test_socket") + mock_session.serve_unix_async.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_process = AsyncMock() + mock_process.returncode = 0 + # Mock multiple lines of output to verify streaming + mock_process.stdout.readline.side_effect = [ + b"Line 1\n", + b"Line 2\n", + b"Line 3\n", + b"", # EOF + ] + mock_process.wait = AsyncMock(return_value=None) + + # Mock the logger to capture log calls + with ( + patch("jumpstarter.exporter.hooks.logger") as mock_logger, + patch("asyncio.create_subprocess_shell", return_value=mock_process), + ): + executor = HookExecutor( + config=hook_config, + device_factory=mock_device_factory, + ) + + result = await executor.execute_before_lease_hook(hook_context) + + assert result is True + + # Verify that output lines were logged in real-time at INFO level + expected_calls = [ + call("Executing before-lease hook for lease %s", "test-lease-123"), + call("Executing hook: %s", "echo 'Line 1'; echo 'Line 2'; echo 'Line 3'"), + call("Hook executed successfully with exit code %d", 0), + ] + mock_logger.info.assert_has_calls(expected_calls, any_order=False) + + async def test_post_lease_hook_execution_on_completion(self, mock_device_factory, hook_context): + """Test that post-lease hook executes when called directly.""" + hook_config = HookConfigV1Alpha1( + after_lease=HookInstanceConfigV1Alpha1(script="echo 'Post-lease cleanup completed'", timeout=10), + ) + + with patch("jumpstarter.exporter.hooks.Session") as mock_session_class: + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + + mock_session.serve_unix_async.return_value.__aenter__ = AsyncMock(return_value="/tmp/test_socket") + mock_session.serve_unix_async.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_process = AsyncMock() + mock_process.returncode = 0 + # Mock post-lease hook output + mock_process.stdout.readline.side_effect = [ + b"Post-lease cleanup completed\n", + b"", # EOF + ] + mock_process.wait = AsyncMock(return_value=None) + + # Mock the logger to capture log calls + with ( + patch("jumpstarter.exporter.hooks.logger") as mock_logger, + patch("asyncio.create_subprocess_shell", return_value=mock_process), + ): + executor = HookExecutor( + config=hook_config, + device_factory=mock_device_factory, + ) + + result = await executor.execute_after_lease_hook(hook_context) + + assert result is True + + # Verify that post-lease hook output was logged + expected_calls = [ + call("Executing after-lease hook for lease %s", "test-lease-123"), + call("Executing hook: %s", "echo 'Post-lease cleanup completed'"), + call("Hook executed successfully with exit code %d", 0), + ] + mock_logger.info.assert_has_calls(expected_calls, any_order=False) + + async def test_hook_exit_code_matching_success(self, mock_device_factory, hook_context): + """Test that hook succeeds when exit code matches expected value.""" + hook_config = HookConfigV1Alpha1( + before_lease=HookInstanceConfigV1Alpha1(script="exit 0", timeout=10, exit_code=0), + ) + + with patch("jumpstarter.exporter.hooks.Session") as mock_session_class: + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.serve_unix_async.return_value.__aenter__ = AsyncMock(return_value="/tmp/test_socket") + mock_session.serve_unix_async.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_process = AsyncMock() + mock_process.returncode = 0 + mock_process.stdout.readline.side_effect = [b""] + mock_process.wait = AsyncMock(return_value=None) + + with patch("asyncio.create_subprocess_shell", return_value=mock_process): + executor = HookExecutor(config=hook_config, device_factory=mock_device_factory) + result = await executor.execute_before_lease_hook(hook_context) + assert result is True + + async def test_hook_exit_code_matching_custom(self, mock_device_factory, hook_context): + """Test that hook succeeds when exit code matches custom expected value.""" + hook_config = HookConfigV1Alpha1( + before_lease=HookInstanceConfigV1Alpha1(script="exit 42", timeout=10, exit_code=42), + ) + + with patch("jumpstarter.exporter.hooks.Session") as mock_session_class: + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.serve_unix_async.return_value.__aenter__ = AsyncMock(return_value="/tmp/test_socket") + mock_session.serve_unix_async.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_process = AsyncMock() + mock_process.returncode = 42 + mock_process.stdout.readline.side_effect = [b""] + mock_process.wait = AsyncMock(return_value=None) + + with patch("asyncio.create_subprocess_shell", return_value=mock_process): + executor = HookExecutor(config=hook_config, device_factory=mock_device_factory) + result = await executor.execute_before_lease_hook(hook_context) + assert result is True + + async def test_hook_exit_code_mismatch_pass(self, mock_device_factory, hook_context): + """Test that hook succeeds when exit code mismatches but on_failure='pass'.""" + hook_config = HookConfigV1Alpha1( + before_lease=HookInstanceConfigV1Alpha1(script="exit 1", timeout=10, exit_code=0, on_failure="pass"), + ) + + with patch("jumpstarter.exporter.hooks.Session") as mock_session_class: + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.serve_unix_async.return_value.__aenter__ = AsyncMock(return_value="/tmp/test_socket") + mock_session.serve_unix_async.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_process = AsyncMock() + mock_process.returncode = 1 + mock_process.stdout.readline.side_effect = [b""] + mock_process.wait = AsyncMock(return_value=None) + + with ( + patch("asyncio.create_subprocess_shell", return_value=mock_process), + patch("jumpstarter.exporter.hooks.logger") as mock_logger, + ): + executor = HookExecutor(config=hook_config, device_factory=mock_device_factory) + result = await executor.execute_before_lease_hook(hook_context) + assert result is True + # Verify INFO log was created (using format string) + mock_logger.info.assert_any_call( + "%s (on_failure=pass, continuing)", "Hook failed: expected exit code 0, got 1" + ) + + async def test_hook_exit_code_mismatch_warn(self, mock_device_factory, hook_context): + """Test that hook succeeds when exit code mismatches but on_failure='warn'.""" + hook_config = HookConfigV1Alpha1( + before_lease=HookInstanceConfigV1Alpha1(script="exit 1", timeout=10, exit_code=0, on_failure="warn"), + ) + + with patch("jumpstarter.exporter.hooks.Session") as mock_session_class: + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.serve_unix_async.return_value.__aenter__ = AsyncMock(return_value="/tmp/test_socket") + mock_session.serve_unix_async.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_process = AsyncMock() + mock_process.returncode = 1 + mock_process.stdout.readline.side_effect = [b""] + mock_process.wait = AsyncMock(return_value=None) + + with ( + patch("asyncio.create_subprocess_shell", return_value=mock_process), + patch("jumpstarter.exporter.hooks.logger") as mock_logger, + ): + executor = HookExecutor(config=hook_config, device_factory=mock_device_factory) + result = await executor.execute_before_lease_hook(hook_context) + assert result is True + # Verify WARNING log was created (using format string) + mock_logger.warning.assert_any_call( + "%s (on_failure=warn, continuing)", "Hook failed: expected exit code 0, got 1" + ) + + async def test_hook_exit_code_mismatch_block(self, mock_device_factory, hook_context): + """Test that hook raises exception when exit code mismatches and on_failure='block'.""" + hook_config = HookConfigV1Alpha1( + before_lease=HookInstanceConfigV1Alpha1(script="exit 1", timeout=10, exit_code=0, on_failure="block"), + ) + + with patch("jumpstarter.exporter.hooks.Session") as mock_session_class: + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.serve_unix_async.return_value.__aenter__ = AsyncMock(return_value="/tmp/test_socket") + mock_session.serve_unix_async.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_process = AsyncMock() + mock_process.returncode = 1 + mock_process.stdout.readline.side_effect = [b""] + mock_process.wait = AsyncMock(return_value=None) + + with patch("asyncio.create_subprocess_shell", return_value=mock_process): + executor = HookExecutor(config=hook_config, device_factory=mock_device_factory) + with pytest.raises(HookExecutionError, match="expected exit code 0, got 1"): + await executor.execute_before_lease_hook(hook_context) + + async def test_hook_timeout_with_pass(self, mock_device_factory, hook_context): + """Test that hook succeeds when timeout occurs but on_failure='pass'.""" + hook_config = HookConfigV1Alpha1( + before_lease=HookInstanceConfigV1Alpha1(script="sleep 60", timeout=1, on_failure="pass"), + ) + + with patch("jumpstarter.exporter.hooks.Session") as mock_session_class: + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.serve_unix_async.return_value.__aenter__ = AsyncMock(return_value="/tmp/test_socket") + mock_session.serve_unix_async.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_process = AsyncMock() + mock_process.terminate = AsyncMock(return_value=None) + mock_process.wait = AsyncMock(return_value=None) + + with ( + patch("asyncio.create_subprocess_shell", return_value=mock_process), + patch("asyncio.wait_for", side_effect=asyncio.TimeoutError()), + patch("jumpstarter.exporter.hooks.logger") as mock_logger, + ): + executor = HookExecutor(config=hook_config, device_factory=mock_device_factory) + result = await executor.execute_before_lease_hook(hook_context) + assert result is True + # Verify INFO log was created + assert any("on_failure=pass, continuing" in str(call) for call in mock_logger.info.call_args_list) + + async def test_hook_timeout_with_warn(self, mock_device_factory, hook_context): + """Test that hook succeeds when timeout occurs but on_failure='warn'.""" + hook_config = HookConfigV1Alpha1( + before_lease=HookInstanceConfigV1Alpha1(script="sleep 60", timeout=1, on_failure="warn"), + ) + + with patch("jumpstarter.exporter.hooks.Session") as mock_session_class: + mock_session = Mock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.serve_unix_async.return_value.__aenter__ = AsyncMock(return_value="/tmp/test_socket") + mock_session.serve_unix_async.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_process = AsyncMock() + mock_process.terminate = AsyncMock(return_value=None) + mock_process.wait = AsyncMock(return_value=None) + + with ( + patch("asyncio.create_subprocess_shell", return_value=mock_process), + patch("asyncio.wait_for", side_effect=asyncio.TimeoutError()), + patch("jumpstarter.exporter.hooks.logger") as mock_logger, + ): + executor = HookExecutor(config=hook_config, device_factory=mock_device_factory) + result = await executor.execute_before_lease_hook(hook_context) + assert result is True + # Verify WARNING log was created + assert any("on_failure=warn, continuing" in str(call) for call in mock_logger.warning.call_args_list) diff --git a/packages/jumpstarter/jumpstarter/exporter/logging.py b/packages/jumpstarter/jumpstarter/exporter/logging.py index 629306c2..6a6e8dad 100644 --- a/packages/jumpstarter/jumpstarter/exporter/logging.py +++ b/packages/jumpstarter/jumpstarter/exporter/logging.py @@ -1,23 +1,53 @@ import logging from collections import deque +from contextlib import contextmanager +from threading import RLock from jumpstarter_protocol import jumpstarter_pb2 +from .logging_protocol import LoggerRegistration +from jumpstarter.common import LogSource + class LogHandler(logging.Handler): - def __init__(self, queue: deque): + def __init__(self, queue: deque, source: LogSource = LogSource.UNSPECIFIED): logging.Handler.__init__(self) self.queue = queue self.listener = None + self.source = source # LogSource enum value + self._lock = RLock() + self._child_handlers = {} # Dict of logger_name -> LogSource mappings + + def add_child_handler(self, logger_name: str, source: LogSource): + """Add a child handler that will route logs from a specific logger with a different source.""" + with self._lock: + self._child_handlers[logger_name] = source + + def remove_child_handler(self, logger_name: str): + """Remove a child handler mapping.""" + with self._lock: + self._child_handlers.pop(logger_name, None) + + def get_source_for_record(self, record): + """Determine the appropriate log source for a record.""" + with self._lock: + # Check if this record comes from a logger with a specific source mapping + logger_name = record.name + for mapped_logger, source in self._child_handlers.items(): + if logger_name.startswith(mapped_logger): + return source + return self.source def enqueue(self, record): self.queue.append(record) def prepare(self, record): + source = self.get_source_for_record(record) return jumpstarter_pb2.LogStreamResponse( uuid="", severity=record.levelname, message=self.format(record), + source=source.value, # Convert to proto value ) def emit(self, record): @@ -25,3 +55,35 @@ def emit(self, record): self.enqueue(self.prepare(record)) except Exception: self.handleError(record) + + @contextmanager + def context_log_source(self, logger_name: str, source: LogSource): + """Context manager to temporarily set a log source for a specific logger.""" + self.add_child_handler(logger_name, source) + try: + yield + finally: + self.remove_child_handler(logger_name) + + +def get_logger( + name: str, source: LogSource = LogSource.SYSTEM, session: LoggerRegistration | None = None +) -> logging.Logger: + """ + Get a logger with automatic LogSource mapping. + + Args: + name: Logger name (e.g., __name__ or custom name) + source: The LogSource to associate with this logger + session: Optional session to register with immediately + + Returns: + A standard Python logger instance + """ + logger = logging.getLogger(name) + + # If session provided, register the source mapping + if session: + session.add_logger_source(name, source) + + return logger diff --git a/packages/jumpstarter/jumpstarter/exporter/logging_protocol.py b/packages/jumpstarter/jumpstarter/exporter/logging_protocol.py new file mode 100644 index 00000000..04ed885f --- /dev/null +++ b/packages/jumpstarter/jumpstarter/exporter/logging_protocol.py @@ -0,0 +1,22 @@ +"""Protocol for logger registration to avoid circular dependencies.""" + +from typing import Protocol + +from jumpstarter.common import LogSource + + +class LoggerRegistration(Protocol): + """Protocol for objects that can register logger sources. + + This protocol defines the interface for objects that can associate + logger names with log sources, enabling proper routing of log messages. + """ + + def add_logger_source(self, logger_name: str, source: LogSource) -> None: + """Register a logger name with its corresponding log source. + + Args: + logger_name: Name of the logger to register + source: The log source category for this logger + """ + ... diff --git a/packages/jumpstarter/jumpstarter/exporter/session.py b/packages/jumpstarter/jumpstarter/exporter/session.py index 63ae2f08..13d1a462 100644 --- a/packages/jumpstarter/jumpstarter/exporter/session.py +++ b/packages/jumpstarter/jumpstarter/exporter/session.py @@ -17,7 +17,7 @@ ) from .logging import LogHandler -from jumpstarter.common import Metadata, TemporarySocket +from jumpstarter.common import ExporterStatus, LogSource, Metadata, TemporarySocket from jumpstarter.common.streams import StreamRequestMetadata from jumpstarter.driver import Driver from jumpstarter.streams.common import forward_stream @@ -39,6 +39,9 @@ class Session( _logging_queue: deque = field(init=False) _logging_handler: QueueHandler = field(init=False) + _current_status: ExporterStatus = field(init=False, default=ExporterStatus.AVAILABLE) + _status_message: str = field(init=False, default="") + _status_update_event: Event = field(init=False) @contextmanager def __contextmanager__(self) -> Generator[Self]: @@ -67,7 +70,11 @@ def __init__(self, *args, root_device, **kwargs): self.mapping = {u: i for (u, _, _, i) in self.root_device.enumerate()} self._logging_queue = deque(maxlen=32) - self._logging_handler = LogHandler(self._logging_queue) + self._logging_handler = LogHandler(self._logging_queue, LogSource.SYSTEM) + self._status_update_event = Event() + + # Map all driver logs to DRIVER source + self._logging_handler.add_child_handler("driver.", LogSource.DRIVER) @asynccontextmanager async def serve_port_async(self, port): @@ -139,3 +146,31 @@ async def LogStream(self, request, context): yield self._logging_queue.popleft() except IndexError: await sleep(0.5) + + def update_status(self, status: int | ExporterStatus, message: str = ""): + """Update the current exporter status for the session.""" + if isinstance(status, int): + self._current_status = ExporterStatus.from_proto(status) + else: + self._current_status = status + self._status_message = message + + def add_logger_source(self, logger_name: str, source: LogSource): + """Add a log source mapping for a specific logger.""" + self._logging_handler.add_child_handler(logger_name, source) + + def remove_logger_source(self, logger_name: str): + """Remove a log source mapping for a specific logger.""" + self._logging_handler.remove_child_handler(logger_name) + + def context_log_source(self, logger_name: str, source: LogSource): + """Context manager to temporarily set a log source for a specific logger.""" + return self._logging_handler.context_log_source(logger_name, source) + + async def GetStatus(self, request, context): + """Get the current exporter status.""" + logger.debug("GetStatus() -> %s", self._current_status) + return jumpstarter_pb2.GetStatusResponse( + status=self._current_status.to_proto(), + status_message=self._status_message, + )