Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion megatron/core/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,8 @@ def _param_name(self, param: torch.nn.Parameter) -> str:
"Ensure that each model chunk has unique parameter names."
)
name_to_param.update(_name_to_param)
name_to_param = handle_experts_in_state_dict(name_to_param)
num_experts = self.model_chunks[0].config.num_moe_experts if self.model_chunks else None
name_to_param = handle_experts_in_state_dict(name_to_param, num_experts)
self.param_to_name = {param: name for name, param in name_to_param.items()}
assert (
param in self.param_to_name
Expand Down
67 changes: 44 additions & 23 deletions megatron/core/transformer/fsdp_dtensor_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,31 +47,24 @@
from megatron.core.transformer.transformer_layer import TransformerLayer


def get_ep_layer_offset():
def get_ep_layer_offset(num_experts: int | None = None) -> int:
"""
Get the expert layer offset for the current model.
"""
from megatron.training.global_vars import get_args

args = get_args()
Args:
num_experts: Total number of experts in the model. If None, returns 0.

Returns:
The expert layer offset for the current EP rank.
"""
ep_size = parallel_state.get_expert_model_parallel_world_size()
ep_rank = parallel_state.get_expert_model_parallel_rank()
num_local_experts = args.num_experts // ep_size if args.num_experts else 0
num_local_experts = num_experts // ep_size if num_experts else 0
local_expert_offset = ep_rank * num_local_experts

return local_expert_offset


def get_total_num_experts():
"""
Get the total number of experts for the current model.
"""
from megatron.training.global_vars import get_args

args = get_args()
return args.num_experts if args.num_experts else 0


def get_expert_index_from_key(key):
"""Extract expert index from various expert key formats.

Expand All @@ -96,12 +89,19 @@ def get_expert_index_from_key(key):
return None


def handle_experts_in_state_dict(state_dict):
def handle_experts_in_state_dict(state_dict, num_experts: int | None = None):
"""
Rewrite expert keys in state dict.

Args:
state_dict: The state dictionary to process.
num_experts: Total number of experts in the model. If None, no expert processing occurs.

Returns:
The processed state dictionary with rewritten expert keys.
"""
local_expert_start = get_ep_layer_offset()
local_expert_end = get_total_num_experts()
local_expert_start = get_ep_layer_offset(num_experts)
local_expert_end = num_experts if num_experts else 0

def should_keep_expert_key(expert_index):
"""Determine if this rank should keep this expert key based on expert index"""
Expand Down Expand Up @@ -147,9 +147,17 @@ def replace_expert_index_in_key(key, expert_index, state_dict):
return state_dict


def expert_param_local_key(key):
"""Get the module parameter corresponding to the key."""
local_expert_offset = get_ep_layer_offset()
def expert_param_local_key(key: str, num_experts: int | None = None) -> str:
"""Get the module parameter corresponding to the key.

Args:
key: The parameter key to process.
num_experts: Total number of experts in the model. If None, no expert processing occurs.

Returns:
The local parameter key with adjusted expert indices.
"""
local_expert_offset = get_ep_layer_offset(num_experts)
expert_index = get_expert_index_from_key(key)
if expert_index is not None:
new_expert_index = expert_index - local_expert_offset
Expand All @@ -174,6 +182,9 @@ def handle_swiglu_in_state_dict(model, model_state_dict, optimizer_state_dict):
"""
assert HAVE_MEGATRON_FSDP, "This function requires Megatron-FSDP to be installed."

# Extract num_experts from model config for expert parameter processing
num_experts = model.config.num_moe_experts if hasattr(model, 'config') else None

def intersection(s1, s2):
# Only works for step=1
start = max(s1.start, s2.start)
Expand Down Expand Up @@ -297,7 +308,9 @@ def split_swiglu_linear_fc1(data, dist_param, swiglu_shard_axis, is_expert_param
new_opt_state_dict[f"{key}_w"] = opt_state_dict[key].copy()
new_opt_state_dict[f"{key}_v"] = opt_state_dict[key].copy()
for subkey in ["exp_avg", "exp_avg_sq"]:
dist_param = model.get_parameter(expert_param_local_key(key[len("module.") :]))
dist_param = model.get_parameter(
expert_param_local_key(key[len("module.") :], num_experts)
)
weight_w, weight_v = split_swiglu_linear_fc1(
opt_state_dict[key][subkey],
dist_param,
Expand Down Expand Up @@ -426,6 +439,13 @@ def validate_loaded_state_dict(state_dict, checkpoint_path):
def get_global_unique_param_name(model_chunks, param):
"""
Get the global unique parameter name for a given model and parameter.

Args:
model_chunks: List of model chunks to search for the parameter.
param: The parameter to find the name for.

Returns:
The global unique parameter name.
"""
param_name = None
for model in model_chunks:
Expand All @@ -450,6 +470,7 @@ def get_global_unique_param_name(model_chunks, param):
param_name = re.sub(r"layers\.(\d+)", f"layers.{tf_layer_number - 1}", param_name)

# Get EP unique parameter name
param_name = list(handle_experts_in_state_dict({param_name: None}).keys())[0]
num_experts = model_chunks[0].config.num_moe_experts if model_chunks else None
param_name = next(iter(handle_experts_in_state_dict({param_name: None}, num_experts).keys()))

return param_name
2 changes: 1 addition & 1 deletion megatron/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ def preprocess_fsdp_dtensor_state_dict(args, raw_state_dict, model):
)
state_dict["model"] = model_state_dict
if args.num_experts:
state_dict["model"] = handle_experts_in_state_dict(state_dict["model"])
state_dict["model"] = handle_experts_in_state_dict(state_dict["model"], args.num_experts)
preprocess_state_dict_for_uneven_dtensor(state_dict)

return state_dict
Expand Down
Loading