-
Notifications
You must be signed in to change notification settings - Fork 220
[WIP] Add Log-Linear Attention #524
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughThis update introduces a chunked log-linear attention mechanism implemented with custom Triton kernels and a naive PyTorch reference. The new operator and its state class are exposed via package Changes
Sequence Diagram(s)sequenceDiagram
participant Test as test_log_linear_attn.py
participant Triton as chunk_log_linear_attn (Triton)
participant Naive as naive_log_linear_attn (PyTorch)
Test->>Triton: Call chunk_log_linear_attn(q, k, v, g, l, ...)
Triton->>Triton: Prepare states, run Triton kernels
Triton-->>Test: Return output
Test->>Naive: Call naive_log_linear_attn(q, k, v, g, l)
Naive-->>Test: Return output
Test->>Test: Compare outputs for correctness
Poem
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 13
🔭 Outside diff range comments (1)
fla/ops/__init__.py (1)
1-53
: Fix import order to comply with isort.The pipeline indicates that isort would modify this file. Please run isort to fix the import ordering.
🧹 Nitpick comments (7)
fla/ops/log_linear_attn/naive.py (1)
37-44
: Add docstring and consider using math.log2 instead of numpy.The function needs documentation. Also, importing numpy just for
np.log2
seems unnecessary whenmath.log2
is available.+import math + def construct_H_matrix(a, L): + """Construct the H matrix for log-linear attention. + + Args: + a: Attention weights of shape (..., T) + L: Hierarchical level values of shape (..., num_levels, T) + + Returns: + H matrix of shape (..., T, T) + """ T = a.size(-1) A = torch.exp(segsum(a)) H = torch.zeros_like(A) - for level in range(int(np.log2(T)) + 1): + for level in range(int(math.log2(T)) + 1): mask = construct_level_mask(level, L) H += A * mask return Htests/ops/test_log_linear_attn.py (2)
1-9
: Fix import order and add module docstring.Consider adding a module-level docstring to describe what this test file covers.
+"""Tests for log-linear attention implementation.""" import os import pytest import torch from fla.ops.log_linear_attn import chunk_log_linear_attn from fla.ops.log_linear_attn.naive import naive_log_linear_attn from fla.utils import assert_close, device, device_platform
11-25
: Consider adding more test coverage.The current test only covers three configurations. Consider adding:
- Edge cases (small sequences, single head, etc.)
- Tests with initial state
- Tests with variable-length sequences (cu_seqlens)
- Different data types (bfloat16, float16)
Would you like me to generate additional test cases to improve coverage?
fla/ops/log_linear_attn/chunk.py (4)
84-111
: Consider refactoring repetitive level initialization code.The hardcoded creation of 12 KV states (lines 86-110) is repetitive. While this might be necessary for Triton's compilation model, it reduces maintainability.
Consider documenting why this repetitive pattern is necessary (e.g., Triton compilation constraints) and the maximum supported sequence length (2^17 based on the error check).
280-431
: Simplify nested if statements for better readability.The static analyzer correctly identifies that many nested if statements could be combined. This would improve readability without affecting functionality.
Example for lines 287-288:
- if MIN_LEVEL <= 0 and MAX_LEVEL >= 0: - if chunk_index & 1: + if MIN_LEVEL <= 0 and MAX_LEVEL >= 0 and chunk_index & 1: p_l = tl.make_block_ptr(Apply similar simplification to all the nested if blocks flagged by SIM102.
948-957
: Consider adding field descriptions to dataclass.The dataclass fields would benefit from documentation.
@dataclass class LogLinearAttentionState: + """State container for log-linear attention across chunks. + + Attributes: + ht: Hidden states for each hierarchical level + offsets: Token offsets for each sequence + q_prev, k_prev, v_prev: Previous chunk's queries, keys, values + g_prev: Previous chunk's gating factors + l_prev: Previous chunk's level scales + """ ht: torch.Tensor offsets: torch.Tensor q_prev: torch.Tensor k_prev: torch.Tensor v_prev: torch.Tensor g_prev: torch.Tensor - l_prev: torch.Tensor + level_scales_prev: torch.Tensor
1004-1006
: Consider making the sequence length limit configurable.The hardcoded limit of 2^17 for sequence length should be documented or made configurable.
if MAX_LEVEL > 10: - raise ValueError("Sequence length must be less than 2**17") + raise ValueError( + f"Sequence length must be less than 2^17 (131,072). " + f"Current effective length would require {MAX_LEVEL + 1} levels." + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
fla/ops/__init__.py
(2 hunks)fla/ops/log_linear_attn/__init__.py
(1 hunks)fla/ops/log_linear_attn/chunk.py
(1 hunks)fla/ops/log_linear_attn/naive.py
(1 hunks)tests/ops/test_log_linear_attn.py
(1 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: toothacher17
PR: fla-org/flash-linear-attention#395
File: fla/layers/mla.py:0-0
Timestamp: 2025-05-13T06:04:24.342Z
Learning: In Flash Attention's `flash_attn_varlen_func`, the scaling factor (typically computed as head_dim ** -0.5) is handled internally by the Flash Attention implementation rather than needing to be applied explicitly to query projections.
fla/ops/log_linear_attn/naive.py (1)
Learnt from: toothacher17
PR: fla-org/flash-linear-attention#395
File: fla/layers/mla.py:0-0
Timestamp: 2025-05-13T06:04:24.342Z
Learning: In Flash Attention's `flash_attn_varlen_func`, the scaling factor (typically computed as head_dim ** -0.5) is handled internally by the Flash Attention implementation rather than needing to be applied explicitly to query projections.
🧬 Code Graph Analysis (3)
fla/ops/__init__.py (1)
fla/ops/log_linear_attn/chunk.py (2)
chunk_log_linear_attn
(1152-1208)LogLinearAttentionState
(949-956)
fla/ops/log_linear_attn/__init__.py (1)
fla/ops/log_linear_attn/chunk.py (2)
chunk_log_linear_attn
(1152-1208)LogLinearAttentionState
(949-956)
fla/ops/log_linear_attn/chunk.py (3)
fla/ops/utils/cumsum.py (1)
chunk_local_cumsum
(428-465)fla/ops/utils/op.py (1)
safe_exp
(30-31)fla/utils.py (1)
input_guard
(131-162)
🪛 GitHub Actions: lint
fla/ops/__init__.py
[error] 1-1: Pre-commit hook 'isort' modified this file to fix import sorting.
fla/ops/log_linear_attn/__init__.py
[error] 1-1: Pre-commit hook 'end-of-file-fixer' modified this file to fix end of file issues.
[error] 1-1: Pre-commit hook 'isort' modified this file to fix import sorting.
tests/ops/test_log_linear_attn.py
[error] 1-1: Pre-commit hook 'end-of-file-fixer' modified this file to fix end of file issues.
[error] 44-44: flake8: ambiguous variable name 'l' (E741)
fla/ops/log_linear_attn/naive.py
[error] 1-1: Pre-commit hook 'isort' modified this file to fix import sorting.
[error] 47-47: flake8: ambiguous variable name 'l' (E741)
fla/ops/log_linear_attn/chunk.py
[error] 1-1: Pre-commit hook 'isort' modified this file to fix import sorting.
[error] 1056-1056: flake8: ambiguous variable name 'l' (E741)
🪛 Ruff (0.11.9)
tests/ops/test_log_linear_attn.py
44-44: Ambiguous variable name: l
(E741)
fla/ops/log_linear_attn/naive.py
47-47: Ambiguous variable name: l
(E741)
fla/ops/log_linear_attn/chunk.py
38-38: Ambiguous variable name: l
(E741)
287-288: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
299-300: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
311-312: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
323-324: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
335-336: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
347-348: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
359-360: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
371-372: Use a single if
statement instead of nested if
statements
(SIM102)
383-384: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
395-396: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
407-408: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
419-420: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
492-493: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
496-497: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
500-501: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
504-505: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
508-509: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
512-513: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
516-517: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
520-521: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
524-525: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
528-529: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
532-533: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
675-675: Ambiguous variable name: l
(E741)
826-826: Ambiguous variable name: l
(E741)
969-969: Ambiguous variable name: l
(E741)
1056-1056: Ambiguous variable name: l
(E741)
1157-1157: Ambiguous variable name: l
(E741)
1191-1192: Use a single if
statement instead of nested if
statements
(SIM102)
🪛 Flake8 (7.2.0)
tests/ops/test_log_linear_attn.py
[error] 44-44: ambiguous variable name 'l'
(E741)
fla/ops/log_linear_attn/naive.py
[error] 47-47: ambiguous variable name 'l'
(E741)
fla/ops/log_linear_attn/chunk.py
[error] 1056-1056: ambiguous variable name 'l'
(E741)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test H100 (PyTorch 2.7) / test-models
import torch | ||
import numpy as np | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix import order to comply with isort.
The pipeline indicates that isort would modify this file. Please run isort to fix the import ordering.
-import torch
import numpy as np
+import torch
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
import torch | |
import numpy as np | |
import numpy as np | |
import torch |
🧰 Tools
🪛 GitHub Actions: lint
[error] 1-1: Pre-commit hook 'isort' modified this file to fix import sorting.
🤖 Prompt for AI Agents
In fla/ops/log_linear_attn/naive.py at lines 1 to 3, the import statements are
not ordered according to isort standards. Run isort on this file to
automatically reorder the imports so that standard library imports come first,
followed by third-party imports like torch and numpy, ensuring compliance with
the project's import style guidelines.
def naive_log_linear_attn(q, k, v, g, l): | ||
H = construct_H_matrix(g.permute(0, 2, 1), l.permute(0, 2, 3, 1)) | ||
M = torch.einsum("bhlc,blhn,bchn->bhlc", H, q, k) | ||
return torch.einsum("bhlc,bchp->blhp", M, v) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename ambiguous variable 'l' and add documentation.
The variable name 'l' is ambiguous and flagged by linters. Consider using a more descriptive name like 'levels' or 'level_scales'.
-def naive_log_linear_attn(q, k, v, g, l):
+def naive_log_linear_attn(q, k, v, g, level_scales):
+ """Compute naive log-linear attention.
+
+ Args:
+ q: Queries of shape (B, T, H, D)
+ k: Keys of shape (B, T, H, D)
+ v: Values of shape (B, T, H, D)
+ g: Gate/forget factors of shape (B, H, T)
+ level_scales: Hierarchical level scales of shape (B, H, L, T)
+
+ Returns:
+ Attention output of shape (B, T, H, D)
+ """
- H = construct_H_matrix(g.permute(0, 2, 1), l.permute(0, 2, 3, 1))
+ H = construct_H_matrix(g.permute(0, 2, 1), level_scales.permute(0, 2, 3, 1))
M = torch.einsum("bhlc,blhn,bchn->bhlc", H, q, k)
return torch.einsum("bhlc,bchp->blhp", M, v)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def naive_log_linear_attn(q, k, v, g, l): | |
H = construct_H_matrix(g.permute(0, 2, 1), l.permute(0, 2, 3, 1)) | |
M = torch.einsum("bhlc,blhn,bchn->bhlc", H, q, k) | |
return torch.einsum("bhlc,bchp->blhp", M, v) | |
def naive_log_linear_attn(q, k, v, g, level_scales): | |
"""Compute naive log-linear attention. | |
Args: | |
q: Queries of shape (B, T, H, D) | |
k: Keys of shape (B, T, H, D) | |
v: Values of shape (B, T, H, D) | |
g: Gate/forget factors of shape (B, H, T) | |
level_scales: Hierarchical level scales of shape (B, H, L, T) | |
Returns: | |
Attention output of shape (B, T, H, D) | |
""" | |
H = construct_H_matrix( | |
g.permute(0, 2, 1), | |
level_scales.permute(0, 2, 3, 1), | |
) | |
M = torch.einsum("bhlc,blhn,bchn->bhlc", H, q, k) | |
return torch.einsum("bhlc,bchp->blhp", M, v) |
🧰 Tools
🪛 Ruff (0.11.9)
47-47: Ambiguous variable name: l
(E741)
🪛 Flake8 (7.2.0)
[error] 47-47: ambiguous variable name 'l'
(E741)
🪛 GitHub Actions: lint
[error] 47-47: flake8: ambiguous variable name 'l' (E741)
🤖 Prompt for AI Agents
In fla/ops/log_linear_attn/naive.py around lines 47 to 50, rename the ambiguous
variable 'l' to a more descriptive name such as 'levels' or 'level_scales'
throughout the naive_log_linear_attn function, including in the function
signature and all its usages. Additionally, add a brief docstring to the
function explaining its purpose and the role of each parameter to improve code
clarity and maintainability.
def segsum(x): | ||
T = x.size(-1) | ||
x_cumsum = torch.cumsum(x, dim=-1) | ||
x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :] | ||
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool)) | ||
x_segsum = x_segsum.masked_fill(~mask, -torch.inf) | ||
return x_segsum |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add docstring and consider memory efficiency for large sequences.
The function lacks documentation. Additionally, creating a full T×T matrix could be memory-intensive for large sequences.
def segsum(x):
+ """Compute segment sums for the input tensor.
+
+ Args:
+ x: Input tensor of shape (..., T)
+
+ Returns:
+ Segment sum matrix of shape (..., T, T) where element [i, j] contains
+ the sum of x[j:i+1] for j <= i, and -inf otherwise.
+ """
T = x.size(-1)
x_cumsum = torch.cumsum(x, dim=-1)
x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool))
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def segsum(x): | |
T = x.size(-1) | |
x_cumsum = torch.cumsum(x, dim=-1) | |
x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :] | |
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool)) | |
x_segsum = x_segsum.masked_fill(~mask, -torch.inf) | |
return x_segsum | |
def segsum(x): | |
"""Compute segment sums for the input tensor. | |
Args: | |
x: Input tensor of shape (..., T) | |
Returns: | |
Segment sum matrix of shape (..., T, T) where element [i, j] contains | |
the sum of x[j:i+1] for j <= i, and -inf otherwise. | |
""" | |
T = x.size(-1) | |
x_cumsum = torch.cumsum(x, dim=-1) | |
x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :] | |
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool)) | |
x_segsum = x_segsum.masked_fill(~mask, -torch.inf) | |
return x_segsum |
🤖 Prompt for AI Agents
In fla/ops/log_linear_attn/naive.py around lines 5 to 11, add a clear docstring
to the segsum function explaining its purpose, inputs, and outputs. To improve
memory efficiency for large sequences, consider refactoring the code to avoid
creating the full T×T matrix explicitly, possibly by using more memory-efficient
operations or chunking strategies.
def construct_level_mask(level, L): | ||
T = L.size(-1) | ||
if level == 0: | ||
return torch.diag_embed(L[..., level, :]) | ||
|
||
indices = torch.cartesian_prod(torch.arange(T), torch.arange(T)).to(L.device) | ||
|
||
mask = torch.where( | ||
torch.logical_and( | ||
torch.logical_and( | ||
indices[:, 0] % (1 << level) >= (1 << (level - 1)), | ||
indices[:, 1] + (1 << (level - 1)) | ||
>= indices[:, 0] - (indices[:, 0] % (1 << (level - 1))), | ||
), | ||
indices[:, 1] < indices[:, 0] - (indices[:, 0] % (1 << (level - 1))), | ||
).view(T, T), | ||
L[..., level, :].unsqueeze(-1).expand(*([-1] * (len(L.shape) - 2)), T, T), | ||
0, | ||
) | ||
|
||
return mask | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add documentation and consider optimizing memory usage.
This function implements complex hierarchical masking logic that needs documentation. Also, torch.cartesian_prod
creates all T² index pairs upfront, which could be memory-intensive.
def construct_level_mask(level, L):
+ """Construct a hierarchical level mask for log-linear attention.
+
+ Args:
+ level: Hierarchical level (0 for diagonal, higher for broader patterns)
+ L: Level values tensor of shape (..., num_levels, T)
+
+ Returns:
+ Mask tensor of shape (..., T, T) with L values placed according to
+ the hierarchical pattern for the given level.
+ """
T = L.size(-1)
if level == 0:
return torch.diag_embed(L[..., level, :])
indices = torch.cartesian_prod(torch.arange(T), torch.arange(T)).to(L.device)
Would you like me to suggest a more memory-efficient implementation that avoids creating the full cartesian product?
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def construct_level_mask(level, L): | |
T = L.size(-1) | |
if level == 0: | |
return torch.diag_embed(L[..., level, :]) | |
indices = torch.cartesian_prod(torch.arange(T), torch.arange(T)).to(L.device) | |
mask = torch.where( | |
torch.logical_and( | |
torch.logical_and( | |
indices[:, 0] % (1 << level) >= (1 << (level - 1)), | |
indices[:, 1] + (1 << (level - 1)) | |
>= indices[:, 0] - (indices[:, 0] % (1 << (level - 1))), | |
), | |
indices[:, 1] < indices[:, 0] - (indices[:, 0] % (1 << (level - 1))), | |
).view(T, T), | |
L[..., level, :].unsqueeze(-1).expand(*([-1] * (len(L.shape) - 2)), T, T), | |
0, | |
) | |
return mask | |
def construct_level_mask(level, L): | |
"""Construct a hierarchical level mask for log-linear attention. | |
Args: | |
level: Hierarchical level (0 for diagonal, higher for broader patterns) | |
L: Level values tensor of shape (..., num_levels, T) | |
Returns: | |
Mask tensor of shape (..., T, T) with L values placed according to | |
the hierarchical pattern for the given level. | |
""" | |
T = L.size(-1) | |
if level == 0: | |
return torch.diag_embed(L[..., level, :]) | |
indices = torch.cartesian_prod(torch.arange(T), torch.arange(T)).to(L.device) | |
mask = torch.where( | |
torch.logical_and( | |
torch.logical_and( | |
indices[:, 0] % (1 << level) >= (1 << (level - 1)), | |
indices[:, 1] + (1 << (level - 1)) | |
>= indices[:, 0] - (indices[:, 0] % (1 << (level - 1))), | |
), | |
indices[:, 1] < indices[:, 0] - (indices[:, 0] % (1 << (level - 1))), | |
).view(T, T), | |
L[..., level, :].unsqueeze(-1).expand(*([-1] * (len(L.shape) - 2)), T, T), | |
0, | |
) | |
return mask |
🤖 Prompt for AI Agents
In fla/ops/log_linear_attn/naive.py lines 14 to 35, add clear docstrings
explaining the purpose, inputs, outputs, and the hierarchical masking logic of
the construct_level_mask function. To optimize memory usage, replace
torch.cartesian_prod with an approach that computes the mask without generating
all T² index pairs at once, such as using broadcasting or iterative methods to
reduce memory overhead.
from .chunk import chunk_log_linear_attn, LogLinearAttentionState | ||
|
||
__all__ = [ | ||
'chunk_log_linear_attn', | ||
'LogLinearAttentionState' | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix formatting issues: import order and end-of-file.
The pipeline indicates two issues:
- isort would modify the import ordering
- Missing newline at end of file
Please run the pre-commit hooks to fix these issues.
from .chunk import chunk_log_linear_attn, LogLinearAttentionState
__all__ = [
'chunk_log_linear_attn',
'LogLinearAttentionState'
-]
+]
+
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
from .chunk import chunk_log_linear_attn, LogLinearAttentionState | |
__all__ = [ | |
'chunk_log_linear_attn', | |
'LogLinearAttentionState' | |
] | |
from .chunk import chunk_log_linear_attn, LogLinearAttentionState | |
__all__ = [ | |
'chunk_log_linear_attn', | |
'LogLinearAttentionState' | |
] | |
🧰 Tools
🪛 GitHub Actions: lint
[error] 1-1: Pre-commit hook 'end-of-file-fixer' modified this file to fix end of file issues.
[error] 1-1: Pre-commit hook 'isort' modified this file to fix import sorting.
🤖 Prompt for AI Agents
In fla/ops/log_linear_attn/__init__.py at lines 1 to 6, fix the import order
according to isort standards by running the isort pre-commit hook, and add a
newline at the end of the file to comply with formatting requirements. Run the
full set of pre-commit hooks to automatically correct these formatting issues.
@triton.heuristics( | ||
{ | ||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None, | ||
} | ||
) | ||
@triton.jit(do_not_specialize=["T"]) | ||
def copy_input_kernel( | ||
q, | ||
k, | ||
v, | ||
g, | ||
l, | ||
cu_seqlens, | ||
q_prev, | ||
k_prev, | ||
v_prev, | ||
g_prev, | ||
l_prev, | ||
offsets, | ||
q_new, | ||
k_new, | ||
v_new, | ||
g_new, | ||
l_new, | ||
T, | ||
H: tl.constexpr, | ||
K: tl.constexpr, | ||
V: tl.constexpr, | ||
L: tl.constexpr, | ||
BT: tl.constexpr, | ||
IS_VARLEN: tl.constexpr, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename ambiguous variable 'l' in copy_input_kernel.
The parameter name 'l' is ambiguous and should be renamed for consistency with the suggested changes in the main kernel.
def copy_input_kernel(
q,
k,
v,
g,
- l,
+ level_scales,
cu_seqlens,
q_prev,
k_prev,
v_prev,
g_prev,
- l_prev,
+ level_scales_prev,
offsets,
q_new,
k_new,
v_new,
g_new,
- l_new,
+ level_scales_new,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
L: tl.constexpr,
BT: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
+ """Copy input tensors while incorporating previous state for continuous processing."""
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
@triton.heuristics( | |
{ | |
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None, | |
} | |
) | |
@triton.jit(do_not_specialize=["T"]) | |
def copy_input_kernel( | |
q, | |
k, | |
v, | |
g, | |
l, | |
cu_seqlens, | |
q_prev, | |
k_prev, | |
v_prev, | |
g_prev, | |
l_prev, | |
offsets, | |
q_new, | |
k_new, | |
v_new, | |
g_new, | |
l_new, | |
T, | |
H: tl.constexpr, | |
K: tl.constexpr, | |
V: tl.constexpr, | |
L: tl.constexpr, | |
BT: tl.constexpr, | |
IS_VARLEN: tl.constexpr, | |
): | |
@triton.heuristics( | |
{ | |
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None, | |
} | |
) | |
@triton.jit(do_not_specialize=["T"]) | |
def copy_input_kernel( | |
q, | |
k, | |
v, | |
g, | |
level_scales, | |
cu_seqlens, | |
q_prev, | |
k_prev, | |
v_prev, | |
g_prev, | |
level_scales_prev, | |
offsets, | |
q_new, | |
k_new, | |
v_new, | |
g_new, | |
level_scales_new, | |
T, | |
H: tl.constexpr, | |
K: tl.constexpr, | |
V: tl.constexpr, | |
L: tl.constexpr, | |
BT: tl.constexpr, | |
IS_VARLEN: tl.constexpr, | |
): | |
"""Copy input tensors while incorporating previous state for continuous processing.""" | |
# … rest of the kernel implementation … |
🧰 Tools
🪛 Ruff (0.11.9)
675-675: Ambiguous variable name: l
(E741)
🤖 Prompt for AI Agents
In fla/ops/log_linear_attn/chunk.py between lines 664 and 695, rename the
ambiguous parameter 'l' in the copy_input_kernel function to a more descriptive
and consistent name that aligns with the main kernel's naming conventions.
Update all occurrences of this parameter within the function signature and body
accordingly to maintain clarity and consistency.
def construct_binary_level_mask(level, T): | ||
if level == 0: | ||
return torch.diag(torch.ones(T, dtype=torch.bool)) | ||
|
||
indices = torch.cartesian_prod(torch.arange(T), torch.arange(T)) | ||
|
||
mask = torch.where( | ||
torch.logical_and( | ||
torch.logical_and( | ||
indices[:, 0] % (1 << level) >= (1 << (level - 1)), | ||
indices[:, 1] + (1 << (level - 1)) | ||
>= indices[:, 0] - (indices[:, 0] % (1 << (level - 1))), | ||
), | ||
indices[:, 1] < indices[:, 0] - (indices[:, 0] % (1 << (level - 1))), | ||
).view(T, T), | ||
1, | ||
0, | ||
) | ||
|
||
return mask | ||
|
||
|
||
def level_lut(BT, device): | ||
lut = torch.zeros((BT, BT), dtype=torch.int32, device=device) | ||
for level in range(1, ceil_log(BT, 2) + 1): | ||
mask = construct_binary_level_mask(level, BT).to(device) | ||
lut = torch.where(mask.to(torch.bool), level, lut) | ||
return lut | ||
|
||
|
||
def ceil_div(x: int, y: int) -> int: | ||
return math.ceil(x / y) | ||
|
||
|
||
def ceil_log(x: int, b: int) -> int: | ||
return math.ceil(math.log(x, b)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add documentation to utility functions.
These utility functions lack documentation explaining their purpose in the hierarchical attention scheme.
def construct_binary_level_mask(level, T):
+ """Construct a binary mask for a specific hierarchical level.
+
+ Args:
+ level: Hierarchical level (0 for diagonal, higher for broader patterns)
+ T: Sequence length
+
+ Returns:
+ Binary mask tensor of shape (T, T)
+ """
if level == 0:
return torch.diag(torch.ones(T, dtype=torch.bool))
def level_lut(BT, device):
+ """Generate a lookup table mapping positions to hierarchical levels.
+
+ Args:
+ BT: Block size
+ device: Target device
+
+ Returns:
+ Lookup table of shape (BT, BT) with level indices
+ """
lut = torch.zeros((BT, BT), dtype=torch.int32, device=device)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def construct_binary_level_mask(level, T): | |
if level == 0: | |
return torch.diag(torch.ones(T, dtype=torch.bool)) | |
indices = torch.cartesian_prod(torch.arange(T), torch.arange(T)) | |
mask = torch.where( | |
torch.logical_and( | |
torch.logical_and( | |
indices[:, 0] % (1 << level) >= (1 << (level - 1)), | |
indices[:, 1] + (1 << (level - 1)) | |
>= indices[:, 0] - (indices[:, 0] % (1 << (level - 1))), | |
), | |
indices[:, 1] < indices[:, 0] - (indices[:, 0] % (1 << (level - 1))), | |
).view(T, T), | |
1, | |
0, | |
) | |
return mask | |
def level_lut(BT, device): | |
lut = torch.zeros((BT, BT), dtype=torch.int32, device=device) | |
for level in range(1, ceil_log(BT, 2) + 1): | |
mask = construct_binary_level_mask(level, BT).to(device) | |
lut = torch.where(mask.to(torch.bool), level, lut) | |
return lut | |
def ceil_div(x: int, y: int) -> int: | |
return math.ceil(x / y) | |
def ceil_log(x: int, b: int) -> int: | |
return math.ceil(math.log(x, b)) | |
def construct_binary_level_mask(level, T): | |
"""Construct a binary mask for a specific hierarchical level. | |
Args: | |
level: Hierarchical level (0 for diagonal, higher for broader patterns) | |
T: Sequence length | |
Returns: | |
Binary mask tensor of shape (T, T) | |
""" | |
if level == 0: | |
return torch.diag(torch.ones(T, dtype=torch.bool)) | |
indices = torch.cartesian_prod(torch.arange(T), torch.arange(T)) | |
mask = torch.where( | |
torch.logical_and( | |
torch.logical_and( | |
indices[:, 0] % (1 << level) >= (1 << (level - 1)), | |
indices[:, 1] + (1 << (level - 1)) | |
>= indices[:, 0] - (indices[:, 0] % (1 << (level - 1))), | |
), | |
indices[:, 1] < indices[:, 0] - (indices[:, 0] % (1 << (level - 1))), | |
).view(T, T), | |
1, | |
0, | |
) | |
return mask | |
def level_lut(BT, device): | |
"""Generate a lookup table mapping positions to hierarchical levels. | |
Args: | |
BT: Block size | |
device: Target device | |
Returns: | |
Lookup table of shape (BT, BT) with level indices | |
""" | |
lut = torch.zeros((BT, BT), dtype=torch.int32, device=device) | |
for level in range(1, ceil_log(BT, 2) + 1): | |
mask = construct_binary_level_mask(level, BT).to(device) | |
lut = torch.where(mask.to(torch.bool), level, lut) | |
return lut |
🤖 Prompt for AI Agents
In fla/ops/log_linear_attn/chunk.py around lines 910 to 946, the utility
functions construct_binary_level_mask, level_lut, ceil_div, and ceil_log lack
documentation. Add concise docstrings to each function explaining their purpose
and role in the hierarchical attention scheme, including input parameters and
return values to improve code readability and maintainability.
@triton.heuristics( | ||
{ | ||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None, | ||
} | ||
) | ||
@triton.jit(do_not_specialize=["T"]) | ||
def copy_last_chunk_kernel( | ||
q, | ||
k, | ||
v, | ||
g, | ||
l, | ||
cu_seqlens, | ||
q_prev, | ||
k_prev, | ||
v_prev, | ||
g_prev, | ||
l_prev, | ||
offsets, | ||
T, | ||
H: tl.constexpr, | ||
K: tl.constexpr, | ||
V: tl.constexpr, | ||
L: tl.constexpr, | ||
BT: tl.constexpr, | ||
IS_VARLEN: tl.constexpr, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename ambiguous variable 'l' in copy_last_chunk_kernel.
Maintain consistency with the naming convention changes.
def copy_last_chunk_kernel(
q,
k,
v,
g,
- l,
+ level_scales,
cu_seqlens,
q_prev,
k_prev,
v_prev,
g_prev,
- l_prev,
+ level_scales_prev,
offsets,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
L: tl.constexpr,
BT: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
+ """Copy the last chunk of sequences to preserve state for next iteration."""
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
@triton.heuristics( | |
{ | |
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None, | |
} | |
) | |
@triton.jit(do_not_specialize=["T"]) | |
def copy_last_chunk_kernel( | |
q, | |
k, | |
v, | |
g, | |
l, | |
cu_seqlens, | |
q_prev, | |
k_prev, | |
v_prev, | |
g_prev, | |
l_prev, | |
offsets, | |
T, | |
H: tl.constexpr, | |
K: tl.constexpr, | |
V: tl.constexpr, | |
L: tl.constexpr, | |
BT: tl.constexpr, | |
IS_VARLEN: tl.constexpr, | |
): | |
@triton.heuristics( | |
{ | |
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None, | |
} | |
) | |
@triton.jit(do_not_specialize=["T"]) | |
def copy_last_chunk_kernel( | |
q, | |
k, | |
v, | |
g, | |
level_scales, | |
cu_seqlens, | |
q_prev, | |
k_prev, | |
v_prev, | |
g_prev, | |
level_scales_prev, | |
offsets, | |
T, | |
H: tl.constexpr, | |
K: tl.constexpr, | |
V: tl.constexpr, | |
L: tl.constexpr, | |
BT: tl.constexpr, | |
IS_VARLEN: tl.constexpr, | |
): | |
"""Copy the last chunk of sequences to preserve state for next iteration.""" |
🧰 Tools
🪛 Ruff (0.11.9)
826-826: Ambiguous variable name: l
(E741)
🤖 Prompt for AI Agents
In fla/ops/log_linear_attn/chunk.py around lines 815 to 841, the variable named
'l' in the copy_last_chunk_kernel function is ambiguous and should be renamed to
a more descriptive name consistent with the project's naming conventions.
Identify a clear and meaningful name that reflects the variable's purpose and
replace all instances of 'l' with this new name within the function signature
and body.
@torch.compiler.disable | ||
def chunk_log_linear_attn( | ||
q: torch.Tensor, | ||
k: torch.Tensor, | ||
v: torch.Tensor, | ||
g: torch.Tensor, | ||
l: torch.Tensor, | ||
initial_state: Optional[torch.Tensor] = None, | ||
output_final_state: bool = False, | ||
cu_seqlens: Optional[torch.LongTensor] = None, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
r""" | ||
Args: | ||
q (torch.Tensor): | ||
queries of shape `[B, T, H, K]`. | ||
k (torch.Tensor): | ||
keys of shape `[B, T, H, K]`. | ||
v (torch.Tensor): | ||
values of shape `[B, T, H, V]`. | ||
g (torch.Tensor): | ||
Forget gates of shape `[B, T, H]`. | ||
l (torch.Tensor): | ||
Scales for each level of shape `[B, T, H, L]`. | ||
initial_state (Optional[torch.Tensor]): | ||
Initial state of shape `[N, H, K, V]` for `N` input sequences. | ||
For equal-length input sequences, `N` equals the batch size `B`. | ||
Default: `None`. | ||
output_final_state (Optional[bool]): | ||
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. | ||
cu_seqlens (torch.LongTensor): | ||
Cumulative sequence lengths of shape `[N+1]` used for variable-length training, | ||
consistent with the FlashAttention API. | ||
|
||
Returns: | ||
o (torch.Tensor): | ||
Outputs of shape `[B, T, H, V]`. | ||
final_state (torch.Tensor): | ||
Final state of type `LogLinearAttentionState` if `output_final_state=True` else `None`. | ||
|
||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix ambiguous variable name and document gradient limitation.
The parameter 'l' should be renamed, and the docstring should mention that gradients are not supported.
def chunk_log_linear_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
- l: torch.Tensor,
+ level_scales: torch.Tensor,
initial_state: Optional[torch.Tensor] = None,
output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]`.
v (torch.Tensor):
values of shape `[B, T, H, V]`.
g (torch.Tensor):
Forget gates of shape `[B, T, H]`.
- l (torch.Tensor):
+ level_scales (torch.Tensor):
Scales for each level of shape `[B, T, H, L]`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, H, K, V]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, H, V]`.
final_state (torch.Tensor):
Final state of type `LogLinearAttentionState` if `output_final_state=True` else `None`.
+
+ Note:
+ This implementation does not support gradient computation (backward pass).
+ Use with torch.no_grad() or for inference only.
"""
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
@torch.compiler.disable | |
def chunk_log_linear_attn( | |
q: torch.Tensor, | |
k: torch.Tensor, | |
v: torch.Tensor, | |
g: torch.Tensor, | |
l: torch.Tensor, | |
initial_state: Optional[torch.Tensor] = None, | |
output_final_state: bool = False, | |
cu_seqlens: Optional[torch.LongTensor] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
r""" | |
Args: | |
q (torch.Tensor): | |
queries of shape `[B, T, H, K]`. | |
k (torch.Tensor): | |
keys of shape `[B, T, H, K]`. | |
v (torch.Tensor): | |
values of shape `[B, T, H, V]`. | |
g (torch.Tensor): | |
Forget gates of shape `[B, T, H]`. | |
l (torch.Tensor): | |
Scales for each level of shape `[B, T, H, L]`. | |
initial_state (Optional[torch.Tensor]): | |
Initial state of shape `[N, H, K, V]` for `N` input sequences. | |
For equal-length input sequences, `N` equals the batch size `B`. | |
Default: `None`. | |
output_final_state (Optional[bool]): | |
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. | |
cu_seqlens (torch.LongTensor): | |
Cumulative sequence lengths of shape `[N+1]` used for variable-length training, | |
consistent with the FlashAttention API. | |
Returns: | |
o (torch.Tensor): | |
Outputs of shape `[B, T, H, V]`. | |
final_state (torch.Tensor): | |
Final state of type `LogLinearAttentionState` if `output_final_state=True` else `None`. | |
""" | |
@torch.compiler.disable | |
def chunk_log_linear_attn( | |
q: torch.Tensor, | |
k: torch.Tensor, | |
v: torch.Tensor, | |
g: torch.Tensor, | |
level_scales: torch.Tensor, | |
initial_state: Optional[torch.Tensor] = None, | |
output_final_state: bool = False, | |
cu_seqlens: Optional[torch.LongTensor] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
r""" | |
Args: | |
q (torch.Tensor): | |
queries of shape `[B, T, H, K]`. | |
k (torch.Tensor): | |
keys of shape `[B, T, H, K]`. | |
v (torch.Tensor): | |
values of shape `[B, T, H, V]`. | |
g (torch.Tensor): | |
Forget gates of shape `[B, T, H]`. | |
level_scales (torch.Tensor): | |
Scales for each level of shape `[B, T, H, L]`. | |
initial_state (Optional[torch.Tensor]): | |
Initial state of shape `[N, H, K, V]` for `N` input sequences. | |
For equal-length input sequences, `N` equals the batch size `B`. | |
Default: `None`. | |
output_final_state (Optional[bool]): | |
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. | |
cu_seqlens (torch.LongTensor): | |
Cumulative sequence lengths of shape `[N+1]` used for variable-length training, | |
consistent with the FlashAttention API. | |
Returns: | |
o (torch.Tensor): | |
Outputs of shape `[B, T, H, V]`. | |
final_state (torch.Tensor): | |
Final state of type `LogLinearAttentionState` if `output_final_state=True` else `None`. | |
Note: | |
This implementation does not support gradient computation (backward pass). | |
Use with torch.no_grad() or for inference only. | |
""" | |
# ... rest of implementation ... |
🧰 Tools
🪛 Ruff (0.11.9)
1157-1157: Ambiguous variable name: l
(E741)
🤖 Prompt for AI Agents
In fla/ops/log_linear_attn/chunk.py around lines 1151 to 1190, rename the
parameter 'l' to a more descriptive name that reflects its purpose, such as
'scales' or similar, to avoid ambiguity. Additionally, update the docstring to
explicitly state that gradients are not supported for this function or
parameter, clarifying the limitation for users.
@staticmethod | ||
@input_guard | ||
@autocast_custom_bwd | ||
def backward(ctx, do, dht): | ||
raise NotImplementedError( | ||
"Backward pass is not implemented for log-linear attention." | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Document that backward pass is not implemented.
The backward pass raises NotImplementedError. This should be clearly documented in the main API function's docstring.
🏁 Script executed:
#!/bin/bash
# Check if this is mentioned in any documentation or if there are any gradient-related tests
rg -A 5 "backward|gradient" --glob "*.md" --glob "*.rst"
rg "requires_grad|backward|gradient" tests/ops/test_log_linear_attn.py
Length of output: 9896
🏁 Script executed:
#!/bin/bash
# Locate user-facing log-linear attention API and where the autograd Function is applied
rg -n "^def .*log_linear.*attention" -n fla/ops/log_linear_attn
rg -n "apply(" -n fla/ops/log_linear_attn
Length of output: 172
🏁 Script executed:
#!/bin/bash
# Identify the autograd Function class and inspect its forward docstring
rg -n "^class .*Attention" -n fla/ops/log_linear_attn/chunk.py
rg -n -B 10 -A 5 "def forward" fla/ops/log_linear_attn/chunk.py
Length of output: 597
🏁 Script executed:
#!/bin/bash
# List files in log_linear_attn directory
ls -1 fla/ops/log_linear_attn
# Find where the autograd function is referenced or applied
rg -n "ChunkLogLinearAttentionFunction" -n fla/ops/log_linear_attn
rg -n "apply\\(" -n fla/ops/log_linear_attn
# List all top‐level function definitions in this folder
rg -n "^def " -n fla/ops/log_linear_attn
Length of output: 1323
🏁 Script executed:
#!/bin/bash
# Show the docstring of the user-facing chunk_log_linear_attn function
rg -n -B 5 -A 20 "def chunk_log_linear_attn" fla/ops/log_linear_attn/chunk.py
Length of output: 956
Add missing backward‐pass note to the chunk_log_linear_attn
docstring
The autograd Function’s backward
raises NotImplementedError
, but the user‐facing chunk_log_linear_attn
docstring doesn’t mention this limitation. Please update it to avoid confusion.
• File: fla/ops/log_linear_attn/chunk.py
, function chunk_log_linear_attn
(around line 1162)
• Action: Insert a “Note” stating that gradients aren’t supported and calling backward will throw.
Suggested diff:
@@ -1162,7 +1162,11 @@ def chunk_log_linear_attn(
- r"""
+ r"""
+ Note:
+ The backward pass is not implemented and will raise NotImplementedError.
+
+ (Supports forward inference only.)
Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]`.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
@staticmethod | |
@input_guard | |
@autocast_custom_bwd | |
def backward(ctx, do, dht): | |
raise NotImplementedError( | |
"Backward pass is not implemented for log-linear attention." | |
) | |
def chunk_log_linear_attn( | |
q: torch.Tensor, | |
k: torch.Tensor, | |
v: torch.Tensor, | |
chunk_size: int, | |
*, | |
use_single_stream: bool = False, | |
) -> torch.Tensor: | |
- r""" | |
+ r""" | |
+ Note: | |
+ The backward pass is not implemented and will raise NotImplementedError. | |
+ | |
+ (Supports forward inference only.) | |
Args: | |
q (torch.Tensor): | |
queries of shape `[B, T, H, K]`. | |
k (torch.Tensor): | |
keys of shape `[B, U, H, K]`. | |
v (torch.Tensor): | |
values of shape `[B, U, H, K]`. | |
""" | |
# ... rest of implementation ... |
🤖 Prompt for AI Agents
In fla/ops/log_linear_attn/chunk.py around line 1162, update the docstring of
the function chunk_log_linear_attn to include a note that the backward pass is
not implemented and gradients are not supported, warning users that calling
backward will raise a NotImplementedError. Add this clarification to prevent
confusion about the function's autograd limitations.
@2022tgoel Nice work! |
Summary by CodeRabbit
New Features
Tests