-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Hybrid Data x Context Parallelism Feature #2054
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
parthmannan
wants to merge
77
commits into
NVIDIA:dev
Choose a base branch
from
parthmannan:pmannan/hybrid_dp_cp_feature
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
77 commits
Select commit
Hold shift + click to select a range
d50a125
Start heterogeneous CP idea prototyping
parthmannan 676f0bf
Add current_cp_assignment and some supporting code
parthmannan 65a24cc
Add bricks for building heterogeneous CP schedules
parthmannan 0c5fa39
Return sub-sample per CP rank count to schedules
parthmannan 8546b0f
Move wrapper to utils and formatting
parthmannan b17a2a2
Add TransformerConfig and arguments
parthmannan 98f592a
Fix implementation issues [WIP]
parthmannan 552bd27
Add sharding logic and packed_seq_params attribute for TE support
parthmannan 48913d0
Protect attention shape change logic with THD assumption
parthmannan 056055c
Add sharding logic and TODO for sync across multiple microbatches
parthmannan f069e0d
Merge branch 'main' of https://gitlab-master.nvidia.com/ADLR/megatron…
parthmannan 8d34369
Updates to working milestone 1 Hybrid CP scheduling
parthmannan 771fc21
Add SFT dataset support
parthmannan 3a7e50a
New updates to use DataIterator wrapper
parthmannan bd2321d
Test scripts
parthmannan 6a1ab4f
Minor fixes
parthmannan e01583a
SFT Dataset + GPT with Hybrid CP working example
parthmannan 286129c
Remove new SFT loss calculation in favor of regular loss calc
parthmannan 2674be8
Milestone 2: Add global batch scheduling, data movement and training
parthmannan 3b87385
Fix for 0 seq len in data
parthmannan bd04cfe
Fix fill_empty_gpus workload assignment
parthmannan 160f514
llama3 8B benchmarking script
parthmannan b0eb325
Bug fix
parthmannan eccf31d
Merge branch 'main' of ssh://gitlab-master.nvidia.com:12051/ADLR/mega…
parthmannan 7e9fe90
Revert "ADLR/megatron-lm!3963 - SFT chat template and pad token chang…
parthmannan b635d15
Turn on legacy tokenizer
parthmannan dc8139f
chore: Format files
69ef399
Remove scheduled id and other finished TODO items
parthmannan b37e3f7
Merge branch 'pmannan/hetero_cp_test_sft' of ssh://gitlab-master.nvid…
parthmannan 3bafa2c
Formatting and cleanup
parthmannan 78dec32
Add copyright
parthmannan 6b1d9c4
Start testing
parthmannan 71c2966
chore: Format files
bff834c
Fix arg name in test
parthmannan 2860324
Merge branch 'pmannan/hetero_cp_test_sft' of ssh://gitlab-master.nvid…
parthmannan ff5600b
Fix arg name in test
parthmannan 7ff37e5
Update RoPE logic to handle none packed_seq_params
parthmannan 3c1f66a
Fix test import
parthmannan aa6c7f7
chore: Format files
9d761cc
fix some bugs and support TP
xiaoyao0115 7d6c2b1
fix pad issue, should pad for sequence parallel when TP>1 and conside…
xiaoyao0115 fbdfe18
little adjustment according to comments
xiaoyao0115 8e6aa49
little fix to loop parameter
xiaoyao0115 f050f54
disable cudagraph anb fsdp when using hybrid cp
xiaoyao0115 1e76e8e
Add cp comm type to dynamic CP attn
parthmannan ab8dc81
Merge branch 'main' of ssh://gitlab-master.nvidia.com:12051/ADLR/mega…
parthmannan 3a39a3a
Cleanup
parthmannan cb72e4e
Removing benchmarking scripts for MR
parthmannan aeb9e36
Merge branch 'main' of ssh://gitlab-master.nvidia.com:12051/ADLR/mega…
parthmannan 6aebf37
Restore SFT dataset/tokenizer to main
parthmannan f5c0e76
Update pre-train script
parthmannan 9d96178
Lint change
parthmannan b7aff32
chore: Format files
994e65f
Add assert
parthmannan f60bfdd
Merge branch 'pmannan/hybrid_dp_cp_feature' of ssh://gitlab-master.nv…
parthmannan aa6cf10
Add support for Yarn Rope
parthmannan f3a6e17
Add support for MultimodalRope
parthmannan a0fb848
General and RoPE support cleanup
parthmannan b0b0fab
Data sampler cleanup
parthmannan 335994c
chore: Format files
ba99cad
Cleanup hybrid_cp_schedule
parthmannan d3c0c03
lint fixes
parthmannan 1712980
Fix MLA + Yarn calls
parthmannan 62817e8
chore: Format files
e586fa4
ProcessGroupCollection.use_mpu_process_groups() doesn't initialize dp…
kunlunl 4221bdc
Fix num_microbatches always being 0
kunlunl 0856c61
Raise error when not using per-token loss
kunlunl d2db69e
Minor fix
parthmannan 627753b
Address comments
parthmannan a3b5ebe
Lint
parthmannan 1c22f5c
chore: Format files
9c75535
Merge branch 'dev' of https://github.com/NVIDIA/Megatron-LM into pman…
parthmannan 97d8224
Merge branch 'dev' of github.com:NVIDIA/Megatron-LM into pmannan/hybr…
parthmannan 617fb17
Fix import
parthmannan e95ae83
Move Dataloader wrapper to core/datasets
parthmannan 0a0339a
Formatting
parthmannan 4af66dd
Fix conflict
parthmannan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,299 @@ | ||
| # Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. | ||
|
|
||
| from typing import Any, List, Optional | ||
| import torch | ||
| from megatron.core import parallel_state | ||
| from megatron.core.pipeline_parallel.hybrid_cp_schedule import BalancedCPScheduler | ||
| from megatron.core.process_groups_config import ProcessGroupCollection | ||
|
|
||
|
|
||
| class HybridCPDataLoaderWrapper: | ||
| """ | ||
| A wrapper class that wraps around an existing data_iterator. | ||
| For every __next__ call, | ||
| 1. Each DP rank pulls a batch of packed samples. | ||
| 2. Extracts the sequence lengths of each sub-sample and all-gathers across the DP group. | ||
| 3. Schedules the sub-samples to the DPxCP ranks using the BalancedCPScheduler. | ||
| 4. Based on the schedule, reroutes the sub-samples to the correct rank using all-to-all. | ||
| 5. Returns the assigned sub-samples to this rank. | ||
|
|
||
| Args: | ||
| data_iterator: The original data_iterator to wrap around | ||
| config: The config object containing the max_seqlen_per_dp_cp_rank | ||
| dp_cp_group: Data parallel context parallel group. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, data_iterator, config, pg_collection: Optional[ProcessGroupCollection] = None | ||
| ): | ||
| self.data_iterator = data_iterator | ||
| self.config = config | ||
| if pg_collection is None: | ||
| self.dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) | ||
| self.dp_group = parallel_state.get_data_parallel_group() | ||
| self.tp_group = parallel_state.get_tensor_model_parallel_group() | ||
| else: | ||
| self.dp_cp_group = pg_collection.dp_cp | ||
| self.dp_group = pg_collection.dp | ||
| self.tp_group = pg_collection.tp | ||
| assert ( | ||
| self.dp_cp_group is not None and self.dp_group is not None and self.tp_group is not None | ||
| ), "dp_cp_group, dp_group, tp_group must not be None when using hybrid context parallel" | ||
|
|
||
| self.cp_balancing_scheduler = BalancedCPScheduler( | ||
| max_seq_len_per_rank=self.config.max_seqlen_per_dp_cp_rank, dp_cp_group=self.dp_cp_group | ||
| ) | ||
|
|
||
| self.total_hdp_gpus = self.dp_cp_group.size() | ||
|
|
||
| def __iter__(self): | ||
| """Return self as an iterator.""" | ||
| return self | ||
|
|
||
| def get_global_seqlens(self, subsample_seqlens: torch.Tensor) -> List[int]: | ||
| """ | ||
| Gathers the sequence lengths of all subsamples from all DP ranks. | ||
| Each DP rank loads the same number of microbatches but each microbatch | ||
| may have a different number of subsamples. | ||
|
|
||
| We find the number of subsamples each rank holds and then gather the | ||
| sequence lengths of all subsamples from all ranks. | ||
| """ | ||
| # Collect the number of subsamples from all ranks | ||
| local_len = torch.tensor([subsample_seqlens.shape[0]], dtype=torch.int32).cuda() | ||
| dp_subsample_count = [torch.zeros_like(local_len) for _ in range(self.dp_group.size())] | ||
| torch.distributed.all_gather(dp_subsample_count, local_len, group=self.dp_group) | ||
|
|
||
| # Find the max number of subsamples across all ranks and pad subsample_seqlens to max length | ||
| dp_subsample_counts = torch.stack(dp_subsample_count, dim=0).cpu().view(-1) | ||
| max_sub_samples = int(dp_subsample_counts.max().item()) | ||
|
|
||
| if local_len.item() < max_sub_samples: | ||
| subsample_seqlens_padded = torch.cat( | ||
| [ | ||
| subsample_seqlens, | ||
| torch.zeros(max_sub_samples - local_len.item(), dtype=torch.int32).cuda(), | ||
| ], | ||
| dim=0, | ||
| ) | ||
| else: | ||
| subsample_seqlens_padded = subsample_seqlens | ||
|
|
||
| # Gather the subsample_seqlens from all ranks | ||
| seqlens_gathered = [ | ||
| torch.empty_like(subsample_seqlens_padded) for _ in range(self.dp_group.size()) | ||
| ] | ||
| torch.distributed.all_gather( | ||
| seqlens_gathered, subsample_seqlens_padded, group=self.dp_group | ||
| ) | ||
|
|
||
| # Trim each seqlens_gathered to the length of the correct sample | ||
| for dp_rank, seqlen in enumerate(seqlens_gathered): | ||
| seqlens_gathered[dp_rank] = seqlen[: dp_subsample_counts[dp_rank]] | ||
|
|
||
| seqlens_gathered = torch.cat(seqlens_gathered, dim=0) | ||
| seqlens_gathered = seqlens_gathered.cpu().tolist() | ||
|
|
||
| # Calculate the offsets to assign unique global ID to each subsample. | ||
| csum = torch.cumsum(dp_subsample_counts, dim=0, dtype=torch.int32) | ||
| offsets = torch.cat([torch.zeros(1, dtype=torch.int32), csum[:-1]], dim=0) | ||
|
|
||
| return seqlens_gathered, offsets | ||
|
|
||
| def get_global_id_seqlens(self, num_local_subsamples, offsets, seqlens_gathered): | ||
| """ | ||
| Calculates the global ID for each subsample. | ||
|
|
||
| We assign a unique global ID to each subsample. | ||
|
|
||
| Returns: | ||
| global_id_seqlens: list of (global_id, seqlen) tuples for scheduling. | ||
| global_ids_this_rank: list of global IDs locally present on this rank. | ||
| """ | ||
| dp_rank = self.dp_group.rank() | ||
| global_ids = torch.arange(len(seqlens_gathered), dtype=torch.int32).cuda() | ||
| # Create a list of (global_id, seqlen) tuples for scheduling | ||
| global_id_seqlens = [(i, seqlens_gathered[i]) for i in range(len(global_ids))] | ||
| # Get the global IDs locally present on this rank | ||
| global_ids_this_rank = global_ids[ | ||
| offsets[dp_rank] : offsets[dp_rank] + num_local_subsamples | ||
| ] | ||
|
|
||
| return global_id_seqlens, global_ids_this_rank | ||
|
|
||
| def _gid_to_src_rank(self, gid: int, offsets: List[int]) -> int: | ||
| dp_src_rank = torch.bucketize(gid, offsets[1:] - 1) | ||
| # Since the torch.distributed.get_process_group_ranks | ||
| # provides the global rank, we need to consider TP | ||
| hdp_rank = ( | ||
| torch.distributed.get_process_group_ranks(self.dp_group)[dp_src_rank] | ||
| // self.tp_group.size() | ||
| ) | ||
| return hdp_rank | ||
|
|
||
| def reroute_samples_to_hdp_ranks( | ||
| self, batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets | ||
| ): | ||
| """ | ||
| Reroutes the sub-samples to the correct rank after scheduling. | ||
|
|
||
| For each key in the batch dict, we perform an all-to-all communication | ||
| to transfer the data to the correct ranks. | ||
| Since all CP ranks within a DP group have the same data, we only need | ||
| to transfer data between matching CP ranks. | ||
| """ | ||
| gid2local_id = {int(gid): i for i, gid in enumerate(global_ids_this_rank)} | ||
| hdp_rank = self.dp_cp_group.rank() | ||
| dp_ranks = torch.distributed.get_process_group_ranks(self.dp_group) | ||
| # Here we actually want to get the DP group's rank within the HDP group, | ||
| # we need to consider TP | ||
| dp_ranks = [r // self.tp_group.size() for r in dp_ranks] | ||
|
|
||
| data_keys = batch[0].keys() | ||
|
|
||
| # Create the send plan | ||
| combined_sample_id_groups: List[List[int]] = [[] for _ in range(self.total_hdp_gpus)] | ||
|
|
||
| for d in range(self.total_hdp_gpus): | ||
| for sample_id_group in sample_id_groups: | ||
| combined_sample_id_groups[d].extend(sample_id_group[d]) | ||
|
|
||
| for dest_rank in range(self.total_hdp_gpus): | ||
| combined_sample_id_groups[dest_rank].sort() | ||
|
|
||
| # Filter out samples that are not present on this rank | ||
| send_ids_sorted = [ | ||
| gid | ||
| for d in dp_ranks | ||
| for gid in combined_sample_id_groups[d] | ||
| if gid in global_ids_this_rank | ||
| ] | ||
| # send_counts = [len(combined_sample_id_groups[d]) for d in range(self.total_hdp_gpus)] | ||
|
|
||
| send_lens_split = [0] * self.total_hdp_gpus | ||
| for dest_rank in range(self.total_hdp_gpus): | ||
| if dest_rank in dp_ranks: | ||
| send_lens_split[dest_rank] = sum( | ||
| [ | ||
| global_id_seqlens[gid][1] | ||
| for gid in combined_sample_id_groups[dest_rank] | ||
| if gid in global_ids_this_rank | ||
| ] | ||
| ) | ||
| else: | ||
| # We only need to share local data with DP ranks that have different data. | ||
| send_lens_split[dest_rank] = 0 | ||
|
|
||
| # Create the recv plan | ||
| recv_sample_id_groups = [[] for _ in range(self.total_hdp_gpus)] | ||
| for gid in combined_sample_id_groups[hdp_rank]: | ||
| src_rank = self._gid_to_src_rank(gid, offsets) | ||
| recv_sample_id_groups[src_rank].append(gid) | ||
|
|
||
| recv_lens_split = [0] * self.total_hdp_gpus | ||
| for src_rank in range(self.total_hdp_gpus): | ||
| recv_lens_split[src_rank] = sum( | ||
| [global_id_seqlens[gid][1] for gid in recv_sample_id_groups[src_rank]] | ||
| ) | ||
|
|
||
| recv_ids_sorted = [ | ||
| gid for d in range(self.total_hdp_gpus) for gid in recv_sample_id_groups[d] | ||
| ] | ||
| recv_counts = [len(recv_sample_id_groups[d]) for d in range(self.total_hdp_gpus)] | ||
|
|
||
| recv_samples = [{k: None for k in data_keys} for _ in range(sum(recv_counts))] | ||
|
|
||
| def _pack_sample_by_key(key: str) -> torch.Tensor: | ||
| flattened_tensors = [] | ||
| for gid in send_ids_sorted: | ||
| t = batch[gid2local_id[gid]][key].to(torch.cuda.current_device(), non_blocking=True) | ||
| flattened_tensors.append(t) | ||
| return ( | ||
| torch.cat(flattened_tensors, dim=0) | ||
| if flattened_tensors | ||
| else torch.empty(0, device=torch.cuda.current_device(), dtype=batch[0][key].dtype) | ||
| ) | ||
|
|
||
| def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor): | ||
| cursor = 0 | ||
| for i, gid in enumerate(recv_ids_sorted): | ||
| sample_len = global_id_seqlens[gid][1] | ||
| recv_samples[i][key] = recv_tensor[cursor : cursor + sample_len] | ||
| cursor += sample_len | ||
|
|
||
| for key in data_keys: | ||
| send_tensor = _pack_sample_by_key(key) | ||
| recv_tensor = torch.empty( | ||
| sum(recv_lens_split), device=torch.cuda.current_device(), dtype=send_tensor.dtype | ||
| ) | ||
| torch.distributed.all_to_all_single( | ||
| output=recv_tensor, | ||
| input=send_tensor, | ||
| output_split_sizes=recv_lens_split, | ||
| input_split_sizes=send_lens_split, | ||
| group=self.dp_cp_group, | ||
| ) | ||
| _unpack_sample_by_key(key, recv_tensor) | ||
|
|
||
| recv_sample_with_id = { | ||
| recv_id: recv_samples[i] for i, recv_id in enumerate(recv_ids_sorted) | ||
| } | ||
| return recv_sample_with_id | ||
|
|
||
| def unpack_batch(self, batch): | ||
| """ | ||
| Unpacks the packed samples into a list of sub-samples. | ||
| Since each sub-sample may be routed to different DPxCP ranks, | ||
| we unpack the sample here to avoid unnecessarily transferring | ||
| the entire packed sample. | ||
| """ | ||
| batch_unpacked = [] | ||
| for sample in batch: | ||
| for sub_sample in range(sample["cu_seqlens"].shape[0] - 1): | ||
| sub_sample_dict = {} | ||
| start_idx = sample["cu_seqlens"][sub_sample] | ||
| end_idx = sample["cu_seqlens"][sub_sample + 1] | ||
| if end_idx - start_idx == 0: | ||
| continue | ||
| for key in sample.keys(): | ||
| if key in ["cu_seqlens", "batch_idx", "max_seqlen"]: | ||
| continue | ||
| sub_sample_dict[key] = sample[key][start_idx:end_idx] | ||
| batch_unpacked.append(sub_sample_dict) | ||
| return batch_unpacked | ||
|
|
||
| def __next__(self) -> Any: | ||
| """ | ||
| Get the next item from the dataset, pull scheduling metadata and return it. | ||
| """ | ||
| if self.data_iterator is None: | ||
| # TP0 reads from data_iterator, others receive via broadcast. | ||
| return None, None | ||
| else: | ||
| batch = next(self.data_iterator) | ||
| subsample_seqlens = [] | ||
| for sample in batch: | ||
| subsample_seqlens.extend( | ||
| [ | ||
| int(sample["cu_seqlens"][i + 1] - sample["cu_seqlens"][i]) | ||
| for i in range(0, sample["cu_seqlens"].shape[0] - 1) | ||
| ] | ||
| ) | ||
| subsample_seqlens = torch.tensor(subsample_seqlens, dtype=torch.int32).cuda() | ||
| subsample_seqlens = subsample_seqlens[subsample_seqlens != 0] | ||
|
|
||
| seqlens_gathered, offsets = self.get_global_seqlens(subsample_seqlens) | ||
|
|
||
| global_id_seqlens, global_ids_this_rank = self.get_global_id_seqlens( | ||
| subsample_seqlens.shape[0], offsets, seqlens_gathered | ||
| ) | ||
|
|
||
| groups, sample_id_groups = self.cp_balancing_scheduler.get_groups_and_subsamples( | ||
| global_id_seqlens, self.config | ||
| ) | ||
|
|
||
| batch = self.unpack_batch(batch) | ||
| samples_this_rank_with_id = self.reroute_samples_to_hdp_ranks( | ||
| batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets | ||
| ) | ||
| return samples_this_rank_with_id, sample_id_groups |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.