Skip to content

Commit c49ffc1

Browse files
committed
GPG sign off
1 parent effebd8 commit c49ffc1

File tree

8 files changed

+469
-57
lines changed

8 files changed

+469
-57
lines changed

megatron/core/models/gpt/gpt_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,7 @@ def forward(
445445
*,
446446
inference_params: Optional[BaseInferenceContext] = None,
447447
loss_mask: Optional[Tensor] = None,
448+
padding_mask: Optional[Tensor] = None,
448449
) -> Tensor:
449450
"""Forward function of the GPT Model This function passes the input tensors
450451
through the embedding layer, and then the decoder and finally into the post
@@ -486,6 +487,7 @@ def forward(
486487
rotary_pos_cos_sin=rotary_pos_cos_sin,
487488
packed_seq_params=packed_seq_params,
488489
sequence_len_offset=sequence_len_offset,
490+
padding_mask=padding_mask,
489491
**(extra_block_kwargs or {}),
490492
)
491493

megatron/core/transformer/moe/moe_layer.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,19 @@ def __init__(
178178
self.cudagraph_tensor_store = MoECudaGraphTensorStore()
179179

180180
@maybe_skip_or_early_return_by_cudagraph("route")
181-
def route(self, hidden_states: torch.Tensor):
181+
def route(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
182182
"""Compute token routing for preprocessing.
183183
184184
This method uses the router to determine which experts to send each token to,
185185
producing routing probabilities and a mapping.
186+
187+
Args:
188+
hidden_states (torch.Tensor): Input tensor.
189+
padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens.
190+
Shape in [seq_length, bsz]. True for valid tokens,
191+
False for padding tokens. Defaults to None.
186192
"""
187-
probs, routing_map = self.router(hidden_states)
193+
probs, routing_map = self.router(hidden_states, padding_mask=padding_mask)
188194
return probs, routing_map
189195

190196
@maybe_skip_or_early_return_by_cudagraph("preprocess")
@@ -270,7 +276,7 @@ def combine(self, output: torch.Tensor, shared_expert_output: Optional[torch.Ten
270276
output = output + shared_expert_output
271277
return output
272278

273-
def forward(self, hidden_states: torch.Tensor):
279+
def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
274280
"""Forward pass for the MoE layer.
275281
276282
The forward pass comprises four main steps:
@@ -280,7 +286,10 @@ def forward(self, hidden_states: torch.Tensor):
280286
4. Combine: The outputs from the experts are combined and returned.
281287
282288
Args:
283-
hidden_states (torch.Tensor): The input tensor to the MoE layer.
289+
hidden_states (torch.Tensor): The input tensor to the MoE layer.
290+
padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens.
291+
Shape in [seq_length, bsz]. True for valid tokens,
292+
False for padding tokens. Defaults to None.
284293
285294
Returns:
286295
A tuple containing the output tensor and the MLP bias, if any.
@@ -292,10 +301,10 @@ def forward(self, hidden_states: torch.Tensor):
292301
)
293302

294303
# MoE forward: route -> dispatch -> compute -> combine
295-
def custom_forward(hidden_states):
304+
def custom_forward(hidden_states, padding_mask=None):
296305
try:
297306
shared_expert_output = self.shared_experts_compute(hidden_states)
298-
probs, routing_map = self.route(hidden_states)
307+
probs, routing_map = self.route(hidden_states, padding_mask=padding_mask)
299308
hidden_states, probs, residual = self.preprocess(hidden_states, probs, routing_map)
300309
except MoECudaGraphPartialCaptureSignal as e:
301310
# This signal is raised from the maybe_skip_or_early_return_by_cudagraph decorator.
@@ -318,11 +327,12 @@ def custom_forward(hidden_states):
318327
tensor_parallel.random.get_cuda_rng_tracker,
319328
parallel_state.get_tensor_model_parallel_group(),
320329
hidden_states,
330+
padding_mask,
321331
)
322332
else:
323-
outputs = tensor_parallel.checkpoint(custom_forward, False, hidden_states)
333+
outputs = tensor_parallel.checkpoint(custom_forward, False, hidden_states, padding_mask)
324334
else:
325-
outputs = custom_forward(hidden_states)
335+
outputs = custom_forward(hidden_states, padding_mask)
326336

