Skip to content

Commit d87f1c7

Browse files
committed
reattempt #2113 with fixes to remove dependency on megatron.training from megatron.core
Signed-off-by: Ananth Subramaniam <[email protected]>
1 parent c4ba666 commit d87f1c7

File tree

3 files changed

+47
-25
lines changed

3 files changed

+47
-25
lines changed

megatron/core/optimizer/distrib_optimizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1153,7 +1153,8 @@ def _param_name(self, param: torch.nn.Parameter) -> str:
11531153
"Ensure that each model chunk has unique parameter names."
11541154
)
11551155
name_to_param.update(_name_to_param)
1156-
name_to_param = handle_experts_in_state_dict(name_to_param)
1156+
num_experts = self.model_chunks[0].config.num_moe_experts if self.model_chunks else None
1157+
name_to_param = handle_experts_in_state_dict(name_to_param, num_experts)
11571158
self.param_to_name = {param: name for name, param in name_to_param.items()}
11581159
assert (
11591160
param in self.param_to_name

megatron/core/transformer/fsdp_dtensor_checkpoint.py

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -47,31 +47,24 @@
4747
from megatron.core.transformer.transformer_layer import TransformerLayer
4848

4949

50-
def get_ep_layer_offset():
50+
def get_ep_layer_offset(num_experts: int | None = None) -> int:
5151
"""
5252
Get the expert layer offset for the current model.
53-
"""
54-
from megatron.training.global_vars import get_args
5553
56-
args = get_args()
54+
Args:
55+
num_experts: Total number of experts in the model. If None, returns 0.
56+
57+
Returns:
58+
The expert layer offset for the current EP rank.
59+
"""
5760
ep_size = parallel_state.get_expert_model_parallel_world_size()
5861
ep_rank = parallel_state.get_expert_model_parallel_rank()
59-
num_local_experts = args.num_experts // ep_size if args.num_experts else 0
62+
num_local_experts = num_experts // ep_size if num_experts else 0
6063
local_expert_offset = ep_rank * num_local_experts
6164

6265
return local_expert_offset
6366

6467

65-
def get_total_num_experts():
66-
"""
67-
Get the total number of experts for the current model.
68-
"""
69-
from megatron.training.global_vars import get_args
70-
71-
args = get_args()
72-
return args.num_experts if args.num_experts else 0
73-
74-
7568
def get_expert_index_from_key(key):
7669
"""Extract expert index from various expert key formats.
7770
@@ -96,12 +89,19 @@ def get_expert_index_from_key(key):
9689
return None
9790

9891

99-
def handle_experts_in_state_dict(state_dict):
92+
def handle_experts_in_state_dict(state_dict, num_experts: int | None = None):
10093
"""
10194
Rewrite expert keys in state dict.
95+
96+
Args:
97+
state_dict: The state dictionary to process.
98+
num_experts: Total number of experts in the model. If None, no expert processing occurs.
99+
100+
Returns:
101+
The processed state dictionary with rewritten expert keys.
102102
"""
103-
local_expert_start = get_ep_layer_offset()
104-
local_expert_end = get_total_num_experts()
103+
local_expert_start = get_ep_layer_offset(num_experts)
104+
local_expert_end = num_experts if num_experts else 0
105105

106106
def should_keep_expert_key(expert_index):
107107
"""Determine if this rank should keep this expert key based on expert index"""
@@ -147,9 +147,17 @@ def replace_expert_index_in_key(key, expert_index, state_dict):
147147
return state_dict
148148

149149

150-
def expert_param_local_key(key):
151-
"""Get the module parameter corresponding to the key."""
152-
local_expert_offset = get_ep_layer_offset()
150+
def expert_param_local_key(key: str, num_experts: int | None = None) -> str:
151+
"""Get the module parameter corresponding to the key.
152+
153+
Args:
154+
key: The parameter key to process.
155+
num_experts: Total number of experts in the model. If None, no expert processing occurs.
156+
157+
Returns:
158+
The local parameter key with adjusted expert indices.
159+
"""
160+
local_expert_offset = get_ep_layer_offset(num_experts)
153161
expert_index = get_expert_index_from_key(key)
154162
if expert_index is not None:
155163
new_expert_index = expert_index - local_expert_offset
@@ -174,6 +182,9 @@ def handle_swiglu_in_state_dict(model, model_state_dict, optimizer_state_dict):
174182
"""
175183
assert HAVE_MEGATRON_FSDP, "This function requires Megatron-FSDP to be installed."
176184

185+
# Extract num_experts from model config for expert parameter processing
186+
num_experts = model.config.num_moe_experts if hasattr(model, 'config') else None
187+
177188
def intersection(s1, s2):
178189
# Only works for step=1
179190
start = max(s1.start, s2.start)
@@ -297,7 +308,9 @@ def split_swiglu_linear_fc1(data, dist_param, swiglu_shard_axis, is_expert_param
297308
new_opt_state_dict[f"{key}_w"] = opt_state_dict[key].copy()
298309
new_opt_state_dict[f"{key}_v"] = opt_state_dict[key].copy()
299310
for subkey in ["exp_avg", "exp_avg_sq"]:
300-
dist_param = model.get_parameter(expert_param_local_key(key[len("module.") :]))
311+
dist_param = model.get_parameter(
312+
expert_param_local_key(key[len("module.") :], num_experts)
313+
)
301314
weight_w, weight_v = split_swiglu_linear_fc1(
302315
opt_state_dict[key][subkey],
303316
dist_param,
@@ -426,6 +439,13 @@ def validate_loaded_state_dict(state_dict, checkpoint_path):
426439
def get_global_unique_param_name(model_chunks, param):
427440
"""
428441
Get the global unique parameter name for a given model and parameter.
442+
443+
Args:
444+
model_chunks: List of model chunks to search for the parameter.
445+
param: The parameter to find the name for.
446+
447+
Returns:
448+
The global unique parameter name.
429449
"""
430450
param_name = None
431451
for model in model_chunks:
@@ -450,6 +470,7 @@ def get_global_unique_param_name(model_chunks, param):
450470
param_name = re.sub(r"layers\.(\d+)", f"layers.{tf_layer_number - 1}", param_name)
451471

452472
# Get EP unique parameter name
453-
param_name = list(handle_experts_in_state_dict({param_name: None}).keys())[0]
473+
num_experts = model_chunks[0].config.num_moe_experts if model_chunks else None
474+
param_name = next(iter(handle_experts_in_state_dict({param_name: None}, num_experts).keys()))
454475

455476
return param_name

megatron/training/checkpointing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -862,7 +862,7 @@ def preprocess_fsdp_dtensor_state_dict(args, raw_state_dict, model):
862862
)
863863
state_dict["model"] = model_state_dict
864864
if args.num_experts:
865-
state_dict["model"] = handle_experts_in_state_dict(state_dict["model"])
865+
state_dict["model"] = handle_experts_in_state_dict(state_dict["model"], args.num_experts)
866866
preprocess_state_dict_for_uneven_dtensor(state_dict)
867867

868868
return state_dict

0 commit comments

Comments
 (0)