-
Notifications
You must be signed in to change notification settings - Fork 225
[Models]: Add MoM #442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Models]: Add MoM #442
Conversation
WalkthroughThis change introduces a new Mixture-of-Memories (MoM) architecture to the Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant MomForCausalLM
participant MomModel
participant MomBlock
participant MomAttention
participant MomMLP
User->>MomForCausalLM: forward(input_ids, ...)
MomForCausalLM->>MomModel: forward(inputs_embeds, ...)
MomModel->>MomBlock: for each layer, forward(hidden_states, ...)
MomBlock->>MomAttention: forward(hidden_states, ...)
MomAttention->>MomAttention: Route tokens to memories (transform)
MomAttention->>MomAttention: Compute attention/gating/convolutions
MomAttention->>MomAttention: Reconstruct outputs (reconstruct)
MomAttention-->>MomBlock: output, cache, router_logits
MomBlock->>MomMLP: forward(hidden_states)
MomMLP-->>MomBlock: output
MomBlock-->>MomModel: hidden_states, cache, router_logits
MomModel-->>MomForCausalLM: hidden_states, cache, router_logits
MomForCausalLM-->>User: logits, loss, cache, router_logits
Estimated code review effort🎯 4 (Complex) | ⏱️ ~90 minutes Poem
Note ⚡️ Unit Test Generation is now available in beta!Learn more here, or try it out under "Finishing Touches" below. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 16
🧹 Nitpick comments (10)
fla/layers/__init__.py (1)
55-57
: Minor style – keep__all__
alphabetically sorted & add trailing comma.Maintaining alphabetical order and a trailing comma prevents merge conflicts and keeps diffs minimal.
- 'MesaNet', - 'MomGatedDeltaNet' + 'MesaNet', + 'MomGatedDeltaNet',fla/models/__init__.py (2)
30-30
: Import list spacing / readability.Missing space after the first comma is inconsistent with the rest of the file:
-from fla.models.mom_gated_deltanet import MomGatedDeltaNetConfig,MomGatedDeltaNetForCausalLM, MomGatedDeltaNetModel +from fla.models.mom_gated_deltanet import MomGatedDeltaNetConfig, MomGatedDeltaNetForCausalLM, MomGatedDeltaNetModel
56-56
: Trailing comma & alphabetical order in__all__
.- 'MomGatedDeltaNetConfig','MomGatedDeltaNetForCausalLM','MomGatedDeltaNetModel' + 'MomGatedDeltaNetConfig', 'MomGatedDeltaNetForCausalLM', 'MomGatedDeltaNetModel',fla/__init__.py (1)
23-26
: Top-level heavy imports slow downimport fla
.Adding yet another large model (
MomGatedDeltaNet
) to the mandatory import path further increases import time and memory footprint of the package. Consider a lazy-import pattern (e.g. viaimportlib.metadata.EntryPoints
or a simpleLazyModule
helper) so that users who do not need this model are not penalised.fla/models/mom_gated_deltanet/configuration_mom_gated_deltanet.py (1)
11-45
: Huge argument list – consider dataclass / kwargs container.32 positional parameters exceed readability and trigger pylint R0913/R0917. A
dataclass
(or grouping related flags into small parameter objects) would make the public API friendlier and future-proof.🧰 Tools
🪛 Pylint (3.3.7)
[refactor] 11-11: Too many arguments (32/5)
(R0913)
[refactor] 11-11: Too many positional arguments (32/5)
(R0917)
[refactor] 11-11: Too many local variables (33/15)
(R0914)
fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py (2)
470-470
: Consider using professional comments without emojis.While the emoji adds visual emphasis, professional codebases typically avoid them in comments.
Apply this diff:
- use_layer_wise_balance=self.config.use_layer_wise_balance, # ✨ + use_layer_wise_balance=self.config.use_layer_wise_balance, # Layer-wise balance flag
516-517
: Consider removing emojis from code comments for professionalism.Apply this diff:
- # ✨ Here is the fix for balance loss in Mixtral. + # Here is the fix for balance loss in Mixtral.- # ✨ balance loss for this layer + # Balance loss for this layer- all_balance_losses = torch.cat(all_balance_losses).mean() # ✨ + all_balance_losses = torch.cat(all_balance_losses).mean()Also applies to: 559-559
fla/layers/mom.py (1)
443-443
: Remove trailing comma to avoid creating a tuple.The trailing comma creates an unnecessary tuple.
Apply this diff:
- q, k, v = self.silu(q), self.silu(k), self.silu(v), + q, k, v = self.silu(q), self.silu(k), self.silu(v)🧰 Tools
🪛 Pylint (3.3.7)
[refactor] 443-443: Disallow trailing comma tuple
(R1707)
fla/layers/mom_varlen.py (2)
505-505
: Remove unused variable assignment.The variable
batchsize
is assigned but never used.Apply this diff:
- batchsize,q_len = hidden_states.shape[0],hidden_states.shape[1] + batch_size, q_len = hidden_states.shape[0], hidden_states.shape[1]Note: If you intended to use
batch_size
later in the code, make sure to update all references accordingly.🧰 Tools
🪛 Ruff (0.11.9)
505-505: Local variable
batchsize
is assigned to but never usedRemove assignment to unused variable
batchsize
(F841)
561-561
: Remove trailing comma to avoid creating a tuple.Apply this diff:
- q, k, v = self.silu(q), self.silu(k), self.silu(v), + q, k, v = self.silu(q), self.silu(k), self.silu(v)🧰 Tools
🪛 Pylint (3.3.7)
[refactor] 561-561: Disallow trailing comma tuple
(R1707)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
fla/__init__.py
(3 hunks)fla/layers/__init__.py
(2 hunks)fla/layers/mom.py
(1 hunks)fla/layers/mom_varlen.py
(1 hunks)fla/models/__init__.py
(2 hunks)fla/models/mom_gated_deltanet/__init__.py
(1 hunks)fla/models/mom_gated_deltanet/configuration_mom_gated_deltanet.py
(1 hunks)fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py
(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (2)
fla/layers/__init__.py (2)
fla/layers/mom.py (1)
MomGatedDeltaNet
(192-613)fla/layers/mom_varlen.py (1)
MomGatedDeltaNet
(309-724)
fla/models/__init__.py (2)
fla/models/mom_gated_deltanet/configuration_mom_gated_deltanet.py (1)
MomGatedDeltaNetConfig
(8-91)fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py (2)
MomGatedDeltaNetForCausalLM
(320-486)MomGatedDeltaNetModel
(209-313)
🪛 Pylint (3.3.7)
fla/models/mom_gated_deltanet/configuration_mom_gated_deltanet.py
[refactor] 8-8: Too many instance attributes (27/7)
(R0902)
[refactor] 11-11: Too many arguments (32/5)
(R0913)
[refactor] 11-11: Too many positional arguments (32/5)
(R0917)
[refactor] 11-11: Too many local variables (33/15)
(R0914)
[refactor] 8-8: Too few public methods (0/2)
(R0903)
fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py
[refactor] 11-11: Use 'from torch import nn' instead
(R0402)
[refactor] 37-37: Too many instance attributes (8/7)
(R0902)
[refactor] 39-39: Too many arguments (7/5)
(R0913)
[refactor] 39-39: Too many positional arguments (7/5)
(R0917)
[refactor] 37-37: Too few public methods (1/2)
(R0903)
[refactor] 128-128: Too many arguments (6/5)
(R0913)
[refactor] 128-128: Too many positional arguments (6/5)
(R0917)
[refactor] 82-82: Too few public methods (1/2)
(R0903)
[refactor] 161-161: Too few public methods (0/2)
(R0903)
[refactor] 230-230: Too many arguments (9/5)
(R0913)
[refactor] 230-230: Too many positional arguments (9/5)
(R0917)
[refactor] 230-230: Too many local variables (17/15)
(R0914)
[error] 307-313: Unexpected keyword argument 'last_hidden_state' in constructor call
(E1123)
[error] 307-313: Unexpected keyword argument 'past_key_values' in constructor call
(E1123)
[error] 307-313: Unexpected keyword argument 'hidden_states' in constructor call
(E1123)
[error] 307-313: Unexpected keyword argument 'attentions' in constructor call
(E1123)
[refactor] 230-230: Too many branches (13/12)
(R0912)
[refactor] 358-367: Unnecessary "else" after "raise", remove the "else" and de-indent the code inside it
(R1720)
[refactor] 369-369: Too many arguments (7/5)
(R0913)
[refactor] 369-369: Too many positional arguments (7/5)
(R0917)
[refactor] 403-403: Too many arguments (11/5)
(R0913)
[refactor] 403-403: Too many positional arguments (11/5)
(R0917)
[refactor] 403-403: Too many local variables (21/15)
(R0914)
[error] 478-486: Unexpected keyword argument 'loss' in constructor call
(E1123)
[error] 478-486: Unexpected keyword argument 'logits' in constructor call
(E1123)
[error] 478-486: Unexpected keyword argument 'past_key_values' in constructor call
(E1123)
[error] 478-486: Unexpected keyword argument 'hidden_states' in constructor call
(E1123)
[error] 478-486: Unexpected keyword argument 'attentions' in constructor call
(E1123)
fla/layers/mom.py
[refactor] 443-443: Disallow trailing comma tuple
(R1707)
[refactor] 10-10: Use 'from torch import nn' instead
(R0402)
[refactor] 34-34: Too many local variables (29/15)
(R0914)
[refactor] 129-129: Too many arguments (8/5)
(R0913)
[refactor] 129-129: Too many positional arguments (8/5)
(R0917)
[refactor] 129-129: Too many local variables (19/15)
(R0914)
[refactor] 192-192: Too many instance attributes (39/7)
(R0902)
[refactor] 239-239: Too many arguments (17/5)
(R0913)
[refactor] 239-239: Too many positional arguments (17/5)
(R0917)
[refactor] 239-239: Too many local variables (25/15)
(R0914)
[refactor] 239-239: Too many statements (68/50)
(R0915)
[refactor] 367-367: Too many arguments (6/5)
(R0913)
[refactor] 367-367: Too many positional arguments (6/5)
(R0917)
[refactor] 367-367: Too many local variables (42/15)
(R0914)
[error] 486-496: Unexpected keyword argument 'head_first' in function call
(E1123)
[error] 510-510: Undefined variable 'o'
(E0602)
[refactor] 367-367: Too many branches (22/12)
(R0912)
[refactor] 367-367: Too many statements (87/50)
(R0915)
[refactor] 531-531: Too many arguments (8/5)
(R0913)
[refactor] 531-531: Too many positional arguments (8/5)
(R0917)
[refactor] 531-531: Too many local variables (19/15)
(R0914)
[error] 599-609: Unexpected keyword argument 'head_first' in function call
(E1123)
fla/layers/mom_varlen.py
[refactor] 561-561: Disallow trailing comma tuple
(R1707)
[refactor] 10-10: Use 'from torch import nn' instead
(R0402)
[refactor] 64-64: Too many arguments (7/5)
(R0913)
[refactor] 64-64: Too many positional arguments (7/5)
(R0917)
[refactor] 64-64: Too many local variables (17/15)
(R0914)
[error] 108-108: Possibly using variable 'index_first_axis' before assignment
(E0606)
[error] 129-129: Possibly using variable 'unpad_input' before assignment
(E0606)
[refactor] 149-149: Too many local variables (31/15)
(R0914)
[refactor] 246-246: Too many arguments (8/5)
(R0913)
[refactor] 246-246: Too many positional arguments (8/5)
(R0917)
[refactor] 246-246: Too many local variables (19/15)
(R0914)
[refactor] 309-309: Too many instance attributes (39/7)
(R0902)
[refactor] 356-356: Too many arguments (17/5)
(R0913)
[refactor] 356-356: Too many positional arguments (17/5)
(R0917)
[refactor] 356-356: Too many local variables (25/15)
(R0914)
[refactor] 356-356: Too many statements (68/50)
(R0915)
[refactor] 484-484: Too many arguments (6/5)
(R0913)
[refactor] 484-484: Too many positional arguments (6/5)
(R0917)
[refactor] 484-484: Too many local variables (52/15)
(R0914)
[error] 595-595: Possibly using variable 'pad_input' before assignment
(E0606)
[error] 602-612: Unexpected keyword argument 'head_first' in function call
(E1123)
[error] 621-621: Undefined variable 'o'
(E0602)
[refactor] 484-484: Too many branches (16/12)
(R0912)
[refactor] 484-484: Too many statements (85/50)
(R0915)
[refactor] 642-642: Too many arguments (8/5)
(R0913)
[refactor] 642-642: Too many positional arguments (8/5)
(R0917)
[refactor] 642-642: Too many local variables (19/15)
(R0914)
[error] 710-720: Unexpected keyword argument 'head_first' in function call
(E1123)
🪛 Ruff (0.11.9)
fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py
243-243: No explicit stacklevel
keyword argument found
Set stacklevel=2
(B028)
359-365: Within an except
clause, raise exceptions with raise ... from err
or raise ... from None
to distinguish them from errors in exception handling
(B904)
fla/layers/mom.py
425-425: Use kwargs.get('seq_idx')
instead of kwargs.get('seq_idx', None)
Replace kwargs.get('seq_idx', None)
with kwargs.get('seq_idx')
(SIM910)
456-456: Use kwargs.get('offsets')
instead of kwargs.get('offsets', None)
Replace kwargs.get('offsets', None)
with kwargs.get('offsets')
(SIM910)
537-537: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
538-538: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
539-539: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
555-555: Use kwargs.get('seq_idx')
instead of kwargs.get('seq_idx', None)
Replace kwargs.get('seq_idx', None)
with kwargs.get('seq_idx')
(SIM910)
584-584: Use kwargs.get('offsets')
instead of kwargs.get('offsets', None)
Replace kwargs.get('offsets', None)
with kwargs.get('offsets')
(SIM910)
fla/layers/mom_varlen.py
29-29: Redefinition of unused pad_input
from line 28
Remove definition: pad_input
(F811)
505-505: Local variable batchsize
is assigned to but never used
Remove assignment to unused variable batchsize
(F841)
543-543: Use kwargs.get('seq_idx')
instead of kwargs.get('seq_idx', None)
Replace kwargs.get('seq_idx', None)
with kwargs.get('seq_idx')
(SIM910)
579-579: Local variable offsets
is assigned to but never used
Remove assignment to unused variable offsets
(F841)
579-579: Use kwargs.get('offsets')
instead of kwargs.get('offsets', None)
Replace kwargs.get('offsets', None)
with kwargs.get('offsets')
(SIM910)
648-648: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
649-649: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
650-650: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
666-666: Use kwargs.get('seq_idx')
instead of kwargs.get('seq_idx', None)
Replace kwargs.get('seq_idx', None)
with kwargs.get('seq_idx')
(SIM910)
695-695: Use kwargs.get('offsets')
instead of kwargs.get('offsets', None)
Replace kwargs.get('offsets', None)
with kwargs.get('offsets')
(SIM910)
🔇 Additional comments (1)
fla/models/mom_gated_deltanet/__init__.py (1)
10-12
: Repeated registration can raise on re-import.
AutoConfig.register
,AutoModel.register
, … will raise aValueError
if the key is already registered (e.g. in unit-test re-imports or interactive sessions). A defensive guard avoids noisy stack traces:-AutoConfig.register(MomGatedDeltaNetConfig.model_type, MomGatedDeltaNetConfig) -AutoModel.register(MomGatedDeltaNetConfig, MomGatedDeltaNetModel) -AutoModelForCausalLM.register(MomGatedDeltaNetConfig, MomGatedDeltaNetForCausalLM) +for registry, key, value in [ + (AutoConfig, MomGatedDeltaNetConfig.model_type, MomGatedDeltaNetConfig), + (AutoModel, MomGatedDeltaNetConfig, MomGatedDeltaNetModel), + (AutoModelForCausalLM, MomGatedDeltaNetConfig, MomGatedDeltaNetForCausalLM), +]: + try: + registry.register(key, value) + except ValueError: + # Already registered – safe to ignore in repeated imports. + pass
fla/layers/__init__.py
Outdated
@@ -25,6 +25,7 @@ | |||
from .rodimus import RodimusAttention, SlidingWindowSharedKeyAttention | |||
from .rwkv6 import RWKV6Attention | |||
from .rwkv7 import RWKV7Attention | |||
from .mom import MomGatedDeltaNet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ambiguous class resolution – two different MomGatedDeltaNet
implementations exist.
Both fla/layers/mom.py
and fla/layers/mom_varlen.py
declare a class named MomGatedDeltaNet
, but only the one from mom.py
is imported here. Call-sites that expect the variable-length behaviour implemented in mom_varlen.py
will silently get the other variant, which is extremely hard to debug.
Consider:
-from .mom import MomGatedDeltaNet
+# Prefer an explicit alias to make the chosen variant obvious
+from .mom import MomGatedDeltaNet as MomGatedDeltaNetFixedLen
+# If you need the var-len version as well, export it under a distinct name
+from .mom_varlen import MomGatedDeltaNet as MomGatedDeltaNetVarLen
…and expose the chosen symbol(s) accordingly in __all__
.
Failing to disambiguate the two will almost certainly bite you the moment both variants are needed in the same runtime.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
from .mom import MomGatedDeltaNet | |
# Prefer an explicit alias to make the chosen variant obvious | |
from .mom import MomGatedDeltaNet as MomGatedDeltaNetFixedLen | |
# If you need the var-len version as well, export it under a distinct name | |
from .mom_varlen import MomGatedDeltaNet as MomGatedDeltaNetVarLen |
🤖 Prompt for AI Agents
In fla/layers/__init__.py at line 28, there is ambiguity because two different
classes named MomGatedDeltaNet exist in mom.py and mom_varlen.py, but only the
one from mom.py is imported. To fix this, explicitly import both classes with
distinct aliases or names to differentiate them, for example, MomGatedDeltaNet
and MomGatedDeltaNetVarLen. Then, update the __all__ list to expose these
distinct names clearly. This will prevent silent confusion and ensure call-sites
use the intended variant.
def __init__( | ||
self, | ||
attn_mode: str = "chunk", | ||
hidden_size: int = 2048, | ||
expand_v: int = 2, | ||
use_gate: bool = True, | ||
use_short_conv: bool = True, | ||
conv_size: int = 4, | ||
head_dim: int = 256, | ||
num_heads: int = 6, | ||
max_position_embeddings: int = 2048, | ||
hidden_ratio: Optional[int] = 4, | ||
intermediate_size: Optional[int] = None, | ||
hidden_act: str = "swish", | ||
num_hidden_layers: int = 21, | ||
norm_first: bool = False, | ||
norm_eps: float = 1e-6, | ||
attn: Optional[Dict] = None, | ||
use_cache: bool = True, | ||
pad_token_id: int = None, | ||
bos_token_id: int = 1, | ||
eos_token_id: int = 2, | ||
tie_word_embeddings: bool = False, | ||
initializer_range: float = 0.02, | ||
fuse_cross_entropy: bool = True, | ||
vocab_size: int = 32000, | ||
num_memories: int = 8, | ||
topk: int = 2, | ||
capacity: float = 1.0, | ||
use_layer_wise_balance: bool=True, | ||
aux_loss_scale: float=0.01, | ||
shared_mem: bool = False, | ||
single_kv_proj: bool = False, | ||
**kwargs | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Constructor mutates input dict – copy instead.
attn
is modified in-place (attn['num_kv_heads'] = …
). Callers that reuse the same dict for multiple configs will observe surprising side-effects.
- self.attn = attn
+ # keep the original object untouched
+ self.attn = dict(attn) if attn is not None else None
🧰 Tools
🪛 Pylint (3.3.7)
[refactor] 11-11: Too many arguments (32/5)
(R0913)
[refactor] 11-11: Too many positional arguments (32/5)
(R0917)
[refactor] 11-11: Too many local variables (33/15)
(R0914)
🤖 Prompt for AI Agents
In fla/models/mom_gated_deltanet/configuration_mom_gated_deltanet.py around
lines 11 to 45, the constructor modifies the input dictionary 'attn' in-place,
which can cause unexpected side effects if the same dictionary is reused
elsewhere. To fix this, create a shallow copy of the 'attn' dictionary at the
start of the constructor before making any modifications, ensuring the original
input remains unchanged.
if loss==0: | ||
breakpoint() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove debug breakpoint from production code.
The breakpoint()
call will halt execution and should never be in production code.
Apply this diff to remove the debug code:
- if loss==0:
- breakpoint()
If you need to debug zero loss issues, consider using proper logging instead:
if loss == 0:
logger.warning("Computed loss is zero, which may indicate an issue with the model or data.")
🤖 Prompt for AI Agents
In fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py around lines 458
to 459, remove the debug breakpoint() call as it halts execution and should not
be in production code. Replace it with a logging statement that warns when loss
equals zero, such as using logger.warning to indicate a potential issue with the
model or data.
|
||
Args: | ||
gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): | ||
Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_memories]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix typo in docstring.
Apply this diff:
- Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_memories].
+ Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, sequence_length, num_memories].
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, seqeunce_length, num_memories]. | |
Logits from the `gate`, should be a tuple of tensors. Shape: [batch_size, sequence_length, num_memories]. |
🤖 Prompt for AI Agents
In fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py at line 504,
there is a typo in the docstring where "seqeunce_length" should be corrected to
"sequence_length". Update the docstring to fix this spelling mistake.
|
||
config_class = MomGatedDeltaNetConfig | ||
supports_gradient_checkpointing = True | ||
_no_split_modules = ['GatedDeltaNetBlock'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix incorrect module name in _no_split_modules
.
The module name should be MomGatedDeltaNetBlock
to match the actual class name.
Apply this diff:
- _no_split_modules = ['GatedDeltaNetBlock']
+ _no_split_modules = ['MomGatedDeltaNetBlock']
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
_no_split_modules = ['GatedDeltaNetBlock'] | |
_no_split_modules = ['MomGatedDeltaNetBlock'] |
🤖 Prompt for AI Agents
In fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py at line 165, the
module name in the _no_split_modules list is incorrect. Replace
'GatedDeltaNetBlock' with the correct class name 'MomGatedDeltaNetBlock' to
ensure it matches the actual class name used in the code.
fla/layers/mom_varlen.py
Outdated
conv_state_q = [None, None], | ||
conv_state_k = [None, None], | ||
conv_state_v = [None, None], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid mutable default arguments.
Using mutable objects as default arguments can lead to unexpected behavior.
Apply this diff:
- conv_state_q = [None, None],
- conv_state_k = [None, None],
- conv_state_v = [None, None],
+ conv_state_q = None,
+ conv_state_k = None,
+ conv_state_v = None,
Then initialize them in the method:
if conv_state_q is None:
conv_state_q = [None, None]
if conv_state_k is None:
conv_state_k = [None, None]
if conv_state_v is None:
conv_state_v = [None, None]
🧰 Tools
🪛 Ruff (0.11.9)
648-648: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
649-649: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
650-650: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
🤖 Prompt for AI Agents
In fla/layers/mom_varlen.py around lines 648 to 650, the default arguments
conv_state_q, conv_state_k, and conv_state_v are mutable lists, which can cause
unexpected behavior. Change their default values to None in the function
signature, then inside the method initialize them with the provided code snippet
to assign [None, None] if they are None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
fla/layers/__init__.py (1)
29-29
: Ambiguous class resolution still unresolvedThere are two distinct
MomGatedDeltaNet
implementations (mom.py
,mom_varlen.py
), but only the fixed-length one is re-exported here. Call-sites requiring the var-len behaviour will silently receive the wrong class.
Consider aliasing the variants explicitly (e.g.MomGatedDeltaNetFixedLen
,MomGatedDeltaNetVarLen
) and exporting both.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
fla/__init__.py
(4 hunks)fla/layers/__init__.py
(2 hunks)fla/models/__init__.py
(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- fla/models/init.py
- fla/init.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
fla/layers/__init__.py (2)
fla/layers/mom.py (1)
MomGatedDeltaNet
(192-613)fla/layers/mom_varlen.py (1)
MomGatedDeltaNet
(309-724)
🪛 Ruff (0.11.9)
fla/layers/__init__.py
29-29: .mom.MomGatedDeltaNet
imported but unused; consider removing, adding to __all__
, or using a redundant alias
(F401)
🪛 Pylint (3.3.7)
fla/layers/__init__.py
[error] 49-50: Undefined variable name 'MomGatedDeltaNetMultiScaleRetention' in all
(E0603)
🪛 GitHub Actions: lint
fla/layers/__init__.py
[error] 1-1: Imports are not sorted properly; fixed by isort hook.
[error] 22-23: flake8: F401 '.mom.MomGatedDeltaNet' and '.multiscale_retention.MultiScaleRetention' imported but unused.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 11
♻️ Duplicate comments (9)
fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py (4)
39-47
: Fix incorrect return type annotation.The
__init__
method should returnNone
, notMomGatedDeltaNetMLP
.
165-165
: Fix incorrect module name in_no_split_modules
.The module name should be
MomGatedDeltaNetBlock
to match the actual class name.
243-243
: Fix incorrect class name in warning message and add stacklevel.The warning message should reference the correct class name and include proper stacklevel.
505-505
: Fix typo in docstring.The docstring contains a spelling error.
fla/layers/mom_varlen.py (5)
28-29
: Remove duplicate import ofpad_input
.The
pad_input
function is imported twice on consecutive lines.
375-375
: Fix incorrect return type annotation.The
__init__
method should returnNone
, notMomGatedDeltaNet
.
403-403
: Fix typo in assertion error message.The assertion message contains a spelling error.
619-622
: Fix critical bug: undefined variable 'o' when shared_mem is True.The variable
o
is used before it's defined whenself.shared_mem
is True.
642-650
: Fix mutable default arguments and formatting.The method signature uses mutable default arguments which can cause unexpected behavior.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
fla/layers/__init__.py
(2 hunks)fla/layers/mom_varlen.py
(1 hunks)fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- fla/layers/init.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
fla/layers/mom_varlen.py (8)
fla/modules/fused_norm_gate.py (1)
FusedRMSNormSwishGate
(1082-1098)fla/modules/convolution.py (1)
ShortConvolution
(544-741)fla/modules/l2norm.py (1)
l2norm
(253-258)fla/ops/gated_delta_rule/chunk.py (1)
chunk_gated_delta_rule
(225-342)fla/ops/gated_delta_rule/fused_recurrent.py (1)
fused_recurrent_gated_delta_rule
(212-314)fla/models/utils.py (2)
Cache
(11-148)update
(43-120)fla/layers/utils.py (2)
pad_input
(176-197)unpad_input
(101-173)fla/layers/mom.py (6)
elu_p1
(25-26)sum_norm
(29-30)transform
(34-126)reconstruct
(129-189)MomGatedDeltaNet
(192-613)forward
(367-528)
🪛 Ruff (0.11.9)
fla/layers/mom_varlen.py
29-29: Redefinition of unused pad_input
from line 28
Remove definition: pad_input
(F811)
505-505: Local variable batchsize
is assigned to but never used
Remove assignment to unused variable batchsize
(F841)
543-543: Use kwargs.get('seq_idx')
instead of kwargs.get('seq_idx', None)
Replace kwargs.get('seq_idx', None)
with kwargs.get('seq_idx')
(SIM910)
579-579: Local variable offsets
is assigned to but never used
Remove assignment to unused variable offsets
(F841)
579-579: Use kwargs.get('offsets')
instead of kwargs.get('offsets', None)
Replace kwargs.get('offsets', None)
with kwargs.get('offsets')
(SIM910)
648-648: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
649-649: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
650-650: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
666-666: Use kwargs.get('seq_idx')
instead of kwargs.get('seq_idx', None)
Replace kwargs.get('seq_idx', None)
with kwargs.get('seq_idx')
(SIM910)
695-695: Use kwargs.get('offsets')
instead of kwargs.get('offsets', None)
Replace kwargs.get('offsets', None)
with kwargs.get('offsets')
(SIM910)
fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py
243-243: No explicit stacklevel
keyword argument found
Set stacklevel=2
(B028)
359-365: Within an except
clause, raise exceptions with raise ... from err
or raise ... from None
to distinguish them from errors in exception handling
(B904)
🪛 GitHub Actions: lint
fla/layers/mom_varlen.py
[error] 30-33: flake8: E302 expected 2 blank lines, found 1
[error] 45-45: flake8: E501 line too long (131 > 127 characters)
[error] 61-61: flake8: E303 too many blank lines (3)
[error] 71-73: flake8: E501 line too long (128-135 > 127 characters)
[error] 98-100: flake8: E501 line too long (170-178 > 127 characters)
[error] 141-146: flake8: E303 too many blank lines (3) and E302 expected 2 blank lines, found 1 and E501 line too long (128 > 127 characters)
[error] 212-212: flake8: E231 missing whitespace after ','
[error] 221-226: flake8: E501 line too long (185-213 > 127 characters)
[error] 239-243: flake8: E231 missing whitespace after ',' and E302 expected 2 blank lines, found 1 and E501 line too long (177 > 127 characters)
[error] 250-291: flake8: E501 line too long (128-137 > 127 characters), E203 whitespace before ',', E231 missing whitespace after ','
[error] 410-412: flake8: E222 multiple spaces after operator and E501 line too long (131-133 > 127 characters)
[error] 413-413: flake8: E501 line too long (132 > 127 characters)
[error] 502-502: flake8: F841 local variable 'batchsize' is assigned to but never used and E231 missing whitespace after ','
[error] 512-521: flake8: E501 line too long (162-223 > 127 characters)
[error] 533-540: flake8: E501 line too long (147-169 > 127 characters) and E225 missing whitespace around operator
[error] 543-553: flake8: E128 continuation line under-indented for visual indent and E231 missing whitespace after ','
[error] 565-568: flake8: E231 missing whitespace after ',' and E225 missing whitespace around operator
[error] 575-576: flake8: E501 line too long (129 > 127 characters) and F841 local variable 'offsets' is assigned to but never used
[error] 593-594: flake8: E231 missing whitespace after ',' (multiple occurrences)
[error] 596-617: flake8: E501 line too long (144-183 > 127 characters) and E231 missing whitespace after ','
[error] 639-647: flake8: E303 too many blank lines (2) and E251 unexpected spaces around keyword/parameter equals
[error] 663-675: flake8: E225 missing whitespace around operator, E128 continuation line under-indented for visual indent, and E231 missing whitespace after ','
[error] pre-commit hooks trailing-whitespace, end-of-file-fixer, and isort modified files indicating formatting issues.
fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py
[error] 202-206: flake8: E302 expected 2 blank lines, found 1
[error] 214-214: flake8: E501 line too long (128 > 127 characters)
[error] 312-317: flake8: E302 expected 2 blank lines, found 1
[error] 533-533: flake8: E303 too many blank lines (2)
[error] pre-commit hook isort modified file indicating import sorting issues.
raise AttributeError( | ||
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " | ||
f"which is not supported for {self.__class__.__name__}. " | ||
f"Try another generation strategy instead. " | ||
f"For the available generation strategies, check this doc: " | ||
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" | ||
) | ||
else: | ||
raise exception |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Improve exception handling with proper chaining.
The current exception handling doesn't follow best practices for exception chaining.
Apply this improvement:
except AttributeError as exception:
if 'past_key_values' in str(exception):
raise AttributeError(
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
f"which is not supported for {self.__class__.__name__}. "
f"Try another generation strategy instead. "
f"For the available generation strategies, check this doc: "
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
- )
+ ) from exception
else:
- raise exception
+ raise
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
raise AttributeError( | |
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " | |
f"which is not supported for {self.__class__.__name__}. " | |
f"Try another generation strategy instead. " | |
f"For the available generation strategies, check this doc: " | |
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" | |
) | |
else: | |
raise exception | |
except AttributeError as exception: | |
if 'past_key_values' in str(exception): | |
raise AttributeError( | |
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " | |
f"which is not supported for {self.__class__.__name__}. " | |
f"Try another generation strategy instead. " | |
f"For the available generation strategies, check this doc: " | |
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" | |
) from exception | |
else: | |
raise |
🧰 Tools
🪛 Ruff (0.11.9)
359-365: Within an except
clause, raise exceptions with raise ... from err
or raise ... from None
to distinguish them from errors in exception handling
(B904)
🪛 GitHub Actions: lint
[error] pre-commit hook isort modified file indicating import sorting issues.
🤖 Prompt for AI Agents
In fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py around lines 359
to 367, the exception handling raises a new AttributeError without chaining the
original exception, which loses the context of the initial error. Fix this by
using "raise ... from exception" syntax to properly chain the new AttributeError
to the original exception, preserving the traceback and error context.
with torch.no_grad(): | ||
p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) | ||
|
||
@dataclass | ||
class MomGatedDeltaNetOutputWithPast(BaseModelOutputWithPast): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix formatting issues to comply with pipeline requirements.
The pipeline failures indicate missing blank lines and line length violations that need to be addressed.
Apply these formatting fixes:
+
@dataclass
class MomGatedDeltaNetOutputWithPast(BaseModelOutputWithPast):
router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None
+
class MomGatedDeltaNetModel(MomGatedDeltaNetPreTrainedModel):
Also fix the long line 214 by breaking it appropriately:
- self.layers = nn.ModuleList([MomGatedDeltaNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
+ self.layers = nn.ModuleList([
+ MomGatedDeltaNetBlock(config, layer_idx)
+ for layer_idx in range(config.num_hidden_layers)
+ ])
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
with torch.no_grad(): | |
p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) | |
@dataclass | |
class MomGatedDeltaNetOutputWithPast(BaseModelOutputWithPast): | |
with torch.no_grad(): | |
p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) | |
@dataclass | |
class MomGatedDeltaNetOutputWithPast(BaseModelOutputWithPast): | |
router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None | |
class MomGatedDeltaNetModel(MomGatedDeltaNetPreTrainedModel): | |
# … |
with torch.no_grad(): | |
p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers) | |
@dataclass | |
class MomGatedDeltaNetOutputWithPast(BaseModelOutputWithPast): | |
# break the long comprehension onto multiple lines for readability/line-length | |
self.layers = nn.ModuleList([ | |
MomGatedDeltaNetBlock(config, layer_idx) | |
for layer_idx in range(config.num_hidden_layers) | |
]) |
🧰 Tools
🪛 GitHub Actions: lint
[error] 202-206: flake8: E302 expected 2 blank lines, found 1
[error] pre-commit hook isort modified file indicating import sorting issues.
🤖 Prompt for AI Agents
In fla/models/mom_gated_deltanet/modeling_mom_gated_deltanet.py around lines 202
to 206, fix formatting issues by adding the required blank lines before the
@dataclass decorator and ensuring proper separation between code blocks.
Additionally, locate line 214 and break the long line into shorter lines to
comply with line length limits, using appropriate indentation for readability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (4)
fla/layers/mom.py (4)
249-249
: Fix incorrect return type annotation.
277-277
: Fix typo in assertion error message.
497-500
: Fix critical bug: undefined variable 'o' when shared_mem is True.
526-528
: Avoid mutable default arguments.
🧹 Nitpick comments (7)
fla/layers/mom.py (7)
393-393
: Remove trailing whitespace.Apply this diff:
- if self.use_gate: - o_g = self.g_proj(hidden_states) - + if self.use_gate: + o_g = self.g_proj(hidden_states) +
418-429
: Fix indentation for method arguments.The continuation lines should be properly indented for better readability.
Apply this diff:
- q, conv_state_q[0] = self.q_conv1d(x=q, - mask=conv_mask, - cache=conv_state_q[0], - output_final_state=use_cache,seq_idx=seq_idx) - k, conv_state_k[0] = self.k_conv1d(x=k, - mask=conv_mask, - cache=conv_state_k[0], - output_final_state=use_cache,seq_idx=seq_idx) - v, conv_state_v[0] = self.v_conv1d(x=v, - mask=conv_mask, - cache=conv_state_v[0], - output_final_state=use_cache,seq_idx=seq_idx) + q, conv_state_q[0] = self.q_conv1d(x=q, + mask=conv_mask, + cache=conv_state_q[0], + output_final_state=use_cache, + seq_idx=seq_idx) + k, conv_state_k[0] = self.k_conv1d(x=k, + mask=conv_mask, + cache=conv_state_k[0], + output_final_state=use_cache, + seq_idx=seq_idx) + v, conv_state_v[0] = self.v_conv1d(x=v, + mask=conv_mask, + cache=conv_state_v[0], + output_final_state=use_cache, + seq_idx=seq_idx)
462-463
: Add missing whitespace after comma.Apply this diff:
- o_e = o_e[:,-max_len:,:,:].to(dtype=q[e].dtype) + o_e = o_e[:, -max_len:, :, :].to(dtype=q[e].dtype)
473-474
: Add missing whitespace after comma.Apply this diff:
- if not hidden_states[e, 0].any() and hidden_states.shape[1] == 1: - o_list[e] = torch.zeros_like(v[e,:,-max_len:,:,:]) + if not hidden_states[e, 0].any() and hidden_states.shape[1] == 1: + o_list[e] = torch.zeros_like(v[e, :, -max_len:, :, :])
486-486
: Add missing whitespace after comma.Apply this diff:
- o_e = o_e[:,-max_len:,:,:] + o_e = o_e[:, -max_len:, :, :]
545-556
: Fix indentation for method arguments.The continuation lines should be properly indented for better readability.
Apply this diff:
- q, conv_state_q[1] = self.q_conv1d(x=self.q_proj(hidden_states), - mask=conv_mask, - cache=conv_state_q[1], - output_final_state=use_cache,seq_idx=seq_idx) - k, conv_state_k[1] = self.k_conv1d(x=self.shared_k(hidden_states), - mask=conv_mask, - cache=conv_state_k[1], - output_final_state=use_cache,seq_idx=seq_idx) - v, conv_state_v[1] = self.v_conv1d(x=self.shared_v(hidden_states), - mask=conv_mask, - cache=conv_state_v[1], - output_final_state=use_cache,seq_idx=seq_idx) + q, conv_state_q[1] = self.q_conv1d(x=self.q_proj(hidden_states), + mask=conv_mask, + cache=conv_state_q[1], + output_final_state=use_cache, + seq_idx=seq_idx) + k, conv_state_k[1] = self.k_conv1d(x=self.shared_k(hidden_states), + mask=conv_mask, + cache=conv_state_k[1], + output_final_state=use_cache, + seq_idx=seq_idx) + v, conv_state_v[1] = self.v_conv1d(x=self.shared_v(hidden_states), + mask=conv_mask, + cache=conv_state_v[1], + output_final_state=use_cache, + seq_idx=seq_idx)
183-518
: Add tests and l2warp support as requested in PR comments.Based on the PR review comments:
- Tests should be added for this new module in
tests/test_model.py
andtests/test_generation.py
- Consider adding l2warp support similar to other models in the codebase
Would you like me to help generate the test cases for the MomGatedDeltaNet module?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
fla/layers/mom.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.11.9)
fla/layers/mom.py
416-416: Use kwargs.get('seq_idx')
instead of kwargs.get('seq_idx', None)
Replace kwargs.get('seq_idx', None)
with kwargs.get('seq_idx')
(SIM910)
447-447: Use kwargs.get('cu_seqlens')
instead of kwargs.get('cu_seqlens', None)
Replace kwargs.get('cu_seqlens', None)
with kwargs.get('cu_seqlens')
(SIM910)
526-526: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
527-527: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
528-528: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
544-544: Use kwargs.get('seq_idx')
instead of kwargs.get('seq_idx', None)
Replace kwargs.get('seq_idx', None)
with kwargs.get('seq_idx')
(SIM910)
573-573: Use kwargs.get('cu_seqlens')
instead of kwargs.get('cu_seqlens', None)
Replace kwargs.get('cu_seqlens', None)
with kwargs.get('cu_seqlens')
(SIM910)
🪛 Flake8 (7.2.0)
fla/layers/mom.py
[error] 419-419: continuation line under-indented for visual indent
(E128)
[error] 420-420: continuation line under-indented for visual indent
(E128)
[error] 421-421: continuation line under-indented for visual indent
(E128)
[error] 423-423: continuation line under-indented for visual indent
(E128)
[error] 424-424: continuation line under-indented for visual indent
(E128)
[error] 425-425: continuation line under-indented for visual indent
(E128)
[error] 427-427: continuation line under-indented for visual indent
(E128)
[error] 428-428: continuation line under-indented for visual indent
(E128)
[error] 429-429: continuation line under-indented for visual indent
(E128)
[error] 546-546: continuation line under-indented for visual indent
(E128)
[error] 547-547: continuation line under-indented for visual indent
(E128)
[error] 548-548: continuation line under-indented for visual indent
(E128)
[error] 550-550: continuation line under-indented for visual indent
(E128)
[error] 551-551: continuation line under-indented for visual indent
(E128)
[error] 552-552: continuation line under-indented for visual indent
(E128)
[error] 554-554: continuation line under-indented for visual indent
(E128)
[error] 555-555: continuation line under-indented for visual indent
(E128)
[error] 556-556: continuation line under-indented for visual indent
(E128)
🪛 GitHub Actions: lint
fla/layers/mom.py
[error] 24-24: flake8: line too long (128 > 127 characters) (E501)
[error] 90-90: flake8: missing whitespace after ',' (E231)
[error] 99-99: flake8: line too long (213 > 127 characters) (E501)
[error] 104-104: flake8: line too long (185 > 127 characters) (E501)
[error] 119-119: flake8: expected 2 blank lines, found 1 (E302)
[error] 119-119: flake8: line too long (177 > 127 characters) (E501)
[error] 126-126: flake8: line too long (128 > 127 characters) (E501)
[error] 159-159: flake8: line too long (135 > 127 characters) (E501)
[error] 161-161: flake8: line too long (137 > 127 characters) (E501)
[error] 167-167: flake8: whitespace before ',' (E203) and missing whitespace after ',' (E231)
[error] 286-288: flake8: multiple spaces after operator (E222) and line too long (131-133 > 127 characters) (E501)
[error] 287-289: flake8: line too long (132 > 127 characters) (E501)
[error] 387-387: flake8: line too long (223 > 127 characters) (E501)
[error] 396-396: flake8: line too long (155 > 127 characters) (E501)
[error] 408-408: flake8: line too long (169 > 127 characters) (E501)
[error] 414-415: flake8: line too long (147 > 127 characters) (E501) and missing whitespace around operator (E225)
[error] 418-428: flake8: continuation line under-indented for visual indent (E128) and missing whitespace after ',' (E231)
[error] 445-445: flake8: line too long (145 > 127 characters) (E501)
[error] 462-462: flake8: missing whitespace after ',' (E231)
[error] 466-466: flake8: line too long (183 > 127 characters) (E501)
[error] 473-473: flake8: missing whitespace after ',' (E231)
[error] 485-485: flake8: missing whitespace after ',' (E231)
[error] 494-497: flake8: line too long (144-183 > 127 characters) (E501)
[error] 519-519: flake8: too many blank lines (2) (E303)
[error] 523-527: flake8: unexpected spaces around keyword/parameter equals (E251)
[error] 543-555: flake8: missing whitespace around operator (E225), continuation line under-indented (E128), and missing whitespace after ',' (E231)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test H100 (PyTorch 2.7) / test-models
🔇 Additional comments (1)
fla/layers/mom.py (1)
415-415
: Add missing whitespace before 'if'.Apply this diff:
- conv_mask = attention_mask[:, -hidden_states.shape[2]:].repeat_interleave(self.num_memories, 0)if attention_mask is not None else None + conv_mask = attention_mask[:, -hidden_states.shape[2]:].repeat_interleave(self.num_memories, 0) if attention_mask is not None else NoneLikely an incorrect or invalid review comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 9
♻️ Duplicate comments (7)
fla/layers/mom_varlen.py (7)
27-29
: Remove duplicate import ofpad_input
.The
pad_input
function is imported twice on consecutive lines.Apply this diff:
if is_flash_attn_2_available(): from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input - from flash_attn.bert_padding import pad_input
356-375
: Fix incorrect return type annotation.The
__init__
method should returnNone
, notMomGatedDeltaNet
.Apply this diff:
- ) -> MomGatedDeltaNet: + ) -> None:
403-403
: Fix typo in assertion error message.Apply this diff:
- assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + assert mode in ['chunk', 'fused_recurrent'], f"Not supported mode `{mode}`."
586-586
: Remove unused variableoffsets
.The variable is assigned but never used.
Apply this diff:
recurrent_state = last_state['recurrent_state'] if last_state is not None else [None for _ in range(1 + self.shared_mem)] - offsets = kwargs.get('offsets', None) # Note: In the updated version of FLA, "offset" has been renamed to "cu_seqlens".
626-629
: Fix critical bug: undefined variable 'o' when shared_mem is True.The variable
o
is used before being defined whenself.shared_mem
is True.You need to ensure
o
is defined before this block. The logic suggests thato
should already be computed from the reconstruction step above. Ifshared_mem
is True but neither 'chunk' nor 'fused_recurrent' mode was executed, this will fail.Consider adding a check or initializing
o
appropriately:if self.shared_mem: + if 'o' not in locals(): + raise RuntimeError("Output 'o' not computed before shared memory processing") shared_o = self.shared_o(shared_hidden_states, attention_mask, recurrent_state, use_cache, conv_state_q, conv_state_k, conv_state_v) o += shared_o
655-657
: Avoid mutable default arguments.Using mutable objects as default arguments can lead to unexpected behavior.
Apply this diff:
- conv_state_q = [None, None], - conv_state_k = [None, None], - conv_state_v = [None, None], + conv_state_q = None, + conv_state_k = None, + conv_state_v = None,Then initialize them in the method:
if conv_state_q is None: conv_state_q = [None, None] if conv_state_k is None: conv_state_k = [None, None] if conv_state_v is None: conv_state_v = [None, None]
508-508
: Remove unused variablebatchsize
.The variable is assigned but never used.
Apply this diff:
- batchsize,q_len = hidden_states.shape[0],hidden_states.shape[1] + batch_size, q_len = hidden_states.shape[0], hidden_states.shape[1]Note: Ensure to update all references from
batch_size
variable defined on line 524 to use this one instead.
🧹 Nitpick comments (3)
fla/models/mom/modeling_mom.py (1)
267-269
: Use standard logger.warning method.The
logger.warning_once
method may not be available in all logging configurations.Apply this diff:
if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + logger.warning("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") use_cache = Falsefla/layers/mom_varlen.py (2)
215-218
: Fix formatting issues in transform function.Missing whitespace after comma and inconsistent spacing.
Apply this diff:
- memory_batch_offset = offset.transpose(0,1) + memory_batch_offset = offset.transpose(0, 1) batch_offset = torch.arange(0, b*s, s, device=offset.device)
411-422
: Consider parameter organization for shared memory projections.The initialization of projection layers has inconsistent indentation and could be better organized.
Apply this diff for better readability:
else: - self.k_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.key_dim, bias=False) for _ in range(self.num_memories)]) - self.v_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.value_dim, bias=False) for _ in range(self.num_memories)]) - self.b_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_heads, bias=False) for _ in range(self.num_memories)]) - self.a_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_heads, bias=False) for _ in range(self.num_memories)]) + self.k_proj = nn.ModuleList([ + nn.Linear(self.hidden_size, self.key_dim, bias=False) + for _ in range(self.num_memories) + ]) + self.v_proj = nn.ModuleList([ + nn.Linear(self.hidden_size, self.value_dim, bias=False) + for _ in range(self.num_memories) + ]) + self.b_proj = nn.ModuleList([ + nn.Linear(self.hidden_size, self.num_heads, bias=False) + for _ in range(self.num_memories) + ]) + self.a_proj = nn.ModuleList([ + nn.Linear(self.hidden_size, self.num_heads, bias=False) + for _ in range(self.num_memories) + ])
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
fla/__init__.py
(3 hunks)fla/layers/mom_varlen.py
(1 hunks)fla/models/__init__.py
(2 hunks)fla/models/mom/__init__.py
(1 hunks)fla/models/mom/configuration_mom.py
(1 hunks)fla/models/mom/modeling_mom.py
(1 hunks)
✅ Files skipped from review due to trivial changes (2)
- fla/models/init.py
- fla/models/mom/init.py
🚧 Files skipped from review as they are similar to previous changes (1)
- fla/init.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
fla/layers/mom_varlen.py (8)
fla/modules/fused_norm_gate.py (1)
FusedRMSNormSwishGate
(1078-1094)fla/modules/convolution.py (1)
ShortConvolution
(548-745)fla/modules/l2norm.py (1)
l2norm
(262-267)fla/ops/gated_delta_rule/chunk.py (1)
chunk_gated_delta_rule
(221-340)fla/ops/gated_delta_rule/fused_recurrent.py (1)
fused_recurrent_gated_delta_rule
(240-351)fla/models/utils.py (2)
Cache
(11-150)update
(43-122)fla/layers/utils.py (2)
pad_input
(176-197)unpad_input
(101-173)fla/layers/mom.py (4)
transform
(25-117)reconstruct
(120-180)MomGatedDeltaNet
(183-600)forward
(358-517)
🪛 Ruff (0.11.9)
fla/models/mom/configuration_mom.py
78-78: Test for membership should be not in
Convert to not in
(E713)
fla/models/mom/modeling_mom.py
247-247: No explicit stacklevel
keyword argument found
Set stacklevel=2
(B028)
364-370: Within an except
clause, raise exceptions with raise ... from err
or raise ... from None
to distinguish them from errors in exception handling
(B904)
fla/layers/mom_varlen.py
28-28: Redefinition of unused pad_input
from line 27
Remove definition: pad_input
(F811)
508-508: Local variable batchsize
is assigned to but never used
Remove assignment to unused variable batchsize
(F841)
586-586: Local variable offsets
is assigned to but never used
Remove assignment to unused variable offsets
(F841)
586-586: Use kwargs.get('offsets')
instead of kwargs.get('offsets', None)
Replace kwargs.get('offsets', None)
with kwargs.get('offsets')
(SIM910)
655-655: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
656-656: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
657-657: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
706-706: Use kwargs.get('offsets')
instead of kwargs.get('offsets', None)
Replace kwargs.get('offsets', None)
with kwargs.get('offsets')
(SIM910)
🪛 Flake8 (7.2.0)
fla/models/mom/configuration_mom.py
[error] 78-78: test for membership should be 'not in'
(E713)
fla/layers/mom_varlen.py
[error] 28-28: redefinition of unused 'pad_input' from line 27
(F811)
[error] 508-508: local variable 'batchsize' is assigned to but never used
(F841)
[error] 586-586: local variable 'offsets' is assigned to but never used
(F841)
🪛 GitHub Actions: lint
fla/models/mom/configuration_mom.py
[error] 78-78: flake8: test for membership should be 'not in' (E713)
fla/models/mom/modeling_mom.py
[error] 210-210: flake8: expected 2 blank lines, found 1 (E302)
[error] 322-322: flake8: expected 2 blank lines, found 1 (E302)
fla/layers/mom_varlen.py
[error] 47-47: flake8: line too long (131 > 127 characters) (E501)
[error] 72-74: flake8: line too long (128-135 > 127 characters) (E501)
[error] 99-101: flake8: line too long (170-178 > 127 characters) (E501)
[error] 142-142: flake8: too many blank lines (3) (E303)
[error] 147-147: flake8: expected 2 blank lines, found 1 (E302) and line too long (128 > 127 characters) (E501)
[error] 213-213: flake8: missing whitespace after ',' (E231)
[error] 222-227: flake8: line too long (185-213 > 127 characters) (E501)
[error] 244-244: flake8: expected 2 blank lines, found 1 (E302) and line too long (177 > 127 characters) (E501)
[error] 251-251: flake8: line too long (128 > 127 characters) (E501)
[error] 284-286: flake8: line too long (135-137 > 127 characters) (E501)
[error] 292-292: flake8: whitespace before ',' (E203) and missing whitespace after ',' (E231)
[error] 411-413: flake8: multiple spaces after operator (E222) and line too long (131-133 > 127 characters) (E501)
[error] 414-414: flake8: line too long (132 > 127 characters) (E501)
[error] 506-506: flake8: local variable 'batchsize' is assigned to but never used (F841) and missing whitespace after ',' (E231)
[error] 516-516: flake8: line too long (223 > 127 characters) (E501)
[error] 525-537: flake8: line too long (129-169 > 127 characters) (E501) and missing whitespace around operator (E225) on line 576
[error] 584-584: flake8: local variable 'offsets' is assigned to but never used (F841)
[error] 601-602: flake8: missing whitespace after ',' (E231) multiple times
[error] 604-604: flake8: line too long (183 > 127 characters) (E501)
[error] 621-625: flake8: missing whitespace after ',' (E231) multiple times and line too long (144-183 > 127 characters) (E501)
[error] 647-647: flake8: too many blank lines (2) (E303)
[error] 651-655: flake8: unexpected spaces around keyword/parameter equals (E251) multiple times
🔇 Additional comments (2)
fla/models/mom/modeling_mom.py (1)
271-301
: Keep all_attns – it’s required for returning attentionsThe
all_attns
variable is neither dead nor local-only: it’s populated whenoutput_attentions=True
and included in the method’s return value. Removing it here would break the interface and diverge from the shared pattern in other model implementations.• In fla/models/mom/modeling_mom.py:
– Line 300:all_attns += (attentions,)
– Line 310:return tuple(..., all_attns] if i is not None)
– Line 315: returned asattentions=all_attns
• This same pattern appears across all other
modeling_*
classes (e.g., Transformer, COMBA, DeltaNet, etc.), ensuring a consistent API.Please ignore the suggestion to remove
all_attns
; it’s necessary for users requesting attentions.Likely an incorrect or invalid review comment.
fla/layers/mom_varlen.py (1)
1-735
: Adjust test file locations to match repository structureThe repository doesn’t have
tests/test_model.py
ortests/test_generation.py
. Model tests are undertests/models/
astest_modeling_<model>.py
, and there’s no generation test directory at the root. To add coverage for the new MoM modules:• Create
tests/models/test_modeling_mom.py
and add unit tests forMomGatedDeltaNet
(e.g. forward pass, gating, capacity handling).
• If you need generation‐style tests for MoM (e.g. end‐to‐end text generation), add a new file under atests/generation/
directory (e.g.tests/generation/test_generation_mom.py
), mirroring your existing generation tests setup.Likely an incorrect or invalid review comment.
expert_mask = torch.nn.functional.one_hot(selected_experts, num_memories) | ||
|
||
# For a given token, determine if it was routed to a given expert. | ||
expert_mask = torch.max(expert_mask, axis=-2).values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use 'dim' instead of 'axis' for torch.max.
PyTorch uses dim
parameter, not axis
.
Apply this diff:
- expert_mask = torch.max(expert_mask, axis=-2).values
+ expert_mask = torch.max(expert_mask, dim=-2).values
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
expert_mask = torch.max(expert_mask, axis=-2).values | |
expert_mask = torch.max(expert_mask, dim=-2).values |
🤖 Prompt for AI Agents
In fla/models/mom/modeling_mom.py at line 550, replace the use of 'axis' with
'dim' in the torch.max function call. Change torch.max(expert_mask,
axis=-2).values to torch.max(expert_mask, dim=-2).values to comply with
PyTorch's API.
fla/models/mom/modeling_mom.py
Outdated
if output_attentions: | ||
warnings.warn("`GatedDeltaNetModel` does not `output_attentions` now, setting it to `False`.") | ||
output_attentions = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix incorrect model name in warning and add stacklevel.
The warning refers to GatedDeltaNetModel
but should refer to MomModel
. Also, add stacklevel
parameter to the warning.
Apply this diff:
if output_attentions:
- warnings.warn("`GatedDeltaNetModel` does not `output_attentions` now, setting it to `False`.")
+ warnings.warn("`MomModel` does not support `output_attentions` now, setting it to `False`.", stacklevel=2)
output_attentions = False
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if output_attentions: | |
warnings.warn("`GatedDeltaNetModel` does not `output_attentions` now, setting it to `False`.") | |
output_attentions = False | |
if output_attentions: | |
warnings.warn( | |
"`MomModel` does not support `output_attentions` now, setting it to `False`.", | |
stacklevel=2, | |
) | |
output_attentions = False |
🧰 Tools
🪛 Ruff (0.11.9)
247-247: No explicit stacklevel
keyword argument found
Set stacklevel=2
(B028)
🤖 Prompt for AI Agents
In fla/models/mom/modeling_mom.py around lines 246 to 248, the warning message
incorrectly refers to `GatedDeltaNetModel` instead of `MomModel` and lacks the
`stacklevel` parameter. Update the warning string to mention `MomModel` and add
the `stacklevel=2` argument to the warnings.warn call to correctly point to the
caller in the stack trace.
fla/models/mom/modeling_mom.py
Outdated
aux_loss = load_balancing_loss_func( | ||
valid_router_logits, | ||
self.num_memories, | ||
self.topk, | ||
use_layer_wise_balance=self.config.use_layer_wise_balance, # ✨ | ||
) | ||
aux_loss *= self.aux_loss_scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
❓ Verification inconclusive
Ensure l2wrap support is included.
According to the PR objectives, l2wrap functionality should be included in the new MoM models.
🏁 Script executed:
#!/bin/bash
# Description: Check if l2wrap is implemented in other models and how it should be integrated
# Search for l2wrap implementation in other models
rg -A 5 "l2wrap" --glob "fla/models/**/*.py"
# Check the commit that introduced l2wrap to all models
git log --grep="l2wrap" --oneline | head -5
Length of output: 193
🏁 Script executed:
#!/bin/bash
set -euxo pipefail
# Search for any l2wrap references in model implementations
rg -A5 "l2wrap" --glob 'fla/models/**/*.py'
# Show the last few commits mentioning l2wrap to understand its introduction
git log --grep="l2wrap" --oneline -n 5
Length of output: 64
🏁 Script executed:
#!/usr/bin/env bash
set -euxo pipefail
# Ensure the repo is cloned so files can be searched
# Search for any l2wrap references under fla/models
rg -A5 "l2wrap" fla/models
# Look for commits mentioning l2wrap to see how and when it was introduced
git log --grep="l2wrap" --oneline -n 5
Length of output: 47
Missing l2wrap integration in MoM model
Our searches didn’t uncover any existing l2wrap references in the repo, so the new MoM implementation currently omits this required component. Please add l2wrap support alongside the load-balancing auxiliary loss.
• File to update: fla/models/mom/modeling_mom.py (at lines 470–476)
• Suggested change (example):
aux_loss = load_balancing_loss_func(
valid_router_logits,
self.num_memories,
self.topk,
- use_layer_wise_balance=self.config.use_layer_wise_balance,
+ use_layer_wise_balance=self.config.use_layer_wise_balance,
)
- aux_loss *= self.aux_loss_scale
+ # integrate l2wrap auxiliary loss
+ wrap_loss = l2wrap_loss_func(
+ hidden_states, # or appropriate tensor
+ self.config.l2wrap_scale, # configurable scale
+ )
+ aux_loss = aux_loss * self.aux_loss_scale + wrap_loss
• Ensure you import l2wrap_loss_func
and expose any new config.l2wrap_scale
setting in the model’s config.
• Add or update unit tests to cover the l2wrap loss path.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
aux_loss = load_balancing_loss_func( | |
valid_router_logits, | |
self.num_memories, | |
self.topk, | |
use_layer_wise_balance=self.config.use_layer_wise_balance, # ✨ | |
) | |
aux_loss *= self.aux_loss_scale | |
aux_loss = load_balancing_loss_func( | |
valid_router_logits, | |
self.num_memories, | |
self.topk, | |
use_layer_wise_balance=self.config.use_layer_wise_balance, # ✨ | |
) | |
# integrate l2wrap auxiliary loss | |
wrap_loss = l2wrap_loss_func( | |
hidden_states, # or appropriate tensor | |
self.config.l2wrap_scale, # configurable scale | |
) | |
aux_loss = aux_loss * self.aux_loss_scale + wrap_loss |
🤖 Prompt for AI Agents
In fla/models/mom/modeling_mom.py around lines 470 to 476, the MoM model
currently lacks integration of the l2wrap loss function. To fix this, first
import l2wrap_loss_func at the top of the file and add a new configuration
parameter l2wrap_scale to the model's config. Then, compute the l2wrap loss
alongside the existing load_balancing_loss_func call, scaling it by
config.l2wrap_scale, and include it in the total auxiliary loss. Finally, update
or add unit tests to cover the new l2wrap loss functionality to ensure proper
test coverage.
fla/layers/mom_varlen.py
Outdated
indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand(b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1) | ||
# discard tokens exceed capacity and is far from now | ||
# left pad | ||
truncation_indices = indices + batch_memory_tokens.reshape((-1,)).unsqueeze(-1) - max_len | ||
mask = torch.bitwise_and(truncation_indices < flatten_offset.unsqueeze(-1), truncation_indices >= 0) | ||
mask = torch.bitwise_and(mask, truncation_indices >= torch.cat((torch.zeros((1,), dtype=flatten_offset.dtype, device=flatten_offset.device), flatten_offset[:-1])).unsqueeze(-1)) | ||
truncation_indices = torch.where(mask, truncation_indices, torch.zeros_like(truncation_indices)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Break long lines for better readability.
Several lines exceed the maximum line length limit.
Apply this diff:
- indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand(b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1)
+ indices = (
+ torch.arange(max_len, device=flatten_offset.device)
+ .unsqueeze(0)
+ .expand(b*num_memories, -1)
+ + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0)
+ .unsqueeze(1)
+ )
# discard tokens exceed capacity and is far from now
# left pad
truncation_indices = indices + batch_memory_tokens.reshape((-1,)).unsqueeze(-1) - max_len
mask = torch.bitwise_and(truncation_indices < flatten_offset.unsqueeze(-1), truncation_indices >= 0)
- mask = torch.bitwise_and(mask, truncation_indices >= torch.cat((torch.zeros((1,), dtype=flatten_offset.dtype, device=flatten_offset.device), flatten_offset[:-1])).unsqueeze(-1))
+ mask = torch.bitwise_and(
+ mask,
+ truncation_indices >= torch.cat(
+ (torch.zeros((1,), dtype=flatten_offset.dtype, device=flatten_offset.device),
+ flatten_offset[:-1])
+ ).unsqueeze(-1)
+ )
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand(b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1) | |
# discard tokens exceed capacity and is far from now | |
# left pad | |
truncation_indices = indices + batch_memory_tokens.reshape((-1,)).unsqueeze(-1) - max_len | |
mask = torch.bitwise_and(truncation_indices < flatten_offset.unsqueeze(-1), truncation_indices >= 0) | |
mask = torch.bitwise_and(mask, truncation_indices >= torch.cat((torch.zeros((1,), dtype=flatten_offset.dtype, device=flatten_offset.device), flatten_offset[:-1])).unsqueeze(-1)) | |
truncation_indices = torch.where(mask, truncation_indices, torch.zeros_like(truncation_indices)) | |
indices = ( | |
torch.arange(max_len, device=flatten_offset.device) | |
.unsqueeze(0) | |
.expand(b * num_memories, -1) | |
+ torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0) | |
.unsqueeze(1) | |
) | |
# discard tokens exceed capacity and is far from now | |
# left pad | |
truncation_indices = indices + batch_memory_tokens.reshape((-1,)).unsqueeze(-1) - max_len | |
mask = torch.bitwise_and( | |
truncation_indices < flatten_offset.unsqueeze(-1), | |
truncation_indices >= 0 | |
) | |
mask = torch.bitwise_and( | |
mask, | |
truncation_indices >= torch.cat( | |
(torch.zeros((1,), dtype=flatten_offset.dtype, device=flatten_offset.device), | |
flatten_offset[:-1]) | |
).unsqueeze(-1) | |
) | |
truncation_indices = torch.where(mask, truncation_indices, torch.zeros_like(truncation_indices)) |
🤖 Prompt for AI Agents
In fla/layers/mom_varlen.py around lines 224 to 230, the code contains several
long lines that exceed the maximum line length limit, reducing readability.
Break these long lines into multiple shorter lines by splitting expressions at
logical points such as after operators or commas, and aligning subsequent lines
with proper indentation to maintain clarity and adhere to style guidelines.
fla/layers/mom_varlen.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@WKX933 I think it would better if merging the two file mom.py
and mom_varlen.py
fla/models/mom/configuration_mom.py
Outdated
self.single_kv_proj = single_kv_proj | ||
self.mom_backend = mom_backend | ||
|
||
if not self.mom_backend in ['GDN']: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JusenD Could we change the mom_backend
name to model_type
registered, e.g., GatedDelteNetConfig.model_type
?
fla/layers/__init__.py
Outdated
@@ -19,6 +19,7 @@ | |||
from .mamba import Mamba | |||
from .mamba2 import Mamba2 | |||
from .mesa_net import MesaNet | |||
from .mom_varlen import MomGatedDeltaNet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JusenD Let's rename this module to MomAttention
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (3)
fla/layers/mom.py (1)
514-516
: Add missing kwargs parameter to shared_o method call.The
shared_o
method expects**kwargs
but it's not being passed in the call.if self.shared_mem: shared_o = self.shared_o(shared_hidden_states, attention_mask, recurrent_state, - use_cache, conv_state_q, conv_state_k, conv_state_v) + use_cache, conv_state_q, conv_state_k, conv_state_v, **kwargs) o += shared_ofla/layers/mom_varlen.py (2)
516-516
: Remove unused variablebatchsize
.The variable
batchsize
is assigned but never used.- batchsize, q_len = hidden_states.shape[0], hidden_states.shape[1] + q_len = hidden_states.shape[1]
641-643
: Add missing kwargs parameter to shared_o method call.The
shared_o
method expects**kwargs
but it's not being passed in the call.if self.shared_mem: shared_o = self.shared_o(shared_hidden_states, attention_mask, recurrent_state, - use_cache, conv_state_q, conv_state_k, conv_state_v) + use_cache, conv_state_q, conv_state_k, conv_state_v, **kwargs) o += shared_o
🧹 Nitpick comments (1)
fla/models/mom/modeling_mom.py (1)
362-368
: Improve exception handling with proper chaining.When re-raising exceptions, use
raise ... from
to preserve the exception chain.- raise AttributeError( - f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " - f"which is not supported for {self.__class__.__name__}. " - f"Try another generation strategy instead. " - f"For the available generation strategies, check this doc: " - f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" - ) + raise AttributeError( + f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " + f"which is not supported for {self.__class__.__name__}. " + f"Try another generation strategy instead. " + f"For the available generation strategies, check this doc: " + f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" + ) from exception
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
fla/__init__.py
(3 hunks)fla/layers/__init__.py
(2 hunks)fla/layers/mom.py
(1 hunks)fla/layers/mom_varlen.py
(1 hunks)fla/models/mom/configuration_mom.py
(1 hunks)fla/models/mom/modeling_mom.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- fla/layers/init.py
- fla/models/mom/configuration_mom.py
- fla/init.py
🧰 Additional context used
🧠 Learnings (1)
fla/layers/mom_varlen.py (1)
Learnt from: toothacher17
PR: fla-org/flash-linear-attention#395
File: fla/layers/mla.py:0-0
Timestamp: 2025-05-13T06:04:24.342Z
Learning: In Flash Attention's `flash_attn_varlen_func`, the scaling factor (typically computed as head_dim ** -0.5) is handled internally by the Flash Attention implementation rather than needing to be applied explicitly to query projections.
🪛 Ruff (0.11.9)
fla/layers/mom.py
429-429: Use kwargs.get('seq_idx')
instead of kwargs.get('seq_idx', None)
Replace kwargs.get('seq_idx', None)
with kwargs.get('seq_idx')
(SIM910)
461-461: Use kwargs.get('cu_seqlens')
instead of kwargs.get('cu_seqlens', None)
Replace kwargs.get('cu_seqlens', None)
with kwargs.get('cu_seqlens')
(SIM910)
542-542: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
543-543: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
544-544: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
560-560: Use kwargs.get('seq_idx')
instead of kwargs.get('seq_idx', None)
Replace kwargs.get('seq_idx', None)
with kwargs.get('seq_idx')
(SIM910)
589-589: Use kwargs.get('cu_seqlens')
instead of kwargs.get('cu_seqlens', None)
Replace kwargs.get('cu_seqlens', None)
with kwargs.get('cu_seqlens')
(SIM910)
fla/layers/mom_varlen.py
516-516: Local variable batchsize
is assigned to but never used
Remove assignment to unused variable batchsize
(F841)
598-598: Local variable offsets
is assigned to but never used
Remove assignment to unused variable offsets
(F841)
598-598: Use kwargs.get('offsets')
instead of kwargs.get('offsets', None)
Replace kwargs.get('offsets', None)
with kwargs.get('offsets')
(SIM910)
669-669: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
670-670: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
671-671: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
720-720: Use kwargs.get('offsets')
instead of kwargs.get('offsets', None)
Replace kwargs.get('offsets', None)
with kwargs.get('offsets')
(SIM910)
fla/models/mom/modeling_mom.py
244-244: No explicit stacklevel
keyword argument found
Set stacklevel=2
(B028)
362-368: Within an except
clause, raise exceptions with raise ... from err
or raise ... from None
to distinguish them from errors in exception handling
(B904)
🪛 Flake8 (7.2.0)
fla/layers/mom.py
[error] 428-428: continuation line under-indented for visual indent
(E128)
fla/layers/mom_varlen.py
[error] 516-516: local variable 'batchsize' is assigned to but never used
(F841)
[error] 598-598: local variable 'offsets' is assigned to but never used
(F841)
🪛 GitHub Actions: lint
fla/layers/mom.py
[error] 24-24: flake8: line too long (128 > 127 characters) (E501)
[error] 100-100: flake8: line too long (135 > 127 characters) (E501)
[error] 123-123: flake8: line too long (177 > 127 characters) (E501)
[error] 130-130: flake8: line too long (128 > 127 characters) (E501)
[error] 398-398: flake8: line too long (150 > 127 characters) (E501)
[error] 427-427: flake8: whitespace before ':' (E203)
[error] 428-428: flake8: continuation line under-indented for visual indent (E128)
fla/layers/mom_varlen.py
[error] 47-47: flake8: line too long (131 > 127 characters) (E501)
[error] 72-72: flake8: line too long (128 > 127 characters) (E501)
[error] 74-74: flake8: line too long (135 > 127 characters) (E501)
[error] 99-99: flake8: line too long (178 > 127 characters) (E501)
[error] 101-101: flake8: line too long (170 > 127 characters) (E501)
[error] 147-147: flake8: line too long (128 > 127 characters) (E501)
[error] 223-223: flake8: line too long (135 > 127 characters) (E501)
[error] 248-248: flake8: line too long (177 > 127 characters) (E501)
[error] 255-255: flake8: line too long (128 > 127 characters) (E501)
[error] 516-516: flake8: local variable 'batchsize' is assigned to but never used (F841)
[error] 527-527: flake8: line too long (150 > 127 characters) (E501)
[error] 598-598: flake8: local variable 'offsets' is assigned to but never used (F841)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test H100 (PyTorch 2.7) / test-models
🔇 Additional comments (1)
fla/models/mom/modeling_mom.py (1)
398-403
: Remove duplicate key in dictionary.The
num_logits_to_keep
key appears twice in the model_inputs dictionary, which will cause the first value to be overridden.model_inputs.update({ 'past_key_values': past_key_values, 'use_cache': use_cache, 'attention_mask': attention_mask, - 'num_logits_to_keep': num_logits_to_keep, })
Likely an incorrect or invalid review comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (4)
fla/layers/mom_varlen.py (2)
372-372
: Fix incorrect return type annotation.The
__init__
method should returnNone
, notMomAttention
.Apply this diff:
- ) -> MomAttention: + ) -> None:
659-661
: Avoid mutable default arguments.Using mutable objects as default arguments can lead to unexpected behavior when the same default object is reused across function calls.
Apply this diff:
- conv_state_q=[None, None], - conv_state_k=[None, None], - conv_state_v=[None, None], + conv_state_q=None, + conv_state_k=None, + conv_state_v=None,Then initialize them in the method:
if conv_state_q is None: conv_state_q = [None, None] if conv_state_k is None: conv_state_k = [None, None] if conv_state_v is None: conv_state_v = [None, None]fla/layers/mom.py (2)
253-253
: Fix incorrect return type annotation.The
__init__
method should returnNone
, notMomAttention
.Apply this diff:
- ) -> MomAttention: + ) -> None:
541-543
: Avoid mutable default arguments.Using mutable objects as default arguments can lead to unexpected behavior when the same default object is reused across function calls.
Apply this diff:
- conv_state_q=[None, None], - conv_state_k=[None, None], - conv_state_v=[None, None], + conv_state_q=None, + conv_state_k=None, + conv_state_v=None,Then initialize them in the method:
if conv_state_q is None: conv_state_q = [None, None] if conv_state_k is None: conv_state_k = [None, None] if conv_state_v is None: conv_state_v = [None, None]
🧹 Nitpick comments (5)
fla/layers/mom_varlen.py (2)
46-46
: Fix line length violations.Multiple lines exceed the 127 character limit, reducing readability.
Break these long lines for better readability. For example:
- cu_seqlens (`torch.Tensor`): - The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + cu_seqlens (`torch.Tensor`): + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. + `cu_seqlens` shape is (batch_size + 1,).Apply similar formatting to other long lines to maintain consistency with the project's style guidelines.
Also applies to: 71-73, 98-100, 140-140, 216-216, 241-241, 248-248, 520-520
591-591
: Simplify kwargs.get() calls.The explicit
None
default is redundant since it's the default behavior.Apply this diff:
- cu_seqlens = kwargs.get('cu_seqlens', None) + cu_seqlens = kwargs.get('cu_seqlens')- cu_seqlens = kwargs.get('cu_seqlens', None) + cu_seqlens = kwargs.get('cu_seqlens')Also applies to: 710-710
fla/layers/mom.py (3)
23-23
: Fix line length violations.Multiple lines exceed the 127 character limit, reducing readability.
Break these long lines for better readability. For example, line 397:
- routing_weights_full = torch.zeros((routing_weights.shape[0], routing_weights.shape[1], self.num_memories), - dtype=routing_weights.dtype, device=routing_weights.device).scatter(-1, selected_memories, routing_weights) + routing_weights_full = torch.zeros( + (routing_weights.shape[0], routing_weights.shape[1], self.num_memories), + dtype=routing_weights.dtype, + device=routing_weights.device + ).scatter(-1, selected_memories, routing_weights)Also applies to: 99-99, 122-122, 129-129, 397-397
426-427
: Fix indentation and whitespace issues.There are whitespace and indentation problems causing pipeline failures.
Apply this diff:
- conv_mask = attention_mask[:, -hidden_states.shape[2] - :].repeat_interleave(self.num_memories, 0) if attention_mask is not None else None + conv_mask = ( + attention_mask[:, -hidden_states.shape[2]:] + .repeat_interleave(self.num_memories, 0) + if attention_mask is not None else None + )
428-428
: Simplify kwargs.get() calls.The explicit
None
default is redundant since it's the default behavior.Apply this diff:
- seq_idx = kwargs.get('seq_idx', None) + seq_idx = kwargs.get('seq_idx')- cu_seqlens = kwargs.get('cu_seqlens', None) + cu_seqlens = kwargs.get('cu_seqlens')Apply similar changes to the other occurrences.
Also applies to: 460-460, 559-559, 588-588
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
fla/layers/mom.py
(1 hunks)fla/layers/mom_varlen.py
(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
fla/layers/mom_varlen.py (1)
Learnt from: toothacher17
PR: fla-org/flash-linear-attention#395
File: fla/layers/mla.py:0-0
Timestamp: 2025-05-13T06:04:24.342Z
Learning: In Flash Attention's `flash_attn_varlen_func`, the scaling factor (typically computed as head_dim ** -0.5) is handled internally by the Flash Attention implementation rather than needing to be applied explicitly to query projections.
🪛 Ruff (0.11.9)
fla/layers/mom_varlen.py
591-591: Use kwargs.get('cu_seqlens')
instead of kwargs.get('cu_seqlens', None)
Replace kwargs.get('cu_seqlens', None)
with kwargs.get('cu_seqlens')
(SIM910)
659-659: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
660-660: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
661-661: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
710-710: Use kwargs.get('cu_seqlens')
instead of kwargs.get('cu_seqlens', None)
Replace kwargs.get('cu_seqlens', None)
with kwargs.get('cu_seqlens')
(SIM910)
fla/layers/mom.py
428-428: Use kwargs.get('seq_idx')
instead of kwargs.get('seq_idx', None)
Replace kwargs.get('seq_idx', None)
with kwargs.get('seq_idx')
(SIM910)
460-460: Use kwargs.get('cu_seqlens')
instead of kwargs.get('cu_seqlens', None)
Replace kwargs.get('cu_seqlens', None)
with kwargs.get('cu_seqlens')
(SIM910)
541-541: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
542-542: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
543-543: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
559-559: Use kwargs.get('seq_idx')
instead of kwargs.get('seq_idx', None)
Replace kwargs.get('seq_idx', None)
with kwargs.get('seq_idx')
(SIM910)
588-588: Use kwargs.get('cu_seqlens')
instead of kwargs.get('cu_seqlens', None)
Replace kwargs.get('cu_seqlens', None)
with kwargs.get('cu_seqlens')
(SIM910)
🪛 GitHub Actions: lint
fla/layers/mom_varlen.py
[error] 46-46: flake8: line too long (131 > 127 characters) (E501)
[error] 71-71: flake8: line too long (128 > 127 characters) (E501)
[error] 73-73: flake8: line too long (135 > 127 characters) (E501)
[error] 98-98: flake8: line too long (178 > 127 characters) (E501)
[error] 100-100: flake8: line too long (170 > 127 characters) (E501)
[error] 140-140: flake8: line too long (128 > 127 characters) (E501)
[error] 216-216: flake8: line too long (135 > 127 characters) (E501)
[error] 241-241: flake8: line too long (177 > 127 characters) (E501)
[error] 248-248: flake8: line too long (128 > 127 characters) (E501)
[error] 520-520: flake8: line too long (150 > 127 characters) (E501)
fla/layers/mom.py
[error] 23-23: flake8: line too long (128 > 127 characters) (E501)
[error] 99-99: flake8: line too long (135 > 127 characters) (E501)
[error] 122-122: flake8: line too long (177 > 127 characters) (E501)
[error] 129-129: flake8: line too long (128 > 127 characters) (E501)
[error] 397-397: flake8: line too long (150 > 127 characters) (E501)
[error] 426-426: flake8: whitespace before ':' (E203)
[error] 427-427: flake8: continuation line under-indented for visual indent (E128)
🪛 Flake8 (7.2.0)
fla/layers/mom.py
[error] 427-427: continuation line under-indented for visual indent
(E128)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test H100 (PyTorch 2.7) / test-models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (5)
fla/models/mom/modeling_mom.py (3)
43-43
: Fix incorrect return type annotation.The
__init__
method should returnNone
, not the class type.Apply this diff:
- ) -> MomMLP: + ) -> None:
164-164
: Fix incorrect module name in_no_split_modules
.The
_no_split_modules
list references'GatedDeltaNetBlock'
but should reference'MomBlock'
for this model.- _no_split_modules = ['GatedDeltaNetBlock'] + _no_split_modules = ['MomBlock']
244-244
: Fix incorrect model name in warning and add stacklevel.The warning refers to
GatedDeltaNetModel
but should refer toMomModel
. Also, addstacklevel
parameter to the warning.Apply this diff:
- warnings.warn("`GatedDeltaNetModel` does not `output_attentions` now, setting it to `False`.") + warnings.warn("`MomModel` does not support `output_attentions` now, setting it to `False`.", stacklevel=2)fla/layers/mom_varlen.py (2)
330-330
: Fix incorrect return type annotation.The
__init__
method should returnNone
, not the class type.Apply this diff:
- ) -> MomAttention: + ) -> None:
617-619
: Avoid mutable default arguments.Using mutable objects as default arguments can lead to unexpected behavior.
Apply this diff:
- conv_state_q=[None, None], - conv_state_k=[None, None], - conv_state_v=[None, None], + conv_state_q=None, + conv_state_k=None, + conv_state_v=None,Then initialize them in the method:
if conv_state_q is None: conv_state_q = [None, None] if conv_state_k is None: conv_state_k = [None, None] if conv_state_v is None: conv_state_v = [None, None]
🧹 Nitpick comments (4)
fla/layers/mom_varlen.py (4)
216-216
: Break long line for better readability.The line exceeds the maximum length limit. Break it into multiple lines.
Apply this diff:
- indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand( - b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1) + indices = ( + torch.arange(max_len, device=flatten_offset.device) + .unsqueeze(0) + .expand(b*num_memories, -1) + + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0) + .unsqueeze(1) + )
477-478
: Break long line for better readability.The line exceeds the maximum length limit. Break it into multiple lines.
Apply this diff:
- routing_weights_full = torch.zeros((routing_weights.shape[0], routing_weights.shape[1], self.num_memories), - dtype=routing_weights.dtype, device=routing_weights.device).scatter(-1, selected_memories, routing_weights) + routing_weights_full = torch.zeros( + (routing_weights.shape[0], routing_weights.shape[1], self.num_memories), + dtype=routing_weights.dtype, + device=routing_weights.device + ).scatter(-1, selected_memories, routing_weights)
549-549
: Optimize kwargs.get usage.Simplify the kwargs.get call by removing the redundant
None
default.Apply this diff:
- cu_seqlens = kwargs.get('cu_seqlens', None) + cu_seqlens = kwargs.get('cu_seqlens')
668-668
: Optimize kwargs.get usage.Simplify the kwargs.get call by removing the redundant
None
default.Apply this diff:
- cu_seqlens = kwargs.get('cu_seqlens', None) + cu_seqlens = kwargs.get('cu_seqlens')
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
fla/layers/mom_varlen.py
(1 hunks)fla/models/mom/configuration_mom.py
(1 hunks)fla/models/mom/modeling_mom.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- fla/models/mom/configuration_mom.py
🧰 Additional context used
🧠 Learnings (1)
fla/layers/mom_varlen.py (1)
Learnt from: toothacher17
PR: fla-org/flash-linear-attention#395
File: fla/layers/mla.py:0-0
Timestamp: 2025-05-13T06:04:24.342Z
Learning: In Flash Attention's `flash_attn_varlen_func`, the scaling factor (typically computed as head_dim ** -0.5) is handled internally by the Flash Attention implementation rather than needing to be applied explicitly to query projections.
🧬 Code Graph Analysis (1)
fla/layers/mom_varlen.py (9)
fla/modules/fused_norm_gate.py (1)
FusedRMSNormGated
(993-1054)fla/modules/layernorm.py (1)
RMSNorm
(1060-1107)fla/modules/convolution.py (1)
ShortConvolution
(548-745)fla/modules/l2norm.py (1)
l2norm
(262-267)fla/ops/gated_delta_rule/chunk.py (1)
chunk_gated_delta_rule
(221-340)fla/ops/gated_delta_rule/fused_recurrent.py (1)
fused_recurrent_gated_delta_rule
(240-351)fla/models/utils.py (1)
Cache
(11-150)fla/layers/utils.py (2)
pad_input
(176-197)unpad_input
(101-173)fla/layers/mom.py (4)
transform
(23-117)reconstruct
(122-184)MomAttention
(187-615)forward
(366-533)
🪛 Ruff (0.11.9)
fla/models/mom/modeling_mom.py
244-244: No explicit stacklevel
keyword argument found
Set stacklevel=2
(B028)
362-368: Within an except
clause, raise exceptions with raise ... from err
or raise ... from None
to distinguish them from errors in exception handling
(B904)
fla/layers/mom_varlen.py
549-549: Use kwargs.get('cu_seqlens')
instead of kwargs.get('cu_seqlens', None)
Replace kwargs.get('cu_seqlens', None)
with kwargs.get('cu_seqlens')
(SIM910)
617-617: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
618-618: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
619-619: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
668-668: Use kwargs.get('cu_seqlens')
instead of kwargs.get('cu_seqlens', None)
Replace kwargs.get('cu_seqlens', None)
with kwargs.get('cu_seqlens')
(SIM910)
🪛 GitHub Actions: lint
fla/layers/mom_varlen.py
[error] 46-46: flake8: line too long (131 > 127 characters) (E501)
[error] 71-71: flake8: line too long (128 > 127 characters) (E501)
[error] 73-73: flake8: line too long (135 > 127 characters) (E501)
[error] 98-98: flake8: line too long (178 > 127 characters) (E501)
[error] 100-100: flake8: line too long (170 > 127 characters) (E501)
[error] 140-140: flake8: line too long (128 > 127 characters) (E501)
[error] 216-216: flake8: line too long (135 > 127 characters) (E501)
[error] 241-241: flake8: line too long (177 > 127 characters) (E501)
[error] 248-248: flake8: line too long (128 > 127 characters) (E501)
[error] 478-478: flake8: line too long (150 > 127 characters) (E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test H100 (PyTorch 2.7) / test-models
🔇 Additional comments (1)
fla/layers/mom_varlen.py (1)
530-530
: Fix syntax error with trailing comma.Remove the trailing comma after the tuple elements to fix the syntax error.
Apply this diff:
- q, k, v = self.silu(q), self.silu(k), self.silu(v), + q, k, v = self.silu(q), self.silu(k), self.silu(v)Likely an incorrect or invalid review comment.
raise AttributeError( | ||
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " | ||
f"which is not supported for {self.__class__.__name__}. " | ||
f"Try another generation strategy instead. " | ||
f"For the available generation strategies, check this doc: " | ||
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Improve exception handling with proper chaining.
Use raise ... from exception
to properly chain exceptions and provide better error context.
Apply this diff:
raise AttributeError(
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
f"which is not supported for {self.__class__.__name__}. "
f"Try another generation strategy instead. "
f"For the available generation strategies, check this doc: "
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
- )
+ ) from exception
else:
- raise exception
+ raise
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
raise AttributeError( | |
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " | |
f"which is not supported for {self.__class__.__name__}. " | |
f"Try another generation strategy instead. " | |
f"For the available generation strategies, check this doc: " | |
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" | |
) | |
raise AttributeError( | |
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " | |
f"which is not supported for {self.__class__.__name__}. " | |
f"Try another generation strategy instead. " | |
f"For the available generation strategies, check this doc: " | |
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" | |
) from exception | |
else: | |
raise |
🧰 Tools
🪛 Ruff (0.11.9)
362-368: Within an except
clause, raise exceptions with raise ... from err
or raise ... from None
to distinguish them from errors in exception handling
(B904)
🤖 Prompt for AI Agents
In fla/models/mom/modeling_mom.py around lines 362 to 368, the AttributeError is
raised without chaining the original exception, which loses error context.
Modify the raise statement to use "raise ... from exception" by capturing the
original exception and chaining it to the new AttributeError to improve
debugging and error traceability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (3)
fla/layers/mom.py (3)
330-330
: Fix incorrect return type annotation.The
__init__
method should returnNone
, notMomAttention
.Apply this diff:
- ) -> MomAttention: + ) -> None:
617-619
: Fix mutable default arguments.Using mutable objects as default arguments can lead to unexpected behavior.
Apply this diff:
- conv_state_q=[None, None], - conv_state_k=[None, None], - conv_state_v=[None, None], + conv_state_q=None, + conv_state_k=None, + conv_state_v=None,Then initialize them in the method:
if conv_state_q is None: conv_state_q = [None, None] if conv_state_k is None: conv_state_k = [None, None] if conv_state_v is None: conv_state_v = [None, None]
358-358
: Fix typo in assertion error message.There's still a typo "suppoerted" instead of "supported" despite being marked as addressed in previous reviews.
Apply this diff:
- assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + assert mode in ['chunk', 'fused_recurrent'], f"Not supported mode `{mode}`."
🧹 Nitpick comments (3)
fla/layers/mom.py (3)
549-549
: Simplify kwargs.get usage.The
None
default is redundant sinceget()
returnsNone
by default.Apply this diff:
- cu_seqlens = kwargs.get('cu_seqlens', None) + cu_seqlens = kwargs.get('cu_seqlens')
668-668
: Simplify kwargs.get usage.The
None
default is redundant sinceget()
returnsNone
by default.Apply this diff:
- cu_seqlens = kwargs.get('cu_seqlens', None) + cu_seqlens = kwargs.get('cu_seqlens')
46-46
: Fix line length violations.Multiple lines exceed the 127 character limit and need to be wrapped for better readability.
Consider breaking long lines using parentheses for function arguments, or splitting complex expressions across multiple lines. For example:
- cu_seqlens (`torch.Tensor`): The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + cu_seqlens (`torch.Tensor`): + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. + `cu_seqlens` shape is (batch_size + 1,).Also applies to: 71-71, 73-73, 98-98, 100-100, 140-140, 216-216, 241-241, 248-248, 478-478
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
fla/layers/__init__.py
(2 hunks)fla/layers/mom.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- fla/layers/init.py
🧰 Additional context used
🪛 Ruff (0.11.9)
fla/layers/mom.py
549-549: Use kwargs.get('cu_seqlens')
instead of kwargs.get('cu_seqlens', None)
Replace kwargs.get('cu_seqlens', None)
with kwargs.get('cu_seqlens')
(SIM910)
617-617: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
618-618: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
619-619: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
668-668: Use kwargs.get('cu_seqlens')
instead of kwargs.get('cu_seqlens', None)
Replace kwargs.get('cu_seqlens', None)
with kwargs.get('cu_seqlens')
(SIM910)
🪛 GitHub Actions: lint
fla/layers/mom.py
[error] 46-46: flake8: line too long (131 > 127 characters) (E501)
[error] 71-71: flake8: line too long (128 > 127 characters) (E501)
[error] 73-73: flake8: line too long (135 > 127 characters) (E501)
[error] 98-98: flake8: line too long (178 > 127 characters) (E501)
[error] 100-100: flake8: line too long (170 > 127 characters) (E501)
[error] 140-140: flake8: line too long (128 > 127 characters) (E501)
[error] 216-216: flake8: line too long (135 > 127 characters) (E501)
[error] 241-241: flake8: line too long (177 > 127 characters) (E501)
[error] 248-248: flake8: line too long (128 > 127 characters) (E501)
[error] 478-478: flake8: line too long (150 > 127 characters) (E501)
🔇 Additional comments (5)
fla/layers/mom.py (5)
1-32
: LGTM! Clean imports and utility setup.The imports are well-organized and the conditional flash attention import with fallback message is good practice.
34-137
: LGTM! Standard flash attention utilities.The helper functions are well-documented and follow standard patterns for flash attention integration.
140-236
: LGTM! Well-implemented memory routing with capacity constraints.This is a complex but crucial function for the MoM architecture. The comprehensive docstring, efficient use of
torch.no_grad()
for indexing operations, and proper handling of topk routing and capacity constraints demonstrate good implementation quality.
241-303
: LGTM! Solid reconstruction logic with safety checks.The function properly reverses the memory routing process with clear documentation and helpful assertions for debugging.
446-609
: LGTM! Core attention mechanism is well-implemented.The forward method properly implements the MoM attention mechanism with correct routing, state management, and mode handling. The logic flow ensures all variables are defined before use.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🔭 Outside diff range comments (1)
fla/layers/mom.py (1)
1-719
: Add missing tests for MoM modules.I couldn’t find any MoM-related tests in
tests/test_model.py
ortests/test_generation.py
. Per the PR objectives, please:
Create or update
tests/test_model.py
to include unit tests for:
transform()
andreconstruct()
logic (e.g., simple input patterns, capacity truncation, padding behavior).MomAttention
constructor andforward
pass (test both chunk and fused_recurrent modes, with and withoutuse_output_gate
/use_short_conv
).Create or update
tests/test_generation.py
to cover end-to-end behavior:
- Integrate
MomAttention
into a minimal model and verify output shapes and basic functionality (e.g., autoregressive generation step).Ensuring these tests exist will satisfy the PR’s test-coverage requirements.
♻️ Duplicate comments (12)
fla/models/mom/modeling_mom.py (5)
114-114
: Fix incorrect module name in_no_split_modules
.The list references 'GatedDeltaNetBlock' but should reference 'MomBlock' for this model.
- _no_split_modules = ['GatedDeltaNetBlock'] + _no_split_modules = ['MomBlock']
193-195
: Fix incorrect model name in warning and add stacklevel.The warning refers to
GatedDeltaNetModel
but should refer toMomModel
. Also, addstacklevel
parameter to the warning.Apply this diff:
if output_attentions: - warnings.warn("`GatedDeltaNetModel` does not `output_attentions` now, setting it to `False`.") + warnings.warn("`MomModel` does not support `output_attentions` now, setting it to `False`.", stacklevel=2) output_attentions = False
312-320
: Improve exception handling with proper chaining.Use
raise ... from exception
to properly chain exceptions and provide better error context.Apply this diff:
raise AttributeError( f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " f"which is not supported for {self.__class__.__name__}. " f"Try another generation strategy instead. " f"For the available generation strategies, check this doc: " f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" - ) + ) from exception else: - raise exception + raise
418-425
: Missing l2wrap integration in auxiliary loss computation.According to the PR objectives, l2wrap functionality should be included in the new MoM models. The current implementation only includes load balancing loss.
#!/bin/bash # Description: Check if l2wrap is implemented in other models to understand the expected integration pattern # Search for l2wrap implementation patterns in other models rg -A 10 "l2wrap" --glob "fla/models/**/*.py" --glob "!fla/models/mom/**" # Search for any l2wrap loss functions ast-grep --pattern 'def $_($$$ l2wrap $$$) { $$$ }'
498-498
: Use 'dim' instead of 'axis' for torch.max.PyTorch uses
dim
parameter, notaxis
.Apply this diff:
- expert_mask = torch.max(expert_mask, axis=-2).values + expert_mask = torch.max(expert_mask, dim=-2).valuesfla/layers/mom.py (7)
211-211
: Add missing whitespace after comma.Apply this diff:
- memory_batch_offset = offset.transpose(0,1) + memory_batch_offset = offset.transpose(0, 1)
302-302
: Fix whitespace issues around comma.Apply this diff:
- resortd_x = torch.zeros((b * s * k, h, d) ,device=gathered_x.device, dtype=gathered_x.dtype).scatter_add_( + resortd_x = torch.zeros((b * s * k, h, d), device=gathered_x.device, dtype=gathered_x.dtype).scatter_add_(
341-341
: Fix incorrect return type annotation.The
__init__
method should returnNone
, notMomAttention
.Apply this diff:
- ) -> MomAttention: + ) -> None:
369-369
: Fix typo in assertion error message.Apply this diff:
- assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + assert mode in ['chunk', 'fused_recurrent'], f"Not supported mode `{mode}`."
379-383
: Fix excessive whitespace after operator.Apply this diff:
- self.k_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.key_dim, bias=False) for _ in range(self.num_memories)]) + self.k_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.key_dim, bias=False) for _ in range(self.num_memories)]) - self.b_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_heads, bias=False) for _ in range(self.num_memories)]) + self.b_proj = nn.ModuleList([nn.Linear(self.hidden_size, self.num_heads, bias=False) for _ in range(self.num_memories)])Also applies to: 387-390
508-615
: Fix critical bug: undefined variable 'o' when shared_mem is True.The variable
o
is used before it's defined whenself.shared_mem
is True. The code tries to addshared_o
too
on line 614, buto
is only defined within the mode-specific blocks.The logic needs restructuring. Initialize
o
before the conditional blocks or ensure it's defined in all code paths before the shared memory addition.
640-642
: Avoid mutable default arguments.Using mutable objects as default arguments can lead to unexpected behavior.
Apply this diff:
- conv_state_q=[None, None], - conv_state_k=[None, None], - conv_state_v=[None, None], + conv_state_q=None, + conv_state_k=None, + conv_state_v=None,Then initialize them in the method:
if conv_state_q is None: conv_state_q = [None, None] if conv_state_k is None: conv_state_k = [None, None] if conv_state_v is None: conv_state_v = [None, None]
🧹 Nitpick comments (1)
fla/layers/mom.py (1)
570-570
: Simplify kwargs.get usage.When the default value is
None
, you can omit it fromkwargs.get()
.Apply this diff:
- cu_seqlens = kwargs.get('cu_seqlens', None) + cu_seqlens = kwargs.get('cu_seqlens')Also applies to: 689-689
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
fla/layers/mom.py
(1 hunks)fla/models/mom/configuration_mom.py
(1 hunks)fla/models/mom/modeling_mom.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- fla/models/mom/configuration_mom.py
🧰 Additional context used
🪛 Ruff (0.11.9)
fla/models/mom/modeling_mom.py
194-194: No explicit stacklevel
keyword argument found
Set stacklevel=2
(B028)
312-318: Within an except
clause, raise exceptions with raise ... from err
or raise ... from None
to distinguish them from errors in exception handling
(B904)
fla/layers/mom.py
570-570: Use kwargs.get('cu_seqlens')
instead of kwargs.get('cu_seqlens', None)
Replace kwargs.get('cu_seqlens', None)
with kwargs.get('cu_seqlens')
(SIM910)
640-640: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
641-641: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
642-642: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
689-689: Use kwargs.get('cu_seqlens')
instead of kwargs.get('cu_seqlens', None)
Replace kwargs.get('cu_seqlens', None)
with kwargs.get('cu_seqlens')
(SIM910)
🪛 GitHub Actions: lint
fla/layers/mom.py
[error] 45-45: flake8: line too long (131 > 127 characters) (E501)
[error] 70-70: flake8: line too long (128 > 127 characters) (E501)
[error] 72-72: flake8: line too long (135 > 127 characters) (E501)
[error] 97-97: flake8: line too long (178 > 127 characters) (E501)
[error] 99-99: flake8: line too long (170 > 127 characters) (E501)
[error] 221-221: flake8: line too long (135 > 127 characters) (E501)
[error] 259-259: flake8: line too long (128 > 127 characters) (E501)
🔇 Additional comments (1)
fla/models/mom/modeling_mom.py (1)
32-108
: LGTM! Well-structured MomBlock implementation.The block correctly handles both standard attention and MoM attention backends, with proper error handling for unsupported backends.
fla/layers/mom.py
Outdated
indices (`torch.Tensor`): | ||
The indices of non-masked tokens from the flattened input sequence. | ||
cu_seqlens (`torch.Tensor`): | ||
The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix long lines exceeding 127 characters.
Several lines exceed the maximum line length limit. Consider breaking them into multiple lines for better readability.
For example, break the docstring lines:
- The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
+ The cumulative sequence lengths, used to index into ragged (unpadded) tensors.
+ `cu_seqlens` shape is (batch_size + 1,).
Also applies to: 70-70, 72-72, 97-97, 99-99
🧰 Tools
🪛 GitHub Actions: lint
[error] 45-45: flake8: line too long (131 > 127 characters) (E501)
🤖 Prompt for AI Agents
In fla/layers/mom.py at lines 45, 70, 72, 97, and 99, some lines exceed the
127-character limit. To fix this, break these long lines into multiple shorter
lines, especially in docstrings and comments, ensuring each line stays within
the character limit for improved readability and compliance with style
guidelines.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (3)
fla/layers/mom.py (3)
38-40
: Fix line length exceeding 127 characters.
65-67
: Fix line length exceeding 127 characters.
640-642
: Avoid mutable default arguments.
🧹 Nitpick comments (4)
fla/layers/mom.py (4)
25-26
: Consider more descriptive function name.The function name
elu_p1
is not very descriptive. Consider renaming it to something likeelu_plus_one
orshifted_elu
to better convey its purpose.-def elu_p1(x): +def elu_plus_one(x): return (F.elu(x, 1., False) + 1.).to(x)
189-189
: Break long line for better readability.This line exceeds the 127 character limit.
- indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand( - b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1) + indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand( + b*num_memories, -1) + torch.cat( + [torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0 + ).unsqueeze(1)
227-227
: Fix line length in docstring.Break this long docstring line for better readability.
- 3. Uses `torch.scatter_add_` to scatter and sum the transformed outputs back to their original positions based on `indices`. + 3. Uses `torch.scatter_add_` to scatter and sum the transformed outputs back to their original + positions based on `indices`.
499-499
: Break long line for better readability.This line exceeds the 127 character limit.
- conv_mask = attention_mask[:, -hidden_states.shape[2]:].repeat_interleave(self.num_memories, 0) if attention_mask is not None else None + conv_mask = (attention_mask[:, -hidden_states.shape[2]:].repeat_interleave(self.num_memories, 0) + if attention_mask is not None else None)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
fla/layers/mom.py
(1 hunks)
🪛 Ruff (0.12.2)
fla/layers/mom.py
93-93: Undefined name unpad_input
(F821)
640-640: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
641-641: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
642-642: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
690-690: Use kwargs.get('cu_seqlens')
instead of kwargs.get('cu_seqlens', None)
Replace kwargs.get('cu_seqlens', None)
with kwargs.get('cu_seqlens')
(SIM910)
🪛 Flake8 (7.2.0)
fla/layers/mom.py
[error] 93-93: undefined name 'unpad_input'
(F821)
🪛 GitHub Actions: lint
fla/layers/mom.py
[error] 38-38: flake8: line too long (128 > 127 characters) (E501)
[error] 40-40: flake8: line too long (135 > 127 characters) (E501)
[error] 65-65: flake8: line too long (178 > 127 characters) (E501)
[error] 67-67: flake8: line too long (170 > 127 characters) (E501)
[error] 93-93: flake8: undefined name 'unpad_input' (F821)
[error] 189-189: flake8: line too long (135 > 127 characters) (E501)
[error] 227-227: flake8: line too long (128 > 127 characters) (E501)
[error] 499-499: flake8: line too long (147 > 127 characters) (E501)
🧰 Additional context used
🪛 Ruff (0.12.2)
fla/layers/mom.py
93-93: Undefined name unpad_input
(F821)
640-640: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
641-641: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
642-642: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
690-690: Use kwargs.get('cu_seqlens')
instead of kwargs.get('cu_seqlens', None)
Replace kwargs.get('cu_seqlens', None)
with kwargs.get('cu_seqlens')
(SIM910)
🪛 Flake8 (7.2.0)
fla/layers/mom.py
[error] 93-93: undefined name 'unpad_input'
(F821)
🪛 GitHub Actions: lint
fla/layers/mom.py
[error] 38-38: flake8: line too long (128 > 127 characters) (E501)
[error] 40-40: flake8: line too long (135 > 127 characters) (E501)
[error] 65-65: flake8: line too long (178 > 127 characters) (E501)
[error] 67-67: flake8: line too long (170 > 127 characters) (E501)
[error] 93-93: flake8: undefined name 'unpad_input' (F821)
[error] 189-189: flake8: line too long (135 > 127 characters) (E501)
[error] 227-227: flake8: line too long (128 > 127 characters) (E501)
[error] 499-499: flake8: line too long (147 > 127 characters) (E501)
shared_mem: bool = False, | ||
single_kv_proj: bool = False, | ||
**kwargs | ||
) -> MomAttention: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return type annotation should be None for init.
The __init__
method should return None
, not MomAttention
.
- ) -> MomAttention:
+ ) -> None:
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
) -> MomAttention: | |
) -> None: |
🤖 Prompt for AI Agents
In fla/layers/mom.py at line 309, the __init__ method currently has a return
type annotation of MomAttention, which is incorrect. Change the return type
annotation of the __init__ method to None to properly indicate it does not
return a value.
fla/models/mom/modeling_mom.py
Outdated
def _init_weights( | ||
self, | ||
module: nn.Module, | ||
rescale_prenorm_residual: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should disable rescale_prenorm_residual
by default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (7)
fla/models/mom/modeling_mom.py (4)
413-427
: Missing l2wrap integration.According to the PR objectives, l2wrap functionality should be integrated into the MoM models. Currently, only the load balancing auxiliary loss is implemented.
Let me verify if l2wrap exists in the codebase:
#!/bin/bash # Search for l2wrap implementation in the codebase rg -A 5 "l2wrap" --type py # Check if there's a l2wrap loss function ast-grep --pattern 'def $_l2wrap$_($_) { $$$ }'Based on the findings, you'll need to integrate l2wrap loss computation alongside the load balancing loss in the forward method.
193-195
: Add stacklevel to warning call.The warning should include
stacklevel=2
to correctly point to the caller in the stack trace.Apply this diff:
if output_attentions: - warnings.warn("`MomModel` does not `output_attentions` now, setting it to `False`.") + warnings.warn("`MomModel` does not support `output_attentions` now, setting it to `False`.", stacklevel=2) output_attentions = False
307-321
: Use proper exception chaining.Chain the original exception to provide better error context.
Apply this diff:
raise AttributeError( f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " f"which is not supported for {self.__class__.__name__}. " f"Try another generation strategy instead. " f"For the available generation strategies, check this doc: " f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" - ) + ) from exception else: - raise exception + raise
498-498
: Use 'dim' instead of 'axis' for torch.max.PyTorch uses
dim
parameter, notaxis
.Apply this diff:
- expert_mask = torch.max(expert_mask, axis=-2).values + expert_mask = torch.max(expert_mask, dim=-2).valuesfla/layers/mom.py (3)
28-105
: Fix undefined function and long docstring lines.Multiple issues in this function:
unpad_input
is called but not defined or imported- Several docstring lines exceed the 127-character limit
First, fix the undefined function by importing it:
-from fla.layers.utils import get_unpad_data, index_first_axis, pad_input +from fla.layers.utils import get_unpad_data, index_first_axis, pad_input, unpad_inputAlso, break the long docstring lines:
- Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. + Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to + different batches. - This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary - tensors for query, key, value tensors. + This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation + of the same intermediary tensors for query, key, value tensors.
291-310
: Fix return type annotation.The
__init__
method should returnNone
, notMomAttention
.Apply this diff:
- ) -> MomAttention: + ) -> None:
637-647
: Avoid mutable default arguments.Using mutable objects as default arguments can lead to unexpected behavior.
Apply this diff:
- conv_state_q=[None, None], - conv_state_k=[None, None], - conv_state_v=[None, None], + conv_state_q=None, + conv_state_k=None, + conv_state_v=None,Then initialize them in the method:
if conv_state_q is None: conv_state_q = [None, None] if conv_state_k is None: conv_state_k = [None, None] if conv_state_v is None: conv_state_v = [None, None]
🧹 Nitpick comments (2)
fla/layers/mom.py (2)
693-693
: Simplify kwargs.get() call.The
None
default is unnecessary sinceget()
returnsNone
by default.Apply this diff:
- cu_seqlens = kwargs.get('cu_seqlens', None) + cu_seqlens = kwargs.get('cu_seqlens')
1-723
: Add tests for MoM modules.As requested in the PR comments, please add tests for the new MoM modules to:
tests/test_model.py
tests/test_generation.py
This will ensure the new functionality is properly validated.
Would you like me to help generate test cases for the MoM attention layer and model classes?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
fla/layers/mom.py
(1 hunks)fla/models/mom/modeling_mom.py
(1 hunks)
🪛 Ruff (0.12.2)
fla/layers/mom.py
94-94: Undefined name unpad_input
(F821)
643-643: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
644-644: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
645-645: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
693-693: Use kwargs.get('cu_seqlens')
instead of kwargs.get('cu_seqlens', None)
Replace kwargs.get('cu_seqlens', None)
with kwargs.get('cu_seqlens')
(SIM910)
fla/models/mom/modeling_mom.py
194-194: No explicit stacklevel
keyword argument found
Set stacklevel=2
(B028)
312-318: Within an except
clause, raise exceptions with raise ... from err
or raise ... from None
to distinguish them from errors in exception handling
(B904)
🪛 Flake8 (7.2.0)
fla/layers/mom.py
[error] 94-94: undefined name 'unpad_input'
(F821)
🪛 GitHub Actions: lint
fla/layers/mom.py
[error] 38-38: flake8: line too long (128 > 127 characters) (E501)
[error] 40-40: flake8: line too long (135 > 127 characters) (E501)
[error] 65-65: flake8: line too long (178 > 127 characters) (E501)
[error] 67-67: flake8: line too long (170 > 127 characters) (E501)
[error] 94-94: flake8: undefined name 'unpad_input' (F821)
[error] 190-190: flake8: line too long (135 > 127 characters) (E501)
[error] 228-228: flake8: line too long (128 > 127 characters) (E501)
[error] 500-500: flake8: line too long (147 > 127 characters) (E501)
🧰 Additional context used
🪛 Ruff (0.12.2)
fla/layers/mom.py
94-94: Undefined name unpad_input
(F821)
643-643: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
644-644: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
645-645: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
693-693: Use kwargs.get('cu_seqlens')
instead of kwargs.get('cu_seqlens', None)
Replace kwargs.get('cu_seqlens', None)
with kwargs.get('cu_seqlens')
(SIM910)
fla/models/mom/modeling_mom.py
194-194: No explicit stacklevel
keyword argument found
Set stacklevel=2
(B028)
312-318: Within an except
clause, raise exceptions with raise ... from err
or raise ... from None
to distinguish them from errors in exception handling
(B904)
🪛 Flake8 (7.2.0)
fla/layers/mom.py
[error] 94-94: undefined name 'unpad_input'
(F821)
🪛 GitHub Actions: lint
fla/layers/mom.py
[error] 38-38: flake8: line too long (128 > 127 characters) (E501)
[error] 40-40: flake8: line too long (135 > 127 characters) (E501)
[error] 65-65: flake8: line too long (178 > 127 characters) (E501)
[error] 67-67: flake8: line too long (170 > 127 characters) (E501)
[error] 94-94: flake8: undefined name 'unpad_input' (F821)
[error] 190-190: flake8: line too long (135 > 127 characters) (E501)
[error] 228-228: flake8: line too long (128 > 127 characters) (E501)
[error] 500-500: flake8: line too long (147 > 127 characters) (E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: check-pt-python-compatibility
🔇 Additional comments (1)
fla/models/mom/modeling_mom.py (1)
32-108
: MomBlock implementation looks good!The block structure with conditional attention backend selection and proper residual connections is well implemented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🔭 Outside diff range comments (1)
tests/models/test_modeling_mom.py (1)
1-58
: Add module docstring and extend generation tests withuse_l2warp
The helpers in
tests/models/test_modeling_base.py
confirm thatrun_test_generation
accepts an optionaluse_l2warp
flag (defaultFalse
). To fully cover both code paths, updatetests/models/test_modeling_mom.py
as follows:• Add a module-level docstring at the top.
• Parameterizetest_generation
overuse_l2warp
(e.g.True
/False
) and pass it through torun_test_generation
.Example patch:
--- a/tests/models/test_modeling_mom.py +++ b/tests/models/test_modeling_mom.py @@ -1,4 +1,12 @@ # -*- coding: utf-8 -*- + +""" +Test module for MoM (Mixture-of-Memories) model implementation. + +This module tests: + - forward/backward pass via run_test_model_forward_backward + - K/V cache generation via run_test_generation (with and without l2warp) +""" import pytest import torch @@ -40,16 +48,26 @@ from .test_modeling_base import run_test_generation, run_test_model_forward_backward # =================================================================================== # Test for Generation # =================================================================================== -@pytest.mark.parametrize( - ['L', 'B', 'T', 'H', 'D', 'dtype'], +@pytest.mark.parametrize( + ['L', 'B', 'T', 'H', 'D', 'use_l2warp', 'dtype'], [ - pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-{}".format(*test)) - for test in [ - (2, 4, 2000, 8, 64, torch.float16), - ] + pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-use_l2warp{}-{}".format(*test)) + for test in [ + # test both with and without l2warp + (2, 4, 2000, 8, 64, False, torch.float16), + (2, 4, 2000, 8, 64, True, torch.float16), + ] ] ) def test_generation( - L: int, - B: int, - T: int, - H: int, - D: int, - dtype: torch.dtype, + L: int, + B: int, + T: int, + H: int, + D: int, + use_l2warp: bool, + dtype: torch.dtype, ): - run_test_generation(L, B, T, H, D, MomConfig, dtype) + run_test_generation( + L, B, T, H, D, MomConfig, dtype, use_l2warp=use_l2warp + )
- Confirms signature compatibility in
tests/models/test_modeling_base.py
.- Improves coverage by testing both default and
l2warp=True
generation paths.- Adds a clear module-level docstring.
🧹 Nitpick comments (1)
tests/models/test_modeling_mom.py (1)
14-34
: Good test coverage for forward/backward pass validation.The parameterized test covers key variations including L2 warp functionality (addressing the PR reviewer's request). The test cases appropriately exercise different model configurations.
Consider adding test cases with different dtypes (e.g., float16, float32) to ensure broader compatibility:
for test in [ (4, 4, 1024, 4, 64, True, torch.bfloat16), (4, 4, 1024, 4, 64, False, torch.bfloat16), (4, 4, 1024, 4, 128, False, torch.bfloat16), + (4, 4, 1024, 4, 64, True, torch.float16), + (4, 4, 512, 4, 64, False, torch.float32), ]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
fla/__init__.py
(3 hunks)tests/models/test_modeling_mom.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- fla/init.py
🔇 Additional comments (1)
tests/models/test_modeling_mom.py (1)
1-8
: LGTM: Imports are well-structured.The imports follow standard conventions and correctly reference the new MoM components and test utilities.
@pytest.mark.parametrize( | ||
['L', 'B', 'T', 'H', 'D', 'dtype'], | ||
[ | ||
pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-{}".format(*test)) | ||
for test in [ | ||
(2, 4, 2000, 8, 64, torch.float16), | ||
] | ||
] | ||
) | ||
def test_generation( | ||
L: int, | ||
B: int, | ||
T: int, | ||
H: int, | ||
D: int, | ||
dtype: torch.dtype, | ||
): | ||
run_test_generation(L, B, T, H, D, MomConfig, dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider adding L2 warp parameter to generation test.
The generation test is well-structured but missing the use_l2warp
parameter that was specifically requested by the PR reviewer. This parameter should be tested for generation functionality as well.
Add the use_l2warp
parameter to ensure complete L2 warp functionality testing:
@pytest.mark.parametrize(
- ['L', 'B', 'T', 'H', 'D', 'dtype'],
+ ['L', 'B', 'T', 'H', 'D', 'use_l2warp', 'dtype'],
[
- pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-{}".format(*test))
+ pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-use_l2warp{}-{}".format(*test))
for test in [
- (2, 4, 2000, 8, 64, torch.float16),
+ (2, 4, 2000, 8, 64, True, torch.float16),
+ (2, 4, 2000, 8, 64, False, torch.float16),
]
]
)
def test_generation(
L: int,
B: int,
T: int,
H: int,
D: int,
+ use_l2warp: bool,
dtype: torch.dtype,
):
- run_test_generation(L, B, T, H, D, MomConfig, dtype)
+ run_test_generation(L, B, T, H, D, MomConfig, use_l2warp=use_l2warp, dtype=dtype)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
@pytest.mark.parametrize( | |
['L', 'B', 'T', 'H', 'D', 'dtype'], | |
[ | |
pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-{}".format(*test)) | |
for test in [ | |
(2, 4, 2000, 8, 64, torch.float16), | |
] | |
] | |
) | |
def test_generation( | |
L: int, | |
B: int, | |
T: int, | |
H: int, | |
D: int, | |
dtype: torch.dtype, | |
): | |
run_test_generation(L, B, T, H, D, MomConfig, dtype) | |
@pytest.mark.parametrize( | |
['L', 'B', 'T', 'H', 'D', 'use_l2warp', 'dtype'], | |
[ | |
pytest.param(*test, id="L{}-B{}-T{}-H{}-D{}-use_l2warp{}-{}".format(*test)) | |
for test in [ | |
(2, 4, 2000, 8, 64, True, torch.float16), | |
(2, 4, 2000, 8, 64, False, torch.float16), | |
] | |
] | |
) | |
def test_generation( | |
L: int, | |
B: int, | |
T: int, | |
H: int, | |
D: int, | |
use_l2warp: bool, | |
dtype: torch.dtype, | |
): | |
run_test_generation(L, B, T, H, D, MomConfig, use_l2warp=use_l2warp, dtype=dtype) |
🤖 Prompt for AI Agents
In tests/models/test_modeling_mom.py around lines 40 to 57, the test_generation
function lacks the use_l2warp parameter, which is needed to fully test L2 warp
functionality. Modify the pytest.mark.parametrize decorator to include
use_l2warp as a parameter with appropriate test values, and update the
test_generation function signature to accept use_l2warp. Pass this parameter to
the run_test_generation call to ensure the generation test covers the L2 warp
feature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (3)
fla/layers/mom.py (3)
90-90
: Critical issue: undefined functionunpad_input
still not resolved.This issue was flagged in previous reviews but remains unaddressed. The function
unpad_input
is not imported and will cause aNameError
at runtime.You need to either import
unpad_input
fromfla.layers.utils
if it exists, or implement an alternative solution. Check the available functions in the utils module:#!/bin/bash # Verify available functions in fla.layers.utils ast-grep --pattern 'def $NAME($_) { $$$ }' fla/layers/utils.py
304-304
: Fix return type annotation for init method.The
__init__
method should returnNone
, notMomAttention
. This issue was flagged in previous reviews but remains unaddressed.- ) -> MomAttention: + ) -> None:
654-656
: Fix mutable default arguments.Using mutable objects as default arguments can lead to unexpected behavior. This issue was flagged in previous reviews but remains unaddressed.
- conv_state_q=[None, None], - conv_state_k=[None, None], - conv_state_v=[None, None], + conv_state_q=None, + conv_state_k=None, + conv_state_v=None,Then initialize them in the method:
if conv_state_q is None: conv_state_q = [None, None] if conv_state_k is None: conv_state_k = [None, None] if conv_state_v is None: conv_state_v = [None, None]
🧹 Nitpick comments (4)
fla/layers/mom.py (4)
193-201
: Clean up commented code.There's a significant amount of commented-out code that should be removed to improve readability and maintainability.
- # transformed_x = transformed_x * mask.unsqueeze(-1).expand_as(transformed_x) - # pad_x = torch.zeros((b * num_memories, capacity_len-max_len, d), dtype=transformed_x.dtype, device=transformed_x.device) - # pad_mask = torch.zeros((b * num_memories, capacity_len-max_len), dtype=transformed_x.dtype, device=transformed_x.device) - # left pad - # transformed_x = torch.cat((pad_x, transformed_x), dim=1).reshape((b, num_memories, capacity_len, d)).transpose(0, 1) mask_2 = mask.reshape((b, num_memories, max_len)).transpose(0, 1) - # truncation_indices += capacity_len-max_len - # if attention_mask is not None: - # mask_2
733-733
: Fix whitespace before bracket.There's incorrect whitespace before the bracket causing a pipeline failure.
- max_len = (cu_seqlens[1:] - cu_seqlens [:-1]).max().item() + max_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
641-641
: Fix indentation issue.The continuation line is under-indented for visual alignment, causing a pipeline failure.
Ensure proper indentation alignment:
- seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) + seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask)
441-441
: Simplify kwargs.get usage.Static analysis suggests using
kwargs.get('cu_seqlens')
instead ofkwargs.get('cu_seqlens', None)
sinceNone
is the default.- origin_cu_seqlens = kwargs.get('cu_seqlens', None) + origin_cu_seqlens = kwargs.get('cu_seqlens')
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
fla/layers/mom.py
(1 hunks)fla/models/mom/__init__.py
(1 hunks)pyproject.toml
(1 hunks)
✅ Files skipped from review due to trivial changes (2)
- pyproject.toml
- fla/models/mom/init.py
🧰 Additional context used
🪛 Ruff (0.12.2)
fla/layers/mom.py
441-441: Use kwargs.get('cu_seqlens')
instead of kwargs.get('cu_seqlens', None)
Replace kwargs.get('cu_seqlens', None)
with kwargs.get('cu_seqlens')
(SIM910)
654-654: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
655-655: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
656-656: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
671-671: Undefined name q_len
(F821)
🪛 Flake8 (7.2.0)
fla/layers/mom.py
[error] 641-641: continuation line under-indented for visual indent
(E128)
[error] 671-671: undefined name 'q_len'
(F821)
🪛 GitHub Actions: lint
fla/layers/mom.py
[error] 33-33: flake8: line too long (128 > 127 characters) (E501)
[error] 35-35: flake8: line too long (135 > 127 characters) (E501)
[error] 60-60: flake8: line too long (178 > 127 characters) (E501)
[error] 62-62: flake8: line too long (170 > 127 characters) (E501)
[error] 187-187: flake8: line too long (135 > 127 characters) (E501)
[error] 222-222: flake8: line too long (128 > 127 characters) (E501)
[error] 508-508: flake8: line too long (133 > 127 characters) (E501)
[error] 516-516: flake8: line too long (133 > 127 characters) (E501)
[error] 524-524: flake8: line too long (133 > 127 characters) (E501)
[error] 552-552: flake8: line too long (145 > 127 characters) (E501)
[error] 586-586: flake8: line too long (130 > 127 characters) (E501)
[error] 599-599: flake8: line too long (145 > 127 characters) (E501)
[error] 641-641: flake8: continuation line under-indented for visual indent (E128)
[error] 671-671: flake8: undefined name 'q_len' (F821)
[error] 733-733: flake8: whitespace before '[' (E211)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: check-pt-python-compatibility
🔇 Additional comments (2)
fla/layers/mom.py (2)
206-277
: Well-implemented reconstruction logic.The function correctly handles the complex inverse transformation from memory-organized chunks back to the original sequence format, with proper routing weight application and validation assertions.
746-777
: Well-implemented state management helpers.The
prepare_recurrent_state
andhandle_recurrent_state
methods correctly handle the complex logic for managing recurrent states across different memory configurations and batch scenarios.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (3)
fla/layers/mom.py (3)
304-304
: Fix incorrect return type annotation.The
__init__
method should returnNone
, notMomAttention
.- ) -> MomAttention: + ) -> None:
667-669
: Avoid mutable default arguments.Using mutable objects as default arguments can lead to unexpected behavior.
Apply this diff:
- conv_state_q=[None, None], - conv_state_k=[None, None], - conv_state_v=[None, None], + conv_state_q=None, + conv_state_k=None, + conv_state_v=None,Then initialize them in the method:
if conv_state_q is None: conv_state_q = [None, None] if conv_state_k is None: conv_state_k = [None, None] if conv_state_v is None: conv_state_v = [None, None]
686-686
: Fix undefined variableq_len
.The variable
q_len
is used but not defined, which will cause aNameError
at runtime.Looking at the context, it appears
q_len
should behidden_states.shape[1]
:- indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:]) + indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -hidden_states.shape[1]:])
🧹 Nitpick comments (8)
fla/layers/mom.py (8)
60-60
: Fix line length violations.These lines exceed the 127-character limit. Break them into multiple lines for better readability.
- (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): - The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), + used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). - (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): - Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, + `max_seqlen_in_batch_k` for the source sequence i.e. key/value).Also applies to: 62-62
187-189
: Fix line length violation and improve readability.Line 187 exceeds the character limit. Break it for better readability.
- indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand( - b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1) + indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand( + b*num_memories, -1 + ) + torch.cat( + [torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0 + ).unsqueeze(1)
222-222
: Fix line length violation in docstring.- 2. Applies the `mask` to zero out invalid positions. + 2. Applies the `mask` to zero out invalid positions.
441-441
: Simplify kwargs.get usage.- origin_cu_seqlens = kwargs.get('cu_seqlens', None) + origin_cu_seqlens = kwargs.get('cu_seqlens')
558-558
: Fix line length violations in recurrent state handling.- recurrent_state[0] = self.handle_recurrent_state(recurrent_state[0], recurrent_state_, cu_seqlens, cu_seqlen_all[0], reverse_indices) + recurrent_state[0] = self.handle_recurrent_state( + recurrent_state[0], recurrent_state_, cu_seqlens, cu_seqlen_all[0], reverse_indices + )Apply similar formatting to line 605.
Also applies to: 605-605
625-625
: Fix indentation issue.Line 625 has incorrect indentation for line continuation.
o = reconstruct(o, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, - seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) + seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask)
644-644
: Fix line length violation.- shared_hidden_states = index_first_axis(rearrange(shared_hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0) + shared_hidden_states = index_first_axis( + rearrange(shared_hidden_states, "b s ... -> (b s) ..."), indices + ).unsqueeze(0)
750-750
: Fix whitespace before bracket.- max_len = (cu_seqlens[1:] - cu_seqlens [:-1]).max().item() + max_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
fla/layers/mom.py
(1 hunks)fla/models/mom/configuration_mom.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- fla/models/mom/configuration_mom.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
fla/layers/mom.py (8)
fla/modules/fused_norm_gate.py (1)
FusedRMSNormGated
(993-1054)fla/modules/layernorm.py (1)
RMSNorm
(1060-1107)fla/modules/convolution.py (1)
ShortConvolution
(661-902)fla/ops/gated_delta_rule/chunk.py (1)
chunk_gated_delta_rule
(221-340)fla/ops/gated_delta_rule/fused_recurrent.py (1)
fused_recurrent_gated_delta_rule
(240-351)fla/models/utils.py (2)
Cache
(11-150)update
(43-122)fla/layers/utils.py (3)
get_unpad_data
(75-98)pad_input
(176-197)unpad_input
(101-173)fla/models/mom/modeling_mom.py (3)
forward
(77-107)forward
(181-264)forward
(356-440)
🪛 Ruff (0.12.2)
fla/layers/mom.py
441-441: Use kwargs.get('cu_seqlens')
instead of kwargs.get('cu_seqlens', None)
Replace kwargs.get('cu_seqlens', None)
with kwargs.get('cu_seqlens')
(SIM910)
667-667: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
668-668: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
669-669: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
🪛 Flake8 (7.2.0)
fla/layers/mom.py
[error] 625-625: continuation line under-indented for visual indent
(E128)
🪛 GitHub Actions: lint
fla/layers/mom.py
[error] 33-33: flake8: line too long (128 > 127 characters) (E501)
[error] 35-35: flake8: line too long (135 > 127 characters) (E501)
[error] 60-60: flake8: line too long (178 > 127 characters) (E501)
[error] 62-62: flake8: line too long (170 > 127 characters) (E501)
[error] 187-187: flake8: line too long (135 > 127 characters) (E501)
[error] 222-222: flake8: line too long (128 > 127 characters) (E501)
[error] 514-514: flake8: line too long (133 > 127 characters) (E501)
[error] 522-522: flake8: line too long (133 > 127 characters) (E501)
[error] 530-530: flake8: line too long (133 > 127 characters) (E501)
[error] 558-558: flake8: line too long (145 > 127 characters) (E501)
[error] 592-592: flake8: line too long (130 > 127 characters) (E501)
[error] 605-605: flake8: line too long (145 > 127 characters) (E501)
[error] 625-625: flake8: continuation line under-indented for visual indent (E128)
[error] 644-644: flake8: line too long (134 > 127 characters) (E501)
[error] 750-750: flake8: whitespace before '[' (E211)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: check-pt-python-compatibility
fla/layers/mom.py
Outdated
output_final_state=use_cache, | ||
cu_seqlens=conv_cu_seqlens, | ||
) | ||
conv_state_q[0] = self.handle_recurrent_state(conv_state_q[0], conv_q_new, cu_seqlens, cu_seqlen_all[0], reverse_indices) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Fix line length violations in convolution state handling.
Multiple lines exceed the 127-character limit.
- conv_state_q[0] = self.handle_recurrent_state(conv_state_q[0], conv_q_new, cu_seqlens, cu_seqlen_all[0], reverse_indices)
+ conv_state_q[0] = self.handle_recurrent_state(
+ conv_state_q[0], conv_q_new, cu_seqlens, cu_seqlen_all[0], reverse_indices
+ )
conv_k = self.prepare_recurrent_state(conv_state_k[0], cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size)
cu_k, conv_k_new = self.k_conv1d(
x=cu_k,
cache=conv_k,
output_final_state=use_cache,
cu_seqlens=conv_cu_seqlens,
)
- conv_state_k[0] = self.handle_recurrent_state(conv_state_k[0], conv_k_new, cu_seqlens, cu_seqlen_all[0], reverse_indices)
+ conv_state_k[0] = self.handle_recurrent_state(
+ conv_state_k[0], conv_k_new, cu_seqlens, cu_seqlen_all[0], reverse_indices
+ )
conv_v = self.prepare_recurrent_state(conv_state_v[0], cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size)
cu_v, conv_v_new = self.v_conv1d(
x=cu_v,
cache=conv_v,
output_final_state=use_cache,
cu_seqlens=conv_cu_seqlens,
)
- conv_state_v[0] = self.handle_recurrent_state(conv_state_v[0], conv_v_new, cu_seqlens, cu_seqlen_all[0], reverse_indices)
+ conv_state_v[0] = self.handle_recurrent_state(
+ conv_state_v[0], conv_v_new, cu_seqlens, cu_seqlen_all[0], reverse_indices
+ )
Also applies to: 522-522, 530-530
🧰 Tools
🪛 GitHub Actions: lint
[error] 514-514: flake8: line too long (133 > 127 characters) (E501)
🤖 Prompt for AI Agents
In fla/layers/mom.py at lines 514, 522, and 530, the lines exceed the
127-character limit. Refactor these lines by breaking them into multiple shorter
lines or using intermediate variables to ensure no line surpasses the character
limit while preserving the original logic and readability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
♻️ Duplicate comments (10)
fla/models/mom/modeling_mom.py (7)
120-125
: Consider explicitly setting rescale_prenorm_residual default.Based on the past review comment,
rescale_prenorm_residual
should be disabled by default. Consider explicitly setting this in the initialization logic or documenting why the current approach is preferred.
160-161
: Add missing blank line before class definition.PEP 8 requires two blank lines before class definitions.
@dataclass class MomOutputWithPast(BaseModelOutputWithPast): aux_loss: Optional[torch.FloatTensor] = None + class MomModel(MomPreTrainedModel):
272-274
: Add missing blank line before class definition.PEP 8 requires two blank lines before class definitions.
@dataclass class MomCausalLMOutputWithPast(CausalLMOutputWithPast): aux_loss: Optional[torch.FloatTensor] = None + class MomForCausalLM(MomPreTrainedModel, GenerationMixin):
312-322
: Improve exception handling with proper chaining.Use exception chaining to preserve the error context.
if 'past_key_values' in str(exception): raise AttributeError( f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " f"which is not supported for {self.__class__.__name__}. " f"Try another generation strategy instead. " f"For the available generation strategies, check this doc: " f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" - ) + ) from exception else: - raise exception + raise
415-429
: Missing l2wrap integration.According to the PR objectives, l2wrap functionality should be included in the new MoM models. The current implementation only includes the load balancing auxiliary loss.
Would you like me to help implement the l2wrap loss integration alongside the existing auxiliary loss calculation?
499-499
: Use 'dim' instead of 'axis' for torch.max.PyTorch uses
dim
parameter, notaxis
.- expert_mask = torch.max(expert_mask, axis=-2).values + expert_mask = torch.max(expert_mask, dim=-2).values
194-196
: Add stacklevel to warning for proper trace.The warning should include
stacklevel=2
to correctly point to the caller.if output_attentions: - warnings.warn("`MomModel` does not `output_attentions` now, setting it to `False`.") + warnings.warn("`MomModel` does not `output_attentions` now, setting it to `False`.", stacklevel=2) output_attentions = Falsefla/layers/mom.py (3)
305-305
: Fix return type annotation for init.The
__init__
method should returnNone
, notMomAttention
.- ) -> MomAttention: + ) -> None:
671-673
: Avoid mutable default arguments.Using mutable objects as default arguments can lead to unexpected behavior.
- conv_state_q=[None, None], - conv_state_k=[None, None], - conv_state_v=[None, None], + conv_state_q=None, + conv_state_k=None, + conv_state_v=None,Then initialize them in the method:
if conv_state_q is None: conv_state_q = [None, None] if conv_state_k is None: conv_state_k = [None, None] if conv_state_v is None: conv_state_v = [None, None]
509-532
: Fix line length violations in convolution handling.Multiple lines exceed the 127-character limit.
- conv_q = self.prepare_recurrent_state(conv_state_q[0], conv_cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size) + conv_q = self.prepare_recurrent_state( + conv_state_q[0], conv_cu_seqlens, cu_seqlen_all[0], reverse_indices, batch_size + ) cu_q, conv_q_new = self.q_conv1d( x=cu_q, cache=conv_q, output_final_state=use_cache, cu_seqlens=conv_cu_seqlens, ) - conv_state_q[0] = self.handle_recurrent_state(conv_state_q[0], conv_q_new, conv_cu_seqlens, cu_seqlen_all[0], reverse_indices) + conv_state_q[0] = self.handle_recurrent_state( + conv_state_q[0], conv_q_new, conv_cu_seqlens, cu_seqlen_all[0], reverse_indices + )Apply similar formatting to the k and v convolution blocks.
🧹 Nitpick comments (1)
fla/layers/mom.py (1)
1-823
: Missing test coverage.Per the PR objectives, tests should be added to
tests/test_model.py
andtests/test_generation.py
for the new MoM modules and model. Please ensure comprehensive test coverage is included.Would you like me to help generate the test cases for the MomAttention layer and MomForCausalLM model?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
fla/layers/mom.py
(1 hunks)fla/models/mom/modeling_mom.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
fla/layers/mom.py
442-442: Use kwargs.get('cu_seqlens')
instead of kwargs.get('cu_seqlens', None)
Replace kwargs.get('cu_seqlens', None)
with kwargs.get('cu_seqlens')
(SIM910)
671-671: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
672-672: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
673-673: Do not use mutable data structures for argument defaults
Replace with None
; initialize within function
(B006)
fla/models/mom/modeling_mom.py
195-195: No explicit stacklevel
keyword argument found
Set stacklevel=2
(B028)
314-320: Within an except
clause, raise exceptions with raise ... from err
or raise ... from None
to distinguish them from errors in exception handling
(B904)
🪛 Flake8 (7.2.0)
fla/layers/mom.py
[error] 627-627: continuation line under-indented for visual indent
(E128)
🪛 GitHub Actions: lint
fla/layers/mom.py
[error] 33-33: flake8: line too long (128 > 127 characters) (E501)
[error] 35-35: flake8: line too long (135 > 127 characters) (E501)
[error] 60-60: flake8: line too long (178 > 127 characters) (E501)
[error] 62-62: flake8: line too long (170 > 127 characters) (E501)
[error] 187-187: flake8: line too long (135 > 127 characters) (E501)
[error] 222-222: flake8: line too long (128 > 127 characters) (E501)
[error] 509-509: flake8: line too long (130 > 127 characters) (E501)
[error] 516-516: flake8: line too long (138 > 127 characters) (E501)
[error] 517-517: flake8: line too long (130 > 127 characters) (E501)
[error] 524-524: flake8: line too long (138 > 127 characters) (E501)
[error] 525-525: flake8: line too long (130 > 127 characters) (E501)
[error] 532-532: flake8: line too long (138 > 127 characters) (E501)
[error] 560-560: flake8: line too long (145 > 127 characters) (E501)
[error] 594-594: flake8: line too long (130 > 127 characters) (E501)
[error] 607-607: flake8: line too long (145 > 127 characters) (E501)
[error] 627-627: flake8: continuation line under-indented for visual indent (E128)
[error] 646-646: flake8: line too long (134 > 127 characters) (E501)
[error] 754-754: flake8: whitespace before '[' (E211)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: check-pt-python-compatibility
🔇 Additional comments (1)
fla/models/mom/modeling_mom.py (1)
32-109
: Well-structured transformer block implementation!The MomBlock class properly implements the transformer architecture with conditional attention mechanism selection and appropriate error handling for unsupported backends.
Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. | ||
|
||
This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix line length violations in docstrings.
Several lines exceed the 127-character limit. Break them into multiple lines.
- Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
+ Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong
+ to different batches.
- This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
+ This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation
+ of the same intermediary
- The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
+ The cumulative sequence lengths for the target (query) and source (key, value), used to index into
+ ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
- Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
+ Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query,
+ `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
Also applies to: 60-62
🧰 Tools
🪛 GitHub Actions: lint
[error] 33-33: flake8: line too long (128 > 127 characters) (E501)
[error] 35-35: flake8: line too long (135 > 127 characters) (E501)
🤖 Prompt for AI Agents
In fla/layers/mom.py at lines 33-35 and 60-62, the docstring lines exceed the
127-character limit. Break these long lines into multiple shorter lines to
comply with the line length restriction, ensuring the text remains clear and
readable without altering the content.
flatten_offset = batch_memory_tokens.flatten().cumsum(dim=0) | ||
max_len = batch_memory_tokens.max() | ||
indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand( | ||
b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix line length violation.
Line exceeds 127-character limit.
- indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand(
- b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1)
+ indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand(
+ b*num_memories, -1
+ ) + torch.cat(
+ [torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0
+ ).unsqueeze(1)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
b*num_memories, -1) + torch.cat([torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0).unsqueeze(1) | |
indices = torch.arange(max_len, device=flatten_offset.device).unsqueeze(0).expand( | |
b * num_memories, -1 | |
) + torch.cat( | |
[torch.tensor([0], device=flatten_offset.device), flatten_offset[:-1]], dim=0 | |
).unsqueeze(1) |
🧰 Tools
🪛 GitHub Actions: lint
[error] 187-187: flake8: line too long (135 > 127 characters) (E501)
🤖 Prompt for AI Agents
In fla/layers/mom.py at line 187, the line exceeds the 127-character limit.
Break the line into multiple shorter lines by splitting the expression at
logical points, such as after the addition operator or before the torch.cat
call, and use proper indentation to maintain readability and comply with the
line length restriction.
resortd_x = torch.zeros((b * s * k, d), device=gathered_x.device, dtype=gathered_x.dtype).scatter_add_( | ||
0, | ||
indices.reshape(-1).unsqueeze(-1).expand(-1, d), | ||
gathered_x, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix whitespace and indentation issues.
Fix whitespace around comma and continuation line indentation.
- resortd_x = torch.zeros((b * s * k, d), device=gathered_x.device, dtype=gathered_x.dtype).scatter_add_(
- 0,
- indices.reshape(-1).unsqueeze(-1).expand(-1, d),
- gathered_x,
- )
+ resortd_x = torch.zeros(
+ (b * s * k, d), device=gathered_x.device, dtype=gathered_x.dtype
+ ).scatter_add_(
+ 0,
+ indices.reshape(-1).unsqueeze(-1).expand(-1, d),
+ gathered_x,
+ )
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
resortd_x = torch.zeros((b * s * k, d), device=gathered_x.device, dtype=gathered_x.dtype).scatter_add_( | |
0, | |
indices.reshape(-1).unsqueeze(-1).expand(-1, d), | |
gathered_x, | |
) | |
resortd_x = torch.zeros( | |
(b * s * k, d), device=gathered_x.device, dtype=gathered_x.dtype | |
).scatter_add_( | |
0, | |
indices.reshape(-1).unsqueeze(-1).expand(-1, d), | |
gathered_x, | |
) |
🤖 Prompt for AI Agents
In fla/layers/mom.py around lines 265 to 269, fix the whitespace by removing any
extra spaces around commas and adjust the indentation of the continuation lines
to align properly under the opening parenthesis for better readability and
consistency.
o = reconstruct(o, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, | ||
seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) | ||
o = rearrange(o, 'b l (h d) -> b l h d', h=self.num_heads) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix continuation line indentation.
The line continuation is under-indented.
o = reconstruct(o, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size,
- seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask)
+ seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
o = reconstruct(o, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, | |
seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) | |
o = rearrange(o, 'b l (h d) -> b l h d', h=self.num_heads) | |
o = reconstruct(o, indices=indices, sorted_indices=sorted_indices, batch_size=batch_size, | |
seq_len=seq_len, topk=self.topk, routing_weights=routing_weights, mask=mask) | |
o = rearrange(o, 'b l (h d) -> b l h d', h=self.num_heads) |
🧰 Tools
🪛 Flake8 (7.2.0)
[error] 627-627: continuation line under-indented for visual indent
(E128)
🪛 GitHub Actions: lint
[error] 627-627: flake8: continuation line under-indented for visual indent (E128)
🤖 Prompt for AI Agents
In fla/layers/mom.py around lines 626 to 628, the continuation line of the
reconstruct function call is under-indented. Adjust the indentation of the
second line so it aligns properly under the opening parenthesis of the first
line, maintaining consistent indentation style for line continuations.
|
||
def cu2pad(self, x, cu_seqlens): | ||
batch_size = cu_seqlens.shape[0] - 1 | ||
max_len = (cu_seqlens[1:] - cu_seqlens [:-1]).max().item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove whitespace before bracket.
There's an extra space before the bracket.
- max_len = (cu_seqlens[1:] - cu_seqlens [:-1]).max().item()
+ max_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
max_len = (cu_seqlens[1:] - cu_seqlens [:-1]).max().item() | |
max_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() |
🧰 Tools
🪛 GitHub Actions: lint
[error] 754-754: flake8: whitespace before '[' (E211)
🤖 Prompt for AI Agents
In fla/layers/mom.py at line 754, remove the extra whitespace before the bracket
in the expression cu_seqlens [:-1]. Change cu_seqlens [:-1] to cu_seqlens[:-1]
to fix the spacing issue.
expert_mask = torch.nn.functional.one_hot(selected_experts, self.num_memories) | ||
|
||
# For a given token, determine if it was routed to a given expert. | ||
expert_mask = torch.max(expert_mask, axis=-2).values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use 'dim' instead of 'axis' for torch.max.
PyTorch uses dim
parameter, not axis
.
- expert_mask = torch.max(expert_mask, axis=-2).values
+ expert_mask = torch.max(expert_mask, dim=-2).values
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
expert_mask = torch.max(expert_mask, axis=-2).values | |
expert_mask = torch.max(expert_mask, dim=-2).values |
🤖 Prompt for AI Agents
In fla/layers/mom.py at line 811, replace the use of 'axis' with 'dim' in the
torch.max function call. Change torch.max(expert_mask, axis=-2).values to
torch.max(expert_mask, dim=-2).values to comply with PyTorch's API.
Changes Made
New Modules
fla/layers/
mom.py
mom_varlen.py
(for varlen)Model Architecture
models/
mom_gated_deltanet
(Gated DeltaNet with MoM integration)Summary by CodeRabbit