@@ -1413,35 +1413,35 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load',
14131413 ignore_rng_state = False
14141414 ignore_rerun_state = True
14151415 if ckpt_format == "torch_dist" :
1416- state_dict_args = types .SimpleNamespace ()
1416+ ckpt_args = types .SimpleNamespace ()
14171417 if state_dict is not None and "args" in state_dict :
1418- state_dict_args = state_dict .get ("args" )
1418+ ckpt_args = state_dict .get ("args" )
14191419
1420- if not hasattr (state_dict_args , "tensor_model_parallel_size" ):
1420+ if not hasattr (ckpt_args , "tensor_model_parallel_size" ):
14211421 print_rank_0 ("WARNING: TP size not found in checkpoint args, using 1 as default." )
1422- if not hasattr (state_dict_args , "pipeline_model_parallel_size" ):
1422+ if not hasattr (ckpt_args , "pipeline_model_parallel_size" ):
14231423 print_rank_0 ("WARNING: PP size not found in checkpoint args, using 1 as default." )
14241424
14251425 ckpt_tp_pp = (
1426- getattr (state_dict_args , "tensor_model_parallel_size" , 1 ),
1427- getattr (state_dict_args , "pipeline_model_parallel_size" , 1 ),
1426+ getattr (ckpt_args , "tensor_model_parallel_size" , 1 ),
1427+ getattr (ckpt_args , "pipeline_model_parallel_size" , 1 ),
14281428 )
14291429 run_tp_pp = (
14301430 args .tensor_model_parallel_size ,
14311431 args .pipeline_model_parallel_size ,
14321432 )
14331433
1434- ckpt_world_size = getattr (state_dict_args , 'world_size' , 0 )
1434+ ckpt_world_size = getattr (ckpt_args , 'world_size' , 0 )
14351435 run_world_size = getattr (args , 'world_size' , 0 )
1436- ckpt_dp = getattr (state_dict_args , 'data_parallel_size' , 0 )
1436+ ckpt_dp = getattr (ckpt_args , 'data_parallel_size' , 0 )
14371437 run_dp = getattr (args , 'data_parallel_size' , 0 )
14381438 mismatch_msg = "(TP, PP) mismatch after resume ({} vs {} from checkpoint)" .format (
14391439 run_tp_pp , ckpt_tp_pp
14401440 )
14411441
14421442 # Determine if RNG state will be loaded
14431443 if (ckpt_tp_pp == run_tp_pp and not release and not args .finetune and not args .no_load_rng
1444- and not getattr (state_dict_args , 'no_save_rng' , False )):
1444+ and not getattr (ckpt_args , 'no_save_rng' , False )):
14451445 gen_sd_rng_state = get_rng_state (args .ckpt_format ) # we can load the rng state
14461446 else :
14471447 ignore_rng_state = True
@@ -1456,7 +1456,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load',
14561456 print_rank_0 (f'sharded_state_dict metadata loaded from the checkpoint: { sharded_sd_metadata } ' )
14571457 # Determine if optimizer state will be loaded
14581458 if (not release and not args .finetune and not args .no_load_optim
1459- and not getattr (state_dict_args , 'no_save_optim' , False )):
1459+ and not getattr (ckpt_args , 'no_save_optim' , False )):
14601460 gen_sd_optim = optimizer
14611461 gen_sd_opt_param_scheduler = opt_param_scheduler
14621462
@@ -1467,7 +1467,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load',
14671467 # (for MCore v0.13+ checkpoints `sharded_sd_metadata is not None`)
14681468 sharded_sd_metadata = {
14691469 'distrib_optim_sharding_type' : ('fully_sharded_model_space'
1470- if getattr (state_dict_args , 'ckpt_fully_parallel_save' , False )
1470+ if getattr (ckpt_args , 'ckpt_fully_parallel_save' , False )
14711471 else 'dp_zero_gather_scatter' ),
14721472 }
14731473 if (
0 commit comments