-
Notifications
You must be signed in to change notification settings - Fork 222
Delta Product Rule Backwards Kernel #526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughThis update introduces a new explicit backward function and kernel for the chunked gated delta product operation, enhancing gradient computation with modular and detailed logic. It also adds support for a "num_householder" parameter in relevant kernels and functions. Additional changes include minor import reorderings, whitespace adjustments, and the addition of newlines in configuration files. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Model
participant BackwardFunction
participant BackwardKernel
User->>Model: Forward pass (chunked gated delta product)
Model->>BackwardFunction: Save intermediates for backward
User->>Model: Backward pass (gradients)
Model->>BackwardFunction: Recompute intermediates
BackwardFunction->>BackwardKernel: Launch backward kernel with parameters (incl. num_householder)
BackwardKernel-->>BackwardFunction: Compute gradients dq, dk, dv, dh, etc.
BackwardFunction-->>Model: Return gradients
Suggested reviewers
Poem
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (4)
fla/ops/simple_gla/README.md (1)
3-7
: Trailing-whitespace trimmed – nice cleanup.Consider adding an editorconfig or pre-commit hook to keep whitespace tidy automatically.
evals/harness.py (1)
14-18
: Nit: incorrect return-type annotation on__init__
.
__init__
should returnNone
, not the class itself. While not harmful at runtime, type-checkers will flag it.- def __init__(self, **kwargs) -> FlashLinearAttentionLMWrapper: + def __init__(self, **kwargs) -> None:legacy/training/flame/parser.py (1)
10-17
: Shadowing the importedTrainingArguments
is intentional but can confuse IDEs.Redefining
TrainingArguments
via subclassing is common, yet some static-analysis tools raise warnings.
Consider renaming the custom class (e.g.,GLATrainingArguments
) to avoid confusion.fla/ops/gated_delta_product/chunk.py (1)
427-438
: Consider caching intermediate values to avoid recomputation.The backward pass recomputes
h
andv_new
which were already computed in the forward pass. While the comment asks "why can't we store this originally in ctx", storing these values would trade memory for computation time. Consider:
- Analyzing the memory vs computation trade-off
- Adding a flag to optionally cache these values for users who prioritize speed over memory
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
evals/harness.py
(1 hunks)fla/ops/gated_delta_product/chunk.py
(2 hunks)fla/ops/gated_delta_product/chunk_deltaproduct_h.py
(11 hunks)fla/ops/gated_delta_product/chunk_deltaproduct_o.py
(1 hunks)fla/ops/simple_gla/README.md
(1 hunks)legacy/training/README.md
(6 hunks)legacy/training/configs/gla_1B.json
(1 hunks)legacy/training/configs/gla_340M.json
(1 hunks)legacy/training/configs/gla_7B.json
(1 hunks)legacy/training/configs/transformer_340M.json
(1 hunks)legacy/training/flame/logging.py
(1 hunks)legacy/training/flame/parser.py
(1 hunks)legacy/training/run.py
(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
legacy/training/README.md (1)
Learnt from: toothacher17
PR: fla-org/flash-linear-attention#395
File: fla/layers/mla.py:0-0
Timestamp: 2025-05-13T06:04:24.342Z
Learning: In Flash Attention's `flash_attn_varlen_func`, the scaling factor (typically computed as head_dim ** -0.5) is handled internally by the Flash Attention implementation rather than needing to be applied explicitly to query projections.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (13)
legacy/training/configs/gla_7B.json (1)
28-29
: No-op change – newline only.Nothing to flag. 👍
legacy/training/configs/gla_340M.json (1)
24-25
: No-op change – newline only.Formatting consistency looks good.
legacy/training/configs/transformer_340M.json (1)
18-19
: No-op change – newline only.All good here.
legacy/training/configs/gla_1B.json (1)
25-26
: No-op change – newline only.LGTM.
legacy/training/README.md (1)
10-11
: Pure formatting change – no action required.Only whitespace/markdown-layout tweaks were made; content and references remain intact.
No further review comments.Also applies to: 38-40, 106-110, 114-115, 141-145, 174-176
evals/harness.py (1)
9-10
: Confirm side-effect import order still works.
import fla # noqa
is kept for its registration side-effects.
By moving it below the Hugging Face imports you’ve avoided circular-import risk, but please double-check that any@register_model
decorators insidefla
still execute beforecli_evaluate()
introspects the registry.@@ -import fla # noqa +import fla # noqa # side-effects: registers FLA model classesIf unit tests cover
lm_eval
model loading this is already verified; otherwise consider a quick smoke test.legacy/training/flame/logging.py (1)
9-10
: LGTM – import consolidation only.Switching to a single-line import has no functional impact and passes black/isort settings.
legacy/training/run.py (1)
7-10
: Verify that delayedfla
import does not defer monkey-patches needed before model instantiation.Some
fla
sub-modules patch HF layers at import time. BecauseAutoModelForCausalLM.from_pretrained
is now executed before anyfla
code runs, ensure no critical monkey-patch is required for forward/backward compatibility. If patches are essential, move the import back above the HF call or perform an explicitfla.patch_all()
right after importing.fla/ops/gated_delta_product/chunk_deltaproduct_h.py (3)
250-252
: Good fix for chunk offset calculation consistency.The update to
boh
calculation to includenum_householder
correctly aligns the backward kernel with the forward kernel's chunking logic.
319-319
: Correct variable naming fix.Good catch on fixing the variable names from
p_wo
/b_wo
top_do
/b_do
to accurately represent the gradient of the output.Also applies to: 322-322
475-475
: Verify divisibility assertion is appropriate.The assertion
T % num_householder == 0
enforces strict divisibility. Ensure this requirement is documented and validated at higher levels to provide better error messages to users.Also applies to: 485-485
fla/ops/gated_delta_product/chunk.py (2)
249-251
: Add clarification for the TODO comment.The TODO comment "implement delta product version of chunk_bwd_dqkwg" suggests using a non-optimized version. Please clarify:
- What specific optimizations are missing?
- Is there a performance impact?
- Is there a timeline for implementing the optimized version?
292-293
: Good implementation of gradient accumulation.The gradient accumulation using in-place addition and the clear comment explaining the multivariate chain rule makes the code easier to understand and maintain.
# if USE_G: | ||
# g += bos * H + i_h | ||
# p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,)) | ||
# b_g = tl.load(p_g, boundary_check=(0,)) | ||
# m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t) | ||
# b_A = tl.where(m_A, b_A * exp(b_g[:, None] - b_g[None, :]), 0) | ||
# else: | ||
# m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t) | ||
# b_A = tl.where(m_A, b_A, 0) | ||
|
||
# # Load values for this Householder step | ||
# p_v = tl.make_block_ptr(v+i_dp*H*V, (T, V), (H*V*num_householder, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
# p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
# b_v = tl.load(p_v, boundary_check=(0, 1)) | ||
# b_do = tl.load(p_do, boundary_check=(0, 1)) | ||
|
||
# # Gradient w.r.t. values: dv += A^T @ do | ||
# b_dv += tl.dot(tl.trans(b_A.to(b_v.dtype)), b_do) | ||
|
||
# # Gradient w.r.t. attention scores: ds = do @ v^T | ||
# b_ds += tl.dot(b_do, tl.trans(b_v)) | ||
|
||
# # Apply scale and gating to score gradients | ||
# b_ds = b_ds * scale | ||
# if USE_G: | ||
# b_ds = tl.where(m_A, b_ds * exp(b_g[:, None] - b_g[None, :]), 0) | ||
# else: | ||
# b_ds = tl.where(m_A, b_ds, 0) | ||
|
||
# # Compute final gradients for each Householder step | ||
# for i_dp in range(num_householder): | ||
# for i_k in range(tl.cdiv(K, BK)): | ||
# p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
# p_k = tl.make_block_ptr(k+i_dp*H*K, (K, T), (1, num_householder*H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) | ||
# p_dq = tl.make_block_ptr(dq, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
# p_dk = tl.make_block_ptr(dk+i_dp*H*K, (K, T), (1, num_householder*H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) | ||
|
||
# b_q = tl.load(p_q, boundary_check=(0, 1)) | ||
# b_k = tl.load(p_k, boundary_check=(0, 1)) | ||
|
||
# # dq += ds @ k^T | ||
# b_dq += tl.dot(b_ds, tl.trans(b_k)) | ||
# # dk += q^T @ ds | ||
# b_dk = tl.dot(tl.trans(b_q), b_ds) | ||
|
||
# tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) | ||
# tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) | ||
|
||
# # Store value gradients | ||
# for i_dp in range(num_householder): | ||
# p_dv = tl.make_block_ptr(dv+i_dp*H*V, (T, V), (H*V*num_householder, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
# tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) | ||
|
||
|
||
# def chunk_gated_delta_product_bwd_o( | ||
# q: torch.Tensor, | ||
# k: torch.Tensor, | ||
# v: torch.Tensor, | ||
# h: torch.Tensor, | ||
# g: Optional[torch.Tensor] = None, | ||
# do: torch.Tensor = None, | ||
# scale: Optional[float] = None, | ||
# cu_seqlens: Optional[torch.LongTensor] = None, | ||
# chunk_size: int = 64, | ||
# num_householder: int = 1, | ||
# ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
# assert q.shape[1] * num_householder == k.shape[1], "q.shape[1] * num_householder must be equal to k.shape[1]" | ||
# B, T, H, K, V = *q.shape, v.shape[-1] | ||
# BT = min(chunk_size, max(16, triton.next_power_of_2(T))) | ||
# chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None | ||
# NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) | ||
|
||
# dq = torch.zeros_like(q) | ||
# dk = torch.zeros_like(k) | ||
# dv = torch.zeros_like(v) | ||
# dh = torch.zeros_like(h) | ||
|
||
# def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) | ||
# chunk_gated_delta_product_bwd_kernel_o[grid]( | ||
# q, | ||
# k, | ||
# v, | ||
# h, | ||
# g, | ||
# do, | ||
# dq, | ||
# dk, | ||
# dv, | ||
# dh, | ||
# cu_seqlens, | ||
# chunk_indices, | ||
# scale, | ||
# T=T, | ||
# num_householder=num_householder, | ||
# H=H, | ||
# K=K, | ||
# V=V, | ||
# BT=BT, | ||
# ) | ||
# return dq, dk, dv, dh |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider removing or documenting the purpose of the commented backward kernel code.
The PR objectives mention adding a "backward kernel skeleton", but having 200+ lines of commented code without clear documentation about its purpose or timeline for implementation can lead to confusion and maintenance issues. Consider:
- If this is a work-in-progress, add a clear TODO comment explaining the implementation timeline
- If this is example/reference code, move it to documentation
- If this is no longer needed, remove it entirely
🤖 Prompt for AI Agents
In fla/ops/gated_delta_product/chunk_deltaproduct_o.py between lines 156 and
363, there is over 200 lines of commented-out backward kernel code without
explanation. To fix this, either add a clear TODO comment at the top explaining
that this is a work-in-progress and when it will be implemented, or if it is
reference/example code, move it to separate documentation files. If the code is
obsolete or no longer needed, remove it entirely to avoid confusion and improve
maintainability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@phi-jkim as suggested by coderabbitai
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!
# from fla.ops.gated_delta_product.chunk_deltaproduct_h import chunk_gated_delta_product_bwd_dhu | ||
# from fla.ops.gated_delta_product.chunk_deltaproduct_o import chunk_gated_delta_product_bwd_o | ||
|
||
# cu_seqlens_dp = cu_seqlens * num_householder if cu_seqlens is not None else None | ||
|
||
# # compute gradients wrt q, k, v | ||
# # dv_new_o is gradient w.r.t. v_new | ||
# # might not be fully parallelizable due to gradient Q including previous states H_0 to H_T | ||
# dq_o, dv_new_o, dk_o, dh = chunk_gated_delta_product_bwd_o( | ||
# q=q, | ||
# k=k, | ||
# v=v_new, # v_new = U[i] - W[i]H[i]^T | ||
# h=h, | ||
# g=g, # forward_h takes in g instead of g_interleaved | ||
# scale=scale, | ||
# cu_seqlens=cu_seqlens, | ||
# num_householder=num_householder, | ||
# do=do, # gradient of the output | ||
# ) | ||
|
||
# # recompute w, u from WY representation to compute gradient | ||
# from fla.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd as gdn_recompute_w_u_fwd | ||
# from fla.ops.delta_rule.wy_fast import recompute_w_u_fwd as dn_recompute_w_u_fwd | ||
|
||
# if g_interleaved is not None: | ||
# w, u = gdn_recompute_w_u_fwd( | ||
# k=k, v=v, beta=beta, A=A, g=g_interleaved, cu_seqlens=cu_seqlens_dp, | ||
# ) | ||
# else: | ||
# w, u = dn_recompute_w_u_fwd( | ||
# k=k, v=v, beta=beta, A=A, cu_seqlens=cu_seqlens_dp, | ||
# ) | ||
|
||
# # compute gradients with respect to u and w | ||
# # but need to account for gradients used for sequential computation of hidden states of H_0 to H_T (sequential) | ||
# dh0, du, dw = chunk_gated_delta_product_bwd_dhu( | ||
# q=q, | ||
# k=k, | ||
# w=w, | ||
# u=u, | ||
# g=g_interleaved, | ||
# initial_state=initial_state, #H_0 | ||
# cu_seqlens=cu_seqlens_dp, | ||
# num_householder=num_householder, | ||
# dht=dht, # gradient w.r.t to last hidden state | ||
# dv=dv_new_o, # gradient w.r.t. v_new | ||
# scale=scale, | ||
# ) | ||
|
||
# # compute gradients w.r.t. WY representation (dk, dv, dbeta, dg) | ||
# # This involves computing gradients through the Householder transformations | ||
# from fla.ops.gated_delta_rule.wy_fast import prepare_wy_repr_bwd | ||
|
||
# # gradient descent from W and U | ||
# # g is used for computing W and U in the forward pass | ||
# # dk2 accounts for the gradient of hidden state wrt to k | ||
# # dv_final is the gradient of hidden state wrt to v | ||
# # this can be fully parallelized | ||
# dk2, dv, dbeta, dg = prepare_wy_repr_bwd( | ||
# k=k, | ||
# v=v, | ||
# beta=beta, | ||
# g=g_interleaved, | ||
# A=A, | ||
# dw=dw, # Use key gradients from output as weights gradients | ||
# du=du, # Use value gradients from hidden state backward | ||
# cu_seqlens=cu_seqlens_dp, | ||
# ) | ||
|
||
# # accumulate gradients | ||
# # dk_final = dk_o + dk2 | ||
# # dk_final = dk2 # should there be (Q[i] K[i]^T \cdot M) | ||
# dv_final = dv | ||
# dg_final = dg | ||
|
||
# # process gating gradients with local cumsum (reverse) | ||
# if g is not None: | ||
# from fla.ops.utils import chunk_local_cumsum | ||
# dg_final = chunk_local_cumsum(dg_final, chunk_size=64, reverse=True, cu_seqlens=cu_seqlens_dp) | ||
|
||
# # Convert interleaved gating gradients back to original format | ||
# dg_final = rearrange(dg_final, 'b (l n) h -> b l n h', n=num_householder)[:, :, 0].contiguous() | ||
# else: | ||
# dg_final = None | ||
|
||
# return dq_o, dk, dv_final, dg_final, dbeta, dh0 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Remove commented alternative implementation.
The function contains 80+ lines of commented code that appears to be an alternative implementation approach. This makes the code harder to read and maintain. If this alternative approach is needed for reference, consider moving it to documentation or a separate development branch.
🤖 Prompt for AI Agents
In fla/ops/gated_delta_product/chunk.py between lines 111 and 197, there is a
large block of commented-out code representing an alternative implementation.
This commented code should be removed to improve code readability and
maintainability. If the alternative implementation needs to be preserved for
reference, move it to documentation or a separate development branch instead of
keeping it commented in the main code.
@phi-jkim very nice job! could you please revert the changes on legacy code and simpla GLA README, which I think is unrelated to this PR. Free free to create another one if necessary |
from lm_eval.__main__ import cli_evaluate | ||
from lm_eval.api.registry import register_model | ||
from lm_eval.models.huggingface import HFLM | ||
|
||
import fla # noqa | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should format this part in somewhere else rather than in this PR
# if USE_G: | ||
# g += bos * H + i_h | ||
# p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,)) | ||
# b_g = tl.load(p_g, boundary_check=(0,)) | ||
# m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t) | ||
# b_A = tl.where(m_A, b_A * exp(b_g[:, None] - b_g[None, :]), 0) | ||
# else: | ||
# m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t) | ||
# b_A = tl.where(m_A, b_A, 0) | ||
|
||
# # Load values for this Householder step | ||
# p_v = tl.make_block_ptr(v+i_dp*H*V, (T, V), (H*V*num_householder, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
# p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
# b_v = tl.load(p_v, boundary_check=(0, 1)) | ||
# b_do = tl.load(p_do, boundary_check=(0, 1)) | ||
|
||
# # Gradient w.r.t. values: dv += A^T @ do | ||
# b_dv += tl.dot(tl.trans(b_A.to(b_v.dtype)), b_do) | ||
|
||
# # Gradient w.r.t. attention scores: ds = do @ v^T | ||
# b_ds += tl.dot(b_do, tl.trans(b_v)) | ||
|
||
# # Apply scale and gating to score gradients | ||
# b_ds = b_ds * scale | ||
# if USE_G: | ||
# b_ds = tl.where(m_A, b_ds * exp(b_g[:, None] - b_g[None, :]), 0) | ||
# else: | ||
# b_ds = tl.where(m_A, b_ds, 0) | ||
|
||
# # Compute final gradients for each Householder step | ||
# for i_dp in range(num_householder): | ||
# for i_k in range(tl.cdiv(K, BK)): | ||
# p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
# p_k = tl.make_block_ptr(k+i_dp*H*K, (K, T), (1, num_householder*H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) | ||
# p_dq = tl.make_block_ptr(dq, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
# p_dk = tl.make_block_ptr(dk+i_dp*H*K, (K, T), (1, num_householder*H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) | ||
|
||
# b_q = tl.load(p_q, boundary_check=(0, 1)) | ||
# b_k = tl.load(p_k, boundary_check=(0, 1)) | ||
|
||
# # dq += ds @ k^T | ||
# b_dq += tl.dot(b_ds, tl.trans(b_k)) | ||
# # dk += q^T @ ds | ||
# b_dk = tl.dot(tl.trans(b_q), b_ds) | ||
|
||
# tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) | ||
# tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) | ||
|
||
# # Store value gradients | ||
# for i_dp in range(num_householder): | ||
# p_dv = tl.make_block_ptr(dv+i_dp*H*V, (T, V), (H*V*num_householder, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
# tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) | ||
|
||
|
||
# def chunk_gated_delta_product_bwd_o( | ||
# q: torch.Tensor, | ||
# k: torch.Tensor, | ||
# v: torch.Tensor, | ||
# h: torch.Tensor, | ||
# g: Optional[torch.Tensor] = None, | ||
# do: torch.Tensor = None, | ||
# scale: Optional[float] = None, | ||
# cu_seqlens: Optional[torch.LongTensor] = None, | ||
# chunk_size: int = 64, | ||
# num_householder: int = 1, | ||
# ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
# assert q.shape[1] * num_householder == k.shape[1], "q.shape[1] * num_householder must be equal to k.shape[1]" | ||
# B, T, H, K, V = *q.shape, v.shape[-1] | ||
# BT = min(chunk_size, max(16, triton.next_power_of_2(T))) | ||
# chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None | ||
# NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) | ||
|
||
# dq = torch.zeros_like(q) | ||
# dk = torch.zeros_like(k) | ||
# dv = torch.zeros_like(v) | ||
# dh = torch.zeros_like(h) | ||
|
||
# def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) | ||
# chunk_gated_delta_product_bwd_kernel_o[grid]( | ||
# q, | ||
# k, | ||
# v, | ||
# h, | ||
# g, | ||
# do, | ||
# dq, | ||
# dk, | ||
# dv, | ||
# dh, | ||
# cu_seqlens, | ||
# chunk_indices, | ||
# scale, | ||
# T=T, | ||
# num_householder=num_householder, | ||
# H=H, | ||
# K=K, | ||
# V=V, | ||
# BT=BT, | ||
# ) | ||
# return dq, dk, dv, dh |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@phi-jkim as suggested by coderabbitai
@phi-jkim Seems that current GDP code can't pass our modeling & generation tests in H100 CIs |
[Gated Delta Product] Optimize kernels and implement backward pass
[Core] Add comprehensive backward pass implementation for gated delta product rule
[Kernels] Enhance chunk_deltaproduct_o.py with backward kernel infrastructure
[Kernels] Extend chunk_deltaproduct_h.py with hidden state backward functionality
Summary by CodeRabbit
New Features
Documentation
Style
No changes to user-facing interfaces or configuration semantics were made.