Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions megatron/core/models/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,9 +575,19 @@ def _postprocess(
runtime_gather_output=runtime_gather_output,
)
# Calc loss for the current Multi-Token Prediction (MTP) layers.
mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group)
mtp_labels, _ = roll_tensor(
mtp_labels,
shifts=-1,
dims=-1,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
)
loss_mask, num_tokens = roll_tensor(
loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group
loss_mask,
shifts=-1,
dims=-1,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
)
mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits)
mtp_loss = loss_mask * mtp_loss
Expand Down
118 changes: 112 additions & 6 deletions megatron/core/transformer/multi_token_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def tie_output_layer_state_dict(
)


def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None):
def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None, packed_seq_params=None):
"""Roll the tensor input along the sequence dimension with Context Parallelism (CP) support.

This function extends the original roll_tensor to support Context Parallelism, which allows
Expand All @@ -138,15 +138,24 @@ def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None):
For CP>1: Splits tensor into chunks, performs rolling within each chunk, then exchanges
boundary elements between adjacent CP ranks to maintain sequence continuity.

For packed sequences: Respects sequence boundaries when rolling to avoid mixing tokens
from different sequences.

Args:
tensor (Tensor): The input tensor to roll.
shifts (int): The shift of the tensor (typically -1 for MTP).
dims (int): The dimension to roll (typically -1 for sequence dimension).
cp_group (ProcessGroup): The context parallelism process group. If None or size=1,
falls back to standard rolling behavior.
packed_seq_params (PackedSeqParams): Parameters for packed sequence processing.
If provided, respects sequence boundaries.
Returns:
tuple: (rolled_tensor, sum_of_rolled_tensor)
"""
# Handle packed sequences cases
if packed_seq_params is not None:
return _roll_tensor_packed_seq(tensor, shifts, dims, packed_seq_params, cp_group)

# Standard rolling behavior when CP is not enabled (cp_group is None or size=1)
if cp_group is None or cp_group.size() == 1:
rolled_tensor = torch.roll(tensor, shifts=shifts, dims=dims)
Expand Down Expand Up @@ -215,6 +224,91 @@ def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None):
return rolled_tensor, rolled_tensor.sum()


def _roll_tensor_packed_seq(tensor, shifts, dims, packed_seq_params, cp_group=None):
"""Roll tensor with packed sequence support.
This function handles rolling for packed sequences by respecting sequence boundaries
"""

# Notice: This is a naive implementation to test the correctness,
# a better solution will only sync the boundary tokens once.
assert (
dims == -1 or dims == tensor.dim() - 1
), "Packed sequence roll only supports the last dimension."
assert shifts == -1, "Packed sequence roll only supports a single-token left shift."
cu_seqlens = packed_seq_params.cu_seqlens_q
assert cu_seqlens is not None, "Packed sequence parameters must provide cu_seqlens_q."

rolled_tensor = tensor.clone()

cp_size = cp_group.size() if cp_group is not None else 1
if cp_size == 1:
# CP disabled: roll each packed sequence independently within its boundaries
for i in range(len(cu_seqlens) - 1):
start_idx = cu_seqlens[i]
end_idx = cu_seqlens[i + 1]
seq_slice = tensor[..., start_idx:end_idx]
rolled_seq = torch.roll(seq_slice, shifts=shifts, dims=dims)
# Zero out the last position(s) that would cross sequence boundaries
rolled_seq[..., shifts:] = 0
rolled_tensor[..., start_idx:end_idx] = rolled_seq
return rolled_tensor, rolled_tensor.sum()

# CP enabled: each rank owns two chunks per sequence (front and mirrored tail).
local_rank = torch.distributed.get_rank(group=cp_group)
global_ranks = torch.distributed.get_process_group_ranks(group=cp_group)
next_rank = global_ranks[(local_rank + 1) % cp_size]
prev_rank = global_ranks[(local_rank - 1) % cp_size]

