Skip to content

[Models]: Add MoM #442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open

[Models]: Add MoM #442

wants to merge 31 commits into from

Conversation

WKX933
Copy link

@WKX933 WKX933 commented Jun 13, 2025

Changes Made

New Modules

  • fla/layers/
    • Added mom.py
    • Introduced mom_varlen.py (for varlen)

Model Architecture

  • models/
    • Implemented mom_gated_deltanet (Gated DeltaNet with MoM integration)

Summary by CodeRabbit

  • New Features
    • Introduced the Mixture-of-Memories (MoM) model architecture with a memory-organized gated delta attention layer and specialized transformer blocks.
    • Added the Mom model, configuration, and causal language modeling support with integration into Hugging Face Transformers' auto-model utilities.
    • Implemented an auxiliary load balancing loss to promote balanced routing across memory units.
    • Enabled efficient generation with caching and input preparation for causal language modeling tasks.
    • Added comprehensive tests validating model forward/backward passes and generation capabilities.

Copy link

coderabbitai bot commented Jun 13, 2025

Walkthrough

This change introduces a new Mixture-of-Memories (MoM) architecture to the fla package. It adds a memory-organized gated delta network layer (MomAttention), a configuration class (MomConfig), and model classes (MomModel, MomForCausalLM). These components are integrated with the Hugging Face Transformers interface and made available in the package's public API.

Changes

File(s) Change Summary
Package exports
fla/__init__.py, fla/layers/__init__.py, fla/models/__init__.py
Expose MomAttention, MomConfig, MomModel, and MomForCausalLM in the package and submodule namespaces via imports and __all__.
Memory-organized attention layer
fla/layers/mom.py
Adds new MomAttention layer implementing memory routing, gating, capacity-aware token routing utilities (_upad_input, transform, reconstruct), recurrent/convolutional state caching, and output normalization.
Model registration
fla/models/mom/__init__.py
Registers MomConfig, MomModel, and MomForCausalLM with Hugging Face Transformers auto classes and exposes them via __all__.
Model configuration
fla/models/mom/configuration_mom.py
Introduces MomConfig class defining model hyperparameters, validation logic, and enforcing backend constraints for MoM models.
Model implementation
fla/models/mom/modeling_mom.py
Implements MomBlock, MomPreTrainedModel, MomModel, MomForCausalLM, output dataclasses, and a layer-wise load balancing loss function; supports caching, gradient checkpointing, generation, and routing-aware loss calculation.
Tests
tests/models/test_modeling_mom.py
Adds parameterized tests for MomConfig model forward/backward passes and generation, covering multiple configurations and data types.
Test configuration
pyproject.toml
Adds pytest configuration to include current directory in Python path during tests.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant MomForCausalLM
    participant MomModel
    participant MomBlock
    participant MomAttention
    participant MomMLP

    User->>MomForCausalLM: forward(input_ids, ...)
    MomForCausalLM->>MomModel: forward(inputs_embeds, ...)
    MomModel->>MomBlock: for each layer, forward(hidden_states, ...)
    MomBlock->>MomAttention: forward(hidden_states, ...)
    MomAttention->>MomAttention: Route tokens to memories (transform)
    MomAttention->>MomAttention: Compute attention/gating/convolutions
    MomAttention->>MomAttention: Reconstruct outputs (reconstruct)
    MomAttention-->>MomBlock: output, cache, router_logits
    MomBlock->>MomMLP: forward(hidden_states)
    MomMLP-->>MomBlock: output
    MomBlock-->>MomModel: hidden_states, cache, router_logits
    MomModel-->>MomForCausalLM: hidden_states, cache, router_logits
    MomForCausalLM-->>User: logits, loss, cache, router_logits
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~90 minutes

Poem

🐇
A memory mix, a clever new trick,
With gates and with tokens that swiftly can pick.
New configs and models, attention anew,
Routing and balancing—so much to do!
In the garden of code, these features now bloom,
The MoM rabbit hops, making plenty of room!

Note

⚡️ Unit Test Generation is now available in beta!

Learn more here, or try it out under "Finishing Touches" below.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai generate unit tests to generate unit tests for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 16

🧹 Nitpick comments (10)
fla/layers/__init__.py (1)

55-57: Minor style – keep __all__ alphabetically sorted & add trailing comma.

Maintaining alphabetical order and a trailing comma prevents merge conflicts and keeps diffs minimal.

-    'MesaNet',
-    'MomGatedDeltaNet'
+    'MesaNet',
+    'MomGatedDeltaNet',
fla/models/__init__.py (2)

30-30: Import list spacing / readability.

Missing space after the first comma is inconsistent with the rest of the file:

-from fla.models.mom_gated_deltanet import MomGatedDeltaNetConfig,MomGatedDeltaNetForCausalLM, MomGatedDeltaNetModel
+from fla.models.mom_gated_deltanet import MomGatedDeltaNetConfig, MomGatedDeltaNetForCausalLM, MomGatedDeltaNetModel

56-56: Trailing comma & alphabetical order in __all__.

-    'MomGatedDeltaNetConfig','MomGatedDeltaNetForCausalLM','MomGatedDeltaNetModel'
+    'MomGatedDeltaNetConfig', 'MomGatedDeltaNetForCausalLM', 'MomGatedDeltaNetModel',
fla/__init__.py (1)

23-26: Top-level heavy imports slow down import fla.

Adding yet another large model (MomGatedDeltaNet) to the mandatory import path further increases import time and memory footprint of the package. Consider a lazy-import pattern (e.g. via importlib.metadata.EntryPoints or a simple LazyModule helper) so that users who do not need this model are not penalised.

fla/models/mom_gated_deltanet/configuration_mom_gated_deltanet.py (1)

11-45: Huge argument list – consider dataclass / kwargs container.

32 positional parameters exceed readability and trigger pylint R0913/R0917. A dataclass (or grouping related flags into small parameter objects) would make the public API friendlier and future-proof.

🧰 Tools
🪛 Pylint (3.3.7)

[refactor] 11-11: Too many arguments (32/5)

(R0913)


[refactor] 11-11: Too many positional arguments (32/5)

(R0917)


[refactor] 11-11: Too many local variables (33/15)

(R0914)

fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py (2)

470-470: Consider using professional comments without emojis.

While the emoji adds visual emphasis, professional codebases typically avoid them in comments.

Apply this diff:

-            use_layer_wise_balance=self.config.use_layer_wise_balance,  # ✨
+            use_layer_wise_balance=self.config.use_layer_wise_balance,  # Layer-wise balance flag

516-517: Consider removing emojis from code comments for professionalism.

Apply this diff:

-    # ✨ Here is the fix for balance loss in Mixtral.
+    # Here is the fix for balance loss in Mixtral.
-        # ✨ balance loss for this layer
+        # Balance loss for this layer
-    all_balance_losses = torch.cat(all_balance_losses).mean()  # ✨
+    all_balance_losses = torch.cat(all_balance_losses).mean()

Also applies to: 559-559

fla/layers/mom.py (1)

443-443: Remove trailing comma to avoid creating a tuple.

The trailing comma creates an unnecessary tuple.

Apply this diff:

-            q, k, v = self.silu(q), self.silu(k), self.silu(v),
+            q, k, v = self.silu(q), self.silu(k), self.silu(v)
🧰 Tools
🪛 Pylint (3.3.7)

[refactor] 443-443: Disallow trailing comma tuple

(R1707)

fla/layers/mom_varlen.py (2)

505-505: Remove unused variable assignment.

The variable batchsize is assigned but never used.

Apply this diff:

-        batchsize,q_len = hidden_states.shape[0],hidden_states.shape[1]
+        batch_size, q_len = hidden_states.shape[0], hidden_states.shape[1]

Note: If you intended to use batch_size later in the code, make sure to update all references accordingly.

🧰 Tools
🪛 Ruff (0.11.9)

505-505: Local variable batchsize is assigned to but never used

Remove assignment to unused variable batchsize

(F841)


561-561: Remove trailing comma to avoid creating a tuple.

Apply this diff:

-            q, k, v = self.silu(q), self.silu(k), self.silu(v),
+            q, k, v = self.silu(q), self.silu(k), self.silu(v)
🧰 Tools
🪛 Pylint (3.3.7)

[refactor] 561-561: Disallow trailing comma tuple

(R1707)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 162ef52 and 2356954.

📒 Files selected for processing (8)
  • fla/__init__.py (3 hunks)
  • fla/layers/__init__.py (2 hunks)
  • fla/layers/mom.py (1 hunks)
  • fla/layers/mom_varlen.py (1 hunks)
  • fla/models/__init__.py (2 hunks)
  • fla/models/mom_gated_deltanet/__init__.py (1 hunks)
  • fla/models/mom_gated_deltanet/configuration_mom_gated_deltanet.py (1 hunks)
  • fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py (1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (2)
fla/layers/__init__.py (2)
fla/layers/mom.py (1)
  • MomGatedDeltaNet (192-613)
fla/layers/mom_varlen.py (1)
  • MomGatedDeltaNet (309-724)
fla/models/__init__.py (2)
fla/models/mom_gated_deltanet/configuration_mom_gated_deltanet.py (1)
  • MomGatedDeltaNetConfig (8-91)
fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py (2)
  • MomGatedDeltaNetForCausalLM (320-486)
  • MomGatedDeltaNetModel (209-313)
🪛 Pylint (3.3.7)
fla/models/mom_gated_deltanet/configuration_mom_gated_deltanet.py

[refactor] 8-8: Too many instance attributes (27/7)

(R0902)


[refactor] 11-11: Too many arguments (32/5)

(R0913)


[refactor] 11-11: Too many positional arguments (32/5)

(R0917)


[refactor] 11-11: Too many local variables (33/15)

(R0914)


[refactor] 8-8: Too few public methods (0/2)

(R0903)

fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py

[refactor] 11-11: Use 'from torch import nn' instead

(R0402)


[refactor] 37-37: Too many instance attributes (8/7)

(R0902)


[refactor] 39-39: Too many arguments (7/5)

(R0913)


[refactor] 39-39: Too many positional arguments (7/5)

(R0917)


[refactor] 37-37: Too few public methods (1/2)

(R0903)


[refactor] 128-128: Too many arguments (6/5)

(R0913)


[refactor] 128-128: Too many positional arguments (6/5)

(R0917)


[refactor] 82-82: Too few public methods (1/2)

(R0903)


[refactor] 161-161: Too few public methods (0/2)

(R0903)


[refactor] 230-230: Too many arguments (9/5)

(R0913)


[refactor] 230-230: Too many positional arguments (9/5)

(R0917)


[refactor] 230-230: Too many local variables (17/15)

(R0914)


[error] 307-313: Unexpected keyword argument 'last_hidden_state' in constructor call

(E1123)


[error] 307-313: Unexpected keyword argument 'past_key_values' in constructor call

(E1123)


[error] 307-313: Unexpected keyword argument 'hidden_states' in constructor call

(E1123)


[error] 307-313: Unexpected keyword argument 'attentions' in constructor call

(E1123)


[refactor] 230-230: Too many branches (13/12)

(R0912)


[refactor] 358-367: Unnecessary "else" after "raise", remove the "else" and de-indent the code inside it

(R1720)


[refactor] 369-369: Too many arguments (7/5)

(R0913)


[refactor] 369-369: Too many positional arguments (7/5)

(R0917)


[refactor] 403-403: Too many arguments (11/5)

(R0913)


[refactor] 403-403: Too many positional arguments (11/5)

(R0917)


[refactor] 403-403: Too many local variables (21/15)

(R0914)


[error] 478-486: Unexpected keyword argument 'loss' in constructor call

(E1123)


[error] 478-486: Unexpected keyword argument 'logits' in constructor call

(E1123)


[error] 478-486: Unexpected keyword argument 'past_key_values' in constructor call

(E1123)


[error] 478-486: Unexpected keyword argument 'hidden_states' in constructor call

(E1123)


[error] 478-486: Unexpected keyword argument 'attentions' in constructor call

(E1123)

fla/layers/mom.py

[refactor] 443-443: Disallow trailing comma tuple

(R1707)


[refactor] 10-10: Use 'from torch import nn' instead

(R0402)


[refactor] 34-34: Too many local variables (29/15)

(R0914)


[refactor] 129-129: Too many arguments (8/5)

(R0913)


[refactor] 129-129: Too many positional arguments (8/5)

(R0917)


[refactor] 129-129: Too many local variables (19/15)

(R0914)


[refactor] 192-192: Too many instance attributes (39/7)

(R0902)


[refactor] 239-239: Too many arguments (17/5)

(R0913)


[refactor] 239-239: Too many positional arguments (17/5)

(R0917)


[refactor] 239-239: Too many local variables (25/15)

(R0914)


[refactor] 239-239: Too many statements (68/50)

(R0915)


[refactor] 367-367: Too many arguments (6/5)

(R0913)


[refactor] 367-367: Too many positional arguments (6/5)

(R0917)


[refactor] 367-367: Too many local variables (42/15)

(R0914)


[error] 486-496: Unexpected keyword argument 'head_first' in function call

(E1123)


[error] 510-510: Undefined variable 'o'

(E0602)


[refactor] 367-367: Too many branches (22/12)

(R0912)


[refactor] 367-367: Too many statements (87/50)

(R0915)


[refactor] 531-531: Too many arguments (8/5)

(R0913)


[refactor] 531-531: Too many positional arguments (8/5)

(R0917)


[refactor] 531-531: Too many local variables (19/15)

(R0914)


[error] 599-609: Unexpected keyword argument 'head_first' in function call

(E1123)

fla/layers/mom_varlen.py

[refactor] 561-561: Disallow trailing comma tuple

(R1707)


[refactor] 10-10: Use 'from torch import nn' instead

(R0402)


[refactor] 64-64: Too many arguments (7/5)

(R0913)


[refactor] 64-64: Too many positional arguments (7/5)

(R0917)


[refactor] 64-64: Too many local variables (17/15)

(R0914)


[error] 108-108: Possibly using variable 'index_first_axis' before assignment

(E0606)


[error] 129-129: Possibly using variable 'unpad_input' before assignment

(E0606)


[refactor] 149-149: Too many local variables (31/15)

(R0914)


[refactor] 246-246: Too many arguments (8/5)

(R0913)


[refactor] 246-246: Too many positional arguments (8/5)

(R0917)


[refactor] 246-246: Too many local variables (19/15)

(R0914)


[refactor] 309-309: Too many instance attributes (39/7)

(R0902)


[refactor] 356-356: Too many arguments (17/5)

(R0913)


[refactor] 356-356: Too many positional arguments (17/5)

(R0917)


[refactor] 356-356: Too many local variables (25/15)

(R0914)


[refactor] 356-356: Too many statements (68/50)

(R0915)


[refactor] 484-484: Too many arguments (6/5)

(R0913)


[refactor] 484-484: Too many positional arguments (6/5)

(R0917)


[refactor] 484-484: Too many local variables (52/15)

(R0914)


[error] 595-595: Possibly using variable 'pad_input' before assignment

(E0606)


[error] 602-612: Unexpected keyword argument 'head_first' in function call

(E1123)


[error] 621-621: Undefined variable 'o'

(E0602)


[refactor] 484-484: Too many branches (16/12)

(R0912)


[refactor] 484-484: Too many statements (85/50)

(R0915)


[refactor] 642-642: Too many arguments (8/5)

(R0913)


[refactor] 642-642: Too many positional arguments (8/5)

(R0917)


[refactor] 642-642: Too many local variables (19/15)

(R0914)


[error] 710-720: Unexpected keyword argument 'head_first' in function call

(E1123)

🪛 Ruff (0.11.9)
fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py

243-243: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


359-365: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)

fla/layers/mom.py

425-425: Use kwargs.get('seq_idx') instead of kwargs.get('seq_idx', None)

Replace kwargs.get('seq_idx', None) with kwargs.get('seq_idx')

(SIM910)


456-456: Use kwargs.get('offsets') instead of kwargs.get('offsets', None)

Replace kwargs.get('offsets', None) with kwargs.get('offsets')

(SIM910)


537-537: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


538-538: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


539-539: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


555-555: Use kwargs.get('seq_idx') instead of kwargs.get('seq_idx', None)

Replace kwargs.get('seq_idx', None) with kwargs.get('seq_idx')

(SIM910)


584-584: Use kwargs.get('offsets') instead of kwargs.get('offsets', None)

Replace kwargs.get('offsets', None) with kwargs.get('offsets')

(SIM910)

fla/layers/mom_varlen.py

29-29: Redefinition of unused pad_input from line 28

Remove definition: pad_input

(F811)


505-505: Local variable batchsize is assigned to but never used

Remove assignment to unused variable batchsize

(F841)


543-543: Use kwargs.get('seq_idx') instead of kwargs.get('seq_idx', None)

Replace kwargs.get('seq_idx', None) with kwargs.get('seq_idx')

(SIM910)


579-579: Local variable offsets is assigned to but never used

Remove assignment to unused variable offsets

(F841)


579-579: Use kwargs.get('offsets') instead of kwargs.get('offsets', None)

Replace kwargs.get('offsets', None) with kwargs.get('offsets')

(SIM910)


648-648: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


649-649: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


650-650: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


666-666: Use kwargs.get('seq_idx') instead of kwargs.get('seq_idx', None)

Replace kwargs.get('seq_idx', None) with kwargs.get('seq_idx')

(SIM910)


695-695: Use kwargs.get('offsets') instead of kwargs.get('offsets', None)

Replace kwargs.get('offsets', None) with kwargs.get('offsets')

(SIM910)

🔇 Additional comments (1)
fla/models/mom_gated_deltanet/__init__.py (1)

10-12: Repeated registration can raise on re-import.

AutoConfig.register, AutoModel.register, … will raise a ValueError if the key is already registered (e.g. in unit-test re-imports or interactive sessions). A defensive guard avoids noisy stack traces:

