diff --git a/src/madengine/core/console.py b/src/madengine/core/console.py index 9340924a..fee615ba 100644 --- a/src/madengine/core/console.py +++ b/src/madengine/core/console.py @@ -14,18 +14,18 @@ class Console: """Class to run console commands. - + Attributes: shellVerbose (bool): The shell verbose flag. live_output (bool): The live output flag. """ def __init__( - self, - shellVerbose: bool=True, + self, + shellVerbose: bool=True, live_output: bool=False ) -> None: """Constructor of the Console class. - + Args: shellVerbose (bool): The shell verbose flag. live_output (bool): The live output flag. @@ -34,16 +34,16 @@ def __init__( self.live_output = live_output def sh( - self, - command: str, - canFail: bool=False, - timeout: int=60, - secret: bool=False, - prefix: str="", + self, + command: str, + canFail: bool=False, + timeout: int=60, + secret: bool=False, + prefix: str="", env: typing.Optional[typing.Dict[str, str]]=None ) -> str: """Run shell command. - + Args: command (str): The shell command. canFail (bool): The flag to allow failure. @@ -51,7 +51,7 @@ def sh( secret (bool): The flag to hide the command. prefix (str): The prefix of the output. env (typing_extensions.TypedDict): The environment variables. - + Returns: str: The output of the shell command. @@ -89,7 +89,7 @@ def sh( except subprocess.TimeoutExpired as exc: proc.kill() raise RuntimeError("Console script timeout") from exc - + # Check for failure if proc.returncode != 0: if not canFail: @@ -107,6 +107,6 @@ def sh( + "' failed with exit code " + str(proc.returncode) ) - + # Return the output return outs.strip() diff --git a/src/madengine/core/context.py b/src/madengine/core/context.py index cb628a69..8a9c832c 100644 --- a/src/madengine/core/context.py +++ b/src/madengine/core/context.py @@ -18,17 +18,18 @@ import os import re import typing +import shutil # third-party modules from madengine.core.console import Console def update_dict(d: typing.Dict, u: typing.Dict) -> typing.Dict: """Update dictionary. - + Args: d: The dictionary. u: The update dictionary. - + Returns: dict: The updated dictionary. """ @@ -41,14 +42,93 @@ def update_dict(d: typing.Dict, u: typing.Dict) -> typing.Dict: d[k] = v return d +def get_cmd(cmd, known_paths): + ''' + A function to get the full path to the command. + + Args: + cmd (str): command name. + known_paths (list): list of known paths to search for the command. + + Returns: + full path to the command if found, else throws an exception. + ''' + + cmd_path = shutil.which(cmd) + if cmd_path is not None: + return cmd_path + + for path in known_paths: + if not os.path.isdir(path): + continue + + cmd_path = os.path.join(path, cmd) + if os.path.isfile(cmd_path) and os.access(cmd_path, os.X_OK): + return cmd_path + + # throw exception if command not found. + raise FileNotFoundError(f'{cmd} not found.') + + +def get_rocminfo_path(): + """Get the rocminfo command. + + Returns: + str: The absolute path to rocminfo. + """ + rocm_path = os.environ.get("ROCM_PATH", "/opt/rocm") + known_paths = [os.path.join(rocm_path, "bin")] + + return get_cmd("rocminfo", known_paths) + + +def get_rocmsmi_path(): + """Get the rocm-smi command. + + Returns: + str: The absolute path to rocm-smi. + """ + rocm_path = os.environ.get("ROCM_PATH", "/opt/rocm") + known_paths = [os.path.join(rocm_path, "bin")] + + return get_cmd("rocm-smi", known_paths) + + +def get_amdsmi_path(): + """Get the amd-smi command. + + Returns: + str: The absolute path to amd-smi. + """ + rocm_path = os.environ.get("ROCM_PATH", "/opt/rocm") + known_paths = [os.path.join(rocm_path, "bin")] + + return get_cmd("amd-smi", known_paths) + + +def get_nvidiasmi_path(): + """Get the nvidia-smi command. + + Returns: + str: The absolute path to nvidia-smi. + """ + cuda_path = os.environ.get("CUDA_PATH", "/usr/local/cuda") + known_paths = [ + "/usr/bin", + "/usr/local/bin", + os.path.join(cuda_path, "bin") + ] + + return get_cmd("nvidia-smi", known_paths) + class Context: """Class to determine context. - + Attributes: console: The console. ctx: The context. - + Methods: get_ctx_test: Get context test. get_gpu_vendor: Get GPU vendor. @@ -62,16 +142,16 @@ class Context: filter: Filter. """ def __init__( - self, - additional_context: str=None, + self, + additional_context: str=None, additional_context_file: str=None ) -> None: """Constructor of the Context class. - + Args: additional_context: The additional context. additional_context_file: The additional context file. - + Raises: RuntimeError: If the GPU vendor is not detected. RuntimeError: If the GPU architecture is not detected. @@ -118,7 +198,7 @@ def __init__( 'MASTER_ADDR': 'localhost', 'MASTER_PORT': 6006, 'HOST_LIST': '', - 'NCCL_SOCKET_IFNAME': '', + 'NCCL_SOCKET_IFNAME': '', 'GLOO_SOCKET_IFNAME': '' } @@ -129,7 +209,7 @@ def __init__( mad_secrets[key] = os.environ[key] if mad_secrets: update_dict(self.ctx['docker_build_arg'], mad_secrets) - update_dict(self.ctx['docker_env_vars'], mad_secrets) + update_dict(self.ctx['docker_env_vars'], mad_secrets) ## ADD MORE CONTEXTS HERE ## @@ -150,7 +230,7 @@ def __init__( def get_ctx_test(self) -> str: """Get context test. - + Returns: str: The output of the shell command. @@ -164,29 +244,40 @@ def get_ctx_test(self) -> str: def get_gpu_vendor(self) -> str: """Get GPU vendor. - + Returns: str: The output of the shell command. - + Raises: RuntimeError: If the GPU vendor is unable to detect. - + Note: What types of GPU vendors are supported? - NVIDIA - AMD """ - # Check if the GPU vendor is NVIDIA or AMD, and if it is unable to detect the GPU vendor. - return self.console.sh( - 'bash -c \'if [[ -f /usr/bin/nvidia-smi ]] && $(/usr/bin/nvidia-smi > /dev/null 2>&1); then echo "NVIDIA"; elif [[ -f /opt/rocm/bin/amd-smi ]]; then echo "AMD"; elif [[ -f /usr/local/bin/amd-smi ]]; then echo "AMD"; else echo "Unable to detect GPU vendor"; fi || true\'' - ) + # Try to detect NVIDIA GPU first + try: + _ = get_nvidiasmi_path() + return "NVIDIA" + except FileNotFoundError: + pass + + # Try to detect AMD GPU + try: + _ = get_amdsmi_path() + return "AMD" + except FileNotFoundError: + pass + + return "Unable to detect GPU vendor" def get_host_os(self) -> str: """Get host OS. - + Returns: str: The output of the shell command. - + Raises: RuntimeError: If the host OS is unable to detect. @@ -203,7 +294,7 @@ def get_host_os(self) -> str: def get_numa_balancing(self) -> bool: """Get NUMA balancing. - + Returns: bool: The output of the shell command. @@ -212,9 +303,9 @@ def get_numa_balancing(self) -> bool: Note: NUMA balancing is enabled if the output is '1', and disabled if the output is '0'. - + What is NUMA balancing? - Non-Uniform Memory Access (NUMA) is a computer memory design used in multiprocessing, + Non-Uniform Memory Access (NUMA) is a computer memory design used in multiprocessing, where the memory access time depends on the memory location relative to the processor. """ # Check if NUMA balancing is enabled or disabled. @@ -226,13 +317,13 @@ def get_numa_balancing(self) -> bool: def get_system_ngpus(self) -> int: """Get system number of GPUs. - + Returns: int: The number of GPUs. - + Raises: RuntimeError: If the GPU vendor is not detected. - + Note: What types of GPU vendors are supported? - NVIDIA @@ -240,7 +331,8 @@ def get_system_ngpus(self) -> int: """ number_gpus = 0 if self.ctx["docker_env_vars"]["MAD_GPU_VENDOR"] == "AMD": - number_gpus = int(self.console.sh("amd-smi list --csv | tail -n +3 | wc -l")) + amdsmi_path = get_amdsmi_path() + number_gpus = int(self.console.sh(f"{amdsmi_path} list --csv | tail -n +3 | wc -l")) elif self.ctx["docker_env_vars"]["MAD_GPU_VENDOR"] == "NVIDIA": number_gpus = int(self.console.sh("nvidia-smi -L | wc -l")) else: @@ -250,21 +342,22 @@ def get_system_ngpus(self) -> int: def get_system_gpu_architecture(self) -> str: """Get system GPU architecture. - + Returns: str: The GPU architecture. - + Raises: RuntimeError: If the GPU vendor is not detected. RuntimeError: If the GPU architecture is unable to determine. - + Note: What types of GPU vendors are supported? - NVIDIA - AMD """ if self.ctx["docker_env_vars"]["MAD_GPU_VENDOR"] == "AMD": - return self.console.sh("/opt/rocm/bin/rocminfo |grep -o -m 1 'gfx.*'") + rocminfo_cmd = get_rocminfo_path() + return self.console.sh(f"{rocminfo_cmd} | grep -o -m 1 'gfx.*'") elif self.ctx["docker_env_vars"]["MAD_GPU_VENDOR"] == "NVIDIA": return self.console.sh( "nvidia-smi -L | head -n1 | sed 's/(UUID: .*)//g' | sed 's/GPU 0: //g'" @@ -274,14 +367,14 @@ def get_system_gpu_architecture(self) -> str: def get_system_gpu_product_name(self) -> str: """Get system GPU product name. - + Returns: str: The GPU product name (e.g., AMD Instinct MI300X, NVIDIA H100 80GB HBM3). - + Raises: RuntimeError: If the GPU vendor is not detected. RuntimeError: If the GPU product name is unable to determine. - + Note: What types of GPU vendors are supported? - NVIDIA @@ -304,7 +397,7 @@ def get_system_hip_version(self): def get_docker_gpus(self) -> typing.Optional[str]: """Get Docker GPUs. - + Returns: str: The range of GPUs. """ @@ -316,7 +409,7 @@ def get_docker_gpus(self) -> typing.Optional[str]: def get_gpu_renderD_nodes(self) -> typing.Optional[typing.List[int]]: """Get GPU renderD nodes from KFD properties. - + Returns: list: The list of GPU renderD nodes. @@ -336,8 +429,9 @@ def get_gpu_renderD_nodes(self) -> typing.Optional[typing.List[int]]: # Check if the GPU vendor is AMD. if self.ctx['docker_env_vars']['MAD_GPU_VENDOR']=='AMD': # get rocm version - rocm_version = self.console.sh("cat /opt/rocm/.info/version | cut -d'-' -f1") - + rocm_path = os.environ.get("ROCM_PATH", "/opt/rocm") + rocm_version = self.console.sh(f"cat {rocm_path}/.info/version | cut -d'-' -f1") + # get renderDs from KFD properties kfd_properties = self.console.sh("grep -r drm_render_minor /sys/devices/virtual/kfd/kfd/topology/nodes").split("\n") kfd_properties = [line for line in kfd_properties if int(line.split()[-1])!=0] # CPUs are 0, skip them @@ -348,7 +442,7 @@ def get_gpu_renderD_nodes(self) -> typing.Optional[typing.List[int]]: if output: data = json.loads(output) else: - raise ValueError("Failed to retrieve AMD GPU data") + raise ValueError("Failed to retrieve AMD GPU data") # get gpu id - renderD mapping using unique id if ROCm < 6.1.2 and node id otherwise # node id is more robust but is only available from 6.1.2 @@ -421,10 +515,10 @@ def set_multi_node_runner(self) -> str: def filter(self, unfiltered: typing.Dict) -> typing.Dict: """Filter the unfiltered dictionary based on the context. - + Args: unfiltered: The unfiltered dictionary. - + Returns: dict: The filtered dictionary. """ diff --git a/src/madengine/core/docker.py b/src/madengine/core/docker.py index 7ed4ff36..5345f146 100644 --- a/src/madengine/core/docker.py +++ b/src/madengine/core/docker.py @@ -8,6 +8,8 @@ # built-in modules import os import typing +import subprocess + # user-defined modules from madengine.core.console import Console @@ -23,6 +25,47 @@ class Docker: groupid (str): The group id. """ + @staticmethod + def is_valid_cmd(cmd) -> bool: + """ + Check if the given command is a valid container runtime by running ' container ls'. + This is necessary because a container runtime might be installed but not properly configured. + """ + + try: + result = subprocess.run( + [cmd, 'container', 'ls'], + capture_output=True, + text=True, + timeout=10 + ) + + return result.returncode == 0 + + except (subprocess.TimeoutExpired, subprocess.SubprocessError, FileNotFoundError): + return False + + @staticmethod + def get_container_cmd() -> str: + """ + Check which container runtime is available, docker or podman. + Use docker by default if both are available. + + Returns: + str: The available container command ('docker' or 'podman'). + + Raises: + RuntimeError: If neither docker nor podman is available. + """ + + if Docker.is_valid_cmd('docker'): + return 'docker' + + if Docker.is_valid_cmd('podman'): + return 'podman' + + raise RuntimeError("Neither docker nor podman is available on this system.") + def __init__( self, image: str, @@ -45,8 +88,10 @@ def __init__( console (Console): The console object. Raises: - RuntimeError: If the container name already exists. + RuntimeError: If the container name already exists or no container runtime available. """ + container_cmd = self.get_container_cmd() + # initialize variables self.docker_sha = None self.keep_alive = keep_alive @@ -57,7 +102,7 @@ def __init__( # check if container name exists container_name_exists = self.console.sh( - "docker container ps -a | grep " + container_name + " | wc -l" + f"{container_cmd} container ps -a | grep " + container_name + " | wc -l" ) # if container name exists, raise error. if container_name_exists != "0": @@ -65,12 +110,12 @@ def __init__( "Container with name, " + container_name + " already exists. " - + "Please stop (docker stop --time=1 SHA) and remove this (docker rm -f SHA) to proceed..." + + f"Please stop ({container_cmd} stop --time=1 SHA) and remove this ({container_cmd} rm -f SHA) to proceed..." ) # run docker command command = ( - "docker run -t -d -u " + f"{container_cmd} run -t -d -u " + self.userid + ":" + self.groupid @@ -83,7 +128,7 @@ def __init__( if mounts is not None: for mount in mounts: command += "-v " + mount + ":" + mount + " " - + # add current working directory command += "-v " + cwd + ":/myworkspace/ " @@ -91,7 +136,7 @@ def __init__( if envVars is not None: for evar in envVars.keys(): command += "-e " + evar + "=" + envVars[evar] + " " - + command += "--workdir /myworkspace/ " command += "--name " + container_name + " " command += image + " " @@ -102,48 +147,50 @@ def __init__( # find container sha self.docker_sha = self.console.sh( - "docker ps -aqf 'name=" + container_name + "' " + f"{container_cmd} ps -aqf 'name=" + container_name + "' " ) def sh( - self, - command: str, - timeout: int=60, + self, + command: str, + timeout: int=60, secret: bool=False ) -> str: """Run shell command inside docker. - + Args: command (str): The shell command. timeout (int): The timeout in seconds. secret (bool): The flag to hide the command. - + Returns: str: The output of the shell command. """ + container_cmd = self.get_container_cmd() # run as root! return self.console.sh( - "docker exec " + self.docker_sha + ' bash -c "' + command + '"', + f"{container_cmd} exec " + self.docker_sha + ' bash -c "' + command + '"', timeout=timeout, secret=secret, ) def __del__(self): """Destructor of the Docker class.""" + container_cmd = self.get_container_cmd() # stop and remove docker container, if not keep_alive and docker sha exists, else print docker sha. if not self.keep_alive and self.docker_sha: - self.console.sh("docker stop --time=1 " + self.docker_sha) - self.console.sh("docker rm -f " + self.docker_sha) + self.console.sh(f"{container_cmd} stop --time=1 " + self.docker_sha) + self.console.sh(f"{container_cmd} rm -f " + self.docker_sha) return # print docker sha if self.docker_sha: print("==========================================") - print("Keeping docker alive, sha :", self.docker_sha) + print(f"Keeping {container_cmd} alive, sha :", self.docker_sha) print( "Open a bash session in container : ", - "docker exec -it " + self.docker_sha + " bash", + f"{container_cmd} exec -it " + self.docker_sha + " bash", ) - print("Stop container : ", "docker stop --time=1 " + self.docker_sha) - print("Remove container : ", "docker rm -f " + self.docker_sha) + print("Stop container : ", f"{container_cmd} stop --time=1 " + self.docker_sha) + print("Remove container : ", f"{container_cmd} rm -f " + self.docker_sha) print("==========================================") diff --git a/src/madengine/scripts/common/pre_scripts/rocEnvTool/rocenv_tool.py b/src/madengine/scripts/common/pre_scripts/rocEnvTool/rocenv_tool.py index 8aca62d7..9988dfc4 100644 --- a/src/madengine/scripts/common/pre_scripts/rocEnvTool/rocenv_tool.py +++ b/src/madengine/scripts/common/pre_scripts/rocEnvTool/rocenv_tool.py @@ -8,6 +8,7 @@ from console import Console from csv_parser import CSVParser import json +import shutil rocm_version = None pkgtype = None @@ -22,6 +23,33 @@ def __init__(self, section_info, cmds): self.section_info = section_info self.cmds = cmds +def get_cmd(cmd, known_paths): + ''' + A function to get the full path to the command. + + Args: + cmd (str): command name. + known_paths (list): list of known paths to search for the command. + + Returns: + full path to the command if found, else throws an exception. + ''' + + cmd_path = shutil.which(cmd) + if cmd_path is not None: + return cmd_path + + for path in known_paths: + if not os.path.isdir(path): + continue + + cmd_path = os.path.join(path, cmd) + if os.path.isfile(cmd_path) and os.access(cmd_path, os.X_OK): + return cmd_path + + # throw exception if command not found. + raise FileNotFoundError(f'{cmd} not found.') + ## utility functions. def parse_env_tags_json(json_file): env_tags = None @@ -32,21 +60,13 @@ def parse_env_tags_json(json_file): ## Hardware information. def print_hardware_information(): - cmd = None - if os.path.isfile("/usr/bin/lshw"): - cmd = "/usr/bin/lshw" - elif os.path.isfile("/usr/sbin/lshw"): - cmd = "/usr/sbin/lshw" - elif os.path.isfile("/sbin/lshw"): - cmd = "/sbin/lshw" - else: - print ("WARNING: Install lshw to get lshw hardware information") - print (" Ex: sudo apt install lshw") - - if cmd is not None: + cmd = "lshw" + try: + cmd = get_cmd(cmd, ["/usr/bin", "/usr/sbin", "/sbin"]) cmd_info = CommandInfo("HardwareInformation", [cmd]) return cmd_info - else: + except FileNotFoundError as e: + print(f"WARNING: {e}") return None ## CPU Hardware Information @@ -57,10 +77,22 @@ def print_cpu_hardware_information(): ## GPU Hardware information. def print_gpu_hardware_information(gpu_device_type): + cmd = None if gpu_device_type == "AMD": - cmd = "/opt/rocm/bin/rocminfo" + try: + cmd = "rocminfo" + cmd = get_cmd(cmd, ["/opt/rocm/bin"]) + except FileNotFoundError as e: + print(f"WARNING: {e}") + return None elif gpu_device_type == "NVIDIA": - cmd = "nvidia-smi -L" + try: + cmd = "nvidia-smi" + cmd = get_cmd(cmd, ["/usr/bin", "/usr/local/cuda/bin"]) + cmd = f'{cmd} -L' + except FileNotFoundError as e: + print(f"WARNING: {e}") + return None else: print ("WARNING: Unknown GPU device detected") cmd_info = CommandInfo("GPU Information", [cmd]) @@ -68,9 +100,14 @@ def print_gpu_hardware_information(gpu_device_type): ## BIOS Information. def print_bios_settings(): - cmd = "/usr/sbin/dmidecode" - cmd_info = CommandInfo("dmidecode Information", [cmd]) - return cmd_info + cmd = "dmidecode" + try: + cmd = get_cmd(cmd, ["/usr/sbin", "/sbin"]) + cmd_info = CommandInfo("dmidecode Information", [cmd]) + return cmd_info + except FileNotFoundError as e: + print(f"WARNING: {e}") + return None ## OS information. def print_os_information(): @@ -81,9 +118,14 @@ def print_os_information(): ## Memory Information. def print_memory_information(): - cmd = "/usr/bin/lsmem" - cmd_info = CommandInfo("Memory Information", [cmd]) - return cmd_info + cmd = "lsmem" + try: + cmd = get_cmd(cmd, ["/usr/bin", "/bin"]) + cmd_info = CommandInfo("Memory Information", [cmd]) + return cmd_info + except FileNotFoundError as e: + print(f"WARNING: {e}") + return None ## ROCm version data def print_rocm_version_information(): @@ -148,7 +190,12 @@ def print_rocm_environment_variables(): def print_rocm_smi_details(smi_config): cmd_info = None - cmd = "/opt/rocm/bin/rocm-smi" + cmd = "rocm-smi" + try: + cmd = get_cmd(cmd, ["/opt/rocm/bin"]) + except FileNotFoundError as e: + print(f"WARNING: {e}") + return None if (smi_config == "rocm_smi"): cmd_info = CommandInfo("ROCm SMI", [cmd]) elif (smi_config == "ifwi_version"): @@ -196,9 +243,14 @@ def print_rocm_smi_details(smi_config): return cmd_info def print_rocm_info_details(): - cmd = "/opt/rocm/bin/rocminfo" - cmd_info = CommandInfo("rocminfo", [cmd]) - return cmd_info + cmd = "rocminfo" + try: + cmd = get_cmd(cmd, ["/opt/rocm/bin"]) + cmd_info = CommandInfo("rocminfo", [cmd]) + return cmd_info + except FileNotFoundError as e: + print(f"WARNING: {e}") + return None ## dmesg boot logs - GPU/ATOM/DRM/BIOS def print_dmesg_logs(ignore_prev_boot_logs=True): @@ -239,9 +291,15 @@ def print_dmesg_logs(ignore_prev_boot_logs=True): ## print amdgpu modinfo def print_amdgpu_modinfo(): - cmd = "/sbin/modinfo amdgpu" - cmd_info = CommandInfo("amdgpu modinfo", [cmd]) - return cmd_info + cmd = "modinfo" + try: + cmd = get_cmd(cmd, ["/sbin", "/usr/sbin"]) + cmd = f"{cmd} amdgpu" + cmd_info = CommandInfo("amdgpu modinfo", [cmd]) + return cmd_info + except FileNotFoundError as e: + print(f"WARNING: {e}") + return None ## print pip list def print_pip_list_details(): @@ -256,9 +314,15 @@ def print_check_numa_balancing(): ## print cuda version information. def print_cuda_version_information(): - cmd = "nvcc --version" - cmd_info = CommandInfo("CUDA information", [cmd]) - return cmd_info + cmd = "nvcc" + try: + cmd = get_cmd(cmd, ["/usr/local/cuda/bin", "/usr/bin"]) + cmd = f"{cmd} --version" + cmd_info = CommandInfo("CUDA information", [cmd]) + return cmd_info + except FileNotFoundError as e: + print(f"WARNING: {e}") + return None def print_cuda_env_variables(): cmd = "env | /bin/grep -i -E 'cuda|nvidia|pytorch|mpi|openmp|ucx|cu'" @@ -353,7 +417,7 @@ def generate_env_info(gpu_device_type): env_map["rocm_smi_showxgmierr"] = print_rocm_smi_details("rocm_smi_showxgmierr") env_map["rocm_smi_clocks"] = print_rocm_smi_details("rocm_smi_clocks") env_map["rocm_smi_showcompute_partition"] = print_rocm_smi_details("rocm_smi_showcompute_partition") - env_map["rocm_smi_nodesbwi"] = print_rocm_smi_details("rocm_smi_nodesbwi") + env_map["rocm_smi_nodesbw"] = print_rocm_smi_details("rocm_smi_nodesbw") env_map["rocm_smi_gpudeviceid"] = print_rocm_smi_details("rocm_smi_gpudeviceid") env_map["rocm_info"] = print_rocm_info_details() elif gpu_device_type == "NVIDIA": diff --git a/src/madengine/tools/run_models.py b/src/madengine/tools/run_models.py index 4f26450d..57b90ae8 100644 --- a/src/madengine/tools/run_models.py +++ b/src/madengine/tools/run_models.py @@ -42,7 +42,7 @@ # MADEngine modules from madengine.core.console import Console -from madengine.core.context import Context +from madengine.core.context import Context, get_amdsmi_path, get_nvidiasmi_path from madengine.core.dataprovider import Data from madengine.core.docker import Docker from madengine.utils.ops import PythonicTee, file_print, substring_found, find_and_replace_pattern @@ -195,15 +195,17 @@ def in_virtualenv(self) -> bool: def clean_up_docker_container(self, is_cleaned: bool = False) -> None: """Clean up docker container.""" + container_cmd = Docker.get_container_cmd() if is_cleaned: - self.console.sh("docker ps -a || true") - self.console.sh("docker kill $(docker ps -q) || true") + self.console.sh(f"{container_cmd} ps -a || true") + self.console.sh(f"{container_cmd} kill $({container_cmd} ps -q) || true") # get gpu vendor gpu_vendor = self.context.ctx["docker_env_vars"]["MAD_GPU_VENDOR"] # show gpu info if gpu_vendor.find("AMD") != -1: - self.console.sh("/opt/rocm/bin/amd-smi || true") + amdsmi_path = get_amdsmi_path() + self.console.sh(f"{amdsmi_path} || true") elif gpu_vendor.find("NVIDIA") != -1: self.console.sh("nvidia-smi -L || true") @@ -584,8 +586,9 @@ def run_model_impl( container_name = "container_" + re.sub('.*:','', image_docker_name) # remove docker container hub details ## Note: --network=host added to fix issue on CentOS+FBK kernel, where iptables is not available + container_cmd = Docker.get_container_cmd() self.console.sh( - "docker build " + f"{container_cmd} build " + use_cache_str + " --network=host " + " -t " @@ -618,7 +621,7 @@ def run_model_impl( print(f"BASE DOCKER is {run_details.base_docker}") # print base docker image digest - run_details.docker_sha = self.console.sh("docker manifest inspect " + run_details.base_docker + " | grep digest | head -n 1 | cut -d \\\" -f 4") + run_details.docker_sha = self.console.sh(f"{container_cmd} manifest inspect " + run_details.base_docker + " | grep digest | head -n 1 | cut -d \\\" -f 4") print(f"BASE DOCKER SHA is {run_details.docker_sha}") else: @@ -723,9 +726,11 @@ def run_model_impl( # echo gpu smi info if gpu_vendor.find("AMD") != -1: - smi = model_docker.sh("/opt/rocm/bin/amd-smi || true") + amdsmi_path = get_amdsmi_path() + smi = model_docker.sh(f"{amdsmi_path} || true") elif gpu_vendor.find("NVIDIA") != -1: - smi = model_docker.sh("/usr/bin/nvidia-smi || true") + nvidiasmi_path = get_nvidiasmi_path() + smi = model_docker.sh(f"{nvidiasmi_path} || true") else: raise RuntimeError("Unable to determine gpu vendor.")