diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index e95409e08e..bec3f499b3 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -42,6 +42,7 @@ from megatron.core.transformer.mlp import MLP from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.utils import ( + ensure_metadata_has_dp_cp_group, is_layer_window_attention, make_sharded_tensors_for_checkpoint, ) @@ -419,6 +420,9 @@ def __init__( # duplicated across TP ranks setattr(param, "sequence_parallel", self.config.sequence_parallel) + tp_group = get_tensor_model_parallel_group_if_none(tp_group, is_expert=is_expert) + self._tp_group = tp_group + def forward(self, x): """Forward.""" _is_first_microbatch = ( @@ -443,7 +447,14 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): self.parallel_mode is None ), "TELinear sharded_state_dict can only be used with duplicated parallel mode" state_dict = self.state_dict(prefix="", keep_vars=True) - return make_sharded_tensors_for_checkpoint(state_dict, prefix, None, sharded_offsets) + return make_sharded_tensors_for_checkpoint( + state_dict, + prefix, + None, + sharded_offsets, + tp_group=self._tp_group, + dp_cp_group=metadata["dp_cp_group"], + ) def backward_dw(self): """Compute weight gradients during the backward pass if delay_wgrad_compute is enabled.""" @@ -491,6 +502,7 @@ def __init__( # TODO: For backward compatibility, remove in v0.15. tp_group = get_tensor_model_parallel_group_if_none(tp_group, is_expert=is_expert) + self._tp_group = tp_group # TE returns a zero length Tensor when bias=False and # return_bias=True, but we prefer None. So in that case we @@ -624,9 +636,15 @@ def forward(self, x): def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): """Sharding along axis 0, bias sharded""" + metadata = ensure_metadata_has_dp_cp_group(metadata) state_dict = self.state_dict(prefix="", keep_vars=True) return make_sharded_tensors_for_checkpoint( - state_dict, prefix, {"weight": 0, "bias": 0}, sharded_offsets + state_dict, + prefix, + {"weight": 0, "bias": 0}, + sharded_offsets, + tp_group=self._tp_group, + dp_cp_group=metadata["dp_cp_group"], ) def __repr__(self): @@ -669,6 +687,7 @@ def __init__( if gather_output: raise ValueError("Transformer Engine linear layers do not support gather_output = True") tp_group = get_tensor_model_parallel_group_if_none(tp_group, is_expert=is_expert) + self._tp_group = tp_group world_size = get_pg_size(tp_group) rank = get_pg_rank(tp_group) @@ -719,7 +738,12 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): """Sharding along axis 0, bias sharded""" state_dict = self.state_dict(prefix="", keep_vars=True) return make_sharded_tensors_for_checkpoint( - state_dict, prefix, {"weight": 0, "bias": 0}, sharded_offsets + state_dict, + prefix, + {"weight": 0, "bias": 0}, + sharded_offsets, + tp_group=self._tp_group, + dp_cp_group=metadata["dp_cp_group"], ) def __repr__(self): @@ -763,6 +787,7 @@ def __init__( "Transformer Engine linear layers do not support input_is_parallel = False" ) tp_group = get_tensor_model_parallel_group_if_none(tp_group, is_expert=is_expert) + self._tp_group = tp_group super().__init__( input_size=input_size, @@ -813,7 +838,12 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): """Sharding along axis 1, bias not sharded""" state_dict = self.state_dict(prefix="", keep_vars=True) return make_sharded_tensors_for_checkpoint( - state_dict, prefix, {"weight": 1}, sharded_offsets + state_dict, + prefix, + {"weight": 1}, + sharded_offsets, + tp_group=self._tp_group, + dp_cp_group=metadata["dp_cp_group"], ) def __repr__(self): @@ -900,6 +930,7 @@ def __init__( assert hasattr( pg_collection, "hcp" ), "TEDotProductAttention pg_collection must have hierarchical cp pg" + self._tp_group = pg_collection.tp if is_te_min_version("0.10.0"): extra_kwargs["attention_type"] = attention_type @@ -1077,7 +1108,12 @@ def sharded_state_dict( else: state_dict = {} return make_sharded_tensors_for_checkpoint( - state_dict, prefix, {'softmax_offset': 0}, sharded_offsets + state_dict, + prefix, + {'softmax_offset': 0}, + sharded_offsets, + tp_group=self._tp_group, + dp_cp_group=metadata["dp_cp_group"], ) @@ -1137,6 +1173,7 @@ def __init__( # The comms between TP and EP group is explicitly handled by MoE token dispatcher. # So we disable comms by making TE agnostic of model parallel. tp_group = get_tensor_model_parallel_group_if_none(tp_group, is_expert=is_expert) + self._tp_group = tp_group tp_size = get_pg_size(tp_group) self.explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel) @@ -1371,7 +1408,12 @@ def _sharded_state_dict_grouped( (ep_axis, global_expert_idx, num_global_experts), ) sub_sd = make_sharded_tensors_for_checkpoint( - state_dict, '', tp_axis_map, new_sharded_offsets + state_dict, + '', + tp_axis_map, + new_sharded_offsets, + tp_group=self._tp_group, + dp_cp_group=metadata["dp_cp_group"], ) # Remove expert layers indexing from sharded keys replace_prefix_for_sharding(sub_sd, f"{gemm_idx}.", expert_prefix) diff --git a/megatron/core/models/bert/bert_model.py b/megatron/core/models/bert/bert_model.py index b7b9bfc73f..0655a1e616 100644 --- a/megatron/core/models/bert/bert_model.py +++ b/megatron/core/models/bert/bert_model.py @@ -14,6 +14,7 @@ from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from megatron.core.models.common.language_module.language_module import LanguageModule +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.dot_product_attention import ( DotProductAttention as MCoreDotProductAttention, ) @@ -73,9 +74,10 @@ def __init__( seq_len_interpolation_factor: Optional[float] = None, add_binary_head=True, return_embeddings=False, + pg_collection: Optional[ProcessGroupCollection] = None, vp_stage: Optional[int] = None, ): - super(BertModel, self).__init__(config=config) + super(BertModel, self).__init__(config=config, pg_collection=pg_collection) if has_config_logger_enabled(config): log_config_to_disk(config, locals(), prefix=type(self).__name__) diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py index d855322c2d..32867d6a9a 100644 --- a/megatron/core/models/common/language_module/language_module.py +++ b/megatron/core/models/common/language_module/language_module.py @@ -24,7 +24,12 @@ from megatron.core.transformer.enums import AttnBackend from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import is_te_min_version, make_tp_sharded_tensor_for_checkpoint +from megatron.core.transformer.utils import ensure_metadata_has_dp_cp_group +from megatron.core.utils import ( + get_tensor_model_parallel_group_if_none, + is_te_min_version, + make_tp_sharded_tensor_for_checkpoint, +) class LanguageModule(MegatronModule): @@ -44,6 +49,7 @@ def __init__( pg_collection = ProcessGroupCollection.use_mpu_process_groups() self.pg_collection = pg_collection self.cp_group = pg_collection.cp + self.tp_group = get_tensor_model_parallel_group_if_none(pg_collection.tp) self.pp_group = pg_collection.pp assert hasattr(self.pg_collection, 'embd'), ( "pg_collection must have a embd. In previous version, it used default " @@ -272,6 +278,10 @@ def sharded_state_dict( ShardedStateDict: sharded state dict for the LanguageModel """ assert not sharded_offsets, "Unexpected sharded offsets" + + # Guard for cases metadata is not provided + metadata = ensure_metadata_has_dp_cp_group(metadata) + sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) first_stage_word_emb_key = f'{prefix}embedding.word_embeddings.weight' @@ -280,7 +290,7 @@ def sharded_state_dict( if self.share_embeddings_and_output_weights: self.tie_embeddings_and_output_weights_state_dict( - sharded_state_dict, output_layer_weight_key, first_stage_word_emb_key + sharded_state_dict, output_layer_weight_key, first_stage_word_emb_key, metadata ) elif self.post_process: # Make sure the output layer follows the embeddings padding logic @@ -297,6 +307,7 @@ def tie_embeddings_and_output_weights_state_dict( sharded_state_dict: ShardedStateDict, output_layer_weight_key: str, first_stage_word_emb_key: str, + metadata: dict, ) -> None: """Ties the embedding and output weights in a given sharded state dict. @@ -341,4 +352,6 @@ def tie_embeddings_and_output_weights_state_dict( key=first_stage_word_emb_key, replica_id=last_stage_word_emb_replica_id, allow_shape_mismatch=True, + tp_group=self.tp_group, + dp_cp_group=metadata['dp_cp_group'], ) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index dbc5a88fc8..c2055198e9 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -752,7 +752,13 @@ def sharded_state_dict( if self.mtp_process and not self.pre_process: emb_weight_key = f'{prefix}embedding.word_embeddings.weight' emb_weight = self.embedding.word_embeddings.weight - tie_word_embeddings_state_dict(sharded_state_dict, emb_weight, emb_weight_key) + tie_word_embeddings_state_dict( + sharded_state_dict, + emb_weight, + emb_weight_key, + tp_group=self.tp_group, + dp_cp_group=metadata['dp_cp_group'], + ) if self.mtp_process and not self.post_process: # We only need to tie the output layer weight if share_embeddings_and_output_weights # is False. Because if share_embeddings_and_output_weights is True, the shared weight @@ -761,7 +767,11 @@ def sharded_state_dict( output_layer_weight_key = f'{prefix}output_layer.weight' output_layer_weight = self.output_layer.weight tie_output_layer_state_dict( - sharded_state_dict, output_layer_weight, output_layer_weight_key + sharded_state_dict, + output_layer_weight, + output_layer_weight_key, + tp_group=self.tp_group, + dp_cp_group=metadata['dp_cp_group'], ) return sharded_state_dict diff --git a/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py b/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py index 6f9999f080..28487c3b36 100644 --- a/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py +++ b/megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py @@ -52,7 +52,7 @@ def __init__( pin_cpu_grads: bool = True, pin_cpu_params: bool = True, overlap_cpu_optimizer_d2h_h2d: bool = True, - **kwargs + **kwargs, ): super(HybridDeviceOptimizer, self).__init__( params, diff --git a/megatron/core/post_training/modelopt/layers.py b/megatron/core/post_training/modelopt/layers.py index 0ca4a8e407..e502b81ac2 100644 --- a/megatron/core/post_training/modelopt/layers.py +++ b/megatron/core/post_training/modelopt/layers.py @@ -1,5 +1,6 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +import logging from typing import Callable, List, Optional import torch @@ -10,6 +11,8 @@ from megatron.core.transformer.transformer_layer import TransformerLayer from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint +logger = logging.getLogger(__name__) + try: import transformer_engine as te @@ -116,6 +119,7 @@ def __init__( tp_group: Optional[torch.distributed.ProcessGroup] = None, ): self.config = config + self.tp_group = tp_group self._return_bias = skip_bias_add and bias @@ -153,7 +157,11 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): if v.ndim == 0: state_dict[k] = v.view(1) sharded_state_dict = make_sharded_tensors_for_checkpoint( - state_dict, prefix, sharded_offsets=sharded_offsets + state_dict, + prefix, + sharded_offsets=sharded_offsets, + tp_group=self.tp_group, + dp_cp_group=metadata['dp_cp_group'], ) return sharded_state_dict @@ -229,7 +237,7 @@ def _report_quantize_tensor_info(self): if not isinstance(v, torch.Tensor): continue original_dtype, original_shape = self._original_tensor_info.get(k, ("-", "-")) - print( + logger.info( "{:<64} {:<16} {:<32} {:<16} {:<32}".format( k, original_dtype, original_shape, str(v.dtype), str(v.shape) ) diff --git a/megatron/core/ssm/mamba_block.py b/megatron/core/ssm/mamba_block.py index 01b9f4eac6..d17f27556e 100644 --- a/megatron/core/ssm/mamba_block.py +++ b/megatron/core/ssm/mamba_block.py @@ -139,6 +139,7 @@ def __init__( assert pg_collection is not None, "pg_collection must be provided for MambaStack" self.pp_group = pg_collection.pp + self.tp_group = pg_collection.tp # Required for pipeline parallel schedules self.input_tensor = None @@ -416,7 +417,11 @@ def sharded_state_dict( if not module is self.layers: sharded_state_dict.update( sharded_state_dict_default( - module, f'{prefix}{name}.', sharded_offsets, metadata + module, + f'{prefix}{name}.', + sharded_offsets, + metadata, + tp_group=self.tp_group, ) ) diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 2caa36fb1e..296ca304fe 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -15,6 +15,7 @@ import torch.nn as nn import torch.nn.functional as F +from megatron.core import parallel_state from megatron.core.dist_checkpointing import ShardedTensor from megatron.core.dist_checkpointing.mapping import ReplicaId, ShardedTensorFactory from megatron.core.inference.contexts import BaseInferenceContext, DynamicInferenceContext @@ -24,6 +25,7 @@ from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.utils import ( + ensure_metadata_has_dp_cp_group, make_sharded_tensors_for_checkpoint, sharded_state_dict_default, ) @@ -74,9 +76,16 @@ class ExtendedRMSNorm(RMSNormGated): def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): """Sharding along axis 0, bias not sharded""" + if not hasattr(self, 'tp_group'): + self.tp_group = parallel_state.get_tensor_model_parallel_group() state_dict = self.state_dict(prefix="", keep_vars=True) return make_sharded_tensors_for_checkpoint( - state_dict, prefix, {"weight": 0}, sharded_offsets + state_dict, + prefix, + {"weight": 0}, + sharded_offsets, + tp_group=self.tp_group, + dp_cp_group=metadata["dp_cp_group"], ) @@ -377,6 +386,7 @@ def __init__( D_cp1=self.D, D_has_hdim=self.D_has_hdim, ) + self.tp_group = pg_collection.tp def forward( self, @@ -788,6 +798,9 @@ def _get_states_from_cache(self, inference_context, batch_size, *, inference_par def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): """Provide a sharded state dictionary for distributed checkpointing.""" + # Guard for cases metadata is not provided + metadata = ensure_metadata_has_dp_cp_group(metadata) + sharded_state_dict = {} # Parameters self._save_to_state_dict(sharded_state_dict, "", keep_vars=True) @@ -807,12 +820,17 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): # Add TP sharding for Conv1d module_sd = module.state_dict(prefix="", keep_vars=True) module_sharded_sd = make_sharded_tensors_for_checkpoint( - module_sd, f"{prefix}{name}.", {f"weight": 0, f"bias": 0}, sharded_offsets + module_sd, + f"{prefix}{name}.", + {f"weight": 0, f"bias": 0}, + sharded_offsets, + tp_group=self.tp_group, + dp_cp_group=metadata['dp_cp_group'], ) else: module_sharded_sd = sharded_state_dict_default( - module, f"{prefix}{name}.", sharded_offsets, metadata + module, f"{prefix}{name}.", sharded_offsets, metadata, tp_group=self.tp_group ) sharded_state_dict.update(module_sharded_sd) diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index e6e65425b2..13acddebc7 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -310,6 +310,8 @@ def sharded_state_dict( key=weight_prefix, allow_shape_mismatch=True, prepend_offsets=sharded_offsets, + tp_group=self.tp_group, + dp_cp_group=metadata["dp_cp_group"], ) } @@ -1046,7 +1048,12 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): """Sharding along axis 0, bias sharded""" state_dict = self.state_dict(prefix="", keep_vars=True) return make_sharded_tensors_for_checkpoint( - state_dict, prefix, {"weight": 0, "bias": 0}, sharded_offsets + state_dict, + prefix, + {"weight": 0, "bias": 0}, + sharded_offsets, + tp_group=self.tp_group, + dp_cp_group=metadata['dp_cp_group'], ) def set_extra_state(self, state: Any): @@ -1284,7 +1291,12 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): """Sharding along axis 1, bias not sharded""" state_dict = self.state_dict(prefix="", keep_vars=True) return make_sharded_tensors_for_checkpoint( - state_dict, prefix, {"weight": 1}, sharded_offsets + state_dict, + prefix, + {"weight": 1}, + sharded_offsets, + tp_group=self.tp_group, + dp_cp_group=metadata['dp_cp_group'], ) def set_extra_state(self, state: Any): diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 187222bef3..3d0f9e7857 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -170,6 +170,7 @@ def __init__( pg_collection, 'cp' ), "Attention pg_collection must have cp process group" self.pg_collection = pg_collection + self.tp_group = pg_collection.tp # Per attention head and per partition values world_size = get_pg_size(self.pg_collection.tp) diff --git a/megatron/core/transformer/dot_product_attention.py b/megatron/core/transformer/dot_product_attention.py index 2a958722e4..5df276ea05 100644 --- a/megatron/core/transformer/dot_product_attention.py +++ b/megatron/core/transformer/dot_product_attention.py @@ -71,6 +71,8 @@ def __init__( assert hasattr( pg_collection, 'tp' ), "DotProductAttention pg_collection must have tp process group" + self.pg_collection = pg_collection + self.tp_group = self.pg_collection.tp world_size = pg_collection.tp.size() self.hidden_size_per_partition = divide(projection_size, world_size) @@ -252,5 +254,10 @@ def sharded_state_dict( else: state_dict = {} return make_sharded_tensors_for_checkpoint( - state_dict, prefix, {'softmax_offset': 0}, sharded_offsets + state_dict, + prefix, + {'softmax_offset': 0}, + sharded_offsets, + tp_group=self.tp_group, + dp_cp_group=metadata['dp_cp_group'], ) diff --git a/megatron/core/transformer/mlp.py b/megatron/core/transformer/mlp.py index 9602beb2f7..8dcf196da9 100644 --- a/megatron/core/transformer/mlp.py +++ b/megatron/core/transformer/mlp.py @@ -87,7 +87,7 @@ def __init__( self.input_size = input_size if input_size != None else self.config.hidden_size - tp_group = get_tensor_model_parallel_group_if_none(tp_group, is_expert=is_expert) + self.tp_group = get_tensor_model_parallel_group_if_none(tp_group, is_expert=is_expert) if ffn_hidden_size is None: if is_expert: raise ValueError("MoE MLP requires `ffn_hidden_size`, but it was not provided.") diff --git a/megatron/core/transformer/module.py b/megatron/core/transformer/module.py index 4fdcacb791..a0f735fa5c 100644 --- a/megatron/core/transformer/module.py +++ b/megatron/core/transformer/module.py @@ -11,6 +11,7 @@ from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.utils import ( + ensure_metadata_has_dp_cp_group, make_sharded_tensors_for_checkpoint, sharded_state_dict_default, ) @@ -77,13 +78,24 @@ def sharded_state_dict( sharded_state_dict = {} # Save parameters self._save_to_state_dict(sharded_state_dict, '', keep_vars=True) + if not hasattr(self, 'tp_group'): + # some model interface hasn't updated for m4, fallback needed + self.tp_group = parallel_state.get_tensor_model_parallel_group() + # Guard for cases metadata is not provided + metadata = ensure_metadata_has_dp_cp_group(metadata) sharded_state_dict = make_sharded_tensors_for_checkpoint( - sharded_state_dict, prefix, sharded_offsets=sharded_offsets + sharded_state_dict, + prefix, + sharded_offsets=sharded_offsets, + tp_group=self.tp_group, + dp_cp_group=metadata['dp_cp_group'], ) # Recurse into submodules for name, module in self.named_children(): sharded_state_dict.update( - sharded_state_dict_default(module, f'{prefix}{name}.', sharded_offsets, metadata) + sharded_state_dict_default( + module, f'{prefix}{name}.', sharded_offsets, metadata, tp_group=self.tp_group + ) ) return sharded_state_dict diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index d8dd3d03f0..45cc63fff2 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -3,7 +3,7 @@ import copy import itertools from copy import deepcopy -from functools import partial, wraps +from functools import partial from math import ceil from typing import Optional, Tuple @@ -11,7 +11,7 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter -from megatron.core import parallel_state, tensor_parallel +from megatron.core import tensor_parallel from megatron.core.activations import squared_relu from megatron.core.dist_checkpointing import ShardedTensor from megatron.core.dist_checkpointing.mapping import ( @@ -38,6 +38,7 @@ from megatron.core.transformer.spec_utils import build_module from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.utils import ( + ensure_metadata_has_dp_cp_group, make_sharded_object_for_checkpoint, sharded_state_dict_default, ) @@ -54,49 +55,6 @@ HAVE_TE = False -# TODO(Hepteract): delete the usage of the global parallel_state. -# Currently we still have to use the global parallel_state in expert_dist_ckpt_decorator(), -# in order to set sub-module's process group while getting sharded_state_dict. -# After sub-module's refactoring is done, we can pass pg_collection to sub-module -# and delete the function expert_dist_ckpt_decorator. -def expert_dist_ckpt_decorator(func): - """Decorator of shared_state_dict in expert layer for distributed checkpoint. - - Since !1940, the TP size for Expert layer can be different with Attention. - To make distributed checkpoint work in such cases, we use a decorator to - replace the default TP parallel states with expert-TP parallel states. - """ - - @wraps(func) - def wrapper(*args, **kwargs): - # Store original states - original_rank = parallel_state._MPU_TENSOR_MODEL_PARALLEL_RANK - original_size = parallel_state._MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE - original_group = parallel_state._TENSOR_MODEL_PARALLEL_GROUP - try: - # Set new states - parallel_state._MPU_TENSOR_MODEL_PARALLEL_RANK = ( - parallel_state.get_expert_tensor_parallel_rank() - ) - parallel_state._MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = ( - parallel_state.get_expert_tensor_parallel_world_size() - ) - parallel_state._TENSOR_MODEL_PARALLEL_GROUP = ( - parallel_state.get_expert_tensor_parallel_group() - ) - - # Execute the function - result = func(*args, **kwargs) - finally: - # Restore original states - parallel_state._MPU_TENSOR_MODEL_PARALLEL_RANK = original_rank - parallel_state._MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = original_size - parallel_state._TENSOR_MODEL_PARALLEL_GROUP = original_group - return result - - return wrapper - - class GroupedMLP(MegatronModule): """An efficient implementation of the Experts layer using GroupedGEMM. @@ -305,7 +263,6 @@ def forward( return fc2_output, None - @expert_dist_ckpt_decorator def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): """ Maps local expert to global experts. @@ -764,6 +721,7 @@ def __init__( ), "bias_dropout_fusion is not supported in TEGroupedMLP when add_bias_linear=True" self.ep_group = pg_collection.ep + self.tp_group = pg_collection.expt_tp # Double the output width with gated linear unit, see https://arxiv.org/pdf/2002.05202.pdf ffn_hidden_size = self.config.moe_ffn_hidden_size @@ -962,7 +920,6 @@ def glu(x): return output, output_bias - @expert_dist_ckpt_decorator def sharded_state_dict( self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None ) -> ShardedStateDict: @@ -970,10 +927,14 @@ def sharded_state_dict( Maps local expert to global experts. The sharded state dict is interchangable with SequentialMLP's. """ + # Guard for cases metadata is not provided + metadata = ensure_metadata_has_dp_cp_group(metadata) singleton_local_shards = (metadata or {}).get('singleton_local_shards', False) sharded_state_dict = {} for name, module in self._modules.items(): - sub_sd = sharded_state_dict_default(module, f'{name}.', sharded_offsets, metadata) + sub_sd = sharded_state_dict_default( + module, f'{name}.', sharded_offsets, metadata, tp_group=self.tp_group + ) if name == 'linear_fc1' and self.config.gated_linear_unit: num_global_experts = self.ep_group.size() * self.num_local_experts local_expert_indices_offset = self.ep_group.rank() * self.num_local_experts @@ -1037,6 +998,7 @@ def __init__( self.num_local_experts = num_local_experts self.local_experts = torch.nn.ModuleList() self.ep_group = pg_collection.ep + self.tp_group = pg_collection.expt_tp # use pg_collection.expt_dp_group as data parallel group in this module. # TODO (Hepteract): expt_dp wont be needed here once distributed checkpoint is refactored self.dp_group = pg_collection.expt_dp @@ -1124,9 +1086,11 @@ def backward_dw(self): for expert in self.local_experts: expert.backward_dw() - @expert_dist_ckpt_decorator def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): """Maps local expert to global experts.""" + # Guard for cases metadata is not provided + metadata = ensure_metadata_has_dp_cp_group(metadata) + sharded_state_dict = {} num_global_experts = self.ep_group.size() * self.num_local_experts local_expert_indices_offset = self.ep_group.rank() * self.num_local_experts diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index d5a6be9224..4bc36636e4 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -122,7 +122,7 @@ def __init__( # Initialize router self.router = TopKRouter(config=self.config, pg_collection=pg_collection) - + self.tp_group = pg_collection.tp # Initialize token dispatcher if config.moe_token_dispatcher_type == "allgather": self.token_dispatcher = MoEAllGatherTokenDispatcher( diff --git a/megatron/core/transformer/moe/shared_experts.py b/megatron/core/transformer/moe/shared_experts.py index 93e6ad0453..c63e074e1b 100644 --- a/megatron/core/transformer/moe/shared_experts.py +++ b/megatron/core/transformer/moe/shared_experts.py @@ -49,7 +49,7 @@ def __init__( config.ffn_hidden_size = config.moe_shared_expert_intermediate_size # TODO(Hepteract): pass pg_collection to MLP after refactoring MLP - super().__init__(config=config, submodules=submodules) + super().__init__(config=config, submodules=submodules, tp_group=pg_collection.tp) self.use_shared_expert_gate = gate if self.use_shared_expert_gate: @@ -137,7 +137,11 @@ def sharded_state_dict( state_dict = self.state_dict(prefix='', keep_vars=True) sub_sd = { f'{prefix}{name}': make_sharded_tensor_for_checkpoint( - state_dict[name], f'{prefix}{name}', prepend_offsets=sharded_offsets + state_dict[name], + f'{prefix}{name}', + prepend_offsets=sharded_offsets, + tp_group=self.tp_group, + dp_cp_group=metadata['dp_cp_group'], ) } sharded_state_dict.update(sub_sd) diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index bd3aa9c8c9..a25a226a6e 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -25,6 +25,7 @@ from megatron.core.transformer.transformer_block import TransformerBlockSubmodules from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import ( + get_pg_rank, is_torch_min_version, make_tp_sharded_tensor_for_checkpoint, make_viewless_tensor, @@ -53,7 +54,11 @@ def tie_word_embeddings_state_dict( - sharded_state_dict: ShardedStateDict, word_emb_weight: Tensor, word_emb_weight_key: str + sharded_state_dict: ShardedStateDict, + word_emb_weight: Tensor, + word_emb_weight_key: str, + tp_group: torch.distributed.ProcessGroup, + dp_cp_group: torch.distributed.ProcessGroup, ) -> None: """tie the embedding of the mtp processing stage in a given sharded state dict. @@ -61,13 +66,15 @@ def tie_word_embeddings_state_dict( sharded_state_dict (ShardedStateDict): state dict with the weight to tie. word_emb_weight (Tensor): weight of the word embedding. word_emb_weight_key (str): key of the word embedding in the sharded state dict. + tp_group (torch.distributed.ProcessGroup): The tensor parallel group + dp_cp_group (torch.distributed.ProcessGroup): The dp-cp comm group Returns: None, acts in-place """ mtp_word_emb_replica_id = ( 1, # copy of embedding in pre processing stage 0, - parallel_state.get_data_parallel_rank(with_context_parallel=True), + get_pg_rank(dp_cp_group), ) assert word_emb_weight_key in sharded_state_dict del sharded_state_dict[word_emb_weight_key] @@ -76,11 +83,17 @@ def tie_word_embeddings_state_dict( key=word_emb_weight_key, replica_id=mtp_word_emb_replica_id, allow_shape_mismatch=True, + tp_group=tp_group, + dp_cp_group=dp_cp_group, ) def tie_output_layer_state_dict( - sharded_state_dict: ShardedStateDict, output_layer_weight: Tensor, output_layer_weight_key: str + sharded_state_dict: ShardedStateDict, + output_layer_weight: Tensor, + output_layer_weight_key: str, + tp_group: torch.distributed.ProcessGroup, + dp_cp_group: torch.distributed.ProcessGroup, ) -> None: """tie the output layer of the mtp processing stage in a given sharded state dict. @@ -88,13 +101,15 @@ def tie_output_layer_state_dict( sharded_state_dict (ShardedStateDict): state dict with the weight to tie. output_layer_weight (Tensor): weight of the output layer. output_layer_weight_key (str): key of the output layer in the sharded state dict. + tp_group (torch.distributed.ProcessGroup): The tensor parallel group + dp_cp_group (torch.distributed.ProcessGroup): The dp-cp comm group Returns: None, acts in-place """ mtp_output_layer_replica_id = ( 1, # copy of output layer in post processing stage 0, - parallel_state.get_data_parallel_rank(with_context_parallel=True), + get_pg_rank(dp_cp_group), ) assert output_layer_weight_key in sharded_state_dict del sharded_state_dict[output_layer_weight_key] @@ -103,6 +118,8 @@ def tie_output_layer_state_dict( key=output_layer_weight_key, replica_id=mtp_output_layer_replica_id, allow_shape_mismatch=True, + tp_group=tp_group, + dp_cp_group=dp_cp_group, ) diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index aead6133f2..2f2ea5adb3 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -281,6 +281,7 @@ def __init__( if pg_collection is None: pg_collection = ProcessGroupCollection.use_mpu_process_groups() self.pg_collection = pg_collection + self.tp_group = pg_collection.tp pp_group = self.pg_collection.pp if hasattr(self.pg_collection, 'pp') else None pp_rank = get_pg_rank(pp_group) @@ -808,7 +809,11 @@ def sharded_state_dict( if not module is self.layers: sharded_state_dict.update( sharded_state_dict_default( - module, f'{prefix}{name}.', sharded_offsets, metadata + module, + f'{prefix}{name}.', + sharded_offsets, + metadata, + tp_group=self.tp_group, ) ) diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index a5babece9d..b7939448c8 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -272,6 +272,7 @@ def __init__( if pg_collection is None: pg_collection = ProcessGroupCollection.use_mpu_process_groups() self.pg_collection = pg_collection + self.tp_group = pg_collection.tp self.submodules_config = submodules self.layer_number = layer_number + get_transformer_layer_offset( diff --git a/megatron/core/transformer/utils.py b/megatron/core/transformer/utils.py index ac00e6557c..d251a01daa 100644 --- a/megatron/core/transformer/utils.py +++ b/megatron/core/transformer/utils.py @@ -11,6 +11,8 @@ from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedStateDict, StateDict from megatron.core.jit import jit_fuser from megatron.core.utils import ( + get_pg_rank, + get_tensor_model_parallel_group_if_none, make_sharded_tensor_for_checkpoint, make_tp_sharded_tensor_for_checkpoint, ) @@ -79,6 +81,8 @@ def make_sharded_tensors_for_checkpoint( tensor_parallel_layers_axis_map: Optional[Dict[str, int]] = None, sharded_offsets: Iterable[Tuple[int, int, int]] = (), extra_state_suffix: str = '_extra_state', + tp_group: Optional[torch.distributed.ProcessGroup] = None, + dp_cp_group: Optional[torch.distributed.ProcessGroup] = None, ): """Wraps tensors from transformer layers with ShardedTensor or ShardedObject. @@ -96,31 +100,52 @@ def make_sharded_tensors_for_checkpoint( applied (e.g. PP related), passed along to ShardedTensor extra_state_suffix (str, default = '_extra_state'): layers with this suffix will be wrapped with ShardedObject instead of ShardedTensor. + tp_group (Optional[torch.distributed.ProcessGroup], optional): tensor parallel group. + If None, defaults to parallel_state.get_tensor_model_parallel_group() + dp_cp_group (Optional[torch.distributed.ProcessGroup], optional): data parallel group + with context parallel. If None, defaults to + parallel_state.get_data_parallel_group(with_context_parallel=True) """ if tensor_parallel_layers_axis_map is None: tensor_parallel_layers_axis_map = {} + if tp_group is None and dp_cp_group is None: + tp_group = get_tensor_model_parallel_group_if_none(tp_group) + dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) + sharded_state_dict = {} for layer_name in state_dict.keys(): tensor = state_dict[layer_name] layer_key = f'{prefix}{layer_name}' if layer_name.endswith(extra_state_suffix): + # Compute replica_id when groups are provided + replica_id = (0, get_pg_rank(tp_group), get_pg_rank(dp_cp_group)) + sharded_state_dict[layer_key] = make_sharded_object_for_checkpoint( - tensor, layer_key, sharded_offsets + tensor, layer_key, sharded_offsets, replica_id=replica_id ) elif layer_name in tensor_parallel_layers_axis_map: tp_axis = tensor_parallel_layers_axis_map[layer_name] sharded_state_dict[layer_key] = make_tp_sharded_tensor_for_checkpoint( - tensor, layer_key, tp_axis, prepend_offsets=sharded_offsets + tensor, + layer_key, + tp_axis, + prepend_offsets=sharded_offsets, + tp_group=tp_group, + dp_cp_group=dp_cp_group, ) else: sharded_state_dict[layer_key] = make_sharded_tensor_for_checkpoint( - tensor, layer_key, prepend_offsets=sharded_offsets + tensor, + layer_key, + prepend_offsets=sharded_offsets, + tp_group=tp_group, + dp_cp_group=dp_cp_group, ) return sharded_state_dict @@ -169,11 +194,27 @@ def _get_extra_state_offsets( return extra_state_shape, extra_state_offset +def ensure_metadata_has_dp_cp_group(metadata: Optional[dict]) -> dict: + """Ensure `metadata` is a dict containing `dp_cp_group` entry. + + If `metadata` is None, a new dict is returned with `dp_cp_group` set. + If `metadata` is a dict and missing `dp_cp_group`, it is updated in-place. + Otherwise, asserts that `dp_cp_group` exists. + """ + if metadata is None: + return {'dp_cp_group': parallel_state.get_data_parallel_group(with_context_parallel=True)} + assert isinstance(metadata, dict), "metadata must be a dict with dp_cp_group as key" + if 'dp_cp_group' not in metadata: + metadata['dp_cp_group'] = parallel_state.get_data_parallel_group(with_context_parallel=True) + return metadata + + def sharded_state_dict_default( module: torch.nn.Module, prefix: str = '', sharded_offsets: Tuple[Tuple[int, int, int]] = (), metadata: Optional[dict] = None, + tp_group: Optional[torch.distributed.ProcessGroup] = None, ) -> ShardedStateDict: """Provides implementation for sharded_state_dict method for non-MegatronModules. @@ -189,11 +230,16 @@ def sharded_state_dict_default( sharded_offsets (Tuple[Tuple[int, int, int]], optional): sharding already applied (e.g. PP related) by sup-modules. Passed along to ShardedTensor metadata (dict, optional): metadata passed to module sharded_state_dict method + tp_group (Optional[torch.distributed.ProcessGroup], optional): tensor parallel group. + If None, defaults to parallel_state.get_tensor_model_parallel_group() Returns: dict: dictionary of state dict keys mapped to ShardedTensors """ + # Guard for cases metadata is not provided + metadata = ensure_metadata_has_dp_cp_group(metadata) + if hasattr(module, 'sharded_state_dict'): module_sharded_sd = module.sharded_state_dict( prefix=prefix, sharded_offsets=sharded_offsets, metadata=metadata @@ -201,7 +247,12 @@ def sharded_state_dict_default( else: module_sd = module.state_dict(prefix='', keep_vars=True) module_sharded_sd = make_sharded_tensors_for_checkpoint( - module_sd, prefix, {}, sharded_offsets + module_sd, + prefix, + {}, + sharded_offsets, + tp_group=tp_group, + dp_cp_group=metadata['dp_cp_group'], ) return module_sharded_sd diff --git a/megatron/core/utils.py b/megatron/core/utils.py index abfaf7f632..52d8032b24 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -793,15 +793,37 @@ def make_tp_sharded_tensor_for_checkpoint( is sharded across TP group. Optionally, can provide offsets which prepend new dimensions to the tensor. + + Args: + tensor: Tensor to shard + key: Key for the sharded tensor + tp_axis: Axis to shard across tensor parallel group (default: 0) + replica_id: Replica ID for the tensor (default: None) + prepend_offsets: Offsets to prepend to tensor dimensions (default: ()) + **kwargs: Additional arguments. May include: + - tp_group: Tensor parallel group (default: None, falls back to parallel_state) + - dp_cp_group: Data parallel + context parallel group + (default: None, falls back to parallel_state) """ + # Pop group parameters from kwargs + tp_group = kwargs.pop('tp_group', None) + dp_cp_group = kwargs.pop('dp_cp_group', None) + prepend_axis_num = len(prepend_offsets) new_offsets = [] - tp_rank = parallel_state.get_tensor_model_parallel_rank() - dp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) - tp_size = parallel_state.get_tensor_model_parallel_world_size() - dp_size = parallel_state.get_data_parallel_world_size(with_context_parallel=True) - dp_replica_id = parallel_state.get_data_parallel_rank(with_context_parallel=True) + + # Get groups with fallback to parallel_state + if tp_group is None and dp_cp_group is None: + tp_group = parallel_state.get_tensor_model_parallel_group() + dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) + + # Use local get_pg_rank and get_pg_size functions + tp_rank = get_pg_rank(tp_group) + dp_rank = get_pg_rank(dp_cp_group) + tp_size = get_pg_size(tp_group) + dp_size = get_pg_size(dp_cp_group) + dp_replica_id = get_pg_rank(dp_cp_group) new_offsets.append((tp_axis + prepend_axis_num, tp_rank, tp_size)) @@ -837,14 +859,34 @@ def make_sharded_tensor_for_checkpoint(tensor, key, prepend_offsets=(), replica_ """Helper for instantiating a non-sharded ShardedTensor (replicated across TP and DP group). Optionally, can provide offsets which prepend new dimensions to the tensor. + + Keyword Args: + tensor: Tensor to create sharded tensor for + key: Key for the sharded tensor + prepend_offsets: Offsets to prepend to tensor dimensions (default: ()) + replica_id: Replica ID for the tensor (default: None) + **kwargs: Additional arguments. May include: + - tp_group: Tensor parallel group (default: None, falls back to parallel_state) + - dp_cp_group: Data parallel + context parallel group + (default: None, falls back to parallel_state) """ + # Pop group parameters from kwargs + tp_group = kwargs.pop('tp_group', None) + dp_cp_group = kwargs.pop('dp_cp_group', None) prepend_axis_num = len(prepend_offsets) new_offsets = [] - dp_rank = parallel_state.get_data_parallel_rank(with_context_parallel=True) - dp_size = parallel_state.get_data_parallel_world_size(with_context_parallel=True) - dp_replica_id = parallel_state.get_data_parallel_rank(with_context_parallel=True) + + # Get groups with fallback to parallel_state + if tp_group is None and dp_cp_group is None: + tp_group = parallel_state.get_tensor_model_parallel_group() + dp_cp_group = parallel_state.get_data_parallel_group(with_context_parallel=True) + + # Use local get_pg_rank and get_pg_size functions + dp_rank = get_pg_rank(dp_cp_group) + dp_size = get_pg_size(dp_cp_group) + dp_replica_id = get_pg_rank(dp_cp_group) if HAVE_DTENSOR and isinstance(tensor, DTensor): # FSDP2 sharding @@ -853,7 +895,7 @@ def make_sharded_tensor_for_checkpoint(tensor, key, prepend_offsets=(), replica_ new_offsets.append((prepend_axis_num, dp_rank, dp_size)) if replica_id is None: - replica_id = (0, parallel_state.get_tensor_model_parallel_rank(), dp_replica_id) + replica_id = (0, get_pg_rank(tp_group), dp_replica_id) return ShardedTensor.from_rank_offsets( key, diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index f912be8465..a7ca22f336 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -17,6 +17,7 @@ import numpy as np import torch +from typing import Optional, Union, List, Dict, Any from torch.distributed.checkpoint import FileSystemReader, default_planner from megatron.core import dist_checkpointing, mpu, tensor_parallel @@ -28,6 +29,7 @@ ) from megatron.core.msc_utils import MultiStorageClientFeature, open_file from megatron.core.num_microbatches_calculator import update_num_microbatches +from megatron.core.utils import get_pg_rank, get_pg_size from megatron.core.optimizer import DistributedOptimizer from megatron.core.rerun_state_machine import get_rerun_state_machine from megatron.core.utils import get_torch_version, is_torch_min_version @@ -306,7 +308,7 @@ def read_metadata(tracker_filename): return max_iter, release -def get_rng_state(ckpt_format: str): +def get_rng_state(ckpt_format: str, tp_group: torch.distributed.ProcessGroup, pp_group: torch.distributed.ProcessGroup) -> Union[List[Dict[str, Any]], ShardedObject]: """Collect rng state across data parallel ranks.""" args = get_args() rng_state = { @@ -329,10 +331,10 @@ def get_rng_state(ckpt_format: str): rng_state_list = [rng_state] if ckpt_format == "torch_dist": - pp_rank = mpu.get_pipeline_model_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() + pp_rank = get_pg_rank(pp_group) + pp_size = get_pg_size(pp_group) + tp_rank = get_pg_rank(tp_group) + tp_size = get_pg_size(tp_group) rng_state_list = ShardedObject('rng_state', rng_state_list, (pp_size, tp_size), (pp_rank, tp_rank), replica_id=mpu.get_data_parallel_rank(with_context_parallel=True)) elif ckpt_format == "fsdp_dtensor": @@ -351,7 +353,25 @@ class CheckpointType(Enum): TORCH_DCP = auto() FSDP_DTENSOR = auto() -def _build_sharded_state_dict_metadata(args: Namespace) -> dict: +def _clean_metadata_for_serialization(metadata: dict) -> dict: + """Create a clean copy of metadata for serialization by removing non-serializable objects. + + Args: + metadata: Original metadata dict + + Returns: + Clean metadata dict suitable for serialization + """ + if metadata is None: + return None + + clean_metadata = metadata.copy() + # Remove dp_cp_group as it's not serializable + clean_metadata.pop('dp_cp_group', None) + return clean_metadata + + +def _build_sharded_state_dict_metadata(args: Namespace, dp_cp_group: Optional[torch.distributed.ProcessGroup] = None) -> dict: """Builds metadata used for sharded_state_dict versioning. The whole content metadata is passed to ``shared_state_dict`` model and optimizer methods @@ -361,6 +381,10 @@ def _build_sharded_state_dict_metadata(args: Namespace) -> dict: In particular, a simple integer (or SemVer) versioning flag (e.g. `metadata['version'] = 3.4`) is discouraged, because the metadata serves for all models and optimizers and it's practically impossible to enforce a linearly increasing versioning for this whole space. + + Args: + args: Arguments namespace + dp_cp_group: Data parallel + context parallel group (default: None, falls back to mpu API) """ metadata = {} @@ -389,11 +413,15 @@ def _build_sharded_state_dict_metadata(args: Namespace) -> dict: metadata['singleton_local_shards'] = False metadata['chained_optim_avoid_prefix'] = True + # Add dp_cp_group to metadata. If not provided, fallback to global parallel state. + if dp_cp_group is None: + dp_cp_group = mpu.get_data_parallel_group(with_context_parallel=True) + metadata['dp_cp_group'] = dp_cp_group return metadata def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far, checkpointing_context=None, pipeline_rank=None, expert_rank=None, tensor_rank=None, pipeline_parallel=None, expert_parallel=None, non_persistent_ckpt=False, - train_data_iterator=None, preprocess_common_state_dict_fn = None, release=False): + train_data_iterator=None, preprocess_common_state_dict_fn = None, release=False, tp_group: Optional[torch.distributed.ProcessGroup] = None, pp_group: Optional[torch.distributed.ProcessGroup] = None, dp_cp_group: Optional[torch.distributed.ProcessGroup] = None): """Save a model, optimizer and optionally dataloader checkpoint. Checkpointing context is used to persist some checkpointing state @@ -407,6 +435,9 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati Dataloader checkpoint is only saved if the dataloader supports it. Currently this applies only to the Megatron Energon dataloader (multimodal) and not the built-in Megatron dataloader (text-only). + + Args: + dp_cp_group: Data parallel + context parallel group (default: None, falls back to mpu API) """ start_ckpt = time() args = get_args() @@ -450,7 +481,10 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati iteration, save_dir, ckpt_format)) # Collect rng state across data parallel ranks. - rng_state = get_rng_state(args.ckpt_format) + if tp_group is None and pp_group is None: + tp_group = mpu.get_tensor_model_parallel_group() + pp_group = mpu.get_pipeline_model_parallel_group() + rng_state = get_rng_state(args.ckpt_format, tp_group, pp_group) # Collect rerun state across all ranks rerun_state_machine = get_rerun_state_machine() @@ -493,7 +527,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati or mpu.get_expert_data_parallel_rank() == 0 \ or ckpt_type != CheckpointType.LEGACY: if ckpt_type != CheckpointType.LEGACY: - sharded_sd_metadata = _build_sharded_state_dict_metadata(args) + sharded_sd_metadata = _build_sharded_state_dict_metadata(args, dp_cp_group=dp_cp_group) if args.use_distributed_optimizer: print_rank_0(f'Storing distributed optimizer sharded state of type' f' {sharded_sd_metadata["distrib_optim_sharding_type"]}') @@ -545,7 +579,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati async_sharded_save=args.async_save, validate_access_integrity=validate_sharding_integrity, preprocess_common_before_consistancy_check=preprocess_common_state_dict_fn, - content_metadata=sharded_sd_metadata) + content_metadata=_clean_metadata_for_serialization(sharded_sd_metadata)) # [ModelOpt]: save sharded modelopt_state if has_nvidia_modelopt: save_sharded_modelopt_state(model, checkpoint_name, (args.ckpt_format, 1)) @@ -806,7 +840,13 @@ def generate_state_dict( key = f"model{i}" if args.ckpt_format == "torch_dist": - model_sd = model[i].sharded_state_dict(**(model_sd_kwargs or {})) + model_sd = model[i].sharded_state_dict( + **(model_sd_kwargs or { + "metadata": { + "dp_cp_group": mpu.get_data_parallel_group(with_context_parallel=True) + } + }) + ) else: # torch, torch_dcp, fsdp_dtensor model_sd = model[i].state_dict_for_save_checkpoint() @@ -815,10 +855,16 @@ def generate_state_dict( # Optimizer stuff. if not args.no_save_optim: if optimizer is not None and not optimizer.is_stub_optimizer: - optimizer_sd = None if args.ckpt_format == "torch_dist": - optimizer_sd = optimizer.sharded_state_dict(state_dict, **(optim_sd_kwargs or {})) + optimizer_sd = optimizer.sharded_state_dict( + state_dict, + **(optim_sd_kwargs or { + "metadata": { + "dp_cp_group": mpu.get_data_parallel_group(with_context_parallel=True) + } + }) + ) elif args.ckpt_format == "fsdp_dtensor": if optim_sd_kwargs is None: optim_sd_kwargs = {} @@ -1361,7 +1407,7 @@ def _set_arg(arg_name, old_arg_name=None, force=False): def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', strict=True, - checkpointing_context=None, skip_load_to_model_and_opt=False): + checkpointing_context=None, skip_load_to_model_and_opt=False, tp_group: Optional[torch.distributed.ProcessGroup] = None, pp_group: Optional[torch.distributed.ProcessGroup] = None, dp_cp_group: Optional[torch.distributed.ProcessGroup] = None): """Load a model checkpoint and return the iteration. strict (bool): whether to strictly enforce that the keys in :attr:`state_dict` of the checkpoint match the names of @@ -1369,6 +1415,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', skip_load_to_model_and_opt (bool): whether to call `load_state_dict` for :attr:`model` and :attr:`optimizer`. In case of running FSDP2 with mcore distributed checkpointing, the tensors are already loaded in-place by `_load_base_checkpoint`. + dp_cp_group: Data parallel + context parallel group (default: None, falls back to mpu API) """ args = get_args() load_dir = getattr(args, load_arg) @@ -1442,7 +1489,10 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', # Determine if RNG state will be loaded if (ckpt_tp_pp == run_tp_pp and not release and not args.finetune and not args.no_load_rng and not getattr(ckpt_args, 'no_save_rng', False)): - gen_sd_rng_state = get_rng_state(args.ckpt_format) # we can load the rng state + if tp_group is None and pp_group is None: + tp_group = mpu.get_tensor_model_parallel_group() + pp_group = mpu.get_pipeline_model_parallel_group() + gen_sd_rng_state = get_rng_state(args.ckpt_format, tp_group, pp_group) # we can load the rng state else: ignore_rng_state = True gen_sd_rng_state = None @@ -1454,6 +1504,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', else: sharded_sd_metadata = dist_checkpointing.load_content_metadata(preloaded_state_dict=state_dict) print_rank_0(f'sharded_state_dict metadata loaded from the checkpoint: {sharded_sd_metadata}') + # Determine if optimizer state will be loaded if (not release and not args.finetune and not args.no_load_optim and not getattr(ckpt_args, 'no_save_optim', False)): @@ -1487,6 +1538,15 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', gen_sd_optim = None gen_sd_opt_param_scheduler = None + if dp_cp_group is None: + dp_cp_group = mpu.get_data_parallel_group(with_context_parallel=True) + + # dist_checkpointing.load_content_metadata(...) may return None. + # Ensure we have a dict before updating to avoid NoneType AttributeError. + if sharded_sd_metadata is None: + sharded_sd_metadata = {} + sharded_sd_metadata["dp_cp_group"] = dp_cp_group + optim_sd_kwargs = dict(metadata=sharded_sd_metadata, is_loading=True) model_sd_kwargs = dict(metadata=sharded_sd_metadata) @@ -1528,12 +1588,15 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', elif args.ckpt_format == "torch_dcp": model_sd = model[0].state_dict() optimizer_sd = optimizer.state_dict(is_loading=True) + if tp_group is None and pp_group is None: + tp_group = mpu.get_tensor_model_parallel_group() + pp_group = mpu.get_pipeline_model_parallel_group() sharded_state_dict = { "model": model_sd, "optimizer": optimizer_sd, "args": None, "iteration": 1, - "rng_state": get_rng_state(args.ckpt_format), + "rng_state": get_rng_state(args.ckpt_format, tp_group, pp_group), "checkpoint_version": None, "opt_param_scheduler": opt_param_scheduler.state_dict(), "num_floating_point_operations_so_far": 0, @@ -1556,7 +1619,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', data_iterator=None, ckpt_format=ckpt_format, force=True, ) if not args.no_load_rng: - gen_sd_rng_state = get_rng_state(args.ckpt_format) + gen_sd_rng_state = get_rng_state(args.ckpt_format, tp_group, pp_group) if not args.no_load_optim: gen_sd_optim = optimizer gen_sd_opt_param_scheduler = opt_param_scheduler diff --git a/tests/functional_tests/python_test_utils/test_inference_regular_pipeline.py b/tests/functional_tests/python_test_utils/test_inference_regular_pipeline.py index b9c70046a4..252aa85c38 100644 --- a/tests/functional_tests/python_test_utils/test_inference_regular_pipeline.py +++ b/tests/functional_tests/python_test_utils/test_inference_regular_pipeline.py @@ -1,3 +1,5 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + import json import logging import math diff --git a/tests/functional_tests/python_test_utils/test_pretraining_resume_checkpoint_pipeline.py b/tests/functional_tests/python_test_utils/test_pretraining_resume_checkpoint_pipeline.py index 6aeb412a8f..a35a72651d 100644 --- a/tests/functional_tests/python_test_utils/test_pretraining_resume_checkpoint_pipeline.py +++ b/tests/functional_tests/python_test_utils/test_pretraining_resume_checkpoint_pipeline.py @@ -1,3 +1,5 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + import logging from typing import Dict @@ -20,9 +22,7 @@ def test_resume_checkpoint_pipeline( model_config = yaml.safe_load(f) checks_types = ( - model_config["METRICS"] - if "METRICS" in model_config - else ["lm loss", "num-zeros"] + model_config["METRICS"] if "METRICS" in model_config else ["lm loss", "num-zeros"] ) checks = { metric: test_pretraining_regular_pipeline.CHECK_THRESHOLDS[metric] diff --git a/tests/unit_tests/dist_checkpointing/models/common.py b/tests/unit_tests/dist_checkpointing/models/common.py index 31b5d9db3c..8cb1dc4df6 100644 --- a/tests/unit_tests/dist_checkpointing/models/common.py +++ b/tests/unit_tests/dist_checkpointing/models/common.py @@ -91,7 +91,8 @@ def common_test_parallel_reconfiguration_e2e( save(gpt_model_A.sharded_state_dict(metadata=metadata), ckpt_dir_A, save_strategy) regular_state_dict_A = gpt_model_A.state_dict() Utils.destroy_model_parallel() - + if metadata is not None: + metadata.pop("dp_cp_group") # Load checkpoint A with different TP/PP and save as checkpoint B # No FPS this time, only FPL Utils.initialize_model_parallel(*dest_tp_pp, **(dst_tp_pp_kwargs or {}), order=store_order) diff --git a/tests/unit_tests/dist_checkpointing/models/test_mamba.py b/tests/unit_tests/dist_checkpointing/models/test_mamba.py index ff2c630997..85fbe5dd04 100644 --- a/tests/unit_tests/dist_checkpointing/models/test_mamba.py +++ b/tests/unit_tests/dist_checkpointing/models/test_mamba.py @@ -130,6 +130,8 @@ def test_parallel_reconfiguration_e2e( ) save(sharded_state_dict, ckpt_dir_A, save_strategy) Utils.destroy_model_parallel() + if metadata is not None: + metadata.pop("dp_cp_group") # Load checkpoint A with different TP/PP/expert/CP and save as checkpoint B # No FPS this time, only FPL diff --git a/tests/unit_tests/dist_checkpointing/models/test_mlp_glu.py b/tests/unit_tests/dist_checkpointing/models/test_mlp_glu.py index 18cfbf67ce..0970e2adc8 100644 --- a/tests/unit_tests/dist_checkpointing/models/test_mlp_glu.py +++ b/tests/unit_tests/dist_checkpointing/models/test_mlp_glu.py @@ -71,6 +71,9 @@ def test_parallel_reconfiguration_e2e( save(mlp_A.sharded_state_dict(prefix=layer_prefix, metadata=metadata), ckpt_dir_A) Utils.destroy_model_parallel() + if "dp_cp_group" in metadata.keys(): + del metadata["dp_cp_group"] + # Load checkpoint A with different TP/PP and save as checkpoint B Utils.initialize_model_parallel(*dest_tp_pp) mlp_B = initialize_mlp() diff --git a/tests/unit_tests/dist_checkpointing/models/test_moe_experts.py b/tests/unit_tests/dist_checkpointing/models/test_moe_experts.py index b116d2cb60..ca546d746a 100644 --- a/tests/unit_tests/dist_checkpointing/models/test_moe_experts.py +++ b/tests/unit_tests/dist_checkpointing/models/test_moe_experts.py @@ -190,6 +190,9 @@ def test_parallel_reconfiguration_e2e( save(sharded_state_dict, ckpt_dir_A, save_strategy) Utils.destroy_model_parallel() + if "dp_cp_group" in metadata.keys(): + del metadata["dp_cp_group"] + # Load checkpoint A with different TP/PP/EP and save as checkpoint B # No FPS this time, only FPL Utils.initialize_model_parallel( @@ -276,6 +279,9 @@ def test_sequential_grouped_mlp_interchangeable( save(sharded_state_dict, ckpt_dir_A, save_strategy) Utils.destroy_model_parallel() + if "dp_cp_group" in metadata.keys(): + del metadata["dp_cp_group"] + Utils.initialize_model_parallel(dest_tp, dest_pp, expert_model_parallel_size=dest_exp) model_B = initialize_expert_layer(1, use_glu, expert_type=dest_module) load_strategy = None @@ -351,6 +357,9 @@ def test_sequential_grouped_mlp_extra_state( save(sharded_state_dict, ckpt_dir_A, save_strategy) Utils.destroy_model_parallel() + if "dp_cp_group" in metadata.keys(): + del metadata["dp_cp_group"] + Utils.initialize_model_parallel(dest_tp, dest_pp, expert_model_parallel_size=dest_exp) load_strategy = None diff --git a/tests/unit_tests/dist_checkpointing/test_local.py b/tests/unit_tests/dist_checkpointing/test_local.py index 1b8597e1f1..5ce3422c72 100644 --- a/tests/unit_tests/dist_checkpointing/test_local.py +++ b/tests/unit_tests/dist_checkpointing/test_local.py @@ -26,6 +26,7 @@ LocalCheckpointManager, ) +from megatron.core import parallel_state from megatron.core.dist_checkpointing import ShardedTensor from megatron.core.dist_checkpointing.dict_utils import diff from megatron.core.dist_checkpointing.mapping import ShardedBase, ShardedTensorFactory @@ -78,7 +79,11 @@ def test_sharded_tensors(self, tp, pp, use_torch_fsdp2): opt_param_scheduler = None rng_state = None iteration = None - optim_sd_kwargs = dict(sharding_type='fully_sharded_model_space') + metadata = dict( + dp_cp_group=parallel_state.get_data_parallel_group(with_context_parallel=True) + ) + model_sd_kwargs = dict(metadata=metadata) + optim_sd_kwargs = dict(sharding_type='fully_sharded_model_space', metadata=metadata) mock_args = parse_args(ignore_unknown_args=True) mock_args.no_save_optim = False mock_args.no_save_rng = True @@ -91,6 +96,7 @@ def test_sharded_tensors(self, tp, pp, use_torch_fsdp2): opt_param_scheduler, rng_state, iteration=iteration, + model_sd_kwargs=model_sd_kwargs, optim_sd_kwargs=optim_sd_kwargs, ) sharded_tensor_factories = find_matching_values( diff --git a/tests/unit_tests/post_training/test_modelopt_module_spec.py b/tests/unit_tests/post_training/test_modelopt_module_spec.py index f27a22390f..ec80fcb1a7 100644 --- a/tests/unit_tests/post_training/test_modelopt_module_spec.py +++ b/tests/unit_tests/post_training/test_modelopt_module_spec.py @@ -6,7 +6,7 @@ import torch from packaging.version import Version -from megatron.core import dist_checkpointing +from megatron.core import dist_checkpointing, parallel_state from megatron.core.inference.contexts import StaticInferenceContext from megatron.core.models.gpt.gpt_layer_specs import ( get_gpt_decoder_block_spec, @@ -92,8 +92,11 @@ def setup_method(self, method): def test_sharded_state_dict_restore(self, tmp_path_dist_ckpt): """Save with the default TE spec and restore using the ModelOpt spec.""" _dist_checkpoint_name = "default_model" - te_fused_sharded_state_dict = self.default_model.sharded_state_dict() - modelopt_sharded_state_dict = self.modelopt_model.sharded_state_dict() + metadata = { + "dp_cp_group": parallel_state.get_data_parallel_group(with_context_parallel=True) + } + te_fused_sharded_state_dict = self.default_model.sharded_state_dict(metadata=metadata) + modelopt_sharded_state_dict = self.modelopt_model.sharded_state_dict(metadata=metadata) with TempNamedDir(tmp_path_dist_ckpt / _dist_checkpoint_name, sync=True) as tmpdirname: dist_checkpointing.save(te_fused_sharded_state_dict, tmpdirname)