Skip to content

Commit c22c2aa

Browse files
lhb8125Hongbin Liu
andauthored
[Was PR1912][Dev] feat(moe): Fine-grained activation offloading (#1969)
Signed-off-by: Hongbin Liu <[email protected]> Signed-off-by: Hongbin Liu <[email protected]> Co-authored-by: Hongbin Liu <[email protected]>
1 parent 13edb58 commit c22c2aa

File tree

27 files changed

+2736
-61
lines changed

27 files changed

+2736
-61
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Fine-grained Activation Offloading (collaborated with rednote)
2+
3+
Memory capacity is more and more important with the rising of extreme sparse MoE models like DeepSeek-V3 and Qwen3-235B. Fine-grained recomputing reduces the memory footprint at the cost of extra recomputation, while offloading could utilize the host-device bandwidth to achieve nearly zero-overhead. Fine-grained Activation Offloading targets at offloading the activation at the granularity of specific modules, so that we can calibrate the amount of offloading activation to maximize the training throughput.
4+
5+
Currently, the supported offloading modules are `"attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act"`, which could work with fine-grained recomputation to release almost all activations of a transformer layer.
6+
7+
**Features**
8+
* Support PP=1/PP/Interleaved PP
9+
* Compatible with fine-grained recomputation
10+
* Support FP8
11+
* Support MTP
12+
* Support mixed dense & moe layer
13+
* Support A2A Overlap
14+
* Support CUDA Graph
15+
* (Temporary) cuda graph scope cannot contains the offloading modules
16+
17+
**Usage**
18+
```bash
19+
# Enable fine-grained activation offloading
20+
--fine-grained-activation-offloading
21+
22+
# Specify which modules are going to offload its input
23+
# Choices: "attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act".
24+
--offload-modules expert_fc1
25+
```
26+
**Compatible with Fine-grained Recomputation**
27+
- For modules with minor perf overhead like layernorm or moe_act, use recomputing to reduce memory footprint;
28+
- For other modules, use offloading to reduce memory footprint;
29+
- Make sure the offloading/reloading could be overlapped with computing;
30+
31+
![Fine-grained Activation Offloading and Fine-grained Recomputation](../images/fine_grained_activation_offloading/offloading_and_recomputing.png)

docs/source/api-guide/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@ API Guide
2222
optimizer_cpu_offload
2323
multi_token_prediction
2424
tokenizers
25+
fine_grained_activation_offloading
325 KB
Loading

megatron/core/extensions/transformer_engine.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
1+
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22

33
import dataclasses
44
import inspect
@@ -299,6 +299,7 @@ def __init__(
299299
extra_kwargs["delay_wgrad_compute"] = self.config.delay_wgrad_compute
300300
else:
301301
raise RuntimeError("Only TE with version >=2.3.0 supports delay_wgrad_compute now.")
302+
302303
if (
303304
self.config.tp_comm_overlap
304305
and tp_comm_buffer_name
@@ -2116,3 +2117,12 @@ def set_save_original_input(module):
21162117
"set_save_original_input is only needed on transformer-engine modules that save "
21172118
"quantized tensors by default. It needs transformer-engine>=2.6.0dev0."
21182119
)
2120+
2121+
2122+
try:
2123+
# pylint: disable=unused-import
2124+
from transformer_engine.pytorch import cpu_offload
2125+
from transformer_engine.pytorch.float8_tensor import Float8Tensor
2126+
except ImportError:
2127+
Float8Tensor = None
2128+
cpu_offload = None

megatron/core/models/common/model_chunk_schedule_plan.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
1+
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22

33
from contextlib import nullcontext
44
from typing import Optional
@@ -8,6 +8,9 @@
88

99
from megatron.core.enums import Fp8Recipe
1010
from megatron.core.fp8_utils import get_fp8_context
11+
from megatron.core.pipeline_parallel.fine_grained_activation_offload import (
12+
fine_grained_offloading_set_last_layer,
13+
)
1114
from megatron.core.pipeline_parallel.utils import (
1215
AbstractSchedulePlan,
1316
NoopScheduleNode,
@@ -450,6 +453,8 @@ def run(
450453
f_layer = f_schedule_plan.get_layer(i)
451454
b_layer = b_schedule_plan.get_layer(b_num_layers - 1 - i)
452455
torch.cuda.nvtx.range_push(f"layer_{i}f-layer_{b_num_layers - 1 - i}b")
456+
if f_layer.layer.config.fine_grained_activation_offloading:
457+
fine_grained_offloading_set_last_layer(i == f_num_layers - 1)
453458
f_input, b_grad = TransformerLayerSchedulePlan.run(
454459
f_layer,
455460
b_layer,
@@ -472,6 +477,8 @@ def run(
472477
for i in range(overlapped_layers, f_num_layers):
473478
f_layer = f_schedule_plan.get_layer(i)
474479
torch.cuda.nvtx.range_push(f"layer_{i}f")
480+
if f_layer.layer.config.fine_grained_activation_offloading:
481+
fine_grained_offloading_set_last_layer(i == f_num_layers - 1)
475482
f_input, _ = TransformerLayerSchedulePlan.run(f_layer, None, f_input=f_input)
476483
torch.cuda.nvtx.range_pop()
477484

megatron/core/models/gpt/fine_grained_callables.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
1+
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22

33
import weakref
44
from contextlib import nullcontext
@@ -8,6 +8,11 @@
88
import torch
99

1010
from megatron.core import tensor_parallel
11+
from megatron.core.pipeline_parallel.fine_grained_activation_offload import (
12+
fine_grained_offloading_group_commit,
13+
fine_grained_offloading_group_start,
14+
get_fine_grained_offloading_context,
15+
)
1116
from megatron.core.pipeline_parallel.utils import ScheduleNode, make_viewless
1217
from megatron.core.transformer.module import float16_to_fp32
1318
from megatron.core.transformer.moe.moe_layer import MoELayer
@@ -350,13 +355,17 @@ def submodule_post_attn_forward(node: ScheduleNode, hidden_states: torch.Tensor)
350355
Run forward pass for computations between attention and dispatch:
351356
pre mlp layernorm->router->dispatch preprocess
352357
"""
358+
if layer.offload_mlp_norm:
359+
hidden_states = fine_grained_offloading_group_start(hidden_states, name="mlp_norm")
353360
if layer.recompute_pre_mlp_layernorm:
354361
layer.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
355-
pre_mlp_layernorm_output = layer.pre_mlp_norm_checkpoint.checkpoint(
356-
layer.pre_mlp_layernorm, hidden_states
357-
)
362+
with get_fine_grained_offloading_context(layer.offload_mlp_norm):
363+
pre_mlp_layernorm_output = layer.pre_mlp_norm_checkpoint.checkpoint(
364+
layer.pre_mlp_layernorm, hidden_states
365+
)
358366
else:
359-
pre_mlp_layernorm_output = layer.pre_mlp_layernorm(hidden_states)
367+
with get_fine_grained_offloading_context(layer.offload_mlp_norm):
368+
pre_mlp_layernorm_output = layer.pre_mlp_layernorm(hidden_states)
360369

361370
local_tokens, probs, _ = layer.mlp.router_and_preprocess(pre_mlp_layernorm_output)
362371

@@ -437,6 +446,10 @@ def submodule_combine_forward(
437446
hidden_states = layer.mlp_bda(layer.training, layer.config.bias_dropout_fusion)(
438447
mlp_output_with_bias, residual, layer.hidden_dropout
439448
)
449+
if layer.offload_mlp_norm:
450+
(hidden_states,) = fine_grained_offloading_group_commit(
451+
hidden_states, name="mlp_norm", forced_released_tensors=[residual]
452+
)
440453
output = make_viewless_tensor(
441454
inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
442455
)

megatron/core/models/gpt/gpt_model.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
1+
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22

33
from collections import OrderedDict
44
from typing import Dict, Literal, Optional
@@ -18,6 +18,9 @@
1818
)
1919
from megatron.core.models.common.language_module.language_module import LanguageModule
2020
from megatron.core.packed_seq_params import PackedSeqParams
21+
from megatron.core.pipeline_parallel.fine_grained_activation_offload import (
22+
fine_grained_offloading_init_chunk_handler,
23+
)
2124
from megatron.core.process_groups_config import ProcessGroupCollection
2225
from megatron.core.quantization.utils import get_quant_config_or_none
2326
from megatron.core.tensor_parallel import gather_from_sequence_parallel_region
@@ -117,6 +120,7 @@ def __init__(
117120
self.parallel_output = parallel_output
118121
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
119122
self.vp_stage = vp_stage
123+
self.disable_param_offloading = True
120124

121125
if hasattr(self.config, 'position_embedding_type'):
122126
self.position_embedding_type = self.config.position_embedding_type
@@ -410,6 +414,24 @@ def _preprocess(
410414

411415
return preproc_output
412416

417+
def preprocess_for_fine_grained_offloading(self):
418+
"""Preprocess for fine-grained activation offloading."""
419+
fine_grained_offloading_init_chunk_handler(
420+
vp_size=self.config.virtual_pipeline_model_parallel_size,
421+
vp_stage=self.vp_stage,
422+
min_offloaded_tensor_size=self.config.min_offloaded_tensor_size,
423+
)
424+
if self.disable_param_offloading:
425+
for param in self.decoder.parameters():
426+
param.offloading_activation = False
427+
if self.mtp_process:
428+
for param in self.mtp.parameters():
429+
param.offloading_activation = False
430+
if self.post_process:
431+
for param in self.output_layer.parameters():
432+
param.offloading_activation = False
433+
self.disable_param_offloading = False
434+
413435
def forward(
414436
self,
415437
input_ids: Tensor,
@@ -435,6 +457,8 @@ def forward(
435457
runtime_gather_output (bool): Gather output at runtime. Default None means
436458
`parallel_output` arg in the constructor will be used.
437459
"""
460+
if self.config.fine_grained_activation_offloading:
461+
self.preprocess_for_fine_grained_offloading()
438462

439463
inference_context = deprecate_inference_params(inference_context, inference_params)
440464

@@ -701,6 +725,9 @@ def build_schedule_plan(
701725
TransformerModelChunkSchedulePlan: The model chunk schedule plan.
702726
"""
703727

728+
if self.config.fine_grained_activation_offloading:
729+
self.preprocess_for_fine_grained_offloading()
730+
704731
from ..common.model_chunk_schedule_plan import TransformerModelChunkSchedulePlan
705732

706733
return TransformerModelChunkSchedulePlan(

0 commit comments

Comments
 (0)