-
Notifications
You must be signed in to change notification settings - Fork 3.3k
[Dev] Remove calculation of padding token in moe routing loss #2121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
c49ffc1
7a34303
5781e60
7e4a3f2
f9d9c15
2903bf9
f1b4e84
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,6 +34,72 @@ | |
| _MOE_LAYER_WISE_LOGGING_TRACKER = {} | ||
|
|
||
|
|
||
| def compute_tokens_per_expert_with_padding( | ||
| routing_map: torch.Tensor, | ||
| padding_mask: Optional[torch.Tensor] = None, | ||
| reshape_for_seq_aux: bool = False, | ||
| seq_length: Optional[int] = None, | ||
| bsz: Optional[int] = None, | ||
| num_experts: Optional[int] = None, | ||
| ): | ||
| """Compute tokens_per_expert and total_num_tokens with optional padding mask. | ||
| This function provides a unified way to compute token counts across different aux loss types. | ||
|
|
||
| Args: | ||
| routing_map (torch.Tensor): Token to expert routing map. | ||
| - For aux_loss/global_aux_loss: shape [num_tokens, num_experts] | ||
| - For seq_aux_loss: shape [num_tokens, num_experts] but will be reshaped | ||
| padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens. | ||
| Shape [num_tokens]. True for valid tokens, False for padding tokens. | ||
| reshape_for_seq_aux (bool): If True, reshape routing_map for seq_aux_loss computation. | ||
| seq_length (int, optional): Sequence length, required when reshape_for_seq_aux=True. | ||
| bsz (int, optional): Batch size, required when reshape_for_seq_aux=True. | ||
| num_experts (int, optional): Number of experts, required when reshape_for_seq_aux=True. | ||
|
|
||
| Returns: | ||
| tuple: (tokens_per_expert, num_valid_tokens) | ||
| - tokens_per_expert (torch.Tensor): Number of tokens per expert, shape [num_experts] | ||
| or [bsz * num_experts] for seq_aux_loss | ||
| - num_valid_tokens (torch.Tensor or int): Number of valid (non-padding) tokens | ||
| """ | ||
| if reshape_for_seq_aux: | ||
| # seq aux loss | ||
| assert ( | ||
| seq_length is not None and bsz is not None and num_experts is not None | ||
| ), "seq_length, bsz, and num_experts must be provided when reshape_for_seq_aux=True" | ||
|
|
||
| if padding_mask is not None: | ||
| # Reshape padding_mask to [seq_length, bsz] | ||
| padding_mask_reshaped = padding_mask.reshape(seq_length, bsz) | ||
| # Expand to match routing_map after reshape [seq_length, bsz * num_experts] | ||
| mask_expanded = ( | ||
| padding_mask_reshaped.unsqueeze(-1) | ||
| .expand(-1, -1, num_experts) | ||
| .reshape(seq_length, -1) | ||
| ) | ||
| routing_map_masked = routing_map.reshape(seq_length, -1) & mask_expanded | ||
| tokens_per_expert = routing_map_masked.sum(dim=0) | ||
| # Count valid tokens only | ||
| num_valid_tokens = padding_mask.sum() | ||
| else: | ||
| tokens_per_expert = routing_map.reshape(seq_length, -1).sum(dim=0) | ||
| num_valid_tokens = routing_map.shape[0] | ||
| else: | ||
| # aux_loss or global_aux_loss | ||
| if padding_mask is not None: | ||
| # routing_map: [num_tokens, num_experts], padding_mask: [num_tokens] | ||
| mask_expanded = padding_mask.unsqueeze(-1) | ||
| routing_map_masked = routing_map & mask_expanded | ||
| tokens_per_expert = routing_map_masked.sum(dim=0) | ||
| # Count valid tokens only | ||
| num_valid_tokens = padding_mask.sum() | ||
| else: | ||
| tokens_per_expert = routing_map.sum(dim=0) | ||
| num_valid_tokens = routing_map.shape[0] | ||
|
|
||
| return tokens_per_expert, num_valid_tokens | ||
|
|
||
|
|
||
| def switch_load_balancing_loss_func( | ||
| probs: torch.Tensor, | ||
| tokens_per_expert: torch.Tensor, | ||
|
|
@@ -42,6 +108,7 @@ def switch_load_balancing_loss_func( | |
| num_experts: int, | ||
| moe_aux_loss_coeff: float, | ||
| fused: bool = False, | ||
| padding_mask: Optional[torch.Tensor] = None, | ||
| ): | ||
| """Calculate the auxiliary loss for load balancing. | ||
| Refer to the Switch Transformer (https://arxiv.org/abs/2101.03961) | ||
|
|
@@ -92,9 +159,18 @@ def switch_load_balancing_loss_func( | |
| topk (int): The number of experts selected for each token. | ||
| num_experts (int): The number of experts. | ||
| moe_aux_loss_coeff (float): The coefficient for the auxiliary loss. | ||
| padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens. | ||
| Shape in [num_tokens]. True for valid tokens, | ||
| False for padding tokens. Defaults to None. | ||
|
Comment on lines
+162
to
+164
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. About the convention of using True or False to mark the padded token, I see the attention mask use False for valid attention in TransformerEngine |
||
| Returns: | ||
| torch.Tensor: The auxiliary loss for load balancing. | ||
| """ | ||
| # Apply padding mask to probs if provided | ||
| if padding_mask is not None: | ||
| # padding_mask: [num_tokens], probs: [num_tokens, num_experts] | ||
| mask_expanded = padding_mask.unsqueeze(-1) | ||
| probs = probs * mask_expanded | ||
|
|
||
| if fused: | ||
| if not HAVE_TE or fused_moe_aux_loss is None: | ||
| raise ValueError("fused_moe_aux_loss is not available. Please install TE >= 2.7.0.") | ||
|
|
@@ -114,18 +190,32 @@ def switch_load_balancing_loss_func( | |
| return aux_loss | ||
|
|
||
|
|
||
| def z_loss_func(logits, z_loss_coeff): | ||
| def z_loss_func(logits, z_loss_coeff, padding_mask: Optional[torch.Tensor] = None): | ||
| """Encourages the router's logits to remain small to enhance stability. | ||
| Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details. | ||
|
|
||
| Args: | ||
| logits (torch.Tensor): The logits of the router. | ||
| z_loss_coeff (float): The coefficient for the z-loss. | ||
| padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens. | ||
| Shape in [num_tokens]. True for valid tokens, | ||
| False for padding tokens. Defaults to None. | ||
|
|
||
| Returns: | ||
| torch.Tensor: The logits after applying the z-loss. | ||
| """ | ||
| logsum = torch.logsumexp(logits, dim=-1) | ||
| z_loss_values = torch.square(logsum) | ||
|
|
||
| if padding_mask is not None: | ||
| # Only compute z_loss for non-padding tokens | ||
| z_loss_values = z_loss_values * padding_mask | ||
| # Compute mean over valid tokens only | ||
| num_valid_tokens = padding_mask.sum() | ||
| z_loss = z_loss_values.sum() / torch.clamp(num_valid_tokens, min=1.0) * z_loss_coeff | ||
| else: | ||
| z_loss = torch.mean(z_loss_values) * z_loss_coeff | ||
|
|
||
| z_loss = torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) * z_loss_coeff | ||
| return z_loss | ||
|
|
||
|
|
||
|
|
@@ -623,21 +713,32 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): | |
|
|
||
|
|
||
| def compute_routing_scores_for_aux_loss( | ||
| logits: torch.Tensor, topk: int, score_function: str, fused: bool = False | ||
| logits: torch.Tensor, | ||
| topk: int, | ||
| score_function: str, | ||
| fused: bool = False, | ||
| padding_mask: Optional[torch.Tensor] = None, | ||
| ): | ||
| """Compute routing scores based on the score function. | ||
|
|
||
| Args: | ||
| logits (torch.Tensor): The logits tensor after gating, shape: [num_tokens, num_experts]. | ||
|
|
||
| padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens. | ||
| Shape in [num_tokens]. True for valid tokens, | ||
| False for padding tokens. Defaults to None. | ||
| Returns: | ||
| torch.Tensor: The normalized routing scores. | ||
| Tuple[torch.Tensor, torch.Tensor]: routing_map and scores. | ||
| """ | ||
| if fused: | ||
| if not HAVE_TE or fused_compute_score_for_moe_aux_loss is None: | ||
| raise ValueError( | ||
| "fused_compute_score_for_moe_aux_loss is not available. Please install TE >= 2.6.0." | ||
| ) | ||
| # Note: fused implementation does not support padding_mask yet | ||
| if padding_mask is not None: | ||
| raise ValueError( | ||
| "Fused compute_routing_scores does not support padding_mask. Set fused=False." | ||
| ) | ||
| return fused_compute_score_for_moe_aux_loss( | ||
| logits=logits, topk=topk, score_function=score_function | ||
| ) | ||
|
|
@@ -652,6 +753,11 @@ def compute_routing_scores_for_aux_loss( | |
|
|
||
| _, top_indices = torch.topk(scores, k=topk, dim=1) | ||
| routing_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool() | ||
|
|
||
| # Apply padding mask to scores if provided | ||
| if padding_mask is not None: | ||
| mask_expanded = padding_mask.unsqueeze(-1).to(scores.dtype) | ||
| scores = scores * mask_expanded | ||
| return routing_map, scores | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: add comments about the meaning of values in padding_mask