11# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
33from abc import ABC , abstractmethod
4- from typing import Optional
4+ from typing import Optional , Union
55
66import torch
77
1414 apply_random_logits ,
1515 apply_router_token_dropping ,
1616 compute_routing_scores_for_aux_loss ,
17+ compute_tokens_per_expert_with_padding ,
1718 router_gating_linear ,
1819 save_to_aux_losses_tracker ,
1920 sinkhorn ,
@@ -276,16 +277,12 @@ def _apply_aux_loss(
276277 if aux_loss_coeff == 0 :
277278 return probs
278279
279- if padding_mask is not None :
280- # routing_map: [num_tokens, num_experts], padding_mask: [num_tokens]
281- mask_expanded = padding_mask .unsqueeze (- 1 )
282- routing_map_masked = routing_map & mask_expanded
283- tokens_per_expert = routing_map_masked .sum (dim = 0 )
284- # Count valid tokens only
285- num_tokens = padding_mask .sum ()
286- else :
287- tokens_per_expert = routing_map .sum (dim = 0 )
288- num_tokens = routing_map .shape [0 ]
280+ # Use unified function to compute tokens_per_expert and num_tokens
281+ tokens_per_expert , num_tokens = compute_tokens_per_expert_with_padding (
282+ routing_map = routing_map ,
283+ padding_mask = padding_mask ,
284+ reshape_for_seq_aux = False ,
285+ )
289286
290287 tokens_per_expert = reduce_from_tensor_model_parallel_region (
291288 tokens_per_expert , self .tp_cp_group
@@ -304,7 +301,7 @@ def _apply_aux_loss(
304301 )
305302 probs = self .attach_and_log_load_balancing_loss (
306303 probs , aux_loss_coeff , aux_loss , "load_balancing_loss" , self .tp_cp_group ,
307- valid_token_count = num_tokens . item () if isinstance ( num_tokens , torch . Tensor ) else num_tokens
304+ valid_token_count = num_tokens
308305 )
309306 return probs
310307
@@ -330,21 +327,17 @@ def _apply_seq_aux_loss(
330327
331328 scores_for_aux_loss = scores_for_aux_loss .reshape (seq_length , - 1 )
332329
333- if padding_mask is not None :
334- # Reshape padding_mask to [seq_length, bsz]
335- padding_mask_reshaped = padding_mask .reshape (seq_length , bsz )
336- # Expand to match routing_map after reshape [seq_length, bsz * num_experts]
337- mask_expanded = padding_mask_reshaped .unsqueeze (- 1 ).expand (- 1 , - 1 , self .config .num_moe_experts ).reshape (seq_length , - 1 )
338- # Apply mask to routing_map for token counting
339- routing_map_masked = routing_map .reshape (seq_length , - 1 ) & mask_expanded
340- tokens_per_expert = routing_map_masked .sum (dim = 0 )
341- # Count valid tokens per sequence
342- num_valid_tokens_per_seq = padding_mask_reshaped .sum (dim = 0 ) # [bsz]
343- total_num_tokens = num_valid_tokens_per_seq .sum () * self .tp_cp_group .size ()
344- else :
345- tokens_per_expert = routing_map .reshape (seq_length , - 1 ).sum (dim = 0 )
346- total_num_tokens = seq_length * self .tp_cp_group .size ()
347- padding_mask_for_loss = None
330+ # Use unified function to compute tokens_per_expert and num_tokens
331+ tokens_per_expert , num_tokens = compute_tokens_per_expert_with_padding (
332+ routing_map = routing_map ,
333+ padding_mask = padding_mask ,
334+ reshape_for_seq_aux = True ,
335+ seq_length = seq_length ,
336+ bsz = bsz ,
337+ num_experts = self .config .num_moe_experts ,
338+ )
339+
340+ total_num_tokens = num_tokens * self .tp_cp_group .size ()
348341
349342 tokens_per_expert = reduce_from_tensor_model_parallel_region (
350343 tokens_per_expert , self .tp_cp_group
@@ -365,8 +358,7 @@ def _apply_seq_aux_loss(
365358 )
366359 # Calculate valid token count: seq_length for each batch element
367360 if padding_mask is not None :
368- num_valid_tokens = padding_mask .sum ()
369- valid_token_count = num_valid_tokens .item () if isinstance (num_valid_tokens , torch .Tensor ) else num_valid_tokens
361+ valid_token_count = padding_mask .sum ()
370362 else :
371363 valid_token_count = seq_length * bsz
372364
@@ -385,16 +377,12 @@ def _apply_global_aux_loss(
385377 if global_aux_loss_coeff == 0 :
386378 return probs
387379
388- if padding_mask is not None :
389- # routing_map: [num_tokens, num_experts], padding_mask: [num_tokens]
390- mask_expanded = padding_mask .unsqueeze (- 1 )
391- routing_map_masked = routing_map & mask_expanded
392- tokens_per_expert = routing_map_masked .sum (dim = 0 )
393- # Count valid tokens only
394- num_tokens = padding_mask .sum ()
395- else :
396- tokens_per_expert = routing_map .sum (dim = 0 )
397- num_tokens = scores_for_aux_loss .shape [0 ]
380+ # Use unified function to compute tokens_per_expert and num_tokens
381+ tokens_per_expert , num_tokens = compute_tokens_per_expert_with_padding (
382+ routing_map = routing_map ,
383+ padding_mask = padding_mask ,
384+ reshape_for_seq_aux = False ,
385+ )
398386
399387 tokens_per_expert = reduce_from_tensor_model_parallel_region (
400388 tokens_per_expert , self .tp_dp_cp_group
@@ -422,7 +410,7 @@ def _apply_global_aux_loss(
422410 global_aux_loss ,
423411 "global_load_balancing_loss" ,
424412 self .tp_dp_cp_group ,
425- valid_token_count = num_tokens . item () if isinstance ( num_tokens , torch . Tensor ) else num_tokens ,
413+ valid_token_count = num_tokens ,
426414 )
427415 return probs
428416
@@ -433,7 +421,7 @@ def attach_and_log_load_balancing_loss(
433421 aux_loss : torch .Tensor ,
434422 aux_loss_name : str ,
435423 reduce_group : torch .distributed .ProcessGroup ,
436- valid_token_count : Optional [int ] = None ,
424+ valid_token_count : Optional [Union [ int , torch . Tensor ] ] = None ,
437425 ):
438426 """Attach aux loss function to activation and add to logging.
439427
@@ -443,7 +431,8 @@ def attach_and_log_load_balancing_loss(
443431 aux_loss (torch.Tensor): Computed aux loss.
444432 aux_loss_name (str): Name of the aux loss for logging.
445433 reduce_group (torch.distributed.ProcessGroup): Process group for reduction.
446- valid_token_count (int, optional): Number of valid tokens excluding padding tokens.
434+ valid_token_count (int or torch.Tensor, optional): Number of valid tokens excluding
435+ padding tokens. Can be a Python int or a torch.Tensor (typically 0-d tensor).
447436 If None, uses activation.shape[0]. Defaults to None.
448437 """
449438 # TODO (zijiey): fix the per_layer_logging for MTP, currently it will incorrectly
0 commit comments