Skip to content

[Models]: Add MoM #442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 33 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
2356954
mom
Jun 13, 2025
34305b2
Merge branch 'main' into mom_
yzhangcs Jun 16, 2025
36be1c9
Merge branch 'main' into mom_
zhiyuan1i Jun 16, 2025
d5ffc7e
Merge branch 'main' into mom_
yzhangcs Jun 16, 2025
75a2b9b
Add `MomGatedDeltaNet` back to the `__init__.py` module exports
yzhangcs Jun 30, 2025
e043a67
Cleaned up import statements in __init__.py
yzhangcs Jun 30, 2025
63e1885
Fix the issue of gradients being NaN
JusenD Jul 9, 2025
7f0fc0c
Cleanup code
JusenD Jul 9, 2025
cefb647
Merge branch 'main' into mom_
yzhangcs Jul 11, 2025
dc7cf66
Update mom.py
yzhangcs Jul 13, 2025
e274e44
Change model name & Update conv api
JusenD Jul 13, 2025
7ab9b43
Fix format issues
yzhangcs Jul 13, 2025
11fa1b6
Replace old modules
yzhangcs Jul 13, 2025
9f62f28
Update docstring and default config
JusenD Jul 13, 2025
4e3a997
Remove old ops
JusenD Jul 13, 2025
9e2b4b6
Update some old settings
yzhangcs Jul 13, 2025
2b51b28
Fix isort
yzhangcs Jul 13, 2025
706d6fd
support inference
JusenD Jul 22, 2025
cea11c4
Fix inference & support expand_v
JusenD Jul 23, 2025
8a8b6c2
Fix rescale_prenorm_residual
JusenD Jul 23, 2025
b9f958a
Add generation testing
yzhangcs Jul 24, 2025
54ec5b6
Merge branch 'main' into mom_
yzhangcs Jul 24, 2025
b105c83
Add proper pythonpath for pytest
yzhangcs Jul 24, 2025
51d014b
Update registration of MomConfig and related models to allow for exis…
yzhangcs Jul 24, 2025
95dc418
Delete unused act
yzhangcs Jul 24, 2025
2d0f203
Refactor the code using cu_seqlen. test_generation passed wo o_norm
JusenD Jul 26, 2025
da52722
Support cu_seqlens
JusenD Jul 27, 2025
273cce6
Support shared memory & test generatation passed wo norm
JusenD Jul 27, 2025
bf5b511
Fix bugs
JusenD Jul 27, 2025
0f80d59
Fix bugs
JusenD Jul 28, 2025
8ba6bed
Remove router_logits in output & fix bugs
JusenD Jul 28, 2025
f5df22a
Fix bugs
JusenD Aug 4, 2025
ad802b0
Fix bugs
JusenD Aug 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions fla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
LightNetAttention,
LinearAttention,
MesaNet,
MomAttention,
MultiheadLatentAttention,
MultiScaleRetention,
NativeSparseAttention,
Expand Down Expand Up @@ -54,6 +55,8 @@
MesaNetModel,
MLAForCausalLM,
MLAModel,
MomForCausalLM,
MomModel,
NSAForCausalLM,
NSAModel,
PaTHAttentionForCausalLM,
Expand Down Expand Up @@ -86,6 +89,7 @@
'LightNetAttention', 'LightNetForCausalLM', 'LightNetModel',
'LinearAttention', 'LinearAttentionForCausalLM', 'LinearAttentionModel',
'MesaNet', 'MesaNetForCausalLM', 'MesaNetModel',
'MomAttention', 'MomForCausalLM', 'MomModel',
'MultiheadLatentAttention', 'MLAForCausalLM', 'MLAModel',
'MultiScaleRetention', 'RetNetForCausalLM', 'RetNetModel',
'NativeSparseAttention', 'NSAForCausalLM', 'NSAModel',
Expand Down
2 changes: 2 additions & 0 deletions fla/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .mamba2 import Mamba2
from .mesa_net import MesaNet
from .mla import MultiheadLatentAttention
from .mom import MomAttention
from .multiscale_retention import MultiScaleRetention
from .nsa import NativeSparseAttention
from .path_attn import PaTHAttention
Expand Down Expand Up @@ -47,6 +48,7 @@
'Mamba',
'Mamba2',
'MesaNet',
'MomAttention',
'MultiheadLatentAttention',
'MultiScaleRetention',
'NativeSparseAttention',
Expand Down
788 changes: 788 additions & 0 deletions fla/layers/mom.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions fla/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from fla.models.mamba2 import Mamba2Config, Mamba2ForCausalLM, Mamba2Model
from fla.models.mesa_net import MesaNetConfig, MesaNetForCausalLM, MesaNetModel
from fla.models.mla import MLAConfig, MLAForCausalLM, MLAModel
from fla.models.mom import MomConfig, MomForCausalLM, MomModel
from fla.models.nsa import NSAConfig, NSAForCausalLM, NSAModel
from fla.models.path_attn import PaTHAttentionConfig, PaTHAttentionForCausalLM, PaTHAttentionModel
from fla.models.retnet import RetNetConfig, RetNetForCausalLM, RetNetModel
Expand All @@ -47,6 +48,7 @@
'MambaConfig', 'MambaForCausalLM', 'MambaModel',
'Mamba2Config', 'Mamba2ForCausalLM', 'Mamba2Model',
'MesaNetConfig', 'MesaNetForCausalLM', 'MesaNetModel',
'MomConfig', 'MomForCausalLM', 'MomModel',
'MLAConfig', 'MLAForCausalLM', 'MLAModel',
'NSAConfig', 'NSAForCausalLM', 'NSAModel',
'PaTHAttentionConfig', 'PaTHAttentionForCausalLM', 'PaTHAttentionModel',
Expand Down
12 changes: 12 additions & 0 deletions fla/models/mom/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-

