3434_MOE_LAYER_WISE_LOGGING_TRACKER = {}
3535
3636
37+ def compute_tokens_per_expert_with_padding (
38+ routing_map : torch .Tensor ,
39+ padding_mask : Optional [torch .Tensor ] = None ,
40+ reshape_for_seq_aux : bool = False ,
41+ seq_length : Optional [int ] = None ,
42+ bsz : Optional [int ] = None ,
43+ num_experts : Optional [int ] = None ,
44+ ):
45+ """Compute tokens_per_expert and total_num_tokens with optional padding mask.
46+ This function provides a unified way to compute token counts across different aux loss types.
47+
48+ Args:
49+ routing_map (torch.Tensor): Token to expert routing map.
50+ - For aux_loss/global_aux_loss: shape [num_tokens, num_experts]
51+ - For seq_aux_loss: shape [num_tokens, num_experts] but will be reshaped
52+ padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens.
53+ Shape [num_tokens]. True for valid tokens, False for padding tokens.
54+ reshape_for_seq_aux (bool): If True, reshape routing_map for seq_aux_loss computation.
55+ seq_length (int, optional): Sequence length, required when reshape_for_seq_aux=True.
56+ bsz (int, optional): Batch size, required when reshape_for_seq_aux=True.
57+ num_experts (int, optional): Number of experts, required when reshape_for_seq_aux=True.
58+
59+ Returns:
60+ tuple: (tokens_per_expert, num_valid_tokens)
61+ - tokens_per_expert (torch.Tensor): Number of tokens per expert, shape [num_experts]
62+ or [bsz * num_experts] for seq_aux_loss
63+ - num_valid_tokens (torch.Tensor or int): Number of valid (non-padding) tokens
64+ """
65+ if reshape_for_seq_aux :
66+ # seq aux loss
67+ assert seq_length is not None and bsz is not None and num_experts is not None , \
68+ "seq_length, bsz, and num_experts must be provided when reshape_for_seq_aux=True"
69+
70+ if padding_mask is not None :
71+ # Reshape padding_mask to [seq_length, bsz]
72+ padding_mask_reshaped = padding_mask .reshape (seq_length , bsz )
73+ # Expand to match routing_map after reshape [seq_length, bsz * num_experts]
74+ mask_expanded = padding_mask_reshaped .unsqueeze (- 1 ).expand (
75+ - 1 , - 1 , num_experts
76+ ).reshape (seq_length , - 1 )
77+ routing_map_masked = routing_map .reshape (seq_length , - 1 ) & mask_expanded
78+ tokens_per_expert = routing_map_masked .sum (dim = 0 )
79+ # Count valid tokens only
80+ num_valid_tokens = padding_mask .sum ()
81+ else :
82+ tokens_per_expert = routing_map .reshape (seq_length , - 1 ).sum (dim = 0 )
83+ num_valid_tokens = routing_map .shape [0 ]
84+ else :
85+ # aux_loss or global_aux_loss
86+ if padding_mask is not None :
87+ # routing_map: [num_tokens, num_experts], padding_mask: [num_tokens]
88+ mask_expanded = padding_mask .unsqueeze (- 1 )
89+ routing_map_masked = routing_map & mask_expanded
90+ tokens_per_expert = routing_map_masked .sum (dim = 0 )
91+ # Count valid tokens only
92+ num_valid_tokens = padding_mask .sum ()
93+ else :
94+ tokens_per_expert = routing_map .sum (dim = 0 )
95+ num_valid_tokens = routing_map .shape [0 ]
96+
97+ return tokens_per_expert , num_valid_tokens
98+
99+
37100def switch_load_balancing_loss_func (
38101 probs : torch .Tensor ,
39102 tokens_per_expert : torch .Tensor ,
@@ -42,6 +105,7 @@ def switch_load_balancing_loss_func(
42105 num_experts : int ,
43106 moe_aux_loss_coeff : float ,
44107 fused : bool = False ,
108+ padding_mask : Optional [torch .Tensor ] = None ,
45109):
46110 """Calculate the auxiliary loss for load balancing.
47111 Refer to the Switch Transformer (https://arxiv.org/abs/2101.03961)
@@ -92,9 +156,18 @@ def switch_load_balancing_loss_func(
92156 topk (int): The number of experts selected for each token.
93157 num_experts (int): The number of experts.
94158 moe_aux_loss_coeff (float): The coefficient for the auxiliary loss.
159+ padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens.
160+ Shape in [num_tokens]. True for valid tokens,
161+ False for padding tokens. Defaults to None.
95162 Returns:
96163 torch.Tensor: The auxiliary loss for load balancing.
97164 """
165+ # Apply padding mask to probs if provided
166+ if padding_mask is not None :
167+ # padding_mask: [num_tokens], probs: [num_tokens, num_experts]
168+ mask_expanded = padding_mask .unsqueeze (- 1 )
169+ probs = probs * mask_expanded
170+
98171 if fused :
99172 if not HAVE_TE or fused_moe_aux_loss is None :
100173 raise ValueError ("fused_moe_aux_loss is not available. Please install TE >= 2.7.0." )
@@ -114,18 +187,32 @@ def switch_load_balancing_loss_func(
114187 return aux_loss
115188
116189
117- def z_loss_func (logits , z_loss_coeff ):
190+ def z_loss_func (logits , z_loss_coeff , padding_mask : Optional [ torch . Tensor ] = None ):
118191 """Encourages the router's logits to remain small to enhance stability.
119192 Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.
120193
121194 Args:
122195 logits (torch.Tensor): The logits of the router.
196+ z_loss_coeff (float): The coefficient for the z-loss.
197+ padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens.
198+ Shape in [num_tokens]. True for valid tokens,
199+ False for padding tokens. Defaults to None.
123200
124201 Returns:
125202 torch.Tensor: The logits after applying the z-loss.
126203 """
204+ logsum = torch .logsumexp (logits , dim = - 1 )
205+ z_loss_values = torch .square (logsum )
206+
207+ if padding_mask is not None :
208+ # Only compute z_loss for non-padding tokens
209+ z_loss_values = z_loss_values * padding_mask
210+ # Compute mean over valid tokens only
211+ num_valid_tokens = padding_mask .sum ()
212+ z_loss = z_loss_values .sum () / torch .clamp (num_valid_tokens , min = 1.0 ) * z_loss_coeff
213+ else :
214+ z_loss = torch .mean (z_loss_values ) * z_loss_coeff
127215
128- z_loss = torch .mean (torch .square (torch .logsumexp (logits , dim = - 1 ))) * z_loss_coeff
129216 return z_loss
130217
131218
@@ -623,21 +710,27 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None):
623710
624711
625712def compute_routing_scores_for_aux_loss (
626- logits : torch .Tensor , topk : int , score_function : str , fused : bool = False
713+ logits : torch .Tensor , topk : int , score_function : str , fused : bool = False ,
714+ padding_mask : Optional [torch .Tensor ] = None
627715):
628716 """Compute routing scores based on the score function.
629717
630718 Args:
631719 logits (torch.Tensor): The logits tensor after gating, shape: [num_tokens, num_experts].
632-
720+ padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens.
721+ Shape in [num_tokens]. True for valid tokens,
722+ False for padding tokens. Defaults to None.
633723 Returns:
634- torch.Tensor: The normalized routing scores.
724+ Tuple[ torch.Tensor, torch.Tensor]: routing_map and scores.
635725 """
636726 if fused :
637727 if not HAVE_TE or fused_compute_score_for_moe_aux_loss is None :
638728 raise ValueError (
639729 "fused_compute_score_for_moe_aux_loss is not available. Please install TE >= 2.6.0."
640730 )
731+ # Note: fused implementation does not support padding_mask yet
732+ if padding_mask is not None :
733+ raise ValueError ("Fused compute_routing_scores does not support padding_mask. Set fused=False." )
641734 return fused_compute_score_for_moe_aux_loss (
642735 logits = logits , topk = topk , score_function = score_function
643736 )
@@ -652,6 +745,11 @@ def compute_routing_scores_for_aux_loss(
652745
653746 _ , top_indices = torch .topk (scores , k = topk , dim = 1 )
654747 routing_map = torch .zeros_like (logits ).int ().scatter (1 , top_indices , 1 ).bool ()
748+
749+ # Apply padding mask to scores if provided
750+ if padding_mask is not None :
751+ mask_expanded = padding_mask .unsqueeze (- 1 ).to (scores .dtype )
752+ scores = scores * mask_expanded
655753 return routing_map , scores
656754
657755
0 commit comments