Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
414 changes: 414 additions & 0 deletions .env-

Large diffs are not rendered by default.

25 changes: 25 additions & 0 deletions ALCF/helpers.sh
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,28 @@ setup_run_cmd() {
if [[ -z "${OVERRIDE_CKPT_OPT_PARAM:-}" ]]; then
train_args+=("--use-checkpoint-opt_param-scheduler")
fi

# Add MuP to the model
export MUP_BASE_WIDTH=${MUP_BASE_WIDTH:-256}
export MUP_MUL=$(( $HIDDEN / $MUP_BASE_WIDTH ))
mup_flags+=(
"--enable-mup"
"--mup-coord-check=True"
"--mup-hidden-weights-scale=${MUP_MUL}"
"--mup-hidden-lr-scale=${MUP_MUL}"
)


# Add depth scaling to the model
export DEPTH_BASE=${DEPTH_BASE:-2}
export DEPTH_MUL=$(( $NLAYERS / $DEPTH_BASE ))
depth_scaling_flags+=(
#"--enable-depth-scale"
"--depth-base=${DEPTH_BASE}"
"--depth-multiplier=${DEPTH_MUL}"
"--depth-alpha=0.5")


# "--init-method-std ${INIT_METHOD_STD:-0.0006}"
# "--shuffle-sample"
train_args+=(
Expand Down Expand Up @@ -271,6 +293,9 @@ setup_run_cmd() {
"--num-attention-heads=${HEADS}"
"--data-cache-path=${data_cache_path}"
"--data-file-list=${DATA_FILE_LIST:-${dfl_fallback}}"
# add MUP parameters
"${mup_flags[@]}"
"${depth_scaling_flags[@]}"
)
# "--adam-eps ${ADAM_EPS:-0.00001}"
cache_dir="${PBS_O_WORKDIR}/.cache/"
Expand Down
65 changes: 65 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
parser = _add_transformer_engine_args(parser)
parser = _add_retro_args(parser)
parser = _add_profiler_args(parser)
parser = _add_mup_args(parser)
parser = _add_depth_scaling_args(parser)

# Custom arguments.
if extra_args_provider is not None:
Expand Down Expand Up @@ -500,6 +502,69 @@ def core_transformer_config_from_args(args):

return TransformerConfig(**kw_args)

def _add_depth_scaling_args(parser):

group = parser.add_argument_group(title='Depth_Scaling')

group.add_argument('--enable-depth-scale', action='store_true',
help='Include in cmd to implement parameterization for model depth scaling', dest='enable_depth_scale')

#group.add_argument('--depth_scaling_enabled', type=bool, default=False,
# help='Include in cmd to implement parameterization for model depth scaling', dest='depth_scaling_enabled')


group.add_argument('--depth-base', type=int, default=1,
help='Specify number of layers in base model', dest='depth_base')

group.add_argument('--depth-multiplier', type=float, default=1.0,
help='Number of layers / Base number of layers', dest='depth_multiplier')

group.add_argument('--depth-alpha', type=float, default=0.5,
help='Value of alpha used in depth scaling', dest='depth_alpha')

return parser

def _add_mup_args(parser):
group = parser.add_argument_group(title='MuP')

group.add_argument('--enable-mup', action='store_true',
help='Include in cmd to implement MuP', dest='enable_mup')

#group.add_argument('--enable-mup', type=bool, default=False,
# help='Set True to use MuP', dest='enable-mup')


group.add_argument('--mup-coord-check', type=bool, default=False,
help='Perform coordinate check for MuP', dest='mup_coord_check')

group.add_argument('--mup-input-weights-scale', type=float, default=1.0,
help='Scalar to multiply initial weights', dest='mup_input_weights_scale')


group.add_argument('--mup-hidden-weights-scale', type=float, default=1.0,
help='Scalar to multiply hidden weights', dest='mup_hidden_weights_scale')

group.add_argument('--mup-output-weights-scale', type=float, default=1.0,
help='Scalar to multiply output weights', dest='mup_output_weights_scale')


group.add_argument('--mup-input-lr-scale', type=float, default=1.0,
help='To scale learning rate for input weights', dest='mup_input_lr_scale')


group.add_argument('--mup-hidden-lr-scale', type=float, default=1.0,
help='To scale learning rate for hidden weights', dest='mup_hidden_lr_scale')


group.add_argument('--mup-output-lr-scale', type=float, default=1.0,
help='To scale learning rate for output weights', dest='mup_output_lr_scale')


return parser




def _add_transformer_engine_args(parser):
group = parser.add_argument_group(title='Transformer-Engine')

Expand Down
26 changes: 26 additions & 0 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,20 @@ class TransformerConfig(ModelParallelConfig):
recompute_method: str = None
recompute_num_layers: int = None
distribute_saved_activations: bool = None

### Begin MuP Code ###
# MuP enabled
enable_mup: bool = False
mup_coord_check: bool = False
mup_hidden_lr_scale: float = 1.0
mup_hidden_weights_scale: float = 1.0

# Depth scaling enabled
enable_depth_scale: bool = False
depth_multiplier: float = 1.0
depth_alpha: float = 1.0

### End MuP Code ###

def __post_init__(self):
""" Python dataclass method that is used to modify attributes after initialization.
Expand Down Expand Up @@ -212,3 +226,15 @@ def __post_init__(self):
if self.output_layer_init_method is None:
self.output_layer_init_method = scaled_init_method_normal(self.init_method_std, self.num_layers)

### Begin MuP Code ###
# Check if mup-enable flag is included in args
if self.enable_mup is None:
self.enable_mup = False

if self.mup_coord_check is None:
self.mup_coord_check = False

if self.enable_depth_scale is None:
self.enable_depth_scale = True
### End MuP Code ###

2 changes: 1 addition & 1 deletion megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
def post_language_model_processing(lm_output, labels, logit_weights,
parallel_output,
fp16_lm_cross_entropy):

# Output. Format [s b h]
output = parallel_lm_logits(
lm_output,
Expand Down
24 changes: 21 additions & 3 deletions megatron/model/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,22 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_)
async_grad_allreduce = False

### Begin MuP Code ###
if args.enable_mup:
log_mult = (args.mup_hidden_weights_scale ** (-1))
else:
log_mult = 1.0
### End Mup Code ###

# Matrix multiply.
logits_parallel = tensor_parallel.linear_with_grad_accumulation_and_async_allreduce(
logits_parallel = log_mult * tensor_parallel.linear_with_grad_accumulation_and_async_allreduce(
input=input_parallel,
weight=word_embeddings_weight,
bias=bias,
gradient_accumulation_fusion=args.gradient_accumulation_fusion,
async_grad_allreduce=async_grad_allreduce,
sequence_parallel=args.sequence_parallel)

# Gather if needed.

if parallel_output:
Expand Down Expand Up @@ -147,16 +155,25 @@ def __init__(self,
super(Embedding, self).__init__()

self.hidden_size = hidden_size
self.init_method = config.init_method
self.init_method = config.init_method ### Begin MuP Comment --- Keeping this because it might be used for initializing position embeddings ?
self.num_tokentypes = num_tokentypes

args = get_args()

# Word embeddings (parallel).
self.embedding_weights_in_fp32 = embedding_weights_in_fp32
self.params_dtype = args.params_dtype

### Begin MuP Code ### -- Do this only for dense inputs
if config.enable_mup:
load_init_function = init_method_normal( config.init_method_std * (vocab_size ** (-1/2)))
else:
load_init_function = config.init_method

self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
vocab_size, self.hidden_size, config=config, init_method=config.init_method)
vocab_size, self.hidden_size, config=config, init_method=load_init_function)
### End MuP Code ###

self._word_embeddings_key = 'word_embeddings'

# Position embedding (serial).
Expand Down Expand Up @@ -505,6 +522,7 @@ def __init__(self,
# embedding tying that also does not have a bias.
bias=False
)

self._output_layer_key = 'output_layer'

def set_input_tensor(self, input_tensor):
Expand Down
Loading