Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
d50a125
Start heterogeneous CP idea prototyping
parthmannan Jul 15, 2025
676f0bf
Add current_cp_assignment and some supporting code
parthmannan Jul 16, 2025
65a24cc
Add bricks for building heterogeneous CP schedules
parthmannan Jul 16, 2025
0c5fa39
Return sub-sample per CP rank count to schedules
parthmannan Jul 17, 2025
8546b0f
Move wrapper to utils and formatting
parthmannan Jul 18, 2025
b17a2a2
Add TransformerConfig and arguments
parthmannan Jul 18, 2025
98f592a
Fix implementation issues [WIP]
parthmannan Jul 19, 2025
552bd27
Add sharding logic and packed_seq_params attribute for TE support
parthmannan Jul 21, 2025
48913d0
Protect attention shape change logic with THD assumption
parthmannan Jul 21, 2025
056055c
Add sharding logic and TODO for sync across multiple microbatches
parthmannan Jul 22, 2025
f069e0d
Merge branch 'main' of https://gitlab-master.nvidia.com/ADLR/megatron…
parthmannan Aug 13, 2025
8d34369
Updates to working milestone 1 Hybrid CP scheduling
parthmannan Aug 14, 2025
771fc21
Add SFT dataset support
parthmannan Aug 21, 2025
3a7e50a
New updates to use DataIterator wrapper
parthmannan Aug 22, 2025
bd2321d
Test scripts
parthmannan Aug 22, 2025
6a1ab4f
Minor fixes
parthmannan Aug 22, 2025
e01583a
SFT Dataset + GPT with Hybrid CP working example
parthmannan Aug 22, 2025
286129c
Remove new SFT loss calculation in favor of regular loss calc
parthmannan Aug 23, 2025
2674be8
Milestone 2: Add global batch scheduling, data movement and training
parthmannan Aug 29, 2025
3b87385
Fix for 0 seq len in data
parthmannan Sep 15, 2025
bd04cfe
Fix fill_empty_gpus workload assignment
parthmannan Sep 16, 2025
160f514
llama3 8B benchmarking script
parthmannan Sep 16, 2025
b0eb325
Bug fix
parthmannan Sep 16, 2025
eccf31d
Merge branch 'main' of ssh://gitlab-master.nvidia.com:12051/ADLR/mega…
parthmannan Sep 16, 2025
7e9fe90
Revert "ADLR/megatron-lm!3963 - SFT chat template and pad token chang…
parthmannan Sep 16, 2025
b635d15
Turn on legacy tokenizer
parthmannan Sep 16, 2025
dc8139f
chore: Format files
Sep 16, 2025
69ef399
Remove scheduled id and other finished TODO items
parthmannan Sep 16, 2025
b37e3f7
Merge branch 'pmannan/hetero_cp_test_sft' of ssh://gitlab-master.nvid…
parthmannan Sep 16, 2025
3bafa2c
Formatting and cleanup
parthmannan Sep 16, 2025
78dec32
Add copyright
parthmannan Sep 17, 2025
6b1d9c4
Start testing
parthmannan Sep 17, 2025
71c2966
chore: Format files
Sep 17, 2025
bff834c
Fix arg name in test
parthmannan Sep 18, 2025
2860324
Merge branch 'pmannan/hetero_cp_test_sft' of ssh://gitlab-master.nvid…
parthmannan Sep 18, 2025
ff5600b
Fix arg name in test
parthmannan Sep 18, 2025
7ff37e5
Update RoPE logic to handle none packed_seq_params
parthmannan Sep 18, 2025
3c1f66a
Fix test import
parthmannan Sep 18, 2025
aa6c7f7
chore: Format files
Sep 18, 2025
9d761cc
fix some bugs and support TP
xiaoyao0115 Sep 23, 2025
7d6c2b1
fix pad issue, should pad for sequence parallel when TP>1 and conside…
xiaoyao0115 Sep 23, 2025
fbdfe18
little adjustment according to comments
xiaoyao0115 Sep 24, 2025
8e6aa49
little fix to loop parameter
xiaoyao0115 Sep 25, 2025
f050f54
disable cudagraph anb fsdp when using hybrid cp
xiaoyao0115 Sep 26, 2025
1e76e8e
Add cp comm type to dynamic CP attn
parthmannan Sep 29, 2025
ab8dc81
Merge branch 'main' of ssh://gitlab-master.nvidia.com:12051/ADLR/mega…
parthmannan Sep 29, 2025
3a39a3a
Cleanup
parthmannan Sep 29, 2025
cb72e4e
Removing benchmarking scripts for MR
parthmannan Oct 2, 2025
aeb9e36
Merge branch 'main' of ssh://gitlab-master.nvidia.com:12051/ADLR/mega…
parthmannan Oct 2, 2025
6aebf37
Restore SFT dataset/tokenizer to main
parthmannan Oct 2, 2025
f5c0e76
Update pre-train script
parthmannan Oct 2, 2025
9d96178
Lint change
parthmannan Oct 2, 2025
b7aff32
chore: Format files
Oct 2, 2025
994e65f
Add assert
parthmannan Oct 3, 2025
f60bfdd
Merge branch 'pmannan/hybrid_dp_cp_feature' of ssh://gitlab-master.nv…
parthmannan Oct 3, 2025
aa6cf10
Add support for Yarn Rope
parthmannan Oct 3, 2025
f3a6e17
Add support for MultimodalRope
parthmannan Oct 3, 2025
a0fb848
General and RoPE support cleanup
parthmannan Oct 3, 2025
b0b0fab
Data sampler cleanup
parthmannan Oct 3, 2025
335994c
chore: Format files
Oct 3, 2025
ba99cad
Cleanup hybrid_cp_schedule
parthmannan Oct 3, 2025
d3c0c03
lint fixes
parthmannan Oct 3, 2025
1712980
Fix MLA + Yarn calls
parthmannan Oct 3, 2025
62817e8
chore: Format files
Oct 3, 2025
e586fa4
ProcessGroupCollection.use_mpu_process_groups() doesn't initialize dp…
kunlunl Oct 9, 2025
4221bdc
Fix num_microbatches always being 0
kunlunl Oct 9, 2025
0856c61
Raise error when not using per-token loss
kunlunl Oct 9, 2025
d2db69e
Minor fix
parthmannan Oct 9, 2025
627753b
Address comments
parthmannan Oct 16, 2025
a3b5ebe
Lint
parthmannan Oct 16, 2025
1c22f5c
chore: Format files
Oct 16, 2025
9c75535
Merge branch 'dev' of https://github.com/NVIDIA/Megatron-LM into pman…
parthmannan Oct 30, 2025
97d8224
Merge branch 'dev' of github.com:NVIDIA/Megatron-LM into pmannan/hybr…
parthmannan Nov 7, 2025
617fb17
Fix import
parthmannan Nov 8, 2025
e95ae83
Move Dataloader wrapper to core/datasets
parthmannan Nov 17, 2025
0a0339a
Formatting
parthmannan Nov 17, 2025
4af66dd
Fix conflict
parthmannan Nov 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 299 additions & 0 deletions megatron/core/datasets/data_schedule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
# Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved.