from transformers import AutoConfig, AutoModel, AutoModelForCausalLM

from fla.models.mom.configuration_mom import MomConfig
from fla.models.mom.modeling_mom import MomForCausalLM, MomModel

AutoConfig.register(MomConfig.model_type, MomConfig, exist_ok=True)
AutoModel.register(MomConfig, MomModel, exist_ok=True)
AutoModelForCausalLM.register(MomConfig, MomForCausalLM, exist_ok=True)

__all__ = ['MomConfig', 'MomForCausalLM', 'MomModel']
101 changes: 101 additions & 0 deletions fla/models/mom/configuration_mom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# -*- coding: utf-8 -*-

from typing import Dict, Optional

from transformers.configuration_utils import PretrainedConfig


class MomConfig(PretrainedConfig):
model_type = 'mom'
keys_to_ignore_at_inference = ['past_key_values']

def __init__(
self,
attn_mode: str = "chunk",
hidden_size: int = 2048,
conv_size: int = 4,
num_heads: int = 4,
head_dim: int = 256,
expand_v: float = 1.,
use_output_gate: bool = True,
use_short_conv: bool = True,
max_position_embeddings: int = 2048,
hidden_ratio: Optional[int] = 4,
intermediate_size: Optional[int] = None,
hidden_act: str = "swish",
num_hidden_layers: int = 24,
norm_eps: float = 1e-6,
attn: Optional[Dict] = None,
use_cache: bool = True,
pad_token_id: Optional[int] = None,
bos_token_id: int = 1,
eos_token_id: int = 2,
tie_word_embeddings: bool = False,
initializer_range: float = 0.02,
num_memories: int = 4,
topk: int = 2,
capacity: float = 1.0,
use_layer_wise_balance: bool = True,
aux_loss_scale: float = 0.01,
shared_mem: bool = True,
single_kv_proj: bool = False,
mom_backend: str = 'gated_deltanet',
fuse_norm: bool = True,
fuse_swiglu: bool = True,
fuse_cross_entropy: bool = True,
vocab_size: int = 32000,
**kwargs
):
self.attn_mode = attn_mode
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = head_dim
self.expand_v = expand_v
self.conv_size = conv_size
self.use_output_gate = use_output_gate
self.use_short_conv = use_short_conv
self.max_position_embeddings = max_position_embeddings

self.hidden_ratio = hidden_ratio
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.num_hidden_layers = num_hidden_layers
self.norm_eps = norm_eps
self.attn = attn
self.use_cache = use_cache
self.initializer_range = initializer_range

self.num_memories = num_memories
self.topk = topk
self.capacity = capacity
self.use_layer_wise_balance = use_layer_wise_balance
self.aux_loss_scale = aux_loss_scale
self.shared_mem = shared_mem
self.single_kv_proj = single_kv_proj
self.mom_backend = mom_backend

self.fuse_norm = fuse_norm
self.fuse_swiglu = fuse_swiglu
self.fuse_cross_entropy = fuse_cross_entropy
self.vocab_size = vocab_size

if self.mom_backend not in ['gated_deltanet']:
raise NotImplementedError(f"The MoM backend {mom_backend} is not currently supported.")

if attn is not None:
if not isinstance(attn, Dict):
raise ValueError("attn must be a dictionary")
if 'layers' not in attn:
raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
if 'num_heads' not in attn:
raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
attn['window_size'] = attn.get('window_size', None)

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
Loading
Loading