-
Notifications
You must be signed in to change notification settings - Fork 3.9k
feat(orttraining): add CPU fallback for FusedAdam optimizer #28233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |
| This file is adapted from fused adam in NVIDIA/apex, commit a109f85 | ||
| """ | ||
|
|
||
| import warnings | ||
| from enum import IntEnum | ||
|
|
||
| import torch | ||
|
|
@@ -31,7 +32,10 @@ class FusedAdam(torch.optim.Optimizer): | |
| when adam_w_mode = 1 and `torch/Adam <https://github.com/pytorch/pytorch/blob/a217a62e73fd30b658743af8a69966f90327f018/torch/optim/adamw.py#L6>`_ | ||
| when adam_w_mode = 2 | ||
|
|
||
| Currently GPU-only. | ||
| On CUDA-capable systems this optimizer uses fused CUDA kernels for efficiency. | ||
| On CPU-only systems (or when ``torch.cuda.is_available()`` returns ``False``) it | ||
| automatically falls back to an equivalent standard PyTorch optimizer with a | ||
| one-time :class:`UserWarning`. Performance will be reduced in fallback mode. | ||
|
|
||
| This version of fused Adam implements 2 fusions. | ||
|
|
||
|
|
@@ -83,16 +87,62 @@ def __init__( | |
| self._adam_w_mode = adam_w_mode | ||
| self._set_grad_none = set_grad_none | ||
|
|
||
| # Skip buffer | ||
| self._dummy_overflow_buf = torch.cuda.IntTensor([0]) | ||
| if torch.cuda.is_available(): | ||
| # Skip buffer | ||
| self._dummy_overflow_buf = torch.cuda.IntTensor([0]) | ||
|
|
||
| from onnxruntime.training.ortmodule.torch_cpp_extensions import fused_ops # noqa: PLC0415 | ||
| from onnxruntime.training.ortmodule.torch_cpp_extensions import fused_ops # noqa: PLC0415 | ||
|
|
||
| self._multi_tensor_adam = fused_ops.multi_tensor_adam | ||
| self._multi_tensor_applier = MultiTensorApply(2048 * 32) | ||
| self._TorchTensorVector = fused_ops.TorchTensorVector | ||
| self._multi_tensor_adam = fused_ops.multi_tensor_adam | ||
| self._multi_tensor_applier = MultiTensorApply(2048 * 32) | ||
| self._TorchTensorVector = fused_ops.TorchTensorVector | ||
| self._cpu_fallback_optimizer = None | ||
| else: | ||
| warnings.warn( | ||
| "FusedAdam CUDA kernels are unavailable; falling back to a standard PyTorch optimizer on CPU. " | ||
| "Performance will be reduced.", | ||
| UserWarning, | ||
| stacklevel=2, | ||
| ) | ||
| # Build an equivalent standard PyTorch optimizer for the CPU path. | ||
| # Retrieve the flat list of parameters from the already-registered param_groups. | ||
| _params = [p for group in self.param_groups for p in group["params"]] | ||
| if adam_w_mode == AdamWMode.ADAM_L2_REGULARIZATION: | ||
| self._cpu_fallback_optimizer = torch.optim.Adam( | ||
| _params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay | ||
| ) | ||
| elif adam_w_mode == AdamWMode.ADAMW_TORCH: | ||
| self._cpu_fallback_optimizer = torch.optim.AdamW( | ||
| _params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay | ||
| ) | ||
|
Comment on lines
+107
to
+117
|
||
| else: | ||
| # AdamWMode.ADAMW_TRANSFORMERS (default): prefer transformers.AdamW | ||
| try: | ||
| from transformers import AdamW as _TransformersAdamW # noqa: PLC0415 | ||
|
|
||
| self._cpu_fallback_optimizer = _TransformersAdamW( | ||
| _params, | ||
| lr=lr, | ||
| betas=betas, | ||
| eps=eps, | ||
| weight_decay=weight_decay, | ||
| correct_bias=bias_correction, | ||
| ) | ||
| except ImportError: | ||
| warnings.warn( | ||
| "transformers package not available; using torch.optim.AdamW as CPU fallback " | ||
| "for AdamWMode.ADAMW_TRANSFORMERS.", | ||
| UserWarning, | ||
| stacklevel=2, | ||
| ) | ||
| self._cpu_fallback_optimizer = torch.optim.AdamW( | ||
| _params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay | ||
| ) | ||
|
Comment on lines
+118
to
+140
|
||
|
|
||
| def zero_grad(self, set_to_none=True): | ||
| if self._cpu_fallback_optimizer is not None: | ||
| self._cpu_fallback_optimizer.zero_grad(set_to_none=self._set_grad_none or set_to_none) | ||
| return | ||
| if self._set_grad_none or set_to_none: | ||
| for group in self.param_groups: | ||
| for p in group["params"]: | ||
|
|
@@ -109,6 +159,9 @@ def step(self, closure=None): | |
|
|
||
| The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes. | ||
| """ | ||
| if self._cpu_fallback_optimizer is not None: | ||
| return self._cpu_fallback_optimizer.step(closure) | ||
|
|
||
| loss = None | ||
| if closure is not None: | ||
| loss = closure() | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,128 @@ | ||||||||||||||||||||||||
| # Copyright (c) Microsoft Corporation. All rights reserved. | ||||||||||||||||||||||||
| # Licensed under the MIT License. | ||||||||||||||||||||||||
| """Unit tests for FusedAdam CPU fallback (issue #17403). | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| These tests patch torch.cuda.is_available to return False so they run | ||||||||||||||||||||||||
| deterministically on both CPU-only and CUDA machines. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Import strategy: | ||||||||||||||||||||||||
| * Add the optim/ source directory to sys.path. | ||||||||||||||||||||||||
| * Pre-register "_multi_tensor_apply" in sys.modules by importing it | ||||||||||||||||||||||||
| directly (it is pure Python with no external deps). | ||||||||||||||||||||||||
| * Load fused_adam.py via importlib with __package__ set to the name we | ||||||||||||||||||||||||
| used for the pre-registered _multi_tensor_apply module so that the | ||||||||||||||||||||||||
|
Comment on lines
+9
to
+13
|
||||||||||||||||||||||||
| * Add the optim/ source directory to sys.path. | |
| * Pre-register "_multi_tensor_apply" in sys.modules by importing it | |
| directly (it is pure Python with no external deps). | |
| * Load fused_adam.py via importlib with __package__ set to the name we | |
| used for the pre-registered _multi_tensor_apply module so that the | |
| * Load ``_multi_tensor_apply.py`` directly from the optim/ source | |
| directory via ``importlib.util.spec_from_file_location``. | |
| * Pre-register that module in ``sys.modules`` under a synthetic package | |
| name so the relative import inside ``fused_adam.py`` can resolve. | |
| * Load ``fused_adam.py`` via ``importlib.util.spec_from_file_location`` | |
| with ``__package__`` set to that synthetic package name so the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CPU fallback uses a separate
torch.optim.*instance, butFusedAdamstill exposes its ownstate/param_groupsfromsuper().__init__. This breaks core optimizer APIs in fallback mode:state_dict()/load_state_dict()won’t reflect the fallback’s moment estimates, and runtime edits likeopt.param_groups[0]['lr']=...(oradd_param_group) won’t affect the optimizer thatstep()actually uses. Consider delegatingstate_dict/load_state_dict/add_param_group(and ensuringparam_groups/statestay in sync) when_cpu_fallback_optimizeris set.