Skip to content

Commit ae99e0a

Browse files
authored
[Dev release cherry pick] Nemotron nano v2 vl + cherrypick #2137 (#2079)
Signed-off-by: Chen Cui <[email protected]>
1 parent b1c616c commit ae99e0a

File tree

3 files changed

+65
-55
lines changed

3 files changed

+65
-55
lines changed

megatron/core/models/multimodal/llava_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def __init__(
124124
max_num_tiles: int = 0,
125125
tokenizer_type: str = "",
126126
vp_stage: Optional[int] = None,
127+
use_vision_backbone_fp8_arch: bool = False,
127128
) -> None:
128129
super().__init__(config=language_transformer_config)
129130

@@ -295,7 +296,7 @@ def __init__(
295296
ln_post_impl = None
296297
use_mask_token = False
297298

298-
if vision_transformer_config.fp8:
299+
if vision_transformer_config.fp8 or use_vision_backbone_fp8_arch:
299300
# FP8 padding for final sequence length to be a multiple of 16 or 32.
300301
class_token_len = 32 if vision_transformer_config.fp8_recipe == "mxfp8" else 16
301302

megatron/core/ssm/mamba_block.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -203,13 +203,14 @@ def __init__(
203203
eps=self.config.layernorm_epsilon,
204204
)
205205

206-
self.apply(
207-
partial(
208-
_init_weights,
209-
n_layer=self.config.num_layers,
210-
initializer_range=self.config.init_method_std,
206+
if self.config.perform_initialization:
207+
self.apply(
208+
partial(
209+
_init_weights,
210+
n_layer=self.config.num_layers,
211+
initializer_range=self.config.init_method_std,
212+
)
211213
)
212-
)
213214

214215
def _select_layers_for_pipeline_parallel(self, layer_type_list):
215216
num_layers_per_pipeline_rank = self.config.num_layers // self.pp_group.size()

megatron/core/ssm/mamba_mixer.py

Lines changed: 56 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -268,60 +268,68 @@ def __init__(
268268
)
269269

270270
conv_dim = self.d_inner_local_tp + 2 * self.ngroups_local_tp * self.d_state # x B C
271-
with get_cuda_rng_tracker().fork():
272-
# weight shape: [conv_dim, 1, d_conv]
273-
# bias shape: [conv_dim]
274-
self.conv1d = nn.Conv1d(
275-
in_channels=conv_dim,
276-
out_channels=conv_dim,
277-
bias=conv_bias,
278-
kernel_size=d_conv,
279-
groups=conv_dim,
280-
padding=d_conv - 1,
281-
device=torch.cuda.current_device(),
282-
dtype=config.params_dtype,
283-
)
284-
setattr(self.conv1d.weight, "tensor_model_parallel", True)
285-
setattr(self.conv1d.bias, "tensor_model_parallel", True)
271+
# weight shape: [conv_dim, 1, d_conv]
272+
# bias shape: [conv_dim]
273+
self.conv1d = nn.Conv1d(
274+
in_channels=conv_dim,
275+
out_channels=conv_dim,
276+
bias=conv_bias,
277+
kernel_size=d_conv,
278+
groups=conv_dim,
279+
padding=d_conv - 1,
280+
device=torch.cuda.current_device(),
281+
dtype=config.params_dtype,
282+
)
283+
setattr(self.conv1d.weight, "tensor_model_parallel", True)
284+
setattr(self.conv1d.bias, "tensor_model_parallel", True)
286285

287-
if self.conv_init is not None:
286+
if self.config.perform_initialization and self.conv_init is not None:
287+
with get_cuda_rng_tracker().fork():
288288
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
289289

290290
self.activation = "silu"
291291
self.act = nn.SiLU()
292292

293-
with get_cuda_rng_tracker().fork():
294-
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
295-
dt = torch.exp(
296-
torch.rand(
297-
self.nheads_local_tp,
298-
device=torch.cuda.current_device(),
299-
dtype=config.params_dtype,
300-
)
301-
* (math.log(dt_max) - math.log(dt_min))
302-
+ math.log(dt_min)
303-
).clamp(min=dt_init_floor)
304-
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
305-
inv_dt = dt + torch.log(-torch.expm1(-dt))
306-
self.dt_bias = nn.Parameter(inv_dt)
307-
# Our initialization would set all Linear.bias to zero,
308-
# need to mark this one as _no_reinit
309-
self.dt_bias._no_reinit = True
310-
# Just to be explicit. Without this we already don't
311-
# put wd on dt_bias because of the check
312-
# name.endswith("bias") in param_grouping.py
313-
self.dt_bias._no_weight_decay = True
314-
setattr(self.dt_bias, "tensor_model_parallel", True)
315-
316-
# A parameter
317-
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
318-
A = torch.empty(
319-
self.nheads_local_tp, dtype=torch.float32, device=torch.cuda.current_device()
320-
).uniform_(*A_init_range)
321-
A_log = torch.log(A) # Keep A_log in fp32
322-
self.A_log = nn.Parameter(A_log)
323-
self.A_log._no_weight_decay = True
324-
setattr(self.A_log, "tensor_model_parallel", True)
293+
if self.config.perform_initialization:
294+
with get_cuda_rng_tracker().fork():
295+
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
296+
dt = torch.exp(
297+
torch.rand(
298+
self.nheads_local_tp,
299+
device=torch.cuda.current_device(),
300+
dtype=config.params_dtype,
301+
)
302+
* (math.log(dt_max) - math.log(dt_min))
303+
+ math.log(dt_min)
304+
).clamp(min=dt_init_floor)
305+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
306+
inv_dt = dt + torch.log(-torch.expm1(-dt))
307+
else:
308+
inv_dt = torch.empty(
309+
self.nheads_local_tp, device=torch.cuda.current_device(), dtype=config.params_dtype
310+
)
311+
312+
self.dt_bias = nn.Parameter(inv_dt)
313+
# Our initialization would set all Linear.bias to zero,
314+
# need to mark this one as _no_reinit
315+
self.dt_bias._no_reinit = True
316+
# Just to be explicit. Without this we already don't
317+
# put wd on dt_bias because of the check
318+
# name.endswith("bias") in param_grouping.py
319+
self.dt_bias._no_weight_decay = True
320+
setattr(self.dt_bias, "tensor_model_parallel", True)
321+
322+
# A parameter
323+
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
324+
A = torch.empty(
325+
self.nheads_local_tp, dtype=torch.float32, device=torch.cuda.current_device()
326+
)
327+
if self.config.perform_initialization:
328+
A = A.uniform_(*A_init_range)
329+
A_log = torch.log(A) # Keep A_log in fp32
330+
self.A_log = nn.Parameter(A_log)
331+
self.A_log._no_weight_decay = True
332+
setattr(self.A_log, "tensor_model_parallel", True)
325333

326334
# D "skip" parameter
327335
self.D = nn.Parameter(

0 commit comments

Comments
 (0)