diff --git a/pretrain_retro.py b/pretrain_retro.py index 63abbac5e3..a0a8fa4cae 100644 --- a/pretrain_retro.py +++ b/pretrain_retro.py @@ -4,6 +4,7 @@ from functools import partial import torch +from importlib import import_module from megatron.training import get_args from megatron.training import get_tokenizer @@ -19,7 +20,6 @@ from megatron.core.enums import ModelType from megatron.core.models.retro import get_retro_decoder_block_spec, RetroConfig, RetroModel from megatron.core.models.retro.utils import get_all_true_mask -from megatron.core.tokenizers import MegatronTokenizer from megatron.training import pretrain from megatron.training.utils import get_ltor_masks_and_position_ids from pretrain_gpt import ( @@ -46,7 +46,7 @@ def core_model_provider(pre_process=True, post_process=True): else: block_spec = get_retro_decoder_block_spec(config, use_transformer_engine=True) - print_rank_0('building GPT model ...') + print_rank_0('building Retro model ...') model = RetroModel( config=config, transformer_layer_spec=block_spec,