-
Notifications
You must be signed in to change notification settings - Fork 239
Delta Product Rule Backwards Kernel #526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -152,3 +152,212 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) | |
BT=BT, | ||
) | ||
return o | ||
|
||
# @triton.heuristics({ | ||
# 'USE_G': lambda args: args['g'] is not None, | ||
# 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None | ||
# }) | ||
# @triton.autotune( | ||
# configs=[ | ||
# triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) | ||
# for BK in BKV_LIST | ||
# for BV in BKV_LIST | ||
# for num_warps in NUM_WARPS | ||
# for num_stages in [2, 3, 4] | ||
# ], | ||
# key=['H', 'K', 'V', 'BT'], | ||
# ) | ||
# @triton.jit(do_not_specialize=['T']) | ||
# def chunk_gated_delta_product_bwd_kernel_o( | ||
# q, | ||
# k, | ||
# v, | ||
# h, | ||
# g, | ||
# do, | ||
# dq, | ||
# dk, | ||
# dv, | ||
# dh, | ||
# cu_seqlens, | ||
# chunk_indices, | ||
# scale, | ||
# T, | ||
# num_householder: tl.constexpr, | ||
# H: tl.constexpr, | ||
# K: tl.constexpr, | ||
# V: tl.constexpr, | ||
# BT: tl.constexpr, | ||
# BK: tl.constexpr, | ||
# BV: tl.constexpr, | ||
# USE_G: tl.constexpr, | ||
# IS_VARLEN: tl.constexpr, | ||
# ): | ||
# # same parameters as forward pass | ||
# i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) | ||
# i_b, i_h = i_bh // H, i_bh % H | ||
|
||
# if IS_VARLEN: | ||
# i_tg = i_t | ||
# i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) | ||
# bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) | ||
# T = eos - bos | ||
# NT = tl.cdiv(T, BT) | ||
# else: | ||
# NT = tl.cdiv(T, BT) | ||
# i_tg = i_b * NT + i_t | ||
# bos, eos = i_b * T, i_b * T + T | ||
|
||
# # offset calculation | ||
# q += (bos * H + i_h) * K | ||
# k += (bos * H + i_h) * K | ||
# v += (bos * H + i_h) * V | ||
# do += (bos * H + i_h) * V | ||
# dq += (bos * H + i_h) * K | ||
# dk += (bos * H + i_h) * K | ||
# dv += (bos * H + i_h) * V | ||
# h += (i_tg * H + i_h).to(tl.int64) * K*V | ||
# dh += (i_tg * H + i_h).to(tl.int64) * K*V | ||
|
||
# b_dq = tl.zeros([BT, BK], dtype=tl.float32) | ||
# b_dk = tl.zeros([BT, BK], dtype=tl.float32) | ||
# b_dv = tl.zeros([BT, BV], dtype=tl.float32) | ||
# b_ds = tl.zeros([BT, BT], dtype=tl.float32) | ||
|
||
# # Compute gradients from hidden state | ||
# for i_k in range(tl.cdiv(K, BK)): | ||
# p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
# p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) | ||
# p_dh = tl.make_block_ptr(dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) | ||
# p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
|
||
# # [BT, BK] | ||
# b_q = tl.load(p_q, boundary_check=(0, 1)) | ||
# # [BK, BV] | ||
# b_h = tl.load(p_h, boundary_check=(0, 1)) | ||
# b_dh = tl.load(p_dh, boundary_check=(0, 1)) | ||
# # [BT, BV] | ||
# b_do = tl.load(p_do, boundary_check=(0, 1)) | ||
|
||
# # Compute gradients w.r.t. q: dq += do @ h^T | ||
# b_dq += tl.dot(b_do, tl.trans(b_h)) | ||
|
||
# # Compute gradients w.r.t. h: dh += q^T @ do | ||
# tl.store(p_dh, (b_dh + tl.dot(tl.trans(b_q), b_do)).to(p_dh.dtype.element_ty), boundary_check=(0, 1)) | ||
|
||
# # Process multiple Householder transformations | ||
# for i_dp in range(num_householder): | ||
# b_A = tl.zeros([BT, BT], dtype=tl.float32) | ||
|
||
# # Compute attention matrix A = Q @ K^T for this Householder step | ||
# for i_k in range(tl.cdiv(K, BK)): | ||
# p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
# p_k = tl.make_block_ptr(k+i_dp*H*K, (K, T), (1, num_householder*H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) | ||
|
||
# b_q = tl.load(p_q, boundary_check=(0, 1)) | ||
# b_k = tl.load(p_k, boundary_check=(0, 1)) | ||
# b_A += tl.dot(b_q, b_k) | ||
|
||
# # Apply causal mask and gating | ||
# o_t = i_t * BT + tl.arange(0, BT) | ||
# m_t = o_t < T | ||
# if USE_G: | ||
# g += bos * H + i_h | ||
# p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,)) | ||
# b_g = tl.load(p_g, boundary_check=(0,)) | ||
# m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t) | ||
# b_A = tl.where(m_A, b_A * exp(b_g[:, None] - b_g[None, :]), 0) | ||
# else: | ||
# m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t) | ||
# b_A = tl.where(m_A, b_A, 0) | ||
|
||
# # Load values for this Householder step | ||
# p_v = tl.make_block_ptr(v+i_dp*H*V, (T, V), (H*V*num_householder, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
# p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
# b_v = tl.load(p_v, boundary_check=(0, 1)) | ||
# b_do = tl.load(p_do, boundary_check=(0, 1)) | ||
|
||
# # Gradient w.r.t. values: dv += A^T @ do | ||
# b_dv += tl.dot(tl.trans(b_A.to(b_v.dtype)), b_do) | ||
|
||
# # Gradient w.r.t. attention scores: ds = do @ v^T | ||
# b_ds += tl.dot(b_do, tl.trans(b_v)) | ||
|
||
# # Apply scale and gating to score gradients | ||
# b_ds = b_ds * scale | ||
# if USE_G: | ||
# b_ds = tl.where(m_A, b_ds * exp(b_g[:, None] - b_g[None, :]), 0) | ||
# else: | ||
# b_ds = tl.where(m_A, b_ds, 0) | ||
|
||
# # Compute final gradients for each Householder step | ||
# for i_dp in range(num_householder): | ||
# for i_k in range(tl.cdiv(K, BK)): | ||
# p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
# p_k = tl.make_block_ptr(k+i_dp*H*K, (K, T), (1, num_householder*H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) | ||
# p_dq = tl.make_block_ptr(dq, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
# p_dk = tl.make_block_ptr(dk+i_dp*H*K, (K, T), (1, num_householder*H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) | ||
|
||
# b_q = tl.load(p_q, boundary_check=(0, 1)) | ||
# b_k = tl.load(p_k, boundary_check=(0, 1)) | ||
|
||
# # dq += ds @ k^T | ||
# b_dq += tl.dot(b_ds, tl.trans(b_k)) | ||
# # dk += q^T @ ds | ||
# b_dk = tl.dot(tl.trans(b_q), b_ds) | ||
|
||
# tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) | ||
# tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) | ||
|
||
# # Store value gradients | ||
# for i_dp in range(num_householder): | ||
# p_dv = tl.make_block_ptr(dv+i_dp*H*V, (T, V), (H*V*num_householder, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
# tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) | ||
|
||
|
||
# def chunk_gated_delta_product_bwd_o( | ||
# q: torch.Tensor, | ||
# k: torch.Tensor, | ||
# v: torch.Tensor, | ||
# h: torch.Tensor, | ||
# g: Optional[torch.Tensor] = None, | ||
# do: torch.Tensor = None, | ||
# scale: Optional[float] = None, | ||
# cu_seqlens: Optional[torch.LongTensor] = None, | ||
# chunk_size: int = 64, | ||
# num_householder: int = 1, | ||
# ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
# assert q.shape[1] * num_householder == k.shape[1], "q.shape[1] * num_householder must be equal to k.shape[1]" | ||
# B, T, H, K, V = *q.shape, v.shape[-1] | ||
# BT = min(chunk_size, max(16, triton.next_power_of_2(T))) | ||
# chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None | ||
# NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) | ||
|
||
# dq = torch.zeros_like(q) | ||
# dk = torch.zeros_like(k) | ||
# dv = torch.zeros_like(v) | ||
# dh = torch.zeros_like(h) | ||
|
||
# def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) | ||
# chunk_gated_delta_product_bwd_kernel_o[grid]( | ||
# q, | ||
# k, | ||
# v, | ||
# h, | ||
# g, | ||
# do, | ||
# dq, | ||
# dk, | ||
# dv, | ||
# dh, | ||
# cu_seqlens, | ||
# chunk_indices, | ||
# scale, | ||
# T=T, | ||
# num_householder=num_householder, | ||
# H=H, | ||
# K=K, | ||
# V=V, | ||
# BT=BT, | ||
# ) | ||
# return dq, dk, dv, dh | ||
Comment on lines
+156
to
+363
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Consider removing or documenting the purpose of the commented backward kernel code. The PR objectives mention adding a "backward kernel skeleton", but having 200+ lines of commented code without clear documentation about its purpose or timeline for implementation can lead to confusion and maintenance issues. Consider:
🤖 Prompt for AI Agents
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @phi-jkim as suggested by coderabbitai There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,10 @@ | ||
# Simple GLA | ||
|
||
Gating mechanism in [Gated RFA](https://arxiv.org/abs/2103.02143), [Mamba2](https://arxiv.org/abs/2405.21060) and [YOCO](https://arxiv.org/abs/2405.05254) (a.k.a., Gated RetNet). | ||
Gating mechanism in [Gated RFA](https://arxiv.org/abs/2103.02143), [Mamba2](https://arxiv.org/abs/2405.21060) and [YOCO](https://arxiv.org/abs/2405.05254) (a.k.a., Gated RetNet). | ||
|
||
Compared to GLA, the gating is head-wise instead of elementwise. | ||
As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability. | ||
It is faster than GLA but has less expressive power. | ||
Compared to GLA, the gating is head-wise instead of elementwise. | ||
As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability. | ||
It is faster than GLA but has less expressive power. | ||
I will use it as a baseline for the GLA. | ||
|
||
$S_{t+1} = g_{t+1} \odot S_{t} + K_{t+1} V_{t+1}^{\top}$ where $g$ is a scalar. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should format this part in somewhere else rather than in this PR