From c8b4518e6e863638cdfc7178444b713f1f318f78 Mon Sep 17 00:00:00 2001 From: Huihuo Zheng Date: Fri, 10 May 2024 01:42:19 +0000 Subject: [PATCH] added async checkpointing support --- megatron/arguments.py | 2 ++ megatron/checkpointing.py | 39 +++++++++++++++++++++++++++++++++++++-- megatron/training.py | 7 ++++++- 3 files changed, 45 insertions(+), 3 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index 702a4bad093..39344e49045 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -1303,6 +1303,8 @@ def _add_data_args(parser): help='Warm up mmap files.') group.add_argument('--num-workers', type=int, default=2, help="Dataloader number of workers.") + group.add_argument('--num-checkpoint-workers', type=int, default=2, + help="Number of checkpoint workers") group.add_argument('--tokenizer-type', type=str, default=None, choices=['BertWordPieceLowerCase', diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index d585baf7177..c5218215741 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -38,7 +38,8 @@ _CHECKPOINT_VERSION = None - +_CHECKPOINT_TASK_LIST = [] +_CHECKPOINT_NUM_TASKS = 0 def set_checkpoint_version(value): global _CHECKPOINT_VERSION if _CHECKPOINT_VERSION is not None: @@ -232,9 +233,12 @@ def get_rng_state(): rng_state_list = [rng_state] return rng_state_list +from multiprocessing import Process + -def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): + +def save_checkpoint_sync(iteration, model, optimizer, opt_param_scheduler): """Save a model checkpoint.""" args = get_args() assert args is not None @@ -338,6 +342,37 @@ def state_dict_for_save_checkpoint_deepspeed(destination=None, prefix='', keep_v if torch.distributed.is_initialized(): torch.distributed.barrier() +def wait_checkpoint(): + print_rank_0("waiting for previous checkpointing to finish") + global _CHECKPOINT_TASK_LIST + for t in _CHECKPOINT_TASK_LIST: + t.join() + _CHECKPOINT_TASK_LIST = [] + +def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): + ''' + This is for performing async checkpoint + ''' + args = get_args() + assert aargs is not None + num_checkpoint_workers = args.num_checkpoint_workers + global _CHECKPOINT_TASK_LIST + global _CHECKPOINT_NUM_TASKS + if args.num_checkpoint_workers > 0: + print_rank_0("Async checkpointing") + if _CHECKPOINT_NUM_TASKS < num_checkpoint_workers: + proc = Process(target=save_checkpoint_sync, args=(iteration, model, optimizer, opt_param_scheduler)) + proc.start() + else: + wait_checkpoint() + _CHECKPOINT_NUM_TASKS = 0 + proc = Process(target=save_checkpoint_sync, args=(iteration, model, optimizer, opt_param_scheduler)) + proc.start() + _CHECKPOINT_TASK_LIST.append(proc) + _CHECKPOINT_NUM_TASKS += 1 + else: + save_checkpoint_sync(iteration, model, optimizer, opt_param_scheduler) + def _transpose_first_dim(t, num_splits, num_splits_first, model): input_shape = t.size() diff --git a/megatron/training.py b/megatron/training.py index 2bdf61f9084..0087da3eb4d 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -27,6 +27,7 @@ from megatron import print_rank_last from megatron.checkpointing import load_checkpoint from megatron.checkpointing import save_checkpoint +from megatron.checkpointing import wait_checkpoint from megatron.model import Float16Module from megatron.model import GPTModel from megatron.core.enums import ModelType @@ -285,7 +286,7 @@ def pretrain( log.info('skipping training (--skip-train is on) ...') iteration = args.iteration - + config = core_transformer_config_from_args(args) if args.do_valid: prefix = f'iteration {iteration} on {args.eval_iters * args.global_batch_size}-sample draw from validation set' @@ -300,6 +301,7 @@ def pretrain( test_data_iterator, model, iteration, process_non_loss_data_func, config, verbose=True, write_to_tensorboard=not args.skip_train, test=True) + wait_checkpoint() return model @@ -1404,6 +1406,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler) print_datetime('exiting program after receiving SIGTERM.') + wait_checkpoint() sys.exit() if args.save and args.save_interval and \ @@ -1425,6 +1428,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler) print_datetime('exiting program after {} minutes'.format(train_time)) + wait_checkpoint() sys.exit() # Exiting based on iterations @@ -1432,6 +1436,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, if args.save and not saved_checkpoint: save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler) + wait_checkpoint() torch.distributed.barrier() print_datetime('exiting program at iteration {}'.format(iteration)) sys.exit()