Skip to content

[WIP] Add Log-Linear Attention #524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

2022tgoel
Copy link

@2022tgoel 2022tgoel commented Jul 12, 2025

Summary by CodeRabbit

  • New Features

    • Introduced a highly optimized chunked log-linear attention mechanism with support for recurrent state management, hierarchical masking, and variable-length sequences.
    • Added a naive log-linear attention implementation for reference and testing.
    • Made new attention operators and state objects available for import.
  • Tests

    • Added comprehensive tests comparing the optimized chunked attention operator to the naive implementation for correctness.

Copy link

coderabbitai bot commented Jul 12, 2025

Walkthrough

This update introduces a chunked log-linear attention mechanism implemented with custom Triton kernels and a naive PyTorch reference. The new operator and its state class are exposed via package __init__.py files. A comprehensive test suite validates the Triton implementation against the naive version using randomized input data and parameterized configurations.

Changes

File(s) Change Summary
fla/ops/init.py Added imports and exports for chunk_log_linear_attn and LogLinearAttentionState.
fla/ops/log_linear_attn/init.py New file; imports and exports chunk_log_linear_attn and LogLinearAttentionState.
fla/ops/log_linear_attn/chunk.py New file; implements chunked log-linear attention with Triton kernels, utility functions, and state.
fla/ops/log_linear_attn/naive.py New file; implements naive log-linear attention and supporting functions in PyTorch.
tests/ops/test_log_linear_attn.py New file; parameterized test comparing Triton and naive attention outputs for correctness.

Sequence Diagram(s)

sequenceDiagram
    participant Test as test_log_linear_attn.py
    participant Triton as chunk_log_linear_attn (Triton)
    participant Naive as naive_log_linear_attn (PyTorch)
    Test->>Triton: Call chunk_log_linear_attn(q, k, v, g, l, ...)
    Triton->>Triton: Prepare states, run Triton kernels
    Triton-->>Test: Return output
    Test->>Naive: Call naive_log_linear_attn(q, k, v, g, l)
    Naive-->>Test: Return output
    Test->>Test: Compare outputs for correctness
Loading

Poem

In fields of memory, kernels churn,
With Triton speed, the queries learn.
Naive and chunked, side by side,
Their outputs checked, no bugs to hide.
A hop, a test, a bunny cheer—
Log-linear dreams are crystal clear!
🐇✨

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai auto-generate unit tests to generate unit tests for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 13

🔭 Outside diff range comments (1)
fla/ops/__init__.py (1)

1-53: Fix import order to comply with isort.

The pipeline indicates that isort would modify this file. Please run isort to fix the import ordering.

🧹 Nitpick comments (7)
fla/ops/log_linear_attn/naive.py (1)

37-44: Add docstring and consider using math.log2 instead of numpy.

The function needs documentation. Also, importing numpy just for np.log2 seems unnecessary when math.log2 is available.

+import math
+
 def construct_H_matrix(a, L):
+    """Construct the H matrix for log-linear attention.
+    
+    Args:
+        a: Attention weights of shape (..., T)
+        L: Hierarchical level values of shape (..., num_levels, T)
+    
+    Returns:
+        H matrix of shape (..., T, T)
+    """
     T = a.size(-1)
     A = torch.exp(segsum(a))
     H = torch.zeros_like(A)
-    for level in range(int(np.log2(T)) + 1):
+    for level in range(int(math.log2(T)) + 1):
         mask = construct_level_mask(level, L)
         H += A * mask
     return H
tests/ops/test_log_linear_attn.py (2)

1-9: Fix import order and add module docstring.

Consider adding a module-level docstring to describe what this test file covers.

+"""Tests for log-linear attention implementation."""
 import os
 
 import pytest
 import torch
 
 from fla.ops.log_linear_attn import chunk_log_linear_attn
 from fla.ops.log_linear_attn.naive import naive_log_linear_attn
 from fla.utils import assert_close, device, device_platform

11-25: Consider adding more test coverage.

The current test only covers three configurations. Consider adding:

  • Edge cases (small sequences, single head, etc.)
  • Tests with initial state
  • Tests with variable-length sequences (cu_seqlens)
  • Different data types (bfloat16, float16)

Would you like me to generate additional test cases to improve coverage?

fla/ops/log_linear_attn/chunk.py (4)

84-111: Consider refactoring repetitive level initialization code.

The hardcoded creation of 12 KV states (lines 86-110) is repetitive. While this might be necessary for Triton's compilation model, it reduces maintainability.

Consider documenting why this repetitive pattern is necessary (e.g., Triton compilation constraints) and the maximum supported sequence length (2^17 based on the error check).


280-431: Simplify nested if statements for better readability.

The static analyzer correctly identifies that many nested if statements could be combined. This would improve readability without affecting functionality.

Example for lines 287-288:

