diff --git a/src/accelerate/launchers.py b/src/accelerate/launchers.py index d811c1622e8..2c553636ca2 100644 --- a/src/accelerate/launchers.py +++ b/src/accelerate/launchers.py @@ -149,7 +149,7 @@ def train(*args): launcher = PrepareForLaunch(function, distributed_type="XLA") print("Launching a training on TPU cores.") xmp.spawn(launcher, args=args, start_method="fork") - elif in_colab and get_gpu_info()[1] < 2: + elif in_colab and (not torch.cuda.is_available() or get_gpu_info()[1] < 2): # No need for a distributed launch otherwise as it's either CPU or one GPU. if torch.cuda.is_available(): print("Launching training on one GPU.")