Skip to content

feat(orttraining): add CPU fallback for FusedAdam optimizer#28233

Open
Rishi-Dave wants to merge 1 commit intomicrosoft:mainfrom
Rishi-Dave:rishidave/feat/fused-adam-cpu-fallback
Open

feat(orttraining): add CPU fallback for FusedAdam optimizer#28233
Rishi-Dave wants to merge 1 commit intomicrosoft:mainfrom
Rishi-Dave:rishidave/feat/fused-adam-cpu-fallback

Conversation

@Rishi-Dave
Copy link
Copy Markdown
Contributor

Summary

  • FusedAdam.__init__ now detects torch.cuda.is_available() and falls back to a standard PyTorch optimizer on CPU instead of crashing.
  • A one-time UserWarning informs the user that the fused CUDA kernel is unavailable and a CPU implementation is in use.
  • step() and zero_grad() delegate to the fallback when present; the CUDA path is unchanged.

Motivation

On CPU-only PyTorch builds, FusedAdam raises immediately in __init__ because it unconditionally:

  1. Allocates torch.cuda.IntTensor([0]) as an overflow buffer.
  2. Imports the CUDA-only C++ extension onnxruntime.training.ortmodule.torch_cpp_extensions.fused_ops.

This makes it impossible to use FusedAdam in CPU-only test/dev environments or to write code that transparently works on either device. The maintainer (@baijumeswani) confirmed in the issue that a CPU fallback with a warning is the desired fix.

Fixes #17403

Changes

orttraining/orttraining/python/training/optim/fused_adam.py:

  • Wrap the two CUDA-specific allocations in if torch.cuda.is_available().
  • On CPU, build self._cpu_fallback_optimizer based on adam_w_mode:
    • ADAM_L2_REGULARIZATIONtorch.optim.Adam (weight_decay applied as L2 regularization)
    • ADAMW_TORCHtorch.optim.AdamW
    • ADAMW_TRANSFORMERStransformers.AdamW (with torch.optim.AdamW fallback when transformers is not installed, plus a second warning)
  • Emit a single UserWarning per instance.
  • step() and zero_grad() early-return through the fallback when set.
  • Update the docstring to drop the "GPU-only" claim.

orttraining/orttraining/test/python/orttraining_test_fused_adam_cpu_fallback.py (new):

  • Patches torch.cuda.is_available() to return False so tests run deterministically on any host.
  • Asserts instantiation succeeds and emits a UserWarning.
  • Asserts a single step() produces parameter updates equivalent to torch.optim.AdamW.
  • Asserts AdamWMode.ADAM_L2_REGULARIZATION instantiates and steps without raising.

Test Plan

  • python -m pytest orttraining/orttraining/test/python/orttraining_test_fused_adam_cpu_fallback.py -v — 3 passed.
  • lintrunner -a on both files — clean, no changes applied.
  • The CUDA code path is byte-for-byte unchanged in behavior; only wrapped in a conditional. No behavioral change for existing GPU users.

FusedAdam previously failed to instantiate on CPU-only PyTorch builds
because __init__ unconditionally allocated a torch.cuda.IntTensor and
imported the CUDA-only fused_ops C++ extension.

Detect torch.cuda.is_available() at construction time. When CUDA is
unavailable, emit a one-time UserWarning and build a standard PyTorch
optimizer that matches the requested AdamWMode:

- ADAM_L2_REGULARIZATION -> torch.optim.Adam
- ADAMW_TORCH            -> torch.optim.AdamW
- ADAMW_TRANSFORMERS     -> transformers.AdamW (falls back to
                            torch.optim.AdamW when transformers is
                            not installed)

step() and zero_grad() delegate to the fallback when present. The
CUDA path is unchanged.

Adds a focused unit test that patches torch.cuda.is_available() so it
runs deterministically on any host.

Fixes microsoft#17403
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a CPU-safe behavior for FusedAdam in ORTTraining so CPU-only PyTorch environments no longer crash at import/initialization time, and introduces unit tests to validate the new fallback path.

Changes:

  • Add torch.cuda.is_available() gating in FusedAdam.__init__ and create a CPU fallback optimizer with warnings when CUDA fused kernels are unavailable.
  • Route step() and zero_grad() through the fallback optimizer when running on CPU.
  • Add a dedicated unit test file that forces CUDA-off behavior and validates the fallback updates.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