-        if MIN_LEVEL <= 0 and MAX_LEVEL >= 0:
-            if chunk_index & 1:
+        if MIN_LEVEL <= 0 and MAX_LEVEL >= 0 and chunk_index & 1:
                 p_l = tl.make_block_ptr(

Apply similar simplification to all the nested if blocks flagged by SIM102.


948-957: Consider adding field descriptions to dataclass.

The dataclass fields would benefit from documentation.

 @dataclass
 class LogLinearAttentionState:
+    """State container for log-linear attention across chunks.
+    
+    Attributes:
+        ht: Hidden states for each hierarchical level
+        offsets: Token offsets for each sequence
+        q_prev, k_prev, v_prev: Previous chunk's queries, keys, values
+        g_prev: Previous chunk's gating factors
+        l_prev: Previous chunk's level scales
+    """
     ht: torch.Tensor
     offsets: torch.Tensor
     q_prev: torch.Tensor
     k_prev: torch.Tensor
     v_prev: torch.Tensor
     g_prev: torch.Tensor
-    l_prev: torch.Tensor
+    level_scales_prev: torch.Tensor

1004-1006: Consider making the sequence length limit configurable.

The hardcoded limit of 2^17 for sequence length should be documented or made configurable.

     if MAX_LEVEL > 10:
-        raise ValueError("Sequence length must be less than 2**17")
+        raise ValueError(
+            f"Sequence length must be less than 2^17 (131,072). "
+            f"Current effective length would require {MAX_LEVEL + 1} levels."
+        )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4889a8b and 01827c3.

📒 Files selected for processing (5)
  • fla/ops/__init__.py (2 hunks)
  • fla/ops/log_linear_attn/__init__.py (1 hunks)
  • fla/ops/log_linear_attn/chunk.py (1 hunks)
  • fla/ops/log_linear_attn/naive.py (1 hunks)
  • tests/ops/test_log_linear_attn.py (1 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: toothacher17
PR: fla-org/flash-linear-attention#395
File: fla/layers/mla.py:0-0
Timestamp: 2025-05-13T06:04:24.342Z
Learning: In Flash Attention's `flash_attn_varlen_func`, the scaling factor (typically computed as head_dim ** -0.5) is handled internally by the Flash Attention implementation rather than needing to be applied explicitly to query projections.
fla/ops/log_linear_attn/naive.py (1)
Learnt from: toothacher17
PR: fla-org/flash-linear-attention#395
File: fla/layers/mla.py:0-0
Timestamp: 2025-05-13T06:04:24.342Z
Learning: In Flash Attention's `flash_attn_varlen_func`, the scaling factor (typically computed as head_dim ** -0.5) is handled internally by the Flash Attention implementation rather than needing to be applied explicitly to query projections.
🧬 Code Graph Analysis (3)
fla/ops/__init__.py (1)
fla/ops/log_linear_attn/chunk.py (2)
  • chunk_log_linear_attn (1152-1208)
  • LogLinearAttentionState (949-956)
fla/ops/log_linear_attn/__init__.py (1)
fla/ops/log_linear_attn/chunk.py (2)
  • chunk_log_linear_attn (1152-1208)
  • LogLinearAttentionState (949-956)
fla/ops/log_linear_attn/chunk.py (3)
fla/ops/utils/cumsum.py (1)
  • chunk_local_cumsum (428-465)
fla/ops/utils/op.py (1)
  • safe_exp (30-31)
fla/utils.py (1)
  • input_guard (131-162)
🪛 GitHub Actions: lint
fla/ops/__init__.py

[error] 1-1: Pre-commit hook 'isort' modified this file to fix import sorting.

fla/ops/log_linear_attn/__init__.py

[error] 1-1: Pre-commit hook 'end-of-file-fixer' modified this file to fix end of file issues.


[error] 1-1: Pre-commit hook 'isort' modified this file to fix import sorting.

tests/ops/test_log_linear_attn.py

[error] 1-1: Pre-commit hook 'end-of-file-fixer' modified this file to fix end of file issues.


[error] 44-44: flake8: ambiguous variable name 'l' (E741)

fla/ops/log_linear_attn/naive.py

[error] 1-1: Pre-commit hook 'isort' modified this file to fix import sorting.


[error] 47-47: flake8: ambiguous variable name 'l' (E741)

fla/ops/log_linear_attn/chunk.py

[error] 1-1: Pre-commit hook 'isort' modified this file to fix import sorting.


[error] 1056-1056: flake8: ambiguous variable name 'l' (E741)

🪛 Ruff (0.11.9)
tests/ops/test_log_linear_attn.py

44-44: Ambiguous variable name: l

(E741)

fla/ops/log_linear_attn/naive.py

47-47: Ambiguous variable name: l

(E741)

fla/ops/log_linear_attn/chunk.py

38-38: Ambiguous variable name: l

(E741)


287-288: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


299-300: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


311-312: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


323-324: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


335-336: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


347-348: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


359-360: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


371-372: Use a single if statement instead of nested if statements

(SIM102)


383-384: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


395-396: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


407-408: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


419-420: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


492-493: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


496-497: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


500-501: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


504-505: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


508-509: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


512-513: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


516-517: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


520-521: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


524-525: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


528-529: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


532-533: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


675-675: Ambiguous variable name: l

(E741)


826-826: Ambiguous variable name: l

(E741)


969-969: Ambiguous variable name: l

(E741)


1056-1056: Ambiguous variable name: l

(E741)


1157-1157: Ambiguous variable name: l

(E741)


1191-1192: Use a single if statement instead of nested if statements

(SIM102)

🪛 Flake8 (7.2.0)
tests/ops/test_log_linear_attn.py

[error] 44-44: ambiguous variable name 'l'

(E741)

fla/ops/log_linear_attn/naive.py

[error] 47-47: ambiguous variable name 'l'

(E741)

fla/ops/log_linear_attn/chunk.py

[error] 1056-1056: ambiguous variable name 'l'

(E741)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-models

Comment on lines +1 to +3
import torch
import numpy as np

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix import order to comply with isort.

The pipeline indicates that isort would modify this file. Please run isort to fix the import ordering.

-import torch
 import numpy as np
+import torch
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
import torch
import numpy as np
import numpy as np
import torch
🧰 Tools
🪛 GitHub Actions: lint

[error] 1-1: Pre-commit hook 'isort' modified this file to fix import sorting.

🤖 Prompt for AI Agents
In fla/ops/log_linear_attn/naive.py at lines 1 to 3, the import statements are
not ordered according to isort standards. Run isort on this file to
automatically reorder the imports so that standard library imports come first,
followed by third-party imports like torch and numpy, ensuring compliance with
the project's import style guidelines.

Comment on lines +47 to +50
def naive_log_linear_attn(q, k, v, g, l):
H = construct_H_matrix(g.permute(0, 2, 1), l.permute(0, 2, 3, 1))
M = torch.einsum("bhlc,blhn,bchn->bhlc", H, q, k)
return torch.einsum("bhlc,bchp->blhp", M, v)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Rename ambiguous variable 'l' and add documentation.

The variable name 'l' is ambiguous and flagged by linters. Consider using a more descriptive name like 'levels' or 'level_scales'.

-def naive_log_linear_attn(q, k, v, g, l):
+def naive_log_linear_attn(q, k, v, g, level_scales):
+    """Compute naive log-linear attention.
+    
+    Args:
+        q: Queries of shape (B, T, H, D)
+        k: Keys of shape (B, T, H, D)
+        v: Values of shape (B, T, H, D)
+        g: Gate/forget factors of shape (B, H, T)
+        level_scales: Hierarchical level scales of shape (B, H, L, T)
+    
+    Returns:
+        Attention output of shape (B, T, H, D)
+    """
-    H = construct_H_matrix(g.permute(0, 2, 1), l.permute(0, 2, 3, 1))
+    H = construct_H_matrix(g.permute(0, 2, 1), level_scales.permute(0, 2, 3, 1))
     M = torch.einsum("bhlc,blhn,bchn->bhlc", H, q, k)
     return torch.einsum("bhlc,bchp->blhp", M, v)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def naive_log_linear_attn(q, k, v, g, l):
H = construct_H_matrix(g.permute(0, 2, 1), l.permute(0, 2, 3, 1))
M = torch.einsum("bhlc,blhn,bchn->bhlc", H, q, k)
return torch.einsum("bhlc,bchp->blhp", M, v)
def naive_log_linear_attn(q, k, v, g, level_scales):
"""Compute naive log-linear attention.
Args:
q: Queries of shape (B, T, H, D)
k: Keys of shape (B, T, H, D)
v: Values of shape (B, T, H, D)
g: Gate/forget factors of shape (B, H, T)
level_scales: Hierarchical level scales of shape (B, H, L, T)
Returns:
Attention output of shape (B, T, H, D)
"""
H = construct_H_matrix(
g.permute(0, 2, 1),
level_scales.permute(0, 2, 3, 1),
)
M = torch.einsum("bhlc,blhn,bchn->bhlc", H, q, k)
return torch.einsum("bhlc,bchp->blhp", M, v)
🧰 Tools
🪛 Ruff (0.11.9)

47-47: Ambiguous variable name: l

(E741)

🪛 Flake8 (7.2.0)

[error] 47-47: ambiguous variable name 'l'

(E741)

🪛 GitHub Actions: lint

[error] 47-47: flake8: ambiguous variable name 'l' (E741)

🤖 Prompt for AI Agents
In fla/ops/log_linear_attn/naive.py around lines 47 to 50, rename the ambiguous
variable 'l' to a more descriptive name such as 'levels' or 'level_scales'
throughout the naive_log_linear_attn function, including in the function
signature and all its usages. Additionally, add a brief docstring to the
function explaining its purpose and the role of each parameter to improve code
clarity and maintainability.

Comment on lines +5 to +11
def segsum(x):
T = x.size(-1)
x_cumsum = torch.cumsum(x, dim=-1)
x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool))
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add docstring and consider memory efficiency for large sequences.