327337
return outputs
328338

megatron/core/transformer/moe/moe_utils.py

Lines changed: 103 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,69 @@
3434
_MOE_LAYER_WISE_LOGGING_TRACKER = {}
3535

3636

37+
def compute_tokens_per_expert_with_padding(
38+
routing_map: torch.Tensor,
39+
padding_mask: Optional[torch.Tensor] = None,
40+
reshape_for_seq_aux: bool = False,
41+
seq_length: Optional[int] = None,
42+
bsz: Optional[int] = None,
43+
num_experts: Optional[int] = None,
44+
):
45+
"""Compute tokens_per_expert and total_num_tokens with optional padding mask.
46+
This function provides a unified way to compute token counts across different aux loss types.
47+
48+
Args:
49+
routing_map (torch.Tensor): Token to expert routing map.
50+
- For aux_loss/global_aux_loss: shape [num_tokens, num_experts]
51+
- For seq_aux_loss: shape [num_tokens, num_experts] but will be reshaped
52+
padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens.
53+
Shape [num_tokens]. True for valid tokens, False for padding tokens.
54+
reshape_for_seq_aux (bool): If True, reshape routing_map for seq_aux_loss computation.
55+
seq_length (int, optional): Sequence length, required when reshape_for_seq_aux=True.
56+
bsz (int, optional): Batch size, required when reshape_for_seq_aux=True.
57+
num_experts (int, optional): Number of experts, required when reshape_for_seq_aux=True.
58+
59+
Returns:
60+
tuple: (tokens_per_expert, num_valid_tokens)
61+
- tokens_per_expert (torch.Tensor): Number of tokens per expert, shape [num_experts]
62+
or [bsz * num_experts] for seq_aux_loss
63+
- num_valid_tokens (torch.Tensor or int): Number of valid (non-padding) tokens
64+
"""
65+
if reshape_for_seq_aux:
66+
# seq aux loss
67+
assert seq_length is not None and bsz is not None and num_experts is not None, \
68+
"seq_length, bsz, and num_experts must be provided when reshape_for_seq_aux=True"
69+
70+
if padding_mask is not None:
71+
# Reshape padding_mask to [seq_length, bsz]
72+
padding_mask_reshaped = padding_mask.reshape(seq_length, bsz)
73+
# Expand to match routing_map after reshape [seq_length, bsz * num_experts]
74+
mask_expanded = padding_mask_reshaped.unsqueeze(-1).expand(
75+
-1, -1, num_experts
76+
).reshape(seq_length, -1)
77+
routing_map_masked = routing_map.reshape(seq_length, -1) & mask_expanded
78+
tokens_per_expert = routing_map_masked.sum(dim=0)
79+
# Count valid tokens only
80+
num_valid_tokens = padding_mask.sum()
81+
else:
82+
tokens_per_expert = routing_map.reshape(seq_length, -1).sum(dim=0)
83+
num_valid_tokens = routing_map.shape[0]
84+
else:
85+
# aux_loss or global_aux_loss
86+
if padding_mask is not None:
87+
# routing_map: [num_tokens, num_experts], padding_mask: [num_tokens]
88+
mask_expanded = padding_mask.unsqueeze(-1)
89+
routing_map_masked = routing_map & mask_expanded
90+
tokens_per_expert = routing_map_masked.sum(dim=0)
91+
# Count valid tokens only
92+
num_valid_tokens = padding_mask.sum()
93+
else:
94+
tokens_per_expert = routing_map.sum(dim=0)
95+
num_valid_tokens = routing_map.shape[0]
96+
97+
return tokens_per_expert, num_valid_tokens
98+
99+
37100
def switch_load_balancing_loss_func(
38101
probs: torch.Tensor,
39102
tokens_per_expert: torch.Tensor,
@@ -42,6 +105,7 @@ def switch_load_balancing_loss_func(
42105
num_experts: int,
43106
moe_aux_loss_coeff: float,
44107
fused: bool = False,
108+
padding_mask: Optional[torch.Tensor] = None,
45109
):
46110
"""Calculate the auxiliary loss for load balancing.
47111
Refer to the Switch Transformer (https://arxiv.org/abs/2101.03961)
@@ -92,9 +156,18 @@ def switch_load_balancing_loss_func(
92156
topk (int): The number of experts selected for each token.
93157
num_experts (int): The number of experts.
94158
moe_aux_loss_coeff (float): The coefficient for the auxiliary loss.
159+
padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens.
160+
Shape in [num_tokens]. True for valid tokens,
161+
False for padding tokens. Defaults to None.
95162
Returns:
96163
torch.Tensor: The auxiliary loss for load balancing.
97164
"""
165+
# Apply padding mask to probs if provided
166+
if padding_mask is not None:
167+
# padding_mask: [num_tokens], probs: [num_tokens, num_experts]
168+
mask_expanded = padding_mask.unsqueeze(-1)
169+
probs = probs * mask_expanded
170+
98171
if fused:
99172
if not HAVE_TE or fused_moe_aux_loss is None:
100173
raise ValueError("fused_moe_aux_loss is not available. Please install TE >= 2.7.0.")
@@ -114,18 +187,32 @@ def switch_load_balancing_loss_func(
114187
return aux_loss
115188

116189

117-
def z_loss_func(logits, z_loss_coeff):
190+
def z_loss_func(logits, z_loss_coeff, padding_mask: Optional[torch.Tensor] = None):
118191
"""Encourages the router's logits to remain small to enhance stability.
119192
Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.
120193
121194
Args:
122195
logits (torch.Tensor): The logits of the router.
196+
z_loss_coeff (float): The coefficient for the z-loss.
197+
padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens.
198+
Shape in [num_tokens]. True for valid tokens,
199+
False for padding tokens. Defaults to None.
123200
124201
Returns:
125202
torch.Tensor: The logits after applying the z-loss.
126203
"""
204+
logsum = torch.logsumexp(logits, dim=-1)
205+
z_loss_values = torch.square(logsum)
206+
207+
if padding_mask is not None:
208+
# Only compute z_loss for non-padding tokens
209+
z_loss_values = z_loss_values * padding_mask
210+
# Compute mean over valid tokens only
211+
num_valid_tokens = padding_mask.sum()
212+
z_loss = z_loss_values.sum() / torch.clamp(num_valid_tokens, min=1.0) * z_loss_coeff
213+
else:
214+
z_loss = torch.mean(z_loss_values) * z_loss_coeff
127215

128-
z_loss = torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) * z_loss_coeff
129216
return z_loss
130217

131218

@@ -623,21 +710,27 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None):
623710

