Skip to content

Tensor-parallel SSM #333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 33 commits into
base: concatenated_dim
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
82eed2b
TP mamba
jlamypoirier Jul 21, 2025
4e310c7
TP mamba
jlamypoirier Jul 22, 2025
3cc4118
fix
jlamypoirier Jul 22, 2025
9f7f75c
fix
jlamypoirier Jul 22, 2025
4054e04
fixes
jlamypoirier Jul 23, 2025
0014cc6
fix
jlamypoirier Jul 23, 2025
47ad548
fixes
jlamypoirier Jul 23, 2025
6a074fa
fixes
jlamypoirier Jul 23, 2025
d66651f
Update external
jlamypoirier Jul 23, 2025
5006328
Merge branch 'main' into tp_mamba
jlamypoirier Jul 24, 2025
13176bd
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
73f591f
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
1feccc8
stuff
jlamypoirier Jul 24, 2025
e528b50
misc
jlamypoirier Jul 24, 2025
bb4dcd9
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
c1b7f44
misc
jlamypoirier Jul 24, 2025
051bb07
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
0a9ff25
misc
jlamypoirier Jul 24, 2025
e7d9636
Parallel discrete mamba 2
jlamypoirier Jul 24, 2025
c14b764
Mamba 2, misc
jlamypoirier Jul 25, 2025
b605bd2
doc
jlamypoirier Jul 25, 2025
0a3e2a7
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 28, 2025
2e6d082
fixes
jlamypoirier Jul 28, 2025
b6c8613
misc
jlamypoirier Jul 28, 2025
acdfab1
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 28, 2025
93e4c94
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Jul 28, 2025
c41efc2
doc
jlamypoirier Jul 28, 2025
0b8bd5d
cleanup
jlamypoirier Jul 28, 2025
c0f1597
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Jul 29, 2025
cef7c15
fix
jlamypoirier Jul 30, 2025
defd6e0
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Aug 8, 2025
8abf258
fixes
jlamypoirier Aug 8, 2025
fd3307d
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Aug 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 175 additions & 61 deletions fast_llm/layers/ssm/config.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,37 @@
import enum
import typing

from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none
from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, TensorSpace
from fast_llm.engine.distributed.config import DistributedDimNames
from fast_llm.functional.config import ActivationType
from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig
from fast_llm.utils import Assert
from fast_llm.utils import Assert, div

if typing.TYPE_CHECKING:
from fast_llm.tensor import Initializer


class SSMDimNames:
model_dim = "model_dim" # Model dimension (D)
state_dim = "state_dim" # State dimension (N)
conv_dim = "conv_dim" # Dimension of the conv1d input in mamba layers
inner_dim = "inner_dim" # Inner dimension after expansion
dt_rank = "dt_rank" # Rank of Ξ”
inner_proj_mamba = "inner_proj_mamba" # Inner projection dimension for mamba
inner_proj_discrete_mamba2 = "inner_proj_discrete_mamba2" # Inner projection dimension for discrete mamba2
inner_proj_mamba2 = "inner_proj_mamba2" # Inner projection dimension for mamba2
x_proj_dim = "x_proj_dim" # X projection dimension
head_dim = "head_dim" # Dimension of the mamba2 head (P)
conv_kernel_size = "conv_kernel_size" # Kernel size of the conv1d in mamba layers
qk_heads = "qk_heads" # Number of QK heads
v_heads = "v_heads" # Number of V heads

# Mamba 2
x_proj_dim_2 = "x_proj_dim_2" # d_xb
c_heads = "c_heads"
# TODO: Use separate tensor space for different mixers so there is no risk of name conflict.
state = "ssm_state" # State dimension (N), aka head size / num channels
head_dim = "ssm_head_dim"
head_groups = "ssm_head_groups"
group_heads = "ssm_group_heads"

convolution_kernel = "ssm_convolution_kernel" # Kernel dimension of the conv1d in mamba layers

dt_rank = "ssm_dt_rank"

# Composite dimensions
composite_heads = "ssm_composite_heads"
composite_heads_and_head_dim = "ssm_composite_heads_and_head_dim"
composite_head_groups_and_state = "ssm_composite_head_groups_and_state"

# Concatenated dimensions
concatenated_convolution = "ssm_concatenated_convolution"
concatenated_x_projection = "ssm_x_concatenated_x_projection"
concatenated_inner_projection = "ssm_concatenated_inner_projection"