from typing import Any, List, Optional
import torch
from megatron.core import parallel_state
from megatron.core.pipeline_parallel.hybrid_cp_schedule import BalancedCPScheduler
from megatron.core.process_groups_config import ProcessGroupCollection


class HybridCPDataLoaderWrapper:
"""
A wrapper class that wraps around an existing data_iterator.
For every __next__ call,
1. Each DP rank pulls a batch of packed samples.
2. Extracts the sequence lengths of each sub-sample and all-gathers across the DP group.
3. Schedules the sub-samples to the DPxCP ranks using the BalancedCPScheduler.
4. Based on the schedule, reroutes the sub-samples to the correct rank using all-to-all.
5. Returns the assigned sub-samples to this rank.

Args:
data_iterator: The original data_iterator to wrap around
config: The config object containing the max_seqlen_per_dp_cp_rank
dp_cp_group: Data parallel context parallel group.
"""

def __init__(
self, data_iterator, config, pg_collection: Optional[ProcessGroupCollection] = None
):
self.data_iterator = data_iterator
self.config = config
if pg_collection is None:
self.dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True)
self.dp_group = parallel_state.get_data_parallel_group()
self.tp_group = parallel_state.get_tensor_model_parallel_group()
else:
self.dp_cp_group = pg_collection.dp_cp
self.dp_group = pg_collection.dp
self.tp_group = pg_collection.tp
assert (
self.dp_cp_group is not None and self.dp_group is not None and self.tp_group is not None
), "dp_cp_group, dp_group, tp_group must not be None when using hybrid context parallel"

