feat(orttraining): add CPU fallback for FusedAdam optimizer#28233
feat(orttraining): add CPU fallback for FusedAdam optimizer#28233Rishi-Dave wants to merge 1 commit intomicrosoft:mainfrom
Conversation
FusedAdam previously failed to instantiate on CPU-only PyTorch builds
because __init__ unconditionally allocated a torch.cuda.IntTensor and
imported the CUDA-only fused_ops C++ extension.
Detect torch.cuda.is_available() at construction time. When CUDA is
unavailable, emit a one-time UserWarning and build a standard PyTorch
optimizer that matches the requested AdamWMode:
- ADAM_L2_REGULARIZATION -> torch.optim.Adam
- ADAMW_TORCH -> torch.optim.AdamW
- ADAMW_TRANSFORMERS -> transformers.AdamW (falls back to
torch.optim.AdamW when transformers is
not installed)
step() and zero_grad() delegate to the fallback when present. The
CUDA path is unchanged.
Adds a focused unit test that patches torch.cuda.is_available() so it
runs deterministically on any host.
Fixes microsoft#17403
There was a problem hiding this comment.
Pull request overview
Adds a CPU-safe behavior for FusedAdam in ORTTraining so CPU-only PyTorch environments no longer crash at import/initialization time, and introduces unit tests to validate the new fallback path.
Changes:
- Add
torch.cuda.is_available()gating inFusedAdam.__init__and create a CPU fallback optimizer with warnings when CUDA fused kernels are unavailable. - Route
step()andzero_grad()through the fallback optimizer when running on CPU. - Add a dedicated unit test file that forces CUDA-off behavior and validates the fallback updates.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
orttraining/orttraining/python/training/optim/fused_adam.py |
Adds CPU fallback construction and delegates step()/zero_grad() when CUDA is unavailable. |
orttraining/orttraining/test/python/orttraining_test_fused_adam_cpu_fallback.py |
New tests that patch torch.cuda.is_available() to exercise and validate the CPU fallback path. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Build an equivalent standard PyTorch optimizer for the CPU path. | ||
| # Retrieve the flat list of parameters from the already-registered param_groups. | ||
| _params = [p for group in self.param_groups for p in group["params"]] | ||
| if adam_w_mode == AdamWMode.ADAM_L2_REGULARIZATION: | ||
| self._cpu_fallback_optimizer = torch.optim.Adam( | ||
| _params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay | ||
| ) | ||
| elif adam_w_mode == AdamWMode.ADAMW_TORCH: | ||
| self._cpu_fallback_optimizer = torch.optim.AdamW( | ||
| _params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay | ||
| ) |
There was a problem hiding this comment.
The CPU fallback flattens parameters into a single list and rebuilds the optimizer with the top-level lr/betas/eps/weight_decay args. If callers pass parameter-group dicts (different lrs/weight_decay per group), the CUDA path honors them via self.param_groups but the CPU fallback will silently ignore those per-group settings. Please preserve parameter groups (and their options) in the fallback optimizer construction.
| else: | ||
| # AdamWMode.ADAMW_TRANSFORMERS (default): prefer transformers.AdamW | ||
| try: | ||
| from transformers import AdamW as _TransformersAdamW # noqa: PLC0415 | ||
|
|
||
| self._cpu_fallback_optimizer = _TransformersAdamW( | ||
| _params, | ||
| lr=lr, | ||
| betas=betas, | ||
| eps=eps, | ||
| weight_decay=weight_decay, | ||
| correct_bias=bias_correction, | ||
| ) | ||
| except ImportError: | ||
| warnings.warn( | ||
| "transformers package not available; using torch.optim.AdamW as CPU fallback " | ||
| "for AdamWMode.ADAMW_TRANSFORMERS.", | ||
| UserWarning, | ||
| stacklevel=2, | ||
| ) | ||
| self._cpu_fallback_optimizer = torch.optim.AdamW( | ||
| _params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay | ||
| ) |
There was a problem hiding this comment.
In the ADAMW_TRANSFORMERS fallback, the except ImportError path switches to torch.optim.AdamW, which can’t represent correct_bias=bias_correction. If bias_correction=False, CPU fallback behavior will diverge from the intended Transformers/AdamW math without any guard. Consider either (a) requiring transformers for this mode, (b) raising when bias_correction=False and transformers is unavailable, or (c) implementing a small CPU update that matches the Transformers variant.
| * Add the optim/ source directory to sys.path. | ||
| * Pre-register "_multi_tensor_apply" in sys.modules by importing it | ||
| directly (it is pure Python with no external deps). | ||
| * Load fused_adam.py via importlib with __package__ set to the name we | ||
| used for the pre-registered _multi_tensor_apply module so that the |
There was a problem hiding this comment.
The module docstring says the import strategy “Add the optim/ source directory to sys.path”, but the test never modifies sys.path. Please update the docstring to reflect what the test actually does (importing via importlib.util.spec_from_file_location and sys.modules pre-registration) so future maintainers aren’t misled.
| * Add the optim/ source directory to sys.path. | |
| * Pre-register "_multi_tensor_apply" in sys.modules by importing it | |
| directly (it is pure Python with no external deps). | |
| * Load fused_adam.py via importlib with __package__ set to the name we | |
| used for the pre-registered _multi_tensor_apply module so that the | |
| * Load ``_multi_tensor_apply.py`` directly from the optim/ source | |
| directory via ``importlib.util.spec_from_file_location``. | |
| * Pre-register that module in ``sys.modules`` under a synthetic package | |
| name so the relative import inside ``fused_adam.py`` can resolve. | |
| * Load ``fused_adam.py`` via ``importlib.util.spec_from_file_location`` | |
| with ``__package__`` set to that synthetic package name so the |
| # Build an equivalent standard PyTorch optimizer for the CPU path. | ||
| # Retrieve the flat list of parameters from the already-registered param_groups. | ||
| _params = [p for group in self.param_groups for p in group["params"]] | ||
| if adam_w_mode == AdamWMode.ADAM_L2_REGULARIZATION: |
There was a problem hiding this comment.
CPU fallback uses a separate torch.optim.* instance, but FusedAdam still exposes its own state/param_groups from super().__init__. This breaks core optimizer APIs in fallback mode: state_dict()/load_state_dict() won’t reflect the fallback’s moment estimates, and runtime edits like opt.param_groups[0]['lr']=... (or add_param_group) won’t affect the optimizer that step() actually uses. Consider delegating state_dict/load_state_dict/add_param_group (and ensuring param_groups/state stay in sync) when _cpu_fallback_optimizer is set.
Summary
FusedAdam.__init__now detectstorch.cuda.is_available()and falls back to a standard PyTorch optimizer on CPU instead of crashing.UserWarninginforms the user that the fused CUDA kernel is unavailable and a CPU implementation is in use.step()andzero_grad()delegate to the fallback when present; the CUDA path is unchanged.Motivation
On CPU-only PyTorch builds,
FusedAdamraises immediately in__init__because it unconditionally:torch.cuda.IntTensor([0])as an overflow buffer.onnxruntime.training.ortmodule.torch_cpp_extensions.fused_ops.This makes it impossible to use
FusedAdamin CPU-only test/dev environments or to write code that transparently works on either device. The maintainer (@baijumeswani) confirmed in the issue that a CPU fallback with a warning is the desired fix.Fixes #17403
Changes
orttraining/orttraining/python/training/optim/fused_adam.py:if torch.cuda.is_available().self._cpu_fallback_optimizerbased onadam_w_mode:ADAM_L2_REGULARIZATION→torch.optim.Adam(weight_decay applied as L2 regularization)ADAMW_TORCH→torch.optim.AdamWADAMW_TRANSFORMERS→transformers.AdamW(withtorch.optim.AdamWfallback whentransformersis not installed, plus a second warning)UserWarningper instance.step()andzero_grad()early-return through the fallback when set.orttraining/orttraining/test/python/orttraining_test_fused_adam_cpu_fallback.py(new):torch.cuda.is_available()to returnFalseso tests run deterministically on any host.UserWarning.step()produces parameter updates equivalent totorch.optim.AdamW.AdamWMode.ADAM_L2_REGULARIZATIONinstantiates and steps without raising.Test Plan
python -m pytest orttraining/orttraining/test/python/orttraining_test_fused_adam_cpu_fallback.py -v— 3 passed.lintrunner -aon both files — clean, no changes applied.