class SSMBlockType(enum.StrEnum):
Expand Down Expand Up @@ -53,6 +61,16 @@ def get_mixer_class(self):
raise NotImplementedError(self)


class DTInitType(enum.StrEnum):
constant = "constant"
random = "random"

def get_init_method(self, scale: float) -> "Initializer":
from fast_llm.tensor import init_fill_, init_uniform_centered_

return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale)


@config_class()
class SSMConfig(LLMBlockConfig):
_abstract = False
Expand All @@ -62,106 +80,126 @@ class SSMConfig(LLMBlockConfig):
desc="Configuration for the normalization layers architecture.",
hint=FieldHint.architecture,
)

# Model dimensions
# TODO: Remove (redundant default)
expansion_factor: int = Field(
default=2,
desc="Expansion factor for Mamba blocks.",
hint=FieldHint.architecture,
valid=check_field(Assert.gt, 0),
)
# head_size [MambaLayer, Mamba2, DiscreteMamba2]
state_size: int = Field(
default=16,
desc="State size for Mamba blocks.",
hint=FieldHint.architecture,
valid=check_field(Assert.gt, 0),
)
# [MambaLayer, Mamba2, DiscreteMamba2]
conv_kernel_dimension: int = Field(
default=4,
desc="Conv kernel dimension for Mamba blocks.",
hint=FieldHint.architecture,
valid=check_field(Assert.gt, 0),
)
# Layer parameters
add_bias_linear: bool = Field(
default=False,
desc="Whether to use bias in SSM layers",
hint=FieldHint.architecture,
)

