diff --git a/src/drunc/controller/controller_driver.py b/src/drunc/controller/controller_driver.py index dd88d419b..31153af48 100644 --- a/src/drunc/controller/controller_driver.py +++ b/src/drunc/controller/controller_driver.py @@ -15,12 +15,10 @@ class ControllerDriver(GRPCDriver): def __init__(self, address: str, token, **kwargs): - super(ControllerDriver, self).__init__( + super().__init__( name="controller_driver", address=address, token=token, **kwargs ) - - def create_stub(self, channel): - return ControllerStub(channel) + self.stub = ControllerStub(self.channel) def pack_empty_addressed_command(cmd): @wraps(cmd) diff --git a/src/drunc/fsm/actions/usvc_elisa_logbook.py b/src/drunc/fsm/actions/usvc_elisa_logbook.py index 190759f74..d8fed3bfa 100644 --- a/src/drunc/fsm/actions/usvc_elisa_logbook.py +++ b/src/drunc/fsm/actions/usvc_elisa_logbook.py @@ -51,6 +51,7 @@ def __init__(self): ) raise DotDruncJsonIncorrectFormat(err_msg) from KeyError else: + err_msg: str = "" if default_elisa_logbook: err_msg = ( diff --git a/src/drunc/process_manager/interface/commands.py b/src/drunc/process_manager/interface/commands.py index 04a594eab..9d24221f3 100644 --- a/src/drunc/process_manager/interface/commands.py +++ b/src/drunc/process_manager/interface/commands.py @@ -45,9 +45,9 @@ def boot( log = get_logger("process_manager.shell") processes = obj.get_driver("process_manager").ps(ProcessQuery(user=user)) - if len(processes.data.values) > 0: + if len(processes.values) > 0: click.confirm( - f"You already have {len(processes.data.values)} processes running, are you sure you want to boot a session?", + f"You already have {len(processes.values)} processes running, are you sure you want to boot a session?", abort=True, ) @@ -67,7 +67,7 @@ def boot( if not result: break log.debug( - f"'{result.data.process_description.metadata.name}' ({result.data.uuid.uuid}) process started" + f"'{result.values[0].process_description.metadata.name}' ({result.values[0].uuid.uuid}) process started" ) except InterruptedCommand: return @@ -144,7 +144,7 @@ def dummy_boot( if not result: break log.debug( - f"'{result.data.process_description.metadata.name}' ({result.data.uuid.uuid}) process started" + f"'{result.values[0].process_description.metadata.name}' ({result.values[0].uuid.uuid}) process started" ) except InterruptedCommand: return @@ -159,11 +159,9 @@ def terminate(obj: ProcessManagerContext) -> None: if not result: return obj.print( - tabulate_process_instance_list(result.data, "Terminated process", False) + tabulate_process_instance_list(result, "Terminated process", False) ) # rich tables require console printing - obj.delete_driver("controller") - @click.command("kill") @add_query_options(at_least_one=True) @@ -171,15 +169,13 @@ def terminate(obj: ProcessManagerContext) -> None: def kill(obj: ProcessManagerContext, query: ProcessQuery) -> None: log = get_logger("process_manager.shell") log.debug(f"Killing with query {query}") - result = obj.get_driver("process_manager").kill(query=query) + result = obj.get_driver("process_manager").kill(query) if not result: return obj.print( - tabulate_process_instance_list(result.data, "Killed process", False) + tabulate_process_instance_list(result, "Killed process", False) ) # rich tables require console printing - obj.delete_driver("controller") - @click.command("flush") @add_query_options(at_least_one=False, all_processes_by_default=True) @@ -187,11 +183,11 @@ def kill(obj: ProcessManagerContext, query: ProcessQuery) -> None: def flush(obj: ProcessManagerContext, query: ProcessQuery) -> None: log = get_logger("process_manager.shell") log.debug(f"Flushing with query {query}") - result = obj.get_driver("process_manager").flush(query=query) + result = obj.get_driver("process_manager").flush(query) if not result: return obj.print( - tabulate_process_instance_list(result.data, "Flushed process", False) + tabulate_process_instance_list(result, "Flushed process", False) ) # rich tables require console printing @@ -216,13 +212,12 @@ def logs( query=query, ) - result = obj.get_driver("process_manager").logs(log_req).data + result = obj.get_driver("process_manager").logs(log_req) if result.uuid.uuid is not None: obj.rule(f"[yellow]{result.uuid.uuid}[/yellow] logs") for line in result.lines: - if line == "": obj.print("") continue @@ -248,7 +243,7 @@ def logs( def restart(obj: ProcessManagerContext, query: ProcessQuery) -> None: log = get_logger("process_manager.shell") log.debug(f"Restarting with query {query}") - obj.get_driver("process_manager").restart(query=query) + obj.get_driver("process_manager").restart(query) @click.command("ps") @@ -262,16 +257,14 @@ def restart(obj: ProcessManagerContext, query: ProcessQuery) -> None: help="Whether to have a long output", ) @click.pass_obj -def ps( - obj: ProcessManagerContext, query: ProcessQuery, long_format: bool -) -> None: +def ps(obj: ProcessManagerContext, query: ProcessQuery, long_format: bool) -> None: log = get_logger("process_manager.shell") log.debug(f"Running ps with query {query}") - results = obj.get_driver("process_manager").ps(query=query) + results = obj.get_driver("process_manager").ps(query) if not results: return obj.print( tabulate_process_instance_list( - results.data, title="Processes running", long=long_format + results, title="Processes running", long=long_format ) ) diff --git a/src/drunc/process_manager/interface/shell.py b/src/drunc/process_manager/interface/shell.py index a192ede09..b7c6370da 100644 --- a/src/drunc/process_manager/interface/shell.py +++ b/src/drunc/process_manager/interface/shell.py @@ -48,7 +48,6 @@ def process_manager_shell(ctx, process_manager_address: str, log_level: str) -> ctx.obj.reset(address=process_manager_address) try: - desc = ctx.obj.get_driver("process_manager").describe() except ServerUnreachable as e: process_manager_shell_log = get_logger( @@ -63,7 +62,7 @@ def process_manager_shell(ctx, process_manager_address: str, log_level: str) -> process_manager_log = get_logger( logger_name="process_manager", - log_file_path=desc.data.info, + log_file_path=desc.info, override_log_file=False, rich_handler=True, ) @@ -72,10 +71,10 @@ def process_manager_shell(ctx, process_manager_address: str, log_level: str) -> f"[green]{getpass.getuser()}[/green] connected to the process manager through a [green]drunc-process-manager-shell[/green] via address [green]{process_manager_address}[/green]" ) process_manager_shell_log.info( - f"Connected to {process_manager_address}, running '{desc.data.name}.{desc.data.session}' (name.session), starting listening..." + f"Connected to {process_manager_address}, running '{desc.name}.{desc.session}' (name.session), starting listening..." ) - if desc.data.HasField("broadcast"): - ctx.obj.start_listening(desc.data.broadcast) + if desc.HasField("broadcast"): + ctx.obj.start_listening(desc.broadcast) def cleanup(): ctx.obj.terminate() diff --git a/src/drunc/process_manager/k8s_process_manager.py b/src/drunc/process_manager/k8s_process_manager.py index 2eaa4b8bc..e528de2cd 100644 --- a/src/drunc/process_manager/k8s_process_manager.py +++ b/src/drunc/process_manager/k8s_process_manager.py @@ -16,6 +16,7 @@ ProcessRestriction, ProcessUUID, ) +from druncschema.request_response_pb2 import ResponseFlag from kubernetes import client, config from drunc.exceptions import DruncCommandException, DruncException @@ -368,20 +369,29 @@ def _return_code(self, podname, session): def _terminate(self): self.log.info("Terminating") - def _logs_impl(self, log_request: LogRequest) -> LogLines: uuids = self._get_process_uid(log_request.query, in_boot_request=True) uuid = self._ensure_one_process(uuids, in_boot_request=True) for uuid in self._get_process_uid(log_request.query): podname = self.boot_request[uuid].process_description.metadata.name session = self.boot_request[uuid].process_description.metadata.session - return [LogLines(line=log) for log in self._core_v1_api.read_namespaced_pod_log( - podname, session, tail_lines=log_request.how_far - ).split("\n")] - - def _boot_impl(self, boot_request: BootRequest) -> ProcessUUID: + return [ + LogLines(line=log) + for log in self._core_v1_api.read_namespaced_pod_log( + podname, session, tail_lines=log_request.how_far + ).split("\n") + ] + + def _boot_impl(self, boot_request: BootRequest) -> ProcessInstanceList: + self.log.debug(f"{self.name} running boot command") this_uuid = str(uuid.uuid4()) - return self.__boot(boot_request, this_uuid) + process = self.__boot(boot_request, this_uuid) + return ProcessInstanceList( + name=self.name, + token=None, + values=[process], + flag=ResponseFlag.EXECUTED_SUCCESSFULLY, + ) def __boot(self, boot_request: BootRequest, uuid: str) -> ProcessInstance: session = boot_request.process_description.metadata.session @@ -422,12 +432,15 @@ def _ps_impl( ret.append(self._get_pi(proc_uuid, podname, session, return_code)) - pil = ProcessInstanceList(values=ret) - - return pil + return ProcessInstanceList( + name=self.name, + token=None, + values=ret, + flag=ResponseFlag.EXECUTED_SUCCESSFULLY, + ) def _restart_impl(self, query: ProcessQuery) -> ProcessInstanceList: - # ret = [] + ret = [] uuids = self._get_process_uid(query, in_boot_request=True) uuid = self._ensure_one_process(uuids, in_boot_request=True) for uuid in self._get_process_uid(query): @@ -437,7 +450,6 @@ def _restart_impl(self, query: ProcessQuery) -> ProcessInstanceList: self._kill_pod(podname, session) - same_uuid_br = [] same_uuid_br = BootRequest() same_uuid_br.CopyFrom(self.boot_request[uuid]) same_uuid = uuid @@ -445,17 +457,18 @@ def _restart_impl(self, query: ProcessQuery) -> ProcessInstanceList: del self.boot_request[uuid] del uuid - ret = self.__boot(same_uuid_br, same_uuid) + ret = [self.__boot(same_uuid_br, same_uuid)] + # ret.append(self._get_pi(uuid, podname, session)) del same_uuid_br del same_uuid - # ret.append(self._get_pi(uuid, podname, session)) - - # pil = ProcessInstanceList( - # values=ret - # ) - return ret + return ProcessInstanceList( + name=self.name, + token=None, + values=ret, + flag=ResponseFlag.EXECUTED_SUCCESSFULLY, + ) # # ORDER MATTERS! # @broadcasted # outer most wrapper 1st step diff --git a/src/drunc/process_manager/process_manager.py b/src/drunc/process_manager/process_manager.py index a60a7fef2..c52acfd21 100644 --- a/src/drunc/process_manager/process_manager.py +++ b/src/drunc/process_manager/process_manager.py @@ -5,7 +5,7 @@ from druncschema.authoriser_pb2 import ActionType, SystemType from druncschema.broadcast_pb2 import BroadcastType -from druncschema.description_pb2 import CommandDescription, OldDescription +from druncschema.description_pb2 import CommandDescription, Description from druncschema.opmon.process_manager_pb2 import ProcessStatus from druncschema.process_manager_pb2 import ( BootRequest, @@ -21,7 +21,6 @@ from druncschema.process_manager_pb2_grpc import ProcessManagerServicer from druncschema.request_response_pb2 import ( Request, - Response, ResponseFlag, ) from google.rpc import code_pb2 @@ -40,10 +39,7 @@ ) from drunc.utils.configuration import ConfTypes from drunc.utils.grpc_utils import ( - UnpackingError, pack_to_any, - unpack_any, - unpack_error_response, ) from drunc.utils.utils import get_logger, pid_info_str @@ -105,7 +101,7 @@ def __init__( name="describe", data_type=["None"], help="Describe self (return a list of commands, the type of endpoint, the name and session).", - return_type="description_pb2.OldDescription", + return_type="description_pb2.Description", ), CommandDescription( name="kill", @@ -117,13 +113,13 @@ def __init__( name="restart", data_type=["process_manager_pb2.ProcessQuery"], help="Restart the process from the process query (which must correspond to one process).", - return_type="process_manager_pb2.ProcessInstance", + return_type="process_manager_pb2.ProcessInstanceList", ), CommandDescription( name="boot", data_type=["generic_pb2.BootRequest", "None"], help="Start a process.", - return_type="process_manager_pb2.ProcessInstance", + return_type="process_manager_pb2.ProcessInstanceList", ), CommandDescription( name="terminate", @@ -147,7 +143,7 @@ def __init__( name="ps", data_type=["process_manager_pb2.ProcessQuery"], help="Get the status of the listed process from the process query input (can be multiple).", - return_type="process_manager_pb2.ProcessInstance", + return_type="process_manager_pb2.ProcessInstanceList", ), ] @@ -182,10 +178,7 @@ def publish(self, q: ProcessQuery, interval_s: float = 10.0): if process.status_code == ProcessInstance.StatusCode.DEAD ) n_session = len( - { - process.process_description.metadata.session - for process in results.values - } + {process.process_description.metadata.session for process in results.values} ) self.opmon_publisher.publish( message=ProcessStatus( @@ -231,7 +224,7 @@ def interrupt_with_exception(self, *args, **kwargs): ) @abc.abstractmethod - def _boot_impl(self, br: BootRequest) -> ProcessInstance: + def _boot_impl(self, boot_request: BootRequest) -> ProcessInstanceList: raise NotImplementedError # ORDER MATTERS! @@ -239,36 +232,26 @@ def _boot_impl(self, br: BootRequest) -> ProcessInstance: @authentified_and_authorised( action=ActionType.CREATE, system=SystemType.PROCESS_MANAGER ) # 2nd step - def boot(self, request: Request, context: ServicerContext) -> Response: - try: - data = unpack_any(request.data, BootRequest) - except UnpackingError as e: - return unpack_error_response(self.__class__.__name__, str(e), request.token) - + def boot( + self, request: BootRequest, context: ServicerContext + ) -> ProcessInstanceList: self.log.debug( "{self.name} booting '{data.process_description.metadata.name}' " "from session '{data.process_description.metadata.session}'" ) try: - resp = self._boot_impl(data) - - return Response( - name=self.name, - token=None, - data=pack_to_any(resp), - flag=ResponseFlag.EXECUTED_SUCCESSFULLY, - children=[], - ) + response = self._boot_impl(request) except NotImplementedError: - return Response( + return ProcessInstanceList( name=self.name, token=None, - data=pack_to_any(resp), + values=[], flag=ResponseFlag.NOT_EXECUTED_NOT_IMPLEMENTED, - children=[], ) + return response + @abc.abstractmethod def _terminate_impl(self) -> ProcessInstanceList: raise NotImplementedError @@ -278,28 +261,25 @@ def _terminate_impl(self) -> ProcessInstanceList: @authentified_and_authorised( action=ActionType.DELETE, system=SystemType.PROCESS_MANAGER ) # 2nd step - def terminate(self, request: Request, context: ServicerContext) -> Response: - self.log.debug(f"{self.name} terminating") + def terminate( + self, request: Request, context: ServicerContext + ) -> ProcessInstanceList: + self.log.debug(f"{self.name} running terminate") + try: - resp = self._terminate_impl() - return Response( - name=self.name, - token=None, - data=pack_to_any(resp), - flag=ResponseFlag.EXECUTED_SUCCESSFULLY, - children=[], - ) + response = self._terminate_impl() except NotImplementedError: - return Response( + return ProcessInstanceList( name=self.name, token=None, - data=pack_to_any(resp), + values=[], flag=ResponseFlag.NOT_EXECUTED_NOT_IMPLEMENTED, - children=[], ) + return response + @abc.abstractmethod - def _restart_impl(self, q: ProcessQuery) -> ProcessInstanceList: + def _restart_impl(self, query: ProcessQuery) -> ProcessInstanceList: raise NotImplementedError # ORDER MATTERS! @@ -307,34 +287,25 @@ def _restart_impl(self, q: ProcessQuery) -> ProcessInstanceList: @authentified_and_authorised( action=ActionType.DELETE, system=SystemType.PROCESS_MANAGER ) # 2nd step - def restart(self, request: Request, context: ServicerContext) -> Response: - try: - data = unpack_any(request.data, ProcessQuery) - except UnpackingError as e: - return unpack_error_response(self.__class__.__name__, str(e), request.token) - + def restart( + self, request: ProcessQuery, context: ServicerContext + ) -> ProcessInstanceList: self.log.debug(f"{self.name} running restart") try: - resp = self._restart_impl(data) - return Response( - name=self.name, - token=None, - data=pack_to_any(resp), - flag=ResponseFlag.EXECUTED_SUCCESSFULLY, - children=[], - ) + response = self._restart_impl(request) except NotImplementedError: - return Response( + return ProcessInstanceList( name=self.name, token=None, - data=pack_to_any(resp), + values=[], flag=ResponseFlag.NOT_EXECUTED_NOT_IMPLEMENTED, - children=[], ) + return response + @abc.abstractmethod - def _kill_impl(self, q: ProcessQuery) -> ProcessInstanceList: + def _kill_impl(self, query: ProcessQuery) -> ProcessInstanceList: raise NotImplementedError # ORDER MATTERS! @@ -342,34 +313,25 @@ def _kill_impl(self, q: ProcessQuery) -> ProcessInstanceList: @authentified_and_authorised( action=ActionType.DELETE, system=SystemType.PROCESS_MANAGER ) # 2nd step - def kill(self, request: Request, context: ServicerContext) -> Response: - try: - data = unpack_any(request.data, ProcessQuery) - except UnpackingError as e: - return unpack_error_response(self.__class__.__name__, str(e), request.token) - + def kill( + self, request: ProcessQuery, context: ServicerContext + ) -> ProcessInstanceList: self.log.debug(f"{self.name} running kill") try: - resp = self._kill_impl(data) - return Response( - name=self.name, - token=None, - data=pack_to_any(resp), - flag=ResponseFlag.EXECUTED_SUCCESSFULLY, - children=[], - ) + response = self._kill_impl(request) except NotImplementedError: - return Response( + return ProcessInstanceList( name=self.name, token=None, - data=pack_to_any(resp), + values=[], flag=ResponseFlag.NOT_EXECUTED_NOT_IMPLEMENTED, - children=[], ) + return response + @abc.abstractmethod - def _ps_impl(self, q: ProcessQuery) -> ProcessInstanceList: + def _ps_impl(self, query: ProcessQuery) -> ProcessInstanceList: raise NotImplementedError # ORDER MATTERS! @@ -377,47 +339,35 @@ def _ps_impl(self, q: ProcessQuery) -> ProcessInstanceList: @authentified_and_authorised( action=ActionType.READ, system=SystemType.PROCESS_MANAGER ) # 2nd step - def ps(self, request: Request, context: ServicerContext) -> Response: - try: - data = unpack_any(request.data, ProcessQuery) - except UnpackingError as e: - return unpack_error_response(self.__class__.__name__, str(e), request.token) - + def ps( + self, request: ProcessQuery, context: ServicerContext + ) -> ProcessInstanceList: self.log.debug(f"{self.name} running ps") try: - resp = self._ps_impl(data) - return Response( - name=self.name, - token=None, - data=pack_to_any(resp), - flag=ResponseFlag.EXECUTED_SUCCESSFULLY, - children=[], - ) + response = self._ps_impl(request) except NotImplementedError: - return Response( + return ProcessInstanceList( name=self.name, token=None, - data=pack_to_any(resp), + values=[], flag=ResponseFlag.NOT_EXECUTED_NOT_IMPLEMENTED, - children=[], ) + return response + # ORDER MATTERS! @broadcasted # outer most wrapper 1st step @authentified_and_authorised( action=ActionType.DELETE, system=SystemType.PROCESS_MANAGER ) # 2nd step - def flush(self, request: Request, context: ServicerContext) -> Response: - try: - data = unpack_any(request.data, ProcessQuery) - except UnpackingError as e: - return unpack_error_response(self.__class__.__name__, str(e), request.token) - + def flush( + self, request: ProcessQuery, context: ServicerContext + ) -> ProcessInstanceList: self.log.debug(f"{self.name} running flush") ret = [] - for uuid in self._get_process_uid(data): + for uuid in self._get_process_uid(request): if uuid not in self.boot_request: pu = ProcessUUID(uuid=uuid) pi = ProcessInstance( @@ -460,14 +410,11 @@ def flush(self, request: Request, context: ServicerContext) -> Response: del self.process_store[uuid] ret += [pi] - pil = ProcessInstanceList(values=ret) - - return Response( + return ProcessInstanceList( name=self.name, token=None, - data=pack_to_any(pil), + values=ret, flag=ResponseFlag.EXECUTED_SUCCESSFULLY, - children=[], ) # ORDER MATTERS! @@ -475,29 +422,27 @@ def flush(self, request: Request, context: ServicerContext) -> Response: @authentified_and_authorised( action=ActionType.READ, system=SystemType.PROCESS_MANAGER ) # 2nd step - def describe(self, request: Request, context: ServicerContext) -> Response: + def describe(self, request: Request, context: ServicerContext) -> Description: self.log.debug(f"{self.name} running describe") - bd = self.describe_broadcast() - d = OldDescription( + + description = Description( type="process_manager", name=self.name, info=self.configuration.log_path, session="no_session" if not self.session else self.session, commands=self.commands, - ) - if bd: - d.broadcast.CopyFrom(pack_to_any(bd)) - - return Response( - name=self.name, - token=None, - data=pack_to_any(d), - flag=ResponseFlag.EXECUTED_SUCCESSFULLY, children=[], + flag=ResponseFlag.EXECUTED_SUCCESSFULLY, + token=None, ) + if broadcast_description := self.describe_broadcast(): + description.broadcast.CopyFrom(pack_to_any(broadcast_description)) + + return description + @abc.abstractmethod - def _logs_impl(self, request_data: LogRequest) -> LogLines: + def _logs_impl(self, log_request: LogRequest) -> LogLines: raise NotImplementedError # ORDER MATTERS! @@ -505,41 +450,34 @@ def _logs_impl(self, request_data: LogRequest) -> LogLines: @authentified_and_authorised( action=ActionType.READ, system=SystemType.PROCESS_MANAGER ) # 2nd step - def logs(self, request: Request, context: ServicerContext) -> Response: + def logs(self, request: LogRequest, context: ServicerContext) -> LogLines: """Fetch logs for a process. Args: request: The incoming request. context: The gRPC context (not used). - Yields: - Response objects containing log lines. + Returns: + A response containing log lines. """ self.log.debug("Getting logs") try: - data = unpack_any(request.data, LogRequest) - except UnpackingError as e: - return unpack_error_response(self.__class__.__name__, str(e), request.token) - - try: - return Response( - name=self.name, - token=None, - data=pack_to_any(self._logs_impl(data)), - flag=ResponseFlag.EXECUTED_SUCCESSFULLY, - children=[], - ) + response = self._logs_impl(request) except NotImplementedError: - return Response( + return LogLines( name=self.name, token=None, - data=None, + uuid=None, + lines=[], flag=ResponseFlag.NOT_EXECUTED_NOT_IMPLEMENTED, - children=[], ) - def _ensure_one_process(self, uuids: [str], in_boot_request: bool = False) -> str: + return response + + def _ensure_one_process( + self, uuids: list[str], in_boot_request: bool = False + ) -> str: if uuids == []: raise BadQuery("The process corresponding to the query doesn't exist") elif len(uuids) > 1: @@ -562,7 +500,7 @@ def _get_process_uid( query: ProcessQuery, in_boot_request: bool = False, order_by: str = "random", - ) -> [str]: + ) -> list[str]: order_by = order_by.lower() if order_by not in ["random", "leaf_first", "root_first"]: raise DruncCommandException(f"Order by '{order_by}' is not supported") diff --git a/src/drunc/process_manager/process_manager_driver.py b/src/drunc/process_manager/process_manager_driver.py index 89ca38d05..359d8132f 100644 --- a/src/drunc/process_manager/process_manager_driver.py +++ b/src/drunc/process_manager/process_manager_driver.py @@ -4,27 +4,30 @@ import signal import tempfile import time +from collections.abc import Iterator from time import sleep -from druncschema.description_pb2 import OldDescription +import grpc +from druncschema.description_pb2 import Description from druncschema.process_manager_pb2 import ( BootRequest, LogLines, LogRequest, ProcessDescription, - ProcessInstance, ProcessInstanceList, ProcessMetadata, ProcessQuery, ProcessRestriction, ) from druncschema.process_manager_pb2_grpc import ProcessManagerStub +from druncschema.request_response_pb2 import Request from drunc.connectivity_service.client import ConnectivityServiceClient from drunc.connectivity_service.exceptions import ApplicationLookupUnsuccessful from drunc.controller.utils import get_segment_lookup_timeout from drunc.exceptions import DruncSetupException, DruncShellException from drunc.process_manager.utils import get_log_path, get_rte_script +from drunc.utils.grpc_utils import copy_token, handle_grpc_error from drunc.utils.shell_utils import GRPCDriver from drunc.utils.utils import ( get_control_type_and_uri_from_connectivity_service, @@ -38,13 +41,10 @@ class ProcessManagerDriver(GRPCDriver): controller_address = "" def __init__(self, address: str, token, **kwargs): - super(ProcessManagerDriver, self).__init__( - name="process_manager.driver", address=address, token=token, **kwargs + super().__init__( + name="process_manager_driver", address=address, token=token, **kwargs ) - self.log.debug("set up process_manager.driver") - - def create_stub(self, channel): - return ProcessManagerStub(channel) + self.stub = ProcessManagerStub(self.channel) def _convert_oks_to_boot_request( self, @@ -54,7 +54,7 @@ def _convert_oks_to_boot_request( db, session_name: str, override_logs: bool, - ) -> BootRequest: + ) -> Iterator[BootRequest]: from drunc.process_manager.oks_parser import collect_apps, collect_infra_apps env = { @@ -133,6 +133,7 @@ def _convert_oks_to_boot_request( self.log.debug(f"{name}'s env:\n{env}") breq = BootRequest( + token=copy_token(self.token), process_description=ProcessDescription( metadata=ProcessMetadata( user=user, @@ -164,7 +165,7 @@ def boot( int | float ) = 0, # This may be useful if you have are using SSHPM, and have SSHD's maxstartups setting set to a low value. **kwargs, - ) -> ProcessInstance: + ) -> Iterator[ProcessInstanceList]: from daqconf.consolidate import consolidate_db self.log.info(f"Booting session [green]{session_name}[/green]") @@ -204,7 +205,7 @@ def boot( last_boot_on_host_at = {} previous_host = None - for br in self._convert_oks_to_boot_request( + for request in self._convert_oks_to_boot_request( oks_conf=conf_file, user=user, session_dal=session_dal, @@ -214,14 +215,14 @@ def boot( **kwargs, ): if ( - br.process_description.metadata.name + request.process_description.metadata.name not in [app.id for app in session_dal.infrastructure_applications] and csc and not csc.is_ready(timeout=10) ): raise DruncSetupException("Connectivity service is not ready in time") - this_host = next(iter(br.process_restriction.allowed_hosts)) + this_host = next(iter(request.process_restriction.allowed_hosts)) time_diff = time.time() - last_boot_on_host_at.get(this_host, 0) @@ -233,12 +234,13 @@ def boot( previous_host = this_host last_boot_on_host_at[this_host] = time.time() - yield self.send_command( - "boot", - data=br, - outformat=ProcessInstance, - timeout=timeout, - ) + + try: + response = self.stub.boot(request, timeout=timeout) + except grpc.RpcError as e: + handle_grpc_error(e) + + yield response top_controller_name = session_dal.segment.controller.id @@ -340,7 +342,7 @@ def dummy_boot( sleep: int, n_sleeps: int, timeout: int | float = 60, - ): # -> ProcessInstance: + ) -> Iterator[ProcessInstanceList]: pwd = os.getcwd() # Construct the list of commands to send to the dummy_boot process @@ -359,7 +361,8 @@ def dummy_boot( ) for process in range(n_processes): - breq = BootRequest( + request = BootRequest( + token=copy_token(self.token), process_description=ProcessDescription( metadata=ProcessMetadata( user=user, @@ -376,70 +379,92 @@ def dummy_boot( ), process_restriction=ProcessRestriction(allowed_hosts=["localhost"]), ) - self.log.debug(f"{breq=}\n\n") + self.log.debug(f"{request=}\n\n") - yield self.send_command( - "boot", - data=breq, - outformat=ProcessInstance, - timeout=timeout, - ) + try: + response = self.stub.boot(request, timeout=timeout) + except grpc.RpcError as e: + handle_grpc_error(e) + + yield response def terminate( self, timeout: int | float = 60, ) -> ProcessInstanceList: - return self.send_command( - "terminate", outformat=ProcessInstanceList, timeout=timeout - ) + request = Request(token=copy_token(self.token)) - def kill(self, query: ProcessQuery, timeout: int | float = 60) -> ProcessInstance: - return self.send_command( - "kill", - data=query, - outformat=ProcessInstanceList, - timeout=timeout, - ) + try: + response = self.stub.terminate(request, timeout=timeout) + except grpc.RpcError as e: + handle_grpc_error(e) - def logs(self, req: LogRequest, timeout: int | float = 60) -> LogLines: - return self.send_command( - "logs", - data=req, - outformat=LogLines, - timeout=timeout, - ) + return response - def ps(self, query: ProcessQuery, timeout: int | float = 60) -> ProcessInstanceList: - return self.send_command( - "ps", - data=query, - outformat=ProcessInstanceList, - timeout=timeout, - ) + def kill( + self, request: ProcessQuery, timeout: int | float = 60 + ) -> ProcessInstanceList: + request.token.CopyFrom(self.token) + + try: + response = self.stub.kill(request, timeout=timeout) + except grpc.RpcError as e: + handle_grpc_error(e) + + return response + + def logs(self, request: LogRequest, timeout: int | float = 60) -> LogLines: + request.token.CopyFrom(self.token) + + try: + response = self.stub.logs(request, timeout=timeout) + except grpc.RpcError as e: + handle_grpc_error(e) + + return response + + def ps( + self, request: ProcessQuery, timeout: int | float = 60 + ) -> ProcessInstanceList: + request.token.CopyFrom(self.token) + + try: + response = self.stub.ps(request, timeout=timeout) + except grpc.RpcError as e: + handle_grpc_error(e) + + return response def flush( - self, query: ProcessQuery, timeout: int | float = 60 + self, request: ProcessQuery, timeout: int | float = 60 ) -> ProcessInstanceList: - return self.send_command( - "flush", - data=query, - outformat=ProcessInstanceList, - timeout=timeout, - ) + request.token.CopyFrom(self.token) + + try: + response = self.stub.flush(request, timeout=timeout) + except grpc.RpcError as e: + handle_grpc_error(e) + + return response def restart( - self, query: ProcessQuery, timeout: int | float = 60 - ) -> ProcessInstance: - return self.send_command( - "restart", - data=query, - outformat=ProcessInstance, - timeout=timeout, - ) + self, request: ProcessQuery, timeout: int | float = 60 + ) -> ProcessInstanceList: + request.token.CopyFrom(self.token) - def describe(self, timeout: int | float = 60) -> OldDescription: - return self.send_command( - "describe", - outformat=OldDescription, - timeout=timeout, - ) + try: + response = self.stub.restart(request, timeout=timeout) + except grpc.RpcError as e: + handle_grpc_error(e) + + return response + + def describe(self, timeout: int | float = 60) -> Description: + request = Request(token=copy_token(self.token)) + + try: + response = self.stub.describe(request, timeout=timeout) + except grpc.RpcError as e: + handle_grpc_error(e) + + return response diff --git a/src/drunc/process_manager/ssh_process_manager.py b/src/drunc/process_manager/ssh_process_manager.py index 14652bec3..305b0e43d 100644 --- a/src/drunc/process_manager/ssh_process_manager.py +++ b/src/drunc/process_manager/ssh_process_manager.py @@ -19,6 +19,7 @@ ProcessRestriction, ProcessUUID, ) +from druncschema.request_response_pb2 import ResponseFlag from drunc.exceptions import DruncCommandException from drunc.process_manager.process_manager import ProcessManager @@ -117,20 +118,30 @@ def kill_processes(self, uuids: list) -> ProcessInstanceList: ] del self.process_store[proc_uuid] - pil = ProcessInstanceList(values=ret) - return pil + return ProcessInstanceList( + name=self.name, + token=None, + values=ret, + flag=ResponseFlag.EXECUTED_SUCCESSFULLY, + ) def _terminate_impl(self) -> ProcessInstanceList: self.log.info("Terminating") + if self.process_store: self.log.info("Killing all the known processes before exiting") uuids = self._get_process_uid( query=ProcessQuery(names=[".*"]), order_by="leaf_first" ) return self.kill_processes(uuids) - else: - self.log.info("No known process to kill before exiting") - return ProcessInstanceList() + + self.log.info("No known process to kill before exiting") + return ProcessInstanceList( + name=self.name, + token=None, + values=[], + flag=ResponseFlag.EXECUTED_SUCCESSFULLY, + ) def _logs_impl(self, log_request: LogRequest) -> LogLines: self.log.debug(f"Retrieving logs for {log_request.query}") @@ -176,26 +187,35 @@ def _logs_impl(self, log_request: LogRequest) -> LogLines: _err_to_out=True, ) except Exception as e: - if uid not in self.process_store: - return LogLines( - uuid=ProcessUUID(uuid=uid), lines=[f"Could not retrieve logs: {e!s}"] - ) - else: - return LogLines( - uuid=ProcessUUID(uuid=uid), - lines=[ - f"Could not retrieve logs: {e!s}", + lines = [f"Could not retrieve logs: {e!s}"] + if uid in self.process_store: + lines.extend( + [ f"stdout: {self.process_store[uid].stdout}", - f"stderr: {self.process_store[uid].stderr}" - ], + f"stderr: {self.process_store[uid].stderr}", + ] ) + return LogLines( + name=self.name, + token=None, + uuid=ProcessUUID(uuid=uid), + lines=lines, + flag=ResponseFlag.UNHANDLED_EXCEPTION_THROWN, + ) + f.close() with open(f.name) as fi: lines = fi.readlines() if "Connection to " in lines[-1] and " closed." in lines[-1]: lines = lines[:-1] - return LogLines(uuid=ProcessUUID(uuid=uid), lines=lines) + return LogLines( + name=self.name, + token=None, + uuid=ProcessUUID(uuid=uid), + lines=lines, + flag=ResponseFlag.EXECUTED_SUCCESSFULLY, + ) os.remove(f.name) @@ -400,21 +420,29 @@ def _ps_impl(self, query: ProcessQuery) -> ProcessInstanceList: ) ret += [pi] - pil = ProcessInstanceList(values=ret) - - return pil + return ProcessInstanceList( + name=self.name, + token=None, + values=ret, + flag=ResponseFlag.EXECUTED_SUCCESSFULLY, + ) - def _boot_impl(self, boot_request: BootRequest) -> ProcessInstance: - self.log.debug(f"{self.name} running _boot_impl") + def _boot_impl(self, boot_request: BootRequest) -> ProcessInstanceList: + self.log.debug(f"{self.name} running boot command") this_uuid = str(uuid.uuid4()) - return self.__boot(boot_request, this_uuid) + process = self.__boot(boot_request, this_uuid) + return ProcessInstanceList( + name=self.name, + token=None, + values=[process], + flag=ResponseFlag.EXECUTED_SUCCESSFULLY, + ) def _restart_impl(self, query: ProcessQuery) -> ProcessInstanceList: self.log.info(f"{self.name} restarting {query.names} in session {self.session}") uuids = self._get_process_uid(query, in_boot_request=True) uuid = self._ensure_one_process(uuids, in_boot_request=True) - same_uuid_br = [] same_uuid_br = BootRequest() same_uuid_br.CopyFrom(self.boot_request[uuid]) same_uuid = uuid @@ -428,18 +456,29 @@ def _restart_impl(self, query: ProcessQuery) -> ProcessInstanceList: del self.boot_request[uuid] del uuid - ret = self.__boot(same_uuid_br, same_uuid) + ret = [self.__boot(same_uuid_br, same_uuid)] del same_uuid_br del same_uuid - return ret + return ProcessInstanceList( + name=self.name, + token=None, + values=ret, + flag=ResponseFlag.EXECUTED_SUCCESSFULLY, + ) def _kill_impl(self, query: ProcessQuery) -> ProcessInstanceList: self.log.info(f"{self.name} killing {query.names} in session {self.session}") + if self.process_store: uuids = self._get_process_uid(query, order_by="leaf_first") return self.kill_processes(uuids) - else: - self.log.info("No known process to kill before exiting") - return ProcessInstanceList() + + self.log.info("No known process to kill before exiting") + return ProcessInstanceList( + name=self.name, + token=None, + values=[], + flag=ResponseFlag.EXECUTED_SUCCESSFULLY, + ) diff --git a/src/drunc/session_manager/session_manager.py b/src/drunc/session_manager/session_manager.py index a1ce5b59c..e7f4c78f3 100644 --- a/src/drunc/session_manager/session_manager.py +++ b/src/drunc/session_manager/session_manager.py @@ -82,7 +82,6 @@ def describe(self, request: Request, context: ServicerContext) -> Description: return Description( type="session_manager", name=self.name, - session=self.name, commands=commands, children=[], flag=ResponseFlag.EXECUTED_SUCCESSFULLY, diff --git a/src/drunc/session_manager/session_manager_driver.py b/src/drunc/session_manager/session_manager_driver.py index c4c6b79d0..a52a6c7ab 100644 --- a/src/drunc/session_manager/session_manager_driver.py +++ b/src/drunc/session_manager/session_manager_driver.py @@ -1,12 +1,14 @@ """Driver for the session manager service.""" +import grpc from druncschema.description_pb2 import Description +from druncschema.request_response_pb2 import Request from druncschema.session_manager_pb2 import AllActiveSessions, AllConfigKeys from druncschema.session_manager_pb2_grpc import SessionManagerStub from druncschema.token_pb2 import Token -from grpc import Channel -from drunc.utils.shell_utils import DecodedResponse, GRPCDriver +from drunc.utils.grpc_utils import copy_token, handle_grpc_error +from drunc.utils.shell_utils import GRPCDriver class SessionManagerDriver(GRPCDriver): @@ -27,38 +29,58 @@ def __init__(self, address: str, token: Token, **kwargs): super().__init__( name="session_manager_driver", address=address, token=token, **kwargs ) + self.stub = SessionManagerStub(self.channel) - def create_stub(self, channel: Channel) -> SessionManagerStub: - """Create gRPC stubs for the session manager service. + def describe(self, timeout: int | float = 60) -> Description: + """Describe the session manager service. Args: - channel: The gRPC channel to use for communication. + timeout: The timeout for the gRPC call in seconds. Returns: - An object containing session manager service method stubs. + A response containing the description of the service. """ - return SessionManagerStub(channel) + request = Request(token=copy_token(self.token)) - def describe(self) -> DecodedResponse | None: - """Describe the session manager service. + try: + response = self.stub.describe(request, timeout=timeout) + except grpc.RpcError as e: + handle_grpc_error(e) - Returns: - A response containing the description of the service. - """ - return self.send_command("describe", outformat=Description) + return response - def list_all_sessions(self) -> DecodedResponse | None: + def list_all_sessions(self, timeout: int | float = 60) -> AllActiveSessions: """List all active sessions managed by the session manager. + Args: + timeout: The timeout for the gRPC call in seconds. + Returns: A response containing a list of all active sessions. """ - return self.send_command("list_all_sessions", outformat=AllActiveSessions) + request = Request(token=copy_token(self.token)) + + try: + response = self.stub.list_all_sessions(request, timeout=timeout) + except grpc.RpcError as e: + handle_grpc_error(e) - def list_all_configs(self) -> DecodedResponse | None: + return response + + def list_all_configs(self, timeout: int | float = 60) -> AllConfigKeys: """List all available configurations in the session manager. + Args: + timeout: The timeout for the gRPC call in seconds. + Returns: A response containing all available configuration keys. """ - return self.send_command("list_all_configs", outformat=AllConfigKeys) + request = Request(token=copy_token(self.token)) + + try: + response = self.stub.list_all_configs(request, timeout=timeout) + except grpc.RpcError as e: + handle_grpc_error(e) + + return response diff --git a/src/drunc/unified_shell/commands.py b/src/drunc/unified_shell/commands.py index 9f2ceef89..a5b662b92 100644 --- a/src/drunc/unified_shell/commands.py +++ b/src/drunc/unified_shell/commands.py @@ -30,9 +30,9 @@ def boot( ProcessQuery(user=user, session=session_name) ) - if len(processes.data.values) > 0: + if len(processes.values) > 0: click.confirm( - f"You already have {len(processes.data.values)} processes running in session {session_name}, are you sure you want to boot a session?", + f"You already have {len(processes.values)} processes running in session {session_name}, are you sure you want to boot a session?", abort=True, ) @@ -50,7 +50,7 @@ def boot( if not result: break log.debug( - f"'{result.data.process_description.metadata.name}' ({result.data.uuid.uuid}) started" + f"'{result.values[0].process_description.metadata.name}' ({result.values[0].uuid.uuid}) started" ) except InterruptedCommand: log.warning("Booting interrupted") diff --git a/src/drunc/unified_shell/context.py b/src/drunc/unified_shell/context.py index 2ff5c1c95..772b95fe9 100644 --- a/src/drunc/unified_shell/context.py +++ b/src/drunc/unified_shell/context.py @@ -50,11 +50,14 @@ def set_controller_driver(self, address_controller, **kwargs) -> None: del self._drivers["controller"] return - self._drivers["controller"] = ControllerDriver( + driver = ControllerDriver( self.address_controller, self._token, ) + # This will raise an exception if the driver already exists + self.set_driver("controller", driver) + def create_token(self, **kwargs) -> Token: from drunc.utils.shell_utils import create_dummy_token_from_uname diff --git a/src/drunc/unified_shell/shell.py b/src/drunc/unified_shell/shell.py index 29a8a4506..259151a73 100644 --- a/src/drunc/unified_shell/shell.py +++ b/src/drunc/unified_shell/shell.py @@ -187,7 +187,6 @@ def unified_shell( unified_shell_log.debug("Runnning [green]describe[/green]") try: desc = ctx.obj.get_driver().describe() - desc = desc.data except Exception as e: unified_shell_log.error( f"[red]Could not connect to the process manager at the address[/red] [green]{process_manager_address}[/]" @@ -281,7 +280,6 @@ def cleanup(): controller_name = session_dal.segment.controller.id unified_shell_log.debug("Initializing the [green]ControllerConfHandler[/green]") - controller_configuration = ControllerConfHandler( type=ConfTypes.OKSFileName, data=ctx.obj.configuration_file, diff --git a/src/drunc/unified_shell/shell_utils.py b/src/drunc/unified_shell/shell_utils.py index 6c2e2282f..120f44fbe 100644 --- a/src/drunc/unified_shell/shell_utils.py +++ b/src/drunc/unified_shell/shell_utils.py @@ -19,7 +19,7 @@ def run_fsm_sequence(sequence_commands, cmd_to_options_and_args, ctx, obj, **kwa if command == "boot": pmd = obj.get_driver("process_manager", quiet_fail=True) process_list = pmd.ps(ProcessQuery(names=[".*"])) - if not process_list.data.values: # We haven't started anything yet + if not process_list.values: # We haven't started anything yet accepted_command.append("boot") if cd: accepted_command_raw = cd.describe_fsm() diff --git a/src/drunc/utils/grpc_utils.py b/src/drunc/utils/grpc_utils.py index a4a01a9f6..9aea97b2d 100644 --- a/src/drunc/utils/grpc_utils.py +++ b/src/drunc/utils/grpc_utils.py @@ -1,3 +1,4 @@ +from typing import NoReturn import grpc from druncschema.generic_pb2 import PlainText @@ -92,9 +93,34 @@ def rethrow_if_timeout(grpc_error): raise ServerTimeout(grpc_error._state.details) from grpc_error +def handle_grpc_error(error: grpc.RpcError) -> NoReturn: + """Handle gRPC errors by rethrowing them with appropriate context. + + Args: + error: The gRPC error to handle. + """ + rethrow_if_unreachable_server(error) + rethrow_if_timeout(error) + raise error + + def interrupt_if_unreachable_server(grpc_error): if not server_is_reachable(grpc_error): if hasattr(grpc_error, "_state"): return grpc_error._state.details elif hasattr(grpc_error, "_details"): return grpc_error._details + + +def copy_token(token: Token) -> Token: + """Create a copy of the original token. + + Args: + token: The original token to copy. + + Returns: + A copy of the original token. + """ + token_copy = Token() + token_copy.CopyFrom(token) + return token_copy diff --git a/src/drunc/utils/shell_utils.py b/src/drunc/utils/shell_utils.py index 3438f1aa8..d02957cd2 100644 --- a/src/drunc/utils/shell_utils.py +++ b/src/drunc/utils/shell_utils.py @@ -15,12 +15,7 @@ DruncSetupException, DruncShellException, ) -from drunc.utils.grpc_utils import ( - UnpackingError, - rethrow_if_timeout, - rethrow_if_unreachable_server, - unpack_any, -) +from drunc.utils.grpc_utils import UnpackingError, handle_grpc_error, unpack_any from drunc.utils.utils import get_logger @@ -95,18 +90,10 @@ def __init__(self, name: str, address: str, token: Token): ) self.address = address - - self.channel = grpc.insecure_channel(self.address) - - self.stub = self.create_stub(self.channel) self.token = Token() self.token.CopyFrom(token) - @abc.abstractmethod - def create_stub(self, channel) -> object: - pass - def _create_request(self, payload=None) -> Request: token2 = Token() token2.CopyFrom(self.token) @@ -119,11 +106,6 @@ def _create_request(self, payload=None) -> Request: else: return Request(token=token2) - def __handle_grpc_error(self, error, command): - rethrow_if_unreachable_server(error) - rethrow_if_timeout(error) - raise error - def handle_response(self, response, command, outformat): dr = DecodedResponse( name=response.name, @@ -139,15 +121,6 @@ def handle_response(self, response, command, outformat): self.log.error(f"Error unpacking data: {e}") dr.data = response.data - for c_response in response.children: - try: - dr.children.append( - self.handle_response(c_response, command, outformat) - ) - except DruncServerSideError as e: - self.log.error(f"Exception thrown from child: {e}") - return dr - else: def text(verb="not executed", reason=""): @@ -162,7 +135,6 @@ def text(verb="not executed", reason=""): if response.data.Is(Stacktrace.DESCRIPTOR): stack = unpack_any(response.data, Stacktrace) dr.data = stack - # stack_txt = 'Stacktrace [bold red]on remote server![/bold red]\n' # Temporary - bold doesn't work stack_txt = "Stacktrace on remote server!\n" last_one = "" @@ -184,30 +156,25 @@ def text(verb="not executed", reason=""): elif response.flag in [ ResponseFlag.NOT_EXECUTED_NOT_IN_CONTROL, ]: - self.log.warn(text()) + self.log.warning(text()) else: self.log.error(text("failed", error_txt)) - for c_response in response.children: - try: - dr.children.append( - self.handle_response(c_response, command, outformat) - ) - except DruncServerSideError as e: - self.log.error(f"Exception thrown from child: {e}") - return dr + for c_response in response.children: + try: + dr.children.append(self.handle_response(c_response, command, outformat)) + except DruncServerSideError as e: + self.log.error(f"Exception thrown from child: {e}") + + return dr def send_command( self, command: str, data=None, outformat=None, - decode_children=False, timeout: int | float = 60, ): - if not self.stub: - raise DruncShellException("No stub initialised") - cmd = getattr(self.stub, command) # this throws if the command doesn't exist request = self._create_request(data) @@ -215,13 +182,12 @@ def send_command( try: response = cmd(request, timeout=timeout) except grpc.RpcError as e: - self.__handle_grpc_error(e, command) + handle_grpc_error(e) # TODO: TEMP HACK UNTIL UNPACKING IS REMOVED from druncschema.description_pb2 import Description - from druncschema.session_manager_pb2 import AllActiveSessions, AllConfigKeys - if isinstance(response, (Description, AllActiveSessions, AllConfigKeys)): + if isinstance(response, Description): return response return self.handle_response(response, command, outformat)