self.cp_balancing_scheduler = BalancedCPScheduler(
max_seq_len_per_rank=self.config.max_seqlen_per_dp_cp_rank, dp_cp_group=self.dp_cp_group
)

self.total_hdp_gpus = self.dp_cp_group.size()

def __iter__(self):
"""Return self as an iterator."""
return self

def get_global_seqlens(self, subsample_seqlens: torch.Tensor) -> List[int]:
"""
Gathers the sequence lengths of all subsamples from all DP ranks.
Each DP rank loads the same number of microbatches but each microbatch
may have a different number of subsamples.

We find the number of subsamples each rank holds and then gather the
sequence lengths of all subsamples from all ranks.
"""
# Collect the number of subsamples from all ranks
local_len = torch.tensor([subsample_seqlens.shape[0]], dtype=torch.int32).cuda()
dp_subsample_count = [torch.zeros_like(local_len) for _ in range(self.dp_group.size())]
torch.distributed.all_gather(dp_subsample_count, local_len, group=self.dp_group)

# Find the max number of subsamples across all ranks and pad subsample_seqlens to max length
dp_subsample_counts = torch.stack(dp_subsample_count, dim=0).cpu().view(-1)
max_sub_samples = int(dp_subsample_counts.max().item())

if local_len.item() < max_sub_samples:
subsample_seqlens_padded = torch.cat(
[
subsample_seqlens,
torch.zeros(max_sub_samples - local_len.item(), dtype=torch.int32).cuda(),
],
dim=0,
)
else:
subsample_seqlens_padded = subsample_seqlens

# Gather the subsample_seqlens from all ranks
seqlens_gathered = [
torch.empty_like(subsample_seqlens_padded) for _ in range(self.dp_group.size())
]
torch.distributed.all_gather(
seqlens_gathered, subsample_seqlens_padded, group=self.dp_group
)

# Trim each seqlens_gathered to the length of the correct sample
for dp_rank, seqlen in enumerate(seqlens_gathered):
seqlens_gathered[dp_rank] = seqlen[: dp_subsample_counts[dp_rank]]

seqlens_gathered = torch.cat(seqlens_gathered, dim=0)
seqlens_gathered = seqlens_gathered.cpu().tolist()

# Calculate the offsets to assign unique global ID to each subsample.
csum = torch.cumsum(dp_subsample_counts, dim=0, dtype=torch.int32)
offsets = torch.cat([torch.zeros(1, dtype=torch.int32), csum[:-1]], dim=0)

return seqlens_gathered, offsets

def get_global_id_seqlens(self, num_local_subsamples, offsets, seqlens_gathered):
"""
Calculates the global ID for each subsample.

We assign a unique global ID to each subsample.

Returns:
global_id_seqlens: list of (global_id, seqlen) tuples for scheduling.
global_ids_this_rank: list of global IDs locally present on this rank.
"""
dp_rank = self.dp_group.rank()
global_ids = torch.arange(len(seqlens_gathered), dtype=torch.int32).cuda()
# Create a list of (global_id, seqlen) tuples for scheduling
global_id_seqlens = [(i, seqlens_gathered[i]) for i in range(len(global_ids))]
# Get the global IDs locally present on this rank
global_ids_this_rank = global_ids[
offsets[dp_rank] : offsets[dp_rank] + num_local_subsamples
]

return global_id_seqlens, global_ids_this_rank

def _gid_to_src_rank(self, gid: int, offsets: List[int]) -> int:
dp_src_rank = torch.bucketize(gid, offsets[1:] - 1)
# Since the torch.distributed.get_process_group_ranks
# provides the global rank, we need to consider TP
hdp_rank = (
torch.distributed.get_process_group_ranks(self.dp_group)[dp_src_rank]
// self.tp_group.size()
)
return hdp_rank

