-
Notifications
You must be signed in to change notification settings - Fork 34
Description
Model description
Mamba2
A pure JAX/Flax implementation of the Mamba2 architecture introduced in "Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality".
This implementation focuses on the "State Space Duality" (SSD) mechanism, enabling highly efficient computation via matrix multiplications. This version is pure JAX, making it fully compatible with XLA for seamless execution on TPUs and GPUs.
It currently supports:
Mamba2ForCausalLM(Language Modeling)Mamba2Forecaster(Time Series)- Numerical parity validated against the reference PyTorch implementation.
Open source examples / papers / references
- Paper: Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality (Dao & Gu, ICML 2024)
- Model code: https://github.com/CosmoNaught/mamba2-jax
- Reference implementation: https://github.com/vasqu/mamba2-torch
- Model weights: Initial Bonsai integration will support random initialisation. A follow-up PR can add
params.pyhelpers to load/convert official Mamba2 ormamba2-torchcheckpoints once the architecture is settled.
Why this model?
This model would be a strong addition to the JAX community for several reasons:
- Hardware Flexibility (TPU Support): Because this implementation uses pure JAX operations (scans/matrix muls) rather than custom CUDA/Triton kernels, it runs seamlessly on Google Cloud TPUs (tested on v5e-1) as well as GPUs.
- Educational & Hackable: The functional programming paradigm of JAX offers a cleaner implementation of the SSD algorithm, making it easier for researchers to understand the Mamba2 internals without navigating low-level CUDA kernels.
- Performance: Initial benchmarks show the JAX implementation is ~2x faster per step on CPU compared to the PyTorch reference, demonstrating the efficiency of XLA for this architecture. NOTE: This was validated during a trivial testbed, further tests are required but initial results look promising even when accounting for a JIT time-tax
- Ecosystem Integration: It is built to integrate naturally with the existing JAX ecosystem (Flax, Optax), facilitating easy research into hybrid architectures.
Brief Implementation Plan
To align with Bonsai's design philosophy and future roadmap:
- NNX Refactor: Refactor the current Linen implementation (prototype) to Flax NNX.
- Structure: Consolidate the math logic (under
ssd.py) and wrapper layers (causalLM and time-series forecasting) into a singlemodeling.pyfile to adhere to the single-file policy, unless advised otherwise.
Full Implementation Plan
As standard Bonsai model dir structures I’d like to add a new model under:
bonsai/models/mamba2/
├── tests/
│ ├── __init__.py
│ ├── run_model.py
│ └── test_outputs.py
├── README.md
├── modeling.py
└── params.py
Planned pieces:
-
modeling.py-
Flax NNX implementation of:
- SSD core (Mamba2 block)
Mamba2Model- In the initial PR:
Mamba2ForCausalLM - Follow-up:
Mamba2Forecaster, lagging this to focus on one correct implementation first.
-
SSD and heads will live together here, in line with the single-model, single-file structure.
-
-
params.py- In the initial PR: a small helper to instantiate a randomly-initialised
Mamba2ForCausalLM - Follow-up: add weight-loading / conversion utilities for official Mamba2 or
mamba2-torchcheckpoints
- In the initial PR: a small helper to instantiate a randomly-initialised
-
tests/test_outputs.py: small correctness/shape tests for a tiny Mamba2 config (forward pass runs, expected shapes/dtypes, no NaNs).run_model.py: simple smoke test to instantiate the model and run a forward pass.
-
README.md- Short description, links to paper / references, and a small “Tested on” matrix (CPU / GPU / TPU).