Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,8 +577,11 @@ def __init__(
"musa",
"hpu",
"sdaa",
"mps",
) or is_torch_xla_available(check_is_tpu=True):
raise ValueError(f"fp16 mixed precision requires a GPU (not {self.device.type!r}).")
raise ValueError(f"fp16 mixed precision requires a GPU or MPS device (not {self.device.type!r}).")
if self.device.type == "mps" and not is_torch_version(">=", "2.5.0"):
raise ValueError("fp16 mixed precision with MPS device requires a Pytorch >= 2.5.0")
kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {}

# FSDP2 doesn't use ShardedGradScaler, don't want to modify `get_grad_scaler`, rather create a simple utility
Expand All @@ -595,8 +598,10 @@ def __init__(
self.native_amp = True
else:
self.native_amp = is_bf16_available(True)
if mixed_precision == "bf16" and not self.native_amp and not is_torch_xla_available():
if not self.native_amp and not is_torch_xla_available():
raise ValueError("bf16 mixed precision requires PyTorch >= 1.10 and a supported device.")
if self.native_amp and self.device.type == "mps" and not is_torch_version(">=", "2.6.0"):
raise ValueError("bf16 mixed precision with MPS device requires a Pytorch >= 2.6.0")

# for DeepSpeed, self.state.mixed_precision is always "bf16",
# see https://github.com/huggingface/accelerate/blob/main/src/accelerate/state.py#L968 and
Expand Down
4 changes: 3 additions & 1 deletion src/accelerate/test_utils/scripts/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
gather,
gather_object,
is_bf16_available,
is_cuda_available,
is_datasets_available,
is_fp16_available,
is_hpu_available,
is_ipex_available,
is_mps_available,
is_pytest_available,
is_xpu_available,
set_seed,
Expand Down Expand Up @@ -534,7 +536,7 @@ def training_check(use_seedable_sampler=False):
accelerator.print("Training yielded the same results on one CPU or distributed setup with batch split.")

# FP32 wrapper check
if torch.cuda.is_available():
if is_cuda_available() or is_mps_available():
# Mostly a test that model.forward will have autocast when running unwrap_model(model, keep_fp32_wrapper=True)
print("Keep fp32 wrapper check.")
AcceleratorState._reset_state()
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def is_bf16_available(ignore_tpu=False):
if is_xpu_available():
return torch.xpu.is_bf16_supported()
if is_mps_available():
return False
return torch.backends.mps.is_macos_or_newer(14, 0)
return True


Expand Down
4 changes: 4 additions & 0 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2136,6 +2136,10 @@ def get_grad_scaler(distributed_type: DistributedType = None, **kwargs):
return torch.amp.GradScaler("hpu", **kwargs)
elif is_xpu_available():
return torch.amp.GradScaler("xpu", **kwargs)
elif is_mps_available():
if not is_torch_version(">=", "2.8.0"):
raise ValueError("Grad Scaler with MPS device requires a Pytorch >= 2.8.0")
return torch.amp.GradScaler("mps", **kwargs)
else:
if is_torch_version(">=", "2.3"):
return torch.amp.GradScaler("cuda", **kwargs)
Expand Down