diff --git a/megatron/core/models/multimodal/llava_model.py b/megatron/core/models/multimodal/llava_model.py index 2ac2657c1c..dae9a02b78 100644 --- a/megatron/core/models/multimodal/llava_model.py +++ b/megatron/core/models/multimodal/llava_model.py @@ -124,6 +124,7 @@ def __init__( max_num_tiles: int = 0, tokenizer_type: str = "", vp_stage: Optional[int] = None, + use_vision_backbone_fp8_arch: bool = False, ) -> None: super().__init__(config=language_transformer_config) @@ -295,7 +296,7 @@ def __init__( ln_post_impl = None use_mask_token = False - if vision_transformer_config.fp8: + if vision_transformer_config.fp8 or use_vision_backbone_fp8_arch: # FP8 padding for final sequence length to be a multiple of 16 or 32. class_token_len = 32 if vision_transformer_config.fp8_recipe == "mxfp8" else 16 diff --git a/megatron/core/ssm/mamba_block.py b/megatron/core/ssm/mamba_block.py index 7d8ca74c8f..d405c5de4a 100644 --- a/megatron/core/ssm/mamba_block.py +++ b/megatron/core/ssm/mamba_block.py @@ -203,13 +203,14 @@ def __init__( eps=self.config.layernorm_epsilon, ) - self.apply( - partial( - _init_weights, - n_layer=self.config.num_layers, - initializer_range=self.config.init_method_std, + if self.config.perform_initialization: + self.apply( + partial( + _init_weights, + n_layer=self.config.num_layers, + initializer_range=self.config.init_method_std, + ) ) - ) def _select_layers_for_pipeline_parallel(self, layer_type_list): num_layers_per_pipeline_rank = self.config.num_layers // self.pp_group.size() diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 895792ff05..15763be8f1 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -289,7 +289,7 @@ def __init__( setattr(self.conv1d.weight, "tensor_model_parallel", True) setattr(self.conv1d.bias, "tensor_model_parallel", True) - if self.conv_init is not None: + if self.config.perform_initialization and self.conv_init is not None: nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init) self.activation = "silu" @@ -322,7 +322,9 @@ def __init__( assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0] A = torch.empty( self.nheads_local_tp, dtype=torch.float32, device=torch.cuda.current_device() - ).uniform_(*A_init_range) + ) + if self.config.perform_initialization: + A = A.uniform_(*A_init_range) A_log = torch.log(A) # Keep A_log in fp32 self.A_log = nn.Parameter(A_log) self.A_log._no_weight_decay = True