Skip to content

Commit 681ce76

Browse files
committed
rename var
Signed-off-by: Maanu Grover <[email protected]>
1 parent 8103049 commit 681ce76

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

megatron/training/checkpointing.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1413,35 +1413,35 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load',
14131413
ignore_rng_state = False
14141414
ignore_rerun_state = True
14151415
if ckpt_format == "torch_dist":
1416-
state_dict_args = types.SimpleNamespace()
1416+
ckpt_args = types.SimpleNamespace()
14171417
if state_dict is not None and "args" in state_dict:
1418-
state_dict_args = state_dict.get("args")
1418+
ckpt_args = state_dict.get("args")
14191419

1420-
if not hasattr(state_dict_args, "tensor_model_parallel_size"):
1420+
if not hasattr(ckpt_args, "tensor_model_parallel_size"):
14211421
print_rank_0("WARNING: TP size not found in checkpoint args, using 1 as default.")
1422-
if not hasattr(state_dict_args, "pipeline_model_parallel_size"):
1422+
if not hasattr(ckpt_args, "pipeline_model_parallel_size"):
14231423
print_rank_0("WARNING: PP size not found in checkpoint args, using 1 as default.")
14241424

14251425
ckpt_tp_pp = (
1426-
getattr(state_dict_args, "tensor_model_parallel_size", 1),
1427-
getattr(state_dict_args, "pipeline_model_parallel_size", 1),
1426+
getattr(ckpt_args, "tensor_model_parallel_size", 1),
1427+
getattr(ckpt_args, "pipeline_model_parallel_size", 1),
14281428
)
14291429
run_tp_pp = (
14301430
args.tensor_model_parallel_size,
14311431
args.pipeline_model_parallel_size,
14321432
)
14331433

1434-
ckpt_world_size = getattr(state_dict_args, 'world_size', 0)
1434+
ckpt_world_size = getattr(ckpt_args, 'world_size', 0)
14351435
run_world_size = getattr(args, 'world_size', 0)
1436-
ckpt_dp = getattr(state_dict_args, 'data_parallel_size', 0)
1436+
ckpt_dp = getattr(ckpt_args, 'data_parallel_size', 0)
14371437
run_dp = getattr(args, 'data_parallel_size', 0)
14381438
mismatch_msg = "(TP, PP) mismatch after resume ({} vs {} from checkpoint)".format(
14391439
run_tp_pp, ckpt_tp_pp
14401440
)
14411441

14421442
# Determine if RNG state will be loaded
14431443
if (ckpt_tp_pp == run_tp_pp and not release and not args.finetune and not args.no_load_rng
1444-
and not getattr(state_dict_args, 'no_save_rng', False)):
1444+
and not getattr(ckpt_args, 'no_save_rng', False)):
14451445
gen_sd_rng_state = get_rng_state(args.ckpt_format) # we can load the rng state
14461446
else:
14471447
ignore_rng_state = True
@@ -1456,7 +1456,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load',
14561456
print_rank_0(f'sharded_state_dict metadata loaded from the checkpoint: {sharded_sd_metadata}')
14571457
# Determine if optimizer state will be loaded
14581458
if (not release and not args.finetune and not args.no_load_optim
1459-
and not getattr(state_dict_args, 'no_save_optim', False)):
1459+
and not getattr(ckpt_args, 'no_save_optim', False)):
14601460
gen_sd_optim = optimizer
14611461
gen_sd_opt_param_scheduler = opt_param_scheduler
14621462

@@ -1467,7 +1467,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load',
14671467
# (for MCore v0.13+ checkpoints `sharded_sd_metadata is not None`)
14681468
sharded_sd_metadata = {
14691469
'distrib_optim_sharding_type': ('fully_sharded_model_space'
1470-
if getattr(state_dict_args, 'ckpt_fully_parallel_save', False)
1470+
if getattr(ckpt_args, 'ckpt_fully_parallel_save', False)
14711471
else 'dp_zero_gather_scatter'),
14721472
}
14731473
if (

0 commit comments

Comments
 (0)