-AutoConfig.register(MomGatedDeltaNetConfig.model_type, MomGatedDeltaNetConfig)
-AutoModel.register(MomGatedDeltaNetConfig, MomGatedDeltaNetModel)
-AutoModelForCausalLM.register(MomGatedDeltaNetConfig, MomGatedDeltaNetForCausalLM)
+for registry, key, value in [
+    (AutoConfig, MomGatedDeltaNetConfig.model_type, MomGatedDeltaNetConfig),
+    (AutoModel, MomGatedDeltaNetConfig, MomGatedDeltaNetModel),
+    (AutoModelForCausalLM, MomGatedDeltaNetConfig, MomGatedDeltaNetForCausalLM),
+]:
+    try:
+        registry.register(key, value)
+    except ValueError:
+        # Already registered – safe to ignore in repeated imports.
+        pass

@@ -25,6 +25,7 @@
from .rodimus import RodimusAttention, SlidingWindowSharedKeyAttention
from .rwkv6 import RWKV6Attention
from .rwkv7 import RWKV7Attention
from .mom import MomGatedDeltaNet
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ambiguous class resolution – two different MomGatedDeltaNet implementations exist.

Both fla/layers/mom.py and fla/layers/mom_varlen.py declare a class named MomGatedDeltaNet, but only the one from mom.py is imported here. Call-sites that expect the variable-length behaviour implemented in mom_varlen.py will silently get the other variant, which is extremely hard to debug.

Consider:

-from .mom import MomGatedDeltaNet
+# Prefer an explicit alias to make the chosen variant obvious
+from .mom import MomGatedDeltaNet as MomGatedDeltaNetFixedLen
+# If you need the var-len version as well, export it under a distinct name
+from .mom_varlen import MomGatedDeltaNet as MomGatedDeltaNetVarLen

…and expose the chosen symbol(s) accordingly in __all__.
Failing to disambiguate the two will almost certainly bite you the moment both variants are needed in the same runtime.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from .mom import MomGatedDeltaNet
# Prefer an explicit alias to make the chosen variant obvious
from .mom import MomGatedDeltaNet as MomGatedDeltaNetFixedLen
# If you need the var-len version as well, export it under a distinct name
from .mom_varlen import MomGatedDeltaNet as MomGatedDeltaNetVarLen
🤖 Prompt for AI Agents
In fla/layers/__init__.py at line 28, there is ambiguity because two different
classes named MomGatedDeltaNet exist in mom.py and mom_varlen.py, but only the
one from mom.py is imported. To fix this, explicitly import both classes with
distinct aliases or names to differentiate them, for example, MomGatedDeltaNet
and MomGatedDeltaNetVarLen. Then, update the __all__ list to expose these
distinct names clearly. This will prevent silent confusion and ensure call-sites
use the intended variant.

Comment on lines 11 to 45
def __init__(
self,
attn_mode: str = "chunk",
hidden_size: int = 2048,
expand_v: int = 2,
use_gate: bool = True,
use_short_conv: bool = True,
conv_size: int = 4,
head_dim: int = 256,
num_heads: int = 6,
max_position_embeddings: int = 2048,
hidden_ratio: Optional[int] = 4,
intermediate_size: Optional[int] = None,
hidden_act: str = "swish",
num_hidden_layers: int = 21,
norm_first: bool = False,
norm_eps: float = 1e-6,
attn: Optional[Dict] = None,
use_cache: bool = True,
pad_token_id: int = None,
bos_token_id: int = 1,
eos_token_id: int = 2,
tie_word_embeddings: bool = False,
initializer_range: float = 0.02,
fuse_cross_entropy: bool = True,
vocab_size: int = 32000,
num_memories: int = 8,
topk: int = 2,
capacity: float = 1.0,
use_layer_wise_balance: bool=True,
aux_loss_scale: float=0.01,
shared_mem: bool = False,
single_kv_proj: bool = False,
**kwargs
):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Constructor mutates input dict – copy instead.

attn is modified in-place (attn['num_kv_heads'] = …). Callers that reuse the same dict for multiple configs will observe surprising side-effects.

-        self.attn = attn
+        # keep the original object untouched
+        self.attn = dict(attn) if attn is not None else None
🧰 Tools
🪛 Pylint (3.3.7)

[refactor] 11-11: Too many arguments (32/5)

(R0913)


[refactor] 11-11: Too many positional arguments (32/5)

(R0917)


[refactor] 11-11: Too many local variables (33/15)

(R0914)

🤖 Prompt for AI Agents
In fla/models/mom_gated_deltanet/configuration_mom_gated_deltanet.py around
lines 11 to 45, the constructor modifies the input dictionary 'attn' in-place,
which can cause unexpected side effects if the same dictionary is reused
elsewhere. To fix this, create a shallow copy of the 'attn' dictionary at the
start of the constructor before making any modifications, ensuring the original
input remains unchanged.

Comment on lines 458 to 459
if loss==0:
breakpoint()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove debug breakpoint from production code.

The breakpoint() call will halt execution and should never be in production code.

Apply this diff to remove the debug code:

-        if loss==0:
-            breakpoint()

If you need to debug zero loss issues, consider using proper logging instead:

if loss == 0:
    logger.warning("Computed loss is zero, which may indicate an issue with the model or data.")
🤖 Prompt for AI Agents
In fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py around lines 458
to 459, remove the debug breakpoint() call as it halts execution and should not
be in production code. Replace it with a logging statement that warns when loss
equals zero, such as using logger.warning to indicate a potential issue with the
model or data.


Args:
gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_memories].
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix typo in docstring.

Apply this diff:

-            Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_memories].
+            Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, sequence_length, num_memories].
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_memories].
Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, sequence_length, num_memories].
🤖 Prompt for AI Agents
In fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py at line 504,
there is a typo in the docstring where "seqeunce_length" should be corrected to
"sequence_length". Update the docstring to fix this spelling mistake.


config_class = MomGatedDeltaNetConfig
supports_gradient_checkpointing = True
_no_split_modules = ['GatedDeltaNetBlock']
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix incorrect module name in _no_split_modules.

The module name should be MomGatedDeltaNetBlock to match the actual class name.

Apply this diff:

-    _no_split_modules = ['GatedDeltaNetBlock']
+    _no_split_modules = ['MomGatedDeltaNetBlock']
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
_no_split_modules = ['GatedDeltaNetBlock']
_no_split_modules = ['MomGatedDeltaNetBlock']
🤖 Prompt for AI Agents
In fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py at line 165, the
module name in the _no_split_modules list is incorrect. Replace
'GatedDeltaNetBlock' with the correct class name 'MomGatedDeltaNetBlock' to
ensure it matches the actual class name used in the code.

Comment on lines 648 to 650
conv_state_q = [None, None],
conv_state_k = [None, None],
conv_state_v = [None, None],
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid mutable default arguments.

Using mutable objects as default arguments can lead to unexpected behavior.

Apply this diff:

-        conv_state_q = [None, None],
-        conv_state_k = [None, None],
-        conv_state_v = [None, None],
+        conv_state_q = None,
+        conv_state_k = None,
+        conv_state_v = None,

Then initialize them in the method:

if conv_state_q is None:
    conv_state_q = [None, None]
if conv_state_k is None:
    conv_state_k = [None, None]
if conv_state_v is None:
    conv_state_v = [None, None]
🧰 Tools
🪛 Ruff (0.11.9)

648-648: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


649-649: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


650-650: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)

🤖 Prompt for AI Agents
In fla/layers/mom_varlen.py around lines 648 to 650, the default arguments
conv_state_q, conv_state_k, and conv_state_v are mutable lists, which can cause
unexpected behavior. Change their default values to None in the function
signature, then inside the method initialize them with the provided code snippet
to assign [None, None] if they are None.

@zhiyuan1i
Copy link
Collaborator

@WKX933 Thank you for your PR!!
Could you please add tests in tests/test_model.py and tests/test_generation.py ?
And since I have introduece l2warp to all models, it would be good if you add some lines in MoM models, 82bb649

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
fla/layers/__init__.py (1)

29-29: Ambiguous class resolution still unresolved

There are two distinct MomGatedDeltaNet implementations (mom.py, mom_varlen.py), but only the fixed-length one is re-exported here. Call-sites requiring the var-len behaviour will silently receive the wrong class.
Consider aliasing the variants explicitly (e.g. MomGatedDeltaNetFixedLen, MomGatedDeltaNetVarLen) and exporting both.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2356954 and 34305b2.

📒 Files selected for processing (3)
  • fla/__init__.py (4 hunks)
  • fla/layers/__init__.py (2 hunks)
  • fla/models/__init__.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • fla/models/init.py
  • fla/init.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
fla/layers/__init__.py (2)
fla/layers/mom.py (1)
  • MomGatedDeltaNet (192-613)
fla/layers/mom_varlen.py (1)
  • MomGatedDeltaNet (309-724)
🪛 Ruff (0.11.9)
fla/layers/__init__.py

29-29: .mom.MomGatedDeltaNet imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)

🪛 Pylint (3.3.7)
fla/layers/__init__.py

[error] 49-50: Undefined variable name 'MomGatedDeltaNetMultiScaleRetention' in all

(E0603)

🪛 GitHub Actions: lint
fla/layers/__init__.py

[error] 1-1: Imports are not sorted properly; fixed by isort hook.


[error] 22-23: flake8: F401 '.mom.MomGatedDeltaNet' and '.multiscale_retention.MultiScaleRetention' imported but unused.

@yzhangcs yzhangcs linked an issue Jun 20, 2025 that may be closed by this pull request
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 11

♻️ Duplicate comments (9)
fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py (4)

39-47: Fix incorrect return type annotation.

The __init__ method should return None, not MomGatedDeltaNetMLP.


165-165: Fix incorrect module name in _no_split_modules.

The module name should be MomGatedDeltaNetBlock to match the actual class name.


243-243: Fix incorrect class name in warning message and add stacklevel.

The warning message should reference the correct class name and include proper stacklevel.


505-505: Fix typo in docstring.

The docstring contains a spelling error.

fla/layers/mom_varlen.py (5)

28-29: Remove duplicate import of pad_input.

The pad_input function is imported twice on consecutive lines.


375-375: Fix incorrect return type annotation.

The __init__ method should return None, not MomGatedDeltaNet.


403-403: Fix typo in assertion error message.

The assertion message contains a spelling error.


619-622: Fix critical bug: undefined variable 'o' when shared_mem is True.

The variable o is used before it's defined when self.shared_mem is True.


642-650: Fix mutable default arguments and formatting.

The method signature uses mutable default arguments which can cause unexpected behavior.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e043a67 and 63e1885.

📒 Files selected for processing (3)
  • fla/layers/__init__.py (2 hunks)
  • fla/layers/mom_varlen.py (1 hunks)
  • fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • fla/layers/init.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
fla/layers/mom_varlen.py (8)
fla/modules/fused_norm_gate.py (1)
  • FusedRMSNormSwishGate (1082-1098)
fla/modules/convolution.py (1)
  • ShortConvolution (544-741)
fla/modules/l2norm.py (1)
  • l2norm (253-258)
fla/ops/gated_delta_rule/chunk.py (1)
  • chunk_gated_delta_rule (225-342)
fla/ops/gated_delta_rule/fused_recurrent.py (1)
  • fused_recurrent_gated_delta_rule (212-314)
fla/models/utils.py (2)
  • Cache (11-148)
  • update (43-120)
fla/layers/utils.py (2)
  • pad_input (176-197)
  • unpad_input (101-173)
fla/layers/mom.py (6)
  • elu_p1 (25-26)
  • sum_norm (29-30)
  • transform (34-126)
  • reconstruct (129-189)
  • MomGatedDeltaNet (192-613)
  • forward (367-528)
🪛 Ruff (0.11.9)
fla/layers/mom_varlen.py

29-29: Redefinition of unused pad_input from line 28

Remove definition: pad_input

(F811)


505-505: Local variable batchsize is assigned to but never used

Remove assignment to unused variable batchsize

(F841)


543-543: Use kwargs.get('seq_idx') instead of kwargs.get('seq_idx', None)

Replace kwargs.get('seq_idx', None) with kwargs.get('seq_idx')

(SIM910)


579-579: Local variable offsets is assigned to but never used

Remove assignment to unused variable offsets

(F841)


579-579: Use kwargs.get('offsets') instead of kwargs.get('offsets', None)

Replace kwargs.get('offsets', None) with kwargs.get('offsets')

(SIM910)


648-648: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


649-649: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


650-650: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


666-666: Use kwargs.get('seq_idx') instead of kwargs.get('seq_idx', None)

Replace kwargs.get('seq_idx', None) with kwargs.get('seq_idx')

(SIM910)


695-695: Use kwargs.get('offsets') instead of kwargs.get('offsets', None)

Replace kwargs.get('offsets', None) with kwargs.get('offsets')

(SIM910)

fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py

243-243: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


359-365: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)

🪛 GitHub Actions: lint
fla/layers/mom_varlen.py

[error] 30-33: flake8: E302 expected 2 blank lines, found 1


[error] 45-45: flake8: E501 line too long (131 > 127 characters)


[error] 61-61: flake8: E303 too many blank lines (3)


[error] 71-73: flake8: E501 line too long (128-135 > 127 characters)


[error] 98-100: flake8: E501 line too long (170-178 > 127 characters)


[error] 141-146: flake8: E303 too many blank lines (3) and E302 expected 2 blank lines, found 1 and E501 line too long (128 > 127 characters)


[error] 212-212: flake8: E231 missing whitespace after ','


[error] 221-226: flake8: E501 line too long (185-213 > 127 characters)


[error] 239-243: flake8: E231 missing whitespace after ',' and E302 expected 2 blank lines, found 1 and E501 line too long (177 > 127 characters)


[error] 250-291: flake8: E501 line too long (128-137 > 127 characters), E203 whitespace before ',', E231 missing whitespace after ','


[error] 410-412: flake8: E222 multiple spaces after operator and E501 line too long (131-133 > 127 characters)


[error] 413-413: flake8: E501 line too long (132 > 127 characters)


[error] 502-502: flake8: F841 local variable 'batchsize' is assigned to but never used and E231 missing whitespace after ','


[error] 512-521: flake8: E501 line too long (162-223 > 127 characters)


[error] 533-540: flake8: E501 line too long (147-169 > 127 characters) and E225 missing whitespace around operator


[error] 543-553: flake8: E128 continuation line under-indented for visual indent and E231 missing whitespace after ','


[error] 565-568: flake8: E231 missing whitespace after ',' and E225 missing whitespace around operator


[error] 575-576: flake8: E501 line too long (129 > 127 characters) and F841 local variable 'offsets' is assigned to but never used


[error] 593-594: flake8: E231 missing whitespace after ',' (multiple occurrences)


[error] 596-617: flake8: E501 line too long (144-183 > 127 characters) and E231 missing whitespace after ','


[error] 639-647: flake8: E303 too many blank lines (2) and E251 unexpected spaces around keyword/parameter equals


[error] 663-675: flake8: E225 missing whitespace around operator, E128 continuation line under-indented for visual indent, and E231 missing whitespace after ','


[error] pre-commit hooks trailing-whitespace, end-of-file-fixer, and isort modified files indicating formatting issues.

fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py

[error] 202-206: flake8: E302 expected 2 blank lines, found 1


[error] 214-214: flake8: E501 line too long (128 > 127 characters)


[error] 312-317: flake8: E302 expected 2 blank lines, found 1


[error] 533-533: flake8: E303 too many blank lines (2)


[error] pre-commit hook isort modified file indicating import sorting issues.

Comment on lines 359 to 367
raise AttributeError(
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
f"which is not supported for {self.__class__.__name__}. "
f"Try another generation strategy instead. "
f"For the available generation strategies, check this doc: "
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
)
else:
raise exception
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve exception handling with proper chaining.

The current exception handling doesn't follow best practices for exception chaining.

Apply this improvement:

         except AttributeError as exception:
             if 'past_key_values' in str(exception):
                 raise AttributeError(
                     f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
                     f"which is not supported for {self.__class__.__name__}. "
                     f"Try another generation strategy instead. "
                     f"For the available generation strategies, check this doc: "
                     f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
-                )
+                ) from exception
             else:
-                raise exception
+                raise
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
raise AttributeError(
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
f"which is not supported for {self.__class__.__name__}. "
f"Try another generation strategy instead. "
f"For the available generation strategies, check this doc: "
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
)
else:
raise exception
except AttributeError as exception:
if 'past_key_values' in str(exception):
raise AttributeError(
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
f"which is not supported for {self.__class__.__name__}. "
f"Try another generation strategy instead. "
f"For the available generation strategies, check this doc: "
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
) from exception
else:
raise
🧰 Tools
🪛 Ruff (0.11.9)

359-365: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)

🪛 GitHub Actions: lint

[error] pre-commit hook isort modified file indicating import sorting issues.

🤖 Prompt for AI Agents
In fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py around lines 359
to 367, the exception handling raises a new AttributeError without chaining the
original exception, which loses the context of the initial error. Fix this by
using "raise ... from exception" syntax to properly chain the new AttributeError
to the original exception, preserving the traceback and error context.

Comment on lines 202 to 206
with torch.no_grad():
p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)

@dataclass
class MomGatedDeltaNetOutputWithPast(BaseModelOutputWithPast):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix formatting issues to comply with pipeline requirements.

The pipeline failures indicate missing blank lines and line length violations that need to be addressed.

Apply these formatting fixes:

+

 @dataclass
 class MomGatedDeltaNetOutputWithPast(BaseModelOutputWithPast):
     router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None

+
 class MomGatedDeltaNetModel(MomGatedDeltaNetPreTrainedModel):

Also fix the long line 214 by breaking it appropriately:

