Skip to content

Commit 3a94887

Browse files
committed
ADLR/megatron-lm!4225 - [Dev][NVFP4][MOE] Proper NVFP4 Zero Padding for MOE
Co-authored-by: Zhongbo Zhu <[email protected]>
1 parent 6b7197c commit 3a94887

File tree

4 files changed

+19
-6
lines changed

4 files changed

+19
-6
lines changed

megatron/core/fp4_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def is_nvfp4tensor(tensor: torch.Tensor) -> bool:
4949

5050
def get_fp4_align_size(fp4_recipe: Fp4Recipe) -> int:
5151
"""
52-
Get the alignment size required for FP4 GEMM.
52+
Get the alignment size required for FP4 GEMM.
5353
FP4 GEMM requires Blackwell and later architectures.
5454
5555
The value 32 is a hardware requirement: TMA (Tensor Memory Accelerator) requires

megatron/core/transformer/moe/experts.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
ShardedTensorFactory,
2222
)
2323
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
24-
from megatron.core.fp8_utils import get_fp8_align_size
2524
from megatron.core.fp4_utils import get_fp4_align_size
25+
from megatron.core.fp8_utils import get_fp8_align_size
2626
from megatron.core.fusions.fused_bias_geglu import quick_gelu, weighted_bias_quick_geglu_impl
2727
from megatron.core.fusions.fused_bias_swiglu import weighted_bias_swiglu_impl
2828
from megatron.core.fusions.fused_weighted_squared_relu import weighted_squared_relu_impl
@@ -136,7 +136,9 @@ def glu(x):
136136
and "moe_act" in self.config.recompute_modules
137137
)
138138
if self.activation_recompute and (self.config.fp8 or self.config.fp4):
139-
raise ValueError("moe_act recompute for fp8 or fp4 cannot work with the legacy GroupedMLP.")
139+
raise ValueError(
140+
"moe_act recompute for fp8 or fp4 cannot work with the legacy GroupedMLP."
141+
)
140142

141143
@jit_fuser
142144
def activation_func_with_probs(x, probs):

megatron/core/transformer/moe/token_dispatcher.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
from megatron.core import utils
1010
from megatron.core.config import is_experimental_enabled
11-
from megatron.core.fp8_utils import get_fp8_align_size
1211
from megatron.core.fp4_utils import get_fp4_align_size
12+
from megatron.core.fp8_utils import get_fp8_align_size
1313
from megatron.core.fusions.fused_indices_converter import fused_indices_to_multihot
1414
from megatron.core.fusions.fused_pad_routing_map import fused_pad_routing_map
1515
from megatron.core.tensor_parallel import (
@@ -1143,6 +1143,14 @@ def get_restored_hidden_states_by_experts(self, hidden_states: torch.Tensor) ->
11431143
)
11441144
return hidden_states
11451145

1146+
def get_align_size_for_quantization(self):
1147+
"""Get the alignment size for quantization."""
1148+
if self.config.fp8:
1149+
return get_fp8_align_size(self.config.fp8_recipe)
1150+
elif self.config.fp4:
1151+
return get_fp4_align_size(self.config.fp4_recipe)
1152+
return 16
1153+
11461154

11471155
class MoEFlexTokenDispatcher(MoETokenDispatcher):
11481156
"""A flexible token dispatcher that abstracts the underlying tensor and expert

megatron/core/transformer/transformer_config.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,13 +1313,16 @@ def __post_init__(self):
13131313
if self.moe_router_padding_for_fp8:
13141314
# enable moe_router_padding_for_quantization
13151315
warnings.warn(
1316-
"--moe-router-padding-for-fp8 is going to be deprecated. Use --moe-router-padding-for-quantization instead."
1316+
"--moe-router-padding-for-fp8 is going to be deprecated. "
1317+
"Use --moe-router-padding-for-quantization instead."
13171318
)
13181319
self.moe_router_padding_for_quantization = True
13191320

13201321
if self.moe_router_padding_for_quantization:
13211322
if self.fp8 is None and self.fp4 is None:
1322-
raise ValueError("fp8/fp4 must be specified when moe_router_padding_for_quantization is True.")
1323+
raise ValueError(
1324+
"fp8/fp4 must be specified when moe_router_padding_for_quantization is True."
1325+
)
13231326

13241327
if self.moe_token_dispatcher_type in ["allgather", "alltoall_seq"]:
13251328
raise ValueError(

0 commit comments

Comments
 (0)