diff --git a/mpi_util.py b/mpi_util.py index 7da01dc..d6a6ab7 100644 --- a/mpi_util.py +++ b/mpi_util.py @@ -46,7 +46,7 @@ def guess_available_gpus(n_gpus=None): cuda_visible_divices = os.environ['CUDA_VISIBLE_DEVICES'] cuda_visible_divices = cuda_visible_divices.split(',') return [int(n) for n in cuda_visible_divices] - if 'RCALL_NUM_GPU' not in os.environ: + if 'RCALL_NUM_GPU' in os.environ: n_gpus = int(os.environ['RCALL_NUM_GPU']) return list(range(n_gpus)) nvidia_dir = '/proc/driver/nvidia/gpus/'