Skip to content

Commit 1bb6f7a

Browse files
authored
remove redundant type convert
1 parent d9f644d commit 1bb6f7a

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

megatron/core/transformer/moe/moe_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def switch_load_balancing_loss_func(
116116
# Apply padding mask to probs if provided
117117
if padding_mask is not None:
118118
# padding_mask: [num_tokens], probs: [num_tokens, num_experts]
119-
mask_expanded = padding_mask.unsqueeze(-1).to(probs.dtype)
119+
mask_expanded = padding_mask.unsqueeze(-1)
120120
probs = probs * mask_expanded
121121

122122
aggregated_probs_per_expert = probs.sum(dim=0)
@@ -145,7 +145,7 @@ def z_loss_func(logits, z_loss_coeff, padding_mask: Optional[torch.Tensor] = Non
145145

146146
if padding_mask is not None:
147147
# Only compute z_loss for non-padding tokens
148-
z_loss_values = z_loss_values * padding_mask.to(z_loss_values.dtype)
148+
z_loss_values = z_loss_values * padding_mask
149149
# Compute mean over valid tokens only
150150
num_valid_tokens = padding_mask.sum()
151151
z_loss = z_loss_values.sum() / torch.clamp(num_valid_tokens, min=1.0) * z_loss_coeff

0 commit comments

Comments
 (0)