orttraining/orttraining/python/training/optim/fused_adam.py Adds CPU fallback construction and delegates step()/zero_grad() when CUDA is unavailable.
orttraining/orttraining/test/python/orttraining_test_fused_adam_cpu_fallback.py New tests that patch torch.cuda.is_available() to exercise and validate the CPU fallback path.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +107 to +117
# Build an equivalent standard PyTorch optimizer for the CPU path.
# Retrieve the flat list of parameters from the already-registered param_groups.
_params = [p for group in self.param_groups for p in group["params"]]
if adam_w_mode == AdamWMode.ADAM_L2_REGULARIZATION:
self._cpu_fallback_optimizer = torch.optim.Adam(
_params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay
)
elif adam_w_mode == AdamWMode.ADAMW_TORCH:
self._cpu_fallback_optimizer = torch.optim.AdamW(
_params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay
)
Copy link

Copilot AI Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CPU fallback flattens parameters into a single list and rebuilds the optimizer with the top-level lr/betas/eps/weight_decay args. If callers pass parameter-group dicts (different lrs/weight_decay per group), the CUDA path honors them via self.param_groups but the CPU fallback will silently ignore those per-group settings. Please preserve parameter groups (and their options) in the fallback optimizer construction.

Copilot uses AI. Check for mistakes.
Comment on lines +118 to +140
else:
# AdamWMode.ADAMW_TRANSFORMERS (default): prefer transformers.AdamW
try:
from transformers import AdamW as _TransformersAdamW # noqa: PLC0415

self._cpu_fallback_optimizer = _TransformersAdamW(
_params,
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
correct_bias=bias_correction,
)
except ImportError:
warnings.warn(
"transformers package not available; using torch.optim.AdamW as CPU fallback "
"for AdamWMode.ADAMW_TRANSFORMERS.",
UserWarning,
stacklevel=2,
)
self._cpu_fallback_optimizer = torch.optim.AdamW(
_params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay
)
Copy link

Copilot AI Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the ADAMW_TRANSFORMERS fallback, the except ImportError path switches to torch.optim.AdamW, which can’t represent correct_bias=bias_correction. If bias_correction=False, CPU fallback behavior will diverge from the intended Transformers/AdamW math without any guard. Consider either (a) requiring transformers for this mode, (b) raising when bias_correction=False and transformers is unavailable, or (c) implementing a small CPU update that matches the Transformers variant.

Copilot uses AI. Check for mistakes.
Comment on lines +9 to +13
* Add the optim/ source directory to sys.path.
* Pre-register "_multi_tensor_apply" in sys.modules by importing it
directly (it is pure Python with no external deps).
* Load fused_adam.py via importlib with __package__ set to the name we
used for the pre-registered _multi_tensor_apply module so that the
Copy link

Copilot AI Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The module docstring says the import strategy “Add the optim/ source directory to sys.path”, but the test never modifies sys.path. Please update the docstring to reflect what the test actually does (importing via importlib.util.spec_from_file_location and sys.modules pre-registration) so future maintainers aren’t misled.

Suggested change
* Add the optim/ source directory to sys.path.
* Pre-register "_multi_tensor_apply" in sys.modules by importing it
directly (it is pure Python with no external deps).
* Load fused_adam.py via importlib with __package__ set to the name we
used for the pre-registered _multi_tensor_apply module so that the
* Load ``_multi_tensor_apply.py`` directly from the optim/ source
directory via ``importlib.util.spec_from_file_location``.
* Pre-register that module in ``sys.modules`` under a synthetic package
name so the relative import inside ``fused_adam.py`` can resolve.
* Load ``fused_adam.py`` via ``importlib.util.spec_from_file_location``
with ``__package__`` set to that synthetic package name so the

Copilot uses AI. Check for mistakes.
Comment on lines +107 to +110
# Build an equivalent standard PyTorch optimizer for the CPU path.
# Retrieve the flat list of parameters from the already-registered param_groups.
_params = [p for group in self.param_groups for p in group["params"]]
if adam_w_mode == AdamWMode.ADAM_L2_REGULARIZATION:
Copy link

Copilot AI Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CPU fallback uses a separate torch.optim.* instance, but FusedAdam still exposes its own state/param_groups from super().__init__. This breaks core optimizer APIs in fallback mode: state_dict()/load_state_dict() won’t reflect the fallback’s moment estimates, and runtime edits like opt.param_groups[0]['lr']=... (or add_param_group) won’t affect the optimizer that step() actually uses. Consider delegating state_dict/load_state_dict/add_param_group (and ensuring param_groups/state stay in sync) when _cpu_fallback_optimizer is set.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] [Training] Support Fused Adam for CPU

2 participants