Skip to content

Commit 1bb6337

Browse files
committed
Merge branch 'dnarayanan/assertion_check' into 'main'
Some bugfixes in megatron/training.py when save argument is not provided See merge request ADLR/megatron-lm!1907
2 parents 5644ed5 + c51503e commit 1bb6337

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

megatron/training/arguments.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,8 @@ def validate_args(args, defaults={}):
521521
if args.decoupled_lr is not None or args.decoupled_min_lr is not None:
522522
assert not args.use_legacy_models, \
523523
'--decoupled-lr and --decoupled-min-lr is not supported in legacy models.'
524-
assert not args.use_dist_ckpt, "Distributed checkpointing does not work with decoupled LR yet."
524+
if args.load is not None or args.save is not None:
525+
assert not args.use_dist_ckpt, "Distributed checkpointing does not work with decoupled LR yet."
525526

526527
# Legacy RoPE arguments
527528
if args.use_rotary_position_embeddings:

megatron/training/training.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,10 +1223,11 @@ def get_e2e_base_metrics():
12231223
if args.exit_signal_handler:
12241224
signal_handler = get_signal_handler()
12251225
if any(signal_handler.signals_received()):
1226-
save_checkpoint_and_time(iteration, model, optimizer,
1227-
opt_param_scheduler,
1228-
num_floating_point_operations_so_far,
1229-
checkpointing_context, train_data_iterator=train_data_iterator)
1226+
if args.save:
1227+
save_checkpoint_and_time(iteration, model, optimizer,
1228+
opt_param_scheduler,
1229+
num_floating_point_operations_so_far,
1230+
checkpointing_context, train_data_iterator=train_data_iterator)
12301231
print_datetime('exiting program after receiving SIGTERM.')
12311232
exit = True
12321233
break
@@ -1259,7 +1260,7 @@ def get_e2e_base_metrics():
12591260
done_cuda, op=torch.distributed.ReduceOp.MAX)
12601261
done = done_cuda.item()
12611262
if done:
1262-
if not saved_checkpoint:
1263+
if args.save and not saved_checkpoint:
12631264
save_checkpoint_and_time(iteration, model, optimizer,
12641265
opt_param_scheduler,
12651266
num_floating_point_operations_so_far,

0 commit comments

Comments
 (0)