diff --git a/megatron/arguments.py b/megatron/arguments.py index 9a5e4b8da7e..29130b5b2e6 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -1305,6 +1305,8 @@ def _add_data_args(parser): help='Warm up mmap files.') group.add_argument('--num-workers', type=int, default=2, help="Dataloader number of workers.") + group.add_argument('--num-checkpoint-workers', type=int, default=2, + help="Number of checkpoint workers") group.add_argument('--tokenizer-type', type=str, default=None, choices=['BertWordPieceLowerCase', diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index b7f4b30bde8..135e8c6ba29 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -39,6 +39,8 @@ _CHECKPOINT_VERSION = None +_CHECKPOINT_TASK_LIST = [] +_CHECKPOINT_NUM_TASKS = 0 dlp = Profile("CHECKPOINT") def set_checkpoint_version(value): global _CHECKPOINT_VERSION @@ -233,9 +235,11 @@ def get_rng_state(): rng_state_list = [rng_state] return rng_state_list +from multiprocessing import Process + -@dlp.log -def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): +@dlp.log +def save_checkpoint_sync(iteration, model, optimizer, opt_param_scheduler): """Save a model checkpoint.""" args = get_args() assert args is not None @@ -339,6 +343,38 @@ def state_dict_for_save_checkpoint_deepspeed(destination=None, prefix='', keep_v if torch.distributed.is_initialized(): torch.distributed.barrier() +def wait_checkpoint(): + print_rank_0("waiting for previous checkpointing to finish") + global _CHECKPOINT_TASK_LIST + for t in _CHECKPOINT_TASK_LIST: + t.join() + _CHECKPOINT_TASK_LIST = [] + +def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): + ''' + This is for performing async checkpoint + ''' + args = get_args() + assert aargs is not None + num_checkpoint_workers = args.num_checkpoint_workers + global _CHECKPOINT_TASK_LIST + global _CHECKPOINT_NUM_TASKS + if args.num_checkpoint_workers > 0: + print_rank_0("Async checkpointing") + if _CHECKPOINT_NUM_TASKS < num_checkpoint_workers: + proc = Process(target=save_checkpoint_sync, args=(iteration, model, optimizer, opt_param_scheduler)) + proc.start() + else: + wait_checkpoint() + _CHECKPOINT_NUM_TASKS = 0 + proc = Process(target=save_checkpoint_sync, args=(iteration, model, optimizer, opt_param_scheduler)) + proc.start() + _CHECKPOINT_TASK_LIST.append(proc) + _CHECKPOINT_NUM_TASKS += 1 + else: + save_checkpoint_sync(iteration, model, optimizer, opt_param_scheduler) + + @dlp.log def _transpose_first_dim(t, num_splits, num_splits_first, model): input_shape = t.size() diff --git a/megatron/training.py b/megatron/training.py index a396ed2b4aa..228b194b2a0 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -29,6 +29,7 @@ # from megatron import print_rank_last from megatron.checkpointing import load_checkpoint from megatron.checkpointing import save_checkpoint +from megatron.checkpointing import wait_checkpoint from megatron.model import Float16Module from megatron.model import GPTModel from megatron.core.enums import ModelType @@ -332,6 +333,7 @@ def pretrain( else: log.info("skipping training (--skip-train is on) ...") iteration = args.iteration + config = core_transformer_config_from_args(args) if args.do_valid: prefix = f"iteration {iteration} on {args.eval_iters * args.global_batch_size}-sample draw from validation set" @@ -360,6 +362,7 @@ def pretrain( write_to_tensorboard=not args.skip_train, test=True, ) + wait_checkpoint() return model @@ -1797,6 +1800,7 @@ def train( iteration, model, optimizer, opt_param_scheduler ) print_datetime("exiting program after receiving SIGTERM.") + wait_checkpoint() sys.exit() if args.save and args.save_interval and iteration % args.save_interval == 0: save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler) @@ -1815,6 +1819,7 @@ def train( iteration, model, optimizer, opt_param_scheduler ) print_datetime("exiting program after {} minutes".format(train_time)) + wait_checkpoint() sys.exit() # Exiting based on iterations if args.exit_interval and iteration % args.exit_interval == 0: @@ -1822,6 +1827,7 @@ def train( save_checkpoint_and_time( iteration, model, optimizer, opt_param_scheduler ) + wait_checkpoint() torch.distributed.barrier() print_datetime("exiting program at iteration {}".format(iteration)) sys.exit()