-        self.layers = nn.ModuleList([MomGatedDeltaNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
+        self.layers = nn.ModuleList([
+            MomGatedDeltaNetBlock(config, layer_idx) 
+            for layer_idx in range(config.num_hidden_layers)
+        ])
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
with torch.no_grad():
p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
@dataclass
class MomGatedDeltaNetOutputWithPast(BaseModelOutputWithPast):
with torch.no_grad():
p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
@dataclass
class MomGatedDeltaNetOutputWithPast(BaseModelOutputWithPast):
router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None
class MomGatedDeltaNetModel(MomGatedDeltaNetPreTrainedModel):
# …
Suggested change
with torch.no_grad():
p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
@dataclass
class MomGatedDeltaNetOutputWithPast(BaseModelOutputWithPast):
# break the long comprehension onto multiple lines for readability/line-length
self.layers = nn.ModuleList([
MomGatedDeltaNetBlock(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
])
🧰 Tools
🪛 GitHub Actions: lint

[error] 202-206: flake8: E302 expected 2 blank lines, found 1


[error] pre-commit hook isort modified file indicating import sorting issues.

🤖 Prompt for AI Agents
In fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py around lines 202
to 206, fix formatting issues by adding the required blank lines before the
@dataclass decorator and ensuring proper separation between code blocks.
Additionally, locate line 214 and break the long line into shorter lines to
comply with line length limits, using appropriate indentation for readability.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (4)
fla/layers/mom.py (4)

249-249: Fix incorrect return type annotation.


277-277: Fix typo in assertion error message.


497-500: Fix critical bug: undefined variable 'o' when shared_mem is True.


526-528: Avoid mutable default arguments.

🧹 Nitpick comments (7)
fla/layers/mom.py (7)

393-393: Remove trailing whitespace.

Apply this diff:

-        if self.use_gate:
-            o_g = self.g_proj(hidden_states)
-        
+        if self.use_gate:
+            o_g = self.g_proj(hidden_states)
+

418-429: Fix indentation for method arguments.

The continuation lines should be properly indented for better readability.

Apply this diff:

-            q, conv_state_q[0] = self.q_conv1d(x=q,
-                                            mask=conv_mask,
-                                            cache=conv_state_q[0],
-                                            output_final_state=use_cache,seq_idx=seq_idx)
-            k, conv_state_k[0] = self.k_conv1d(x=k,
-                                            mask=conv_mask,
-                                            cache=conv_state_k[0],
-                                            output_final_state=use_cache,seq_idx=seq_idx)
-            v, conv_state_v[0] = self.v_conv1d(x=v,
-                                            mask=conv_mask,
-                                            cache=conv_state_v[0],
-                                            output_final_state=use_cache,seq_idx=seq_idx)
+            q, conv_state_q[0] = self.q_conv1d(x=q,
+                                                mask=conv_mask,
+                                                cache=conv_state_q[0],
+                                                output_final_state=use_cache,
+                                                seq_idx=seq_idx)
+            k, conv_state_k[0] = self.k_conv1d(x=k,
+                                                mask=conv_mask,
+                                                cache=conv_state_k[0],
+                                                output_final_state=use_cache,
+                                                seq_idx=seq_idx)
+            v, conv_state_v[0] = self.v_conv1d(x=v,
+                                                mask=conv_mask,
+                                                cache=conv_state_v[0],
+                                                output_final_state=use_cache,
+                                                seq_idx=seq_idx)

462-463: Add missing whitespace after comma.

Apply this diff:

-                o_e = o_e[:,-max_len:,:,:].to(dtype=q[e].dtype)
+                o_e = o_e[:, -max_len:, :, :].to(dtype=q[e].dtype)

473-474: Add missing whitespace after comma.

Apply this diff:

-                if not hidden_states[e, 0].any() and hidden_states.shape[1] == 1:
-                    o_list[e] = torch.zeros_like(v[e,:,-max_len:,:,:])
+                if not hidden_states[e, 0].any() and hidden_states.shape[1] == 1:
+                    o_list[e] = torch.zeros_like(v[e, :, -max_len:, :, :])

486-486: Add missing whitespace after comma.

Apply this diff:

-                o_e = o_e[:,-max_len:,:,:]
+                o_e = o_e[:, -max_len:, :, :]

545-556: Fix indentation for method arguments.

The continuation lines should be properly indented for better readability.

Apply this diff:

-            q, conv_state_q[1] = self.q_conv1d(x=self.q_proj(hidden_states),
-                                            mask=conv_mask,
-                                            cache=conv_state_q[1],
-                                            output_final_state=use_cache,seq_idx=seq_idx)
-            k, conv_state_k[1] = self.k_conv1d(x=self.shared_k(hidden_states),
-                                            mask=conv_mask,
-                                            cache=conv_state_k[1],
-                                            output_final_state=use_cache,seq_idx=seq_idx)
-            v, conv_state_v[1] = self.v_conv1d(x=self.shared_v(hidden_states),
-                                            mask=conv_mask,
-                                            cache=conv_state_v[1],
-                                            output_final_state=use_cache,seq_idx=seq_idx)
+            q, conv_state_q[1] = self.q_conv1d(x=self.q_proj(hidden_states),
+                                                mask=conv_mask,
+                                                cache=conv_state_q[1],
+                                                output_final_state=use_cache,
+                                                seq_idx=seq_idx)
+            k, conv_state_k[1] = self.k_conv1d(x=self.shared_k(hidden_states),
+                                                mask=conv_mask,
+                                                cache=conv_state_k[1],
+                                                output_final_state=use_cache,
+                                                seq_idx=seq_idx)
+            v, conv_state_v[1] = self.v_conv1d(x=self.shared_v(hidden_states),
+                                                mask=conv_mask,
+                                                cache=conv_state_v[1],
+                                                output_final_state=use_cache,
+                                                seq_idx=seq_idx)

183-518: Add tests and l2warp support as requested in PR comments.

Based on the PR review comments:

  1. Tests should be added for this new module in tests/test_model.py and tests/test_generation.py
  2. Consider adding l2warp support similar to other models in the codebase

Would you like me to help generate the test cases for the MomGatedDeltaNet module?

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cefb647 and dc7cf66.

📒 Files selected for processing (1)
  • fla/layers/mom.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.11.9)
fla/layers/mom.py

416-416: Use kwargs.get('seq_idx') instead of kwargs.get('seq_idx', None)

Replace kwargs.get('seq_idx', None) with kwargs.get('seq_idx')

(SIM910)


447-447: Use kwargs.get('cu_seqlens') instead of kwargs.get('cu_seqlens', None)

Replace kwargs.get('cu_seqlens', None) with kwargs.get('cu_seqlens')

(SIM910)


526-526: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


527-527: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


528-528: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


544-544: Use kwargs.get('seq_idx') instead of kwargs.get('seq_idx', None)

Replace kwargs.get('seq_idx', None) with kwargs.get('seq_idx')

(SIM910)


573-573: Use kwargs.get('cu_seqlens') instead of kwargs.get('cu_seqlens', None)

Replace kwargs.get('cu_seqlens', None) with kwargs.get('cu_seqlens')

(SIM910)

🪛 Flake8 (7.2.0)
fla/layers/mom.py

[error] 419-419: continuation line under-indented for visual indent

(E128)


[error] 420-420: continuation line under-indented for visual indent

(E128)


[error] 421-421: continuation line under-indented for visual indent

(E128)


[error] 423-423: continuation line under-indented for visual indent

(E128)


[error] 424-424: continuation line under-indented for visual indent

(E128)


[error] 425-425: continuation line under-indented for visual indent

(E128)


[error] 427-427: continuation line under-indented for visual indent

(E128)


[error] 428-428: continuation line under-indented for visual indent

(E128)


[error] 429-429: continuation line under-indented for visual indent

(E128)


[error] 546-546: continuation line under-indented for visual indent

(E128)


[error] 547-547: continuation line under-indented for visual indent

(E128)


[error] 548-548: continuation line under-indented for visual indent

(E128)


[error] 550-550: continuation line under-indented for visual indent

(E128)


[error] 551-551: continuation line under-indented for visual indent

(E128)


[error] 552-552: continuation line under-indented for visual indent

(E128)


[error] 554-554: continuation line under-indented for visual indent

(E128)


[error] 555-555: continuation line under-indented for visual indent

(E128)


[error] 556-556: continuation line under-indented for visual indent

(E128)

🪛 GitHub Actions: lint
fla/layers/mom.py

[error] 24-24: flake8: line too long (128 > 127 characters) (E501)


[error] 90-90: flake8: missing whitespace after ',' (E231)


[error] 99-99: flake8: line too long (213 > 127 characters) (E501)


[error] 104-104: flake8: line too long (185 > 127 characters) (E501)


[error] 119-119: flake8: expected 2 blank lines, found 1 (E302)


[error] 119-119: flake8: line too long (177 > 127 characters) (E501)


[error] 126-126: flake8: line too long (128 > 127 characters) (E501)


[error] 159-159: flake8: line too long (135 > 127 characters) (E501)


[error] 161-161: flake8: line too long (137 > 127 characters) (E501)


[error] 167-167: flake8: whitespace before ',' (E203) and missing whitespace after ',' (E231)


[error] 286-288: flake8: multiple spaces after operator (E222) and line too long (131-133 > 127 characters) (E501)


[error] 287-289: flake8: line too long (132 > 127 characters) (E501)


[error] 387-387: flake8: line too long (223 > 127 characters) (E501)


[error] 396-396: flake8: line too long (155 > 127 characters) (E501)


[error] 408-408: flake8: line too long (169 > 127 characters) (E501)


[error] 414-415: flake8: line too long (147 > 127 characters) (E501) and missing whitespace around operator (E225)


[error] 418-428: flake8: continuation line under-indented for visual indent (E128) and missing whitespace after ',' (E231)


[error] 445-445: flake8: line too long (145 > 127 characters) (E501)


[error] 462-462: flake8: missing whitespace after ',' (E231)


[error] 466-466: flake8: line too long (183 > 127 characters) (E501)


[error] 473-473: flake8: missing whitespace after ',' (E231)


[error] 485-485: flake8: missing whitespace after ',' (E231)


[error] 494-497: flake8: line too long (144-183 > 127 characters) (E501)


[error] 519-519: flake8: too many blank lines (2) (E303)


[error] 523-527: flake8: unexpected spaces around keyword/parameter equals (E251)


[error] 543-555: flake8: missing whitespace around operator (E225), continuation line under-indented (E128), and missing whitespace after ',' (E231)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-models
🔇 Additional comments (1)
fla/layers/mom.py (1)

415-415: Add missing whitespace before 'if'.

Apply this diff:

-            conv_mask = attention_mask[:, -hidden_states.shape[2]:].repeat_interleave(self.num_memories, 0)if attention_mask is not None else None
+            conv_mask = attention_mask[:, -hidden_states.shape[2]:].repeat_interleave(self.num_memories, 0) if attention_mask is not None else None

Likely an incorrect or invalid review comment.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

♻️ Duplicate comments (7)
fla/layers/mom_varlen.py (7)

27-29: Remove duplicate import of pad_input.

The pad_input function is imported twice on consecutive lines.

Apply this diff:

 if is_flash_attn_2_available():
     from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input 
-    from flash_attn.bert_padding import pad_input

356-375: Fix incorrect return type annotation.

The __init__ method should return None, not MomGatedDeltaNet.

Apply this diff:

-    ) -> MomGatedDeltaNet:
+    ) -> None:

403-403: Fix typo in assertion error message.

Apply this diff:

-        assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
+        assert mode in ['chunk', 'fused_recurrent'], f"Not supported mode `{mode}`."

586-586: Remove unused variable offsets.

The variable is assigned but never used.

Apply this diff:

         recurrent_state = last_state['recurrent_state'] if last_state is not None else [None for _ in range(1 + self.shared_mem)]
-        offsets = kwargs.get('offsets', None)
         # Note: In the updated version of FLA, "offset" has been renamed to "cu_seqlens".

626-629: Fix critical bug: undefined variable 'o' when shared_mem is True.

The variable o is used before being defined when self.shared_mem is True.

You need to ensure o is defined before this block. The logic suggests that o should already be computed from the reconstruction step above. If shared_mem is True but neither 'chunk' nor 'fused_recurrent' mode was executed, this will fail.

Consider adding a check or initializing o appropriately:

         if self.shared_mem:
+            if 'o' not in locals():
+                raise RuntimeError("Output 'o' not computed before shared memory processing")
             shared_o = self.shared_o(shared_hidden_states, attention_mask, recurrent_state, use_cache, conv_state_q, conv_state_k, conv_state_v)
             o += shared_o

655-657: Avoid mutable default arguments.

Using mutable objects as default arguments can lead to unexpected behavior.

Apply this diff:

-        conv_state_q = [None, None],
-        conv_state_k = [None, None],
-        conv_state_v = [None, None],
+        conv_state_q = None,
+        conv_state_k = None,
+        conv_state_v = None,

Then initialize them in the method:

if conv_state_q is None:
    conv_state_q = [None, None]
if conv_state_k is None:
    conv_state_k = [None, None]
if conv_state_v is None:
    conv_state_v = [None, None]

508-508: Remove unused variable batchsize.

The variable is assigned but never used.

Apply this diff:

-        batchsize,q_len = hidden_states.shape[0],hidden_states.shape[1]
+        batch_size, q_len = hidden_states.shape[0], hidden_states.shape[1]

Note: Ensure to update all references from batch_size variable defined on line 524 to use this one instead.

🧹 Nitpick comments (3)
fla/models/mom/modeling_mom.py (1)

267-269: Use standard logger.warning method.

The logger.warning_once method may not be available in all logging configurations.

Apply this diff:

         if self.gradient_checkpointing and self.training and use_cache:
-            logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
+            logger.warning("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
             use_cache = False
fla/layers/mom_varlen.py (2)

215-218: Fix formatting issues in transform function.

Missing whitespace after comma and inconsistent spacing.

Apply this diff:

-        memory_batch_offset = offset.transpose(0,1)
+        memory_batch_offset = offset.transpose(0, 1)
         batch_offset = torch.arange(0, b*s, s, device=offset.device)

411-422: Consider parameter organization for shared memory projections.

The initialization of projection layers has inconsistent indentation and could be better organized.

Apply this diff for better readability:

         else:
-            self.k_proj =  nn.ModuleList([nn.Linear(self.hidden_size, self.key_dim, bias=False) for _ in range(self.num_memories)])
-            self.v_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.value_dim, bias=False) for _ in range(self.num_memories)])
-            self.b_proj =  nn.ModuleList([nn.Linear(self.hidden_size, self.num_heads, bias=False) for _ in range(self.num_memories)])
-            self.a_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_heads, bias=False) for _ in range(self.num_memories)])
+            self.k_proj = nn.ModuleList([
+                nn.Linear(self.hidden_size, self.key_dim, bias=False) 
+                for _ in range(self.num_memories)
+            ])
+            self.v_proj = nn.ModuleList([
+                nn.Linear(self.hidden_size, self.value_dim, bias=False) 
+                for _ in range(self.num_memories)
+            ])
+            self.b_proj = nn.ModuleList([
+                nn.Linear(self.hidden_size, self.num_heads, bias=False) 
+                for _ in range(self.num_memories)
+            ])
+            self.a_proj = nn.ModuleList([
+                nn.Linear(self.hidden_size, self.num_heads, bias=False) 
+                for _ in range(self.num_memories)
+            ])
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between dc7cf66 and e274e44.

📒 Files selected for processing (6)
  • fla/__init__.py (3 hunks)
  • fla/layers/mom_varlen.py (1 hunks)
  • fla/models/__init__.py (2 hunks)
  • fla/models/mom/__init__.py (1 hunks)
  • fla/models/mom/configuration_mom.py (1 hunks)
  • fla/models/mom/modeling_mom.py (1 hunks)
✅ Files skipped from review due to trivial changes (2)
  • fla/models/init.py
  • fla/models/mom/init.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • fla/init.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
fla/layers/mom_varlen.py (8)
fla/modules/fused_norm_gate.py (1)
  • FusedRMSNormSwishGate (1078-1094)
fla/modules/convolution.py (1)
  • ShortConvolution (548-745)
fla/modules/l2norm.py (1)
  • l2norm (262-267)
fla/ops/gated_delta_rule/chunk.py (1)
  • chunk_gated_delta_rule (221-340)
fla/ops/gated_delta_rule/fused_recurrent.py (1)
  • fused_recurrent_gated_delta_rule (240-351)
fla/models/utils.py (2)
  • Cache (11-150)
  • update (43-122)
fla/layers/utils.py (2)
  • pad_input (176-197)
  • unpad_input (101-173)
fla/layers/mom.py (4)
  • transform (25-117)
  • reconstruct (120-180)
  • MomGatedDeltaNet (183-600)
  • forward (358-517)
🪛 Ruff (0.11.9)
fla/models/mom/configuration_mom.py

78-78: Test for membership should be not in

Convert to not in

(E713)

fla/models/mom/modeling_mom.py

247-247: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


364-370: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)

fla/layers/mom_varlen.py

28-28: Redefinition of unused pad_input from line 27

Remove definition: pad_input

(F811)


508-508: Local variable batchsize is assigned to but never used

Remove assignment to unused variable batchsize

(F841)


586-586: Local variable offsets is assigned to but never used

Remove assignment to unused variable offsets

(F841)


586-586: Use kwargs.get('offsets') instead of kwargs.get('offsets', None)

Replace kwargs.get('offsets', None) with kwargs.get('offsets')

(SIM910)


655-655: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


656-656: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


657-657: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


706-706: Use kwargs.get('offsets') instead of kwargs.get('offsets', None)

Replace kwargs.get('offsets', None) with kwargs.get('offsets')

(SIM910)

🪛 Flake8 (7.2.0)
fla/models/mom/configuration_mom.py

[error] 78-78: test for membership should be 'not in'

(E713)

fla/layers/mom_varlen.py

[error] 28-28: redefinition of unused 'pad_input' from line 27

(F811)


[error] 508-508: local variable 'batchsize' is assigned to but never used

(F841)


[error] 586-586: local variable 'offsets' is assigned to but never used

(F841)

🪛 GitHub Actions: lint
fla/models/mom/configuration_mom.py

[error] 78-78: flake8: test for membership should be 'not in' (E713)

fla/models/mom/modeling_mom.py

[error] 210-210: flake8: expected 2 blank lines, found 1 (E302)


[error] 322-322: flake8: expected 2 blank lines, found 1 (E302)

fla/layers/mom_varlen.py

[error] 47-47: flake8: line too long (131 > 127 characters) (E501)


[error] 72-74: flake8: line too long (128-135 > 127 characters) (E501)


[error] 99-101: flake8: line too long (170-178 > 127 characters) (E501)


[error] 142-142: flake8: too many blank lines (3) (E303)


[error] 147-147: flake8: expected 2 blank lines, found 1 (E302) and line too long (128 > 127 characters) (E501)


[error] 213-213: flake8: missing whitespace after ',' (E231)


[error] 222-227: flake8: line too long (185-213 > 127 characters) (E501)


[error] 244-244: flake8: expected 2 blank lines, found 1 (E302) and line too long (177 > 127 characters) (E501)


