Skip to content
Open
Show file tree
Hide file tree
Changes from 89 commits
Commits
Show all changes
111 commits
Select commit Hold shift + click to select a range
fd8bdc1
update get_rng_state to take in comm pgs
yaoyu-33 Jul 28, 2025
a7477e3
update save and load checkpoint
yaoyu-33 Jul 28, 2025
1f4a19b
update checkpointing utils for m4
yaoyu-33 Jul 28, 2025
bbff062
update `make_sharded_tensor_for_checkpoint` and `make_sharded_tensor_…
yaoyu-33 Jul 28, 2025
8fba285
update fallback
yaoyu-33 Jul 28, 2025
9bad6eb
update `make_sharded_tensor_for_checkpoint` function calls
yaoyu-33 Jul 28, 2025
4a0f021
update `make_tp_sharded_tensor_for_checkpoint` for m4
yaoyu-33 Jul 28, 2025
794c0fc
chore: Format files
Jul 28, 2025
568da27
add dp_cp_group to metadata of ckpt
yaoyu-33 Jul 28, 2025
029a929
lint
yaoyu-33 Jul 28, 2025
790e56d
Merge remote-tracking branch 'origin/yuya/m4_dist_ckpt' into yuya/m4_…
yaoyu-33 Jul 28, 2025
9c45be8
lint
yaoyu-33 Jul 28, 2025
1ecf614
chore: Format files
Jul 28, 2025
bee469a
bug fix
yaoyu-33 Jul 28, 2025
b67586c
Merge remote-tracking branch 'origin/yuya/m4_dist_ckpt' into yuya/m4_…
yaoyu-33 Jul 28, 2025
3a9c8e6
chore: Format files
Jul 28, 2025
81d203d
fallback fix
yaoyu-33 Jul 29, 2025
f893946
chore: Format files
Jul 29, 2025
85c59f4
bug fix
yaoyu-33 Jul 29, 2025
7b59837
bug fix
yaoyu-33 Jul 30, 2025
7f9cce8
Merge branch 'refs/heads/main' into yuya/m4_dist_ckpt
yaoyu-33 Jul 31, 2025
ce6ff71
metadata guard
yaoyu-33 Aug 3, 2025
db5409e
reformats
yaoyu-33 Aug 3, 2025
d4dbe48
chore: Format files
Aug 3, 2025
6508c72
try remove expert_dist_ckpt_decorator
yaoyu-33 Aug 3, 2025
aefc4b6
lint
yaoyu-33 Aug 3, 2025
87c2b13
Merge remote-tracking branch 'origin/yuya/m4_dist_ckpt' into yuya/m4_…
yaoyu-33 Aug 3, 2025
b197a28
condition fix
yaoyu-33 Aug 4, 2025
3bc6dea
Merge branch 'main' into yuya/m4_dist_ckpt
yaoyu-33 Aug 5, 2025
7b1b555
lint
yaoyu-33 Aug 6, 2025
9dcff5e
Merge branch 'main' into yuya/m4_dist_ckpt
yaoyu-33 Aug 7, 2025
86f4292
temp update test
yaoyu-33 Aug 7, 2025
4d9a89e
try bug fix
yaoyu-33 Aug 7, 2025
02248d9
revert and some fix
yaoyu-33 Aug 7, 2025
7dfb88a
bug fix
yaoyu-33 Aug 7, 2025
787c800
lint
yaoyu-33 Aug 7, 2025
c4157b3
Revert "temp update test"
yaoyu-33 Aug 7, 2025
4f100de
bug fix for missing default values
yaoyu-33 Aug 7, 2025
5b42ca3
remove not used
yaoyu-33 Aug 7, 2025
4395f4d
Merge branch 'main' into yuya/m4_dist_ckpt
yaoyu-33 Aug 13, 2025
9bcf4cb
chore: Format files
Aug 13, 2025
ecbbbc7
bug fix for `sharded_state_dict_default`
yaoyu-33 Aug 13, 2025
d0cc504
Merge remote-tracking branch 'origin/yuya/m4_dist_ckpt' into yuya/m4_…
yaoyu-33 Aug 13, 2025
0a8446d
Merge branch 'refs/heads/main' into yuya/m4_dist_ckpt
yaoyu-33 Sep 25, 2025
87e231e
chore: Format files
Sep 25, 2025
bdeb0ed
main branch change
yaoyu-33 Sep 25, 2025
c9ff988
Merge branch 'refs/heads/main' into yuya/m4_dist_ckpt
yaoyu-33 Sep 25, 2025
9c75cb5
interface fixes for checkpoint with m4
yaoyu-33 Sep 25, 2025
c9407dc
chore: Format files
Sep 25, 2025
48be193
revert transformer_engine.py changes
yaoyu-33 Sep 25, 2025
3c30a13
bug fixes
yaoyu-33 Sep 25, 2025
5a671e9
lint
yaoyu-33 Sep 26, 2025
6360922
update to use logger to pass lint
yaoyu-33 Sep 29, 2025
39dcac3
chore: Format files
Sep 29, 2025
1b3cf6e
Guard for cases metadata is not provided
yaoyu-33 Sep 30, 2025
d2258c0
chore: Format files
Sep 30, 2025
ad61ac8
Merge branch 'main' into yuya/m4_dist_ckpt
yaoyu-33 Sep 30, 2025
1da7644
update to ensure_metadata_has_dp_cp_group util
yaoyu-33 Sep 30, 2025
321a68b
Add debug assert
yaoyu-33 Sep 30, 2025
6021b37
chore: Format files
Sep 30, 2025
aed7622
lint
yaoyu-33 Sep 30, 2025
dfc7978
Merge remote-tracking branch 'origin/yuya/m4_dist_ckpt' into yuya/m4_…
yaoyu-33 Sep 30, 2025
e71a0b0
Merge branch 'main' into yuya/m4_dist_ckpt
dimapihtar Oct 30, 2025
10a137f
minor fixes
dimapihtar Oct 30, 2025
386f840
fix te module issue
dimapihtar Nov 3, 2025
ab71e71
fix te module issue
dimapihtar Nov 3, 2025
7438282
add tp_group param
dimapihtar Nov 3, 2025
d02bbf1
minor fix
dimapihtar Nov 3, 2025
070fa5a
clear metadata
dimapihtar Nov 3, 2025
ab24fa4
fix tp_group
dimapihtar Nov 5, 2025
a961335
remove tp_group extra usage
dimapihtar Nov 5, 2025
406dff3
add tp_group to state_dict
dimapihtar Nov 5, 2025
ca4d1bf
add tp_group to state_dict
dimapihtar Nov 5, 2025
d8845f8
add tp_group to state_dict
dimapihtar Nov 5, 2025
df58591
add tp_group to state_dict
dimapihtar Nov 5, 2025
778b01d
fix tp_group
dimapihtar Nov 5, 2025
465d803
fix tp_group usage
dimapihtar Nov 5, 2025
bcf420f
fix typo
dimapihtar Nov 5, 2025
8cd778b
add tp_group param to state_dict
dimapihtar Nov 5, 2025
f19cb10
add tp_group param to state_dict
dimapihtar Nov 5, 2025
bd3f9e4
fix tp_group usage
dimapihtar Nov 5, 2025
787ec6a
fix metadata
dimapihtar Nov 5, 2025
6e9af69
fix unit tests
dimapihtar Nov 5, 2025
146d3c1
fix style
dimapihtar Nov 5, 2025
3a203d8
fix style
dimapihtar Nov 5, 2025
8bf8220
remove unused imports
dimapihtar Nov 5, 2025
ba4e32f
add headers
dimapihtar Nov 5, 2025
b33362d
Merge branch 'main' into yuya/m4_dist_ckpt
dimapihtar Nov 5, 2025
1311d81
fix style
dimapihtar Nov 5, 2025
f4e828a
pass tp_group param
dimapihtar Nov 6, 2025
f5bfcee
fix tp_group usage
dimapihtar Nov 6, 2025
5e375a2
fix tp_group usage
dimapihtar Nov 6, 2025
12cc898
remove debug assertions
dimapihtar Nov 6, 2025
9cbb410
fix style
dimapihtar Nov 6, 2025
a599cd1
pass tp_param to state_dict
dimapihtar Nov 6, 2025
4e11170
add tp_group to state_dict
dimapihtar Nov 6, 2025
088702a
pass tp_group param to sharded_state_dict
dimapihtar Nov 6, 2025
bcc30d7
pass tp_group to sharded_state_dict
dimapihtar Nov 6, 2025
d10b644
pass tp_group to sharded_state_dict
dimapihtar Nov 6, 2025
17cf6c9
fix training/checkpointing
yaoyu-33 Nov 6, 2025
e8a1f84
Merge remote-tracking branch 'dim/yuya/m4_dist_ckpt' into yuya/m4_dis…
yaoyu-33 Nov 6, 2025
93e6099
pass metadata arg
dimapihtar Nov 6, 2025
34c02b0
remove tp_group arg from sharded_state_dict
dimapihtar Nov 6, 2025
2900114
set self.tp_group
dimapihtar Nov 6, 2025
59b9485
pass tp_group & pp_group to get_rng_state
dimapihtar Nov 7, 2025
415b371
minor fix
dimapihtar Nov 7, 2025
10af8c0
remove extra tp_group arg
dimapihtar Nov 7, 2025
e612b04
remove extra tp_group arg
dimapihtar Nov 7, 2025
3456575
fix code style
dimapihtar Nov 7, 2025
0fc5fa0
fix code style
dimapihtar Nov 7, 2025
d59a993
set te_group properly
dimapihtar Nov 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 50 additions & 15 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from megatron.core.transformer.mlp import MLP
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import (
ensure_metadata_has_dp_cp_group,
is_layer_window_attention,
make_sharded_tensors_for_checkpoint,
)
Expand Down Expand Up @@ -434,7 +435,7 @@ def forward(self, x):
return out
return out, None

def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None, tp_group=None):
"""Replicate cross TP/DP."""

# Provide the dist-ckpt support when TELinear is directly used
Expand All @@ -443,7 +444,14 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
self.parallel_mode is None
), "TELinear sharded_state_dict can only be used with duplicated parallel mode"
state_dict = self.state_dict(prefix="", keep_vars=True)
return make_sharded_tensors_for_checkpoint(state_dict, prefix, None, sharded_offsets)
return make_sharded_tensors_for_checkpoint(
state_dict,
prefix,
None,
sharded_offsets,
tp_group=tp_group,
dp_cp_group=metadata["dp_cp_group"],
)

def backward_dw(self):
"""Compute weight gradients during the backward pass if delay_wgrad_compute is enabled."""
Expand Down Expand Up @@ -622,11 +630,17 @@ def forward(self, x):
return out
return out, None

def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None, tp_group=None):
"""Sharding along axis 0, bias sharded"""
metadata = ensure_metadata_has_dp_cp_group(metadata)
state_dict = self.state_dict(prefix="", keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {"weight": 0, "bias": 0}, sharded_offsets
state_dict,
prefix,
{"weight": 0, "bias": 0},
sharded_offsets,
tp_group=tp_group,
dp_cp_group=metadata["dp_cp_group"],
)

def __repr__(self):
Expand Down Expand Up @@ -715,11 +729,16 @@ def __init__(
self.bias.zero_()
setattr(self.bias, "allreduce", True)

def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None, tp_group=None):
"""Sharding along axis 0, bias sharded"""
state_dict = self.state_dict(prefix="", keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {"weight": 0, "bias": 0}, sharded_offsets
state_dict,
prefix,
{"weight": 0, "bias": 0},
sharded_offsets,
tp_group=tp_group,
dp_cp_group=metadata["dp_cp_group"],
)

def __repr__(self):
Expand Down Expand Up @@ -809,11 +828,16 @@ def __init__(
setattr(self.bias, "allreduce", True)
setattr(self.bias, "sequence_parallel", config.sequence_parallel)

def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None, tp_group=None):
"""Sharding along axis 1, bias not sharded"""
state_dict = self.state_dict(prefix="", keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {"weight": 1}, sharded_offsets
state_dict,
prefix,
{"weight": 1},
sharded_offsets,
tp_group=tp_group,
dp_cp_group=metadata["dp_cp_group"],
)

def __repr__(self):
Expand Down Expand Up @@ -1070,14 +1094,20 @@ def sharded_state_dict(
prefix: str = '',
sharded_offsets: Tuple[Tuple[int, int, int]] = (),
metadata: Optional[dict] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
) -> ShardedStateDict:
"""Sharded state dict for the learnable softmax offset parameter"""
if self.config.softmax_type == "learnable":
state_dict = self.state_dict(prefix="", keep_vars=True)
else:
state_dict = {}
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'softmax_offset': 0}, sharded_offsets
state_dict,
prefix,
{'softmax_offset': 0},
sharded_offsets,
tp_group=tp_group,
dp_cp_group=metadata["dp_cp_group"],
)


Expand Down Expand Up @@ -1341,7 +1371,7 @@ def _split_extra_state(self, state):
return extra_states

def _sharded_state_dict_grouped(
self, tp_axis_map, prefix="", sharded_offsets=(), metadata=None
self, tp_axis_map, prefix="", sharded_offsets=(), metadata=None, tp_group=None
):
"""
prefix should be module_name to make keys identical to sequetial ones.
Expand Down Expand Up @@ -1371,7 +1401,12 @@ def _sharded_state_dict_grouped(
(ep_axis, global_expert_idx, num_global_experts),
)
sub_sd = make_sharded_tensors_for_checkpoint(
state_dict, '', tp_axis_map, new_sharded_offsets
state_dict,
'',
tp_axis_map,
new_sharded_offsets,
tp_group=tp_group,
dp_cp_group=metadata["dp_cp_group"],
)
# Remove expert layers indexing from sharded keys
replace_prefix_for_sharding(sub_sd, f"{gemm_idx}.", expert_prefix)
Expand Down Expand Up @@ -1440,7 +1475,7 @@ def __init__(
tp_group=tp_group,
)

def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None, tp_group=None):
"""
For each gemm, sharding along axis 0, bias sharded.
Assume sharded_offsets[-1] is the expert parallel offset.
Expand All @@ -1449,7 +1484,7 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
for gemm_idx in range(self.num_gemms):
tp_axis_map.update({f"{gemm_idx}.weight": 0, f"{gemm_idx}.bias": 0})
return super()._sharded_state_dict_grouped(
tp_axis_map, prefix, sharded_offsets, metadata
tp_axis_map, prefix, sharded_offsets, metadata, tp_group=tp_group
)

class TERowParallelGroupedLinear(TEGroupedLinear):
Expand Down Expand Up @@ -1486,14 +1521,14 @@ def __init__(
tp_group=tp_group,
)

def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None, tp_group=None):
"""
For each gemm, sharding along axis 1, bias not sharded.
Assume sharded_offsets[-1] is the expert parallel offset.
"""
tp_axis_map = {f"{gemm_idx}.weight": 1 for gemm_idx in range(self.num_gemms)}
return super()._sharded_state_dict_grouped(
tp_axis_map, prefix, sharded_offsets, metadata
tp_axis_map, prefix, sharded_offsets, metadata, tp_group=tp_group
)

else:
Expand Down
4 changes: 3 additions & 1 deletion megatron/core/models/bert/bert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.dot_product_attention import (
DotProductAttention as MCoreDotProductAttention,
)
Expand Down Expand Up @@ -73,9 +74,10 @@ def __init__(
seq_len_interpolation_factor: Optional[float] = None,
add_binary_head=True,
return_embeddings=False,
pg_collection: Optional[ProcessGroupCollection] = None,
vp_stage: Optional[int] = None,
):
super(BertModel, self).__init__(config=config)
super(BertModel, self).__init__(config=config, pg_collection=pg_collection)

if has_config_logger_enabled(config):
log_config_to_disk(config, locals(), prefix=type(self).__name__)
Expand Down
14 changes: 13 additions & 1 deletion megatron/core/models/common/language_module/language_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
from megatron.core.transformer.enums import AttnBackend
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import is_te_min_version, make_tp_sharded_tensor_for_checkpoint
from megatron.core.transformer.utils import ensure_metadata_has_dp_cp_group
from megatron.core.utils import (
get_tensor_model_parallel_group_if_none,
is_te_min_version,
make_tp_sharded_tensor_for_checkpoint,
)


class LanguageModule(MegatronModule):
Expand All @@ -44,6 +49,7 @@ def __init__(
pg_collection = ProcessGroupCollection.use_mpu_process_groups()
self.pg_collection = pg_collection
self.cp_group = pg_collection.cp
self.tp_group = get_tensor_model_parallel_group_if_none(pg_collection.tp)
self.pp_group = pg_collection.pp
assert hasattr(self.pg_collection, 'embd'), (
"pg_collection must have a embd. In previous version, it used default "
Expand Down Expand Up @@ -272,6 +278,10 @@ def sharded_state_dict(
ShardedStateDict: sharded state dict for the LanguageModel
"""
assert not sharded_offsets, "Unexpected sharded offsets"

# Guard for cases metadata is not provided
metadata = ensure_metadata_has_dp_cp_group(metadata)

sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)

first_stage_word_emb_key = f'{prefix}embedding.word_embeddings.weight'
Expand Down Expand Up @@ -341,4 +351,6 @@ def tie_embeddings_and_output_weights_state_dict(
key=first_stage_word_emb_key,
replica_id=last_stage_word_emb_replica_id,
allow_shape_mismatch=True,
tp_group=self.tp_group,
dp_cp_group=metadata['dp_cp_group'],
)
14 changes: 12 additions & 2 deletions megatron/core/models/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,13 @@ def sharded_state_dict(
if self.mtp_process and not self.pre_process:
emb_weight_key = f'{prefix}embedding.word_embeddings.weight'
emb_weight = self.embedding.word_embeddings.weight
tie_word_embeddings_state_dict(sharded_state_dict, emb_weight, emb_weight_key)
tie_word_embeddings_state_dict(
sharded_state_dict,
emb_weight,
emb_weight_key,
tp_group=self.tp_group,
dp_cp_group=metadata['dp_cp_group'],
)
if self.mtp_process and not self.post_process:
# We only need to tie the output layer weight if share_embeddings_and_output_weights
# is False. Because if share_embeddings_and_output_weights is True, the shared weight
Expand All @@ -761,7 +767,11 @@ def sharded_state_dict(
output_layer_weight_key = f'{prefix}output_layer.weight'
output_layer_weight = self.output_layer.weight
tie_output_layer_state_dict(
sharded_state_dict, output_layer_weight, output_layer_weight_key
sharded_state_dict,
output_layer_weight,
output_layer_weight_key,
tp_group=self.tp_group,
dp_cp_group=metadata['dp_cp_group'],
)

return sharded_state_dict
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
pin_cpu_grads: bool = True,
pin_cpu_params: bool = True,
overlap_cpu_optimizer_d2h_h2d: bool = True,
**kwargs
**kwargs,
):
super(HybridDeviceOptimizer, self).__init__(
params,
Expand Down
12 changes: 10 additions & 2 deletions megatron/core/post_training/modelopt/layers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.

import logging
from typing import Callable, List, Optional

import torch
Expand All @@ -10,6 +11,8 @@
from megatron.core.transformer.transformer_layer import TransformerLayer
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint

logger = logging.getLogger(__name__)

try:
import transformer_engine as te

Expand Down Expand Up @@ -116,6 +119,7 @@ def __init__(
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
self.config = config
self.tp_group = tp_group

self._return_bias = skip_bias_add and bias

Expand Down Expand Up @@ -153,7 +157,11 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
if v.ndim == 0:
state_dict[k] = v.view(1)
sharded_state_dict = make_sharded_tensors_for_checkpoint(
state_dict, prefix, sharded_offsets=sharded_offsets
state_dict,
prefix,
sharded_offsets=sharded_offsets,
tp_group=self.tp_group,
dp_cp_group=metadata['dp_cp_group'],
)
return sharded_state_dict

Expand Down Expand Up @@ -229,7 +237,7 @@ def _report_quantize_tensor_info(self):
if not isinstance(v, torch.Tensor):
continue
original_dtype, original_shape = self._original_tensor_info.get(k, ("-", "-"))
print(
logger.info(
"{:<64} {:<16} {:<32} {:<16} {:<32}".format(
k, original_dtype, original_shape, str(v.dtype), str(v.shape)
)
Expand Down
7 changes: 6 additions & 1 deletion megatron/core/ssm/mamba_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def __init__(
assert pg_collection is not None, "pg_collection must be provided for MambaStack"

self.pp_group = pg_collection.pp
self.tp_group = pg_collection.tp

# Required for pipeline parallel schedules
self.input_tensor = None
Expand Down Expand Up @@ -416,7 +417,11 @@ def sharded_state_dict(
if not module is self.layers:
sharded_state_dict.update(
sharded_state_dict_default(
module, f'{prefix}{name}.', sharded_offsets, metadata
module,
f'{prefix}{name}.',
sharded_offsets,
metadata,
tp_group=self.tp_group,
)
)

Expand Down
Loading
Loading