Skip to content

Commit f654cad

Browse files
author
Haochen Yuan
committed
add UT
1 parent 1bb6f7a commit f654cad

File tree

4 files changed

+273
-42
lines changed

4 files changed

+273
-42
lines changed

megatron/core/transformer/moe/moe_utils.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,70 @@
3434
_MOE_LAYER_WISE_LOGGING_TRACKER = {}
3535

3636

37+
def compute_tokens_per_expert_with_padding(
38+
routing_map: torch.Tensor,
39+
padding_mask: Optional[torch.Tensor] = None,
40+
reshape_for_seq_aux: bool = False,
41+
seq_length: Optional[int] = None,
42+
bsz: Optional[int] = None,
43+
num_experts: Optional[int] = None,
44+
):
45+
"""Compute tokens_per_expert and total_num_tokens with optional padding mask.
46+
47+
This function provides a unified way to compute token counts across different aux loss types.
48+
49+
Args:
50+
routing_map (torch.Tensor): Token to expert routing map.
51+
- For aux_loss/global_aux_loss: shape [num_tokens, num_experts]
52+
- For seq_aux_loss: shape [num_tokens, num_experts] but will be reshaped
53+
padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens.
54+
Shape [num_tokens]. True for valid tokens, False for padding tokens.
55+
reshape_for_seq_aux (bool): If True, reshape routing_map for seq_aux_loss computation.
56+
seq_length (int, optional): Sequence length, required when reshape_for_seq_aux=True.
57+
bsz (int, optional): Batch size, required when reshape_for_seq_aux=True.
58+
num_experts (int, optional): Number of experts, required when reshape_for_seq_aux=True.
59+
60+
Returns:
61+
tuple: (tokens_per_expert, num_valid_tokens)
62+
- tokens_per_expert (torch.Tensor): Number of tokens per expert, shape [num_experts]
63+
or [bsz * num_experts] for seq_aux_loss
64+
- num_valid_tokens (torch.Tensor or int): Number of valid (non-padding) tokens
65+
"""
66+
if reshape_for_seq_aux:
67+
# seq aux loss
68+
assert seq_length is not None and bsz is not None and num_experts is not None, \
69+
"seq_length, bsz, and num_experts must be provided when reshape_for_seq_aux=True"
70+
71+
if padding_mask is not None:
72+
# Reshape padding_mask to [seq_length, bsz]
73+
padding_mask_reshaped = padding_mask.reshape(seq_length, bsz)
74+
# Expand to match routing_map after reshape [seq_length, bsz * num_experts]
75+
mask_expanded = padding_mask_reshaped.unsqueeze(-1).expand(
76+
-1, -1, num_experts
77+
).reshape(seq_length, -1)
78+
routing_map_masked = routing_map.reshape(seq_length, -1) & mask_expanded
79+
tokens_per_expert = routing_map_masked.sum(dim=0)
80+
# Count valid tokens only
81+
num_valid_tokens = padding_mask.sum()
82+
else:
83+
tokens_per_expert = routing_map.reshape(seq_length, -1).sum(dim=0)
84+
num_valid_tokens = routing_map.shape[0]
85+
else:
86+
# aux_loss or global_aux_loss
87+
if padding_mask is not None:
88+
# routing_map: [num_tokens, num_experts], padding_mask: [num_tokens]
89+
mask_expanded = padding_mask.unsqueeze(-1)
90+
routing_map_masked = routing_map & mask_expanded
91+
tokens_per_expert = routing_map_masked.sum(dim=0)
92+
# Count valid tokens only
93+
num_valid_tokens = padding_mask.sum()
94+
else:
95+
tokens_per_expert = routing_map.sum(dim=0)
96+
num_valid_tokens = routing_map.shape[0]
97+
98+
return tokens_per_expert, num_valid_tokens
99+
100+
37101
def switch_load_balancing_loss_func(
38102
probs: torch.Tensor,
39103
tokens_per_expert: torch.Tensor,

megatron/core/transformer/moe/router.py

Lines changed: 31 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22

33
from abc import ABC, abstractmethod
4-
from typing import Optional
4+
from typing import Optional, Union
55

66
import torch
77

@@ -14,6 +14,7 @@
1414
apply_random_logits,
1515
apply_router_token_dropping,
1616
compute_routing_scores_for_aux_loss,
17+
compute_tokens_per_expert_with_padding,
1718
router_gating_linear,
1819
save_to_aux_losses_tracker,
1920
sinkhorn,
@@ -276,16 +277,12 @@ def _apply_aux_loss(
276277
if aux_loss_coeff == 0:
277278
return probs
278279

279-
if padding_mask is not None:
280-
# routing_map: [num_tokens, num_experts], padding_mask: [num_tokens]
281-
mask_expanded = padding_mask.unsqueeze(-1)
282-
routing_map_masked = routing_map & mask_expanded
283-
tokens_per_expert = routing_map_masked.sum(dim=0)
284-
# Count valid tokens only
285-
num_tokens = padding_mask.sum()
286-
else:
287-
tokens_per_expert = routing_map.sum(dim=0)
288-
num_tokens = routing_map.shape[0]
280+
# Use unified function to compute tokens_per_expert and num_tokens
281+
tokens_per_expert, num_tokens = compute_tokens_per_expert_with_padding(
282+
routing_map=routing_map,
283+
padding_mask=padding_mask,
284+
reshape_for_seq_aux=False,
285+
)
289286

290287
tokens_per_expert = reduce_from_tensor_model_parallel_region(
291288
tokens_per_expert, self.tp_cp_group
@@ -304,7 +301,7 @@ def _apply_aux_loss(
304301
)
305302
probs = self.attach_and_log_load_balancing_loss(
306303
probs, aux_loss_coeff, aux_loss, "load_balancing_loss", self.tp_cp_group,
307-
valid_token_count=num_tokens.item() if isinstance(num_tokens, torch.Tensor) else num_tokens
304+
valid_token_count=num_tokens
308305
)
309306
return probs
310307

@@ -330,21 +327,17 @@ def _apply_seq_aux_loss(
330327

331328
scores_for_aux_loss = scores_for_aux_loss.reshape(seq_length, -1)
332329

333-
if padding_mask is not None:
334-
# Reshape padding_mask to [seq_length, bsz]
335-
padding_mask_reshaped = padding_mask.reshape(seq_length, bsz)
336-
# Expand to match routing_map after reshape [seq_length, bsz * num_experts]
337-
mask_expanded = padding_mask_reshaped.unsqueeze(-1).expand(-1, -1, self.config.num_moe_experts).reshape(seq_length, -1)
338-
# Apply mask to routing_map for token counting
339-
routing_map_masked = routing_map.reshape(seq_length, -1) & mask_expanded
340-
tokens_per_expert = routing_map_masked.sum(dim=0)
341-
# Count valid tokens per sequence
342-
num_valid_tokens_per_seq = padding_mask_reshaped.sum(dim=0) # [bsz]
343-
total_num_tokens = num_valid_tokens_per_seq.sum() * self.tp_cp_group.size()
344-
else:
345-
tokens_per_expert = routing_map.reshape(seq_length, -1).sum(dim=0)
346-
total_num_tokens = seq_length * self.tp_cp_group.size()
347-
padding_mask_for_loss = None
330+
# Use unified function to compute tokens_per_expert and num_tokens
331+
tokens_per_expert, num_tokens = compute_tokens_per_expert_with_padding(
332+
routing_map=routing_map,
333+
padding_mask=padding_mask,
334+
reshape_for_seq_aux=True,
335+
seq_length=seq_length,
336+
bsz=bsz,
337+
num_experts=self.config.num_moe_experts,
338+
)
339+
340+
total_num_tokens = num_tokens * self.tp_cp_group.size()
348341

349342
tokens_per_expert = reduce_from_tensor_model_parallel_region(
350343
tokens_per_expert, self.tp_cp_group
@@ -365,8 +358,7 @@ def _apply_seq_aux_loss(
365358
)
366359
# Calculate valid token count: seq_length for each batch element
367360
if padding_mask is not None:
368-
num_valid_tokens = padding_mask.sum()
369-
valid_token_count = num_valid_tokens.item() if isinstance(num_valid_tokens, torch.Tensor) else num_valid_tokens
361+
valid_token_count = padding_mask.sum()
370362
else:
371363
valid_token_count = seq_length * bsz
372364

@@ -385,16 +377,12 @@ def _apply_global_aux_loss(
385377
if global_aux_loss_coeff == 0:
386378
return probs
387379

388-
if padding_mask is not None:
389-
# routing_map: [num_tokens, num_experts], padding_mask: [num_tokens]
390-
mask_expanded = padding_mask.unsqueeze(-1)
391-
routing_map_masked = routing_map & mask_expanded
392-
tokens_per_expert = routing_map_masked.sum(dim=0)
393-
# Count valid tokens only
394-
num_tokens = padding_mask.sum()
395-
else:
396-
tokens_per_expert = routing_map.sum(dim=0)
397-
num_tokens = scores_for_aux_loss.shape[0]
380+
# Use unified function to compute tokens_per_expert and num_tokens
381+
tokens_per_expert, num_tokens = compute_tokens_per_expert_with_padding(
382+
routing_map=routing_map,
383+
padding_mask=padding_mask,
384+
reshape_for_seq_aux=False,
385+
)
398386

399387
tokens_per_expert = reduce_from_tensor_model_parallel_region(
400388
tokens_per_expert, self.tp_dp_cp_group
@@ -422,7 +410,7 @@ def _apply_global_aux_loss(
422410
global_aux_loss,
423411
"global_load_balancing_loss",
424412
self.tp_dp_cp_group,
425-
valid_token_count=num_tokens.item() if isinstance(num_tokens, torch.Tensor) else num_tokens,
413+
valid_token_count=num_tokens,
426414
)
427415
return probs
428416

@@ -433,7 +421,7 @@ def attach_and_log_load_balancing_loss(
433421
aux_loss: torch.Tensor,
434422
aux_loss_name: str,
435423
reduce_group: torch.distributed.ProcessGroup,
436-
valid_token_count: Optional[int] = None,
424+
valid_token_count: Optional[Union[int, torch.Tensor]] = None,
437425
):
438426
"""Attach aux loss function to activation and add to logging.
439427
@@ -443,7 +431,8 @@ def attach_and_log_load_balancing_loss(
443431
aux_loss (torch.Tensor): Computed aux loss.
444432
aux_loss_name (str): Name of the aux loss for logging.
445433
reduce_group (torch.distributed.ProcessGroup): Process group for reduction.
446-
valid_token_count (int, optional): Number of valid tokens excluding padding tokens.
434+
valid_token_count (int or torch.Tensor, optional): Number of valid tokens excluding
435+
padding tokens. Can be a Python int or a torch.Tensor (typically 0-d tensor).
447436
If None, uses activation.shape[0]. Defaults to None.
448437
"""
449438
# TODO (zijiey): fix the per_layer_logging for MTP, currently it will incorrectly

tests/unit_tests/transformer/moe/test_aux_loss.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,3 +575,145 @@ def test_force_balanced_aux_loss(self, tp_size, ep_size, cp_size):
575575
reduce_from_tensor_model_parallel_region(aux_loss, router.tp_cp_group)
576576
assert aux_loss.item() == 1, f"{aux_loss_type}: {aux_loss.item()}"
577577
clear_aux_losses_tracker()
578+
579+
580+
class TestPaddingMaskAuxLoss:
581+
"""Test padding mask support in various aux loss types."""
582+
583+
def setup_method(self, method):
584+
Utils.initialize_model_parallel(1, 1)
585+
_set_random_seed(seed_=123, data_parallel_random_init=False)
586+
587+
# Default configuration
588+
self.default_transformer_config = TransformerConfig(
589+
num_layers=1,
590+
hidden_size=12,
591+
num_attention_heads=8,
592+
num_moe_experts=32,
593+
use_cpu_initialization=True,
594+
moe_router_load_balancing_type="aux_loss",
595+
moe_router_topk=8,
596+
moe_aux_loss_coeff=1.0,
597+
bf16=True,
598+
params_dtype=torch.bfloat16,
599+
add_bias_linear=False,
600+
)
601+
602+
def new_router(self, **kwargs):
603+
"""Create a new router with updated configuration."""
604+
pg_collection = get_default_pg_collection()
605+
new_transformer_config = dataclasses.replace(self.default_transformer_config, **kwargs)
606+
router = TopKRouter(config=new_transformer_config, pg_collection=pg_collection)
607+
router.set_layer_number(0)
608+
return router
609+
610+
def teardown_method(self, method):
611+
Utils.destroy_model_parallel()
612+
613+
@pytest.mark.internal
614+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
615+
@pytest.mark.parametrize("aux_loss_type", ["aux_loss", "seq_aux_loss", "global_aux_loss"])
616+
def test_padding_mask_removes_padding_tokens(self, aux_loss_type):
617+
"""Test that padding tokens are correctly excluded from aux loss calculation."""
618+
clear_aux_losses_tracker()
619+
620+
router = self.new_router(
621+
moe_router_load_balancing_type=aux_loss_type,
622+
moe_aux_loss_coeff=1.0,
623+
moe_router_dtype="fp64",
624+
).cuda()
625+
626+
seq_len = 32
627+
batch_size = 2
628+
hidden_size = router.config.hidden_size
629+
630+
# Create input with padding
631+
hidden_states_full = torch.randn(
632+
(seq_len, batch_size, hidden_size), dtype=torch.bfloat16, device='cuda'
633+
)
634+
635+
# Create padding mask: first half valid, second half padding
636+
padding_mask = torch.ones((seq_len, batch_size), dtype=torch.bool, device='cuda')
637+
padding_mask[seq_len // 2:, :] = False
638+
639+
# Test with padding mask
640+
router.weight.grad = None
641+
scores_with_mask, routing_map_with_mask = router(hidden_states_full, padding_mask=padding_mask)
642+
scores_with_mask.backward(torch.zeros_like(scores_with_mask))
643+
644+
loss_name = {
645+
"aux_loss": "load_balancing_loss",
646+
"seq_aux_loss": "seq_load_balancing_loss",
647+
"global_aux_loss": "global_load_balancing_loss",
648+
}[aux_loss_type]
649+
650+
tracker = get_moe_layer_wise_logging_tracker()
651+
aux_loss_with_mask = tracker[loss_name]["values"][0].clone()
652+
grad_with_mask = router.weight.grad.clone()
653+
654+
# Test without padding (with only half of the tokens)
655+
clear_aux_losses_tracker()
656+
router.weight.grad = None
657+
hidden_states_valid = hidden_states_full[:seq_len // 2, :, :]
658+
scores_without_mask, routing_map_without_mask = router(hidden_states_valid)
659+
scores_without_mask.backward(torch.zeros_like(scores_without_mask))
660+
661+
aux_loss_without_mask = tracker[loss_name]["values"][0].clone()
662+
grad_without_mask = router.weight.grad.clone()
663+
664+
# The aux loss with mask should be close to the aux loss without mask
665+
torch.testing.assert_close(aux_loss_with_mask, aux_loss_without_mask, rtol=1e-2, atol=1e-3)
666+
torch.testing.assert_close(grad_with_mask, grad_without_mask, rtol=1e-2, atol=1e-3)
667+
668+
clear_aux_losses_tracker()
669+
670+
@pytest.mark.internal
671+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
672+
def test_padding_mask_with_z_loss(self):
673+
"""Test that padding mask works correctly with z_loss."""
674+
clear_aux_losses_tracker()
675+
676+
router = self.new_router(
677+
moe_router_load_balancing_type="aux_loss",
678+
moe_aux_loss_coeff=0.0,
679+
moe_z_loss_coeff=1.0,
680+
moe_router_dtype="fp32",
681+
).cuda()
682+
683+
seq_len = 32
684+
batch_size = 2
685+
hidden_size = router.config.hidden_size
686+
687+
# Create input
688+
hidden_states_full = torch.randn(
689+
(seq_len, batch_size, hidden_size), dtype=torch.bfloat16, device='cuda'
690+
)
691+
692+
# Create padding mask: first half valid, second half padding
693+
padding_mask = torch.ones((seq_len, batch_size), dtype=torch.bool, device='cuda')
694+
padding_mask[seq_len // 2:, :] = False
695+
696+
# Test with padding mask
697+
router.weight.grad = None
698+
scores_with_mask, _ = router(hidden_states_full, padding_mask=padding_mask)
699+
scores_with_mask.sum().backward()
700+
701+
tracker = get_moe_layer_wise_logging_tracker()
702+
z_loss_with_mask = tracker["z_loss"]["values"][0].clone()
703+
grad_with_mask = router.weight.grad.clone()
704+
705+
# Test without padding (with only half of the tokens)
706+
clear_aux_losses_tracker()
707+
router.weight.grad = None
708+
hidden_states_valid = hidden_states_full[:seq_len // 2, :, :]
709+
scores_without_mask, _ = router(hidden_states_valid)
710+
scores_without_mask.sum().backward()
711+
712+
z_loss_without_mask = tracker["z_loss"]["values"][0].clone()
713+
grad_without_mask = router.weight.grad.clone()
714+
715+
# The z_loss with mask should be close to the z_loss without mask
716+
torch.testing.assert_close(z_loss_with_mask, z_loss_without_mask, rtol=1e-2, atol=1e-3)
717+
torch.testing.assert_close(grad_with_mask, grad_without_mask, rtol=1e-2, atol=1e-3)
718+
719+
clear_aux_losses_tracker()

0 commit comments

Comments
 (0)