def reroute_samples_to_hdp_ranks(
self, batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets
):
"""
Reroutes the sub-samples to the correct rank after scheduling.

For each key in the batch dict, we perform an all-to-all communication
to transfer the data to the correct ranks.
Since all CP ranks within a DP group have the same data, we only need
to transfer data between matching CP ranks.
"""
gid2local_id = {int(gid): i for i, gid in enumerate(global_ids_this_rank)}
hdp_rank = self.dp_cp_group.rank()
dp_ranks = torch.distributed.get_process_group_ranks(self.dp_group)
# Here we actually want to get the DP group's rank within the HDP group,
# we need to consider TP
dp_ranks = [r // self.tp_group.size() for r in dp_ranks]

data_keys = batch[0].keys()

# Create the send plan
combined_sample_id_groups: List[List[int]] = [[] for _ in range(self.total_hdp_gpus)]

for d in range(self.total_hdp_gpus):
for sample_id_group in sample_id_groups:
combined_sample_id_groups[d].extend(sample_id_group[d])

for dest_rank in range(self.total_hdp_gpus):
combined_sample_id_groups[dest_rank].sort()

# Filter out samples that are not present on this rank
send_ids_sorted = [
gid
for d in dp_ranks
for gid in combined_sample_id_groups[d]
if gid in global_ids_this_rank
]
# send_counts = [len(combined_sample_id_groups[d]) for d in range(self.total_hdp_gpus)]

send_lens_split = [0] * self.total_hdp_gpus
for dest_rank in range(self.total_hdp_gpus):
if dest_rank in dp_ranks:
send_lens_split[dest_rank] = sum(
[
global_id_seqlens[gid][1]
for gid in combined_sample_id_groups[dest_rank]
if gid in global_ids_this_rank
]
)
else:
# We only need to share local data with DP ranks that have different data.
send_lens_split[dest_rank] = 0

# Create the recv plan
recv_sample_id_groups = [[] for _ in range(self.total_hdp_gpus)]
for gid in combined_sample_id_groups[hdp_rank]:
src_rank = self._gid_to_src_rank(gid, offsets)
recv_sample_id_groups[src_rank].append(gid)

recv_lens_split = [0] * self.total_hdp_gpus
for src_rank in range(self.total_hdp_gpus):
recv_lens_split[src_rank] = sum(
[global_id_seqlens[gid][1] for gid in recv_sample_id_groups[src_rank]]
)

recv_ids_sorted = [
gid for d in range(self.total_hdp_gpus) for gid in recv_sample_id_groups[d]
]
recv_counts = [len(recv_sample_id_groups[d]) for d in range(self.total_hdp_gpus)]

recv_samples = [{k: None for k in data_keys} for _ in range(sum(recv_counts))]

def _pack_sample_by_key(key: str) -> torch.Tensor:
flattened_tensors = []
for gid in send_ids_sorted:
t = batch[gid2local_id[gid]][key].to(torch.cuda.current_device(), non_blocking=True)
flattened_tensors.append(t)
return (
torch.cat(flattened_tensors, dim=0)
if flattened_tensors
else torch.empty(0, device=torch.cuda.current_device(), dtype=batch[0][key].dtype)
)

def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor):
cursor = 0
for i, gid in enumerate(recv_ids_sorted):
sample_len = global_id_seqlens[gid][1]
recv_samples[i][key] = recv_tensor[cursor : cursor + sample_len]
cursor += sample_len

for key in data_keys:
send_tensor = _pack_sample_by_key(key)
recv_tensor = torch.empty(
sum(recv_lens_split), device=torch.cuda.current_device(), dtype=send_tensor.dtype
)
torch.distributed.all_to_all_single(
output=recv_tensor,
input=send_tensor,
output_split_sizes=recv_lens_split,
input_split_sizes=send_lens_split,
group=self.dp_cp_group,
)
_unpack_sample_by_key(key, recv_tensor)

recv_sample_with_id = {
recv_id: recv_samples[i] for i, recv_id in enumerate(recv_ids_sorted)
}
return recv_sample_with_id

def unpack_batch(self, batch):
"""
Unpacks the packed samples into a list of sub-samples.
Since each sub-sample may be routed to different DPxCP ranks,
we unpack the sample here to avoid unnecessarily transferring
the entire packed sample.
"""
batch_unpacked = []
for sample in batch:
for sub_sample in range(sample["cu_seqlens"].shape[0] - 1):
sub_sample_dict = {}
start_idx = sample["cu_seqlens"][sub_sample]
end_idx = sample["cu_seqlens"][sub_sample + 1]
if end_idx - start_idx == 0:
continue
for key in sample.keys():
if key in ["cu_seqlens", "batch_idx", "max_seqlen"]:
continue
sub_sample_dict[key] = sample[key][start_idx:end_idx]
batch_unpacked.append(sub_sample_dict)
return batch_unpacked

