Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 60 additions & 7 deletions orttraining/orttraining/python/training/optim/fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
"""

import warnings
from enum import IntEnum

import torch
Expand All @@ -31,7 +32,10 @@ class FusedAdam(torch.optim.Optimizer):
when adam_w_mode = 1 and `torch/Adam <https://github.com/pytorch/pytorch/blob/a217a62e73fd30b658743af8a69966f90327f018/torch/optim/adamw.py#L6>`_
when adam_w_mode = 2

Currently GPU-only.
On CUDA-capable systems this optimizer uses fused CUDA kernels for efficiency.
On CPU-only systems (or when ``torch.cuda.is_available()`` returns ``False``) it
automatically falls back to an equivalent standard PyTorch optimizer with a
one-time :class:`UserWarning`. Performance will be reduced in fallback mode.

This version of fused Adam implements 2 fusions.

Expand Down Expand Up @@ -83,16 +87,62 @@ def __init__(
self._adam_w_mode = adam_w_mode
self._set_grad_none = set_grad_none

# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
if torch.cuda.is_available():
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])

from onnxruntime.training.ortmodule.torch_cpp_extensions import fused_ops # noqa: PLC0415
from onnxruntime.training.ortmodule.torch_cpp_extensions import fused_ops # noqa: PLC0415

self._multi_tensor_adam = fused_ops.multi_tensor_adam
self._multi_tensor_applier = MultiTensorApply(2048 * 32)
self._TorchTensorVector = fused_ops.TorchTensorVector
self._multi_tensor_adam = fused_ops.multi_tensor_adam
self._multi_tensor_applier = MultiTensorApply(2048 * 32)
self._TorchTensorVector = fused_ops.TorchTensorVector
self._cpu_fallback_optimizer = None
else:
warnings.warn(
"FusedAdam CUDA kernels are unavailable; falling back to a standard PyTorch optimizer on CPU. "
"Performance will be reduced.",
UserWarning,
stacklevel=2,
)
# Build an equivalent standard PyTorch optimizer for the CPU path.
# Retrieve the flat list of parameters from the already-registered param_groups.
_params = [p for group in self.param_groups for p in group["params"]]
if adam_w_mode == AdamWMode.ADAM_L2_REGULARIZATION:
Comment on lines +107 to +110
Copy link

Copilot AI Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CPU fallback uses a separate torch.optim.* instance, but FusedAdam still exposes its own state/param_groups from super().__init__. This breaks core optimizer APIs in fallback mode: state_dict()/load_state_dict() won’t reflect the fallback’s moment estimates, and runtime edits like opt.param_groups[0]['lr']=... (or add_param_group) won’t affect the optimizer that step() actually uses. Consider delegating state_dict/load_state_dict/add_param_group (and ensuring param_groups/state stay in sync) when _cpu_fallback_optimizer is set.

Copilot uses AI. Check for mistakes.
self._cpu_fallback_optimizer = torch.optim.Adam(
_params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay
)
elif adam_w_mode == AdamWMode.ADAMW_TORCH:
self._cpu_fallback_optimizer = torch.optim.AdamW(
_params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay
)
Comment on lines +107 to +117
Copy link

Copilot AI Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CPU fallback flattens parameters into a single list and rebuilds the optimizer with the top-level lr/betas/eps/weight_decay args. If callers pass parameter-group dicts (different lrs/weight_decay per group), the CUDA path honors them via self.param_groups but the CPU fallback will silently ignore those per-group settings. Please preserve parameter groups (and their options) in the fallback optimizer construction.

Copilot uses AI. Check for mistakes.
else:
# AdamWMode.ADAMW_TRANSFORMERS (default): prefer transformers.AdamW
try:
from transformers import AdamW as _TransformersAdamW # noqa: PLC0415

self._cpu_fallback_optimizer = _TransformersAdamW(
_params,
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
correct_bias=bias_correction,
)
except ImportError:
warnings.warn(
"transformers package not available; using torch.optim.AdamW as CPU fallback "
"for AdamWMode.ADAMW_TRANSFORMERS.",
UserWarning,
stacklevel=2,
)
self._cpu_fallback_optimizer = torch.optim.AdamW(
_params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay
)
Comment on lines +118 to +140
Copy link

Copilot AI Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the ADAMW_TRANSFORMERS fallback, the except ImportError path switches to torch.optim.AdamW, which can’t represent correct_bias=bias_correction. If bias_correction=False, CPU fallback behavior will diverge from the intended Transformers/AdamW math without any guard. Consider either (a) requiring transformers for this mode, (b) raising when bias_correction=False and transformers is unavailable, or (c) implementing a small CPU update that matches the Transformers variant.

Copilot uses AI. Check for mistakes.

def zero_grad(self, set_to_none=True):
if self._cpu_fallback_optimizer is not None:
self._cpu_fallback_optimizer.zero_grad(set_to_none=self._set_grad_none or set_to_none)
return
if self._set_grad_none or set_to_none:
for group in self.param_groups:
for p in group["params"]:
Expand All @@ -109,6 +159,9 @@ def step(self, closure=None):

The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes.
"""
if self._cpu_fallback_optimizer is not None:
return self._cpu_fallback_optimizer.step(closure)