# [MambaLayer, Mamba2]
dt_rank: None | int = Field(
default=None,
desc="Rank of the Ξ” projection matrix. If 'None', will be set to ceil(hidden_size/16)",
hint=FieldHint.architecture,
)
chunk_size: int = Field(
default=256,
desc="Chunk size for Mamba2 blocks.",
hint=FieldHint.architecture,
)
# head_groups [DiscreteMamba2]
n_qk_heads: int = Field(
default=32,
desc="Number of QK heads for Mamba2 blocks.",
hint=FieldHint.architecture,
)
# heads [DiscreteMamba2]# TODO: Remove? (redundant)
n_v_heads: int = Field(
default=32,
desc="Number of V heads for Mamba2 blocks.",
hint=FieldHint.architecture,
)
activation_type: ActivationType = Field(
# c_size [MambaLayer, Mamba2, DiscreteMamba2]?
d_inner: None | int = Field(
default=None,
desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.",
hint=FieldHint.architecture,
)
dt_min: float = Field(
default=0.001,
desc="Minimum step size for discretization",
desc="Inner dimension for Mamba2 blocks.",
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)
dt_init_floor: float = Field(
default=1e-4,
desc="Minimum value for initializing dt",
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
# xb_size [Mamba2]
d_xb: int = Field(
default=None,
desc="Dimension of the xB in Mamba2 blocks.",
hint=FieldHint.architecture,
)

d_inner: None | int = Field(
# Model options
# add_bias_linear [Mamba2, DiscreteMamba2] [hard-coded to False in MambaLayer]
add_bias_linear: bool = Field(
default=False,
desc="Whether to use bias in SSM layers",
hint=FieldHint.architecture,
)
# activation_type [DiscreteMamba2] [hard-coded to silu in MambaLayer, Mamba2]
activation_type: ActivationType = Field(
default=None,
desc="Inner dimension for Mamba2 blocks.",
hint=FieldHint.core,
hint=FieldHint.architecture,
)
# repeat_xb_before_conv [Mamba2]
repeat_kv_before_conv: bool = Field(
default=True,
desc="Whether to repeat x and B before (True) or after (False) the conv1d in Mamba2 blocks.",
hint=FieldHint.architecture,
)
# chunk_size [DiscreteMamba2]
chunk_size: int = Field(
default=256,
desc="Chunk size for Mamba2 blocks.",
hint=FieldHint.architecture,
)

# Learning rate
# lr_scale [MambaLayer, Mamba2, DiscreteMamba2]
mamba_lr_scale: float | None = Field(
default=None,
desc="Learning rate scale for Mamba blocks.",
hint=FieldHint.feature,
valid=skip_valid_if_none(check_field(Assert.geq, 0)),
)

# Mamba 2
repeat_kv_before_conv: bool = Field(
default=True,
desc="Whether to repeat the KV before the conv1d in Mamba2 blocks.",
hint=FieldHint.architecture,
# Initialization
# dt_weight_initialization_method [Mamba2]
dt_init: DTInitType = Field(
default=DTInitType.random,
desc="Initialization method for dt",
hint=FieldHint.core,
)
d_xb: int = Field(
default=None,
desc="Dimension of the xB in Mamba2 blocks.",
hint=FieldHint.architecture,
# dt_weight_initialization_scale [Mamba2]
dt_scale: float = Field(
default=1.0,
desc="Scale for dt",
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)
dt_init: str = Field(
default="random",
desc="Initialization method for dt",
# dt_bias_initialization_min [MambaLayer, Mamba2]
dt_min: float = Field(
default=0.001,
desc="Minimum step size for discretization",
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)
# dt_bias_initialization_max [MambaLayer, Mamba2]
dt_max: float = Field(
default=0.1,
desc="Maximum step size for discretization",
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)
dt_scale: float = Field(
default=1.0,
desc="Scale for dt",
# dt_bias_initialization_floor [MambaLayer, Mamba2]
dt_init_floor: float = Field(
default=1e-4,
desc="Minimum value for initializing dt",
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)
Expand All @@ -172,3 +210,79 @@ def _validate(self) -> None:
self.activation_type = ActivationType.silu
super()._validate()
Assert.geq(self.dt_max, self.dt_min)

def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType) -> None:
tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor)

# Head groups are configured differently depending on the block type.
if block_type == SSMBlockType.mamba:
num_heads = div(self.d_inner, self.state_size)
num_head_groups = num_heads
elif block_type == SSMBlockType.mamba2:
num_heads = div(self.d_inner, self.state_size)
num_head_groups = div(self.d_xb, self.state_size)
elif block_type == SSMBlockType.mamba2_discrete:
# TODO: Use different variables?
num_heads = self.n_v_heads
num_head_groups = self.n_qk_heads
else:
raise NotImplementedError(block_type)

tensor_space.add_tensor_dim(state := TensorDim(SSMDimNames.state, self.state_size))
if block_type == SSMBlockType.mamba2_discrete:
tensor_space.add_tensor_dim(head_dim := TensorDim(SSMDimNames.head_dim, div(self.d_inner, num_heads)))
else:
head_dim = state

tensor_space.add_tensor_dim(head_groups := TensorDim(SSMDimNames.head_groups, num_head_groups, tensor))
tensor_space.add_tensor_dim(group_heads := TensorDim(SSMDimNames.group_heads, div(num_heads, num_head_groups)))
tensor_space.add_tensor_dim(
heads := CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads))
)
tensor_space.add_tensor_dim(
heads_and_head_dim := CompositeTensorDim(
SSMDimNames.composite_heads_and_head_dim, (head_groups, group_heads, head_dim)
)
)
tensor_space.add_tensor_dim(
head_groups_and_state := CompositeTensorDim(
SSMDimNames.composite_head_groups_and_state, (head_groups, state)
)
)
tensor_space.add_tensor_dim(TensorDim(SSMDimNames.convolution_kernel, self.conv_kernel_dimension))

# DT projection
if block_type in (SSMBlockType.mamba, SSMBlockType.mamba2):
tensor_space.add_tensor_dim(dt_rank := TensorDim(SSMDimNames.dt_rank, self.dt_rank))

if block_type == SSMBlockType.mamba:
tensor_space.add_tensor_dim(
ConcatenatedTensorDim(SSMDimNames.concatenated_x_projection, (dt_rank, state, state))
)
# TODO: Use composition instead
tensor_space.add_tensor_dim(
ConcatenatedTensorDim(
SSMDimNames.concatenated_inner_projection, (heads_and_head_dim, heads_and_head_dim)
)
)
elif block_type == SSMBlockType.mamba2:
# TODO: Factor out state?
tensor_space.add_tensor_dim(
ConcatenatedTensorDim(
SSMDimNames.concatenated_inner_projection,
(heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim),
)
)
elif block_type == SSMBlockType.mamba2_discrete:
tensor_space.add_tensor_dim(
ConcatenatedTensorDim(
SSMDimNames.concatenated_inner_projection,
(heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim, heads),
)
)
tensor_space.add_tensor_dim(
ConcatenatedTensorDim(
SSMDimNames.concatenated_convolution,
(heads_and_head_dim, head_groups_and_state, head_groups_and_state),
)
)
Loading