The function lacks documentation. Additionally, creating a full T×T matrix could be memory-intensive for large sequences.

 def segsum(x):
+    """Compute segment sums for the input tensor.
+    
+    Args:
+        x: Input tensor of shape (..., T)
+    
+    Returns:
+        Segment sum matrix of shape (..., T, T) where element [i, j] contains
+        the sum of x[j:i+1] for j <= i, and -inf otherwise.
+    """
     T = x.size(-1)
     x_cumsum = torch.cumsum(x, dim=-1)
     x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
     mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool))
     x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
     return x_segsum
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def segsum(x):
T = x.size(-1)
x_cumsum = torch.cumsum(x, dim=-1)
x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool))
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
def segsum(x):
"""Compute segment sums for the input tensor.
Args:
x: Input tensor of shape (..., T)
Returns:
Segment sum matrix of shape (..., T, T) where element [i, j] contains
the sum of x[j:i+1] for j <= i, and -inf otherwise.
"""
T = x.size(-1)
x_cumsum = torch.cumsum(x, dim=-1)
x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool))
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
🤖 Prompt for AI Agents
In fla/ops/log_linear_attn/naive.py around lines 5 to 11, add a clear docstring
to the segsum function explaining its purpose, inputs, and outputs. To improve
memory efficiency for large sequences, consider refactoring the code to avoid
creating the full T×T matrix explicitly, possibly by using more memory-efficient
operations or chunking strategies.

