4242from megatron .core .transformer .mlp import MLP
4343from megatron .core .transformer .transformer_config import TransformerConfig
4444from megatron .core .transformer .utils import (
45+ ensure_metadata_has_dp_cp_group ,
4546 is_layer_window_attention ,
4647 make_sharded_tensors_for_checkpoint ,
4748)
@@ -420,6 +421,9 @@ def __init__(
420421 # duplicated across TP ranks
421422 setattr (param , "sequence_parallel" , self .config .sequence_parallel )
422423
424+ tp_group = get_tensor_model_parallel_group_if_none (tp_group , is_expert = is_expert )
425+ self ._tp_group = tp_group
426+
423427 def forward (self , x ):
424428 """Forward."""
425429 _is_first_microbatch = (
@@ -444,7 +448,14 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
444448 self .parallel_mode is None
445449 ), "TELinear sharded_state_dict can only be used with duplicated parallel mode"
446450 state_dict = self .state_dict (prefix = "" , keep_vars = True )
447- return make_sharded_tensors_for_checkpoint (state_dict , prefix , None , sharded_offsets )
451+ return make_sharded_tensors_for_checkpoint (
452+ state_dict ,
453+ prefix ,
454+ None ,
455+ sharded_offsets ,
456+ tp_group = self ._tp_group ,
457+ dp_cp_group = metadata ["dp_cp_group" ],
458+ )
448459
449460 def backward_dw (self ):
450461 """Compute weight gradients during the backward pass if delay_wgrad_compute is enabled."""
@@ -492,6 +503,7 @@ def __init__(
492503
493504 # TODO: For backward compatibility, remove in v0.15.
494505 tp_group = get_tensor_model_parallel_group_if_none (tp_group , is_expert = is_expert )
506+ self ._tp_group = tp_group
495507
496508 # TE returns a zero length Tensor when bias=False and
497509 # return_bias=True, but we prefer None. So in that case we
@@ -625,9 +637,15 @@ def forward(self, x):
625637
626638 def sharded_state_dict (self , prefix = "" , sharded_offsets = (), metadata = None ):
627639 """Sharding along axis 0, bias sharded"""
640+ metadata = ensure_metadata_has_dp_cp_group (metadata )
628641 state_dict = self .state_dict (prefix = "" , keep_vars = True )
629642 return make_sharded_tensors_for_checkpoint (
630- state_dict , prefix , {"weight" : 0 , "bias" : 0 }, sharded_offsets
643+ state_dict ,
644+ prefix ,
645+ {"weight" : 0 , "bias" : 0 },
646+ sharded_offsets ,
647+ tp_group = self ._tp_group ,
648+ dp_cp_group = metadata ["dp_cp_group" ],
631649 )
632650
633651 def __repr__ (self ):
@@ -670,6 +688,7 @@ def __init__(
670688 if gather_output :
671689 raise ValueError ("Transformer Engine linear layers do not support gather_output = True" )
672690 tp_group = get_tensor_model_parallel_group_if_none (tp_group , is_expert = is_expert )
691+ self ._tp_group = tp_group
673692 world_size = get_pg_size (tp_group )
674693 rank = get_pg_rank (tp_group )
675694
@@ -720,7 +739,12 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
720739 """Sharding along axis 0, bias sharded"""
721740 state_dict = self .state_dict (prefix = "" , keep_vars = True )
722741 return make_sharded_tensors_for_checkpoint (
723- state_dict , prefix , {"weight" : 0 , "bias" : 0 }, sharded_offsets
742+ state_dict ,
743+ prefix ,
744+ {"weight" : 0 , "bias" : 0 },
745+ sharded_offsets ,
746+ tp_group = self ._tp_group ,
747+ dp_cp_group = metadata ["dp_cp_group" ],
724748 )
725749
726750 def __repr__ (self ):
@@ -764,6 +788,7 @@ def __init__(
764788 "Transformer Engine linear layers do not support input_is_parallel = False"
765789 )
766790 tp_group = get_tensor_model_parallel_group_if_none (tp_group , is_expert = is_expert )
791+ self ._tp_group = tp_group
767792
768793 super ().__init__ (
769794 input_size = input_size ,
@@ -814,7 +839,12 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
814839 """Sharding along axis 1, bias not sharded"""
815840 state_dict = self .state_dict (prefix = "" , keep_vars = True )
816841 return make_sharded_tensors_for_checkpoint (
817- state_dict , prefix , {"weight" : 1 }, sharded_offsets
842+ state_dict ,
843+ prefix ,
844+ {"weight" : 1 },
845+ sharded_offsets ,
846+ tp_group = self ._tp_group ,
847+ dp_cp_group = metadata ["dp_cp_group" ],
818848 )
819849
820850 def __repr__ (self ):
@@ -901,6 +931,7 @@ def __init__(
901931 assert hasattr (
902932 pg_collection , "hcp"
903933 ), "TEDotProductAttention pg_collection must have hierarchical cp pg"
934+ self ._tp_group = pg_collection .tp
904935
905936 if is_te_min_version ("0.10.0" ):
906937 extra_kwargs ["attention_type" ] = attention_type
@@ -1078,7 +1109,12 @@ def sharded_state_dict(
10781109 else :
10791110 state_dict = {}
10801111 return make_sharded_tensors_for_checkpoint (
1081- state_dict , prefix , {'softmax_offset' : 0 }, sharded_offsets
1112+ state_dict ,
1113+ prefix ,
1114+ {'softmax_offset' : 0 },
1115+ sharded_offsets ,
1116+ tp_group = self ._tp_group ,
1117+ dp_cp_group = metadata ["dp_cp_group" ],
10821118 )
10831119
10841120
@@ -1138,6 +1174,7 @@ def __init__(
11381174 # The comms between TP and EP group is explicitly handled by MoE token dispatcher.
11391175 # So we disable comms by making TE agnostic of model parallel.
11401176 tp_group = get_tensor_model_parallel_group_if_none (tp_group , is_expert = is_expert )
1177+ self ._tp_group = tp_group
11411178 tp_size = get_pg_size (tp_group )
11421179
11431180 self .explicit_expert_comm = is_expert and (tp_size > 1 or self .expert_parallel )
@@ -1372,7 +1409,12 @@ def _sharded_state_dict_grouped(
13721409 (ep_axis , global_expert_idx , num_global_experts ),
13731410 )
13741411 sub_sd = make_sharded_tensors_for_checkpoint (
1375- state_dict , '' , tp_axis_map , new_sharded_offsets
1412+ state_dict ,
1413+ '' ,
1414+ tp_axis_map ,
1415+ new_sharded_offsets ,
1416+ tp_group = self ._tp_group ,
1417+ dp_cp_group = metadata ["dp_cp_group" ],
13761418 )
13771419 # Remove expert layers indexing from sharded keys
13781420 replace_prefix_for_sharding (sub_sd , f"{ gemm_idx } ." , expert_prefix )
0 commit comments