diff --git a/megatron/core/datasets/data_schedule.py b/megatron/core/datasets/data_schedule.py new file mode 100644 index 0000000000..0f016473b6 --- /dev/null +++ b/megatron/core/datasets/data_schedule.py @@ -0,0 +1,301 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. + +from typing import Any, List, Optional + +import torch + +from megatron.core import parallel_state +from megatron.core.pipeline_parallel.hybrid_cp_schedule import BalancedCPScheduler +from megatron.core.process_groups_config import ProcessGroupCollection + + +class HybridCPDataLoaderWrapper: + """ + A wrapper class that wraps around an existing data_iterator. + For every __next__ call, + 1. Each DP rank pulls a batch of packed samples. + 2. Extracts the sequence lengths of each sub-sample and all-gathers across the DP group. + 3. Schedules the sub-samples to the DPxCP ranks using the BalancedCPScheduler. + 4. Based on the schedule, reroutes the sub-samples to the correct rank using all-to-all. + 5. Returns the assigned sub-samples to this rank. + + Args: + data_iterator: The original data_iterator to wrap around + config: The config object containing the max_seqlen_per_dp_cp_rank + dp_cp_group: Data parallel context parallel group. + """ + + def __init__( + self, data_iterator, config, pg_collection: Optional[ProcessGroupCollection] = None + ): + self.data_iterator = data_iterator + self.config = config + if pg_collection is None: + self.dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) + self.dp_group = parallel_state.get_data_parallel_group() + self.tp_group = parallel_state.get_tensor_model_parallel_group() + else: + self.dp_cp_group = pg_collection.dp_cp + self.dp_group = pg_collection.dp + self.tp_group = pg_collection.tp + assert ( + self.dp_cp_group is not None and self.dp_group is not None and self.tp_group is not None + ), "dp_cp_group, dp_group, tp_group must not be None when using hybrid context parallel" + + self.cp_balancing_scheduler = BalancedCPScheduler( + max_seq_len_per_rank=self.config.max_seqlen_per_dp_cp_rank, dp_cp_group=self.dp_cp_group + ) + + self.total_hdp_gpus = self.dp_cp_group.size() + + def __iter__(self): + """Return self as an iterator.""" + return self + + def get_global_seqlens(self, subsample_seqlens: torch.Tensor) -> List[int]: + """ + Gathers the sequence lengths of all subsamples from all DP ranks. + Each DP rank loads the same number of microbatches but each microbatch + may have a different number of subsamples. + + We find the number of subsamples each rank holds and then gather the + sequence lengths of all subsamples from all ranks. + """ + # Collect the number of subsamples from all ranks + local_len = torch.tensor([subsample_seqlens.shape[0]], dtype=torch.int32).cuda() + dp_subsample_count = [torch.zeros_like(local_len) for _ in range(self.dp_group.size())] + torch.distributed.all_gather(dp_subsample_count, local_len, group=self.dp_group) + + # Find the max number of subsamples across all ranks and pad subsample_seqlens to max length + dp_subsample_counts = torch.stack(dp_subsample_count, dim=0).cpu().view(-1) + max_sub_samples = int(dp_subsample_counts.max().item()) + + if local_len.item() < max_sub_samples: + subsample_seqlens_padded = torch.cat( + [ + subsample_seqlens, + torch.zeros(max_sub_samples - local_len.item(), dtype=torch.int32).cuda(), + ], + dim=0, + ) + else: + subsample_seqlens_padded = subsample_seqlens + + # Gather the subsample_seqlens from all ranks + seqlens_gathered = [ + torch.empty_like(subsample_seqlens_padded) for _ in range(self.dp_group.size()) + ] + torch.distributed.all_gather( + seqlens_gathered, subsample_seqlens_padded, group=self.dp_group + ) + + # Trim each seqlens_gathered to the length of the correct sample + for dp_rank, seqlen in enumerate(seqlens_gathered): + seqlens_gathered[dp_rank] = seqlen[: dp_subsample_counts[dp_rank]] + + seqlens_gathered = torch.cat(seqlens_gathered, dim=0) + seqlens_gathered = seqlens_gathered.cpu().tolist() + + # Calculate the offsets to assign unique global ID to each subsample. + csum = torch.cumsum(dp_subsample_counts, dim=0, dtype=torch.int32) + offsets = torch.cat([torch.zeros(1, dtype=torch.int32), csum[:-1]], dim=0) + + return seqlens_gathered, offsets + + def get_global_id_seqlens(self, num_local_subsamples, offsets, seqlens_gathered): + """ + Calculates the global ID for each subsample. + + We assign a unique global ID to each subsample. + + Returns: + global_id_seqlens: list of (global_id, seqlen) tuples for scheduling. + global_ids_this_rank: list of global IDs locally present on this rank. + """ + dp_rank = self.dp_group.rank() + global_ids = torch.arange(len(seqlens_gathered), dtype=torch.int32).cuda() + # Create a list of (global_id, seqlen) tuples for scheduling + global_id_seqlens = [(i, seqlens_gathered[i]) for i in range(len(global_ids))] + # Get the global IDs locally present on this rank + global_ids_this_rank = global_ids[ + offsets[dp_rank] : offsets[dp_rank] + num_local_subsamples + ] + + return global_id_seqlens, global_ids_this_rank + + def _gid_to_src_rank(self, gid: int, offsets: List[int]) -> int: + dp_src_rank = torch.bucketize(gid, offsets[1:] - 1) + # Since the torch.distributed.get_process_group_ranks + # provides the global rank, we need to consider TP + hdp_rank = ( + torch.distributed.get_process_group_ranks(self.dp_group)[dp_src_rank] + // self.tp_group.size() + ) + return hdp_rank + + def reroute_samples_to_hdp_ranks( + self, batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets + ): + """ + Reroutes the sub-samples to the correct rank after scheduling. + + For each key in the batch dict, we perform an all-to-all communication + to transfer the data to the correct ranks. + Since all CP ranks within a DP group have the same data, we only need + to transfer data between matching CP ranks. + """ + gid2local_id = {int(gid): i for i, gid in enumerate(global_ids_this_rank)} + hdp_rank = self.dp_cp_group.rank() + dp_ranks = torch.distributed.get_process_group_ranks(self.dp_group) + # Here we actually want to get the DP group's rank within the HDP group, + # we need to consider TP + dp_ranks = [r // self.tp_group.size() for r in dp_ranks] + + data_keys = batch[0].keys() + + # Create the send plan + combined_sample_id_groups: List[List[int]] = [[] for _ in range(self.total_hdp_gpus)] + + for d in range(self.total_hdp_gpus): + for sample_id_group in sample_id_groups: + combined_sample_id_groups[d].extend(sample_id_group[d]) + + for dest_rank in range(self.total_hdp_gpus): + combined_sample_id_groups[dest_rank].sort() + + # Filter out samples that are not present on this rank + send_ids_sorted = [ + gid + for d in dp_ranks + for gid in combined_sample_id_groups[d] + if gid in global_ids_this_rank + ] + # send_counts = [len(combined_sample_id_groups[d]) for d in range(self.total_hdp_gpus)] + + send_lens_split = [0] * self.total_hdp_gpus + for dest_rank in range(self.total_hdp_gpus): + if dest_rank in dp_ranks: + send_lens_split[dest_rank] = sum( + [ + global_id_seqlens[gid][1] + for gid in combined_sample_id_groups[dest_rank] + if gid in global_ids_this_rank + ] + ) + else: + # We only need to share local data with DP ranks that have different data. + send_lens_split[dest_rank] = 0 + + # Create the recv plan + recv_sample_id_groups = [[] for _ in range(self.total_hdp_gpus)] + for gid in combined_sample_id_groups[hdp_rank]: + src_rank = self._gid_to_src_rank(gid, offsets) + recv_sample_id_groups[src_rank].append(gid) + + recv_lens_split = [0] * self.total_hdp_gpus + for src_rank in range(self.total_hdp_gpus): + recv_lens_split[src_rank] = sum( + [global_id_seqlens[gid][1] for gid in recv_sample_id_groups[src_rank]] + ) + + recv_ids_sorted = [ + gid for d in range(self.total_hdp_gpus) for gid in recv_sample_id_groups[d] + ] + recv_counts = [len(recv_sample_id_groups[d]) for d in range(self.total_hdp_gpus)] + + recv_samples = [{k: None for k in data_keys} for _ in range(sum(recv_counts))] + + def _pack_sample_by_key(key: str) -> torch.Tensor: + flattened_tensors = [] + for gid in send_ids_sorted: + t = batch[gid2local_id[gid]][key].to(torch.cuda.current_device(), non_blocking=True) + flattened_tensors.append(t) + return ( + torch.cat(flattened_tensors, dim=0) + if flattened_tensors + else torch.empty(0, device=torch.cuda.current_device(), dtype=batch[0][key].dtype) + ) + + def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor): + cursor = 0 + for i, gid in enumerate(recv_ids_sorted): + sample_len = global_id_seqlens[gid][1] + recv_samples[i][key] = recv_tensor[cursor : cursor + sample_len] + cursor += sample_len + + for key in data_keys: + send_tensor = _pack_sample_by_key(key) + recv_tensor = torch.empty( + sum(recv_lens_split), device=torch.cuda.current_device(), dtype=send_tensor.dtype + ) + torch.distributed.all_to_all_single( + output=recv_tensor, + input=send_tensor, + output_split_sizes=recv_lens_split, + input_split_sizes=send_lens_split, + group=self.dp_cp_group, + ) + _unpack_sample_by_key(key, recv_tensor) + + recv_sample_with_id = { + recv_id: recv_samples[i] for i, recv_id in enumerate(recv_ids_sorted) + } + return recv_sample_with_id + + def unpack_batch(self, batch): + """ + Unpacks the packed samples into a list of sub-samples. + Since each sub-sample may be routed to different DPxCP ranks, + we unpack the sample here to avoid unnecessarily transferring + the entire packed sample. + """ + batch_unpacked = [] + for sample in batch: + for sub_sample in range(sample["cu_seqlens"].shape[0] - 1): + sub_sample_dict = {} + start_idx = sample["cu_seqlens"][sub_sample] + end_idx = sample["cu_seqlens"][sub_sample + 1] + if end_idx - start_idx == 0: + continue + for key in sample.keys(): + if key in ["cu_seqlens", "batch_idx", "max_seqlen"]: + continue + sub_sample_dict[key] = sample[key][start_idx:end_idx] + batch_unpacked.append(sub_sample_dict) + return batch_unpacked + + def __next__(self) -> Any: + """ + Get the next item from the dataset, pull scheduling metadata and return it. + """ + if self.data_iterator is None: + # TP0 reads from data_iterator, others receive via broadcast. + return None, None + else: + batch = next(self.data_iterator) + subsample_seqlens = [] + for sample in batch: + subsample_seqlens.extend( + [ + int(sample["cu_seqlens"][i + 1] - sample["cu_seqlens"][i]) + for i in range(0, sample["cu_seqlens"].shape[0] - 1) + ] + ) + subsample_seqlens = torch.tensor(subsample_seqlens, dtype=torch.int32).cuda() + subsample_seqlens = subsample_seqlens[subsample_seqlens != 0] + + seqlens_gathered, offsets = self.get_global_seqlens(subsample_seqlens) + + global_id_seqlens, global_ids_this_rank = self.get_global_id_seqlens( + subsample_seqlens.shape[0], offsets, seqlens_gathered + ) + + groups, sample_id_groups = self.cp_balancing_scheduler.get_groups_and_subsamples( + global_id_seqlens, self.config + ) + + batch = self.unpack_batch(batch) + samples_this_rank_with_id = self.reroute_samples_to_hdp_ranks( + batch, global_ids_this_rank, global_id_seqlens, sample_id_groups, offsets + ) + return samples_this_rank_with_id, sample_id_groups diff --git a/megatron/core/datasets/gpt_dataset.py b/megatron/core/datasets/gpt_dataset.py index 710a4c684f..f50a6a77f5 100644 --- a/megatron/core/datasets/gpt_dataset.py +++ b/megatron/core/datasets/gpt_dataset.py @@ -49,6 +49,24 @@ class GPTDatasetConfig(BlendedMegatronDatasetConfig): object_storage_cache_path: Optional[str] = None """Path for caching indices for s3 or msc dataloading.""" + context_parallel_size: int = 1 + """Option to enable context parallelism""" + + data_parallel_size: int = 1 + """Option to enable data parallelism""" + + sequence_parallel_size: int = 0 + """Option to indicate the sequence parallelism size when using TP + Set to 0 if sequence parallel is not enabled regardless of TP size. + """ + + hybrid_context_parallel: bool = False + """Option to enable hybrid context parallelism. When setting this to True, + each sample should be divisible by the data parallel size * context parallel size * 2. + If sequence parallel is enabled, it should be divisible by the + data parallel size * context parallel size * sequence parallel size * 2. + """ + def __post_init__(self) -> None: """Do asserts and set fields post init""" super().__post_init__() diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 85732c0f7e..80d38d61bf 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -1007,6 +1007,7 @@ def __init__( self.kept_packed_seq_params = set( field.name for field in dataclasses.fields(PackedSeqParams) ) + if get_te_version() < PkgVersion("1.3.0"): # TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H # copies (#555) @@ -1057,6 +1058,25 @@ def forward( packed_seq_params: PackedSeqParams = None, ): """Forward.""" + if packed_seq_params is not None: + # If Dynamic CP group is provided, update TE DPA CP group + if packed_seq_params.cp_group is not None: + self.cp_group = packed_seq_params.cp_group + super().set_context_parallel_group( + self.cp_group, + torch.distributed.get_process_group_ranks(self.cp_group), + TEDotProductAttention.cp_stream, + self.cp_comm_type, + ) + # If cp_group is None but local_cp_size is provided, + # Indicates to turn off CP dynamically + elif packed_seq_params.local_cp_size is not None: + assert ( + packed_seq_params.local_cp_size == 1 + ), "local_cp_size must be == 1 if provided without cp_group" + super().set_context_parallel_group(None, None, None, self.cp_comm_type) + self.kept_packed_seq_params.discard("cp_group") + self.kept_packed_seq_params.discard("local_cp_size") packed_seq_kwargs = ( {key: getattr(packed_seq_params, key) for key in self.kept_packed_seq_params} if packed_seq_params is not None diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index e31fcd2577..c799f16e30 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -53,6 +53,22 @@ class ModelParallelConfig: type. """ + max_seqlen_per_dp_cp_rank: Optional[int] = None + """ + Maximum sequence length per DPxCP rank. This is the maximum sequence length each rank + can handle without overflowing the memory. Typically, a good starting point is to set this + to maximum sequence length / context parallel size. + This is used to calculate the number and length of sub-samples assigned to + each rank when using hybrid_context_parallel. + """ + + hybrid_context_parallel: bool = False + """ + If true, enables hybrid context parallel. This is used to balance the workload of + each CP rank when we use packed samples with variable sequence lengths. + Please set max_seqlen_per_dp_cp_rank when using hybrid_context_parallel. + """ + expert_model_parallel_size: int = 1 """Distributes Moe Experts across sub data parallel dimension.""" diff --git a/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/core/models/common/embeddings/rotary_pos_embedding.py index 0d7d5e626d..c7c452d2f8 100644 --- a/megatron/core/models/common/embeddings/rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/rotary_pos_embedding.py @@ -148,13 +148,12 @@ def get_cos_sin(self, max_seq_len: int, offset: int = 0) -> (Tensor, Tensor): return cos, sin @lru_cache(maxsize=32) - def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) -> Tensor: - """Forward pass of RoPE embedding. + def get_emb(self, max_seq_len: int, offset: int = 0) -> Tensor: + """Forward pass of RoPE embedding before CP sharding. Args: max_seq_len (int): Maximum size of sequence offset (int, optional): RoPE offset. Defaults to 0. - packed_seq (bool, optional): Whether to use packed sequence. Defaults to False. Returns: Tensor: Embeddings after applying RoPE. @@ -174,10 +173,34 @@ def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) - ) # emb [seq_length, .., dim] emb = emb[:, None, None, :] - if self.cp_group is not None and self.cp_group.size() > 1 and not packed_seq: - # slice rotary_pos_emb along sequence dimension and select the parition of the current - # CP rank - emb = get_pos_emb_on_this_cp_rank(emb, 0, self.cp_group) + return emb + + def forward( + self, max_seq_len: int, offset: int = 0, packed_seq_params: Optional[PackedSeqParams] = None + ) -> Tensor: + """Forward pass of RoPE embedding. + + Args: + max_seq_len (int): Maximum size of sequence + offset (int, optional): RoPE offset. Defaults to 0. + packed_seq_params (PackedSeqParams, optional): Packed sequence params. Defaults to None. + + Returns: + Tensor: Embeddings after applying RoPE. + """ + emb = self.get_emb(max_seq_len, offset) + packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' + if packed_seq_params is not None and packed_seq_params.local_cp_size is not None: + # Set CP group to dynamic CP group for CP slicing + cp_group = packed_seq_params.cp_group + else: + cp_group = self.cp_group + + if cp_group is not None and cp_group.size() > 1 and not packed_seq: + # slice rotary_pos_emb along sequence dimension + # and select the parition of the current CP rank + emb = get_pos_emb_on_this_cp_rank(emb, 0, cp_group) + return emb def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): @@ -279,13 +302,19 @@ def __init__( else parallel_state.get_context_parallel_group(check_initialized=False) ) - def forward(self, position_ids: torch.Tensor, mrope_section: List[int]) -> Tensor: + def forward( + self, + position_ids: torch.Tensor, + mrope_section: List[int], + packed_seq_params: Optional[PackedSeqParams] = None, + ) -> Tensor: """Forward pass of multimodal RoPE embedding. Args: position_ids (torch.Tensor): A postion_id tensor with shape [3, batchsize, seqlens] mrope_section (list[int]): Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + packed_seq_params (PackedSeqParams, optional): Packed sequence params. Defaults to None. Returns: Tensor: Embeddings after applying RoPE. @@ -318,8 +347,17 @@ def forward(self, position_ids: torch.Tensor, mrope_section: List[int]) -> Tenso # shape (seq_length, bs, 1, 2 * dim) emb = emb[..., None, :].transpose(0, 1).contiguous() - if self.cp_group is not None and self.cp_group.size() > 1: + if packed_seq_params is not None and packed_seq_params.local_cp_size is not None: + if packed_seq_params.local_cp_size > 1: + # Set CP group to dynamic CP group for CP slicing + cp_group = packed_seq_params.cp_group + else: + # Set CP group to None to avoid CP slicing + cp_group = None + else: + cp_group = self.cp_group + if cp_group is not None and cp_group.size() > 1: # slice rotary_pos_emb along sequence dimension and select the parition of the current # CP rank - emb = get_pos_emb_on_this_cp_rank(emb, 0, self.cp_group) + emb = get_pos_emb_on_this_cp_rank(emb, 0, cp_group) return emb diff --git a/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py b/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py index bcbb74b0df..4e45de72c7 100644 --- a/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py +++ b/megatron/core/models/common/embeddings/yarn_rotary_pos_embedding.py @@ -99,13 +99,12 @@ def __init__( ) @lru_cache(maxsize=32) - def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) -> Tensor: + def get_emb(self, max_seq_len: int, offset: int = 0) -> Tensor: """Forward pass of Yarn Rotary Embedding. Args: max_seq_len (int): Maximum size of sequence offset (int, optional): RoPE offset. Defaults to 0. - packed_seq (bool, optional): Whether to use packed sequence. Defaults to False. Returns: Tensor: Embeddings after applying Yarn RoPE. @@ -151,19 +150,43 @@ def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) - emb = torch.cat((freqs, freqs), dim=-1) # emb [seq_length, .., dim] emb = emb[:, None, None, :] - if self.cp_group is not None and self.cp_group.size() > 1 and not packed_seq: + return emb, _mscale + + def forward( + self, max_seq_len: int, offset: int = 0, packed_seq_params: Optional[PackedSeqParams] = None + ) -> Tensor: + """Forward pass of Yarn Rotary Embedding. + + Args: + max_seq_len (int): Maximum size of sequence + offset (int, optional): RoPE offset. Defaults to 0. + packed_seq_params (PackedSeqParams, optional): Packed sequence params. Defaults to None. + + Returns: + Tensor: Embeddings after applying Yarn RoPE. + """ + emb, _mscale = self.get_emb(max_seq_len, offset) + packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' + if packed_seq_params is not None and packed_seq_params.local_cp_size is not None: + # Set CP group to dynamic CP group for CP slicing + cp_group = packed_seq_params.cp_group + else: + cp_group = self.cp_group + if cp_group is not None and cp_group.size() > 1 and not packed_seq: # slice rotary_pos_emb along sequence dimension # and select the parition of the current CP rank - emb = get_pos_emb_on_this_cp_rank(emb, 0, self.cp_group) + emb = get_pos_emb_on_this_cp_rank(emb, 0, cp_group) return emb, _mscale - def _set_cos_sin_cache(self, seq_len, offset, dtype, packed_seq=False): + def _set_cos_sin_cache(self, seq_len, offset, dtype, packed_seq_params=None): self.max_seq_len_cached = seq_len self.offset_cached = offset self.dtype_cached = dtype - self.packed_seq_cached = packed_seq + self.packed_seq_cached = ( + packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' + ) - emb, _mscale = self.forward(seq_len, offset, packed_seq) + emb, _mscale = self.forward(seq_len, offset, packed_seq_params) self.register_buffer( "cos_cached", (emb.cos() * _mscale).to(dtype).contiguous(), persistent=False ) @@ -172,16 +195,17 @@ def _set_cos_sin_cache(self, seq_len, offset, dtype, packed_seq=False): ) def get_cached_cos_sin( - self, seq_len, offset=0, dtype=torch.get_default_dtype(), packed_seq=False + self, seq_len, offset=0, dtype=torch.get_default_dtype(), packed_seq_params=None ): """Get cached cos and sin values.""" + packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' if ( seq_len > self.max_seq_len_cached or offset != self.offset_cached or dtype != self.dtype_cached or packed_seq != self.packed_seq_cached ): - self._set_cos_sin_cache(seq_len, offset, dtype, packed_seq) + self._set_cos_sin_cache(seq_len, offset, dtype, packed_seq_params) return (self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index e840fca99b..9491196c9a 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -344,16 +344,16 @@ def _preprocess( inference_context, self.decoder, decoder_input, self.config, packed_seq_params ) rotary_pos_emb = self.rotary_pos_emb( - rotary_seq_len, - packed_seq=packed_seq_params is not None - and packed_seq_params.qkv_format == 'thd', + rotary_seq_len, packed_seq_params=packed_seq_params ) elif self.position_embedding_type == 'yarn': if self.training or not self.config.flash_decode: rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( inference_context, self.decoder, decoder_input, self.config, packed_seq_params ) - rotary_pos_emb, _ = self.rotary_pos_emb(rotary_seq_len) + rotary_pos_emb, _ = self.rotary_pos_emb( + rotary_seq_len, packed_seq_params=packed_seq_params + ) else: raise NotImplementedError( "Flash decoding uses precomputed cos and sin for RoPE, not implemented in " @@ -361,7 +361,9 @@ def _preprocess( ) elif self.position_embedding_type == 'mrope' and not self.config.multi_latent_attention: if self.training or not self.config.flash_decode: - rotary_pos_emb = self.rotary_pos_emb(position_ids, self.mrope_section) + rotary_pos_emb = self.rotary_pos_emb( + position_ids, self.mrope_section, packed_seq_params=packed_seq_params + ) else: # Flash decoding uses precomputed cos and sin for RoPE raise NotImplementedError( diff --git a/megatron/core/packed_seq_params.py b/megatron/core/packed_seq_params.py index 330d0e0347..08ebdac67d 100644 --- a/megatron/core/packed_seq_params.py +++ b/megatron/core/packed_seq_params.py @@ -1,6 +1,7 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. from dataclasses import dataclass +import torch.distributed as dist from torch import Tensor @@ -18,3 +19,5 @@ class PackedSeqParams: cu_seqlens_kv_padded: Tensor = None max_seqlen_q: int = None max_seqlen_kv: int = None + local_cp_size: int = None + cp_group: dist.ProcessGroup = None diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index 1e41bf9d8c..ab3082d4dd 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -6,6 +6,7 @@ import os import warnings from datetime import timedelta +from math import log2 from typing import Callable, List, Optional import numpy as np @@ -110,6 +111,8 @@ _CONTEXT_PARALLEL_GLOBAL_RANKS = None # Hierarchical context parallel groups _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS = None +# Hybrid context parallel groups +_HYBRID_DP_CP_GROUPS = {} # Data parallel group information with context parallel combined. _DATA_PARALLEL_GROUP_WITH_CP = None @@ -410,6 +413,31 @@ def create_hierarchical_groups( return hierarchical_groups, hierarchical_groups_gloo +def create_hybrid_dp_cp_groups(rank, ranks, pg_options): + """ + Creates groups required for hybrid DPxCP. + Creates a new group for every power of 2 up to the number of DPxCP ranks. + Returns a dictionary indexed by group size. + """ + hybrid_dp_cp_groups = {} + # Generate group for every power of 2 up to the number of CP ranks + # We limit the allowed group sizes in order to avoid excessive overhead. + group_sizes = [2**i for i in range(int(log2(len(ranks))))][1:] + for group_size in group_sizes: + for i in range(0, len(ranks), group_size): + group = create_group( + ranks[i : i + group_size], + pg_options=pg_options, + group_desc=f"HYBRID_DP_CP_GROUP_{group_size}", + ) + if rank in ranks[i : i + group_size]: + assert ( + group_size not in hybrid_dp_cp_groups + ), f"Rank {rank} appears in multiple Hybrid DP CP groups of size {group_size}" + hybrid_dp_cp_groups[group_size] = group + return hybrid_dp_cp_groups + + class RankGenerator(object): """A class for generating rank groups for different modes of parallelism.""" @@ -519,6 +547,7 @@ def initialize_model_parallel( use_sharp: bool = False, context_parallel_size: int = 1, hierarchical_context_parallel_sizes: Optional[List[int]] = None, + hybrid_context_parallel: bool = False, expert_model_parallel_size: int = 1, num_distributed_optimizer_instances: int = 1, expert_tensor_parallel_size: Optional[int] = None, @@ -881,6 +910,19 @@ def initialize_model_parallel( if "NCCL_COLLNET_ENABLE" in os.environ: del os.environ["NCCL_COLLNET_ENABLE"] + if hybrid_context_parallel: + global _HYBRID_DP_CP_GROUPS + for ranks_with_cp in decoder_rank_generator.get_ranks('dp-cp'): + assert ( + len(ranks_with_cp) % 2 == 0 + ), "Hybrid context parallel requires an even number of ranks" + _HYBRID_DP_CP_GROUPS.update( + create_hybrid_dp_cp_groups( + rank, ranks_with_cp, get_nccl_options("dp_cp", nccl_comm_cfgs) + ) + ) + # TODO: Are gloo groups needed for hybrid cp? + for ranks in decoder_rank_generator.get_ranks('dp'): group = create_group( ranks, @@ -1395,6 +1437,18 @@ def get_hierarchical_context_parallel_groups(check_initialized=True): return _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS +def get_hybrid_data_context_parallel_groups(check_initialized=True, group_size=None): + """Get the hybrid context parallel groups the caller rank belongs to.""" + # If the group size is the same as the entire DPxCP group, return the original group + if get_data_parallel_world_size(with_context_parallel=True) == group_size: + if check_initialized: + assert _DATA_PARALLEL_GROUP_WITH_CP is not None + return _DATA_PARALLEL_GROUP_WITH_CP + if check_initialized: + assert _HYBRID_DP_CP_GROUPS is not None + return _HYBRID_DP_CP_GROUPS[group_size] + + def get_embedding_group(check_initialized=True): """Get the embedding group the caller rank belongs to.""" if check_initialized: diff --git a/megatron/core/pipeline_parallel/hybrid_cp_schedule.py b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py new file mode 100644 index 0000000000..27b5fc8794 --- /dev/null +++ b/megatron/core/pipeline_parallel/hybrid_cp_schedule.py @@ -0,0 +1,660 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +from collections import deque +from functools import lru_cache +from math import ceil, log2 +from typing import Callable, List, Optional, Tuple + +import torch + +from megatron.core import parallel_state +from megatron.core.rerun_state_machine import RerunDataIterator + + +class BalancedCPScheduler: + """ + This class provides the functionality to form groups of sub-samples + such that all DPxCP ranks have a roughly balanced workload in the group. + """ + + def __init__(self, max_seq_len_per_rank: int, dp_cp_group: torch.distributed.ProcessGroup): + self.max_seq_len_per_rank = max_seq_len_per_rank + self.num_subsamples = 0 + self.num_subsamples_processed = 0 + self.free_resources = [] + self.total_hdp_gpus = dp_cp_group.size() + + @lru_cache(maxsize=128) + def get_total_workload(self, seq_length: int, cp_size: Optional[int] = None): + """ + seq_length: sequence length of a sub-sample + cp_size: total number of CP ranks working on this sub-sample + + Note: + This function is used to estimate the relative workload intensity + of a sub-sample. This is not meant to be an accurate flops calculator. + + Returns: workload of a sub-sample + """ + if cp_size is None: + cp_size = self.gpus_needed(seq_length) + return (seq_length * seq_length) / cp_size + + @lru_cache(maxsize=128) + def gpus_needed(self, seq_len: int) -> int: + """ + Calculates the number of GPUs needed for a given sequence length + and max sequence length per CP rank. + This is used to determine the CP size of a sub-sample. + + The number is rounded up to the next power of 2 to match the available + hybrid context parallel process group sizes. + """ + return max(1, 2 ** ceil(log2((seq_len / self.max_seq_len_per_rank)))) + + def make_buckets_equal( + self, + sample_seqlens: List[Tuple[int, int]], # List of (sample_id, sequence_length) tuples + compute_estimator: Callable[[int], float], + ) -> List[deque]: + """ + Makes as many buckets as unique CP sizes needed. + This keeps sample IDs tethered to their sequence lengths throughout the bucketing process. + """ + # Extract just the sequence lengths for determining k + seqlens = [seq_len for _, seq_len in sample_seqlens] + + # Determine k based on unique GPU categories needed + k = len({self.gpus_needed(L) for L in seqlens}) + + # Create a work target for each bucket + # This is the total work divided by the number of buckets + work = [] + for _, s in sample_seqlens: + cp_size = self.gpus_needed(s) + work.append(compute_estimator(s, cp_size)) + total_work = sum(work) + target = total_work / k + buckets, cur, cur_work = [], [], 0.0 + remaining_work = total_work + remaining_k = k + + for i, (sample_id, seq_len) in enumerate(sample_seqlens): + work = compute_estimator(seq_len) + projected = cur_work + work + + # Check if we should close this bucket + if cur and ( + projected > target * 1.1 # Too much work + or len(sample_seqlens) - i <= remaining_k - len(buckets) + ): # Need to save sequences for remaining buckets + buckets.append(deque(cur)) + cur, cur_work = [], 0.0 + remaining_work -= sum(compute_estimator(seq_len) for _, seq_len in cur) + remaining_k -= 1 + + cur.append((sample_id, seq_len)) + cur_work += work + + if cur: + buckets.append(deque(cur)) + + return buckets + + def next_hdp_group( + self, + sample_seqlens: List[Tuple[int, int]], # List of (sample_id, sequence_length) tuples + compute_estimator: Callable[[int], float], + total_gpus: int, + delta: float = 0.05, # balance slack (e.g. 5 %) + strategy: str = "dp", # "dp" or "pp" + eps_bucket: float = 0.10, # ε target for bucket balance + ) -> Tuple[List[List[int]], List[Tuple[int, int]], List[float], List[List[int]]]: + """ + Given a list of (sample_id, sequence_length) tuples, this function aims to assign + sequences in a group such that all GPUs in the DPxCP group have a roughly balanced + workload. Once each group is roughly balanced, we exit and return the + group and the leftover sequences. + + The function performs the following passes in order to form a balanced microbatch: + 1. We create buckets of sequences that are roughly balanced. + We try to create as many buckets as possible CP sizes. + 2. Given a bucket has sequences available, we assign the sample + a. To a new set of GPUs if there are enough free GPUs. + b. To an existing set of GPUs with the lowest load. + 3. We check if the group is balanced whenever we need to move onto a new CP size + in the same set of GPUs. + 4. We trim the group if removing the last added sequence helps improve balance. + 5. If we run out of sequences to assign and there are empty GPUs, + we redistribute work to empty GPUs by recursively increasing the CP size of a + sample until no empty GPUs are left. + + Returns (micro_batches, leftover_sample_seqlens, exec_times, sample_ids_per_gpu). + """ + if not sample_seqlens: + return ( + [[] for _ in range(total_gpus)], + [], + [0.0 for _ in range(total_gpus)], + [[] for _ in range(total_gpus)], + ) + + # Get buckets of sequences with balanced work + buckets = self.make_buckets_equal(sample_seqlens, compute_estimator) + + # Initialize tracking structures + micro_batches = [[] for _ in range(total_gpus)] + exec_times = [0.0 for _ in range(total_gpus)] + sample_ids_per_gpu = [[] for _ in range(total_gpus)] + + gpu_group_id = [None] * total_gpus + group_members = {} + group_size = {} + next_gid = 0 + + pp_cursor = 0 + prev_needed = None + check_balance = False + + while buckets: + # ---- Step 1 – pick the next sequence we COULD place ------------------ + sample_seq_tuple = bucket_idx = None + needed = None + + scan_order = ( + range(len(buckets)) + if strategy == "dp" + else [(pp_cursor + i) % len(buckets) for i in range(len(buckets))] + ) + + for idx in scan_order: + if not buckets[idx]: + continue + cand_tuple = buckets[idx][0] # This is now (sample_id, seq_len) + cand_seq_len = cand_tuple[1] + needed = self.gpus_needed(cand_seq_len) + + # (a) Do we have an *existing* group of size `needed`? + candidate_gids = [gid for gid, sz in group_size.items() if sz == needed] + + # (b) Or enough completely free GPUs to start a new group? + free_ranks = [r for r, gid in enumerate(gpu_group_id) if gid is None] + if candidate_gids or len(free_ranks) >= needed: + sample_seq_tuple, bucket_idx = cand_tuple, idx + break + + # No place to put any remaining sequence – finish this micro‑batch + if sample_seq_tuple is None: + break + + # TODO[pmannan]: PP not yet supported. Add PP scheduling. + if strategy == "pp": + pp_cursor = (bucket_idx + 1) % len(buckets) + + sample_id, seq_len = sample_seq_tuple + needed = self.gpus_needed(seq_len) + if prev_needed is None: + prev_needed = needed + + # (a) Existing groups of exactly this size + candidate_gids = [gid for gid, sz in group_size.items() if sz == needed] + if candidate_gids: + best_gid, best_load = min( + ( + (gid, max(exec_times[r] for r in group_members[gid])) + for gid in candidate_gids + ), + key=lambda t: t[1], + ) + else: + best_gid, best_load = None, float("inf") + + # (b) Hypothetical **new** group from completely free GPUs + free_ranks = [r for r, gid in enumerate(gpu_group_id) if gid is None] + if len(free_ranks) >= needed: + free_sorted = sorted(free_ranks, key=lambda r: exec_times[r]) + new_members = free_sorted[:needed] + new_load = exec_times[new_members[-1]] + + if new_load < best_load: + best_gid = None + chosen_members = new_members + else: + chosen_members = group_members[best_gid] + else: + chosen_members = group_members[best_gid] + + # ---- Step 2 – if we decided to create a fresh group ---------------- + if best_gid is None: + best_gid = next_gid + next_gid += 1 + group_members[best_gid] = chosen_members + group_size[best_gid] = needed + for r in chosen_members: + gpu_group_id[r] = best_gid + + # ---- Step 3 – assign the sequence to every member of that group ------ + per_gpu_cost = compute_estimator(seq_len) + + for r in chosen_members: + micro_batches[r].append(seq_len) + exec_times[r] += per_gpu_cost + sample_ids_per_gpu[r].append(sample_id) + + # Remove the sequence definitively from its bucket + buckets[bucket_idx].popleft() + + # ---- Step 4 – tidy, balance‑check, maybe early‑exit ------------------ + while buckets and not buckets[0]: + buckets.pop(0) + pp_cursor %= max(1, len(buckets)) + + # TODO: Removing this helps reduce the number of groups when we have + # lots of samples with same CP size. + # But because we don't exit as soon as we get balanced, + # even if there is one group available that can take the next sample, + # we will keep adding samples to the same group. + # trim_overload() does not help because it only checks if removing the + # last added sample helps. + # We cannot check after adding every sample because there will always be imbalance + # if we don't wait for future scheduling. + + # IMPORTANT: So we need a solution here + if needed < prev_needed: + # When we get into a lower CP size in the same group, + # we can start checking for balance. There is still a gotcha here. + # Let's say we have a group of 3 GPU 0-2, then we move onto group of 2. + # We keep assigning group of 2 as we do in descending order but GPU 7/15 + # never sees a microbatch assigned to it + # until we run out of samples with CP2. + # This means we are never balanced as min(exec_times) will always be 0. + # We need a smart way of identifying that we have run out of big samples + # and if we are having to assign work to a GPU already working, + # is it because there are empty GPUs? + # Would assigning work to empty GPUs first by moving onto next CP bucket help? + # But we need to remember to come back to this CP size bucket and then + # check for balance. Maybe the scheduling algorithm should look at empty + # GPUs and find work rather than going sequence by sequence. + check_balance = True + + if ( + check_balance + and buckets + and max(exec_times) - min(exec_times) <= delta * max(exec_times) + ): + break + + # Gather leftovers (flatten remaining buckets, preserve order) + leftovers = [] + for b in buckets: + for sample_seq_tuple in b: + leftovers.append(sample_seq_tuple) + + # --------------------------------------------------------------------------- + def trim_overload(): + """ + Iteratively pop the most‑recent sequence from the *most‑loaded group* + whenever doing so reduces the global slack. + """ + while True: + cur_max = max(exec_times) + cur_min = min(exec_times) + cur_slack = cur_max - cur_min + if cur_slack <= delta * cur_max: + # Slack is already within limit. + break + if cur_min == 0: + # There are empty GPUs that will be + # handled in the next step. + break + + max_r = exec_times.index(cur_max) + gid = gpu_group_id[max_r] + members = group_members[gid] + + if not micro_batches[max_r] or len(micro_batches[max_r]) <= 1: + break + + seq = micro_batches[max_r][-1] + need = group_size[gid] + per_gpu_cost = compute_estimator(seq) + + proj_times = exec_times[:] + for r in members: + proj_times[r] -= per_gpu_cost + + proj_slack = max(proj_times) - min(proj_times) + + # Check if trimming the workload helps imbalance + if proj_slack < cur_slack: + sample_id_to_remove = sample_ids_per_gpu[max_r][-1] + for r in members: + micro_batches[r].pop() + exec_times[r] -= per_gpu_cost + sample_ids_per_gpu[r].pop() + leftovers.append((sample_id_to_remove, seq)) + else: + break + + trim_overload() + + # Track samples in this group before redistribution to empty GPUs + total_work_before = sum(len(mb) for mb in micro_batches) + + # Check for empty GPUs and redistribute work + def fill_empty_gpus( + micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size + ): + """ + Recursively check for empty GPUs and redistribute work by increasing + the number of GPUs sharing samples. This ensures all GPUs have work. + GPUs must be allocated consecutively so we may need to push existing + work to other ranks in order to expand samples. + """ + # Find empty GPUs + empty_gpus = [i for i in range(total_gpus) if not micro_batches[i]] + if not empty_gpus: + return ( + micro_batches, + exec_times, + sample_ids_per_gpu, + group_members, + group_size, + ) # No empty GPUs, we're done + + # Find the smallest group size that exists + existing_group_sizes = set(group_size.values()) + assert ( + existing_group_sizes + ), "There should be at least one group existing, cannot reditribute, " + "try to increase 'max-seqlen-per-cp-rank'." + + min_group_size = min(existing_group_sizes) + # We have Hybrid DPxCP groups for every power of 2 of GPUs or the entire DPxCP group. + next_power = min(min_group_size * 2, total_gpus) + + # Find the first group of min_group_size that can be expanded + expandable_gid = None + expandable_members = None + expandable_new_gpus = None + + for gid, size in group_size.items(): + if size == min_group_size: + members = group_members[gid] + needed_count = next_power - min_group_size + group_start_gpu = members[0] + group_end_gpu = members[-1] + empty_gpu = [idx for idx, work in enumerate(micro_batches) if not work][0] + assert not all( + work for work in micro_batches[empty_gpu : empty_gpu + needed_count] + ), f"Empty GPUs were detected but not enough to expand." + work_to_push = micro_batches[ + group_end_gpu + 1 : empty_gpu + ] # This is work of all other subsequent sub-samples + exec_times_to_push = exec_times[group_end_gpu + 1 : empty_gpu] + sample_ids_to_push = sample_ids_per_gpu[group_end_gpu + 1 : empty_gpu] + + new_micro_batches = [[]] * len(micro_batches) + new_exec_times = [0.0] * len(exec_times) + new_sample_ids_per_gpu = [[]] * len(sample_ids_per_gpu) + + # No change in work until the group selected for expansion + for i in range(group_start_gpu): + new_micro_batches[i] = micro_batches[i] + new_exec_times[i] = exec_times[i] + new_sample_ids_per_gpu[i] = sample_ids_per_gpu[i] + + # The work is distributed across the expanded group + for i in range(group_start_gpu, group_end_gpu + needed_count + 1): + new_micro_batches[i] = micro_batches[group_end_gpu] + new_exec_times[i] = self.get_total_workload( + micro_batches[group_end_gpu][0], next_power + ) + new_sample_ids_per_gpu[i] = sample_ids_per_gpu[group_end_gpu] + + # Any assigned work on expanded GPUs is pushed + for i, work in enumerate(work_to_push): + new_micro_batches[group_end_gpu + needed_count + 1 + i] = work + new_exec_times[group_end_gpu + needed_count + 1 + i] = exec_times_to_push[i] + new_sample_ids_per_gpu[group_end_gpu + needed_count + 1 + i] = ( + sample_ids_to_push[i] + ) + + group_size[gid] = next_power + group_members[gid] = list(range(members[0], members[-1] + needed_count + 1)) + for pushed_gid in group_size.keys(): + if pushed_gid > gid: + group_members[pushed_gid] = [ + x + needed_count for x in group_members[pushed_gid] + ] + + return ( + new_micro_batches, + new_exec_times, + new_sample_ids_per_gpu, + group_members, + group_size, + ) + + empty_gpus = any([not micro_batches[i] for i in range(total_gpus)]) + while empty_gpus: + micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size = ( + fill_empty_gpus( + micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size + ) + ) + empty_gpus = any([not micro_batches[i] for i in range(total_gpus)]) + + # Assert that no sample has been completely removed + total_work_after = sum(len(mb) for mb in micro_batches) + assert ( + total_work_after >= total_work_before + ), f"Samples were removed: {total_work_before} -> {total_work_after}" + + return micro_batches, leftovers, exec_times, sample_ids_per_gpu + + def get_groups_and_subsamples(self, sample_id_seqlens, config): + """ + This function recursively forms groups of sub-samples such that all DPxCP ranks + have a roughly balanced workload in the group. + """ + groups = [] + sample_id_groups = [] + # We assign a sample_id to each sub-sample in order to track assignment to each GPU. + sample_id_seqlens = sorted(sample_id_seqlens, key=lambda x: x[1], reverse=True) + while sample_id_seqlens: + mb, sample_id_seqlens, exec_times, sample_ids = self.next_hdp_group( + sample_id_seqlens, self.get_total_workload, self.total_hdp_gpus + ) + groups.append(mb) + if len(sample_ids) < self.total_hdp_gpus: + sample_ids.extend([] * (self.total_hdp_gpus - len(sample_ids))) + sample_id_groups.append(sample_ids) + + return groups, sample_id_groups + + +def hybrid_context_parallel_forward_backward( + forward_step_func, + data_iterator, + model, + num_microbatches, + input_tensor, + output_tensor_grad, + forward_data_store, + config, + collect_non_loss_data, + first_val_step, + forward_only, + no_sync_func, + total_num_tokens, + check_first_val_step, + model_type, +): + """ + Scheduler for Hybrid Context Parallel. + + This function performs the packed sample scheduling and determines + 1. The number of microbatches to schedule for each CP rank + 2. The number of groups each CP rank should execute + 3. The number of sub-samples per group each CP rank should execute + + A group is defined by a set of samples that can run across the CP domain without any barrier. + There are many reasons why we may not be able to run endless samples within a single group. + For example, if we have 8 GPUs, + if GPU 0-5 are assigned a long sample that requires CP6, + GPU 6-7 are assigned a short sample that requires CP2, + The next sample which requires CP4 can be assigned GPU 4-7. + But GPU 6-7 will finish first and get deadlocked if GPU 4-5 are not participating in the group. + """ + from .schedules import backward_step, forward_step + + def _broadcast(item): + if item is not None: + torch.distributed.broadcast( + item, + parallel_state.get_tensor_model_parallel_src_rank(), + group=parallel_state.get_tensor_model_parallel_group(), + ) + + def _broadcast_num_samples_this_group(num_samples_this_group): + dev = torch.cuda.current_device() + torch.distributed.barrier() + + n = 0 if num_samples_this_group is None else int(num_samples_this_group.numel()) + n = torch.tensor([n], dtype=torch.int64, device=dev) + + _broadcast(n) + n = int(n.item()) + + assert n > 0, "there should be at least 1 sub samples in the group" + num_samples_this_group_broadcast = ( + torch.empty(n, dtype=torch.int32, device=dev) + if num_samples_this_group is None + else num_samples_this_group + ) + _broadcast(num_samples_this_group_broadcast) + return num_samples_this_group_broadcast + + def _get_new_data_iterator(sample_id_in_group, group_id): + if is_first_tp_rank: + sub_sample_id = sample_ids_this_group[sample_id_in_group] + sample = batch[sub_sample_id] + partner_cp_size = len( + [True for sample_ids in sample_id_groups[group_id] if sub_sample_id in sample_ids] + ) + sample["local_cp_size"] = torch.tensor(partner_cp_size, dtype=torch.int32) + new_data_iterator = RerunDataIterator(iter([sample])) + return new_data_iterator + else: + return None + + # We get data once per global batch and schedule the sub-samples. + # TODO(pmannan): Should we wrap the data_iterator here instead of the training.py file? + hdp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) + is_first_tp_rank = parallel_state.get_tensor_model_parallel_rank() == 0 + + if is_first_tp_rank: + data = next(data_iterator) + sample_id_groups = data[1] + batch = data[0] + else: + data, sample_id_groups, batch = None, None, None + + num_samples_this_group = None + if is_first_tp_rank: + num_samples_this_group = torch.tensor( + [len(group[hdp_rank]) for group in sample_id_groups], dtype=torch.int32, device='cuda' + ) + + num_samples_this_group = _broadcast_num_samples_this_group(num_samples_this_group) + num_samples_this_group = num_samples_this_group.cpu().numpy() + num_total_groups = num_samples_this_group.shape[0] + + current_microbatch = 0 + + # Upto last group, we don't need any sync. + with no_sync_func(): + for j in range(num_total_groups - 1): + sample_ids_this_group = sample_id_groups[j][hdp_rank] if is_first_tp_rank else None + for i in range(num_samples_this_group[j]): + # Call forward step for each sub-sample + new_data_iterator = _get_new_data_iterator(i, j) + # TODO: Find the usage of current_microbatch and is_first_microbatch and + # how that may affect my usage. + output_tensor, num_tokens = forward_step( + forward_step_func, + new_data_iterator, + model, + num_microbatches, + input_tensor, + forward_data_store, + config, + collect_non_loss_data, + is_first_microbatch=check_first_val_step( + first_val_step, forward_only, current_microbatch == 0 + ), + current_microbatch=current_microbatch, + ) + current_microbatch += 1 + total_num_tokens += num_tokens.item() + if not forward_only: + backward_step( + input_tensor, output_tensor, output_tensor_grad, model_type, config + ) + + # Create a barrier at end of each group. + # This barrier ensures that all ranks are prepared to change assigned CP group sizes and + # no rank is starting a sub-sample ahead of it's partner ranks. + torch.distributed.barrier( + parallel_state.get_data_parallel_group(with_context_parallel=True) + ) + + # For the last group, we need to run the last sub-sample out of the context handler. + with no_sync_func(): + sample_ids_this_group = sample_id_groups[-1][hdp_rank] if is_first_tp_rank else None + for i in range(num_samples_this_group[-1] - 1): + new_data_iterator = _get_new_data_iterator(i, -1) + # Call forward step for each sub-sample + output_tensor, num_tokens = forward_step( + forward_step_func, + new_data_iterator, + model, + num_microbatches, + input_tensor, + forward_data_store, + config, + collect_non_loss_data, + is_first_microbatch=check_first_val_step( + first_val_step, forward_only, current_microbatch == 0 + ), + current_microbatch=current_microbatch, + ) + current_microbatch += 1 + total_num_tokens += num_tokens.item() + if not forward_only: + backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) + + # The last sub-sample of the last group of the last microbatch is + # run out of the context handler. + new_data_iterator = _get_new_data_iterator(-1, -1) + # Call forward step for each sub-sample + output_tensor, num_tokens = forward_step( + forward_step_func, + new_data_iterator, + model, + num_microbatches, + input_tensor, + forward_data_store, + config, + collect_non_loss_data, + is_first_microbatch=check_first_val_step( + first_val_step, forward_only, current_microbatch == 0 + ), + current_microbatch=current_microbatch, + ) + total_num_tokens += num_tokens.item() + if not forward_only: + backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) + + return forward_data_store, total_num_tokens diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index d0b912349b..2b18d977ca 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -35,6 +35,7 @@ combined_1f1b_schedule_for_interleaved_pipelining, combined_1f1b_schedule_for_no_pipelining, ) +from .hybrid_cp_schedule import hybrid_context_parallel_forward_backward # Types Shape = Union[List[int], torch.Size] @@ -597,6 +598,24 @@ def forward_backward_no_pipelining( total_num_tokens, partial(check_first_val_step, first_val_step, forward_only), ) + elif config.hybrid_context_parallel: + forward_data_store, total_num_tokens = hybrid_context_parallel_forward_backward( + forward_step_func, + data_iterator, + model, + num_microbatches, + input_tensor, + output_tensor_grad, + forward_data_store, + config, + collect_non_loss_data, + first_val_step, + forward_only, + no_sync_func, + total_num_tokens, + check_first_val_step, + model_type, + ) else: with no_sync_func(): for i in range(num_microbatches - 1): diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 74031f3821..2ecd17db25 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -848,7 +848,7 @@ def forward( ) ) - if packed_seq_params is not None: + if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': query = query.squeeze(1) key = key.squeeze(1) value = value.squeeze(1) @@ -863,7 +863,7 @@ def forward( ): q_pos_emb, k_pos_emb = rotary_pos_emb - if packed_seq_params is not None: + if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': if packed_seq_params.cu_seqlens_q_padded is not None: cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded else: diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index 074523afd7..097cb8e57e 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -503,6 +503,11 @@ def get_query_key_value_tensors( assert ( hidden_states.ndim == 3 ), f"hidden_states should be 3D, [s, b, n*h], got {hidden_states.ndim}D" + if packed_seq_params is not None: + assert ( + packed_seq_params.local_cp_size is None + ), "hybrid_context_parallel is not supported with MLA yet and is planned for future. \ + Please disable hybrid_context_parallel." inference_context = deprecate_inference_params(inference_context, inference_params) @@ -519,11 +524,13 @@ def get_query_key_value_tensors( rotary_pos_sin = None packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' if self.config.rope_type == "rope": - rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq) + rotary_pos_emb = self.rotary_pos_emb( + rotary_seq_len, packed_seq_params=packed_seq_params + ) else: if self.config.apply_rope_fusion: rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb.get_cached_cos_sin( - rotary_seq_len, dtype=hidden_states.dtype, packed_seq=packed_seq + rotary_seq_len, dtype=hidden_states.dtype, packed_seq_params=packed_seq_params ) rotary_pos_emb = None assert inference_context is None, "Inference with MLA RoPE fusion is not supported" @@ -532,9 +539,11 @@ def get_query_key_value_tensors( and fused_apply_mla_rope_for_kv is not None ), "Fused MLA RoPE apply is not imported successfully" else: - rotary_pos_emb, mscale = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq) + rotary_pos_emb, mscale = self.rotary_pos_emb( + rotary_seq_len, packed_seq_params=packed_seq_params + ) - if packed_seq_params is not None: + if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': if packed_seq_params.cu_seqlens_q_padded is not None: cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded else: diff --git a/megatron/core/utils.py b/megatron/core/utils.py index 9b62b18d40..86712d72f9 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -59,6 +59,15 @@ logger = logging.getLogger(__name__) +try: + # Register the TE CUDA kernels + import transformer_engine # pylint: disable=unused-import + + # Alias the PyTorch wrapper so we can call tex.* APIs + import transformer_engine_torch as tex +except ImportError: + # TE isn’t installed or the torch wrapper is missing + tex = None try: _torch_version = PkgVersion(torch.__version__) @@ -1942,7 +1951,9 @@ def is_submodule(module, parent_module, strict=True): ######################## -def get_batch_on_this_cp_rank(batch: Dict[str, Any]): +def get_batch_on_this_cp_rank( + batch: Dict[str, Any], cp_size: Optional[int] = None, cp_rank: Optional[int] = None +): """Slice batch input along sequence dimension into multiple chunks, which are parallelized across GPUs in a context parallel group. """ @@ -1953,12 +1964,19 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]): # we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0 # and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so # that we can get balanced workload among GPUs in a context parallel group. - cp_size = parallel_state.get_context_parallel_world_size() - if cp_size > 1: + if cp_size is not None or cp_rank is not None: + assert ( + cp_size is not None and cp_rank is not None + ), "Both cp_size and cp_rank must be provided for batch slicing" + + if cp_size is None: + cp_size = parallel_state.get_context_parallel_world_size() + if cp_rank is None: cp_rank = parallel_state.get_context_parallel_rank() + if cp_size > 1: for key, val in batch.items(): if val is not None: - seq_dim = 1 if key != "attention_mask" else 2 + seq_dim = 1 if key != 'attention_mask' else 2 val = val.view( *val.shape[0:seq_dim], 2 * cp_size, @@ -1975,6 +1993,102 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]): return batch +def get_thd_batch_on_this_cp_rank( + batch: Dict[str, Any], + cu_seqlens: torch.Tensor, + cu_seqlens_padded: torch.Tensor, + max_seqlen: torch.Tensor, + cp_size: Optional[int] = None, + cp_rank: Optional[int] = None, +): + """Slice each sub-sample in a packed sample batch input along + sequence dimension into multiple chunks, which are parallelized + across GPUs in a context parallel group. + """ + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + max_seqlen_q=int(max_seqlen[0].item()), + max_seqlen_kv=int(max_seqlen[0].item()), + ) + + cp_size = get_context_parallel_world_size() if cp_size is None else cp_size + cp_rank = get_context_parallel_rank() if cp_rank is None else cp_rank + if cp_size > 1: # slice batch along sequence dimension for context parallelism + assert tex is not None and is_te_min_version("1.10.0"), ( + "Please update Transformer Engine to >= 1.10 to use " + "Context Parallel with THD format data" + ) + index = tex.thd_get_partitioned_indices( + cu_seqlens_padded, batch['tokens'].size(1), cp_size, cp_rank + ) + for key, data in batch.items(): + if key in {'attention_mask', 'cu_seqlens', 'cu_seqlens_padded', 'max_seqlen'}: + continue + batch[key] = data.index_select(1, index) + + return batch, packed_seq_params + + +################################ +### hybrid context parallel ### +################################ + + +def get_batch_on_this_hybrid_cp_rank( + batch: Dict[str, Any], + local_cp_size: int, + cp_group: Optional[torch.distributed.ProcessGroup] = None, +): + """Slice batch input along sequence dimension into multiple chunks, + which are parallelized across GPUs in a context parallel group. + """ + assert local_cp_size is not None + if cp_group is None: + # Get the local cp group required for as defined by the HybridCPDataLoaderWrapper + if local_cp_size > 1: + cp_group = parallel_state.get_hybrid_data_context_parallel_groups( + group_size=local_cp_size + ) + else: + # If cp group is provided, it must match the local cp size + # as defined by the HybridCPDataLoaderWrapper + assert cp_group.size() == local_cp_size + + # Convert [seqlen] to [1, seqlen] similar to default collate_fn + # as hybrid_context_parallel dataloader wrapper does not go through default collate_fn + for key, data in batch.items(): + if key in ['attention_mask']: + continue + batch[key] = torch.stack([data], 0) + sample_length = batch['tokens'].shape[1] + # TODO(pmannan): Take care of padding tokens here if not divisible by cp_size*2 + # Create packed_seq_params for SBHD format with cp group information. + packed_seq_params = PackedSeqParams( + qkv_format="sbhd", + cu_seqlens_q=torch.tensor([0, sample_length], device="cuda", pin_memory=True), + cu_seqlens_kv=torch.tensor([0, sample_length], device="cuda", pin_memory=True), + cu_seqlens_q_padded=torch.tensor([0, sample_length], device="cuda", pin_memory=True), + cu_seqlens_kv_padded=torch.tensor([0, sample_length], device="cuda", pin_memory=True), + max_seqlen_q=sample_length, + max_seqlen_kv=sample_length, + local_cp_size=local_cp_size, + cp_group=cp_group, + ) + + if cp_group is not None and cp_group.size() > 1: + # When using hybrid_context_parallel, each sub-sample of a packed sample is + # required to be divisible by CP*DP*2 or CP*DP*TP*2 (if using sequence parallel) + batch = get_batch_on_this_cp_rank( + batch, cp_group.size(), torch.distributed.get_rank(group=cp_group) + ) + + return batch, packed_seq_params + + ###################### ### NVTX profiling ### ###################### diff --git a/megatron/legacy/data/data_samplers.py b/megatron/legacy/data/data_samplers.py index 1bf1bf5ee9..79bdc7b193 100644 --- a/megatron/legacy/data/data_samplers.py +++ b/megatron/legacy/data/data_samplers.py @@ -34,13 +34,22 @@ def build_pretraining_data_loader(dataset, consumed_samples): data_parallel_rank=mpu.get_data_parallel_rank(), data_parallel_size=mpu.get_data_parallel_world_size()) elif args.dataloader_type == 'single': - # Megatron sampler - batch_sampler = MegatronPretrainingSampler( - total_samples=len(dataset), - consumed_samples=consumed_samples, - micro_batch_size=args.micro_batch_size, - data_parallel_rank=mpu.get_data_parallel_rank(), - data_parallel_size=mpu.get_data_parallel_world_size()) + if args.hybrid_context_parallel: + batch_sampler = HybridCPMegatronPretrainingSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=args.micro_batch_size, + global_batch_size=args.global_batch_size, + data_parallel_rank=mpu.get_data_parallel_rank(), + data_parallel_size=mpu.get_data_parallel_world_size()) + else: + # Megatron sampler + batch_sampler = MegatronPretrainingSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=args.micro_batch_size, + data_parallel_rank=mpu.get_data_parallel_rank(), + data_parallel_size=mpu.get_data_parallel_world_size()) elif args.dataloader_type == 'cyclic': batch_sampler = MegatronPretrainingRandomSampler( dataset, @@ -59,11 +68,16 @@ def build_pretraining_data_loader(dataset, consumed_samples): args.dataloader_type)) # Torch dataloader. + if args.hybrid_context_parallel: + extra_kwargs = {"collate_fn": lambda x: x,} + else: + extra_kwargs = {} return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True, persistent_workers=True if args.num_workers > 0 else False, + **extra_kwargs, ) class MegatronPretrainingSampler: @@ -114,6 +128,49 @@ def __iter__(self): start_idx, end_idx = self.get_start_end_idx() yield batch[start_idx:end_idx] +class HybridCPMegatronPretrainingSampler(MegatronPretrainingSampler): + """ + Data sampler for hybrid context parallel (Hybrid CP) format. + This data sampler pulls in the entire global batch at once across all data parallel ranks. + This helps provide the Hybrid CP Dataloader Wrapper to schedule and load balance sub-samples + of the entire global batch. + """ + + def __init__(self, total_samples, consumed_samples, micro_batch_size, global_batch_size, + data_parallel_rank, data_parallel_size, drop_last=True): + super().__init__(total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size, drop_last) + self.global_batch_size = global_batch_size + self.data_parallel_size = data_parallel_size + self.num_micro_batches = self.global_batch_size // self.micro_batch_times_data_parallel_size + + def __len__(self): + return self.total_samples + + def get_start_end_idx_global_batch(self): + start_idx = [self.data_parallel_rank * self.micro_batch_size + i * self.micro_batch_size * self.data_parallel_size for i in range(self.num_micro_batches)] + end_idx = [start_idx[i] + self.micro_batch_size for i in range(self.num_micro_batches)] + return start_idx, end_idx + + def __iter__(self): + batch = [] + # Last batch will be dropped if drop_last is not set False + for idx in range(self.consumed_samples, self.total_samples): + batch.append(idx) + if len(batch) == self.micro_batch_times_data_parallel_size * self.num_micro_batches: + start_idx, end_idx = self.get_start_end_idx_global_batch() + global_batch_idx = [] + for i in range(self.num_micro_batches): + global_batch_idx.extend(batch[start_idx[i]:end_idx[i]]) + yield global_batch_idx + batch = [] + + # Check the last partial batch and see drop_last is set + if len(batch) > 0 and not self.drop_last: + start_idx, end_idx = self.get_start_end_idx_global_batch() + global_batch_idx = [] + for i in range(self.num_micro_batches): + global_batch_idx.extend(batch[start_idx[i]:end_idx[i]]) + yield global_batch_idx class RandomSeedDataset(Dataset): diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index bb1b17e9ba..57ac82906e 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -956,6 +956,13 @@ def validate_args(args, defaults={}): if args.tp_comm_overlap: assert args.sequence_parallel == True, 'Tensor parallel communication/GEMM overlap can happen only when sequence parallelism is enabled' + if args.hybrid_context_parallel: + assert not args.pipeline_model_parallel_size > 1, 'Hybrid context parallelism not supported with pipeline parallelism' + assert not args.enable_cuda_graph, 'Hybrid context parallelism not supported with CUDA Graph' + assert not args.use_megatron_fsdp, 'Hybrid context parallelism not supported with Megatron FSDP' + assert args.dataloader_type == 'single', 'Hybrid context parallelism only supported with single dataloader type' + assert args.calculate_per_token_loss, 'Hybrid context parallelism must be used with --calculate-per-token-loss' + # disable async_tensor_model_parallel_allreduce when # model parallel memory optimization is enabled if (args.tensor_model_parallel_size > 1 or args.context_parallel_size > 1) \ @@ -2858,6 +2865,13 @@ def _add_distributed_args(parser): '--hierarchical-context-parallel-sizes 2 4 indicates every two adjacent gpus ' 'forms the first level of cp groups and the cp ranks with the same odevity ' 'forms the second level of cp groups.') + group.add_argument('--max-seqlen-per-cp-rank', type=int, default=None, + help='Maximum sequence length per CP rank. This is used to calculate the ' + 'number of sub-samples assigned to each CP rank when using heterogeneous context parallel.') + group.add_argument('--hybrid-context-parallel', action='store_true', default=False, + help='Enables hybrid context parallel. This is used to balance the workload ' + 'of each CP rank when we use packed samples with variable sequence lengths. ' + 'Requires --max-seqlen-per-cp-rank to be set.') group.add_argument('--nccl-communicator-config-path', type=str, default=None, help='Path to the yaml file with NCCL communicator ' 'configurations. The number of min/max thread groups and thread ' diff --git a/megatron/training/initialize.py b/megatron/training/initialize.py index 8b585fdd87..fb9a3aa273 100644 --- a/megatron/training/initialize.py +++ b/megatron/training/initialize.py @@ -369,6 +369,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks, s use_sharp=args.use_sharp, context_parallel_size=args.context_parallel_size, hierarchical_context_parallel_sizes=args.hierarchical_context_parallel_sizes, + hybrid_context_parallel=args.hybrid_context_parallel, expert_model_parallel_size=args.expert_model_parallel_size, num_distributed_optimizer_instances=args.num_distributed_optimizer_instances, expert_tensor_parallel_size=args.expert_tensor_parallel_size, diff --git a/megatron/training/training.py b/megatron/training/training.py index 9986f93164..84cf83a1df 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -88,6 +88,7 @@ from megatron.training.initialize import set_jit_fusion_options from megatron.training.utils import get_batch_on_this_cp_rank, get_batch_on_this_tp_rank from megatron.legacy.data.data_samplers import build_pretraining_data_loader +from megatron.core.datasets.data_schedule import HybridCPDataLoaderWrapper from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler from megatron.core.transformer.moe import upcycling_utils from megatron.core.transformer.moe.moe_utils import track_moe_metrics @@ -1449,28 +1450,14 @@ def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_sch for key in losses_reduced[0].keys(): val = [x[key].view(-1) for x in losses_reduced] if val[0].numel() == 2: - if args.sft: - # in mcore the normalization happens on micro batch instead of global - val = torch.vstack(val) - val = val[:, 0] / val[:, 1] - val = val.mean() - torch.distributed.all_reduce( - val, - group=mpu.get_data_parallel_group(with_context_parallel=True) - ) - val /= torch.distributed.get_world_size( - group=mpu.get_data_parallel_group(with_context_parallel=True) - ) - loss_reduced[key] = val - else: - # there is one dict per microbatch. in new reporting, we average - # over the total number of tokens across the global batch. - val = torch.vstack(val).sum(dim=0) - torch.distributed.all_reduce( - val, - group=mpu.get_data_parallel_group(with_context_parallel=True) - ) - loss_reduced[key] = val[0] / val[1] + # there is one dict per microbatch. in new reporting, we average + # over the total number of tokens across the global batch. + val = torch.vstack(val).sum(dim=0) + torch.distributed.all_reduce( + val, + group=mpu.get_data_parallel_group(with_context_parallel=True) + ) + loss_reduced[key] = val[0] / val[1] elif val[0].numel() == 1: # legacy behavior, we average over the number of microbatches val = torch.cat(val).mean() @@ -2155,6 +2142,9 @@ def train( energy_monitor = get_energy_monitor() one_logger = get_one_logger() + if args.hybrid_context_parallel: + train_data_iterator = iter(HybridCPDataLoaderWrapper(train_data_iterator, config)) + if args.run_workload_inspector_server: try: from workload_inspector.utils.webserver import run_server diff --git a/megatron/training/utils.py b/megatron/training/utils.py index 52a3bf36d8..341a2d41a2 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -541,19 +541,58 @@ def _broadcast(item): else data["attention_mask"].cuda(non_blocking=True) ), 'position_ids': data["position_ids"].cuda(non_blocking=True), + 'cu_seqlens': ( + None + if "cu_seqlens" not in data + else data["cu_seqlens"].cuda(non_blocking=True) + ), + 'max_seqlen': ( + None + if "max_seqlen" not in data + else data["max_seqlen"].cuda(non_blocking=True) + ), + 'local_cp_size': ( + None + if "local_cp_size" not in data + else data["local_cp_size"].cuda(non_blocking=True) + ), } + def _broadcast_cu_seqlens(cu_seqlens): + dev = torch.cuda.current_device() + n = 0 if cu_seqlens is None else int(cu_seqlens.numel()) + n_tensor = torch.tensor(n, dtype=torch.int64, device=dev) + _broadcast(n_tensor) + + if n == 0: + buf = torch.empty(0, dtype=torch.int32, device=dev) + else: + assert isinstance(cu_seqlens, torch.Tensor) + assert cu_seqlens.dtype == torch.int32 + assert cu_seqlens.shape[0] == 1, "micro-batch-size must be 1 for packing" + buf = cu_seqlens.to(device=dev, non_blocking=True).contiguous() + _broadcast(buf) + + if args.hybrid_context_parallel: + seq_len = torch.tensor(batch['tokens'].shape[0], dtype=torch.int32, device=torch.cuda.current_device()) + _broadcast(seq_len) + if args.pipeline_model_parallel_size == 1 or mtp_on_this_rank: _broadcast(batch['tokens']) _broadcast(batch['labels']) _broadcast(batch['loss_mask']) _broadcast(batch['attention_mask']) _broadcast(batch['position_ids']) + _broadcast_cu_seqlens(batch['cu_seqlens']) + _broadcast(batch['max_seqlen']) + _broadcast(batch['local_cp_size']) elif mpu.is_pipeline_first_stage(): _broadcast(batch['tokens']) _broadcast(batch['attention_mask']) _broadcast(batch['position_ids']) + _broadcast_cu_seqlens(batch['cu_seqlens']) + _broadcast(batch['max_seqlen']) elif mpu.is_pipeline_last_stage(): # Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding. @@ -564,42 +603,79 @@ def _broadcast(item): _broadcast(batch['attention_mask']) else: - + if args.hybrid_context_parallel: + seq_len = torch.tensor(0, dtype=torch.int32, device=torch.cuda.current_device()) + _broadcast(seq_len) + shape = (seq_len.item()) + else: + shape = (args.micro_batch_size, args.seq_length) + tokens = torch.empty( - (args.micro_batch_size, args.seq_length), + shape, dtype=torch.int64, device=torch.cuda.current_device(), ) labels = torch.empty( - (args.micro_batch_size, args.seq_length), + shape, dtype=torch.int64, device=torch.cuda.current_device(), ) loss_mask = torch.empty( - (args.micro_batch_size, args.seq_length), + shape, dtype=torch.float32, device=torch.cuda.current_device(), ) if args.create_attention_mask_in_dataloader: + shape_attention_mask = (args.micro_batch_size, 1, args.seq_length, args.seq_length) if not args.hybrid_context_parallel else (1, 1, shape[0], shape[0]) attention_mask = torch.empty( - (args.micro_batch_size, 1, args.seq_length, args.seq_length), + shape_attention_mask, dtype=torch.bool, device=torch.cuda.current_device(), ) else: attention_mask = None position_ids = torch.empty( - (args.micro_batch_size, args.seq_length), + shape, dtype=torch.int64, device=torch.cuda.current_device(), ) + cu_seqlens = None + max_seqlen = torch.empty( + 1, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) if not args.hybrid_context_parallel else None + local_cp_size = torch.empty( + 1, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) if args.hybrid_context_parallel else None + + def _broadcast_cu_seqlens(): + dev = torch.cuda.current_device() + + n = torch.empty((), dtype=torch.int64, device=dev) + _broadcast(n) + n = int(n.item()) + + if n == 0: + cu_seqlens = torch.empty(0, dtype=torch.int32, device=dev) + else: + cu_seqlens = torch.empty((args.micro_batch_size, n), dtype=torch.int32, device=dev) + _broadcast(cu_seqlens) + + return cu_seqlens if n > 0 else None + if args.pipeline_model_parallel_size == 1 or mtp_on_this_rank: _broadcast(tokens) _broadcast(labels) _broadcast(loss_mask) _broadcast(attention_mask) _broadcast(position_ids) + cu_seqlens = _broadcast_cu_seqlens() + _broadcast(max_seqlen) + _broadcast(local_cp_size) elif mpu.is_pipeline_first_stage(): labels = None @@ -608,6 +684,8 @@ def _broadcast(item): _broadcast(tokens) _broadcast(attention_mask) _broadcast(position_ids) + cu_seqlens = _broadcast_cu_seqlens() + _broadcast(max_seqlen) elif mpu.is_pipeline_last_stage(): # Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding. @@ -615,7 +693,8 @@ def _broadcast(item): # to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage. tokens = None position_ids = None - + cu_seqlens = None + max_seqlen = None _broadcast(labels) _broadcast(loss_mask) _broadcast(attention_mask) @@ -626,6 +705,9 @@ def _broadcast(item): 'loss_mask': loss_mask, 'attention_mask': attention_mask, 'position_ids': position_ids, + 'cu_seqlens': cu_seqlens, + 'max_seqlen': max_seqlen, + 'local_cp_size': local_cp_size, } return batch diff --git a/pretrain_gpt.py b/pretrain_gpt.py index ecb7163ff7..e976f5aff7 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -14,9 +14,9 @@ from megatron.core.enums import ModelType from megatron.core.models.gpt import GPTModel from megatron.core.rerun_state_machine import get_rerun_state_machine +from megatron.core.utils import get_attr_wrapped_model, get_thd_batch_on_this_cp_rank, get_batch_on_this_hybrid_cp_rank, StragglerDetector from megatron.core.tokenizers.text.utils.build_tokenizer import build_tokenizer from megatron.core.transformer.multi_token_prediction import mtp_on_this_rank, get_mtp_ranks -from megatron.core.utils import StragglerDetector, get_attr_wrapped_model from megatron.training.arguments import core_transformer_config_from_args from megatron.training import get_args, get_timers, get_tokenizer, inprocess_restart, pretrain, print_rank_0 from megatron.training.datasets.sft_dataset import SFTDataset @@ -46,7 +46,7 @@ def get_batch(data_iterator, vp_stage: Optional[int] = None): # TODO: this is pretty hacky, find a better way if not is_first_or_last_pipeline_stage(vp_stage) and ( (not mtp_on_this_rank(config, ignore_virtual=False, vp_stage=vp_stage))): - return None, None, None, None, None + return None, None, None, None, None, None # get batches based on the TP rank you are on batch = get_batch_on_this_tp_rank( @@ -54,10 +54,24 @@ def get_batch(data_iterator, vp_stage: Optional[int] = None): mtp_on_this_rank=mtp_on_this_rank(config, ignore_virtual=False, vp_stage=vp_stage) ) - # slice batch along sequence dimension for context parallelism - batch = get_batch_on_this_cp_rank(batch) - - return batch.values() + cu_seqlens = batch.pop('cu_seqlens', None) + cu_seqlens_padded = batch.pop('cu_seqlens_padded', None) + max_seqlen = batch.pop('max_seqlen', None) + local_cp_size = batch.pop('local_cp_size', None) + if local_cp_size is not None: + local_cp_size = int(local_cp_size.item()) + + if cu_seqlens is None and local_cp_size is None: + # slice batch along sequence dimension for context parallelism + batch = get_batch_on_this_cp_rank(batch) # The implementation of this function is in MCore + packed_seq_params = None + elif local_cp_size is None: # Packed THD format + assert max_seqlen.dim() == 1 + batch, packed_seq_params = get_thd_batch_on_this_cp_rank(batch, cu_seqlens, cu_seqlens_padded, max_seqlen) + else: # Hybrid CP format + batch, packed_seq_params = get_batch_on_this_hybrid_cp_rank(batch, local_cp_size) + + return (*batch.values(), packed_seq_params) # define spiky loss as a loss that's 10x the max loss observed @@ -142,7 +156,7 @@ def forward_step(data_iterator, model: GPTModel, return_schedule_plan: bool = Fa global stimer with stimer(bdata=True): vp_stage = get_attr_wrapped_model(model, "vp_stage") - tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator, vp_stage) + tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params = get_batch(data_iterator, vp_stage) timers('batch-generator').stop() with stimer: @@ -158,7 +172,7 @@ def forward_step(data_iterator, model: GPTModel, return_schedule_plan: bool = Fa return schedule_plan, partial(loss_func, loss_mask, model=model) else: output_tensor = model( - tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask + tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask, packed_seq_params=packed_seq_params ) # [ModelOpt]: model is needed to access ModelOpt distillation losses @@ -204,6 +218,10 @@ def core_gpt_dataset_config_from_args(args): object_storage_cache_path=args.object_storage_cache_path, mid_level_dataset_surplus=args.mid_level_dataset_surplus, allow_ambiguous_pad_tokens=args.allow_ambiguous_pad_tokens, + context_parallel_size=args.context_parallel_size, + data_parallel_size=args.data_parallel_size, + sequence_parallel_size=args.tensor_model_parallel_size*args.sequence_parallel, + hybrid_context_parallel=args.hybrid_context_parallel, ) diff --git a/tests/unit_tests/test_parallel_state.py b/tests/unit_tests/test_parallel_state.py index 7218ed5b6e..0c722ee025 100644 --- a/tests/unit_tests/test_parallel_state.py +++ b/tests/unit_tests/test_parallel_state.py @@ -1,5 +1,7 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from math import log2 + import pytest import torch @@ -499,3 +501,32 @@ def golden_rank_result_from_past_code( assert expert_dp_group == expert_rank_generator.get_ranks( "dp" ), f"{expert_dp_group} != {expert_rank_generator.get_ranks('dp')}." + + +@pytest.mark.parametrize( + "world_size, tp_size, cp_size, dp_size", + [(8, 1, 2, 4), (8, 1, 1, 8)], # 8 GPUs, 1 TP, 2 CP, 4 DP # 8 GPUs, 1 TP, 1 CP, 8 DP +) +def test_hybrid_dp_cp_groups(world_size, tp_size, cp_size, dp_size): + """ + Test that hybrid DPxCP groups are created correctly. + """ + Utils.destroy_model_parallel() + + # Skip if world size doesn't match + actual_world_size = torch.cuda.device_count() + if actual_world_size != world_size: + pytest.skip(f"Test requires world_size={world_size}, but got {actual_world_size}") + Utils.initialize_model_parallel( + tensor_model_parallel_size=tp_size, + context_parallel_size=cp_size, + hybrid_context_parallel=True, + ) + + dp_cp_size = ps.get_data_parallel_world_size(with_context_parallel=True) + group_sizes = [2**i for i in range(int(log2(dp_cp_size)))][1:] + for group_size in group_sizes: + group = ps.get_hybrid_data_context_parallel_groups(group_size=group_size) + assert group.size() == group_size + + Utils.destroy_model_parallel()