# Iterate over each sequence individually
for i in range(len(cu_seqlens) - 1):
start_idx = cu_seqlens[i]
end_idx = cu_seqlens[i + 1]

# the idx has been multiplied by cp_size, need to divide it by cp_size to get the local idx
local_start_idx = start_idx // cp_size
local_end_idx = end_idx // cp_size
tensor_slice = rolled_tensor[..., local_start_idx:local_end_idx].clone()

# The following code is very similar as the code in roll_tensor function
local_chunks = tensor_slice.chunk(2, dim=dims)
rolled_chunks = [torch.roll(chunk, shifts=shifts, dims=dims) for chunk in local_chunks]

tensor_send_list = []
tensor_recv_list = []
for chunk in rolled_chunks:
boundary = chunk.select(dims, shifts).contiguous().clone()
tensor_send_list.append(boundary)
tensor_recv_list.append(torch.empty_like(boundary))

ops = []
if local_rank != 0:
ops.append(torch.distributed.isend(tensor=tensor_send_list[0], dst=prev_rank))
ops.append(torch.distributed.irecv(tensor=tensor_recv_list[1], src=prev_rank))
else:
tensor_recv_list[1].zero_()

if local_rank != cp_size - 1:
ops.append(torch.distributed.irecv(tensor=tensor_recv_list[0], src=next_rank))
ops.append(torch.distributed.isend(tensor=tensor_send_list[1], dst=next_rank))
else:
tensor_recv_list[0].copy_(tensor_send_list[1])

for op in ops:
op.wait()

index = [slice(None)] * rolled_chunks[0].dim()
index[dims] = shifts
for chunk, recv in zip(rolled_chunks, tensor_recv_list):
chunk[tuple(index)] = recv

seq_result = torch.cat(rolled_chunks, dim=dims)

# update the rolled tensor
rolled_tensor[..., local_start_idx:local_end_idx] = seq_result

return rolled_tensor, rolled_tensor.sum()


class MTPLossLoggingHelper:
"""Helper class for logging MTP losses."""

Expand Down Expand Up @@ -595,6 +689,7 @@ def _get_embeddings(
position_ids: torch.Tensor,
embedding: Callable,
hidden_states: torch.Tensor,
packed_seq_params: Optional[PackedSeqParams] = None,
):
"""
Preprocesses input data for the Multi-Token Prediction (MTP) layers.
Expand All @@ -609,10 +704,23 @@ def _get_embeddings(
from gpt model to compute the decoder input.
hidden_states (torch.Tensor): hidden states tensor of shape [s, b, h] where s is the
sequence length, b is the batch size, and h is the hidden size.
packed_seq_params (PackedSeqParams): Parameters for packed sequence processing.
"""
# Calc logits for the current Multi-Token Prediction (MTP) layers.
input_ids, _ = roll_tensor(input_ids, shifts=-1, dims=-1, cp_group=self.cp_group)
position_ids, _ = roll_tensor(position_ids, shifts=-1, dims=-1, cp_group=self.cp_group)
input_ids, _ = roll_tensor(
input_ids,
shifts=-1,
dims=-1,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
)
position_ids, _ = roll_tensor(
position_ids,
shifts=-1,
dims=-1,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
)
# embedding
decoder_input = embedding(input_ids=input_ids, position_ids=position_ids)

Expand Down Expand Up @@ -793,15 +901,13 @@ def forward(
[s, b, h], and optionally the updated context tensor if cross-attention is used.
"""
assert context is None, f"multi token prediction + cross attention is not yet supported."
assert (
packed_seq_params is None
), f"multi token prediction + sequence packing is not yet supported."

input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings(
input_ids=input_ids,
position_ids=position_ids,
embedding=embedding,
hidden_states=hidden_states,
packed_seq_params=packed_seq_params,
)

if self.config.recompute_granularity == 'full' and self.training:
Expand Down
Loading
Loading