loss = None
if closure is not None:
loss = closure()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Unit tests for FusedAdam CPU fallback (issue #17403).

These tests patch torch.cuda.is_available to return False so they run
deterministically on both CPU-only and CUDA machines.

Import strategy:
* Add the optim/ source directory to sys.path.
* Pre-register "_multi_tensor_apply" in sys.modules by importing it
directly (it is pure Python with no external deps).
* Load fused_adam.py via importlib with __package__ set to the name we
used for the pre-registered _multi_tensor_apply module so that the
Comment on lines +9 to +13
Copy link

Copilot AI Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The module docstring says the import strategy “Add the optim/ source directory to sys.path”, but the test never modifies sys.path. Please update the docstring to reflect what the test actually does (importing via importlib.util.spec_from_file_location and sys.modules pre-registration) so future maintainers aren’t misled.

Suggested change
* Add the optim/ source directory to sys.path.
* Pre-register "_multi_tensor_apply" in sys.modules by importing it
directly (it is pure Python with no external deps).
* Load fused_adam.py via importlib with __package__ set to the name we
used for the pre-registered _multi_tensor_apply module so that the
* Load ``_multi_tensor_apply.py`` directly from the optim/ source
directory via ``importlib.util.spec_from_file_location``.
* Pre-register that module in ``sys.modules`` under a synthetic package
name so the relative import inside ``fused_adam.py`` can resolve.
* Load ``fused_adam.py`` via ``importlib.util.spec_from_file_location``
with ``__package__`` set to that synthetic package name so the

Copilot uses AI. Check for mistakes.
relative import resolves correctly.

This avoids touching the training/__init__.py which requires the compiled
onnxruntime C extension (not available in the source-tree environment).
The CUDA extension import inside fused_adam.__init__ is guarded by
``if torch.cuda.is_available():`` and never runs with the mock in place.
"""

import importlib.util
import sys
import warnings
from pathlib import Path
from unittest.mock import patch

import torch
import torch.nn as nn

# ---------------------------------------------------------------------------
# Locate the optim source directory.
# File layout:
# orttraining/orttraining/test/python/ <- __file__ (parents[0])
# orttraining/orttraining/test/ <- parents[1]
# orttraining/orttraining/ <- parents[2]
# orttraining/orttraining/python/training/optim/ <- _OPTIM_DIR
# ---------------------------------------------------------------------------
_OPTIM_DIR = Path(__file__).resolve().parents[2] / "python" / "training" / "optim"
assert _OPTIM_DIR.is_dir(), f"optim dir not found: {_OPTIM_DIR}"

# Step 1: load _multi_tensor_apply as a top-level module (no package needed,
# it has zero external imports) and register it under the name that the
# relative import inside fused_adam.py expects.
_PKG = "fused_adam_pkg"
_mta_spec = importlib.util.spec_from_file_location(
f"{_PKG}._multi_tensor_apply",
_OPTIM_DIR / "_multi_tensor_apply.py",
)
_mta_mod = importlib.util.module_from_spec(_mta_spec)
sys.modules[f"{_PKG}._multi_tensor_apply"] = _mta_mod
_mta_spec.loader.exec_module(_mta_mod)

# Step 2: load fused_adam.py with __package__ = _PKG so its relative import
# "from ._multi_tensor_apply import ..." resolves to the entry above.
_fa_spec = importlib.util.spec_from_file_location(
f"{_PKG}.fused_adam",
_OPTIM_DIR / "fused_adam.py",
)
_fa_mod = importlib.util.module_from_spec(_fa_spec)
_fa_mod.__package__ = _PKG
# Patch cuda before executing the module body so the CUDA block is skipped.
with patch("torch.cuda.is_available", return_value=False):
_fa_spec.loader.exec_module(_fa_mod)

AdamWMode = _fa_mod.AdamWMode
FusedAdam = _fa_mod.FusedAdam


def _make_param(shape=(3, 3)):
"""Return an nn.Parameter with a synthetic gradient."""
p = nn.Parameter(torch.randn(*shape))
p.grad = torch.randn(*shape)
return p


@patch("torch.cuda.is_available", return_value=False)
class TestFusedAdamCpuFallback:
"""All tests run with CUDA disabled to exercise the CPU fallback path."""

def test_instantiation_warns_and_succeeds(self, _mock_cuda):
"""FusedAdam must instantiate without error and emit a UserWarning."""
param = nn.Parameter(torch.randn(3, 3))
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
opt = FusedAdam([param], lr=1e-3)

assert opt is not None, "FusedAdam should instantiate on CPU"
user_warnings = [w for w in caught if issubclass(w.category, UserWarning)]
assert len(user_warnings) >= 1, "Expected at least one UserWarning about CPU fallback"
assert any("CPU" in str(w.message) or "fallback" in str(w.message).lower() for w in user_warnings)

def test_step_updates_params_like_adamw(self, _mock_cuda):
"""After one step, params must change in the same direction as torch.optim.AdamW."""
torch.manual_seed(42)
weight_init = torch.randn(4, 4)
grad = torch.randn(4, 4)

# FusedAdam (CPU fallback) path — ADAMW_TORCH maps to torch.optim.AdamW
p_fused = nn.Parameter(weight_init.clone())
p_fused.grad = grad.clone()
with warnings.catch_warnings(record=True):
warnings.simplefilter("always")
opt_fused = FusedAdam([p_fused], lr=1e-3, adam_w_mode=AdamWMode.ADAMW_TORCH)
opt_fused.step()

# Reference: plain torch.optim.AdamW with matching hyperparams
p_ref = nn.Parameter(weight_init.clone())
p_ref.grad = grad.clone()
opt_ref = torch.optim.AdamW([p_ref], lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0)
opt_ref.step()

assert not torch.allclose(p_fused.data, weight_init), "Parameters should have changed after step"
assert torch.allclose(p_fused.data, p_ref.data, atol=1e-5), (
f"FusedAdam CPU fallback should match torch.optim.AdamW.\n"
f"Max diff: {(p_fused.data - p_ref.data).abs().max().item()}"
)

def test_adam_l2_mode_instantiates_and_steps(self, _mock_cuda):
"""AdamWMode.ADAM_L2_REGULARIZATION must instantiate and step without error."""
param = _make_param()
with warnings.catch_warnings(record=True):
warnings.simplefilter("always")
opt = FusedAdam([param], lr=1e-3, adam_w_mode=AdamWMode.ADAM_L2_REGULARIZATION)

before = param.data.clone()
opt.step()
assert not torch.allclose(param.data, before), "Parameters should change after step"
Loading