Skip to content

Commit e4b7259

Browse files
jaredcasperrogerwaleffesanthnm2
authored
Fix Mamba TP and remove confusing legacy initialization (#2202)
Co-authored-by: Roger Waleffe <[email protected]> Co-authored-by: Keshav Santhanam <[email protected]>
1 parent 7dec856 commit e4b7259

File tree

3 files changed

+11
-71
lines changed

3 files changed

+11
-71
lines changed

megatron/core/ssm/mamba_block.py

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55
# This source code is licensed under the Apache license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
import math
98
from contextlib import nullcontext
109
from dataclasses import dataclass
11-
from functools import partial
1210
from typing import Optional, Tuple, Union
1311

1412
import torch
@@ -23,7 +21,6 @@
2321
from megatron.core.process_groups_config import ProcessGroupCollection
2422
from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols as LayerSymbols
2523
from megatron.core.ssm.mamba_hybrid_layer_allocation import allocate_layers
26-
from megatron.core.tensor_parallel import get_cuda_rng_tracker
2724
from megatron.core.transformer import TransformerConfig
2825
from megatron.core.transformer.identity_op import IdentityOp
2926
from megatron.core.transformer.module import MegatronModule
@@ -33,50 +30,6 @@
3330
from megatron.core.utils import WrappedTensor, deprecate_inference_params, make_viewless_tensor
3431

3532

36-
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
37-
def _init_weights(
38-
module,
39-
n_layer,
40-
initializer_range=0.02, # Now only used for embedding layer.
41-
rescale_prenorm_residual=True,
42-
n_residuals_per_layer=1, # Change to 2 if we have MLP
43-
):
44-
with get_cuda_rng_tracker().fork():
45-
if isinstance(module, nn.Linear):
46-
if not getattr(module.weight, "_no_reinit", False):
47-
nn.init.normal_(module.weight, std=initializer_range)
48-
if module.bias is not None:
49-
if not getattr(module.bias, "_no_reinit", False):
50-
nn.init.zeros_(module.bias)
51-
elif isinstance(module, nn.Embedding):
52-
nn.init.normal_(module.weight, std=initializer_range)
53-
54-
for name, p in module.named_parameters():
55-
if name in ["conv1d.weight", "out_proj.weight"]:
56-
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
57-
if name in ["in_proj.weight"]:
58-
nn.init.normal_(p, mean=0.0, std=initializer_range)
59-
60-
if rescale_prenorm_residual:
61-
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
62-
# > A modified initialization which accounts for the accumulation on the
63-
# > residual path with model depth. Scale
64-
# > the weights of residual layers at initialization by a factor of
65-
# > 1/√N where N is the # of residual layers.
66-
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
67-
#
68-
# Reference (Megatron-LM):
69-
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
70-
for name, p in module.named_parameters():
71-
if name in ["out_proj.weight", "fc2.weight"]:
72-
# Special Scaled Initialization
73-
nn.init.normal_(
74-
p,
75-
mean=0.0,
76-
std=initializer_range / math.sqrt(n_residuals_per_layer * n_layer),
77-
)
78-
79-
8033
@dataclass
8134
class MambaStackSubmodules:
8235
"""
@@ -210,14 +163,6 @@ def __init__(
210163
eps=self.config.layernorm_epsilon,
211164
)
212165

213-
self.apply(
214-
partial(
215-
_init_weights,
216-
n_layer=self.config.num_layers,
217-
initializer_range=self.config.init_method_std,
218-
)
219-
)
220-
221166
def _select_layers_for_pipeline_parallel(self, layer_type_list):
222167
num_layers_per_pipeline_rank = self.config.num_layers // self.pp_group.size()
223168

megatron/core/ssm/mamba_mixer.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,8 @@ def __init__(
293293

294294
if self.conv_init is not None:
295295
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
296+
else:
297+
nn.init.kaiming_uniform_(self.conv1d.weight, a=math.sqrt(5))
296298

297299
self.activation = "silu"
298300
self.act = nn.SiLU()
@@ -311,13 +313,6 @@ def __init__(
311313
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
312314
inv_dt = dt + torch.log(-torch.expm1(-dt))
313315
self.dt_bias = nn.Parameter(inv_dt)
314-
# Our initialization would set all Linear.bias to zero,
315-
# need to mark this one as _no_reinit
316-
self.dt_bias._no_reinit = True
317-
# Just to be explicit. Without this we already don't
318-
# put wd on dt_bias because of the check
319-
# name.endswith("bias") in param_grouping.py
320-
self.dt_bias._no_weight_decay = True
321316
setattr(self.dt_bias, "tensor_model_parallel", True)
322317

323318
# A parameter
@@ -327,7 +322,6 @@ def __init__(
327322
).uniform_(*A_init_range)
328323
A_log = torch.log(A) # Keep A_log in fp32
329324
self.A_log = nn.Parameter(A_log)
330-
self.A_log._no_weight_decay = True
331325
setattr(self.A_log, "tensor_model_parallel", True)
332326

333327
# D "skip" parameter
@@ -337,7 +331,6 @@ def __init__(
337331
device=torch.cuda.current_device(),
338332
)
339333
) # Keep in fp32
340-
self.D._no_weight_decay = True
341334
setattr(self.D, "tensor_model_parallel", True)
342335

343336
if self.rmsnorm:
@@ -350,6 +343,7 @@ def __init__(
350343
device=torch.cuda.current_device(),
351344
dtype=config.params_dtype,
352345
)
346+
setattr(self.norm.weight, "tensor_model_parallel", True)
353347

354348
# Assume sequence parallelism: input is partitioned along d_inner and
355349
# output is partitioned along the sequence dimension

tests/unit_tests/inference/engines/test_dynamic_engine.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ def _build_test_env(cls, test_config):
347347
fp8="hybrid" if test_config.fp8 else None,
348348
fp8_recipe="tensorwise" if test_config.fp8 else None,
349349
cuda_graph_scope=test_config.cuda_graph_scope,
350+
is_hybrid_model=True, # Needs to be set for correct out_proj init
350351
)
351352

352353
# Mamba model.
@@ -557,13 +558,13 @@ def test_simple(self, model_provider, num_cuda_graphs, cuda_graph_scope) -> None
557558
]
558559

559560
mamba_expected_generated_tokens = [
560-
[74, 72, 83, 59, 1, 70, 15, 89, 30, 52, 82, 70, 64, 16, 83, 5],
561-
[25, 54, 42, 57, 33, 64, 60, 13, 28, 74, 8, 4, 56, 68, 87, 82],
562-
[31, 55, 77, 25, 96, 13, 32, 49, 40, 54, 73, 10, 50, 2, 64, 96],
563-
[72, 80, 35, 72, 77, 85, 98, 36, 4, 97, 37, 46, 79, 95, 83, 85],
564-
[8, 80, 56, 4, 87, 1, 15, 98, 85, 7, 31, 38, 91, 28, 18, 80],
565-
[9, 94, 48, 60, 87, 57, 25, 76, 91, 34, 69, 86, 73, 24, 63, 97],
566-
[17, 5, 62, 66, 15, 52, 32, 75, 66, 18, 69, 5, 67, 37, 94, 51],
561+
[74, 72, 9, 59, 1, 70, 15, 89, 30, 52, 82, 70, 64, 16, 83, 5],
562+
[25, 54, 28, 14, 87, 27, 60, 92, 28, 74, 8, 63, 60, 68, 87, 82],
563+
[31, 21, 87, 25, 96, 13, 32, 49, 40, 54, 55, 68, 73, 2, 64, 96],
564+
[72, 80, 35, 72, 77, 85, 98, 36, 4, 97, 37, 46, 79, 95, 83, 25],
565+
[8, 80, 56, 4, 87, 1, 43, 98, 85, 7, 50, 38, 24, 28, 18, 80],
566+
[9, 94, 36, 16, 87, 57, 25, 76, 64, 92, 47, 86, 73, 72, 71, 97],
567+
[17, 5, 62, 66, 15, 52, 32, 75, 66, 18, 90, 14, 67, 37, 94, 33],
567568
[],
568569
]
569570

0 commit comments

Comments
 (0)