624711

625712
def compute_routing_scores_for_aux_loss(
626-
logits: torch.Tensor, topk: int, score_function: str, fused: bool = False
713+
logits: torch.Tensor, topk: int, score_function: str, fused: bool = False,
714+
padding_mask: Optional[torch.Tensor] = None
627715
):
628716
"""Compute routing scores based on the score function.
629717
630718
Args:
631719
logits (torch.Tensor): The logits tensor after gating, shape: [num_tokens, num_experts].
632-
720+
padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens.
721+
Shape in [num_tokens]. True for valid tokens,
722+
False for padding tokens. Defaults to None.
633723
Returns:
634-
torch.Tensor: The normalized routing scores.
724+
Tuple[torch.Tensor, torch.Tensor]: routing_map and scores.
635725
"""
636726
if fused:
637727
if not HAVE_TE or fused_compute_score_for_moe_aux_loss is None:
638728
raise ValueError(
639729
"fused_compute_score_for_moe_aux_loss is not available. Please install TE >= 2.6.0."
640730
)
731+
# Note: fused implementation does not support padding_mask yet
732+
if padding_mask is not None:
733+
raise ValueError("Fused compute_routing_scores does not support padding_mask. Set fused=False.")
641734
return fused_compute_score_for_moe_aux_loss(
642735
logits=logits, topk=topk, score_function=score_function
643736
)
@@ -652,6 +745,11 @@ def compute_routing_scores_for_aux_loss(
652745

653746
_, top_indices = torch.topk(scores, k=topk, dim=1)
654747
routing_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()
748+
749+
# Apply padding mask to scores if provided
750+
if padding_mask is not None:
751+
mask_expanded = padding_mask.unsqueeze(-1).to(scores.dtype)
752+
scores = scores * mask_expanded
655753
return routing_map, scores
656754

657755

0 commit comments

Comments
 (0)