Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions megatron/core/models/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ def forward(
*,
inference_params: Optional[BaseInferenceContext] = None,
loss_mask: Optional[Tensor] = None,
padding_mask: Optional[Tensor] = None,
Comment on lines 445 to +448
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: add comments about the meaning of values in padding_mask

) -> Tensor:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoder and finally into the post
Expand Down Expand Up @@ -486,6 +487,7 @@ def forward(
rotary_pos_cos_sin=rotary_pos_cos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
padding_mask=padding_mask,
**(extra_block_kwargs or {}),
)

Expand Down
26 changes: 19 additions & 7 deletions megatron/core/transformer/moe/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,19 @@ def __init__(
self.cudagraph_tensor_store = MoECudaGraphTensorStore()

@maybe_skip_or_early_return_by_cudagraph("route")
def route(self, hidden_states: torch.Tensor):
def route(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
"""Compute token routing for preprocessing.

This method uses the router to determine which experts to send each token to,
producing routing probabilities and a mapping.

Args:
hidden_states (torch.Tensor): Input tensor.
padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens.
Shape = [seq_length, bsz]. True for valid tokens,
False for padding tokens. Defaults to None.
"""
probs, routing_map = self.router(hidden_states)
probs, routing_map = self.router(hidden_states, padding_mask=padding_mask)
return probs, routing_map

@maybe_skip_or_early_return_by_cudagraph("preprocess")
Expand Down Expand Up @@ -270,7 +276,7 @@ def combine(self, output: torch.Tensor, shared_expert_output: Optional[torch.Ten
output = output + shared_expert_output
return output

def forward(self, hidden_states: torch.Tensor):
def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
"""Forward pass for the MoE layer.

The forward pass comprises four main steps:
Expand All @@ -281,6 +287,9 @@ def forward(self, hidden_states: torch.Tensor):

Args:
hidden_states (torch.Tensor): The input tensor to the MoE layer.
padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens.
Shape = [seq_length, bsz]. True for valid tokens,
False for padding tokens. Defaults to None.

Returns:
A tuple containing the output tensor and the MLP bias, if any.
Expand All @@ -292,10 +301,10 @@ def forward(self, hidden_states: torch.Tensor):
)

# MoE forward: route -> dispatch -> compute -> combine
def custom_forward(hidden_states):
def custom_forward(hidden_states, padding_mask=None):
try:
shared_expert_output = self.shared_experts_compute(hidden_states)
probs, routing_map = self.route(hidden_states)
probs, routing_map = self.route(hidden_states, padding_mask=padding_mask)
hidden_states, probs, residual = self.preprocess(hidden_states, probs, routing_map)
except MoECudaGraphPartialCaptureSignal as e:
# This signal is raised from the maybe_skip_or_early_return_by_cudagraph decorator.
Expand All @@ -318,11 +327,14 @@ def custom_forward(hidden_states):
tensor_parallel.random.get_cuda_rng_tracker,
parallel_state.get_tensor_model_parallel_group(),
hidden_states,
padding_mask,
)
else:
outputs = tensor_parallel.checkpoint(custom_forward, False, hidden_states)
outputs = tensor_parallel.checkpoint(
custom_forward, False, hidden_states, padding_mask
)
else:
outputs = custom_forward(hidden_states)
outputs = custom_forward(hidden_states, padding_mask)

return outputs

Expand Down
116 changes: 111 additions & 5 deletions megatron/core/transformer/moe/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,72 @@
_MOE_LAYER_WISE_LOGGING_TRACKER = {}


def compute_tokens_per_expert_with_padding(
routing_map: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
reshape_for_seq_aux: bool = False,
seq_length: Optional[int] = None,
bsz: Optional[int] = None,
num_experts: Optional[int] = None,
):
"""Compute tokens_per_expert and total_num_tokens with optional padding mask.
This function provides a unified way to compute token counts across different aux loss types.

Args:
routing_map (torch.Tensor): Token to expert routing map.
- For aux_loss/global_aux_loss: shape [num_tokens, num_experts]
- For seq_aux_loss: shape [num_tokens, num_experts] but will be reshaped
padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens.
Shape [num_tokens]. True for valid tokens, False for padding tokens.
reshape_for_seq_aux (bool): If True, reshape routing_map for seq_aux_loss computation.
seq_length (int, optional): Sequence length, required when reshape_for_seq_aux=True.
bsz (int, optional): Batch size, required when reshape_for_seq_aux=True.
num_experts (int, optional): Number of experts, required when reshape_for_seq_aux=True.

Returns:
tuple: (tokens_per_expert, num_valid_tokens)
- tokens_per_expert (torch.Tensor): Number of tokens per expert, shape [num_experts]
or [bsz * num_experts] for seq_aux_loss
- num_valid_tokens (torch.Tensor or int): Number of valid (non-padding) tokens
"""
if reshape_for_seq_aux:
# seq aux loss
assert (
seq_length is not None and bsz is not None and num_experts is not None
), "seq_length, bsz, and num_experts must be provided when reshape_for_seq_aux=True"

if padding_mask is not None:
# Reshape padding_mask to [seq_length, bsz]
padding_mask_reshaped = padding_mask.reshape(seq_length, bsz)
# Expand to match routing_map after reshape [seq_length, bsz * num_experts]
mask_expanded = (
padding_mask_reshaped.unsqueeze(-1)
.expand(-1, -1, num_experts)
.reshape(seq_length, -1)
)
routing_map_masked = routing_map.reshape(seq_length, -1) & mask_expanded
tokens_per_expert = routing_map_masked.sum(dim=0)
# Count valid tokens only
num_valid_tokens = padding_mask.sum()
else:
tokens_per_expert = routing_map.reshape(seq_length, -1).sum(dim=0)
num_valid_tokens = routing_map.shape[0]
else:
# aux_loss or global_aux_loss
if padding_mask is not None:
# routing_map: [num_tokens, num_experts], padding_mask: [num_tokens]
mask_expanded = padding_mask.unsqueeze(-1)
routing_map_masked = routing_map & mask_expanded
tokens_per_expert = routing_map_masked.sum(dim=0)
# Count valid tokens only
num_valid_tokens = padding_mask.sum()
else:
tokens_per_expert = routing_map.sum(dim=0)
num_valid_tokens = routing_map.shape[0]

return tokens_per_expert, num_valid_tokens


def switch_load_balancing_loss_func(
probs: torch.Tensor,
tokens_per_expert: torch.Tensor,
Expand All @@ -42,6 +108,7 @@ def switch_load_balancing_loss_func(
num_experts: int,
moe_aux_loss_coeff: float,
fused: bool = False,
padding_mask: Optional[torch.Tensor] = None,
):
"""Calculate the auxiliary loss for load balancing.
Refer to the Switch Transformer (https://arxiv.org/abs/2101.03961)
Expand Down Expand Up @@ -92,9 +159,18 @@ def switch_load_balancing_loss_func(
topk (int): The number of experts selected for each token.
num_experts (int): The number of experts.
moe_aux_loss_coeff (float): The coefficient for the auxiliary loss.
padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens.
Shape in [num_tokens]. True for valid tokens,
False for padding tokens. Defaults to None.
Comment on lines +162 to +164
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About the convention of using True or False to mark the padded token,
could you take a look to other frameworks like PyTorch and transformers for the typical choice?

I see the attention mask use False for valid attention in TransformerEngine
https://github.com/NVIDIA/TransformerEngine/tree/main?tab=readme-ov-file#v17-padding-mask-definition-for-pytorch

Returns:
torch.Tensor: The auxiliary loss for load balancing.
"""
# Apply padding mask to probs if provided
if padding_mask is not None:
# padding_mask: [num_tokens], probs: [num_tokens, num_experts]
mask_expanded = padding_mask.unsqueeze(-1)
probs = probs * mask_expanded

if fused:
if not HAVE_TE or fused_moe_aux_loss is None:
raise ValueError("fused_moe_aux_loss is not available. Please install TE >= 2.7.0.")
Expand All @@ -114,18 +190,32 @@ def switch_load_balancing_loss_func(
return aux_loss


def z_loss_func(logits, z_loss_coeff):
def z_loss_func(logits, z_loss_coeff, padding_mask: Optional[torch.Tensor] = None):
"""Encourages the router's logits to remain small to enhance stability.
Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.

Args:
logits (torch.Tensor): The logits of the router.
z_loss_coeff (float): The coefficient for the z-loss.
padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens.
Shape in [num_tokens]. True for valid tokens,
False for padding tokens. Defaults to None.

Returns:
torch.Tensor: The logits after applying the z-loss.
"""
logsum = torch.logsumexp(logits, dim=-1)
z_loss_values = torch.square(logsum)

if padding_mask is not None:
# Only compute z_loss for non-padding tokens
z_loss_values = z_loss_values * padding_mask
# Compute mean over valid tokens only
num_valid_tokens = padding_mask.sum()
z_loss = z_loss_values.sum() / torch.clamp(num_valid_tokens, min=1.0) * z_loss_coeff
else:
z_loss = torch.mean(z_loss_values) * z_loss_coeff

z_loss = torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) * z_loss_coeff
return z_loss


Expand Down Expand Up @@ -623,21 +713,32 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None):


def compute_routing_scores_for_aux_loss(
logits: torch.Tensor, topk: int, score_function: str, fused: bool = False
logits: torch.Tensor,
topk: int,
score_function: str,
fused: bool = False,
padding_mask: Optional[torch.Tensor] = None,
):
"""Compute routing scores based on the score function.

Args:
logits (torch.Tensor): The logits tensor after gating, shape: [num_tokens, num_experts].

padding_mask (torch.Tensor, optional): Boolean mask indicating non-padding tokens.
Shape in [num_tokens]. True for valid tokens,
False for padding tokens. Defaults to None.
Returns:
torch.Tensor: The normalized routing scores.
Tuple[torch.Tensor, torch.Tensor]: routing_map and scores.
"""
if fused:
if not HAVE_TE or fused_compute_score_for_moe_aux_loss is None:
raise ValueError(
"fused_compute_score_for_moe_aux_loss is not available. Please install TE >= 2.6.0."
)
# Note: fused implementation does not support padding_mask yet
if padding_mask is not None:
raise ValueError(
"Fused compute_routing_scores does not support padding_mask. Set fused=False."
)
return fused_compute_score_for_moe_aux_loss(
logits=logits, topk=topk, score_function=score_function
)
Expand All @@ -652,6 +753,11 @@ def compute_routing_scores_for_aux_loss(

_, top_indices = torch.topk(scores, k=topk, dim=1)
routing_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()

# Apply padding mask to scores if provided
if padding_mask is not None:
mask_expanded = padding_mask.unsqueeze(-1).to(scores.dtype)
scores = scores * mask_expanded
return routing_map, scores


Expand Down
Loading
Loading