Skip to content

Commit f2e34db

Browse files
committed
attempt to fix NVFP4 + selective recompute
Signed-off-by: Zhongbo Zhu <[email protected]>
1 parent 3a94887 commit f2e34db

File tree

6 files changed

+44
-19
lines changed

6 files changed

+44
-19
lines changed

megatron/core/transformer/attention.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,20 @@ def __init__(
205205

206206
if (
207207
HAVE_TE
208-
and self.config.fp8
209-
and self.config.fp8_recipe != 'delayed'
210-
and is_te_min_version("2.6.0dev0")
211208
and isinstance(self.linear_proj, TELinear)
209+
and (
210+
(
211+
self.config.fp8
212+
and self.config.fp8_recipe != 'delayed'
213+
and is_te_min_version("2.6.0dev0")
214+
)
215+
or (
216+
self.config.fp4
217+
and is_te_min_version("2.7.0.dev0")
218+
)
219+
)
212220
):
213-
# For fp8 training, the output of the fused core_attn is saved by itself, and
221+
# For fp8/fp4 training, the output of the fused core_attn is saved by itself, and
214222
# linear_proj also saves the quantized tensor of this output. Here we set the
215223
# linear_proj to save the original input tensors to avoid the extra memory usage of
216224
# the quantized tensor.
@@ -1129,7 +1137,7 @@ def _backward_output_proj(self):
11291137
self.linear_proj.backward_dw()
11301138

11311139
def set_for_recompute_input_layernorm(self):
1132-
"""Set the attention layer for recompute input_layernorm. Only needed for fp8."""
1140+
"""Set the attention layer for recompute input_layernorm. Only needed for fp8/fp4."""
11331141
from megatron.core.extensions.transformer_engine import set_save_original_input
11341142

11351143
set_save_original_input(self.linear_qkv)

megatron/core/transformer/moe/moe_layer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def experts_compute(
204204
if self.use_shared_expert and not self.shared_expert_overlap:
205205
# Compute the shared expert separately when not overlapped with communication.
206206
if self.shared_experts_recompute:
207-
if self.config.fp8:
207+
if self.config.fp8 or self.config.fp4:
208208
shared_expert_output = te_checkpoint(
209209
self.shared_experts,
210210
False,
@@ -272,7 +272,7 @@ def custom_forward(hidden_states):
272272
return output, mlp_bias
273273

274274
if self.moe_layer_recompute:
275-
if self.config.fp8:
275+
if self.config.fp8 or self.config.fp4:
276276
output, mlp_bias = te_checkpoint(
277277
custom_forward,
278278
False,
@@ -294,7 +294,7 @@ def backward_dw(self):
294294
self.shared_experts.backward_dw()
295295

296296
def set_for_recompute_pre_mlp_layernorm(self):
297-
"""Set the MoE layer for recompute pre_mlp_layernorm. Only needed for fp8."""
297+
"""Set the MoE layer for recompute pre_mlp_layernorm. Only needed for fp8/fp4."""
298298
# If shared_experts_recompute is used, nothing needs to be done because the checkpoint
299299
# function will save the original input tensors.
300300
if self.shared_experts is not None and not self.shared_experts_recompute:

megatron/core/transformer/moe/shared_experts.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,11 @@ def __init__(
6262
else:
6363
self.gate_weight = None
6464

65-
if self.config.fp8 and is_te_min_version("2.6.0dev0"):
66-
# For fp8 training, the output of pre_mlp_layernorm is saved by router, and
65+
if (
66+
(self.config.fp8 and is_te_min_version("2.6.0dev0"))
67+
or (self.config.fp4 and is_te_min_version("2.7.0.dev0"))
68+
):
69+
# For fp8/fp4 training, the output of pre_mlp_layernorm is saved by router, and
6770
# the shared expert linear_fc1 also saves the quantized tensor of this output.
6871
# Here we set the linear_fc1 to save the original input tensors to avoid the extra
6972
# memory usage of the quantized tensor.

megatron/core/transformer/multi_latent_attention.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,20 @@ def __init__(
177177

178178
if (
179179
HAVE_TE
180-
and self.config.fp8
181-
and self.config.fp8_recipe != 'delayed'
182-
and is_te_min_version("2.6.0dev0")
183180
and isinstance(self.linear_proj, TELinear)
181+
and (
182+
(
183+
self.config.fp8
184+
and self.config.fp8_recipe != 'delayed'
185+
and is_te_min_version("2.6.0dev0")
186+
)
187+
or (
188+
self.config.fp4
189+
and is_te_min_version("2.7.0.dev0")
190+
)
191+
)
184192
):
185-
# For fp8 training, the output of the fused core_attn is saved by itself, and
193+
# For fp8/fp4 training, the output of the fused core_attn is saved by itself, and
186194
# linear_proj also saves the quantized tensor of this output. Here we set the
187195
# linear_proj to save the original input tensors to avoid the extra memory usage of
188196
# the quantized tensor.
@@ -781,7 +789,8 @@ def qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_po
781789
return query, key, value
782790

783791
if self.recompute_up_proj:
784-
self.qkv_up_checkpoint = tensor_parallel.CheckpointWithoutOutput(fp8=self.config.fp8)
792+
quantization = self.config.fp8 or self.config.fp4
793+
self.qkv_up_checkpoint = tensor_parallel.CheckpointWithoutOutput(fp8=quantization)
785794
query, key, value = self.qkv_up_checkpoint.checkpoint(
786795
qkv_up_proj_and_rope_apply, q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb
787796
)
@@ -911,7 +920,7 @@ def _backward_output_proj(self):
911920
self.linear_proj.backward_dw()
912921

913922
def set_for_recompute_input_layernorm(self):
914-
"""Set the attention layer for recompute input_layernorm. Only needed for fp8."""
923+
"""Set the attention layer for recompute input_layernorm. Only needed for fp8/fp4."""
915924
from megatron.core.extensions.transformer_engine import set_save_original_input
916925

917926
if self.config.q_lora_rank is not None:

megatron/core/transformer/multi_token_prediction.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,11 @@ def _proj_and_transformer_layer(
570570
fp8_context = nullcontext()
571571
transformer_layer_fp8_context = nullcontext()
572572

573+
# TODO: currently no support for FP4 in MTP layers because we need more numerical validation
574+
# raise Error here to avoid unexpected behavior
575+
if self.config.fp4:
576+
raise ValueError("FP4 is not supported for MTP layers yet.")
577+
573578
with rng_context:
574579
with fp8_context:
575580
hidden_states = self._concat_embeddings(hidden_states, decoder_input)

megatron/core/transformer/transformer_layer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -391,11 +391,11 @@ def __init__(
391391
and not self.config.external_cuda_graph
392392
):
393393
self.recompute_input_layernorm = True
394-
if self.config.fp8:
394+
if self.config.fp8 or self.config.fp4:
395395
self.self_attention.set_for_recompute_input_layernorm()
396396
if not isinstance(self.pre_mlp_layernorm, IdentityOp):
397397
self.recompute_pre_mlp_layernorm = True
398-
if self.config.fp8:
398+
if self.config.fp8 or self.config.fp4:
399399
if isinstance(self.mlp, MoELayer):
400400
self.mlp.set_for_recompute_pre_mlp_layernorm()
401401
else:
@@ -595,7 +595,7 @@ def _forward_mlp(self, hidden_states, inference_context=None):
595595
)
596596

597597
if self.recompute_mlp:
598-
if self.config.fp8:
598+
if self.config.fp8 or self.config.fp4:
599599
# import here to avoid circular import
600600
from megatron.core.extensions.transformer_engine import te_checkpoint
601601

0 commit comments

Comments
 (0)