diff --git a/.ruff.toml b/.ruff.toml index a025d6ce4..60cc75877 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -5,4 +5,4 @@ ignore = ["ALL"] select = [ "F", # PyFlakes "I" # Isort -] +] \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index b26590997..bff417e18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ drunc-fsm-tests = "drunc.tests.fsm:main" application-registry-service = "drunc.apps.app_connectivity_server:main" drunc-ssh-validator = "drunc.apps.ssh_validator:main" drunc-ssh-doctor = "drunc.apps.ssh_doctor:main" +drunc-process-wrapper = "drunc.apps.process_wrapper:main" [tool.setuptools.packages.find] where = ["src"] @@ -69,6 +70,10 @@ addopts = "-v --tb=short --cov=drunc --cov=src/drunc tests/" source = ["drunc"] omit = ["tests/*"] +[tool.ruff] +exclude = ["tests/"] + + # * See https://docs.astral.sh/ruff/rules/ for details on Ruff's linting options [tool.ruff.lint] select = [ diff --git a/src/drunc/apps/process_wrapper.py b/src/drunc/apps/process_wrapper.py new file mode 100644 index 000000000..3624947a9 --- /dev/null +++ b/src/drunc/apps/process_wrapper.py @@ -0,0 +1,49 @@ +import os +import signal +import subprocess +import time + +import click + +from drunc.process_manager.subprocess_process_manager import on_parent_exit + + +def terminate_all(sig, frame): + pgrp = os.getpgid(os.getpid()) + os.killpg(pgrp, signal.SIGKILL) + + +@click.command() +@click.argument("cmd") +@click.option( + "-l", + "--log", + "log_path", + type=click.Path(file_okay=True, dir_okay=False), + required=True, +) +def main(cmd: str, log_path: str): + signal.signal(signal.SIGTERM, terminate_all) + + with open(log_path, "w") as logfile: + proc = subprocess.Popen( + cmd, + shell=True, + stdout=logfile, + stderr=logfile, + preexec_fn=on_parent_exit( + signal.SIGTERM, # Propagate SIGHUP to child processes, SIGKILL doesn't seem to kill gunicorn... + setsid=False, # Don't create a new session, so that the process group can be killed + ), + ) + + return_code = None + while True: + return_code = proc.poll() + if return_code is not None: + return return_code + time.sleep(0.1) + + +if __name__ == "__main__": + main() diff --git a/src/drunc/broadcast/server/broadcast_sender.py b/src/drunc/broadcast/server/broadcast_sender.py index f138a3f5a..2ee371e5a 100644 --- a/src/drunc/broadcast/server/broadcast_sender.py +++ b/src/drunc/broadcast/server/broadcast_sender.py @@ -60,8 +60,7 @@ def __init__( def describe_broadcast(self): if self.implementation: return self.implementation.describe_broadcast() - else: - return None + return None def can_broadcast(self): if not self.implementation: diff --git a/src/drunc/broadcast/utils.py b/src/drunc/broadcast/utils.py index 1492996ef..ee8df8446 100644 --- a/src/drunc/broadcast/utils.py +++ b/src/drunc/broadcast/utils.py @@ -26,5 +26,4 @@ def get_broadcast_level_from_broadcast_type( bt = BroadcastType.Name(btype) if bt not in levels: return logger.info - else: - return getattr(logger, levels[bt].lower()) + return getattr(logger, levels[bt].lower()) diff --git a/src/drunc/connectivity_service/client.py b/src/drunc/connectivity_service/client.py index 368fb6de7..0c707e6af 100644 --- a/src/drunc/connectivity_service/client.py +++ b/src/drunc/connectivity_service/client.py @@ -194,11 +194,10 @@ def resolve(self, uid_regex: str, data_type: str, ntries=50) -> dict: content = response.json() if content: return content - else: - self.log.debug( - f"Could not find the address of '{uid_regex}' on the application registry" - ) - time.sleep(0.2) + self.log.debug( + f"Could not find the address of '{uid_regex}' on the application registry" + ) + time.sleep(0.2) except (HTTPError, ConnectionError, ReadTimeout) as e: self.log.debug(e) diff --git a/src/drunc/controller/children_interface/grpc_child.py b/src/drunc/controller/children_interface/grpc_child.py index 7de703afa..20c15ddb9 100644 --- a/src/drunc/controller/children_interface/grpc_child.py +++ b/src/drunc/controller/children_interface/grpc_child.py @@ -22,9 +22,7 @@ from drunc.broadcast.client.broadcast_handler import BroadcastHandler from drunc.broadcast.client.configuration import BroadcastClientConfHandler -from drunc.connectivity_service.exceptions import ( - ApplicationLookupUnsuccessful, -) +from drunc.connectivity_service.exceptions import ApplicationLookupUnsuccessful from drunc.controller.children_interface.child_node import ChildNode from drunc.exceptions import DruncSetupException from drunc.utils.configuration import ConfHandler, ConfTypes @@ -116,10 +114,8 @@ def _setup_connection(self): if tries_remaining == 0: raise server_unreachable_error self.log.info( - ( - f"Could not connect to the controller ({self.uri}). " - f"Trying {tries_remaining} more times..." - ) + f"Could not connect to the controller ({self.uri}). " + f"Trying {tries_remaining} more times..." ) time.sleep(5) diff --git a/src/drunc/controller/children_interface/rest_api_child.py b/src/drunc/controller/children_interface/rest_api_child.py index 46ce43b14..72a813baa 100644 --- a/src/drunc/controller/children_interface/rest_api_child.py +++ b/src/drunc/controller/children_interface/rest_api_child.py @@ -294,8 +294,7 @@ def send_app_command( raise CouldnotSendCommand( f"Connection error to {self.app_url}" ) from e - else: - self.log.error("Trying again...") + self.log.error("Trying again...") self.log.debug(f"Ack to {self.app}: {ack.status_code}") self.sent_cmd = cmd_id @@ -326,13 +325,12 @@ def check_response(self, timeout: int = 0) -> dict: raise NoResponse( f"No response available from {self.app} for command {self.sent_cmd}" ) - else: - self.log.error( - f"Timeout while waiting for a reply from {self.app} for command {self.sent_cmd}" - ) - raise ResponseTimeout( - f"Timeout while waiting for a reply from {self.app} for command {self.sent_cmd}" - ) + self.log.error( + f"Timeout while waiting for a reply from {self.app} for command {self.sent_cmd}" + ) + raise ResponseTimeout( + f"Timeout while waiting for a reply from {self.app} for command {self.sent_cmd}" + ) return r diff --git a/src/drunc/controller/controller.py b/src/drunc/controller/controller.py index a22e5a57a..c816f6485 100644 --- a/src/drunc/controller/controller.py +++ b/src/drunc/controller/controller.py @@ -3,9 +3,10 @@ import threading import time import traceback +from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor, as_completed from functools import wraps -from typing import Callable, List, TypeVar +from typing import TypeVar from druncschema.authoriser_pb2 import ActionType, SystemType from druncschema.broadcast_pb2 import BroadcastType @@ -51,10 +52,7 @@ from drunc.exceptions import DruncCommandException, DruncException from drunc.fsm.actions.utils import get_dotdrunc_json from drunc.fsm.configuration import FSMConfHandler -from drunc.fsm.exceptions import ( - DotDruncJsonIncorrectFormat, - DotDruncJsonNotFound, -) +from drunc.fsm.exceptions import DotDruncJsonIncorrectFormat, DotDruncJsonNotFound from drunc.fsm.utils import convert_fsm_transition from drunc.utils.grpc_utils import UnpackingError, pack_to_any, unpack_any from drunc.utils.utils import get_logger @@ -208,7 +206,7 @@ def wrap(obj, request, context): class Controller(ControllerServicer): - children_nodes: List[ChildNode] = [] + children_nodes: list[ChildNode] = [] def __init__(self, configuration, name: str, session: str, token: Token): super().__init__() @@ -629,7 +627,7 @@ def propagate_to_child( ) self.log.error( - f"Failed to propagate {command_name} to {child.name} ({child.name}) EXCEPTION THROWN: {str(e)}" + f"Failed to propagate {command_name} to {child.name} ({child.name}) EXCEPTION THROWN: {e!s}" ) threads = [] diff --git a/src/drunc/controller/controller_actor.py b/src/drunc/controller/controller_actor.py index 3d5f1affe..643916c66 100644 --- a/src/drunc/controller/controller_actor.py +++ b/src/drunc/controller/controller_actor.py @@ -1,5 +1,4 @@ import threading -from typing import Optional from druncschema.token_pb2 import Token @@ -8,7 +7,7 @@ class ControllerActor: - def __init__(self, token: Optional[Token] = None): + def __init__(self, token: Token | None = None): self.log = get_logger("controller.actor") self._token = Token(token="", user_name="") if token is not None: @@ -21,7 +20,7 @@ def get_token(self) -> Token: def get_user_name(self) -> str: return self._token.user_name - def _update_actor(self, token: Optional[Token] = None) -> None: + def _update_actor(self, token: Token | None = None) -> None: self._lock.acquire() self._token = Token(token="", user_name="") if token is not None: diff --git a/src/drunc/controller/controller_driver.py b/src/drunc/controller/controller_driver.py index e31b0a544..76b53ae52 100644 --- a/src/drunc/controller/controller_driver.py +++ b/src/drunc/controller/controller_driver.py @@ -20,11 +20,7 @@ from druncschema.token_pb2 import Token from drunc.exceptions import DruncServerSideError -from drunc.utils.grpc_utils import ( - UnpackingError, - handle_grpc_error, - unpack_any, -) +from drunc.utils.grpc_utils import UnpackingError, handle_grpc_error, unpack_any from drunc.utils.shell_utils import DecodedResponse from drunc.utils.utils import get_logger @@ -297,7 +293,7 @@ def text(verb="not executed", reason=""): elif response.data.Is(PlainText.DESCRIPTOR): txt = unpack_any(response.data, PlainText) - error_txt = txt.text # noqa: F841 (might need to revisit this) + error_txt = txt.text dr.data = error_txt if response.flag in [ diff --git a/src/drunc/controller/interface/commands.py b/src/drunc/controller/interface/commands.py index bbedd55f3..98a4831e9 100644 --- a/src/drunc/controller/interface/commands.py +++ b/src/drunc/controller/interface/commands.py @@ -362,7 +362,7 @@ def expert_command( if string: data = json.loads(command) else: - with open(command, "r") as f: + with open(command) as f: data = json.load(f) except FileNotFoundError: diff --git a/src/drunc/controller/interface/context.py b/src/drunc/controller/interface/context.py index 082c42a00..70caadad3 100644 --- a/src/drunc/controller/interface/context.py +++ b/src/drunc/controller/interface/context.py @@ -6,10 +6,7 @@ from drunc.broadcast.client.configuration import BroadcastClientConfHandler from drunc.controller.controller_driver import ControllerDriver from drunc.utils.configuration import ConfTypes -from drunc.utils.shell_utils import ( - ShellContext, - create_dummy_token_from_uname, -) +from drunc.utils.shell_utils import ShellContext, create_dummy_token_from_uname from drunc.utils.utils import resolve_localhost_to_hostname diff --git a/src/drunc/controller/interface/shell_utils.py b/src/drunc/controller/interface/shell_utils.py index 4a453812e..f3ee4072b 100644 --- a/src/drunc/controller/interface/shell_utils.py +++ b/src/drunc/controller/interface/shell_utils.py @@ -297,7 +297,7 @@ def controller_setup(ctx, controller_address): if state == "initialising": log.error("Controller did not initialise in time") - return + return None log.debug(f"Taking control of the controller as {ctx.get_token()}") try: @@ -375,12 +375,11 @@ def tree_prefix(i, n): last = "└── " if i == 0 and n == 1: return first_one - elif i == 0: + if i == 0: return first_many - elif i == n - 1: + if i == n - 1: return last - else: - return next + return next def validate_and_format_fsm_arguments( diff --git a/src/drunc/controller/stateful_node.py b/src/drunc/controller/stateful_node.py index e877ecd7d..05a873303 100644 --- a/src/drunc/controller/stateful_node.py +++ b/src/drunc/controller/stateful_node.py @@ -1,7 +1,8 @@ from __future__ import annotations import abc -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any from druncschema.opmon.FSM_pb2 import FSMStatus @@ -32,9 +33,7 @@ def value(self, value): self.stateful_node.log.info(f"{self._name} changed to {value}") self.stateful_node.publish_state() - def __init__( - self, name: str, stateful_node=None, initial_value: Optional[str] = None - ): + def __init__(self, name: str, stateful_node=None, initial_value: str | None = None): self._name = name self.stateful_node = stateful_node self._value = initial_value @@ -89,7 +88,7 @@ class StatefulNode(abc.ABC): def __init__( self, fsm_configuration, - publisher: Optional[Callable[[Any], None]] = None, + publisher: Callable[[Any], None] | None = None, init_state: str = "", session: str = "", name: str = "", diff --git a/src/drunc/data/process_manager/schema/process_manager.schema.json b/src/drunc/data/process_manager/schema/process_manager.schema.json index ce8445233..b321cbb9b 100644 --- a/src/drunc/data/process_manager/schema/process_manager.schema.json +++ b/src/drunc/data/process_manager/schema/process_manager.schema.json @@ -7,76 +7,156 @@ { "type": "object", "properties": { - "type": {"type": "string", "enum": ["ssh", "k8s"]}, - "name": {"type": "string"}, + "type": { + "type": "string", + "enum": [ + "ssh", + "k8s", + "subprocess" + ] + }, + "name": { + "type": "string" + }, "authoriser": { "type": "object", - "properties": {"type": {"type": "string"}}, - "required": ["type"], + "properties": { + "type": { + "type": "string" + } + }, + "required": [ + "type" + ], "additionalProperties": true }, - "environment": {"type": "object"}, + "environment": { + "type": "object" + }, "opmon_uri": { "type": "object", "properties": { - "path": {"type": "string"}, - "type": {"type": "string", "enum": ["file", "stream"]} + "path": { + "type": "string" + }, + "type": { + "type": "string", + "enum": [ + "file", + "stream" + ] + } }, - "required": ["path", "type"], + "required": [ + "path", + "type" + ], "additionalProperties": true }, "opmon_conf": { "type": "object", "properties": { - "level": {"type": "string"}, - "interval_s": {"type": ["number", "integer"]} + "level": { + "type": "string" + }, + "interval_s": { + "type": [ + "number", + "integer" + ] + } }, "additionalProperties": true }, "broadcaster": { "type": "object", "properties": { - "type": {"type": "string"}, - "kafka_address": {"type": "string"}, - "publish_timeout": {"type": ["integer", "number"]} + "type": { + "type": "string" + }, + "kafka_address": { + "type": "string" + }, + "publish_timeout": { + "type": [ + "integer", + "number" + ] + } }, - "required": ["type", "kafka_address", "publish_timeout"], + "required": [ + "type", + "kafka_address", + "publish_timeout" + ], "additionalProperties": true }, - "command_address": {"type": "string"} + "command_address": { + "type": "string" + } }, - "required": ["type", "name"], + "required": [ + "type", + "name" + ], "additionalProperties": true }, { - "if": {"properties": {"type": {"const": "ssh"}}}, + "if": { + "properties": { + "type": { + "const": "ssh" + } + } + }, "then": { "type": "object", "properties": { "settings": { "type": "object", "properties": { - "disable_localhost_host_key_check": {"type": "boolean"}, - "disable_host_key_check": {"type": "boolean"} + "disable_localhost_host_key_check": { + "type": "boolean" + }, + "disable_host_key_check": { + "type": "boolean" + } }, "additionalProperties": true }, - "command_address": {"type": "string"}, - "kill_timeout": {"type": ["number", "integer"]} + "command_address": { + "type": "string" + }, + "kill_timeout": { + "type": [ + "number", + "integer" + ] + } }, "additionalProperties": true } }, { - "if": {"properties": {"type": {"const": "k8s"}}}, + "if": { + "properties": { + "type": { + "const": "k8s" + } + } + }, "then": { "type": "object", "properties": { - "command_address": {"type": "string"}, - "image": {"type": "string"} + "command_address": { + "type": "string" + }, + "image": { + "type": "string" + } }, "additionalProperties": true } } ] -} +} \ No newline at end of file diff --git a/src/drunc/data/process_manager/subprocess-standalone.json b/src/drunc/data/process_manager/subprocess-standalone.json new file mode 100644 index 000000000..5ff7e8f0e --- /dev/null +++ b/src/drunc/data/process_manager/subprocess-standalone.json @@ -0,0 +1,19 @@ +{ + "type": "subprocess", + "name": "SubProcessProcessManager", + "kill_timeout": 0.5, + "authoriser": { + "type": "dummy" + }, + "environment": { + "GRPC_ENABLE_FORK_SUPPORT": "false" + }, + "opmon_uri": { + "path": "./info.json", + "type": "file" + }, + "opmon_conf": { + "level": "info", + "interval_s": 10.0 + } +} \ No newline at end of file diff --git a/src/drunc/fsm/actions/file_logbook.py b/src/drunc/fsm/actions/file_logbook.py index b5cd530a9..ecdb85a0e 100644 --- a/src/drunc/fsm/actions/file_logbook.py +++ b/src/drunc/fsm/actions/file_logbook.py @@ -1,5 +1,3 @@ -from typing import Optional - from drunc.fsm.core import FSMAction from drunc.utils.utils import now_str @@ -11,7 +9,7 @@ def __init__(self, configuration): self.file = self.conf_dict["file_name"] def post_start( - self, _input_data, _context, file_logbook_post: Optional[str] = None, **kwargs + self, _input_data, _context, file_logbook_post: str | None = None, **kwargs ): with open(self.file, "a") as f: f.write( diff --git a/src/drunc/fsm/actions/user_provided_run_number.py b/src/drunc/fsm/actions/user_provided_run_number.py index ea97dec5f..83789749f 100644 --- a/src/drunc/fsm/actions/user_provided_run_number.py +++ b/src/drunc/fsm/actions/user_provided_run_number.py @@ -1,5 +1,4 @@ import time -from typing import Optional from drunc.fsm.actions.utils import validate_run_type from drunc.fsm.core import FSMAction @@ -14,9 +13,9 @@ def pre_start( _input_data: dict, _context, run_number: int, - run_type: Optional[str] = "TEST", + run_type: str | None = "TEST", disable_data_storage: bool = False, - trigger_rate: Optional[float] = None, + trigger_rate: float | None = None, **kwargs, ): run_type = validate_run_type(run_type.upper()) diff --git a/src/drunc/fsm/actions/usvc_elisa_logbook.py b/src/drunc/fsm/actions/usvc_elisa_logbook.py index d67c3fd38..e959535c6 100644 --- a/src/drunc/fsm/actions/usvc_elisa_logbook.py +++ b/src/drunc/fsm/actions/usvc_elisa_logbook.py @@ -1,6 +1,5 @@ import json import os -from typing import Optional import requests @@ -83,10 +82,10 @@ def __init__(self): self.timeout = 5 def post_start( - self, _input_data: dict, _context, elisa_post: Optional[str] = None, **kwargs + self, _input_data: dict, _context, elisa_post: str | None = None, **kwargs ): if self.elisa_hardware in self.no_publish_hardware: - return + return None text = "" self.thread_id = None # Clear this value here, so that if it fails stop can't reply to an old message @@ -142,10 +141,10 @@ def post_start( return _input_data def post_drain_dataflow( - self, _input_data, _context, elisa_post: Optional[str] = None, **kwargs + self, _input_data, _context, elisa_post: str | None = None, **kwargs ): if self.elisa_hardware in self.no_publish_hardware: - return + return None text = "" if elisa_post is not None: self.log.info( diff --git a/src/drunc/fsm/actions/usvc_provided_run_number.py b/src/drunc/fsm/actions/usvc_provided_run_number.py index 000c96957..62bc256eb 100644 --- a/src/drunc/fsm/actions/usvc_provided_run_number.py +++ b/src/drunc/fsm/actions/usvc_provided_run_number.py @@ -1,5 +1,4 @@ import time -from typing import Optional import requests @@ -31,7 +30,7 @@ def pre_start( _context, run_type: str, disable_data_storage: bool = False, - trigger_rate: Optional[float] = None, + trigger_rate: float | None = None, **kwargs, ): run_type = validate_run_type(run_type.upper()) diff --git a/src/drunc/fsm/core.py b/src/drunc/fsm/core.py index 2b6b1c081..bd1bfd823 100644 --- a/src/drunc/fsm/core.py +++ b/src/drunc/fsm/core.py @@ -229,8 +229,7 @@ def get_destination_state(self, source_state, transition) -> str: if self.can_execute_transition(source_state, transition): if tr.destination == "": return source_state - else: - return tr.destination + return tr.destination def get_executable_transitions(self, source_state) -> list[Transition]: valid_transitions = [] diff --git a/src/drunc/process_manager/configuration.py b/src/drunc/process_manager/configuration.py index 6c6a46172..77fb6b0e2 100644 --- a/src/drunc/process_manager/configuration.py +++ b/src/drunc/process_manager/configuration.py @@ -3,7 +3,7 @@ import sys from enum import Enum from importlib import resources -from typing import Any, Dict, Union +from typing import Any from urllib.parse import unquote, urlparse from jsonschema import ValidationError @@ -23,6 +23,7 @@ class ProcessManagerTypes(Enum): Unknown = 0 SSH = 1 K8s = 2 + SubProcess = 3 class ProcessManagerConfData: @@ -59,6 +60,9 @@ def _parse_dict(self, data): case "ssh": new_data.type = ProcessManagerTypes.SSH new_data.kill_timeout = data.get("kill_timeout", 0.5) + case "subprocess": + new_data.type = ProcessManagerTypes.SubProcess + new_data.kill_timeout = data.get("kill_timeout", 0.5) case "k8s": new_data.type = ProcessManagerTypes.K8s new_data.image = data.get("image", "ghcr.io/dune-daq/alma9:latest") @@ -168,7 +172,7 @@ def get_process_manager_configuration(process_manager_conf_filename: str) -> str return process_manager_conf_filename -def _load_pm_schema_from_package() -> Dict[str, Any]: +def _load_pm_schema_from_package() -> dict[str, Any]: """Load JSON Schema from packaged file; raise if missing or unreadable.""" try: # Package path for schema JSON: drunc/data/process_manager/schema/process_manager.schema.json @@ -188,7 +192,7 @@ def _load_pm_schema_from_package() -> Dict[str, Any]: ) -def _load_config_from_source(source: Union[str, Dict[str, Any]]) -> Dict[str, Any]: +def _load_config_from_source(source: str | dict[str, Any]) -> dict[str, Any]: """ Accepts: - file URLs (file:///...), @@ -207,7 +211,7 @@ def _load_config_from_source(source: Union[str, Dict[str, Any]]) -> Dict[str, An u = urlparse(source) if u.scheme == "file": path = unquote(u.path) - with open(path, "r", encoding="utf-8") as f: + with open(path, encoding="utf-8") as f: return json.load(f) raise FileNotFoundError(f"Unsupported URL scheme: {u.scheme}") @@ -215,7 +219,7 @@ def _load_config_from_source(source: Union[str, Dict[str, Any]]) -> Dict[str, An resolved_url = get_process_manager_configuration(source) u = urlparse(resolved_url) path = unquote(u.path) - with open(path, "r", encoding="utf-8") as f: + with open(path, encoding="utf-8") as f: return json.load(f) # Already a dict @@ -225,7 +229,7 @@ def _load_config_from_source(source: Union[str, Dict[str, Any]]) -> Dict[str, An raise TypeError("validate_config() expects dict, path, URL, or raw JSON text") -def validate_pm_config(config_or_source: Union[str, Dict[str, Any]]) -> bool: +def validate_pm_config(config_or_source: str | dict[str, Any]) -> bool: try: pm_conf = _load_config_from_source(config_or_source) schema = _load_pm_schema_from_package() diff --git a/src/drunc/process_manager/interface/context.py b/src/drunc/process_manager/interface/context.py index a81396130..615e41487 100644 --- a/src/drunc/process_manager/interface/context.py +++ b/src/drunc/process_manager/interface/context.py @@ -6,10 +6,7 @@ from drunc.broadcast.client.configuration import BroadcastClientConfHandler from drunc.process_manager.process_manager_driver import ProcessManagerDriver from drunc.utils.configuration import ConfTypes -from drunc.utils.shell_utils import ( - ShellContext, - create_dummy_token_from_uname, -) +from drunc.utils.shell_utils import ShellContext, create_dummy_token_from_uname from drunc.utils.utils import get_logger, resolve_localhost_to_hostname diff --git a/src/drunc/process_manager/k8s_process_manager.py b/src/drunc/process_manager/k8s_process_manager.py index 584796beb..013960378 100644 --- a/src/drunc/process_manager/k8s_process_manager.py +++ b/src/drunc/process_manager/k8s_process_manager.py @@ -333,10 +333,9 @@ def _verify_host_in_cluster(self, target_host): if cached: self.log.debug(f"Host '{target_host}' cached (valid)") return True - else: - raise DruncK8sNodeException( - f"Host '{target_host}' was previously verified as unavailable" - ) + raise DruncK8sNodeException( + f"Host '{target_host}' was previously verified as unavailable" + ) try: target_node = self._core_v1_api.read_node(name=target_host) @@ -365,7 +364,7 @@ def _verify_host_in_cluster(self, target_host): raise DruncK8sNodeException( f"Target host '{target_host}' is not part of the Kubernetes cluster" ) - elif e.status in [401, 403]: + if e.status in [401, 403]: raise DruncK8sException( f"Permission denied accessing cluster to verify '{target_host}': {e}" ) diff --git a/src/drunc/process_manager/oks_parser.py b/src/drunc/process_manager/oks_parser.py index 302d9b982..0458e77e0 100644 --- a/src/drunc/process_manager/oks_parser.py +++ b/src/drunc/process_manager/oks_parser.py @@ -1,5 +1,4 @@ import os -from typing import Dict, List import confmodel_dal @@ -8,7 +7,7 @@ from drunc.utils.utils import get_logger -def collect_variables(variables, env_dict: Dict[str, str]) -> None: +def collect_variables(variables, env_dict: dict[str, str]) -> None: """!Process a dal::Variable object, placing key/value pairs in a dictionary @param variables A Variable/VariableSet object @@ -34,11 +33,11 @@ def collect_apps( db, session_obj, segment_obj, - env: Dict[str, str], + env: dict[str, str], tree_prefix=[ 0, ], -) -> List[Dict]: +) -> list[dict]: """! Recustively collect (daq) application belonging to segment and its subsegments @param session_obj The session the segment belongs to @@ -168,7 +167,7 @@ def collect_apps( return apps -def collect_infra_apps(session, env: Dict[str, str], tree_prefix) -> List[Dict]: +def collect_infra_apps(session, env: dict[str, str], tree_prefix) -> list[dict]: """! Collect infrastructure applications @param session The session diff --git a/src/drunc/process_manager/process_manager.py b/src/drunc/process_manager/process_manager.py index 589e8ea13..14ae24a86 100644 --- a/src/drunc/process_manager/process_manager.py +++ b/src/drunc/process_manager/process_manager.py @@ -19,10 +19,7 @@ ProcessUUID, ) from druncschema.process_manager_pb2_grpc import ProcessManagerServicer -from druncschema.request_response_pb2 import ( - Request, - ResponseFlag, -) +from druncschema.request_response_pb2 import Request, ResponseFlag from google.rpc import code_pb2 from grpc import ServicerContext @@ -487,7 +484,7 @@ def _ensure_one_process( ) -> str: if uuids == []: raise BadQuery("The process corresponding to the query doesn't exist") - elif len(uuids) > 1: + if len(uuids) > 1: raise BadQuery("There are more than 1 processes corresponding to the query") if in_boot_request: @@ -567,13 +564,17 @@ def get(conf, **kwargs): log.debug("Starting [green]SSH process_manager[/green]") return SSHProcessManager(conf, **kwargs) - elif conf.data.type == ProcessManagerTypes.K8s: + if conf.data.type == ProcessManagerTypes.SubProcess: + from drunc.process_manager.subprocess_process_manager import ( + SubProcessProcessManager, + ) + + log.info("Starting [green]SubProcess process_manager[/green]") + return SubProcessProcessManager(conf, **kwargs) + if conf.data.type == ProcessManagerTypes.K8s: from drunc.process_manager.k8s_process_manager import K8sProcessManager log.debug("Starting [green]K8s process_manager[/green]") return K8sProcessManager(conf, **kwargs) - else: - log.error(f"ProcessManager type {conf.get('type')} is unsupported!") - raise RuntimeError( - f"ProcessManager type {conf.get('type')} is unsupported!" - ) + log.error(f"ProcessManager type {conf.get('type')} is unsupported!") + raise RuntimeError(f"ProcessManager type {conf.get('type')} is unsupported!") diff --git a/src/drunc/process_manager/process_manager_driver.py b/src/drunc/process_manager/process_manager_driver.py index 22b4f04b6..1c2975702 100644 --- a/src/drunc/process_manager/process_manager_driver.py +++ b/src/drunc/process_manager/process_manager_driver.py @@ -6,7 +6,7 @@ import time from collections.abc import Iterator from time import sleep -from typing import Any, Dict, List +from typing import Any import grpc from druncschema.description_pb2 import Description @@ -124,7 +124,7 @@ def boot( previous_host = this_host last_boot_on_host_at[this_host] = time.time() - + self.log.debug(f"Boot request: {request}") try: response = self.stub.boot(request, timeout=timeout) except grpc.RpcError as e: @@ -142,7 +142,7 @@ def _collect_all_apps( session_dal, db, session_name: str, - ) -> List[Dict]: + ) -> list[dict]: from drunc.process_manager.oks_parser import collect_apps, collect_infra_apps env = { @@ -172,8 +172,8 @@ def _collect_all_apps( return apps def _prepare_exec_and_args( - self, session_dal, exe: str, args: List[str] - ) -> List[ProcessDescription.ExecAndArgs]: + self, session_dal, exe: str, args: list[str] + ) -> list[ProcessDescription.ExecAndArgs]: """ Prepare """ @@ -207,7 +207,7 @@ def _prepare_exec_and_args( def _build_boot_request( self, - app: Dict, + app: dict, user: str, session_name: str, session_dal, @@ -384,7 +384,7 @@ def get_controller_address(session_dal, session_name): connection_server, connection_port, ) - return + return None return uri.replace("grpc://", "") diff --git a/src/drunc/process_manager/subprocess_process_manager.py b/src/drunc/process_manager/subprocess_process_manager.py new file mode 100644 index 000000000..360b03446 --- /dev/null +++ b/src/drunc/process_manager/subprocess_process_manager.py @@ -0,0 +1,595 @@ +import getpass +import os +import signal +import tempfile +import threading +from ctypes import CDLL +from subprocess import Popen +from time import sleep + +import sh +from druncschema.broadcast_pb2 import BroadcastType +from druncschema.process_manager_pb2 import ( + BootRequest, + LogLines, + LogRequest, + ProcessDescription, + ProcessInstance, + ProcessInstanceList, + ProcessMetadata, + ProcessQuery, + ProcessRestriction, + ProcessUUID, +) +from druncschema.request_response_pb2 import ResponseFlag + +from drunc.exceptions import DruncCommandException, DruncException +from drunc.process_manager.process_manager import ( + ProcessManager, + ProcessManagerConfHandler, +) + +# # ------------------------------------------------ +# # pexpect.spawn(...,preexec_fn=on_parent_exit('SIGTERM')) + +# Constant taken from http://linux.die.net/include/linux/prctl.h +PR_SET_PDEATHSIG = 1 + + +class PrCtlError(DruncException): + pass + + +def on_parent_exit(signum, setsid=True): + """ + Return a function to be run in a child process which will trigger + SIGNAME to be sent when the parent process dies + """ + + def set_parent_exit_signal(): + # http://linux.die.net/man/2/prctl + result = CDLL("libc.so.6").prctl(PR_SET_PDEATHSIG, signum) + if result != 0: + raise PrCtlError("prctl failed with error code %s" % result) + + if setsid: + os.setsid() + + return set_parent_exit_signal + + +# ------------------------------------------------ + + +class AppProcessWatcherThread(threading.Thread): + def __init__(self, pm, name, user, session, process): + threading.Thread.__init__(self) + self.pm = pm + self.user = user + self.session = session + self.name = name + self.process = process + + def run(self): + self.process.wait() + self.pm.notify_join( + name=self.name, session=self.session, user=self.user, exec=self.process + ) + + +class SubProcessProcessManager(ProcessManager): + """ + A process manager that uses subprocess.Popen to launch and manage processes locally. + Used for testing as a CI tool. + """ + + def __init__(self, configuration: ProcessManagerConfHandler, **kwargs): + """ + Initialize the SubProcessProcessManager with the given configuration. + + Args: + configuration (ProcessManagerConfHandler): The configuration handler for the + process manager. + """ + self.session: str = getpass.getuser() # unfortunate + super().__init__(configuration=configuration, session=self.session, **kwargs) + + self.watchers: list[AppProcessWatcherThread] = [] + + def kill_processes(self, uuids: list) -> ProcessInstanceList: + """ + Kill the processes with the given UUIDs. + + Args: + uuids (list): List of process UUIDs to kill. + + Returns: + ProcessInstanceList: List of process instances that were killed. + """ + + # Make a list of the killed processes to return + ret: list[ProcessInstance] = [] + + # Iterate over the UUIDs and kill each process + for proc_uuid in uuids: + # Retrieve the process from the store + process: sh.RunningCommand = self.process_store[proc_uuid] + + # Get the application name from the boot request metadata + app_name: str = self.boot_request[ + proc_uuid + ].process_description.metadata.name + + # Kill the process if it is still running + if process.poll() is None: + sequence: list[signal.Signals] = [ + signal.SIGQUIT, + signal.SIGTERM, + signal.SIGKILL, # Kept as nuclear option + ] + for sig in sequence: + if process.poll() is not None: + self.log.info( + f"Process '{app_name}' already dead with PID {proc_uuid}" + ) + break + self.log.info( + f"Sending signal '{str(sig).split('.')[-1]}' to '{app_name}' with UUID {proc_uuid}" + ) + process.send_signal(sig) # TODO grab this from the inputs + if process.poll() is not None: + break + sleep(self.configuration.data.kill_timeout) + + # Construct the ProcessInstance to return + pd = ProcessDescription() + pd.CopyFrom(self.boot_request[proc_uuid].process_description) + + pr = ProcessRestriction() + pr.CopyFrom(self.boot_request[proc_uuid].process_restriction) + + pu = ProcessUUID(uuid=proc_uuid) + + return_code = self.process_store[proc_uuid].poll() + + ret += [ + ProcessInstance( + process_description=pd, + process_restriction=pr, + status_code=ProcessInstance.StatusCode.DEAD, + return_code=return_code, + uuid=pu, + ) + ] + del self.process_store[proc_uuid] + + return ProcessInstanceList(values=ret) + + def _terminate_impl(self) -> ProcessInstanceList: + """ + Terminate all running processes. + + Returns: + ProcessInstanceList: List of process instances that were terminated. + """ + + self.log.info("Terminating") + + # If there are known processes, kill them + if self.process_store: + self.log.info("Killing all the known processes before exiting") + + # Get all the process UUIDs + uuids = self._get_process_uid( + query=ProcessQuery(names=[".*"]), order_by="leaf_first" + ) + return self.kill_processes(uuids) + self.log.info("No known process to kill before exiting") + return ProcessInstanceList() + + async def _logs_impl(self, log_request: LogRequest) -> LogLines: + """ + Retrieve logs for the specified process. + + Runs the `tail` command to get the last `how_far` lines from the log file as a + subprocess, yielding each line as a LogLines object. This is the most efficient + way to retrieve logs without loading the entire file into memory. + + Args: + log_request (LogRequest): The log request containing the query and how far + to retrieve. + + Yields: + LogLines: The log lines retrieved for the process. + """ + + self.log.debug(f"Retrieving logs for {log_request.query}") + + # Ensure only one process matches the query, get its log file + uid: str = self._ensure_one_process(self._get_process_uid(log_request.query)) + logfile = self.boot_request[uid].process_description.process_logs_path + + # Use a temporary file to store the logs + f = tempfile.NamedTemporaryFile(delete=False) + f_file = open(f.name, "w") + + # Determine how many lines to retrieve + nlines = log_request.how_far + if not nlines: + nlines = 100 + + # Run the tail command to get the logs + try: + cmd = [ + "tail", + f"-{nlines}", + logfile, + ] + p = Popen( + cmd, + stdout=f_file, + stderr=f_file, + ) + p.wait() + except Exception as e: + ll = LogLines( + uuid=ProcessUUID(uuid=uid), line=f"Could not retrieve logs: {e!s}" + ) + yield ll + if uid in self.process_store: + llstdout = LogLines( + uuid=ProcessUUID(uuid=uid), + line=f"stdout: {self.process_store[uid].stdout}", + ) + llstderr = LogLines( + uuid=ProcessUUID(uuid=uid), + line=f"stderr: {self.process_store[uid].stderr}", + ) + yield llstdout + yield llstderr + + # Close the temporary file and read its contents + f.close() + with open(f.name) as fi: + lines = fi.readlines() + for line in lines: + ll = LogLines(uuid=ProcessUUID(uuid=uid), line=line) + yield ll + + # Clean up the temporary file + os.remove(f.name) + + def notify_join( + self, name: str, session: str, user: str, exec: sh.RunningCommand + ) -> None: + """ + Notify that a process has exited and perform cleanup. + + Args: + name (str): The name of the process. + session (str): The session associated with the process. + user (str): The user who started the process. + exec (sh.RunningCommand): The process that has exited. + + Returns: + None + """ + self.log.debug(f"{self.name} joining processes from the event loop") + exit_code = exec.poll() + + end_msg: str = ( + f"Process '{name}' from session '{session}' with PID {exec.pid} " + f"exited with code {exit_code}" + ) + self.log.info(end_msg) + + if exec: + self.log.debug(name + str(exec)) + + self.broadcast(end_msg, BroadcastType.SUBPROCESS_STATUS_UPDATE) + return + + def _watch( + self, name: str, session: str, user: str, process: sh.RunningCommand + ) -> None: + """ + Start a watcher thread to monitor the given process. + + Args: + name (str): The name of the process. + session (str): The session associated with the process. + user (str): The user who started the process. + process (sh.RunningCommand): The process to watch. + + Returns: + None + """ + + self.log.debug(f"{self.name} watching process {name}") + t = AppProcessWatcherThread( + pm=self, session=session, user=user, name=name, process=process + ) + t.start() + self.watchers.append(t) + + def __boot(self, boot_request: BootRequest) -> ProcessInstance: + """ + Boot a new process based on the provided BootRequest. + + Args: + boot_request (BootRequest): The request containing process description and restrictions. + + Returns: + ProcessInstance: The instance of the booted process. + """ + + # Validate the boot request + meta: ProcessMetadata = boot_request.process_description.metadata + if len(boot_request.process_restriction.allowed_hosts) < 1: + raise DruncCommandException("No allowed host provided! bailing") + + error: str = "" + pid: int | None = None + for host in boot_request.process_restriction.allowed_hosts: + # We can only run processes on localhost + if host != "localhost": + raise DruncCommandException( + "SubProcess process manager does not support remote hosts" + ) + + try: + # Extract necessary information from the boot request + hostname: str = host + log_file: str = boot_request.process_description.process_logs_path + env_var: dict[str, str] = boot_request.process_description.env + + # Setup the command to run + cmd = ( + f"SubProcessPM: Starting process {os.getpid()} on host " + f"{os.uname().nodename} as user {getpass.getuser()}; " + ) + + # Add exported environment variables + env_setup_cmd: str = "; ".join( + [f"export {n}='{v}'" for n, v in env_var.items()] + ) + + # Change to the specified execution directory + exec_dir: str = ( + boot_request.process_description.process_execution_directory + ) + cmd = f"cd {exec_dir}; " + + # Add the executable and its arguments + for ( + exe_arg + ) in boot_request.process_description.executable_and_arguments: + cmd += exe_arg.exec + for arg in exe_arg.args: + cmd += f" {arg}" + cmd += "; " + + if cmd[-1] == ";": + cmd = cmd[:-1] + + # Setup the cli command to run + wrapped_cmd: str = ( + f'drunc-process-wrapper --log {log_file} "{env_setup_cmd}; {cmd}"' + ) + + # Log the wrapped command, splitting on ';' and ':' for readability + wrapped_cmd_fmt_for_logging = wrapped_cmd.replace(";", ";\n").replace( + ":", ":\n" + ) + self.log.debug(f"Running command:\n{wrapped_cmd_fmt_for_logging}") + + process: Popen = Popen( + wrapped_cmd, + shell=True, + preexec_fn=on_parent_exit(signal.SIGTERM), + ) + self.log.debug(f"Started process with PID {process.pid}") + self.process_store[str(process.pid)] = process + pid: str = str(process.pid) + + self._watch( + name=meta.name, + user=meta.user, + session=meta.session, + process=self.process_store[pid], + ) + self.log.debug("Watcher started") + break + + except Exception as e: + error += str(e) + print(f"Couldn't start on host {host}, reason:\n{e!s}") + continue + + # Add the boot request to the boot_request store + self.boot_request[pid] = BootRequest() + self.boot_request[pid].CopyFrom(boot_request) + hostname: str = "localhost" + self.boot_request[pid].process_description.metadata.hostname = hostname + + self.log.info( + f"Booted '{boot_request.process_description.metadata.name}' from session '" + f"{boot_request.process_description.metadata.session}' with PID {pid}" + ) + + # Construct the ProcessInstance to return + pd = ProcessDescription() + pd.CopyFrom(self.boot_request[pid].process_description) + pr = ProcessRestriction() + pr.CopyFrom(self.boot_request[pid].process_restriction) + pu = ProcessUUID(uuid=pid) + + # If the process failed to start, return a DEAD instance + if pid not in self.process_store: + pi = ProcessInstance( + process_description=pd, + process_restriction=pr, + status_code=ProcessInstance.StatusCode.DEAD, ## should be unknown + return_code=None, + uuid=pu, + ) + return pi + + # If the process started, return a RUNNING instance + return_code: int | None = self.process_store[pid].poll() + alive: bool = return_code is not None + pi = ProcessInstance( + process_description=pd, + process_restriction=pr, + status_code=ProcessInstance.StatusCode.RUNNING + if alive + else ProcessInstance.StatusCode.DEAD, + return_code=return_code, + uuid=pu, + ) + return pi + + def _ps_impl(self, query: ProcessQuery) -> ProcessInstanceList: + """ + List processes matching the given query. + + Args: + query (ProcessQuery): The query to filter processes. + + Returns: + ProcessInstanceList: List of process instances matching the query. + """ + + self.log.debug(f"{self.name} running ps") + ret: list[ProcessInstance] = [] + + for proc_uuid in self._get_process_uid(query): + if proc_uuid not in self.process_store: + pu = ProcessUUID(uuid=proc_uuid) + pi = ProcessInstance( + process_description=ProcessDescription(), + process_restriction=ProcessRestriction(), + status_code=ProcessInstance.StatusCode.DEAD, # should be unknown + return_code=None, + uuid=pu, + ) + ret += [pi] + continue + pd = ProcessDescription() + pd.CopyFrom(self.boot_request[proc_uuid].process_description) + pr = ProcessRestriction() + pr.CopyFrom(self.boot_request[proc_uuid].process_restriction) + pu = ProcessUUID(uuid=proc_uuid) + return_code = None + if self.process_store[proc_uuid].poll() is None: + try: + return_code = self.process_store[proc_uuid].exit_code + except Exception: + pass + + pi = ProcessInstance( + process_description=pd, + process_restriction=pr, + status_code=ProcessInstance.StatusCode.RUNNING + if self.process_store[proc_uuid].poll() is None + else ProcessInstance.StatusCode.DEAD, + return_code=return_code, + uuid=pu, + ) + ret += [pi] + + pil = ProcessInstanceList(values=ret) + + return pil + + def _boot_impl(self, boot_request: BootRequest) -> ProcessInstance: + """ + Boot a new process based on the provided BootRequest. + + Overwrites the base class method to call the internal __boot method. + + Args: + boot_request (BootRequest): The request containing process description and + restrictions. + + Returns: + ProcessInstance: The instance of the booted process. + """ + + self.log.debug(f"{self.name} running _boot_impl") + try: + pi: ProcessInstance = self.__boot(boot_request) + return ProcessInstanceList( + name=self.name, + token=None, + values=[pi], + flag=ResponseFlag.EXECUTED_SUCCESSFULLY, + ) + except Exception as e: + self.log.error(f"Exception during boot: {e!s}") + return ProcessInstanceList( + name=self.name, + token=None, + values=[], + flag=ResponseFlag.EXECUTED_SUCCESSFULLY, + ) + + def _restart_impl(self, query: ProcessQuery) -> ProcessInstanceList: + """ + Restart the process matching the given query. + + Args: + query (ProcessQuery): The query to identify the process to restart. + + Returns: + ProcessInstanceList: List of process instances that were restarted. + """ + + self.log.info(f"{self.name} restarting {query.names} in session {self.session}") + + # Ensure only one process matches the query + uuids: list[str] = self._get_process_uid(query, in_boot_request=True) + uuid: str = self._ensure_one_process(uuids, in_boot_request=True) + + # Make copies of the boot request and uuid to avoid mutation issues + same_uuid_br = BootRequest() + same_uuid_br.CopyFrom(self.boot_request[uuid]) + same_uuid = uuid + + # Terminate the existing process if it is running + if uuid in self.process_store: + process = self.process_store[uuid] + if process.poll() is None: + process.terminate() + + # Clean up the existing process from the stores + del self.process_store[uuid] + del self.boot_request[uuid] + del uuid + + # Boot a new process with the same boot request + ret = self.__boot(same_uuid_br, same_uuid) + + # Clean up temporary copies + del same_uuid_br + del same_uuid + + return ret + + def _kill_impl(self, query: ProcessQuery) -> ProcessInstanceList: + """ + Kill the processes matching the given query. + + Args: + query (ProcessQuery): The query to identify the processes to kill. + + Returns: + ProcessInstanceList: List of process instances that were killed. + """ + + self.log.info(f"{self.name} killing {query.names} in session {self.session}") + if self.process_store: + uuids = self._get_process_uid(query, order_by="leaf_first") + return self.kill_processes(uuids) + self.log.info("No known process to kill before exiting") + return ProcessInstanceList() diff --git a/src/drunc/session_manager/session_manager.py b/src/drunc/session_manager/session_manager.py index 5fb22de12..482df8d08 100644 --- a/src/drunc/session_manager/session_manager.py +++ b/src/drunc/session_manager/session_manager.py @@ -6,10 +6,7 @@ from conffwk import Configuration from druncschema.description_pb2 import CommandDescription, Description -from druncschema.request_response_pb2 import ( - Request, - ResponseFlag, -) +from druncschema.request_response_pb2 import Request, ResponseFlag from druncschema.session_manager_pb2 import ( ActiveSession, AllActiveSessions, diff --git a/src/drunc/tests/apps/test_process_wrapper.py b/src/drunc/tests/apps/test_process_wrapper.py new file mode 100644 index 000000000..9a1921ce6 --- /dev/null +++ b/src/drunc/tests/apps/test_process_wrapper.py @@ -0,0 +1,124 @@ +import pathlib +import tempfile + +import pytest +from click.exceptions import MissingParameter + +from drunc.apps.process_wrapper import main as process_wrapper_main + + +@pytest.fixture(scope="function") +def tmp_path(): + with tempfile.TemporaryDirectory() as tmpdirname: + yield pathlib.Path(tmpdirname) + + +def test_process_wrapper_success(tmp_path): + """ + Test that the process wrapper correctly runs a successful command and logs output. + Validates that the command's output is captured in the log file and that the return + code is 0. + Args: + tmp_path: A temporary directory path provided by pytest fixture. + """ + log_file = tmp_path / "process.log" + cmd = 'echo "Test Success"' + result = process_wrapper_main.main( + args=["--log", str(log_file), cmd], standalone_mode=False + ) + assert result == 0 + with open(log_file) as f: + log_content = f.read() + assert "Test Success" in log_content + + +def test_process_wrapper_failure(tmp_path): + """ + Test that the process wrapper correctly runs a failing command and logs output. + Validates that the command's output is captured in the log file and that the return + code is non-zero. + Args: + tmp_path: A temporary directory path provided by pytest fixture. + """ + log_file = tmp_path / "process.log" + cmd = "exit 42" + result = process_wrapper_main.main( + args=["--log", str(log_file), cmd], standalone_mode=False + ) + assert result == 42 + with open(log_file) as f: + log_content = f.read() + assert log_content == "" or log_content.isspace() + + +def test_process_wrapper_no_log(tmp_path): + """ + Test that the process wrapper raises a MissingParameter error when no log file is specified. + Args: + tmp_path: A temporary directory path provided by pytest fixture. + """ + cmd = 'echo "No log file"' + with pytest.raises(MissingParameter): + process_wrapper_main.main(args=[cmd], standalone_mode=False) + + +def test_process_wrapper_invalid_command(tmp_path): + """ + Test that the process wrapper handles an invalid command gracefully. + Validates that the command's failure is captured in the log file and that the return + code is non-zero. + Args: + tmp_path: A temporary directory path provided by pytest fixture. + """ + log_file = tmp_path / "process.log" + cmd = "nonexistent_command_123" + result = process_wrapper_main.main( + args=["--log", str(log_file), cmd], standalone_mode=False + ) + assert result != 0 + with open(log_file) as f: + log_content = f.read() + assert ( + "not found" in log_content + or "No such file" in log_content + or "command not found" in log_content + ) + + +def test_process_wrapper_logs_stderr(tmp_path): + """ + Test that the process wrapper captures stderr output in the log file. + Validates that stderr output from the command is present in the log file and that the + return code is 0. + Args: + tmp_path: A temporary directory path provided by pytest fixture. + """ + log_file = tmp_path / "process.log" + cmd = "python -c \"import sys; sys.stderr.write('error\\n')\"" + result = process_wrapper_main.main( + args=["--log", str(log_file), cmd], standalone_mode=False + ) + assert result == 0 + with open(log_file) as f: + log_content = f.read() + assert "error" in log_content + + +def test_process_wrapper_multiple_commands(tmp_path): + """ + Test that the process wrapper correctly runs multiple commands and logs output. + Validates that the output from all commands is captured in the log file and that the + return code is 0. + Args: + tmp_path: A temporary directory path provided by pytest fixture. + """ + log_file = tmp_path / "process.log" + cmd = 'echo "first" && echo "second"' + result = process_wrapper_main.main( + args=["--log", str(log_file), cmd], standalone_mode=False + ) + assert result == 0 + with open(log_file) as f: + log_content = f.read() + assert "first" in log_content + assert "second" in log_content diff --git a/src/drunc/utils/grpc_utils.py b/src/drunc/utils/grpc_utils.py index 972a91e2a..c65cb5673 100644 --- a/src/drunc/utils/grpc_utils.py +++ b/src/drunc/utils/grpc_utils.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, NoReturn, Optional +from typing import NoReturn import grpc from druncschema.generic_pb2 import PlainText @@ -105,7 +105,7 @@ def rethrow_if_unreachable_server(grpc_error: grpc.RpcError) -> NoReturn: if not server_is_reachable(grpc_error): if hasattr(grpc_error, "_state"): raise ServerUnreachable(grpc_error._state.details) from grpc_error - elif hasattr(grpc_error, "_details"): + if hasattr(grpc_error, "_details"): raise ServerUnreachable(grpc_error._details) from grpc_error @@ -139,7 +139,7 @@ def handle_grpc_error(error: grpc.RpcError) -> NoReturn: raise error -def interrupt_if_unreachable_server(grpc_error: grpc.RpcError) -> Optional[str]: +def interrupt_if_unreachable_server(grpc_error: grpc.RpcError) -> str | None: """ Interrupt if server is not reachable and return the error details. @@ -153,7 +153,7 @@ def interrupt_if_unreachable_server(grpc_error: grpc.RpcError) -> Optional[str]: if not server_is_reachable(grpc_error): if hasattr(grpc_error, "_state"): return grpc_error._state.details - elif hasattr(grpc_error, "_details"): + if hasattr(grpc_error, "_details"): return grpc_error._details @@ -185,7 +185,7 @@ class GrpcErrorDetails: code: str message: str - details: List[str] + details: list[str] def __str__(self): """ diff --git a/src/drunc/utils/shell_utils.py b/src/drunc/utils/shell_utils.py index d0fefc151..07988c73c 100644 --- a/src/drunc/utils/shell_utils.py +++ b/src/drunc/utils/shell_utils.py @@ -108,11 +108,11 @@ def set_driver(self, name: str, driver: object) -> None: raise DruncShellException(f"Driver {name} already present in this context") self._drivers[name] = driver - def get_driver(self, name: str = None, quiet_fail: bool = False) -> object: + def get_driver(self, name: str | None = None, quiet_fail: bool = False) -> object: try: if name: return self._drivers[name] - elif len(self._drivers) > 1: + if len(self._drivers) > 1: raise DruncShellException("More than one driver in this context") return list(self._drivers.values())[0] except KeyError: diff --git a/src/drunc/utils/utils.py b/src/drunc/utils/utils.py index 71ffe8e39..95cf8c526 100644 --- a/src/drunc/utils/utils.py +++ b/src/drunc/utils/utils.py @@ -102,8 +102,7 @@ def get_new_port(): def now_str(posix_friendly=False): if not posix_friendly: return datetime.now().strftime("%m/%d/%Y,%H:%M:%S") - else: - return datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + return datetime.now().strftime("%Y-%m-%d-%H-%M-%S") def expand_path(path, turn_to_abs_path=False): @@ -248,7 +247,7 @@ def parent_death_pact(signal=signal.SIGHUP): # last three args are unused for PR_SET_PDEATHSIG retcode = libc.prctl(PR_SET_PDEATHSIG, signal, 0, 0, 0) if retcode != 0: - raise Exception("prctl() returned nonzero retcode %d" % retcode) + raise Exception(f"prctl() returned nonzero retcode {retcode:d}") class IncorrectAddress(DruncException): @@ -330,7 +329,7 @@ def get_control_type_and_uri_from_cli(CLAs: list[str]) -> ControlType: return ControlType.REST_API, resolve_localhost_and_127_ip_to_network_ip( CLA.replace("rest://", "") ) - elif CLA.startswith("grpc://"): + if CLA.startswith("grpc://"): return ControlType.gRPC, resolve_localhost_and_127_ip_to_network_ip( CLA.replace("grpc://", "") ) @@ -345,7 +344,7 @@ def get_control_type_and_uri_from_connectivity_service( timeout: int = 10, # seconds retry_wait: float = 0.1, # seconds progress_bar: bool = False, - title: str = None, + title: str | None = None, ) -> tuple[ControlType, str]: uris = [] logger = get_logger("utils.get_control_type_and_uri_from_connectivity_service") @@ -374,8 +373,7 @@ def get_control_type_and_uri_from_connectivity_service( ) if len(uris) == 0: raise ApplicationLookupUnsuccessful - else: - break + break except ApplicationLookupUnsuccessful: elapsed = time.time() - start @@ -394,8 +392,7 @@ def get_control_type_and_uri_from_connectivity_service( ) if len(uris) == 0: raise ApplicationLookupUnsuccessful - else: - break + break except ApplicationLookupUnsuccessful: elapsed = time.time() - start diff --git a/tests/conftest.py b/tests/conftest.py index 3aaebdac6..336baaf42 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,7 +28,8 @@ def load_test_config(): def boot_session(configuration_name, request): - from drunc.process_manager.oks_parser import collect_apps, collect_infra_apps + from drunc.process_manager.oks_parser import (collect_apps, + collect_infra_apps) req_name = request.node.name configuration_file = f"{configuration_name}.data.xml" @@ -36,7 +37,8 @@ def boot_session(configuration_name, request): from daqconf.consolidate import consolidate_db consolidate_db(configuration_file, configuration_consolidated_file) - from daqconf.set_connectivity_service_port import set_connectivity_service_port + from daqconf.set_connectivity_service_port import \ + set_connectivity_service_port set_connectivity_service_port(configuration_consolidated_file, configuration_name) session_name = f"{req_name}-{configuration_name}" @@ -91,7 +93,7 @@ def boot_session(configuration_name, request): for _ in range(10): if os.path.exists(processes["local-connection-server"][1]): - with open(processes["local-connection-server"][1], "r") as f: + with open(processes["local-connection-server"][1]) as f: if "[INFO] Starting gunicorn" in f.readline(): break time.sleep(0.1) diff --git a/tests/controller/test_controller.py b/tests/controller/test_controller.py index f17bf9233..4fb46b6e9 100644 --- a/tests/controller/test_controller.py +++ b/tests/controller/test_controller.py @@ -15,7 +15,7 @@ def test_controller_init(one_controller_running): while time_inc < timeout: if os.path.exists(controller_process[1]): - with open(controller_process[1], "r") as f: + with open(controller_process[1]) as f: for line in f.readlines(): if "Controller ready" in line: found = True diff --git a/tests/fsm/actions/test_utils.py b/tests/fsm/actions/test_utils.py index 06a1b4844..0962a5a33 100644 --- a/tests/fsm/actions/test_utils.py +++ b/tests/fsm/actions/test_utils.py @@ -6,7 +6,8 @@ from drunc.exceptions import DruncException from drunc.fsm.actions.utils import get_dotdrunc_json, validate_run_type -from drunc.fsm.exceptions import DotDruncJsonIncorrectFormat, DotDruncJsonNotFound +from drunc.fsm.exceptions import (DotDruncJsonIncorrectFormat, + DotDruncJsonNotFound) dotdrunc_json = { "run_registry_configuration": { diff --git a/tests/process_manager/conftest.py b/tests/process_manager/conftest.py index 7e66fb392..615465ac2 100644 --- a/tests/process_manager/conftest.py +++ b/tests/process_manager/conftest.py @@ -10,18 +10,12 @@ import google.protobuf.any_pb2 import pytest from druncschema.description_pb2 import Description -from druncschema.process_manager_pb2 import ( - BootRequest, - LogLines, - LogRequest, - ProcessDescription, - ProcessInstance, - ProcessInstanceList, - ProcessMetadata, - ProcessQuery, - ProcessRestriction, - ProcessUUID, -) +from druncschema.process_manager_pb2 import (BootRequest, LogLines, LogRequest, + ProcessDescription, + ProcessInstance, + ProcessInstanceList, + ProcessMetadata, ProcessQuery, + ProcessRestriction, ProcessUUID) from druncschema.request_response_pb2 import Request, ResponseFlag from druncschema.token_pb2 import Token diff --git a/tests/process_manager/dummy_requests.py b/tests/process_manager/dummy_requests.py index 360156d8f..ed08280fe 100644 --- a/tests/process_manager/dummy_requests.py +++ b/tests/process_manager/dummy_requests.py @@ -3,15 +3,10 @@ """ import google.protobuf.any_pb2 -from druncschema.process_manager_pb2 import ( - BootRequest, - LogRequest, - ProcessDescription, - ProcessMetadata, - ProcessQuery, - ProcessRestriction, - ProcessUUID, -) +from druncschema.process_manager_pb2 import (BootRequest, LogRequest, + ProcessDescription, + ProcessMetadata, ProcessQuery, + ProcessRestriction, ProcessUUID) from druncschema.request_response_pb2 import Request from druncschema.token_pb2 import Token diff --git a/tests/process_manager/dummy_responses.py b/tests/process_manager/dummy_responses.py index 8d9d1793b..c0edf8256 100644 --- a/tests/process_manager/dummy_responses.py +++ b/tests/process_manager/dummy_responses.py @@ -3,12 +3,8 @@ """ from druncschema.description_pb2 import Description -from druncschema.process_manager_pb2 import ( - LogLines, - ProcessInstance, - ProcessInstanceList, - ProcessUUID, -) +from druncschema.process_manager_pb2 import (LogLines, ProcessInstance, + ProcessInstanceList, ProcessUUID) from druncschema.request_response_pb2 import ResponseFlag from druncschema.token_pb2 import Token diff --git a/tests/process_manager/interface/test_commands.py b/tests/process_manager/interface/test_commands.py index 6b98b1544..0dc5721f7 100644 --- a/tests/process_manager/interface/test_commands.py +++ b/tests/process_manager/interface/test_commands.py @@ -7,17 +7,10 @@ import pytest from click.testing import CliRunner -from drunc.process_manager.interface.commands import ( - InterruptedCommand, - boot, - dummy_boot, - flush, - kill, - logs, - ps, - restart, - terminate, -) +from drunc.process_manager.interface.commands import (InterruptedCommand, boot, + dummy_boot, flush, kill, + logs, ps, restart, + terminate) @pytest.fixture diff --git a/tests/process_manager/process_manager_mock_impls.py b/tests/process_manager/process_manager_mock_impls.py index c45159269..fe123ca79 100644 --- a/tests/process_manager/process_manager_mock_impls.py +++ b/tests/process_manager/process_manager_mock_impls.py @@ -6,20 +6,12 @@ """ -from typing import Optional from unittest.mock import Mock -from druncschema.process_manager_pb2 import ( - BootRequest, - LogLines, - LogRequest, - ProcessInstanceList, - ProcessQuery, -) - -from drunc.process_manager.configuration import ( - ProcessManagerConfHandler, -) +from druncschema.process_manager_pb2 import (BootRequest, LogLines, LogRequest, + ProcessInstanceList, ProcessQuery) + +from drunc.process_manager.configuration import ProcessManagerConfHandler from drunc.process_manager.process_manager import ProcessManager, ResponseFlag @@ -32,7 +24,7 @@ def __init__( self, configuration: ProcessManagerConfHandler = Mock(), name: str = "process_manager_no_impl", - session: Optional[str] = None, + session: str | None = None, **kwargs, ): """ diff --git a/tests/process_manager/test_grpc_fields.py b/tests/process_manager/test_grpc_fields.py index e7ed02d62..926c3e59a 100644 --- a/tests/process_manager/test_grpc_fields.py +++ b/tests/process_manager/test_grpc_fields.py @@ -223,11 +223,8 @@ def test_process_description_field_init(): """ Test ProcessDescription fields properly populated """ - from druncschema.process_manager_pb2 import ( - ProcessDescription, - ProcessMetadata, - ProcessUUID, - ) + from druncschema.process_manager_pb2 import (ProcessDescription, + ProcessMetadata, ProcessUUID) metadata = ProcessMetadata( uuid=ProcessUUID(uuid="test-uuid"), @@ -261,12 +258,10 @@ def test_process_instance_field_init(): """ Test ProcessInstance fields properly populated """ - from druncschema.process_manager_pb2 import ( - ProcessDescription, - ProcessInstance, - ProcessRestriction, - ProcessUUID, - ) + from druncschema.process_manager_pb2 import (ProcessDescription, + ProcessInstance, + ProcessRestriction, + ProcessUUID) process_description = ProcessDescription() process_restriction = ProcessRestriction() @@ -293,7 +288,8 @@ def test_process_instance_list_field_init(): """ Test ProcessInstanceList fields properly populated """ - from druncschema.process_manager_pb2 import ProcessInstance, ProcessInstanceList + from druncschema.process_manager_pb2 import (ProcessInstance, + ProcessInstanceList) from druncschema.request_response_pb2 import ResponseFlag from druncschema.token_pb2 import Token @@ -316,11 +312,9 @@ def test_boot_request_field_init(): """ Test BootRequest fields properly populated """ - from druncschema.process_manager_pb2 import ( - BootRequest, - ProcessDescription, - ProcessRestriction, - ) + from druncschema.process_manager_pb2 import (BootRequest, + ProcessDescription, + ProcessRestriction) from druncschema.token_pb2 import Token token = Token() diff --git a/tests/process_manager/test_process_manager_driver.py b/tests/process_manager/test_process_manager_driver.py index 0a766634f..930d924f0 100644 --- a/tests/process_manager/test_process_manager_driver.py +++ b/tests/process_manager/test_process_manager_driver.py @@ -13,12 +13,9 @@ import grpc import pytest -from druncschema.process_manager_pb2 import ( - BootRequest, - ProcessDescription, - ProcessMetadata, - ProcessRestriction, -) +from druncschema.process_manager_pb2 import (BootRequest, ProcessDescription, + ProcessMetadata, + ProcessRestriction) from druncschema.token_pb2 import Token from drunc.connectivity_service.exceptions import ApplicationLookupUnsuccessful diff --git a/tests/process_manager/test_process_manager_endpoints.py b/tests/process_manager/test_process_manager_endpoints.py index 6aefb85a1..c7544c2ce 100644 --- a/tests/process_manager/test_process_manager_endpoints.py +++ b/tests/process_manager/test_process_manager_endpoints.py @@ -18,9 +18,8 @@ import pytest from druncschema.process_manager_pb2 import DESCRIPTOR -from tests.process_manager.process_manager_mock_impls import ( - ConcreteProcessManager, -) +from tests.process_manager.process_manager_mock_impls import \ + ConcreteProcessManager @pytest.fixture(scope="function") diff --git a/tests/process_manager/test_process_manager_serialisation.py b/tests/process_manager/test_process_manager_serialisation.py index a9fabe062..753b3d765 100644 --- a/tests/process_manager/test_process_manager_serialisation.py +++ b/tests/process_manager/test_process_manager_serialisation.py @@ -19,9 +19,7 @@ import grpc import pytest from druncschema.process_manager_pb2_grpc import ( - ProcessManagerStub, - add_ProcessManagerServicer_to_server, -) + ProcessManagerStub, add_ProcessManagerServicer_to_server) class ProcessManagerSerialisationTestSuite: diff --git a/tests/session_manager/conftest.py b/tests/session_manager/conftest.py index a4f18185d..f793006f4 100644 --- a/tests/session_manager/conftest.py +++ b/tests/session_manager/conftest.py @@ -6,13 +6,9 @@ import pytest from druncschema.description_pb2 import CommandDescription, Description from druncschema.request_response_pb2 import Request, ResponseFlag -from druncschema.session_manager_pb2 import ( - DESCRIPTOR, - ActiveSession, - AllActiveSessions, - AllConfigKeys, - ConfigKey, -) +from druncschema.session_manager_pb2 import (DESCRIPTOR, ActiveSession, + AllActiveSessions, AllConfigKeys, + ConfigKey) from druncschema.token_pb2 import Token from drunc.session_manager.session_manager import SessionManager diff --git a/tests/session_manager/test_session_manager_endpoints.py b/tests/session_manager/test_session_manager_endpoints.py index aeaab31b3..0961b23d4 100644 --- a/tests/session_manager/test_session_manager_endpoints.py +++ b/tests/session_manager/test_session_manager_endpoints.py @@ -6,9 +6,7 @@ import grpc import pytest -from druncschema.session_manager_pb2 import ( - DESCRIPTOR, -) +from druncschema.session_manager_pb2 import DESCRIPTOR @pytest.mark.parametrize( diff --git a/tests/session_manager/test_session_manager_serialisation.py b/tests/session_manager/test_session_manager_serialisation.py index e858d48ea..c2cde7acb 100644 --- a/tests/session_manager/test_session_manager_serialisation.py +++ b/tests/session_manager/test_session_manager_serialisation.py @@ -11,9 +11,7 @@ import grpc import pytest from druncschema.session_manager_pb2_grpc import ( - SessionManagerStub, - add_SessionManagerServicer_to_server, -) + SessionManagerStub, add_SessionManagerServicer_to_server) from grpc._channel import _InactiveRpcError from drunc.session_manager.session_manager import SessionManager diff --git a/tests/session_manager/test_session_manager_servicer.py b/tests/session_manager/test_session_manager_servicer.py index e6874fde5..59a026f32 100644 --- a/tests/session_manager/test_session_manager_servicer.py +++ b/tests/session_manager/test_session_manager_servicer.py @@ -3,12 +3,8 @@ from druncschema.description_pb2 import Description from druncschema.request_response_pb2 import ResponseFlag -from druncschema.session_manager_pb2 import ( - ActiveSession, - AllActiveSessions, - AllConfigKeys, - ConfigKey, -) +from druncschema.session_manager_pb2 import (ActiveSession, AllActiveSessions, + AllConfigKeys, ConfigKey) def test_describe( diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 94b3573d6..19d308255 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -8,22 +8,14 @@ import pytest from drunc.exceptions import DruncSetupException -from drunc.utils.utils import ( - ControlType, - IncorrectAddress, - expand_path, - get_control_type_and_uri_from_cli, - get_new_port, - get_random_string, - host_is_local, - https_or_http_present, - now_str, - parent_death_pact, - regex_match, - resolve_localhost_and_127_ip_to_network_ip, - resolve_localhost_to_hostname, - validate_command_facility, -) +from drunc.utils.utils import (ControlType, IncorrectAddress, expand_path, + get_control_type_and_uri_from_cli, get_new_port, + get_random_string, host_is_local, + https_or_http_present, now_str, + parent_death_pact, regex_match, + resolve_localhost_and_127_ip_to_network_ip, + resolve_localhost_to_hostname, + validate_command_facility) def test_get_random_string():