def __next__(self) -> Any:
"""
Get the next item from the dataset, pull scheduling metadata and return it.
"""
if self.data_iterator is None:
# TP0 reads from data_iterator, others receive via broadcast.
return None, None
else:
batch = next(self.data_iterator)
subsample_seqlens = []
for sample in batch:
subsample_seqlens.extend(
[
int(sample["cu_seqlens"][i + 1] - sample["cu_seqlens"][i])
for i in range(0, sample["cu_seqlens"].shape[0] - 1)
]
)
subsample_seqlens = torch.tensor(subsample_seqlens, dtype=torch.int32).cuda()
subsample_seqlens = subsample_seqlens[subsample_seqlens != 0]

seqlens_gathered, offsets = self.get_global_seqlens(subsample_seqlens)

global_id_seqlens, global_ids_this_rank = self.get_global_id_seqlens(
subsample_seqlens.shape[0], offsets, seqlens_gathered
)

groups, sample_id_groups = self.cp_balancing_scheduler.get_groups_and_subsamples(
global_id_seqlens, self.config
)

batch = self.unpack_batch(batch)
samples_this_rank_with_id = self.reroute_samples_to_hdp_ranks(
batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets
)
return samples_this_rank_with_id, sample_id_groups
18 changes: 18 additions & 0 deletions megatron/core/datasets/gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,24 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig):
object_storage_cache_path: Optional[str] = None
"""Path for caching indices for s3 or msc dataloading."""

context_parallel_size: int = 1
"""Option to enable context parallelism"""

data_parallel_size: int = 1
"""Option to enable data parallelism"""

sequence_parallel_size: int = 0
"""Option to indicate the sequence parallelism size when using TP
Set to 0 if sequence parallel is not enabled regardless of TP size.
"""

hybrid_context_parallel: bool = False
"""Option to enable hybrid context parallelism. When setting this to True,
each sample should be divisible by the data parallel size * context parallel size * 2.
If sequence parallel is enabled, it should be divisible by the
data parallel size * context parallel size * sequence parallel size * 2.
"""

def __post_init__(self) -> None:
"""Do asserts and set fields post init"""
super().__post_init__()
Expand Down
20 changes: 20 additions & 0 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,7 @@ def __init__(
self.kept_packed_seq_params = set(
field.name for field in dataclasses.fields(PackedSeqParams)
)

if get_te_version() < PkgVersion("1.3.0"):
# TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H
# copies (#555)
Expand Down Expand Up @@ -1049,6 +1050,25 @@ def forward(
packed_seq_params: PackedSeqParams = None,
):
"""Forward."""
if packed_seq_params is not None:
# If Dynamic CP group is provided, update TE DPA CP group
if packed_seq_params.cp_group is not None:
self.cp_group = packed_seq_params.cp_group
super().set_context_parallel_group(
self.cp_group,
torch.distributed.get_process_group_ranks(self.cp_group),
TEDotProductAttention.cp_stream,
self.cp_comm_type,
)
# If cp_group is None but local_cp_size is provided,
# Indicates to turn off CP dynamically
elif packed_seq_params.local_cp_size is not None:
assert (
packed_seq_params.local_cp_size == 1
), "local_cp_size must be == 1 if provided without cp_group"
super().set_context_parallel_group(None, None, None, self.cp_comm_type)
self.kept_packed_seq_params.discard("cp_group")
self.kept_packed_seq_params.discard("local_cp_size")
packed_seq_kwargs = (
{key: getattr(packed_seq_params, key) for key in self.kept_packed_seq_params}
if packed_seq_params is not None
Expand Down
16 changes: 16 additions & 0 deletions megatron/core/model_parallel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,22 @@ class ModelParallelConfig:
type.
"""

max_seqlen_per_dp_cp_rank: Optional[int] = None
"""
Maximum sequence length per DPxCP rank. This is the maximum sequence length each rank
can handle without overflowing the memory. Typically, a good starting point is to set this
to maximum sequence length / context parallel size.
This is used to calculate the number and length of sub-samples assigned to
each rank when using hybrid_context_parallel.
"""

hybrid_context_parallel: bool = False
"""
If true, enables hybrid context parallel. This is used to balance the workload of
each CP rank when we use packed samples with variable sequence lengths.
Please set max_seqlen_per_dp_cp_rank when using hybrid_context_parallel.
"""

expert_model_parallel_size: int = 1
"""Distributes Moe Experts across sub data parallel dimension."""

Expand Down
Loading