Comment on lines +14 to +35
def construct_level_mask(level, L):
T = L.size(-1)
if level == 0:
return torch.diag_embed(L[..., level, :])

indices = torch.cartesian_prod(torch.arange(T), torch.arange(T)).to(L.device)

mask = torch.where(
torch.logical_and(
torch.logical_and(
indices[:, 0] % (1 << level) >= (1 << (level - 1)),
indices[:, 1] + (1 << (level - 1))
>= indices[:, 0] - (indices[:, 0] % (1 << (level - 1))),
),
indices[:, 1] < indices[:, 0] - (indices[:, 0] % (1 << (level - 1))),
).view(T, T),
L[..., level, :].unsqueeze(-1).expand(*([-1] * (len(L.shape) - 2)), T, T),
0,
)

return mask

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add documentation and consider optimizing memory usage.

This function implements complex hierarchical masking logic that needs documentation. Also, torch.cartesian_prod creates all T² index pairs upfront, which could be memory-intensive.

 def construct_level_mask(level, L):
+    """Construct a hierarchical level mask for log-linear attention.
+    
+    Args:
+        level: Hierarchical level (0 for diagonal, higher for broader patterns)
+        L: Level values tensor of shape (..., num_levels, T)
+    
+    Returns:
+        Mask tensor of shape (..., T, T) with L values placed according to 
+        the hierarchical pattern for the given level.
+    """
     T = L.size(-1)
     if level == 0:
         return torch.diag_embed(L[..., level, :])
 
     indices = torch.cartesian_prod(torch.arange(T), torch.arange(T)).to(L.device)

Would you like me to suggest a more memory-efficient implementation that avoids creating the full cartesian product?

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def construct_level_mask(level, L):
T = L.size(-1)
if level == 0:
return torch.diag_embed(L[..., level, :])
indices = torch.cartesian_prod(torch.arange(T), torch.arange(T)).to(L.device)
mask = torch.where(
torch.logical_and(
torch.logical_and(
indices[:, 0] % (1 << level) >= (1 << (level - 1)),
indices[:, 1] + (1 << (level - 1))
>= indices[:, 0] - (indices[:, 0] % (1 << (level - 1))),
),
indices[:, 1] < indices[:, 0] - (indices[:, 0] % (1 << (level - 1))),
).view(T, T),
L[..., level, :].unsqueeze(-1).expand(*([-1] * (len(L.shape) - 2)), T, T),
0,
)
return mask
def construct_level_mask(level, L):
"""Construct a hierarchical level mask for log-linear attention.
Args:
level: Hierarchical level (0 for diagonal, higher for broader patterns)
L: Level values tensor of shape (..., num_levels, T)
Returns:
Mask tensor of shape (..., T, T) with L values placed according to
the hierarchical pattern for the given level.
"""
T = L.size(-1)
if level == 0:
return torch.diag_embed(L[..., level, :])
indices = torch.cartesian_prod(torch.arange(T), torch.arange(T)).to(L.device)
mask = torch.where(
torch.logical_and(
torch.logical_and(
indices[:, 0] % (1 << level) >= (1 << (level - 1)),
indices[:, 1] + (1 << (level - 1))
>= indices[:, 0] - (indices[:, 0] % (1 << (level - 1))),
),
indices[:, 1] < indices[:, 0] - (indices[:, 0] % (1 << (level - 1))),
).view(T, T),
L[..., level, :].unsqueeze(-1).expand(*([-1] * (len(L.shape) - 2)), T, T),
0,
)
return mask
🤖 Prompt for AI Agents
In fla/ops/log_linear_attn/naive.py lines 14 to 35, add clear docstrings
explaining the purpose, inputs, outputs, and the hierarchical masking logic of
the construct_level_mask function. To optimize memory usage, replace
torch.cartesian_prod with an approach that computes the mask without generating
all T² index pairs at once, such as using broadcasting or iterative methods to
reduce memory overhead.

