diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py index 8b4740516e..6e093f96f7 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -1153,7 +1153,8 @@ def _param_name(self, param: torch.nn.Parameter) -> str: "Ensure that each model chunk has unique parameter names." ) name_to_param.update(_name_to_param) - name_to_param = handle_experts_in_state_dict(name_to_param) + num_experts = self.model_chunks[0].config.num_moe_experts if self.model_chunks else None + name_to_param = handle_experts_in_state_dict(name_to_param, num_experts) self.param_to_name = {param: name for name, param in name_to_param.items()} assert ( param in self.param_to_name diff --git a/megatron/core/transformer/fsdp_dtensor_checkpoint.py b/megatron/core/transformer/fsdp_dtensor_checkpoint.py index 9ef3f1f1b8..f7a938aff2 100644 --- a/megatron/core/transformer/fsdp_dtensor_checkpoint.py +++ b/megatron/core/transformer/fsdp_dtensor_checkpoint.py @@ -47,31 +47,24 @@ from megatron.core.transformer.transformer_layer import TransformerLayer -def get_ep_layer_offset(): +def get_ep_layer_offset(num_experts: int | None = None) -> int: """ Get the expert layer offset for the current model. - """ - from megatron.training.global_vars import get_args - args = get_args() + Args: + num_experts: Total number of experts in the model. If None, returns 0. + + Returns: + The expert layer offset for the current EP rank. + """ ep_size = parallel_state.get_expert_model_parallel_world_size() ep_rank = parallel_state.get_expert_model_parallel_rank() - num_local_experts = args.num_experts // ep_size if args.num_experts else 0 + num_local_experts = num_experts // ep_size if num_experts else 0 local_expert_offset = ep_rank * num_local_experts return local_expert_offset -def get_total_num_experts(): - """ - Get the total number of experts for the current model. - """ - from megatron.training.global_vars import get_args - - args = get_args() - return args.num_experts if args.num_experts else 0 - - def get_expert_index_from_key(key): """Extract expert index from various expert key formats. @@ -96,12 +89,19 @@ def get_expert_index_from_key(key): return None -def handle_experts_in_state_dict(state_dict): +def handle_experts_in_state_dict(state_dict, num_experts: int | None = None): """ Rewrite expert keys in state dict. + + Args: + state_dict: The state dictionary to process. + num_experts: Total number of experts in the model. If None, no expert processing occurs. + + Returns: + The processed state dictionary with rewritten expert keys. """ - local_expert_start = get_ep_layer_offset() - local_expert_end = get_total_num_experts() + local_expert_start = get_ep_layer_offset(num_experts) + local_expert_end = num_experts if num_experts else 0 def should_keep_expert_key(expert_index): """Determine if this rank should keep this expert key based on expert index""" @@ -147,9 +147,17 @@ def replace_expert_index_in_key(key, expert_index, state_dict): return state_dict -def expert_param_local_key(key): - """Get the module parameter corresponding to the key.""" - local_expert_offset = get_ep_layer_offset() +def expert_param_local_key(key: str, num_experts: int | None = None) -> str: + """Get the module parameter corresponding to the key. + + Args: + key: The parameter key to process. + num_experts: Total number of experts in the model. If None, no expert processing occurs. + + Returns: + The local parameter key with adjusted expert indices. + """ + local_expert_offset = get_ep_layer_offset(num_experts) expert_index = get_expert_index_from_key(key) if expert_index is not None: new_expert_index = expert_index - local_expert_offset @@ -174,6 +182,9 @@ def handle_swiglu_in_state_dict(model, model_state_dict, optimizer_state_dict): """ assert HAVE_MEGATRON_FSDP, "This function requires Megatron-FSDP to be installed." + # Extract num_experts from model config for expert parameter processing + num_experts = model.config.num_moe_experts if hasattr(model, 'config') else None + def intersection(s1, s2): # Only works for step=1 start = max(s1.start, s2.start) @@ -297,7 +308,9 @@ def split_swiglu_linear_fc1(data, dist_param, swiglu_shard_axis, is_expert_param new_opt_state_dict[f"{key}_w"] = opt_state_dict[key].copy() new_opt_state_dict[f"{key}_v"] = opt_state_dict[key].copy() for subkey in ["exp_avg", "exp_avg_sq"]: - dist_param = model.get_parameter(expert_param_local_key(key[len("module.") :])) + dist_param = model.get_parameter( + expert_param_local_key(key[len("module.") :], num_experts) + ) weight_w, weight_v = split_swiglu_linear_fc1( opt_state_dict[key][subkey], dist_param, @@ -426,6 +439,13 @@ def validate_loaded_state_dict(state_dict, checkpoint_path): def get_global_unique_param_name(model_chunks, param): """ Get the global unique parameter name for a given model and parameter. + + Args: + model_chunks: List of model chunks to search for the parameter. + param: The parameter to find the name for. + + Returns: + The global unique parameter name. """ param_name = None for model in model_chunks: @@ -450,6 +470,7 @@ def get_global_unique_param_name(model_chunks, param): param_name = re.sub(r"layers\.(\d+)", f"layers.{tf_layer_number - 1}", param_name) # Get EP unique parameter name - param_name = list(handle_experts_in_state_dict({param_name: None}).keys())[0] + num_experts = model_chunks[0].config.num_moe_experts if model_chunks else None + param_name = next(iter(handle_experts_in_state_dict({param_name: None}, num_experts).keys())) return param_name diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index 52e0500e52..b37c63af03 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -862,7 +862,7 @@ def preprocess_fsdp_dtensor_state_dict(args, raw_state_dict, model): ) state_dict["model"] = model_state_dict if args.num_experts: - state_dict["model"] = handle_experts_in_state_dict(state_dict["model"]) + state_dict["model"] = handle_experts_in_state_dict(state_dict["model"], args.num_experts) preprocess_state_dict_for_uneven_dtensor(state_dict) return state_dict