Skip to content

Delta Product Rule Backwards Kernel #526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion evals/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from __future__ import annotations

import fla # noqa
from lm_eval.__main__ import cli_evaluate
from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM

import fla # noqa

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should format this part in somewhere else rather than in this PR


@register_model('fla')
class FlashLinearAttentionLMWrapper(HFLM):
Expand Down
362 changes: 325 additions & 37 deletions fla/ops/gated_delta_product/chunk.py

Large diffs are not rendered by default.

43 changes: 30 additions & 13 deletions fla/ops/gated_delta_product/chunk_deltaproduct_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64(
chunk_offsets,
scale,
T,
num_householder: tl.constexpr,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
Expand All @@ -237,6 +238,7 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64(
):
i_v, i_nh = tl.program_id(0), tl.program_id(1)
i_n, i_h = i_nh // H, i_nh % H

if IS_VARLEN:
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
Expand All @@ -245,7 +247,9 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64(
else:
bos, eos = i_n * T, i_n * T + T
NT = tl.cdiv(T, BT)
boh = i_n * NT
# boh = i_n * tl.cdiv(T, BT)
# Jinha: update boh to match the chunk_gated_delta_product_fwd_kernel_h_blockdim64 implementation
boh = i_n * tl.cdiv(T // num_householder, BT)

# [BK, BV]
b_dh1 = tl.zeros([64, BV], dtype=tl.float32)
Expand Down Expand Up @@ -312,13 +316,13 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64(
b_g_exp = None

p_dv = tl.make_block_ptr(dv, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_wo = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dv2 = tl.make_block_ptr(dv2, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))

b_wo = tl.load(p_wo, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
b_dv = tl.zeros([BT, BV], dtype=tl.float32)

# Update dv
# Update dv based on hidden state gradients
p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_dv += tl.dot(b_k, b_dh1.to(b_k.dtype))
Expand All @@ -344,7 +348,8 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64(
b_dv += tl.load(p_dv, boundary_check=(0, 1))

tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
# Update dh

# Update hidden state gradients
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
b_w = tl.load(p_w, boundary_check=(0, 1))
Expand All @@ -353,7 +358,8 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64(
b_dh1 *= bg_last_exp
b_q = b_q * b_g_exp[None, :]
b_q = (b_q * scale).to(b_q.dtype)
b_dh1 += tl.dot(b_q, b_wo.to(b_q.dtype))-tl.dot(b_w, b_dv.to(b_w.dtype))
b_dh1 += tl.dot(b_q, b_do.to(b_q.dtype)) - tl.dot(b_w, b_dv.to(b_w.dtype))

if K > 64:
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
Expand All @@ -363,7 +369,8 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64(
b_dh2 *= bg_last_exp
b_q = b_q * b_g_exp[None, :]
b_q = (b_q * scale).to(b_q.dtype)
b_dh2 += tl.dot(b_q, b_wo.to(b_q.dtype))-tl.dot(b_w, b_dv.to(b_w.dtype))
b_dh2 += tl.dot(b_q, b_do.to(b_q.dtype)) - tl.dot(b_w, b_dv.to(b_w.dtype))

if K > 128:
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
Expand All @@ -373,7 +380,8 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64(
b_dh3 *= bg_last_exp
b_q = b_q * b_g_exp[None, :]
b_q = (b_q * scale).to(b_q.dtype)
b_dh3 += tl.dot(b_q, b_wo.to(b_q.dtype))-tl.dot(b_w, b_dv.to(b_w.dtype))
b_dh3 += tl.dot(b_q, b_do.to(b_q.dtype)) - tl.dot(b_w, b_dv.to(b_w.dtype))

if K > 192:
p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
Expand All @@ -383,7 +391,7 @@ def chunk_gated_delta_product_bwd_kernel_dhu_blockdim64(
b_dh4 *= bg_last_exp
b_q = b_q * b_g_exp[None, :]
b_q = (b_q * scale).to(b_q.dtype)
b_dh4 += tl.dot(b_q, b_wo.to(b_q.dtype))-tl.dot(b_w, b_dv.to(b_w.dtype))
b_dh4 += tl.dot(b_q, b_do.to(b_q.dtype)) - tl.dot(b_w, b_dv.to(b_w.dtype))

if USE_INITIAL_STATE:
p_dh0 = tl.make_block_ptr(dh0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
Expand Down Expand Up @@ -460,19 +468,25 @@ def chunk_gated_delta_product_bwd_dhu(
dv: torch.Tensor,
scale: float,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
chunk_size: int = 64,
num_householder: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, H, K, V = *q.shape, do.shape[-1]
assert T % num_householder == 0, "T must be divisible by num_householder"
T_true = T // num_householder

# N: the actual number of sequences in the batch with either equal or variable lengths
BT = 64
assert K <= 256, "current kernel does not support head dimension being larger than 256."

chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
chunk_indices = prepare_chunk_indices(cu_seqlens // num_householder, chunk_size) if cu_seqlens is not None else None
if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
N, NT, chunk_offsets = B, triton.cdiv(T_true, BT), None
else:
N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
N, NT, chunk_offsets = (
len(cu_seqlens) - 1, len(chunk_indices),
prepare_chunk_offsets(cu_seqlens // num_householder, BT)
)

dh = q.new_empty(B, NT, H, K, V)
dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
Expand All @@ -494,9 +508,12 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), N*H)
chunk_offsets=chunk_offsets,
scale=scale,
T=T,
num_householder=num_householder,
H=H,
K=K,
V=V,
BT=BT,
)
# could call chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64 instead
# after adjusting number of tokens
return dh, dh0, dv2
209 changes: 209 additions & 0 deletions fla/ops/gated_delta_product/chunk_deltaproduct_o.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,212 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
BT=BT,
)
return o

# @triton.heuristics({
# 'USE_G': lambda args: args['g'] is not None,
# 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
# })
# @triton.autotune(
# configs=[
# triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
# for BK in BKV_LIST
# for BV in BKV_LIST
# for num_warps in NUM_WARPS
# for num_stages in [2, 3, 4]
# ],
# key=['H', 'K', 'V', 'BT'],
# )
# @triton.jit(do_not_specialize=['T'])
# def chunk_gated_delta_product_bwd_kernel_o(
# q,
# k,
# v,
# h,
# g,
# do,
# dq,
# dk,
# dv,
# dh,
# cu_seqlens,
# chunk_indices,
# scale,
# T,
# num_householder: tl.constexpr,
# H: tl.constexpr,
# K: tl.constexpr,
# V: tl.constexpr,
# BT: tl.constexpr,
# BK: tl.constexpr,
# BV: tl.constexpr,
# USE_G: tl.constexpr,
# IS_VARLEN: tl.constexpr,
# ):
# # same parameters as forward pass
# i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
# i_b, i_h = i_bh // H, i_bh % H

# if IS_VARLEN:
# i_tg = i_t
# i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
# bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
# T = eos - bos
# NT = tl.cdiv(T, BT)
# else:
# NT = tl.cdiv(T, BT)
# i_tg = i_b * NT + i_t
# bos, eos = i_b * T, i_b * T + T

# # offset calculation
# q += (bos * H + i_h) * K
# k += (bos * H + i_h) * K
# v += (bos * H + i_h) * V
# do += (bos * H + i_h) * V
# dq += (bos * H + i_h) * K
# dk += (bos * H + i_h) * K
# dv += (bos * H + i_h) * V
# h += (i_tg * H + i_h).to(tl.int64) * K*V
# dh += (i_tg * H + i_h).to(tl.int64) * K*V

# b_dq = tl.zeros([BT, BK], dtype=tl.float32)
# b_dk = tl.zeros([BT, BK], dtype=tl.float32)
# b_dv = tl.zeros([BT, BV], dtype=tl.float32)
# b_ds = tl.zeros([BT, BT], dtype=tl.float32)

# # Compute gradients from hidden state
# for i_k in range(tl.cdiv(K, BK)):
# p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
# p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
# p_dh = tl.make_block_ptr(dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
# p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))

# # [BT, BK]
# b_q = tl.load(p_q, boundary_check=(0, 1))
# # [BK, BV]
# b_h = tl.load(p_h, boundary_check=(0, 1))
# b_dh = tl.load(p_dh, boundary_check=(0, 1))
# # [BT, BV]
# b_do = tl.load(p_do, boundary_check=(0, 1))

# # Compute gradients w.r.t. q: dq += do @ h^T
# b_dq += tl.dot(b_do, tl.trans(b_h))

# # Compute gradients w.r.t. h: dh += q^T @ do
# tl.store(p_dh, (b_dh + tl.dot(tl.trans(b_q), b_do)).to(p_dh.dtype.element_ty), boundary_check=(0, 1))

# # Process multiple Householder transformations
# for i_dp in range(num_householder):
# b_A = tl.zeros([BT, BT], dtype=tl.float32)

# # Compute attention matrix A = Q @ K^T for this Householder step
# for i_k in range(tl.cdiv(K, BK)):
# p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
# p_k = tl.make_block_ptr(k+i_dp*H*K, (K, T), (1, num_householder*H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))

# b_q = tl.load(p_q, boundary_check=(0, 1))
# b_k = tl.load(p_k, boundary_check=(0, 1))
# b_A += tl.dot(b_q, b_k)

# # Apply causal mask and gating
# o_t = i_t * BT + tl.arange(0, BT)
# m_t = o_t < T
# if USE_G:
# g += bos * H + i_h
# p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,))
# b_g = tl.load(p_g, boundary_check=(0,))
# m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
# b_A = tl.where(m_A, b_A * exp(b_g[:, None] - b_g[None, :]), 0)
# else:
# m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
# b_A = tl.where(m_A, b_A, 0)

# # Load values for this Householder step
# p_v = tl.make_block_ptr(v+i_dp*H*V, (T, V), (H*V*num_householder, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# b_v = tl.load(p_v, boundary_check=(0, 1))
# b_do = tl.load(p_do, boundary_check=(0, 1))

# # Gradient w.r.t. values: dv += A^T @ do
# b_dv += tl.dot(tl.trans(b_A.to(b_v.dtype)), b_do)

# # Gradient w.r.t. attention scores: ds = do @ v^T
# b_ds += tl.dot(b_do, tl.trans(b_v))

# # Apply scale and gating to score gradients
# b_ds = b_ds * scale
# if USE_G:
# b_ds = tl.where(m_A, b_ds * exp(b_g[:, None] - b_g[None, :]), 0)
# else:
# b_ds = tl.where(m_A, b_ds, 0)

# # Compute final gradients for each Householder step
# for i_dp in range(num_householder):
# for i_k in range(tl.cdiv(K, BK)):
# p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
# p_k = tl.make_block_ptr(k+i_dp*H*K, (K, T), (1, num_householder*H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
# p_dq = tl.make_block_ptr(dq, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
# p_dk = tl.make_block_ptr(dk+i_dp*H*K, (K, T), (1, num_householder*H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))

# b_q = tl.load(p_q, boundary_check=(0, 1))
# b_k = tl.load(p_k, boundary_check=(0, 1))

# # dq += ds @ k^T
# b_dq += tl.dot(b_ds, tl.trans(b_k))
# # dk += q^T @ ds
# b_dk = tl.dot(tl.trans(b_q), b_ds)

# tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
# tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))

# # Store value gradients
# for i_dp in range(num_householder):
# p_dv = tl.make_block_ptr(dv+i_dp*H*V, (T, V), (H*V*num_householder, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))


# def chunk_gated_delta_product_bwd_o(
# q: torch.Tensor,
# k: torch.Tensor,
# v: torch.Tensor,
# h: torch.Tensor,
# g: Optional[torch.Tensor] = None,
# do: torch.Tensor = None,
# scale: Optional[float] = None,
# cu_seqlens: Optional[torch.LongTensor] = None,
# chunk_size: int = 64,
# num_householder: int = 1,
# ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# assert q.shape[1] * num_householder == k.shape[1], "q.shape[1] * num_householder must be equal to k.shape[1]"
# B, T, H, K, V = *q.shape, v.shape[-1]
# BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
# chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
# NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)

# dq = torch.zeros_like(q)
# dk = torch.zeros_like(k)
# dv = torch.zeros_like(v)
# dh = torch.zeros_like(h)

# def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
# chunk_gated_delta_product_bwd_kernel_o[grid](
# q,
# k,
# v,
# h,
# g,
# do,
# dq,
# dk,
# dv,
# dh,
# cu_seqlens,
# chunk_indices,
# scale,
# T=T,
# num_householder=num_householder,
# H=H,
# K=K,
# V=V,
# BT=BT,
# )
# return dq, dk, dv, dh
Comment on lines +156 to +363
Copy link

@coderabbitai coderabbitai bot Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider removing or documenting the purpose of the commented backward kernel code.

The PR objectives mention adding a "backward kernel skeleton", but having 200+ lines of commented code without clear documentation about its purpose or timeline for implementation can lead to confusion and maintenance issues. Consider:

  1. If this is a work-in-progress, add a clear TODO comment explaining the implementation timeline
  2. If this is example/reference code, move it to documentation
  3. If this is no longer needed, remove it entirely
🤖 Prompt for AI Agents
In fla/ops/gated_delta_product/chunk_deltaproduct_o.py between lines 156 and
363, there is over 200 lines of commented-out backward kernel code without
explanation. To fix this, either add a clear TODO comment at the top explaining
that this is a work-in-progress and when it will be implemented, or if it is
reference/example code, move it to separate documentation files. If the code is
obsolete or no longer needed, remove it entirely to avoid confusion and improve
maintainability.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phi-jkim as suggested by coderabbitai

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

8 changes: 4 additions & 4 deletions fla/ops/simple_gla/README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Simple GLA

Gating mechanism in [Gated RFA](https://arxiv.org/abs/2103.02143), [Mamba2](https://arxiv.org/abs/2405.21060) and [YOCO](https://arxiv.org/abs/2405.05254) (a.k.a., Gated RetNet).
Gating mechanism in [Gated RFA](https://arxiv.org/abs/2103.02143), [Mamba2](https://arxiv.org/abs/2405.21060) and [YOCO](https://arxiv.org/abs/2405.05254) (a.k.a., Gated RetNet).

Compared to GLA, the gating is head-wise instead of elementwise.
As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability.
It is faster than GLA but has less expressive power.
Compared to GLA, the gating is head-wise instead of elementwise.
As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability.
It is faster than GLA but has less expressive power.
I will use it as a baseline for the GLA.

$S_{t+1} = g_{t+1} \odot S_{t} + K_{t+1} V_{t+1}^{\top}$ where $g$ is a scalar.
Loading
Loading