[error] 251-251: flake8: line too long (128 > 127 characters) (E501)


[error] 284-286: flake8: line too long (135-137 > 127 characters) (E501)


[error] 292-292: flake8: whitespace before ',' (E203) and missing whitespace after ',' (E231)


[error] 411-413: flake8: multiple spaces after operator (E222) and line too long (131-133 > 127 characters) (E501)


[error] 414-414: flake8: line too long (132 > 127 characters) (E501)


[error] 506-506: flake8: local variable 'batchsize' is assigned to but never used (F841) and missing whitespace after ',' (E231)


[error] 516-516: flake8: line too long (223 > 127 characters) (E501)


[error] 525-537: flake8: line too long (129-169 > 127 characters) (E501) and missing whitespace around operator (E225) on line 576


[error] 584-584: flake8: local variable 'offsets' is assigned to but never used (F841)


[error] 601-602: flake8: missing whitespace after ',' (E231) multiple times


[error] 604-604: flake8: line too long (183 > 127 characters) (E501)


[error] 621-625: flake8: missing whitespace after ',' (E231) multiple times and line too long (144-183 > 127 characters) (E501)


[error] 647-647: flake8: too many blank lines (2) (E303)


[error] 651-655: flake8: unexpected spaces around keyword/parameter equals (E251) multiple times

🔇 Additional comments (2)
fla/models/mom/modeling_mom.py (1)

271-301: Keep all_attns – it’s required for returning attentions

The all_attns variable is neither dead nor local-only: it’s populated when output_attentions=True and included in the method’s return value. Removing it here would break the interface and diverge from the shared pattern in other model implementations.

• In fla/models/mom/modeling_mom.py:
– Line 300: all_attns += (attentions,)
– Line 310: return tuple(..., all_attns] if i is not None)
– Line 315: returned as attentions=all_attns

• This same pattern appears across all other modeling_* classes (e.g., Transformer, COMBA, DeltaNet, etc.), ensuring a consistent API.

Please ignore the suggestion to remove all_attns; it’s necessary for users requesting attentions.

Likely an incorrect or invalid review comment.

fla/layers/mom_varlen.py (1)

1-735: Adjust test file locations to match repository structure

The repository doesn’t have tests/test_model.py or tests/test_generation.py. Model tests are under tests/models/ as test_modeling_<model>.py, and there’s no generation test directory at the root. To add coverage for the new MoM modules:

• Create tests/models/test_modeling_mom.py and add unit tests for MomGatedDeltaNet (e.g. forward pass, gating, capacity handling).
• If you need generation‐style tests for MoM (e.g. end‐to‐end text generation), add a new file under a tests/generation/ directory (e.g. tests/generation/test_generation_mom.py), mirroring your existing generation tests setup.

Likely an incorrect or invalid review comment.

expert_mask = torch.nn.functional.one_hot(selected_experts, num_memories)

# For a given token, determine if it was routed to a given expert.
expert_mask = torch.max(expert_mask, axis=-2).values
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Use 'dim' instead of 'axis' for torch.max.

PyTorch uses dim parameter, not axis.

Apply this diff:

-        expert_mask = torch.max(expert_mask, axis=-2).values
+        expert_mask = torch.max(expert_mask, dim=-2).values
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
expert_mask = torch.max(expert_mask, axis=-2).values
expert_mask = torch.max(expert_mask, dim=-2).values
🤖 Prompt for AI Agents
In fla/models/mom/modeling_mom.py at line 550, replace the use of 'axis' with
'dim' in the torch.max function call. Change torch.max(expert_mask,
axis=-2).values to torch.max(expert_mask, dim=-2).values to comply with
PyTorch's API.

Comment on lines 246 to 248
if output_attentions:
warnings.warn("`GatedDeltaNetModel` does not `output_attentions` now, setting it to `False`.")
output_attentions = False
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix incorrect model name in warning and add stacklevel.

The warning refers to GatedDeltaNetModel but should refer to MomModel. Also, add stacklevel parameter to the warning.

Apply this diff:

         if output_attentions:
-            warnings.warn("`GatedDeltaNetModel` does not `output_attentions` now, setting it to `False`.")
+            warnings.warn("`MomModel` does not support `output_attentions` now, setting it to `False`.", stacklevel=2)
             output_attentions = False
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if output_attentions:
warnings.warn("`GatedDeltaNetModel` does not `output_attentions` now, setting it to `False`.")
output_attentions = False
if output_attentions:
warnings.warn(
"`MomModel` does not support `output_attentions` now, setting it to `False`.",
stacklevel=2,
)
output_attentions = False
🧰 Tools
🪛 Ruff (0.11.9)

247-247: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)

🤖 Prompt for AI Agents
In fla/models/mom/modeling_mom.py around lines 246 to 248, the warning message
incorrectly refers to `GatedDeltaNetModel` instead of `MomModel` and lacks the
`stacklevel` parameter. Update the warning string to mention `MomModel` and add
the `stacklevel=2` argument to the warnings.warn call to correctly point to the
caller in the stack trace.

Comment on lines 470 to 476
aux_loss = load_balancing_loss_func(
valid_router_logits,
self.num_memories,
self.topk,
use_layer_wise_balance=self.config.use_layer_wise_balance, # ✨
)
aux_loss *= self.aux_loss_scale
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

❓ Verification inconclusive

Ensure l2wrap support is included.

According to the PR objectives, l2wrap functionality should be included in the new MoM models.


🏁 Script executed:

#!/bin/bash
# Description: Check if l2wrap is implemented in other models and how it should be integrated

# Search for l2wrap implementation in other models
rg -A 5 "l2wrap" --glob "fla/models/**/*.py"

# Check the commit that introduced l2wrap to all models
git log --grep="l2wrap" --oneline | head -5

Length of output: 193


🏁 Script executed:

#!/bin/bash
set -euxo pipefail

# Search for any l2wrap references in model implementations
rg -A5 "l2wrap" --glob 'fla/models/**/*.py'

# Show the last few commits mentioning l2wrap to understand its introduction
git log --grep="l2wrap" --oneline -n 5

Length of output: 64


🏁 Script executed:

#!/usr/bin/env bash
set -euxo pipefail

# Ensure the repo is cloned so files can be searched
# Search for any l2wrap references under fla/models
rg -A5 "l2wrap" fla/models

# Look for commits mentioning l2wrap to see how and when it was introduced
git log --grep="l2wrap" --oneline -n 5

Length of output: 47


Missing l2wrap integration in MoM model
Our searches didn’t uncover any existing l2wrap references in the repo, so the new MoM implementation currently omits this required component. Please add l2wrap support alongside the load-balancing auxiliary loss.

• File to update: fla/models/mom/modeling_mom.py (at lines 470–476)
• Suggested change (example):

             aux_loss = load_balancing_loss_func(
                 valid_router_logits,
                 self.num_memories,
                 self.topk,
-                use_layer_wise_balance=self.config.use_layer_wise_balance,
+                use_layer_wise_balance=self.config.use_layer_wise_balance,
             )
-            aux_loss *= self.aux_loss_scale
+            # integrate l2wrap auxiliary loss
+            wrap_loss = l2wrap_loss_func(
+                hidden_states,               # or appropriate tensor
+                self.config.l2wrap_scale,    # configurable scale
+            )
+            aux_loss = aux_loss * self.aux_loss_scale + wrap_loss

• Ensure you import l2wrap_loss_func and expose any new config.l2wrap_scale setting in the model’s config.
• Add or update unit tests to cover the l2wrap loss path.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
aux_loss = load_balancing_loss_func(
valid_router_logits,
self.num_memories,
self.topk,
use_layer_wise_balance=self.config.use_layer_wise_balance, # ✨
)
aux_loss *= self.aux_loss_scale
aux_loss = load_balancing_loss_func(
valid_router_logits,
self.num_memories,
self.topk,
use_layer_wise_balance=self.config.use_layer_wise_balance, # ✨
)
# integrate l2wrap auxiliary loss
wrap_loss = l2wrap_loss_func(
hidden_states, # or appropriate tensor
self.config.l2wrap_scale, # configurable scale
)
aux_loss = aux_loss * self.aux_loss_scale + wrap_loss
🤖 Prompt for AI Agents
In fla/models/mom/modeling_mom.py around lines 470 to 476, the MoM model
currently lacks integration of the l2wrap loss function. To fix this, first
import l2wrap_loss_func at the top of the file and add a new configuration
parameter l2wrap_scale to the model's config. Then, compute the l2wrap loss
alongside the existing load_balancing_loss_func call, scaling it by
config.l2wrap_scale, and include it in the total auxiliary loss. Finally, update
or add unit tests to cover the new l2wrap loss functionality to ensure proper
test coverage.

Comment on lines 224 to 230
indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand(b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1)
# discard tokens exceed capacity and is far from now
# left pad
truncation_indices = indices + batch_memory_tokens.reshape((-1,)).unsqueeze(-1) - max_len
mask = torch.bitwise_and(truncation_indices < flatten_offset.unsqueeze(-1), truncation_indices >= 0)
mask = torch.bitwise_and(mask, truncation_indices >= torch.cat((torch.zeros((1,), dtype=flatten_offset.dtype, device=flatten_offset.device), flatten_offset[:-1])).unsqueeze(-1))
truncation_indices = torch.where(mask, truncation_indices, torch.zeros_like(truncation_indices))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Break long lines for better readability.

Several lines exceed the maximum line length limit.

Apply this diff:

-        indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand(b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1)
+        indices = (
+            torch.arange(max_len, device=flatten_offset.device)
+            .unsqueeze(0)
+            .expand(b*num_memories, -1) 
+            + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0)
+            .unsqueeze(1)
+        )
         # discard tokens exceed capacity and is far from now
         # left pad
         truncation_indices = indices + batch_memory_tokens.reshape((-1,)).unsqueeze(-1) - max_len
         mask = torch.bitwise_and(truncation_indices < flatten_offset.unsqueeze(-1), truncation_indices >= 0)
-        mask = torch.bitwise_and(mask, truncation_indices >= torch.cat((torch.zeros((1,), dtype=flatten_offset.dtype, device=flatten_offset.device), flatten_offset[:-1])).unsqueeze(-1))
+        mask = torch.bitwise_and(
+            mask, 
+            truncation_indices >= torch.cat(
+                (torch.zeros((1,), dtype=flatten_offset.dtype, device=flatten_offset.device), 
+                 flatten_offset[:-1])
+            ).unsqueeze(-1)
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand(b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1)
# discard tokens exceed capacity and is far from now
# left pad
truncation_indices = indices + batch_memory_tokens.reshape((-1,)).unsqueeze(-1) - max_len
mask = torch.bitwise_and(truncation_indices < flatten_offset.unsqueeze(-1), truncation_indices >= 0)
mask = torch.bitwise_and(mask, truncation_indices >= torch.cat((torch.zeros((1,), dtype=flatten_offset.dtype, device=flatten_offset.device), flatten_offset[:-1])).unsqueeze(-1))
truncation_indices = torch.where(mask, truncation_indices, torch.zeros_like(truncation_indices))
indices = (
torch.arange(max_len, device=flatten_offset.device)
.unsqueeze(0)
.expand(b * num_memories, -1)
+ torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0)
.unsqueeze(1)
)
# discard tokens exceed capacity and is far from now
# left pad
truncation_indices = indices + batch_memory_tokens.reshape((-1,)).unsqueeze(-1) - max_len
mask = torch.bitwise_and(
truncation_indices < flatten_offset.unsqueeze(-1),
truncation_indices >= 0
)
mask = torch.bitwise_and(
mask,
truncation_indices >= torch.cat(
(torch.zeros((1,), dtype=flatten_offset.dtype, device=flatten_offset.device),
flatten_offset[:-1])
).unsqueeze(-1)
)
truncation_indices = torch.where(mask, truncation_indices, torch.zeros_like(truncation_indices))
🤖 Prompt for AI Agents
In fla/layers/mom_varlen.py around lines 224 to 230, the code contains several
long lines that exceed the maximum line length limit, reducing readability.
Break these long lines into multiple shorter lines by splitting expressions at
logical points such as after operators or commas, and aligning subsequent lines
with proper indentation to maintain clarity and adhere to style guidelines.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WKX933 I think it would better if merging the two file mom.py and mom_varlen.py

self.single_kv_proj = single_kv_proj
self.mom_backend = mom_backend

if not self.mom_backend in ['GDN']:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JusenD Could we change the mom_backend name to model_type registered, e.g., GatedDelteNetConfig.model_type?

@@ -19,6 +19,7 @@
from .mamba import Mamba
from .mamba2 import Mamba2
from .mesa_net import MesaNet
from .mom_varlen import MomGatedDeltaNet
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JusenD Let's rename this module to MomAttention

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (3)
fla/layers/mom.py (1)

514-516: Add missing kwargs parameter to shared_o method call.

The shared_o method expects **kwargs but it's not being passed in the call.

         if self.shared_mem:
             shared_o = self.shared_o(shared_hidden_states, attention_mask, recurrent_state,
-                                     use_cache, conv_state_q, conv_state_k, conv_state_v)
+                                     use_cache, conv_state_q, conv_state_k, conv_state_v, **kwargs)
             o += shared_o
fla/layers/mom_varlen.py (2)

516-516: Remove unused variable batchsize.

The variable batchsize is assigned but never used.

-        batchsize, q_len = hidden_states.shape[0], hidden_states.shape[1]
+        q_len = hidden_states.shape[1]

641-643: Add missing kwargs parameter to shared_o method call.

The shared_o method expects **kwargs but it's not being passed in the call.

         if self.shared_mem:
             shared_o = self.shared_o(shared_hidden_states, attention_mask, recurrent_state,
-                                     use_cache, conv_state_q, conv_state_k, conv_state_v)
+                                     use_cache, conv_state_q, conv_state_k, conv_state_v, **kwargs)
             o += shared_o
🧹 Nitpick comments (1)
fla/models/mom/modeling_mom.py (1)

362-368: Improve exception handling with proper chaining.

When re-raising exceptions, use raise ... from to preserve the exception chain.

-            raise AttributeError(
-                f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
-                f"which is not supported for {self.__class__.__name__}. "
-                f"Try another generation strategy instead. "
-                f"For the available generation strategies, check this doc: "
-                f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
-            )
+            raise AttributeError(
+                f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
+                f"which is not supported for {self.__class__.__name__}. "
+                f"Try another generation strategy instead. "
+                f"For the available generation strategies, check this doc: "
+                f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
+            ) from exception
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e274e44 and 7ab9b43.

📒 Files selected for processing (6)
  • fla/__init__.py (3 hunks)
  • fla/layers/__init__.py (2 hunks)
  • fla/layers/mom.py (1 hunks)
  • fla/layers/mom_varlen.py (1 hunks)
  • fla/models/mom/configuration_mom.py (1 hunks)
  • fla/models/mom/modeling_mom.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • fla/layers/init.py
  • fla/models/mom/configuration_mom.py
  • fla/init.py
🧰 Additional context used
🧠 Learnings (1)
fla/layers/mom_varlen.py (1)
Learnt from: toothacher17
PR: fla-org/flash-linear-attention#395
File: fla/layers/mla.py:0-0
Timestamp: 2025-05-13T06:04:24.342Z
Learning: In Flash Attention's `flash_attn_varlen_func`, the scaling factor (typically computed as head_dim ** -0.5) is handled internally by the Flash Attention implementation rather than needing to be applied explicitly to query projections.
🪛 Ruff (0.11.9)
fla/layers/mom.py

429-429: Use kwargs.get('seq_idx') instead of kwargs.get('seq_idx', None)

Replace kwargs.get('seq_idx', None) with kwargs.get('seq_idx')

(SIM910)


461-461: Use kwargs.get('cu_seqlens') instead of kwargs.get('cu_seqlens', None)

Replace kwargs.get('cu_seqlens', None) with kwargs.get('cu_seqlens')

(SIM910)


542-542: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


543-543: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


544-544: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


560-560: Use kwargs.get('seq_idx') instead of kwargs.get('seq_idx', None)

Replace kwargs.get('seq_idx', None) with kwargs.get('seq_idx')

(SIM910)


589-589: Use kwargs.get('cu_seqlens') instead of kwargs.get('cu_seqlens', None)

Replace kwargs.get('cu_seqlens', None) with kwargs.get('cu_seqlens')

(SIM910)

fla/layers/mom_varlen.py

516-516: Local variable batchsize is assigned to but never used

Remove assignment to unused variable batchsize

(F841)


598-598: Local variable offsets is assigned to but never used

Remove assignment to unused variable offsets

(F841)


598-598: Use kwargs.get('offsets') instead of kwargs.get('offsets', None)

Replace kwargs.get('offsets', None) with kwargs.get('offsets')

(SIM910)


669-669: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


670-670: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


671-671: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


720-720: Use kwargs.get('offsets') instead of kwargs.get('offsets', None)

Replace kwargs.get('offsets', None) with kwargs.get('offsets')

(SIM910)

fla/models/mom/modeling_mom.py

244-244: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


362-368: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)

🪛 Flake8 (7.2.0)
fla/layers/mom.py

[error] 428-428: continuation line under-indented for visual indent

(E128)

fla/layers/mom_varlen.py

[error] 516-516: local variable 'batchsize' is assigned to but never used

(F841)


[error] 598-598: local variable 'offsets' is assigned to but never used

(F841)

🪛 GitHub Actions: lint
fla/layers/mom.py

[error] 24-24: flake8: line too long (128 > 127 characters) (E501)


[error] 100-100: flake8: line too long (135 > 127 characters) (E501)


[error] 123-123: flake8: line too long (177 > 127 characters) (E501)


[error] 130-130: flake8: line too long (128 > 127 characters) (E501)


[error] 398-398: flake8: line too long (150 > 127 characters) (E501)


[error] 427-427: flake8: whitespace before ':' (E203)


[error] 428-428: flake8: continuation line under-indented for visual indent (E128)

fla/layers/mom_varlen.py

[error] 47-47: flake8: line too long (131 > 127 characters) (E501)


[error] 72-72: flake8: line too long (128 > 127 characters) (E501)


[error] 74-74: flake8: line too long (135 > 127 characters) (E501)


[error] 99-99: flake8: line too long (178 > 127 characters) (E501)


[error] 101-101: flake8: line too long (170 > 127 characters) (E501)


[error] 147-147: flake8: line too long (128 > 127 characters) (E501)


[error] 223-223: flake8: line too long (135 > 127 characters) (E501)


[error] 248-248: flake8: line too long (177 > 127 characters) (E501)


[error] 255-255: flake8: line too long (128 > 127 characters) (E501)


[error] 516-516: flake8: local variable 'batchsize' is assigned to but never used (F841)


[error] 527-527: flake8: line too long (150 > 127 characters) (E501)


[error] 598-598: flake8: local variable 'offsets' is assigned to but never used (F841)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-models
🔇 Additional comments (1)
fla/models/mom/modeling_mom.py (1)

398-403: Remove duplicate key in dictionary.

The num_logits_to_keep key appears twice in the model_inputs dictionary, which will cause the first value to be overridden.

         model_inputs.update({
             'past_key_values': past_key_values,
             'use_cache': use_cache,
             'attention_mask': attention_mask,
-            'num_logits_to_keep': num_logits_to_keep,
         })

Likely an incorrect or invalid review comment.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (4)
fla/layers/mom_varlen.py (2)

372-372: Fix incorrect return type annotation.

The __init__ method should return None, not MomAttention.

Apply this diff:

-    ) -> MomAttention:
+    ) -> None:

659-661: Avoid mutable default arguments.

Using mutable objects as default arguments can lead to unexpected behavior when the same default object is reused across function calls.

Apply this diff:

-        conv_state_q=[None, None],
-        conv_state_k=[None, None],
-        conv_state_v=[None, None],
+        conv_state_q=None,
+        conv_state_k=None,
+        conv_state_v=None,

Then initialize them in the method:

if conv_state_q is None:
    conv_state_q = [None, None]
if conv_state_k is None:
    conv_state_k = [None, None]
if conv_state_v is None:
    conv_state_v = [None, None]
fla/layers/mom.py (2)

253-253: Fix incorrect return type annotation.

The __init__ method should return None, not MomAttention.

Apply this diff:

-    ) -> MomAttention:
+    ) -> None:

541-543: Avoid mutable default arguments.

Using mutable objects as default arguments can lead to unexpected behavior when the same default object is reused across function calls.

Apply this diff:

-        conv_state_q=[None, None],
-        conv_state_k=[None, None],
-        conv_state_v=[None, None],
+        conv_state_q=None,
+        conv_state_k=None,
+        conv_state_v=None,

Then initialize them in the method:

if conv_state_q is None:
    conv_state_q = [None, None]
if conv_state_k is None:
    conv_state_k = [None, None]
if conv_state_v is None:
    conv_state_v = [None, None]
🧹 Nitpick comments (5)
fla/layers/mom_varlen.py (2)

46-46: Fix line length violations.

Multiple lines exceed the 127 character limit, reducing readability.

Break these long lines for better readability. For example:

-        cu_seqlens (`torch.Tensor`):
-            The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
+        cu_seqlens (`torch.Tensor`):
+            The cumulative sequence lengths, used to index into ragged (unpadded) tensors. 
+            `cu_seqlens` shape is (batch_size + 1,).

Apply similar formatting to other long lines to maintain consistency with the project's style guidelines.

Also applies to: 71-73, 98-100, 140-140, 216-216, 241-241, 248-248, 520-520


591-591: Simplify kwargs.get() calls.

The explicit None default is redundant since it's the default behavior.

Apply this diff:

-        cu_seqlens = kwargs.get('cu_seqlens', None)
+        cu_seqlens = kwargs.get('cu_seqlens')
-        cu_seqlens = kwargs.get('cu_seqlens', None)
+        cu_seqlens = kwargs.get('cu_seqlens')

Also applies to: 710-710

fla/layers/mom.py (3)

23-23: Fix line length violations.

Multiple lines exceed the 127 character limit, reducing readability.

Break these long lines for better readability. For example, line 397:

-        routing_weights_full = torch.zeros((routing_weights.shape[0], routing_weights.shape[1], self.num_memories),
-                                           dtype=routing_weights.dtype, device=routing_weights.device).scatter(-1, selected_memories, routing_weights)
+        routing_weights_full = torch.zeros(
+            (routing_weights.shape[0], routing_weights.shape[1], self.num_memories),
+            dtype=routing_weights.dtype, 
+            device=routing_weights.device
+        ).scatter(-1, selected_memories, routing_weights)

Also applies to: 99-99, 122-122, 129-129, 397-397


426-427: Fix indentation and whitespace issues.

There are whitespace and indentation problems causing pipeline failures.

Apply this diff:

-            conv_mask = attention_mask[:, -hidden_states.shape[2]
-                :].repeat_interleave(self.num_memories, 0) if attention_mask is not None else None
+            conv_mask = (
+                attention_mask[:, -hidden_states.shape[2]:]
+                .repeat_interleave(self.num_memories, 0) 
+                if attention_mask is not None else None
+            )

428-428: Simplify kwargs.get() calls.

The explicit None default is redundant since it's the default behavior.

Apply this diff:

-        seq_idx = kwargs.get('seq_idx', None)
+        seq_idx = kwargs.get('seq_idx')
-        cu_seqlens = kwargs.get('cu_seqlens', None)
+        cu_seqlens = kwargs.get('cu_seqlens')

Apply similar changes to the other occurrences.

Also applies to: 460-460, 559-559, 588-588

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7ab9b43 and 11fa1b6.

📒 Files selected for processing (2)
  • fla/layers/mom.py (1 hunks)
  • fla/layers/mom_varlen.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
fla/layers/mom_varlen.py (1)
Learnt from: toothacher17
PR: fla-org/flash-linear-attention#395
File: fla/layers/mla.py:0-0
Timestamp: 2025-05-13T06:04:24.342Z
Learning: In Flash Attention's `flash_attn_varlen_func`, the scaling factor (typically computed as head_dim ** -0.5) is handled internally by the Flash Attention implementation rather than needing to be applied explicitly to query projections.
🪛 Ruff (0.11.9)
fla/layers/mom_varlen.py

591-591: Use kwargs.get('cu_seqlens') instead of kwargs.get('cu_seqlens', None)

Replace kwargs.get('cu_seqlens', None) with kwargs.get('cu_seqlens')

(SIM910)


659-659: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


660-660: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


661-661: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


710-710: Use kwargs.get('cu_seqlens') instead of kwargs.get('cu_seqlens', None)

Replace kwargs.get('cu_seqlens', None) with kwargs.get('cu_seqlens')

(SIM910)

fla/layers/mom.py

428-428: Use kwargs.get('seq_idx') instead of kwargs.get('seq_idx', None)

Replace kwargs.get('seq_idx', None) with kwargs.get('seq_idx')

(SIM910)


460-460: Use kwargs.get('cu_seqlens') instead of kwargs.get('cu_seqlens', None)

Replace kwargs.get('cu_seqlens', None) with kwargs.get('cu_seqlens')

(SIM910)


541-541: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


542-542: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


543-543: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


559-559: Use kwargs.get('seq_idx') instead of kwargs.get('seq_idx', None)

Replace kwargs.get('seq_idx', None) with kwargs.get('seq_idx')

(SIM910)


588-588: Use kwargs.get('cu_seqlens') instead of kwargs.get('cu_seqlens', None)

Replace kwargs.get('cu_seqlens', None) with kwargs.get('cu_seqlens')

(SIM910)

🪛 GitHub Actions: lint
fla/layers/mom_varlen.py

[error] 46-46: flake8: line too long (131 > 127 characters) (E501)


[error] 71-71: flake8: line too long (128 > 127 characters) (E501)


[error] 73-73: flake8: line too long (135 > 127 characters) (E501)


[error] 98-98: flake8: line too long (178 > 127 characters) (E501)


[error] 100-100: flake8: line too long (170 > 127 characters) (E501)


[error] 140-140: flake8: line too long (128 > 127 characters) (E501)


[error] 216-216: flake8: line too long (135 > 127 characters) (E501)


[error] 241-241: flake8: line too long (177 > 127 characters) (E501)


[error] 248-248: flake8: line too long (128 > 127 characters) (E501)


[error] 520-520: flake8: line too long (150 > 127 characters) (E501)

fla/layers/mom.py

[error] 23-23: flake8: line too long (128 > 127 characters) (E501)


[error] 99-99: flake8: line too long (135 > 127 characters) (E501)


[error] 122-122: flake8: line too long (177 > 127 characters) (E501)


[error] 129-129: flake8: line too long (128 > 127 characters) (E501)


[error] 397-397: flake8: line too long (150 > 127 characters) (E501)


[error] 426-426: flake8: whitespace before ':' (E203)


[error] 427-427: flake8: continuation line under-indented for visual indent (E128)

🪛 Flake8 (7.2.0)
fla/layers/mom.py

[error] 427-427: continuation line under-indented for visual indent

(E128)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-models

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (5)
fla/models/mom/modeling_mom.py (3)

43-43: Fix incorrect return type annotation.

The __init__ method should return None, not the class type.

Apply this diff:

-    ) -> MomMLP:
+    ) -> None:

164-164: Fix incorrect module name in _no_split_modules.

The _no_split_modules list references 'GatedDeltaNetBlock' but should reference 'MomBlock' for this model.

-    _no_split_modules = ['GatedDeltaNetBlock']
+    _no_split_modules = ['MomBlock']

244-244: Fix incorrect model name in warning and add stacklevel.

The warning refers to GatedDeltaNetModel but should refer to MomModel. Also, add stacklevel parameter to the warning.

Apply this diff:

-            warnings.warn("`GatedDeltaNetModel` does not `output_attentions` now, setting it to `False`.")
+            warnings.warn("`MomModel` does not support `output_attentions` now, setting it to `False`.", stacklevel=2)
fla/layers/mom_varlen.py (2)

330-330: Fix incorrect return type annotation.

The __init__ method should return None, not the class type.

Apply this diff:

-    ) -> MomAttention:
+    ) -> None:

617-619: Avoid mutable default arguments.

Using mutable objects as default arguments can lead to unexpected behavior.

Apply this diff:

-        conv_state_q=[None, None],
-        conv_state_k=[None, None],
-        conv_state_v=[None, None],
+        conv_state_q=None,
+        conv_state_k=None,
+        conv_state_v=None,

Then initialize them in the method:

if conv_state_q is None:
    conv_state_q = [None, None]
if conv_state_k is None:
    conv_state_k = [None, None]
if conv_state_v is None:
    conv_state_v = [None, None]
🧹 Nitpick comments (4)
fla/layers/mom_varlen.py (4)

216-216: Break long line for better readability.

The line exceeds the maximum length limit. Break it into multiple lines.

Apply this diff:

-        indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand(
-            b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1)
+        indices = (
+            torch.arange(max_len, device=flatten_offset.device)
+            .unsqueeze(0)
+            .expand(b*num_memories, -1) 
+            + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0)
+            .unsqueeze(1)
+        )

477-478: Break long line for better readability.

The line exceeds the maximum length limit. Break it into multiple lines.

Apply this diff:

-        routing_weights_full = torch.zeros((routing_weights.shape[0], routing_weights.shape[1], self.num_memories),
-                                           dtype=routing_weights.dtype, device=routing_weights.device).scatter(-1, selected_memories, routing_weights)
+        routing_weights_full = torch.zeros(
+            (routing_weights.shape[0], routing_weights.shape[1], self.num_memories),
+            dtype=routing_weights.dtype, 
+            device=routing_weights.device
+        ).scatter(-1, selected_memories, routing_weights)

549-549: Optimize kwargs.get usage.

Simplify the kwargs.get call by removing the redundant None default.

Apply this diff:

-        cu_seqlens = kwargs.get('cu_seqlens', None)
+        cu_seqlens = kwargs.get('cu_seqlens')

668-668: Optimize kwargs.get usage.

Simplify the kwargs.get call by removing the redundant None default.

Apply this diff:

-        cu_seqlens = kwargs.get('cu_seqlens', None)
+        cu_seqlens = kwargs.get('cu_seqlens')
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 11fa1b6 and 9f62f28.

📒 Files selected for processing (3)
  • fla/layers/mom_varlen.py (1 hunks)
  • fla/models/mom/configuration_mom.py (1 hunks)
  • fla/models/mom/modeling_mom.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • fla/models/mom/configuration_mom.py
🧰 Additional context used
🧠 Learnings (1)
fla/layers/mom_varlen.py (1)
Learnt from: toothacher17
PR: fla-org/flash-linear-attention#395
File: fla/layers/mla.py:0-0
Timestamp: 2025-05-13T06:04:24.342Z
Learning: In Flash Attention's `flash_attn_varlen_func`, the scaling factor (typically computed as head_dim ** -0.5) is handled internally by the Flash Attention implementation rather than needing to be applied explicitly to query projections.
🧬 Code Graph Analysis (1)
fla/layers/mom_varlen.py (9)
fla/modules/fused_norm_gate.py (1)
  • FusedRMSNormGated (993-1054)
fla/modules/layernorm.py (1)
  • RMSNorm (1060-1107)
fla/modules/convolution.py (1)
  • ShortConvolution (548-745)
fla/modules/l2norm.py (1)
  • l2norm (262-267)
fla/ops/gated_delta_rule/chunk.py (1)
  • chunk_gated_delta_rule (221-340)
fla/ops/gated_delta_rule/fused_recurrent.py (1)
  • fused_recurrent_gated_delta_rule (240-351)
fla/models/utils.py (1)
  • Cache (11-150)
fla/layers/utils.py (2)
  • pad_input (176-197)
  • unpad_input (101-173)
fla/layers/mom.py (4)
  • transform (23-117)
  • reconstruct (122-184)
  • MomAttention (187-615)
  • forward (366-533)
🪛 Ruff (0.11.9)
fla/models/mom/modeling_mom.py

244-244: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


362-368: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)

fla/layers/mom_varlen.py

549-549: Use kwargs.get('cu_seqlens') instead of kwargs.get('cu_seqlens', None)

Replace kwargs.get('cu_seqlens', None) with kwargs.get('cu_seqlens')

(SIM910)


617-617: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


618-618: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


619-619: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


668-668: Use kwargs.get('cu_seqlens') instead of kwargs.get('cu_seqlens', None)

Replace kwargs.get('cu_seqlens', None) with kwargs.get('cu_seqlens')

(SIM910)

🪛 GitHub Actions: lint
fla/layers/mom_varlen.py

[error] 46-46: flake8: line too long (131 > 127 characters) (E501)


[error] 71-71: flake8: line too long (128 > 127 characters) (E501)


[error] 73-73: flake8: line too long (135 > 127 characters) (E501)


[error] 98-98: flake8: line too long (178 > 127 characters) (E501)


[error] 100-100: flake8: line too long (170 > 127 characters) (E501)


[error] 140-140: flake8: line too long (128 > 127 characters) (E501)


[error] 216-216: flake8: line too long (135 > 127 characters) (E501)


[error] 241-241: flake8: line too long (177 > 127 characters) (E501)


[error] 248-248: flake8: line too long (128 > 127 characters) (E501)


[error] 478-478: flake8: line too long (150 > 127 characters) (E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-models
🔇 Additional comments (1)
fla/layers/mom_varlen.py (1)

530-530: Fix syntax error with trailing comma.

Remove the trailing comma after the tuple elements to fix the syntax error.

Apply this diff:

-            q, k, v = self.silu(q), self.silu(k), self.silu(v),
+            q, k, v = self.silu(q), self.silu(k), self.silu(v)

Likely an incorrect or invalid review comment.

Comment on lines +362 to +368
raise AttributeError(
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
f"which is not supported for {self.__class__.__name__}. "
f"Try another generation strategy instead. "
f"For the available generation strategies, check this doc: "
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Improve exception handling with proper chaining.

Use raise ... from exception to properly chain exceptions and provide better error context.

Apply this diff:

                raise AttributeError(
                    f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
                    f"which is not supported for {self.__class__.__name__}. "
                    f"Try another generation strategy instead. "
                    f"For the available generation strategies, check this doc: "
                    f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
-                )
+                ) from exception
            else:
-                raise exception
+                raise
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
raise AttributeError(
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
f"which is not supported for {self.__class__.__name__}. "
f"Try another generation strategy instead. "
f"For the available generation strategies, check this doc: "
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
)
raise AttributeError(
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
f"which is not supported for {self.__class__.__name__}. "
f"Try another generation strategy instead. "
f"For the available generation strategies, check this doc: "
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
) from exception
else:
raise
🧰 Tools
🪛 Ruff (0.11.9)

362-368: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)

🤖 Prompt for AI Agents
In fla/models/mom/modeling_mom.py around lines 362 to 368, the AttributeError is
raised without chaining the original exception, which loses error context.
Modify the raise statement to use "raise ... from exception" by capturing the
original exception and chaining it to the new AttributeError to improve
debugging and error traceability.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (3)
fla/layers/mom.py (3)

330-330: Fix incorrect return type annotation.

The __init__ method should return None, not MomAttention.

Apply this diff:

-    ) -> MomAttention:
+    ) -> None:

617-619: Fix mutable default arguments.

Using mutable objects as default arguments can lead to unexpected behavior.

Apply this diff:

-        conv_state_q=[None, None],
-        conv_state_k=[None, None],
-        conv_state_v=[None, None],
+        conv_state_q=None,
+        conv_state_k=None,
+        conv_state_v=None,

Then initialize them in the method:

if conv_state_q is None:
    conv_state_q = [None, None]
if conv_state_k is None:
    conv_state_k = [None, None]
if conv_state_v is None:
    conv_state_v = [None, None]

358-358: Fix typo in assertion error message.

There's still a typo "suppoerted" instead of "supported" despite being marked as addressed in previous reviews.

Apply this diff:

-        assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
+        assert mode in ['chunk', 'fused_recurrent'], f"Not supported mode `{mode}`."
🧹 Nitpick comments (3)
fla/layers/mom.py (3)

549-549: Simplify kwargs.get usage.

The None default is redundant since get() returns None by default.

Apply this diff:

-        cu_seqlens = kwargs.get('cu_seqlens', None)
+        cu_seqlens = kwargs.get('cu_seqlens')

668-668: Simplify kwargs.get usage.

The None default is redundant since get() returns None by default.

Apply this diff:

-        cu_seqlens = kwargs.get('cu_seqlens', None)
+        cu_seqlens = kwargs.get('cu_seqlens')

46-46: Fix line length violations.

Multiple lines exceed the 127 character limit and need to be wrapped for better readability.

Consider breaking long lines using parentheses for function arguments, or splitting complex expressions across multiple lines. For example:

-        cu_seqlens (`torch.Tensor`): The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
+        cu_seqlens (`torch.Tensor`): 
+            The cumulative sequence lengths, used to index into ragged (unpadded) tensors. 
+            `cu_seqlens` shape is (batch_size + 1,).

Also applies to: 71-71, 73-73, 98-98, 100-100, 140-140, 216-216, 241-241, 248-248, 478-478

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9f62f28 and 4e3a997.

📒 Files selected for processing (2)
  • fla/layers/__init__.py (2 hunks)
  • fla/layers/mom.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • fla/layers/init.py
🧰 Additional context used
🪛 Ruff (0.11.9)
fla/layers/mom.py

549-549: Use kwargs.get('cu_seqlens') instead of kwargs.get('cu_seqlens', None)

Replace kwargs.get('cu_seqlens', None) with kwargs.get('cu_seqlens')

(SIM910)


617-617: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


618-618: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


619-619: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


668-668: Use kwargs.get('cu_seqlens') instead of kwargs.get('cu_seqlens', None)

Replace kwargs.get('cu_seqlens', None) with kwargs.get('cu_seqlens')

(SIM910)

🪛 GitHub Actions: lint
fla/layers/mom.py

[error] 46-46: flake8: line too long (131 > 127 characters) (E501)


[error] 71-71: flake8: line too long (128 > 127 characters) (E501)


[error] 73-73: flake8: line too long (135 > 127 characters) (E501)


[error] 98-98: flake8: line too long (178 > 127 characters) (E501)


[error] 100-100: flake8: line too long (170 > 127 characters) (E501)


[error] 140-140: flake8: line too long (128 > 127 characters) (E501)


[error] 216-216: flake8: line too long (135 > 127 characters) (E501)


[error] 241-241: flake8: line too long (177 > 127 characters) (E501)


[error] 248-248: flake8: line too long (128 > 127 characters) (E501)


[error] 478-478: flake8: line too long (150 > 127 characters) (E501)

🔇 Additional comments (5)
fla/layers/mom.py (5)

1-32: LGTM! Clean imports and utility setup.

The imports are well-organized and the conditional flash attention import with fallback message is good practice.


34-137: LGTM! Standard flash attention utilities.

The helper functions are well-documented and follow standard patterns for flash attention integration.


140-236: LGTM! Well-implemented memory routing with capacity constraints.

This is a complex but crucial function for the MoM architecture. The comprehensive docstring, efficient use of torch.no_grad() for indexing operations, and proper handling of topk routing and capacity constraints demonstrate good implementation quality.


241-303: LGTM! Solid reconstruction logic with safety checks.

The function properly reverses the memory routing process with clear documentation and helpful assertions for debugging.


446-609: LGTM! Core attention mechanism is well-implemented.

The forward method properly implements the MoM attention mechanism with correct routing, state management, and mode handling. The logic flow ensures all variables are defined before use.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🔭 Outside diff range comments (1)
fla/layers/mom.py (1)

1-719: Add missing tests for MoM modules.

I couldn’t find any MoM-related tests in tests/test_model.py or tests/test_generation.py. Per the PR objectives, please:

  • Create or update tests/test_model.py to include unit tests for:

    • transform() and reconstruct() logic (e.g., simple input patterns, capacity truncation, padding behavior).
    • MomAttention constructor and forward pass (test both chunk and fused_recurrent modes, with and without use_output_gate/use_short_conv).
  • Create or update tests/test_generation.py to cover end-to-end behavior:

    • Integrate MomAttention into a minimal model and verify output shapes and basic functionality (e.g., autoregressive generation step).

Ensuring these tests exist will satisfy the PR’s test-coverage requirements.

♻️ Duplicate comments (12)
fla/models/mom/modeling_mom.py (5)

114-114: Fix incorrect module name in _no_split_modules.

The list references 'GatedDeltaNetBlock' but should reference 'MomBlock' for this model.

-    _no_split_modules = ['GatedDeltaNetBlock']
+    _no_split_modules = ['MomBlock']

193-195: Fix incorrect model name in warning and add stacklevel.

The warning refers to GatedDeltaNetModel but should refer to MomModel. Also, add stacklevel parameter to the warning.

Apply this diff:

         if output_attentions:
-            warnings.warn("`GatedDeltaNetModel` does not `output_attentions` now, setting it to `False`.")
+            warnings.warn("`MomModel` does not support `output_attentions` now, setting it to `False`.", stacklevel=2)
             output_attentions = False

312-320: Improve exception handling with proper chaining.

Use raise ... from exception to properly chain exceptions and provide better error context.

Apply this diff:

                raise AttributeError(
                    f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
                    f"which is not supported for {self.__class__.__name__}. "
                    f"Try another generation strategy instead. "
                    f"For the available generation strategies, check this doc: "
                    f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
-                )
+                ) from exception
            else:
-                raise exception
+                raise

418-425: Missing l2wrap integration in auxiliary loss computation.

According to the PR objectives, l2wrap functionality should be included in the new MoM models. The current implementation only includes load balancing loss.

#!/bin/bash
# Description: Check if l2wrap is implemented in other models to understand the expected integration pattern

# Search for l2wrap implementation patterns in other models
rg -A 10 "l2wrap" --glob "fla/models/**/*.py" --glob "!fla/models/mom/**"

# Search for any l2wrap loss functions
ast-grep --pattern 'def $_($$$ l2wrap $$$) {
  $$$
}'

498-498: Use 'dim' instead of 'axis' for torch.max.

PyTorch uses dim parameter, not axis.

Apply this diff:

-        expert_mask = torch.max(expert_mask, axis=-2).values
+        expert_mask = torch.max(expert_mask, dim=-2).values
fla/layers/mom.py (7)

211-211: Add missing whitespace after comma.

Apply this diff:

-        memory_batch_offset = offset.transpose(0,1)
+        memory_batch_offset = offset.transpose(0, 1)

302-302: Fix whitespace issues around comma.

Apply this diff:

-    resortd_x = torch.zeros((b * s * k, h, d) ,device=gathered_x.device, dtype=gathered_x.dtype).scatter_add_(
+    resortd_x = torch.zeros((b * s * k, h, d), device=gathered_x.device, dtype=gathered_x.dtype).scatter_add_(

341-341: Fix incorrect return type annotation.

The __init__ method should return None, not MomAttention.

Apply this diff:

-    ) -> MomAttention:
+    ) -> None:

369-369: Fix typo in assertion error message.

Apply this diff:

-        assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
+        assert mode in ['chunk', 'fused_recurrent'], f"Not supported mode `{mode}`."

379-383: Fix excessive whitespace after operator.

Apply this diff:

-            self.k_proj =  nn.ModuleList([nn.Linear(self.hidden_size, self.key_dim, bias=False) for _ in range(self.num_memories)])
+            self.k_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.key_dim, bias=False) for _ in range(self.num_memories)])
-            self.b_proj =  nn.ModuleList([nn.Linear(self.hidden_size, self.num_heads, bias=False) for _ in range(self.num_memories)])
+            self.b_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_heads, bias=False) for _ in range(self.num_memories)])

Also applies to: 387-390


508-615: Fix critical bug: undefined variable 'o' when shared_mem is True.

The variable o is used before it's defined when self.shared_mem is True. The code tries to add shared_o to o on line 614, but o is only defined within the mode-specific blocks.

The logic needs restructuring. Initialize o before the conditional blocks or ensure it's defined in all code paths before the shared memory addition.


640-642: Avoid mutable default arguments.

Using mutable objects as default arguments can lead to unexpected behavior.

Apply this diff:

-        conv_state_q=[None, None],
-        conv_state_k=[None, None],
-        conv_state_v=[None, None],
+        conv_state_q=None,
+        conv_state_k=None,
+        conv_state_v=None,

Then initialize them in the method:

if conv_state_q is None:
    conv_state_q = [None, None]
if conv_state_k is None:
    conv_state_k = [None, None]
if conv_state_v is None:
    conv_state_v = [None, None]
🧹 Nitpick comments (1)
fla/layers/mom.py (1)

570-570: Simplify kwargs.get usage.

When the default value is None, you can omit it from kwargs.get().

Apply this diff:

-        cu_seqlens = kwargs.get('cu_seqlens', None)
+        cu_seqlens = kwargs.get('cu_seqlens')

Also applies to: 689-689

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4e3a997 and 9e2b4b6.

📒 Files selected for processing (3)
  • fla/layers/mom.py (1 hunks)
  • fla/models/mom/configuration_mom.py (1 hunks)
  • fla/models/mom/modeling_mom.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • fla/models/mom/configuration_mom.py
🧰 Additional context used
🪛 Ruff (0.11.9)
fla/models/mom/modeling_mom.py

194-194: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


312-318: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)

fla/layers/mom.py

570-570: Use kwargs.get('cu_seqlens') instead of kwargs.get('cu_seqlens', None)

Replace kwargs.get('cu_seqlens', None) with kwargs.get('cu_seqlens')

(SIM910)


640-640: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


641-641: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


642-642: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


689-689: Use kwargs.get('cu_seqlens') instead of kwargs.get('cu_seqlens', None)

Replace kwargs.get('cu_seqlens', None) with kwargs.get('cu_seqlens')

(SIM910)

🪛 GitHub Actions: lint
fla/layers/mom.py

[error] 45-45: flake8: line too long (131 > 127 characters) (E501)


[error] 70-70: flake8: line too long (128 > 127 characters) (E501)


[error] 72-72: flake8: line too long (135 > 127 characters) (E501)


[error] 97-97: flake8: line too long (178 > 127 characters) (E501)


[error] 99-99: flake8: line too long (170 > 127 characters) (E501)


[error] 221-221: flake8: line too long (135 > 127 characters) (E501)


[error] 259-259: flake8: line too long (128 > 127 characters) (E501)

🔇 Additional comments (1)
fla/models/mom/modeling_mom.py (1)

32-108: LGTM! Well-structured MomBlock implementation.

The block correctly handles both standard attention and MoM attention backends, with proper error handling for unsupported backends.

indices (`torch.Tensor`):
The indices of non-masked tokens from the flattened input sequence.
cu_seqlens (`torch.Tensor`):
The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix long lines exceeding 127 characters.

Several lines exceed the maximum line length limit. Consider breaking them into multiple lines for better readability.

For example, break the docstring lines:

-            The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
+            The cumulative sequence lengths, used to index into ragged (unpadded) tensors. 
+            `cu_seqlens` shape is (batch_size + 1,).

Also applies to: 70-70, 72-72, 97-97, 99-99

🧰 Tools
🪛 GitHub Actions: lint

[error] 45-45: flake8: line too long (131 > 127 characters) (E501)

🤖 Prompt for AI Agents
In fla/layers/mom.py at lines 45, 70, 72, 97, and 99, some lines exceed the
127-character limit. To fix this, break these long lines into multiple shorter
lines, especially in docstrings and comments, ensuring each line stays within
the character limit for improved readability and compliance with style
guidelines.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (3)
fla/layers/mom.py (3)

38-40: Fix line length exceeding 127 characters.


65-67: Fix line length exceeding 127 characters.


640-642: Avoid mutable default arguments.

🧹 Nitpick comments (4)
fla/layers/mom.py (4)

25-26: Consider more descriptive function name.

The function name elu_p1 is not very descriptive. Consider renaming it to something like elu_plus_one or shifted_elu to better convey its purpose.

-def elu_p1(x):
+def elu_plus_one(x):
     return (F.elu(x, 1., False) + 1.).to(x)

189-189: Break long line for better readability.

This line exceeds the 127 character limit.

-        indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand(
-            b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1)
+        indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand(
+            b*num_memories, -1) + torch.cat(
+                [torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0
+            ).unsqueeze(1)

227-227: Fix line length in docstring.

Break this long docstring line for better readability.

-    3. Uses `torch.scatter_add_` to scatter and sum the transformed outputs back to their original positions based on `indices`.
+    3. Uses `torch.scatter_add_` to scatter and sum the transformed outputs back to their original 
+       positions based on `indices`.

499-499: Break long line for better readability.

This line exceeds the 127 character limit.

-            conv_mask = attention_mask[:, -hidden_states.shape[2]:].repeat_interleave(self.num_memories, 0) if attention_mask is not None else None
+            conv_mask = (attention_mask[:, -hidden_states.shape[2]:].repeat_interleave(self.num_memories, 0) 
+                        if attention_mask is not None else None)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2b51b28 and 706d6fd.

📒 Files selected for processing (1)
  • fla/layers/mom.py (1 hunks)
🪛 Ruff (0.12.2)
fla/layers/mom.py

93-93: Undefined name unpad_input

(F821)


640-640: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


641-641: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


642-642: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


690-690: Use kwargs.get('cu_seqlens') instead of kwargs.get('cu_seqlens', None)

Replace kwargs.get('cu_seqlens', None) with kwargs.get('cu_seqlens')

(SIM910)

🪛 Flake8 (7.2.0)
fla/layers/mom.py

[error] 93-93: undefined name 'unpad_input'

(F821)

🪛 GitHub Actions: lint
fla/layers/mom.py

[error] 38-38: flake8: line too long (128 > 127 characters) (E501)


[error] 40-40: flake8: line too long (135 > 127 characters) (E501)


[error] 65-65: flake8: line too long (178 > 127 characters) (E501)


[error] 67-67: flake8: line too long (170 > 127 characters) (E501)


[error] 93-93: flake8: undefined name 'unpad_input' (F821)


[error] 189-189: flake8: line too long (135 > 127 characters) (E501)


[error] 227-227: flake8: line too long (128 > 127 characters) (E501)


[error] 499-499: flake8: line too long (147 > 127 characters) (E501)

🧰 Additional context used
🪛 Ruff (0.12.2)
fla/layers/mom.py

93-93: Undefined name unpad_input

(F821)


640-640: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


641-641: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


642-642: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


690-690: Use kwargs.get('cu_seqlens') instead of kwargs.get('cu_seqlens', None)

Replace kwargs.get('cu_seqlens', None) with kwargs.get('cu_seqlens')

(SIM910)

🪛 Flake8 (7.2.0)
fla/layers/mom.py

[error] 93-93: undefined name 'unpad_input'

(F821)

🪛 GitHub Actions: lint
fla/layers/mom.py

[error] 38-38: flake8: line too long (128 > 127 characters) (E501)


[error] 40-40: flake8: line too long (135 > 127 characters) (E501)


[error] 65-65: flake8: line too long (178 > 127 characters) (E501)


[error] 67-67: flake8: line too long (170 > 127 characters) (E501)


[error] 93-93: flake8: undefined name 'unpad_input' (F821)


[error] 189-189: flake8: line too long (135 > 127 characters) (E501)


[error] 227-227: flake8: line too long (128 > 127 characters) (E501)


[error] 499-499: flake8: line too long (147 > 127 characters) (E501)

shared_mem: bool = False,
single_kv_proj: bool = False,
**kwargs
) -> MomAttention:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Return type annotation should be None for init.

The __init__ method should return None, not MomAttention.

-    ) -> MomAttention:
+    ) -> None:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
) -> MomAttention:
) -> None:
🤖 Prompt for AI Agents
In fla/layers/mom.py at line 309, the __init__ method currently has a return
type annotation of MomAttention, which is incorrect. Change the return type
annotation of the __init__ method to None to properly indicate it does not
return a value.

def _init_weights(
self,
module: nn.Module,
rescale_prenorm_residual: bool = True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should disable rescale_prenorm_residual by default.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (7)
fla/models/mom/modeling_mom.py (4)

413-427: Missing l2wrap integration.

According to the PR objectives, l2wrap functionality should be integrated into the MoM models. Currently, only the load balancing auxiliary loss is implemented.

Let me verify if l2wrap exists in the codebase:

#!/bin/bash
# Search for l2wrap implementation in the codebase
rg -A 5 "l2wrap" --type py

# Check if there's a l2wrap loss function
ast-grep --pattern 'def $_l2wrap$_($_) { $$$ }'

Based on the findings, you'll need to integrate l2wrap loss computation alongside the load balancing loss in the forward method.


193-195: Add stacklevel to warning call.

The warning should include stacklevel=2 to correctly point to the caller in the stack trace.

Apply this diff:

         if output_attentions:
-            warnings.warn("`MomModel` does not `output_attentions` now, setting it to `False`.")
+            warnings.warn("`MomModel` does not support `output_attentions` now, setting it to `False`.", stacklevel=2)
             output_attentions = False

307-321: Use proper exception chaining.

Chain the original exception to provide better error context.

Apply this diff:

                raise AttributeError(
                    f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
                    f"which is not supported for {self.__class__.__name__}. "
                    f"Try another generation strategy instead. "
                    f"For the available generation strategies, check this doc: "
                    f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
-                )
+                ) from exception
            else:
-                raise exception
+                raise

498-498: Use 'dim' instead of 'axis' for torch.max.

PyTorch uses dim parameter, not axis.

Apply this diff:

-        expert_mask = torch.max(expert_mask, axis=-2).values
+        expert_mask = torch.max(expert_mask, dim=-2).values
fla/layers/mom.py (3)

28-105: Fix undefined function and long docstring lines.

Multiple issues in this function:

  1. unpad_input is called but not defined or imported
  2. Several docstring lines exceed the 127-character limit

First, fix the undefined function by importing it:

-from fla.layers.utils import get_unpad_data, index_first_axis, pad_input
+from fla.layers.utils import get_unpad_data, index_first_axis, pad_input, unpad_input

Also, break the long docstring lines:

-    Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
+    Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to
+    different batches.

-    This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
-    tensors for query, key, value tensors.
+    This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation
+    of the same intermediary tensors for query, key, value tensors.

291-310: Fix return type annotation.

The __init__ method should return None, not MomAttention.

Apply this diff:

-    ) -> MomAttention:
+    ) -> None:

637-647: Avoid mutable default arguments.

Using mutable objects as default arguments can lead to unexpected behavior.

Apply this diff:

-        conv_state_q=[None, None],
-        conv_state_k=[None, None],
-        conv_state_v=[None, None],
+        conv_state_q=None,
+        conv_state_k=None,
+        conv_state_v=None,

Then initialize them in the method:

if conv_state_q is None:
    conv_state_q = [None, None]
if conv_state_k is None:
    conv_state_k = [None, None]  
if conv_state_v is None:
    conv_state_v = [None, None]
🧹 Nitpick comments (2)
fla/layers/mom.py (2)

693-693: Simplify kwargs.get() call.

The None default is unnecessary since get() returns None by default.

Apply this diff:

-        cu_seqlens = kwargs.get('cu_seqlens', None)
+        cu_seqlens = kwargs.get('cu_seqlens')

1-723: Add tests for MoM modules.

As requested in the PR comments, please add tests for the new MoM modules to:

  • tests/test_model.py
  • tests/test_generation.py

This will ensure the new functionality is properly validated.

Would you like me to help generate test cases for the MoM attention layer and model classes?

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 706d6fd and 8a8b6c2.

📒 Files selected for processing (2)
  • fla/layers/mom.py (1 hunks)
  • fla/models/mom/modeling_mom.py (1 hunks)
🪛 Ruff (0.12.2)
fla/layers/mom.py

94-94: Undefined name unpad_input

(F821)


643-643: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


644-644: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


645-645: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


693-693: Use kwargs.get('cu_seqlens') instead of kwargs.get('cu_seqlens', None)

Replace kwargs.get('cu_seqlens', None) with kwargs.get('cu_seqlens')

(SIM910)

fla/models/mom/modeling_mom.py

194-194: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


312-318: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)

🪛 Flake8 (7.2.0)
fla/layers/mom.py

[error] 94-94: undefined name 'unpad_input'

(F821)

🪛 GitHub Actions: lint
fla/layers/mom.py

[error] 38-38: flake8: line too long (128 > 127 characters) (E501)


[error] 40-40: flake8: line too long (135 > 127 characters) (E501)


[error] 65-65: flake8: line too long (178 > 127 characters) (E501)


[error] 67-67: flake8: line too long (170 > 127 characters) (E501)


[error] 94-94: flake8: undefined name 'unpad_input' (F821)


[error] 190-190: flake8: line too long (135 > 127 characters) (E501)


[error] 228-228: flake8: line too long (128 > 127 characters) (E501)


[error] 500-500: flake8: line too long (147 > 127 characters) (E501)

🧰 Additional context used
🪛 Ruff (0.12.2)
fla/layers/mom.py

94-94: Undefined name unpad_input

(F821)


643-643: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


644-644: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


645-645: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


693-693: Use kwargs.get('cu_seqlens') instead of kwargs.get('cu_seqlens', None)

Replace kwargs.get('cu_seqlens', None) with kwargs.get('cu_seqlens')

(SIM910)

fla/models/mom/modeling_mom.py

194-194: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


312-318: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)

🪛 Flake8 (7.2.0)
fla/layers/mom.py

[error] 94-94: undefined name 'unpad_input'

(F821)

🪛 GitHub Actions: lint
fla/layers/mom.py

[error] 38-38: flake8: line too long (128 > 127 characters) (E501)


[error] 40-40: flake8: line too long (135 > 127 characters) (E501)


[error] 65-65: flake8: line too long (178 > 127 characters) (E501)


[error] 67-67: flake8: line too long (170 > 127 characters) (E501)


[error] 94-94: flake8: undefined name 'unpad_input' (F821)


[error] 190-190: flake8: line too long (135 > 127 characters) (E501)


[error] 228-228: flake8: line too long (128 > 127 characters) (E501)


[error] 500-500: flake8: line too long (147 > 127 characters) (E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: check-pt-python-compatibility
🔇 Additional comments (1)
fla/models/mom/modeling_mom.py (1)

32-108: MomBlock implementation looks good!

The block structure with conditional attention backend selection and proper residual connections is well implemented.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🔭 Outside diff range comments (1)
tests/models/test_modeling_mom.py (1)

1-58: Add module docstring and extend generation tests with use_l2warp

The helpers in tests/models/test_modeling_base.py confirm that run_test_generation accepts an optional use_l2warp flag (default False). To fully cover both code paths, update tests/models/test_modeling_mom.py as follows:

• Add a module-level docstring at the top.
• Parameterize test_generation over use_l2warp (e.g. True/False) and pass it through to run_test_generation.

Example patch:

--- a/tests/models/test_modeling_mom.py
+++ b/tests/models/test_modeling_mom.py
@@ -1,4 +1,12 @@
 # -*- coding: utf-8 -*-
+
+"""
+Test module for MoM (Mixture-of-Memories) model implementation.
+
+This module tests:
+ - forward/backward pass via run_test_model_forward_backward
+ - K/V cache generation via run_test_generation (with and without l2warp)
+"""

 import pytest
 import torch
@@ -40,16 +48,26 @@ from .test_modeling_base import run_test_generation, run_test_model_forward_backward
 
 # ===================================================================================
 # Test for Generation
 # ===================================================================================
-@pytest.mark.parametrize(
-    ['L', 'B', 'T', 'H', 'D', 'dtype'],
+@pytest.mark.parametrize(
+    ['L', 'B', 'T', 'H', 'D', 'use_l2warp', 'dtype'],
     [
-        pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-{}".format(*test))
-        for test in [
-            (2, 4, 2000, 8, 64, torch.float16),
-        ]
+        pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-use_l2warp{}-{}".format(*test))
+        for test in [
+            # test both with and without l2warp
+            (2, 4, 2000, 8, 64, False, torch.float16),
+            (2, 4, 2000, 8, 64, True,  torch.float16),
+        ]
     ]
 )
 def test_generation(
-    L: int,
-    B: int,
-    T: int,
-    H: int,
-    D: int,
-    dtype: torch.dtype,
+    L: int,
+    B: int,
+    T: int,
+    H: int,
+    D: int,
+    use_l2warp: bool,
+    dtype: torch.dtype,
 ):
-    run_test_generation(L, B, T, H, D, MomConfig, dtype)
+    run_test_generation(
+        L, B, T, H, D, MomConfig, dtype, use_l2warp=use_l2warp
+    )
  • Confirms signature compatibility in tests/models/test_modeling_base.py.
  • Improves coverage by testing both default and l2warp=True generation paths.
  • Adds a clear module-level docstring.
🧹 Nitpick comments (1)
tests/models/test_modeling_mom.py (1)

14-34: Good test coverage for forward/backward pass validation.

The parameterized test covers key variations including L2 warp functionality (addressing the PR reviewer's request). The test cases appropriately exercise different model configurations.

Consider adding test cases with different dtypes (e.g., float16, float32) to ensure broader compatibility:

        for test in [
            (4, 4, 1024, 4, 64, True, torch.bfloat16),
            (4, 4, 1024, 4, 64, False, torch.bfloat16),
            (4, 4, 1024, 4, 128, False, torch.bfloat16),
+           (4, 4, 1024, 4, 64, True, torch.float16),
+           (4, 4, 512, 4, 64, False, torch.float32),
        ]
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8a8b6c2 and 54ec5b6.

📒 Files selected for processing (2)
  • fla/__init__.py (3 hunks)
  • tests/models/test_modeling_mom.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • fla/init.py
🔇 Additional comments (1)
tests/models/test_modeling_mom.py (1)

1-8: LGTM: Imports are well-structured.

The imports follow standard conventions and correctly reference the new MoM components and test utilities.

Comment on lines +40 to +57
@pytest.mark.parametrize(
['L', 'B', 'T', 'H', 'D', 'dtype'],
[
pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-{}".format(*test))
for test in [
(2, 4, 2000, 8, 64, torch.float16),
]
]
)
def test_generation(
L: int,
B: int,
T: int,
H: int,
D: int,
dtype: torch.dtype,
):
run_test_generation(L, B, T, H, D, MomConfig, dtype)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider adding L2 warp parameter to generation test.

The generation test is well-structured but missing the use_l2warp parameter that was specifically requested by the PR reviewer. This parameter should be tested for generation functionality as well.

Add the use_l2warp parameter to ensure complete L2 warp functionality testing:

@pytest.mark.parametrize(
-   ['L', 'B', 'T', 'H', 'D', 'dtype'],
+   ['L', 'B', 'T', 'H', 'D', 'use_l2warp', 'dtype'],
    [
-       pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-{}".format(*test))
+       pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-use_l2warp{}-{}".format(*test))
        for test in [
-           (2, 4, 2000, 8, 64, torch.float16),
+           (2, 4, 2000, 8, 64, True, torch.float16),
+           (2, 4, 2000, 8, 64, False, torch.float16),
        ]
    ]
)
def test_generation(
    L: int,
    B: int,
    T: int,
    H: int,
    D: int,
+   use_l2warp: bool,
    dtype: torch.dtype,
):
-   run_test_generation(L, B, T, H, D, MomConfig, dtype)
+   run_test_generation(L, B, T, H, D, MomConfig, use_l2warp=use_l2warp, dtype=dtype)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@pytest.mark.parametrize(
['L', 'B', 'T', 'H', 'D', 'dtype'],
[
pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-{}".format(*test))
for test in [
(2, 4, 2000, 8, 64, torch.float16),
]
]
)
def test_generation(
L: int,
B: int,
T: int,
H: int,
D: int,
dtype: torch.dtype,
):
run_test_generation(L, B, T, H, D, MomConfig, dtype)
@pytest.mark.parametrize(
['L', 'B', 'T', 'H', 'D', 'use_l2warp', 'dtype'],
[
pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-use_l2warp{}-{}".format(*test))
for test in [
(2, 4, 2000, 8, 64, True, torch.float16),
(2, 4, 2000, 8, 64, False, torch.float16),
]
]
)
def test_generation(
L: int,
B: int,
T: int,
H: int,
D: int,
use_l2warp: bool,
dtype: torch.dtype,
):
run_test_generation(L, B, T, H, D, MomConfig, use_l2warp=use_l2warp, dtype=dtype)
🤖 Prompt for AI Agents
In tests/models/test_modeling_mom.py around lines 40 to 57, the test_generation
function lacks the use_l2warp parameter, which is needed to fully test L2 warp
functionality. Modify the pytest.mark.parametrize decorator to include
use_l2warp as a parameter with appropriate test values, and update the
test_generation function signature to accept use_l2warp. Pass this parameter to
the run_test_generation call to ensure the generation test covers the L2 warp
feature.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (3)
fla/layers/mom.py (3)

90-90: Critical issue: undefined function unpad_input still not resolved.

This issue was flagged in previous reviews but remains unaddressed. The function unpad_input is not imported and will cause a NameError at runtime.

You need to either import unpad_input from fla.layers.utils if it exists, or implement an alternative solution. Check the available functions in the utils module:

#!/bin/bash
# Verify available functions in fla.layers.utils
ast-grep --pattern 'def $NAME($_) {
  $$$
}' fla/layers/utils.py

304-304: Fix return type annotation for init method.

The __init__ method should return None, not MomAttention. This issue was flagged in previous reviews but remains unaddressed.

-    ) -> MomAttention:
+    ) -> None:

654-656: Fix mutable default arguments.

Using mutable objects as default arguments can lead to unexpected behavior. This issue was flagged in previous reviews but remains unaddressed.

-        conv_state_q=[None, None],
-        conv_state_k=[None, None],
-        conv_state_v=[None, None],
+        conv_state_q=None,
+        conv_state_k=None,
+        conv_state_v=None,

Then initialize them in the method:

if conv_state_q is None:
    conv_state_q = [None, None]
if conv_state_k is None:
    conv_state_k = [None, None]
if conv_state_v is None:
    conv_state_v = [None, None]
🧹 Nitpick comments (4)
fla/layers/mom.py (4)

193-201: Clean up commented code.

There's a significant amount of commented-out code that should be removed to improve readability and maintainability.

-    # transformed_x = transformed_x * mask.unsqueeze(-1).expand_as(transformed_x)
-    # pad_x = torch.zeros((b * num_memories, capacity_len-max_len, d), dtype=transformed_x.dtype, device=transformed_x.device)
-    # pad_mask = torch.zeros((b * num_memories, capacity_len-max_len), dtype=transformed_x.dtype, device=transformed_x.device)
-    # left pad
-    # transformed_x = torch.cat((pad_x, transformed_x), dim=1).reshape((b, num_memories, capacity_len, d)).transpose(0, 1)
     mask_2 = mask.reshape((b, num_memories, max_len)).transpose(0, 1)
-    # truncation_indices += capacity_len-max_len
-    # if attention_mask is not None:
-    #     mask_2

733-733: Fix whitespace before bracket.

There's incorrect whitespace before the bracket causing a pipeline failure.

-        max_len = (cu_seqlens[1:] - cu_seqlens [:-1]).max().item()
+        max_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()

641-641: Fix indentation issue.

The continuation line is under-indented for visual alignment, causing a pipeline failure.

Ensure proper indentation alignment:

-                seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask)
+                        seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask)

441-441: Simplify kwargs.get usage.

Static analysis suggests using kwargs.get('cu_seqlens') instead of kwargs.get('cu_seqlens', None) since None is the default.

-        origin_cu_seqlens = kwargs.get('cu_seqlens', None)
+        origin_cu_seqlens = kwargs.get('cu_seqlens')
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 54ec5b6 and 273cce6.

📒 Files selected for processing (3)
  • fla/layers/mom.py (1 hunks)
  • fla/models/mom/__init__.py (1 hunks)
  • pyproject.toml (1 hunks)
✅ Files skipped from review due to trivial changes (2)
  • pyproject.toml
  • fla/models/mom/init.py
🧰 Additional context used
🪛 Ruff (0.12.2)
fla/layers/mom.py

441-441: Use kwargs.get('cu_seqlens') instead of kwargs.get('cu_seqlens', None)

Replace kwargs.get('cu_seqlens', None) with kwargs.get('cu_seqlens')

(SIM910)


654-654: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


655-655: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


656-656: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


671-671: Undefined name q_len

(F821)

🪛 Flake8 (7.2.0)
fla/layers/mom.py

[error] 641-641: continuation line under-indented for visual indent

(E128)


[error] 671-671: undefined name 'q_len'

(F821)

🪛 GitHub Actions: lint
fla/layers/mom.py

[error] 33-33: flake8: line too long (128 > 127 characters) (E501)


[error] 35-35: flake8: line too long (135 > 127 characters) (E501)


[error] 60-60: flake8: line too long (178 > 127 characters) (E501)


[error] 62-62: flake8: line too long (170 > 127 characters) (E501)


[error] 187-187: flake8: line too long (135 > 127 characters) (E501)


[error] 222-222: flake8: line too long (128 > 127 characters) (E501)


[error] 508-508: flake8: line too long (133 > 127 characters) (E501)


[error] 516-516: flake8: line too long (133 > 127 characters) (E501)


[error] 524-524: flake8: line too long (133 > 127 characters) (E501)


[error] 552-552: flake8: line too long (145 > 127 characters) (E501)


[error] 586-586: flake8: line too long (130 > 127 characters) (E501)


[error] 599-599: flake8: line too long (145 > 127 characters) (E501)


[error] 641-641: flake8: continuation line under-indented for visual indent (E128)


[error] 671-671: flake8: undefined name 'q_len' (F821)


[error] 733-733: flake8: whitespace before '[' (E211)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: check-pt-python-compatibility
🔇 Additional comments (2)
fla/layers/mom.py (2)

206-277: Well-implemented reconstruction logic.

The function correctly handles the complex inverse transformation from memory-organized chunks back to the original sequence format, with proper routing weight application and validation assertions.


746-777: Well-implemented state management helpers.

The prepare_recurrent_state and handle_recurrent_state methods correctly handle the complex logic for managing recurrent states across different memory configurations and batch scenarios.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (3)
fla/layers/mom.py (3)

304-304: Fix incorrect return type annotation.

The __init__ method should return None, not MomAttention.

-    ) -> MomAttention:
+    ) -> None:

667-669: Avoid mutable default arguments.

Using mutable objects as default arguments can lead to unexpected behavior.

Apply this diff:

-        conv_state_q=[None, None],
-        conv_state_k=[None, None],
-        conv_state_v=[None, None],
+        conv_state_q=None,
+        conv_state_k=None,
+        conv_state_v=None,

Then initialize them in the method:

if conv_state_q is None:
    conv_state_q = [None, None]
if conv_state_k is None:
    conv_state_k = [None, None]
if conv_state_v is None:
    conv_state_v = [None, None]

686-686: Fix undefined variable q_len.

The variable q_len is used but not defined, which will cause a NameError at runtime.

Looking at the context, it appears q_len should be hidden_states.shape[1]:

-            indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])
+            indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -hidden_states.shape[1]:])
🧹 Nitpick comments (8)
fla/layers/mom.py (8)

60-60: Fix line length violations.

These lines exceed the 127-character limit. Break them into multiple lines for better readability.

-        (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
-            The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
+        (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
+            The cumulative sequence lengths for the target (query) and source (key, value), 
+            used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
-        (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
-            Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
+        (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
+            Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, 
+            `max_seqlen_in_batch_k` for the source sequence i.e. key/value).

Also applies to: 62-62


187-189: Fix line length violation and improve readability.

Line 187 exceeds the character limit. Break it for better readability.

-        indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand(
-            b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1)
+        indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand(
+            b*num_memories, -1
+        ) + torch.cat(
+            [torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0
+        ).unsqueeze(1)

222-222: Fix line length violation in docstring.

-    2. Applies the `mask` to zero out invalid positions.
+    2. Applies the `mask` to zero out invalid positions.

441-441: Simplify kwargs.get usage.

-        origin_cu_seqlens = kwargs.get('cu_seqlens', None)
+        origin_cu_seqlens = kwargs.get('cu_seqlens')

558-558: Fix line length violations in recurrent state handling.

-            recurrent_state[0] = self.handle_recurrent_state(recurrent_state[0], recurrent_state_, cu_seqlens, cu_seqlen_all[0], reverse_indices)
+            recurrent_state[0] = self.handle_recurrent_state(
+                recurrent_state[0], recurrent_state_, cu_seqlens, cu_seqlen_all[0], reverse_indices
+            )

Apply similar formatting to line 605.

Also applies to: 605-605


625-625: Fix indentation issue.

Line 625 has incorrect indentation for line continuation.

         o = reconstruct(o, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size,
-                seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask)
+                        seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask)

644-644: Fix line length violation.

-                shared_hidden_states = index_first_axis(rearrange(shared_hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0)
+                shared_hidden_states = index_first_axis(
+                    rearrange(shared_hidden_states, "b s ... -> (b s) ..."), indices
+                ).unsqueeze(0)

750-750: Fix whitespace before bracket.

-        max_len = (cu_seqlens[1:] - cu_seqlens [:-1]).max().item()
+        max_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 273cce6 and 0f80d59.

📒 Files selected for processing (2)
  • fla/layers/mom.py (1 hunks)
  • fla/models/mom/configuration_mom.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • fla/models/mom/configuration_mom.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
fla/layers/mom.py (8)
fla/modules/fused_norm_gate.py (1)
  • FusedRMSNormGated (993-1054)
fla/modules/layernorm.py (1)
  • RMSNorm (1060-1107)
fla/modules/convolution.py (1)
  • ShortConvolution (661-902)
fla/ops/gated_delta_rule/chunk.py (1)
  • chunk_gated_delta_rule (221-340)
fla/ops/gated_delta_rule/fused_recurrent.py (1)
  • fused_recurrent_gated_delta_rule (240-351)
fla/models/utils.py (2)
  • Cache (11-150)
  • update (43-122)
fla/layers/utils.py (3)
  • get_unpad_data (75-98)
  • pad_input (176-197)
  • unpad_input (101-173)
fla/models/mom/modeling_mom.py (3)
  • forward (77-107)
  • forward (181-264)
  • forward (356-440)
🪛 Ruff (0.12.2)
fla/layers/mom.py

441-441: Use kwargs.get('cu_seqlens') instead of kwargs.get('cu_seqlens', None)

Replace kwargs.get('cu_seqlens', None) with kwargs.get('cu_seqlens')

(SIM910)


667-667: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


668-668: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


669-669: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)

🪛 Flake8 (7.2.0)
fla/layers/mom.py

[error] 625-625: continuation line under-indented for visual indent

(E128)

🪛 GitHub Actions: lint
fla/layers/mom.py

[error] 33-33: flake8: line too long (128 > 127 characters) (E501)


[error] 35-35: flake8: line too long (135 > 127 characters) (E501)


[error] 60-60: flake8: line too long (178 > 127 characters) (E501)


[error] 62-62: flake8: line too long (170 > 127 characters) (E501)


[error] 187-187: flake8: line too long (135 > 127 characters) (E501)


[error] 222-222: flake8: line too long (128 > 127 characters) (E501)


[error] 514-514: flake8: line too long (133 > 127 characters) (E501)


[error] 522-522: flake8: line too long (133 > 127 characters) (E501)


[error] 530-530: flake8: line too long (133 > 127 characters) (E501)


[error] 558-558: flake8: line too long (145 > 127 characters) (E501)


[error] 592-592: flake8: line too long (130 > 127 characters) (E501)


[error] 605-605: flake8: line too long (145 > 127 characters) (E501)


[error] 625-625: flake8: continuation line under-indented for visual indent (E128)


[error] 644-644: flake8: line too long (134 > 127 characters) (E501)


[error] 750-750: flake8: whitespace before '[' (E211)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: check-pt-python-compatibility

output_final_state=use_cache,
cu_seqlens=conv_cu_seqlens,
)
conv_state_q[0] = self.handle_recurrent_state(conv_state_q[0], conv_q_new, cu_seqlens, cu_seqlen_all[0], reverse_indices)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Fix line length violations in convolution state handling.

Multiple lines exceed the 127-character limit.

-            conv_state_q[0] = self.handle_recurrent_state(conv_state_q[0], conv_q_new, cu_seqlens, cu_seqlen_all[0], reverse_indices)
+            conv_state_q[0] = self.handle_recurrent_state(
+                conv_state_q[0], conv_q_new, cu_seqlens, cu_seqlen_all[0], reverse_indices
+            )
             conv_k = self.prepare_recurrent_state(conv_state_k[0], cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size)
             cu_k, conv_k_new = self.k_conv1d(
                 x=cu_k,
                 cache=conv_k,
                 output_final_state=use_cache,
                 cu_seqlens=conv_cu_seqlens,
             )
-            conv_state_k[0] = self.handle_recurrent_state(conv_state_k[0], conv_k_new, cu_seqlens, cu_seqlen_all[0], reverse_indices)
+            conv_state_k[0] = self.handle_recurrent_state(
+                conv_state_k[0], conv_k_new, cu_seqlens, cu_seqlen_all[0], reverse_indices
+            )
             conv_v = self.prepare_recurrent_state(conv_state_v[0], cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size)
             cu_v, conv_v_new = self.v_conv1d(
                 x=cu_v,
                 cache=conv_v,
                 output_final_state=use_cache,
                 cu_seqlens=conv_cu_seqlens,
             )
-            conv_state_v[0] = self.handle_recurrent_state(conv_state_v[0], conv_v_new, cu_seqlens, cu_seqlen_all[0], reverse_indices)
+            conv_state_v[0] = self.handle_recurrent_state(
+                conv_state_v[0], conv_v_new, cu_seqlens, cu_seqlen_all[0], reverse_indices
+            )

Also applies to: 522-522, 530-530

🧰 Tools
🪛 GitHub Actions: lint

[error] 514-514: flake8: line too long (133 > 127 characters) (E501)

🤖 Prompt for AI Agents
In fla/layers/mom.py at lines 514, 522, and 530, the lines exceed the
127-character limit. Refactor these lines by breaking them into multiple shorter
lines or using intermediate variables to ensure no line surpasses the character
limit while preserving the original logic and readability.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

♻️ Duplicate comments (10)
fla/models/mom/modeling_mom.py (7)

120-125: Consider explicitly setting rescale_prenorm_residual default.

Based on the past review comment, rescale_prenorm_residual should be disabled by default. Consider explicitly setting this in the initialization logic or documenting why the current approach is preferred.


160-161: Add missing blank line before class definition.

PEP 8 requires two blank lines before class definitions.

 @dataclass
 class MomOutputWithPast(BaseModelOutputWithPast):
     aux_loss: Optional[torch.FloatTensor] = None
 
+
 class MomModel(MomPreTrainedModel):

272-274: Add missing blank line before class definition.

PEP 8 requires two blank lines before class definitions.

 @dataclass
 class MomCausalLMOutputWithPast(CausalLMOutputWithPast):
     aux_loss: Optional[torch.FloatTensor] = None
 
+
 class MomForCausalLM(MomPreTrainedModel, GenerationMixin):

312-322: Improve exception handling with proper chaining.

Use exception chaining to preserve the error context.

             if 'past_key_values' in str(exception):
                 raise AttributeError(
                     f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
                     f"which is not supported for {self.__class__.__name__}. "
                     f"Try another generation strategy instead. "
                     f"For the available generation strategies, check this doc: "
                     f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
-                )
+                ) from exception
             else:
-                raise exception
+                raise

415-429: Missing l2wrap integration.

According to the PR objectives, l2wrap functionality should be included in the new MoM models. The current implementation only includes the load balancing auxiliary loss.

Would you like me to help implement the l2wrap loss integration alongside the existing auxiliary loss calculation?


499-499: Use 'dim' instead of 'axis' for torch.max.

PyTorch uses dim parameter, not axis.

-        expert_mask = torch.max(expert_mask, axis=-2).values
+        expert_mask = torch.max(expert_mask, dim=-2).values

194-196: Add stacklevel to warning for proper trace.

The warning should include stacklevel=2 to correctly point to the caller.

         if output_attentions:
-            warnings.warn("`MomModel` does not `output_attentions` now, setting it to `False`.")
+            warnings.warn("`MomModel` does not `output_attentions` now, setting it to `False`.", stacklevel=2)
             output_attentions = False
fla/layers/mom.py (3)

305-305: Fix return type annotation for init.

The __init__ method should return None, not MomAttention.

-    ) -> MomAttention:
+    ) -> None:

671-673: Avoid mutable default arguments.

Using mutable objects as default arguments can lead to unexpected behavior.

-        conv_state_q=[None, None],
-        conv_state_k=[None, None],
-        conv_state_v=[None, None],
+        conv_state_q=None,
+        conv_state_k=None,
+        conv_state_v=None,

Then initialize them in the method:

if conv_state_q is None:
    conv_state_q = [None, None]
if conv_state_k is None:
    conv_state_k = [None, None]
if conv_state_v is None:
    conv_state_v = [None, None]

509-532: Fix line length violations in convolution handling.

Multiple lines exceed the 127-character limit.

-            conv_q = self.prepare_recurrent_state(conv_state_q[0], conv_cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size)
+            conv_q = self.prepare_recurrent_state(
+                conv_state_q[0], conv_cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size
+            )
             cu_q, conv_q_new = self.q_conv1d(
                 x=cu_q,
                 cache=conv_q,
                 output_final_state=use_cache,
                 cu_seqlens=conv_cu_seqlens,
             )
-            conv_state_q[0] = self.handle_recurrent_state(conv_state_q[0], conv_q_new, conv_cu_seqlens, cu_seqlen_all[0], reverse_indices)
+            conv_state_q[0] = self.handle_recurrent_state(
+                conv_state_q[0], conv_q_new, conv_cu_seqlens, cu_seqlen_all[0], reverse_indices
+            )

Apply similar formatting to the k and v convolution blocks.

🧹 Nitpick comments (1)
fla/layers/mom.py (1)

1-823: Missing test coverage.

Per the PR objectives, tests should be added to tests/test_model.py and tests/test_generation.py for the new MoM modules and model. Please ensure comprehensive test coverage is included.

Would you like me to help generate the test cases for the MomAttention layer and MomForCausalLM model?

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0f80d59 and 8ba6bed.

📒 Files selected for processing (2)
  • fla/layers/mom.py (1 hunks)
  • fla/models/mom/modeling_mom.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
fla/layers/mom.py

442-442: Use kwargs.get('cu_seqlens') instead of kwargs.get('cu_seqlens', None)

Replace kwargs.get('cu_seqlens', None) with kwargs.get('cu_seqlens')

(SIM910)


671-671: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


672-672: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


673-673: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)

fla/models/mom/modeling_mom.py

195-195: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


314-320: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)

🪛 Flake8 (7.2.0)
fla/layers/mom.py

[error] 627-627: continuation line under-indented for visual indent

(E128)

🪛 GitHub Actions: lint
fla/layers/mom.py

[error] 33-33: flake8: line too long (128 > 127 characters) (E501)


[error] 35-35: flake8: line too long (135 > 127 characters) (E501)


[error] 60-60: flake8: line too long (178 > 127 characters) (E501)


[error] 62-62: flake8: line too long (170 > 127 characters) (E501)


[error] 187-187: flake8: line too long (135 > 127 characters) (E501)


[error] 222-222: flake8: line too long (128 > 127 characters) (E501)


[error] 509-509: flake8: line too long (130 > 127 characters) (E501)


[error] 516-516: flake8: line too long (138 > 127 characters) (E501)


[error] 517-517: flake8: line too long (130 > 127 characters) (E501)


[error] 524-524: flake8: line too long (138 > 127 characters) (E501)


[error] 525-525: flake8: line too long (130 > 127 characters) (E501)


[error] 532-532: flake8: line too long (138 > 127 characters) (E501)


[error] 560-560: flake8: line too long (145 > 127 characters) (E501)


[error] 594-594: flake8: line too long (130 > 127 characters) (E501)


[error] 607-607: flake8: line too long (145 > 127 characters) (E501)


[error] 627-627: flake8: continuation line under-indented for visual indent (E128)


[error] 646-646: flake8: line too long (134 > 127 characters) (E501)


[error] 754-754: flake8: whitespace before '[' (E211)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: check-pt-python-compatibility
🔇 Additional comments (1)
fla/models/mom/modeling_mom.py (1)

32-109: Well-structured transformer block implementation!

The MomBlock class properly implements the transformer architecture with conditional attention mechanism selection and appropriate error handling for unsupported backends.

Comment on lines +33 to +35
Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.

This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix line length violations in docstrings.

Several lines exceed the 127-character limit. Break them into multiple lines.

-    Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
+    Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong 
+    to different batches.

-    This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
+    This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation 
+    of the same intermediary

-            The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
+            The cumulative sequence lengths for the target (query) and source (key, value), used to index into 
+            ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).

-            Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
+            Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, 
+            `max_seqlen_in_batch_k` for the source sequence i.e. key/value).

Also applies to: 60-62

🧰 Tools
🪛 GitHub Actions: lint

[error] 33-33: flake8: line too long (128 > 127 characters) (E501)


[error] 35-35: flake8: line too long (135 > 127 characters) (E501)

🤖 Prompt for AI Agents
In fla/layers/mom.py at lines 33-35 and 60-62, the docstring lines exceed the
127-character limit. Break these long lines into multiple shorter lines to
comply with the line length restriction, ensuring the text remains clear and
readable without altering the content.

flatten_offset = batch_memory_tokens.flatten().cumsum(dim=0)
max_len = batch_memory_tokens.max()
indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand(
b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix line length violation.

Line exceeds 127-character limit.

-        indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand(
-            b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1)
+        indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand(
+            b*num_memories, -1
+        ) + torch.cat(
+            [torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0
+        ).unsqueeze(1)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1)
indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand(
b * num_memories, -1
) + torch.cat(
[torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0
).unsqueeze(1)
🧰 Tools
🪛 GitHub Actions: lint

[error] 187-187: flake8: line too long (135 > 127 characters) (E501)

🤖 Prompt for AI Agents
In fla/layers/mom.py at line 187, the line exceeds the 127-character limit.
Break the line into multiple shorter lines by splitting the expression at
logical points, such as after the addition operator or before the torch.cat
call, and use proper indentation to maintain readability and comply with the
line length restriction.

Comment on lines +265 to +269
resortd_x = torch.zeros((b * s * k, d), device=gathered_x.device, dtype=gathered_x.dtype).scatter_add_(
0,
indices.reshape(-1).unsqueeze(-1).expand(-1, d),
gathered_x,
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix whitespace and indentation issues.

Fix whitespace around comma and continuation line indentation.

-    resortd_x = torch.zeros((b * s * k, d), device=gathered_x.device, dtype=gathered_x.dtype).scatter_add_(
-        0,
-        indices.reshape(-1).unsqueeze(-1).expand(-1, d),
-        gathered_x,
-    )
+    resortd_x = torch.zeros(
+        (b * s * k, d), device=gathered_x.device, dtype=gathered_x.dtype
+    ).scatter_add_(
+        0,
+        indices.reshape(-1).unsqueeze(-1).expand(-1, d),
+        gathered_x,
+    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
resortd_x = torch.zeros((b * s * k, d), device=gathered_x.device, dtype=gathered_x.dtype).scatter_add_(
0,
indices.reshape(-1).unsqueeze(-1).expand(-1, d),
gathered_x,
)
resortd_x = torch.zeros(
(b * s * k, d), device=gathered_x.device, dtype=gathered_x.dtype
).scatter_add_(
0,
indices.reshape(-1).unsqueeze(-1).expand(-1, d),
gathered_x,
)
🤖 Prompt for AI Agents
In fla/layers/mom.py around lines 265 to 269, fix the whitespace by removing any
extra spaces around commas and adjust the indentation of the continuation lines
to align properly under the opening parenthesis for better readability and
consistency.

Comment on lines +626 to +628
o = reconstruct(o, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size,
seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask)
o = rearrange(o, 'b l (h d) -> b l h d', h=self.num_heads)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix continuation line indentation.

The line continuation is under-indented.

         o = reconstruct(o, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size,
-                seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask)
+                        seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
o = reconstruct(o, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size,
seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask)
o = rearrange(o, 'b l (h d) -> b l h d', h=self.num_heads)
o = reconstruct(o, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size,
seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask)
o = rearrange(o, 'b l (h d) -> b l h d', h=self.num_heads)
🧰 Tools
🪛 Flake8 (7.2.0)

[error] 627-627: continuation line under-indented for visual indent

(E128)

🪛 GitHub Actions: lint

[error] 627-627: flake8: continuation line under-indented for visual indent (E128)

🤖 Prompt for AI Agents
In fla/layers/mom.py around lines 626 to 628, the continuation line of the
reconstruct function call is under-indented. Adjust the indentation of the
second line so it aligns properly under the opening parenthesis of the first
line, maintaining consistent indentation style for line continuations.


def cu2pad(self, x, cu_seqlens):
batch_size = cu_seqlens.shape[0] - 1
max_len = (cu_seqlens[1:] - cu_seqlens [:-1]).max().item()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove whitespace before bracket.

There's an extra space before the bracket.

-        max_len = (cu_seqlens[1:] - cu_seqlens [:-1]).max().item()
+        max_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
max_len = (cu_seqlens[1:] - cu_seqlens [:-1]).max().item()
max_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
🧰 Tools
🪛 GitHub Actions: lint

[error] 754-754: flake8: whitespace before '[' (E211)

🤖 Prompt for AI Agents
In fla/layers/mom.py at line 754, remove the extra whitespace before the bracket
in the expression cu_seqlens [:-1]. Change cu_seqlens [:-1] to cu_seqlens[:-1]
to fix the spacing issue.

expert_mask = torch.nn.functional.one_hot(selected_experts, self.num_memories)

# For a given token, determine if it was routed to a given expert.
expert_mask = torch.max(expert_mask, axis=-2).values
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Use 'dim' instead of 'axis' for torch.max.

PyTorch uses dim parameter, not axis.

-        expert_mask = torch.max(expert_mask, axis=-2).values
+        expert_mask = torch.max(expert_mask, dim=-2).values
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
expert_mask = torch.max(expert_mask, axis=-2).values
expert_mask = torch.max(expert_mask, dim=-2).values
🤖 Prompt for AI Agents
In fla/layers/mom.py at line 811, replace the use of 'axis' with 'dim' in the
torch.max function call. Change torch.max(expert_mask, axis=-2).values to
torch.max(expert_mask, dim=-2).values to comply with PyTorch's API.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[RFC] Add Mixture of Memory
4 participants