Comment on lines +1 to +6
from .chunk import chunk_log_linear_attn, LogLinearAttentionState

__all__ = [
'chunk_log_linear_attn',
'LogLinearAttentionState'
]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix formatting issues: import order and end-of-file.

The pipeline indicates two issues:

  1. isort would modify the import ordering
  2. Missing newline at end of file

Please run the pre-commit hooks to fix these issues.

 from .chunk import chunk_log_linear_attn, LogLinearAttentionState
 
 __all__ = [
     'chunk_log_linear_attn',
     'LogLinearAttentionState'
-]
+]
+
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from .chunk import chunk_log_linear_attn, LogLinearAttentionState
__all__ = [
'chunk_log_linear_attn',
'LogLinearAttentionState'
]
from .chunk import chunk_log_linear_attn, LogLinearAttentionState
__all__ = [
'chunk_log_linear_attn',
'LogLinearAttentionState'
]
🧰 Tools
🪛 GitHub Actions: lint

[error] 1-1: Pre-commit hook 'end-of-file-fixer' modified this file to fix end of file issues.


[error] 1-1: Pre-commit hook 'isort' modified this file to fix import sorting.

🤖 Prompt for AI Agents
In fla/ops/log_linear_attn/__init__.py at lines 1 to 6, fix the import order
according to isort standards by running the isort pre-commit hook, and add a
newline at the end of the file to comply with formatting requirements. Run the
full set of pre-commit hooks to automatically correct these formatting issues.

