Skip to content

Chunkwise gated linear attention reaching 60~80 TFLOP/s, with step-by-step optimization records#88

Open
learning-chip wants to merge 31 commits intomainfrom
linear_attn
Open

Chunkwise gated linear attention reaching 60~80 TFLOP/s, with step-by-step optimization records#88
learning-chip wants to merge 31 commits intomainfrom
linear_attn

Conversation

@learning-chip
Copy link
Copy Markdown
Collaborator

@learning-chip learning-chip commented Apr 5, 2026

TL;DR This is a much (3~5x) faster version of triton chunk_o in vllm-ascend.

Note on the chosen algorithm: it mirrors the chunk_fwd_o part of chunk_gated_delta_rule_fwd used in vllm-ascend's Qwen3.5 prefill. The "chunk_o" part is as simple as Gated Linear Attention (same as Mamba-2). This PR does not consider the "delta rule" part yet, which is only shown in "chunk_h" phase.

Step-by-step perf optimizations

  • Initial static-shape starting point -- c226f0a (generated from tilelang linear_attention_causal, directly compiling dynamic-shape in tilelang fails, I sent a few PRs to tilelang)
  • First dynamic-shape version, ~1 TFLOP/s -- 41aeecc
  • inline all wrappers and reduce branching, ~5 TFLOP/s -- 5f1ac35
  • precompute and cache causal mask, ~30 TFLOP/s -- a9b54ed
  • chunk size 64 -> 128, got ~50 TFLOP/s -- bd954f9
  • L0 ping-pong buffer, got ~58 TFLOP/s --7b811b0
  • two-slot C-V pipelining, got ~75 TFLOP/s -- 3350511
  • L1 prefetching, got ~78 TFLOP/s and ~570 GiB/s bandwidth -- 26aac37

Added optimize_step_by_step directory to reproduce the above step-by-step performance gains, but at latest commit.

Feature list:

  • Dynamic batch/seq dim, templated head/hidden dim
  • Compare precision and performance to triton baseline (see triton_baseline/performance_summary.md, PTO version is 3~4x faster)
  • Support BSND (seq-first) and BNSD (head-first) layouts (seq-first works but drops TFLOPs from ~75 to ~60)
  • Variable seqlen input for BSND case
  • Support scalar gating factor, matching FLA's simple GLA chunk

Compiles and runs correctly using the pto-isa headers in /usr/local/Ascend/cann-8.5.1/include (CANN version in quay.io/ascend/vllm-ascend:v0.18.0rc1 package). GitCode tag 8.5.0 also compiles and runs fine.

Remaining issues for this PR

  • Although unit tests all pass, still need careful human check/cleaning/annotation of those largely AI-generated code.
    • better keep this PR as-is, and manually distill the minimum & cleanest code in a separate PR and merge.
  • Getting compile error using a newer pto-isa header on 04/03 commit, to fix later.
    • Now fixed by adding #include <pto/common/pto_tile.hpp> to the kernel sources, and change the use of TTRI. See commit -- 382153e

Minor issues:
To avoid slow mask construction using scalar core loops, currently the causal mask is precomputed in passed-in as extra arg.

in vllm-ascend the on-SRAM mask is built by:

        o_i = tl.arange(0, BT).to(tl.float32)
        m_A = o_i[:, None] >= o_i[None, :]
  • Need to pick an efficient PTO instruction for this masking

Remaining issues for future PRs

On algorithm side:

  • Add "chunk_h" part (including chunk_scaled_dot_kkt_fwd/recompute_w_u_fwd/chunk_gated_delta_rule_fwd_h) for GatedDeltaNet and Kimi Delta Attention -- see Complete chunkwise GatedDeltaNet #91
  • Merge into one "GDN layer megakernel" and integrate into vllm/sglang
  • Generalize to the entire FLA repo collections (provide agent some samples of Triton GPU -> PTO-ISA porting, this PR might be enough)
  • Support backward kernel for training

On C++ framework side:

  • Test new TPUSH/TPOP abstraction for C-V communication
  • Test generalization to A5 backend.
  • Test new AUTO mode for synchronization

On Python DSL framework side:

  • Implement this simple chunk_o part of linear attention as the first useful mix kernel example in pto-dsl
  • Enable ptoas auto-sync pass and compare with manual plan

Comment thread examples/jit_cpp/linear_attention/run_linear_attention.py Fixed
Comment thread examples/jit_cpp/linear_attention/benchmark_linear_attention.py Fixed
Comment thread examples/jit_cpp/linear_attention/run_linear_attention.py Fixed
@learning-chip learning-chip changed the title WIP Linear Attention Chunkwise linear attention with step-by-step optimization to reach ~80 TFLOP/s Apr 5, 2026
@learning-chip learning-chip changed the title Chunkwise linear attention with step-by-step optimization to reach ~80 TFLOP/s Chunkwise linear attention reaching ~80 TFLOP/s with step-by-step optimization history Apr 5, 2026
@learning-chip learning-chip marked this pull request as ready for review April 5, 2026 22:26
Comment thread examples/jit_cpp/linear_attention/benchmark_linear_attention.py Fixed
sys.path.insert(0, str(COMMON_DIR))

import torch
import torch_npu # noqa: F401
from statistics import median

import torch
import torch_npu # noqa: F401

from functools import lru_cache

from jit_shared import BLOCK_DIM, OPTIMIZED_KERNEL_FLAGS, compile_cpp as shared_compile_cpp
from functools import lru_cache

from jit_shared import BLOCK_DIM, OPTIMIZED_KERNEL_FLAGS, compile_cpp as shared_compile_cpp
from jit_shared import get_causal_mask, load_dynamic_mask_lib
@@ -0,0 +1,259 @@
import argparse
import importlib.util
import os
from statistics import median

import torch
import torch_npu # noqa: F401
Comment thread examples/jit_cpp/linear_attention/triton_baseline/chunk_o.py Fixed
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
) -> torch.Tensor:
b, t, hg, k_dim, v_dim = *q.shape, v.shape[-1]
from typing import Literal

import torch
import torch_npu # noqa: F401
@learning-chip learning-chip changed the title Chunkwise linear attention reaching ~80 TFLOP/s with step-by-step optimization history Chunkwise gated linear attention reaching 60~80 TFLOP/s, with step-by-step optimization records Apr 7, 2026
from typing import Optional

import torch
import torch_npu # noqa: F401

from functools import lru_cache

from jit_shared import BLOCK_DIM, compile_cpp as shared_compile_cpp
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant