From 6118e57fc12ab6b4d087cc218fd56a107b7956ce Mon Sep 17 00:00:00 2001 From: "S. M. Mohiuddin Khan Shiam" Date: Mon, 9 Jun 2025 00:39:10 +0600 Subject: [PATCH] Fixes and Enhancements for Mamba Inference and Reference Implementations This pull request addresses several bugs and limitations within the Mamba codebase, primarily aimed at improving inference robustness in the Mamba2 module and increasing the accuracy of reference implementations. Key changes include: In mamba_ssm/modules/mamba2.py: Resolved an issue in _get_states_from_cache to correctly handle dynamic batch sizes during inference, ensuring proper state re-initialization when batch sizes change. Removed the batch == 1 assertion in the forward method for variable-length sequence inference, enabling batched processing for these inputs. Updated the fallback path in the step method to support ngroups > 1, allowing grouped SSM inference even if Triton kernels are not available. In mamba_ssm/ops/selective_scan_interface.py: Added optional RMS Normalization for B, C, and delta tensors within mamba_inner_ref to better match the main MambaInnerFn's behavior and improve numerical consistency. Corrected a shape comment in selective_scan_ref for clarity. In mamba_ssm/models/mixer_seq_simple.py: Removed a redundant comment in the _init_weights function. In mamba_ssm/utils/hf.py: Addressed a bug in load_state_dict_hf to ensure correct dtype conversion and device placement when loading Hugging Face model weights. These modifications enhance the stability, flexibility, and correctness of the Mamba library. --- mamba_ssm/models/mixer_seq_simple.py | 72 +++++++++------ mamba_ssm/modules/mamba2.py | 105 ++++++++++++++++++---- mamba_ssm/ops/selective_scan_interface.py | 31 +++++-- mamba_ssm/utils/hf.py | 2 +- setup.py | 33 ++++--- 5 files changed, 175 insertions(+), 68 deletions(-) diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py index fae2257a..9f7187a6 100644 --- a/mamba_ssm/models/mixer_seq_simple.py +++ b/mamba_ssm/models/mixer_seq_simple.py @@ -26,20 +26,25 @@ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None +from typing import Dict, List, Optional, Tuple, Union +import torch +from torch import nn, Tensor + + def create_block( - d_model, - d_intermediate, - ssm_cfg=None, - attn_layer_idx=None, - attn_cfg=None, - norm_epsilon=1e-5, - rms_norm=False, - residual_in_fp32=False, - fused_add_norm=False, - layer_idx=None, - device=None, - dtype=None, -): + d_model: int, + d_intermediate: int, + ssm_cfg: Optional[Dict] = None, + attn_layer_idx: Optional[List[int]] = None, + attn_cfg: Optional[Dict] = None, + norm_epsilon: float = 1e-5, + rms_norm: bool = False, + residual_in_fp32: bool = False, + fused_add_norm: bool = False, + layer_idx: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[torch.dtype] = None, +) -> nn.Module: if ssm_cfg is None: ssm_cfg = {} if attn_layer_idx is None: @@ -88,7 +93,7 @@ def _init_weights( n_layer, initializer_range=0.02, # Now only used for embedding layer. rescale_prenorm_residual=True, - n_residuals_per_layer=1, # Change to 2 if we have MLP + n_residuals_per_layer=1, ): if isinstance(module, nn.Linear): if module.bias is not None: @@ -122,16 +127,16 @@ def __init__( n_layer: int, d_intermediate: int, vocab_size: int, - ssm_cfg=None, - attn_layer_idx=None, - attn_cfg=None, + ssm_cfg: Optional[Dict] = None, + attn_layer_idx: Optional[List[int]] = None, + attn_cfg: Optional[Dict] = None, norm_epsilon: float = 1e-5, rms_norm: bool = False, - initializer_cfg=None, - fused_add_norm=False, - residual_in_fp32=False, - device=None, - dtype=None, + initializer_cfg: Optional[Dict] = None, + fused_add_norm: bool = False, + residual_in_fp32: bool = False, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[torch.dtype] = None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -187,7 +192,12 @@ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs) for i, layer in enumerate(self.layers) } - def forward(self, input_ids, inference_params=None, **mixer_kwargs): + def forward( + self, + input_ids: Tensor, + inference_params = None, + **mixer_kwargs + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: hidden_states = self.embedding(input_ids) residual = None for layer in self.layers: @@ -213,13 +223,12 @@ def forward(self, input_ids, inference_params=None, **mixer_kwargs): class MambaLMHeadModel(nn.Module, GenerationMixin): - def __init__( self, config: MambaConfig, - initializer_cfg=None, - device=None, - dtype=None, + initializer_cfg: Optional[Dict] = None, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[torch.dtype] = None, ) -> None: self.config = config d_model = config.d_model @@ -271,7 +280,14 @@ def tie_weights(self): def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) - def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs): + def forward( + self, + input_ids: Tensor, + position_ids: Optional[Tensor] = None, + inference_params = None, + num_last_tokens: int = 0, + **mixer_kwargs + ) -> Union[Tensor, Tuple[Tensor, Dict[str, Tensor]]]: """ "position_ids" is just to be compatible with Transformer generation. We don't use it. num_last_tokens: if > 0, only return the logits for the last n tokens diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py index 36b16d47..6fd8d893 100644 --- a/mamba_ssm/modules/mamba2.py +++ b/mamba_ssm/modules/mamba2.py @@ -10,18 +10,31 @@ try: from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -except ImportError: - causal_conv1d_fn, causal_conv1d_update = None, None +except ImportError as e: + raise ImportError( + "causal_conv1d package not found. Please install it with: " + "pip install causal-conv1d>=1.4.0" + ) from e try: from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states -except ImportError: +except ImportError as e: causal_conv1d_varlen_states = None + import warnings + warnings.warn( + "causal_conv1d_varlen module not found. Variable length sequences will not be supported. " + "Install the latest causal_conv1d package for full functionality." + ) try: from mamba_ssm.ops.triton.selective_state_update import selective_state_update -except ImportError: +except ImportError as e: selective_state_update = None + import warnings + warnings.warn( + "selective_state_update module not found. Performance may be degraded. " + "Make sure to install with the 'triton' extra: pip install mamba-ssm[triton]" + ) from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated @@ -221,9 +234,12 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W) else: assert causal_conv1d_varlen_states is not None, "varlen inference requires causal_conv1d package" - assert batch == 1, "varlen inference only supports batch dimension 1" + # The 'batch' variable here might be misleading when cu_seqlens is used. + # The actual number of sequences is cu_seqlens.shape[0] - 1. + # conv_state is already shaped (inference_batch, ...). + # xBC should be (total_tokens, features) when cu_seqlens is present. conv_varlen_states = causal_conv1d_varlen_states( - xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1] + xBC, cu_seqlens, state_len=conv_state.shape[-1] ) conv_state.copy_(conv_varlen_states) assert self.activation in ["silu", "swish"] @@ -308,16 +324,55 @@ def step(self, hidden_states, conv_state, ssm_state): # SSM step if selective_state_update is None: - assert self.ngroups == 1, "Only support ngroups=1 for this inference code path" + assert self.nheads % self.ngroups == 0, "nheads must be divisible by ngroups for PyTorch step fallback" + k = self.nheads // self.ngroups + # Discretize A and B + # dt is already (batch, nheads) from xBC split and projection dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads) - dA = torch.exp(dt * A) # (batch, nheads) - x = rearrange(x, "b (h p) -> b h p", p=self.headdim) - dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x) - ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx) - y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C) - y = y + rearrange(self.D.to(dtype), "h -> h 1") * x - y = rearrange(y, "b h p -> b (h p)") + # A is (nheads,) + + # Reshape for grouped operations + # x: (B, d_ssm) -> (B, ngroups, k, headdim) + x_r = rearrange(x, "b (g k p) -> b g k p", g=self.ngroups, k=k, p=self.headdim) + # dt: (B, nheads) -> (B, ngroups, k) + dt_r = rearrange(dt, "b (g k) -> b g k", g=self.ngroups, k=k) + # A: (nheads,) -> (ngroups, k) + A_r = rearrange(A, "(g k) -> g k", g=self.ngroups, k=k) + # dA: (B, ngroups, k) + dA_r = torch.exp(dt_r * A_r.unsqueeze(0)) + + # B: (B, ngroups * d_state) -> (B, ngroups, d_state) + B_r = rearrange(B, "b (g n) -> b g n", g=self.ngroups) + # C: (B, ngroups * d_state) -> (B, ngroups, d_state) + C_r = rearrange(C, "b (g n) -> b g n", g=self.ngroups) + # ssm_state: (B, nheads, headdim, d_state) -> (B, ngroups, k, headdim, d_state) + ssm_state_r = rearrange(ssm_state, "b (g k) p n -> b g k p n", g=self.ngroups, k=k) + + # SSM recurrence: h_new = dA * h_old + dB * x + # dB = dt * B + # dB_scaled_by_dt: (B, ngroups, k, d_state) + dB_scaled_by_dt = torch.einsum("bgk,bgn->bgkn", dt_r, B_r) + # dBx: (B, ngroups, k, headdim, d_state) + dBx = torch.einsum("bgkp,bgkn->bgkpn", x_r, dB_scaled_by_dt) + + ssm_state_new_r = dA_r.unsqueeze(-1).unsqueeze(-1) * ssm_state_r + dBx + ssm_state.copy_(rearrange(ssm_state_new_r, "b g k p n -> b (g k) p n")) + + # Output: y = C * h_new + D * x + # y_interim: (B, ngroups, k, headdim) + y_interim = torch.einsum("bgkpn,bgn->bgkp", ssm_state_new_r.to(dtype), C_r) + + D_param = self.D.to(dtype) + if self.D_has_hdim: # D is (d_ssm) = (nheads * headdim) + D_r = rearrange(D_param, "(g k p) -> g k p", g=self.ngroups, k=k, p=self.headdim) + y_r = y_interim + D_r.unsqueeze(0) * x_r + else: # D is (nheads) + D_r = rearrange(D_param, "(g k) -> g k", g=self.ngroups, k=k) + y_r = y_interim + D_r.unsqueeze(0).unsqueeze(-1) * x_r + + y = rearrange(y_r, "b g k p -> b (g k p)") # (B, d_ssm) + if not self.rmsnorm: y = y * self.act(z) # (B D) else: @@ -376,8 +431,26 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) else: conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] - # TODO: What if batch size changes between generation, and we reuse the same states? - if initialize_states: + # Handle batch size changes or explicit initialization + if initialize_states or conv_state.shape[0] != batch_size or ssm_state.shape[0] != batch_size: + # Re-initialize states if batch size changed or if explicitly requested + conv_state = torch.zeros( + batch_size, + self.conv1d.weight.shape[0], # out_channels + self.d_conv, # kernel_size + device=self.conv1d.weight.device, + dtype=self.conv1d.weight.dtype, + ) + ssm_state = torch.zeros( + batch_size, + self.nheads, + self.headdim, + self.d_state, + device=self.in_proj.weight.device, + dtype=self.in_proj.weight.dtype, + ) + inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) + elif initialize_states: # Original condition if batch sizes matched but re-init was true conv_state.zero_() ssm_state.zero_() return conv_state, ssm_state diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index a41f1359..162ca3b7 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -173,7 +173,7 @@ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta if y.is_complex(): y = y.real * 2 ys.append(y) - y = torch.stack(ys, dim=2) # (batch dim L) + y = torch.stack(ys, dim=2) # (batch, dim, L) out = y if D is None else y + u * rearrange(D, "d -> d 1") if z is not None: out = out * F.silu(z) @@ -385,7 +385,8 @@ def mamba_inner_ref( xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, - C_proj_bias=None, delta_softplus=True + C_proj_bias=None, delta_softplus=True, + b_rms_weight=None, c_rms_weight=None, dt_rms_weight=None, b_c_dt_rms_eps=1e-6 ): assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d." L = xz.shape[-1] @@ -399,21 +400,39 @@ def mamba_inner_ref( x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d) delta = delta_proj_weight @ x_dbl[:, :delta_rank].t() delta = rearrange(delta, "d (b l) -> b d l", l=L) + + if dt_rms_weight is not None: + delta_reshaped = rearrange(delta, "b d l -> (b l) d").contiguous() + delta_reshaped = rms_norm_forward(delta_reshaped, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps) + delta = rearrange(delta_reshaped, "(b l) d -> b d l", l=L).contiguous() + if B is None: # variable B B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d) if B_proj_bias is not None: B = B + B_proj_bias.to(dtype=B.dtype) if not A.is_complex(): - B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() + B = rearrange(B, "(b l) dstate -> b dstate l", l=L) + else: + B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2) + if b_rms_weight is not None: + B_reshaped = rearrange(B, "b dstate l -> (b l) dstate").contiguous() + B_reshaped = rms_norm_forward(B_reshaped, b_rms_weight, bias=None, eps=b_c_dt_rms_eps) + B = rearrange(B_reshaped, "(b l) dstate -> b dstate l", l=L).contiguous() else: - B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() + B = B.contiguous() # Ensure contiguity if not already handled by RMSNorm path if C is None: # variable B C = x_dbl[:, -d_state:] # (bl d) if C_proj_bias is not None: C = C + C_proj_bias.to(dtype=C.dtype) if not A.is_complex(): - C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() + C = rearrange(C, "(b l) dstate -> b dstate l", l=L) + else: + C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2) + if c_rms_weight is not None: + C_reshaped = rearrange(C, "b dstate l -> (b l) dstate").contiguous() + C_reshaped = rms_norm_forward(C_reshaped, c_rms_weight, bias=None, eps=b_c_dt_rms_eps) + C = rearrange(C_reshaped, "(b l) dstate -> b dstate l", l=L).contiguous() else: - C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() + C = C.contiguous() # Ensure contiguity if not already handled by RMSNorm path y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) diff --git a/mamba_ssm/utils/hf.py b/mamba_ssm/utils/hf.py index 0d7555ac..778bd185 100644 --- a/mamba_ssm/utils/hf.py +++ b/mamba_ssm/utils/hf.py @@ -15,7 +15,7 @@ def load_state_dict_hf(model_name, device=None, dtype=None): # If not fp32, then we don't want to load directly to the GPU mapped_device = "cpu" if dtype not in [torch.float32, None] else device resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) - return torch.load(resolved_archive_file, map_location=mapped_device) + state_dict = torch.load(resolved_archive_file, map_location=mapped_device) # Convert dtype before moving to GPU to save memory if dtype is not None: state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} diff --git a/setup.py b/setup.py index f61ca90d..26bd83f3 100755 --- a/setup.py +++ b/setup.py @@ -99,27 +99,28 @@ def get_torch_hip_version(): return None -def check_if_hip_home_none(global_option: str) -> None: - - if HIP_HOME is not None: - return - # warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary - # in that case. +def check_if_hip_home_none(global_option: str): + if HIP_HOME is None: + raise RuntimeError( + f"{global_option} was requested, but the ROCm/HIP installation is incomplete. " + 'Please make sure ROCm is properly installed and HIP_HOME environment variable is set.\n' + 'On Ubuntu, you may need to install: rocm-libs hipcc hiprt hipcub rocprim rocrand rocthrust rocblas hipblas rocsolver hipsparse rocsparse hipfft rocfft rocthrust rocrand' + ) warnings.warn( f"{global_option} was requested, but hipcc was not found. Are you sure your environment has hipcc available?" ) def check_if_cuda_home_none(global_option: str) -> None: - if CUDA_HOME is not None: - return - # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary - # in that case. - warnings.warn( - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " - "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " - "only images whose names contain 'devel' will provide nvcc." - ) + if CUDA_HOME is None: + raise RuntimeError( + f"{global_option} was requested, but CUDA installation was not found. " + 'Please ensure CUDA is properly installed and the CUDA_HOME environment variable is set.\n' + 'Common solutions include:\n' + '1. Install CUDA from NVIDIA: https://developer.nvidia.com/cuda-downloads\n' + '2. Set CUDA_HOME to your CUDA installation directory (e.g., /usr/local/cuda-11.8)\n' + '3. Add CUDA to your PATH: export PATH=$PATH:$CUDA_HOME/bin' + ) def append_nvcc_threads(nvcc_extra_args): @@ -158,8 +159,6 @@ def append_nvcc_threads(nvcc_extra_args): UserWarning ) - cc_flag.append("-DBUILD_PYTHON_PACKAGE") - else: check_if_cuda_home_none(PACKAGE_NAME) # Check, if CUDA11 is installed for compute capability 8.0