Comment on lines +664 to +695
@triton.heuristics(
{
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def copy_input_kernel(
q,
k,
v,
g,
l,
cu_seqlens,
q_prev,
k_prev,
v_prev,
g_prev,
l_prev,
offsets,
q_new,
k_new,
v_new,
g_new,
l_new,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
L: tl.constexpr,
BT: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Rename ambiguous variable 'l' in copy_input_kernel.

The parameter name 'l' is ambiguous and should be renamed for consistency with the suggested changes in the main kernel.

 def copy_input_kernel(
     q,
     k,
     v,
     g,
-    l,
+    level_scales,
     cu_seqlens,
     q_prev,
     k_prev,
     v_prev,
     g_prev,
-    l_prev,
+    level_scales_prev,
     offsets,
     q_new,
     k_new,
     v_new,
     g_new,
-    l_new,
+    level_scales_new,
     T,
     H: tl.constexpr,
     K: tl.constexpr,
     V: tl.constexpr,
     L: tl.constexpr,
     BT: tl.constexpr,
     IS_VARLEN: tl.constexpr,
 ):
+    """Copy input tensors while incorporating previous state for continuous processing."""
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@triton.heuristics(
{
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def copy_input_kernel(
q,
k,
v,
g,
l,
cu_seqlens,
q_prev,
k_prev,
v_prev,
g_prev,
l_prev,
offsets,
q_new,
k_new,
v_new,
g_new,
l_new,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
L: tl.constexpr,
BT: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
@triton.heuristics(
{
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def copy_input_kernel(
q,
k,
v,
g,
level_scales,
cu_seqlens,
q_prev,
k_prev,
v_prev,
g_prev,
level_scales_prev,
offsets,
q_new,
k_new,
v_new,
g_new,
level_scales_new,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
L: tl.constexpr,
BT: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
"""Copy input tensors while incorporating previous state for continuous processing."""
# … rest of the kernel implementation …
🧰 Tools
🪛 Ruff (0.11.9)

675-675: Ambiguous variable name: l

(E741)

🤖 Prompt for AI Agents
In fla/ops/log_linear_attn/chunk.py between lines 664 and 695, rename the
ambiguous parameter 'l' in the copy_input_kernel function to a more descriptive
and consistent name that aligns with the main kernel's naming conventions.
Update all occurrences of this parameter within the function signature and body
accordingly to maintain clarity and consistency.

Comment on lines +910 to +946
def construct_binary_level_mask(level, T):
if level == 0:
return torch.diag(torch.ones(T, dtype=torch.bool))

indices = torch.cartesian_prod(torch.arange(T), torch.arange(T))

mask = torch.where(
torch.logical_and(
torch.logical_and(
indices[:, 0] % (1 << level) >= (1 << (level - 1)),
indices[:, 1] + (1 << (level - 1))
>= indices[:, 0] - (indices[:, 0] % (1 << (level - 1))),
),
indices[:, 1] < indices[:, 0] - (indices[:, 0] % (1 << (level - 1))),
).view(T, T),
1,
0,
)

return mask


def level_lut(BT, device):
lut = torch.zeros((BT, BT), dtype=torch.int32, device=device)
for level in range(1, ceil_log(BT, 2) + 1):
mask = construct_binary_level_mask(level, BT).to(device)
lut = torch.where(mask.to(torch.bool), level, lut)
return lut


def ceil_div(x: int, y: int) -> int:
return math.ceil(x / y)


def ceil_log(x: int, b: int) -> int:
return math.ceil(math.log(x, b))

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add documentation to utility functions.

These utility functions lack documentation explaining their purpose in the hierarchical attention scheme.

 def construct_binary_level_mask(level, T):
+    """Construct a binary mask for a specific hierarchical level.
+    
+    Args:
+        level: Hierarchical level (0 for diagonal, higher for broader patterns)
+        T: Sequence length
+    
+    Returns:
+        Binary mask tensor of shape (T, T)
+    """
     if level == 0:
         return torch.diag(torch.ones(T, dtype=torch.bool))
 def level_lut(BT, device):
+    """Generate a lookup table mapping positions to hierarchical levels.
+    
+    Args:
+        BT: Block size
+        device: Target device
+    
+    Returns:
+        Lookup table of shape (BT, BT) with level indices
+    """
     lut = torch.zeros((BT, BT), dtype=torch.int32, device=device)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def construct_binary_level_mask(level, T):
if level == 0:
return torch.diag(torch.ones(T, dtype=torch.bool))
indices = torch.cartesian_prod(torch.arange(T), torch.arange(T))
mask = torch.where(
torch.logical_and(
torch.logical_and(
indices[:, 0] % (1 << level) >= (1 << (level - 1)),
indices[:, 1] + (1 << (level - 1))
>= indices[:, 0] - (indices[:, 0] % (1 << (level - 1))),
),
indices[:, 1] < indices[:, 0] - (indices[:, 0] % (1 << (level - 1))),
).view(T, T),
1,
0,
)
return mask
def level_lut(BT, device):
lut = torch.zeros((BT, BT), dtype=torch.int32, device=device)
for level in range(1, ceil_log(BT, 2) + 1):
mask = construct_binary_level_mask(level, BT).to(device)
lut = torch.where(mask.to(torch.bool), level, lut)
return lut
def ceil_div(x: int, y: int) -> int:
return math.ceil(x / y)
def ceil_log(x: int, b: int) -> int:
return math.ceil(math.log(x, b))
def construct_binary_level_mask(level, T):
"""Construct a binary mask for a specific hierarchical level.
Args:
level: Hierarchical level (0 for diagonal, higher for broader patterns)
T: Sequence length
Returns:
Binary mask tensor of shape (T, T)
"""
if level == 0:
return torch.diag(torch.ones(T, dtype=torch.bool))
indices = torch.cartesian_prod(torch.arange(T), torch.arange(T))
mask = torch.where(
torch.logical_and(
torch.logical_and(
indices[:, 0] % (1 << level) >= (1 << (level - 1)),
indices[:, 1] + (1 << (level - 1))
>= indices[:, 0] - (indices[:, 0] % (1 << (level - 1))),
),
indices[:, 1] < indices[:, 0] - (indices[:, 0] % (1 << (level - 1))),
).view(T, T),
1,
0,
)
return mask
def level_lut(BT, device):
"""Generate a lookup table mapping positions to hierarchical levels.
Args:
BT: Block size
device: Target device
Returns:
Lookup table of shape (BT, BT) with level indices
"""
lut = torch.zeros((BT, BT), dtype=torch.int32, device=device)
for level in range(1, ceil_log(BT, 2) + 1):
mask = construct_binary_level_mask(level, BT).to(device)
lut = torch.where(mask.to(torch.bool), level, lut)
return lut
🤖 Prompt for AI Agents
In fla/ops/log_linear_attn/chunk.py around lines 910 to 946, the utility
functions construct_binary_level_mask, level_lut, ceil_div, and ceil_log lack
documentation. Add concise docstrings to each function explaining their purpose
and role in the hierarchical attention scheme, including input parameters and
return values to improve code readability and maintainability.

Comment on lines +815 to +841
@triton.heuristics(
{
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def copy_last_chunk_kernel(
q,
k,
v,
g,
l,
cu_seqlens,
q_prev,
k_prev,
v_prev,
g_prev,
l_prev,
offsets,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
L: tl.constexpr,
BT: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Rename ambiguous variable 'l' in copy_last_chunk_kernel.

Maintain consistency with the naming convention changes.

 def copy_last_chunk_kernel(
     q,
     k,
     v,
     g,
-    l,
+    level_scales,
     cu_seqlens,
     q_prev,
     k_prev,
     v_prev,
     g_prev,
-    l_prev,
+    level_scales_prev,
     offsets,
     T,
     H: tl.constexpr,
     K: tl.constexpr,
     V: tl.constexpr,
     L: tl.constexpr,
     BT: tl.constexpr,
     IS_VARLEN: tl.constexpr,
 ):
+    """Copy the last chunk of sequences to preserve state for next iteration."""
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@triton.heuristics(
{
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def copy_last_chunk_kernel(
q,
k,
v,
g,
l,
cu_seqlens,
q_prev,
k_prev,
v_prev,
g_prev,
l_prev,
offsets,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
L: tl.constexpr,
BT: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
@triton.heuristics(
{
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.jit(do_not_specialize=["T"])
def copy_last_chunk_kernel(
q,
k,
v,
g,
level_scales,
cu_seqlens,
q_prev,
k_prev,
v_prev,
g_prev,
level_scales_prev,
offsets,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
L: tl.constexpr,
BT: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
"""Copy the last chunk of sequences to preserve state for next iteration."""
🧰 Tools
🪛 Ruff (0.11.9)

826-826: Ambiguous variable name: l

(E741)

🤖 Prompt for AI Agents
In fla/ops/log_linear_attn/chunk.py around lines 815 to 841, the variable named
'l' in the copy_last_chunk_kernel function is ambiguous and should be renamed to
a more descriptive name consistent with the project's naming conventions.
Identify a clear and meaningful name that reflects the variable's purpose and
replace all instances of 'l' with this new name within the function signature
and body.

Comment on lines +1151 to +1190
@torch.compiler.disable
def chunk_log_linear_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
l: torch.Tensor,
initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]`.
v (torch.Tensor):
values of shape `[B, T, H, V]`.
g (torch.Tensor):
Forget gates of shape `[B, T, H]`.
l (torch.Tensor):
Scales for each level of shape `[B, T, H, L]`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, H, K, V]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.

Returns:
o (torch.Tensor):
Outputs of shape `[B, T, H, V]`.
final_state (torch.Tensor):
Final state of type `LogLinearAttentionState` if `output_final_state=True` else `None`.

"""
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix ambiguous variable name and document gradient limitation.

The parameter 'l' should be renamed, and the docstring should mention that gradients are not supported.

 def chunk_log_linear_attn(
     q: torch.Tensor,
     k: torch.Tensor,
     v: torch.Tensor,
     g: torch.Tensor,
-    l: torch.Tensor,
+    level_scales: torch.Tensor,
     initial_state: Optional[torch.Tensor] = None,
     output_final_state: bool = False,
     cu_seqlens: Optional[torch.LongTensor] = None,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
     r"""
     Args:
         q (torch.Tensor):
             queries of shape `[B, T, H, K]`.
         k (torch.Tensor):
             keys of shape `[B, T, H, K]`.
         v (torch.Tensor):
             values of shape `[B, T, H, V]`.
         g (torch.Tensor):
             Forget gates of shape `[B, T, H]`.
-        l (torch.Tensor):
+        level_scales (torch.Tensor):
             Scales for each level of shape `[B, T, H, L]`.
         initial_state (Optional[torch.Tensor]):
             Initial state of shape `[N, H, K, V]` for `N` input sequences.
             For equal-length input sequences, `N` equals the batch size `B`.
             Default: `None`.
         output_final_state (Optional[bool]):
             Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
         cu_seqlens (torch.LongTensor):
             Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
             consistent with the FlashAttention API.
 
     Returns:
         o (torch.Tensor):
             Outputs of shape `[B, T, H, V]`.
         final_state (torch.Tensor):
             Final state of type `LogLinearAttentionState` if `output_final_state=True` else `None`.
+    
+    Note:
+        This implementation does not support gradient computation (backward pass).
+        Use with torch.no_grad() or for inference only.
 
     """
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@torch.compiler.disable
def chunk_log_linear_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
l: torch.Tensor,
initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]`.
v (torch.Tensor):
values of shape `[B, T, H, V]`.
g (torch.Tensor):
Forget gates of shape `[B, T, H]`.
l (torch.Tensor):
Scales for each level of shape `[B, T, H, L]`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, H, K, V]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, H, V]`.
final_state (torch.Tensor):
Final state of type `LogLinearAttentionState` if `output_final_state=True` else `None`.
"""
@torch.compiler.disable
def chunk_log_linear_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
level_scales: torch.Tensor,
initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]`.
v (torch.Tensor):
values of shape `[B, T, H, V]`.
g (torch.Tensor):
Forget gates of shape `[B, T, H]`.
level_scales (torch.Tensor):
Scales for each level of shape `[B, T, H, L]`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, H, K, V]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, H, V]`.
final_state (torch.Tensor):
Final state of type `LogLinearAttentionState` if `output_final_state=True` else `None`.
Note:
This implementation does not support gradient computation (backward pass).
Use with torch.no_grad() or for inference only.
"""
# ... rest of implementation ...
🧰 Tools
🪛 Ruff (0.11.9)

1157-1157: Ambiguous variable name: l

(E741)

🤖 Prompt for AI Agents
In fla/ops/log_linear_attn/chunk.py around lines 1151 to 1190, rename the
parameter 'l' to a more descriptive name that reflects its purpose, such as
'scales' or similar, to avoid ambiguity. Additionally, update the docstring to
explicitly state that gradients are not supported for this function or
parameter, clarifying the limitation for users.

Comment on lines +1142 to +1149
@staticmethod
@input_guard
@autocast_custom_bwd
def backward(ctx, do, dht):
raise NotImplementedError(
"Backward pass is not implemented for log-linear attention."
)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Document that backward pass is not implemented.

The backward pass raises NotImplementedError. This should be clearly documented in the main API function's docstring.


🏁 Script executed:

#!/bin/bash
# Check if this is mentioned in any documentation or if there are any gradient-related tests
rg -A 5 "backward|gradient" --glob "*.md" --glob "*.rst"
rg "requires_grad|backward|gradient" tests/ops/test_log_linear_attn.py

Length of output: 9896


🏁 Script executed:

#!/bin/bash
# Locate user-facing log-linear attention API and where the autograd Function is applied
rg -n "^def .*log_linear.*attention" -n fla/ops/log_linear_attn
rg -n "apply(" -n fla/ops/log_linear_attn

Length of output: 172


🏁 Script executed:

#!/bin/bash
# Identify the autograd Function class and inspect its forward docstring
rg -n "^class .*Attention" -n fla/ops/log_linear_attn/chunk.py
rg -n -B 10 -A 5 "def forward" fla/ops/log_linear_attn/chunk.py

Length of output: 597


🏁 Script executed:

#!/bin/bash
# List files in log_linear_attn directory
ls -1 fla/ops/log_linear_attn

# Find where the autograd function is referenced or applied
rg -n "ChunkLogLinearAttentionFunction" -n fla/ops/log_linear_attn
rg -n "apply\\(" -n fla/ops/log_linear_attn

# List all top‐level function definitions in this folder
rg -n "^def " -n fla/ops/log_linear_attn

Length of output: 1323


🏁 Script executed:

#!/bin/bash
# Show the docstring of the user-facing chunk_log_linear_attn function
rg -n -B 5 -A 20 "def chunk_log_linear_attn" fla/ops/log_linear_attn/chunk.py

Length of output: 956


Add missing backward‐pass note to the chunk_log_linear_attn docstring

The autograd Function’s backward raises NotImplementedError, but the user‐facing chunk_log_linear_attn docstring doesn’t mention this limitation. Please update it to avoid confusion.

• File: fla/ops/log_linear_attn/chunk.py, function chunk_log_linear_attn (around line 1162)
• Action: Insert a “Note” stating that gradients aren’t supported and calling backward will throw.

Suggested diff:

@@ -1162,7 +1162,11 @@ def chunk_log_linear_attn(
-    r"""
+    r"""
+    Note:
+        The backward pass is not implemented and will raise NotImplementedError.
+
+    (Supports forward inference only.)
     Args:
         q (torch.Tensor):
             queries of shape `[B, T, H, K]`.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@staticmethod
@input_guard
@autocast_custom_bwd
def backward(ctx, do, dht):
raise NotImplementedError(
"Backward pass is not implemented for log-linear attention."
)
def chunk_log_linear_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
chunk_size: int,
*,
use_single_stream: bool = False,
) -> torch.Tensor:
- r"""
+ r"""
+ Note:
+ The backward pass is not implemented and will raise NotImplementedError.
+
+ (Supports forward inference only.)
Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]`.
k (torch.Tensor):
keys of shape `[B, U, H, K]`.
v (torch.Tensor):
values of shape `[B, U, H, K]`.
"""
# ... rest of implementation ...
🤖 Prompt for AI Agents
In fla/ops/log_linear_attn/chunk.py around line 1162, update the docstring of
the function chunk_log_linear_attn to include a note that the backward pass is
not implemented and gradients are not supported, warning users that calling
backward will raise a NotImplementedError. Add this clarification to prevent
confusion about the function's autograd limitations.

@yzhangcs
Copy link
Member

@2022tgoel Nice work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants