4747from megatron .core .transformer .transformer_layer import TransformerLayer
4848
4949
50- def get_ep_layer_offset () :
50+ def get_ep_layer_offset (num_experts : int | None = None ) -> int :
5151 """
5252 Get the expert layer offset for the current model.
53- """
54- from megatron .training .global_vars import get_args
5553
56- args = get_args ()
54+ Args:
55+ num_experts: Total number of experts in the model. If None, returns 0.
56+
57+ Returns:
58+ The expert layer offset for the current EP rank.
59+ """
5760 ep_size = parallel_state .get_expert_model_parallel_world_size ()
5861 ep_rank = parallel_state .get_expert_model_parallel_rank ()
59- num_local_experts = args . num_experts // ep_size if args . num_experts else 0
62+ num_local_experts = num_experts // ep_size if num_experts else 0
6063 local_expert_offset = ep_rank * num_local_experts
6164
6265 return local_expert_offset
6366
6467
65- def get_total_num_experts ():
66- """
67- Get the total number of experts for the current model.
68- """
69- from megatron .training .global_vars import get_args
70-
71- args = get_args ()
72- return args .num_experts if args .num_experts else 0
73-
74-
7568def get_expert_index_from_key (key ):
7669 """Extract expert index from various expert key formats.
7770
@@ -96,12 +89,19 @@ def get_expert_index_from_key(key):
9689 return None
9790
9891
99- def handle_experts_in_state_dict (state_dict ):
92+ def handle_experts_in_state_dict (state_dict , num_experts : int | None = None ):
10093 """
10194 Rewrite expert keys in state dict.
95+
96+ Args:
97+ state_dict: The state dictionary to process.
98+ num_experts: Total number of experts in the model. If None, no expert processing occurs.
99+
100+ Returns:
101+ The processed state dictionary with rewritten expert keys.
102102 """
103- local_expert_start = get_ep_layer_offset ()
104- local_expert_end = get_total_num_experts ()
103+ local_expert_start = get_ep_layer_offset (num_experts )
104+ local_expert_end = num_experts if num_experts else 0
105105
106106 def should_keep_expert_key (expert_index ):
107107 """Determine if this rank should keep this expert key based on expert index"""
@@ -147,9 +147,17 @@ def replace_expert_index_in_key(key, expert_index, state_dict):
147147 return state_dict
148148
149149
150- def expert_param_local_key (key ):
151- """Get the module parameter corresponding to the key."""
152- local_expert_offset = get_ep_layer_offset ()
150+ def expert_param_local_key (key : str , num_experts : int | None = None ) -> str :
151+ """Get the module parameter corresponding to the key.
152+
153+ Args:
154+ key: The parameter key to process.
155+ num_experts: Total number of experts in the model. If None, no expert processing occurs.
156+
157+ Returns:
158+ The local parameter key with adjusted expert indices.
159+ """
160+ local_expert_offset = get_ep_layer_offset (num_experts )
153161 expert_index = get_expert_index_from_key (key )
154162 if expert_index is not None :
155163 new_expert_index = expert_index - local_expert_offset
@@ -174,6 +182,9 @@ def handle_swiglu_in_state_dict(model, model_state_dict, optimizer_state_dict):
174182 """
175183 assert HAVE_MEGATRON_FSDP , "This function requires Megatron-FSDP to be installed."
176184
185+ # Extract num_experts from model config for expert parameter processing
186+ num_experts = model .config .num_moe_experts if hasattr (model , 'config' ) else None
187+
177188 def intersection (s1 , s2 ):
178189 # Only works for step=1
179190 start = max (s1 .start , s2 .start )
@@ -297,7 +308,9 @@ def split_swiglu_linear_fc1(data, dist_param, swiglu_shard_axis, is_expert_param
297308 new_opt_state_dict [f"{ key } _w" ] = opt_state_dict [key ].copy ()
298309 new_opt_state_dict [f"{ key } _v" ] = opt_state_dict [key ].copy ()
299310 for subkey in ["exp_avg" , "exp_avg_sq" ]:
300- dist_param = model .get_parameter (expert_param_local_key (key [len ("module." ) :]))
311+ dist_param = model .get_parameter (
312+ expert_param_local_key (key [len ("module." ) :], num_experts )
313+ )
301314 weight_w , weight_v = split_swiglu_linear_fc1 (
302315 opt_state_dict [key ][subkey ],
303316 dist_param ,
@@ -426,6 +439,13 @@ def validate_loaded_state_dict(state_dict, checkpoint_path):
426439def get_global_unique_param_name (model_chunks , param ):
427440 """
428441 Get the global unique parameter name for a given model and parameter.
442+
443+ Args:
444+ model_chunks: List of model chunks to search for the parameter.
445+ param: The parameter to find the name for.
446+
447+ Returns:
448+ The global unique parameter name.
429449 """
430450 param_name = None
431451 for model in model_chunks :
@@ -450,6 +470,7 @@ def get_global_unique_param_name(model_chunks, param):
450470 param_name = re .sub (r"layers\.(\d+)" , f"layers.{ tf_layer_number - 1 } " , param_name )
451471
452472 # Get EP unique parameter name
453- param_name = list (handle_experts_in_state_dict ({param_name : None }).keys ())[0 ]
473+ num_experts = model_chunks [0 ].config .num_moe_experts if model_chunks else None
474+ param_name = next (iter (handle_experts_in_state_dict ({param_name : None }, num_experts ).keys ()))
454475
455476 return param_name
0 commit comments