Skip to content

Delta Product Rule Backwards Kernel #526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

phi-jkim
Copy link

@phi-jkim phi-jkim commented Jul 14, 2025

[Gated Delta Product] Optimize kernels and implement backward pass

[Core] Add comprehensive backward pass implementation for gated delta product rule

- Implement chunk_gated_delta_product_bwd function with gradient computation for q, k, v, g parameters
- Add support for sequential hidden state gradient computation through H_0 to H_T
- Integrate WY representation gradient handling with Householder transformations

[Kernels] Enhance chunk_deltaproduct_o.py with backward kernel infrastructure

- Add backward kernel skeleton for output computation gradients

[Kernels] Extend chunk_deltaproduct_h.py with hidden state backward functionality

- Modify to take num_householder parameter for multi-step updates per token 

Summary by CodeRabbit

  • New Features

    • Added advanced backward computation and modular gradient support for chunked gated delta product operations, improving differentiability and consistency for models using this mechanism.
    • Introduced support for Householder transformations in chunked backward operations, enhancing flexibility for sequence modeling.
  • Documentation

    • Minor formatting and whitespace improvements in various documentation and configuration files for better readability.
  • Style

    • Refactored and reordered import statements in several files to improve code clarity and maintainability.

No changes to user-facing interfaces or configuration semantics were made.

Copy link

coderabbitai bot commented Jul 14, 2025

Walkthrough

This update introduces a new explicit backward function and kernel for the chunked gated delta product operation, enhancing gradient computation with modular and detailed logic. It also adds support for a "num_householder" parameter in relevant kernels and functions. Additional changes include minor import reorderings, whitespace adjustments, and the addition of newlines in configuration files.

Changes

File(s) Change Summary
fla/ops/gated_delta_product/chunk.py Introduced a new explicit backward function (chunk_gated_delta_product_bwd) for the gated delta product operation, integrated into the autograd function, and removed old backward function imports.
fla/ops/gated_delta_product/chunk_deltaproduct_h.py Added num_householder parameter to the backward kernel and Python wrapper, adjusted chunking/indexing logic, and corrected variable naming for output gradients.
fla/ops/gated_delta_product/chunk_deltaproduct_o.py Added a new Triton backward kernel and Python function to compute gradients for chunked gated delta product with Householder transformations, supporting gating and variable-length sequences.
fla/ops/simple_gla/README.md
legacy/training/README.md
Removed trailing spaces and made minor whitespace/formatting adjustments in README files.
legacy/training/configs/gla_1B.json
legacy/training/configs/gla_340M.json
legacy/training/configs/gla_7B.json
legacy/training/configs/transformer_340M.json
Added a newline character at the end of each JSON configuration file.
evals/harness.py
legacy/training/flame/logging.py
legacy/training/flame/parser.py
legacy/training/run.py
Reordered or consolidated import statements; no changes to logic or exported entities.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant Model
    participant BackwardFunction
    participant BackwardKernel

    User->>Model: Forward pass (chunked gated delta product)
    Model->>BackwardFunction: Save intermediates for backward
    User->>Model: Backward pass (gradients)
    Model->>BackwardFunction: Recompute intermediates
    BackwardFunction->>BackwardKernel: Launch backward kernel with parameters (incl. num_householder)
    BackwardKernel-->>BackwardFunction: Compute gradients dq, dk, dv, dh, etc.
    BackwardFunction-->>Model: Return gradients
Loading

Suggested reviewers

  • yzhangcs

Poem

In the warren of code, a new path unfurled,
With gradients explicit, and Householders twirled.
Kernels now backward, chunked with great care,
Imports aligned, and configs with air.
A hop and a skip—this merge is complete!
🐇✨

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai auto-generate unit tests to generate unit tests for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (4)
fla/ops/simple_gla/README.md (1)

3-7: Trailing-whitespace trimmed – nice cleanup.

Consider adding an editorconfig or pre-commit hook to keep whitespace tidy automatically.

evals/harness.py (1)

14-18: Nit: incorrect return-type annotation on __init__.

__init__ should return None, not the class itself. While not harmful at runtime, type-checkers will flag it.

-    def __init__(self, **kwargs) -> FlashLinearAttentionLMWrapper:
+    def __init__(self, **kwargs) -> None:
legacy/training/flame/parser.py (1)

10-17: Shadowing the imported TrainingArguments is intentional but can confuse IDEs.

Redefining TrainingArguments via subclassing is common, yet some static-analysis tools raise warnings.
Consider renaming the custom class (e.g., GLATrainingArguments) to avoid confusion.

fla/ops/gated_delta_product/chunk.py (1)

427-438: Consider caching intermediate values to avoid recomputation.

The backward pass recomputes h and v_new which were already computed in the forward pass. While the comment asks "why can't we store this originally in ctx", storing these values would trade memory for computation time. Consider:

  1. Analyzing the memory vs computation trade-off
  2. Adding a flag to optionally cache these values for users who prioritize speed over memory
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 91a73fe and 35a7acf.

📒 Files selected for processing (13)
  • evals/harness.py (1 hunks)
  • fla/ops/gated_delta_product/chunk.py (2 hunks)
  • fla/ops/gated_delta_product/chunk_deltaproduct_h.py (11 hunks)
  • fla/ops/gated_delta_product/chunk_deltaproduct_o.py (1 hunks)
  • fla/ops/simple_gla/README.md (1 hunks)
  • legacy/training/README.md (6 hunks)
  • legacy/training/configs/gla_1B.json (1 hunks)
  • legacy/training/configs/gla_340M.json (1 hunks)
  • legacy/training/configs/gla_7B.json (1 hunks)
  • legacy/training/configs/transformer_340M.json (1 hunks)
  • legacy/training/flame/logging.py (1 hunks)
  • legacy/training/flame/parser.py (1 hunks)
  • legacy/training/run.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
legacy/training/README.md (1)
Learnt from: toothacher17
PR: fla-org/flash-linear-attention#395
File: fla/layers/mla.py:0-0
Timestamp: 2025-05-13T06:04:24.342Z
Learning: In Flash Attention's `flash_attn_varlen_func`, the scaling factor (typically computed as head_dim ** -0.5) is handled internally by the Flash Attention implementation rather than needing to be applied explicitly to query projections.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (13)
legacy/training/configs/gla_7B.json (1)

28-29: No-op change – newline only.

Nothing to flag. 👍

legacy/training/configs/gla_340M.json (1)

24-25: No-op change – newline only.

Formatting consistency looks good.

legacy/training/configs/transformer_340M.json (1)

18-19: No-op change – newline only.

All good here.

legacy/training/configs/gla_1B.json (1)

25-26: No-op change – newline only.

LGTM.

legacy/training/README.md (1)

10-11: Pure formatting change – no action required.

Only whitespace/markdown-layout tweaks were made; content and references remain intact.
No further review comments.

Also applies to: 38-40, 106-110, 114-115, 141-145, 174-176

evals/harness.py (1)

9-10: Confirm side-effect import order still works.

import fla # noqa is kept for its registration side-effects.
By moving it below the Hugging Face imports you’ve avoided circular-import risk, but please double-check that any @register_model decorators inside fla still execute before cli_evaluate() introspects the registry.

@@
-import fla  # noqa
+import fla  # noqa  # side-effects: registers FLA model classes

If unit tests cover lm_eval model loading this is already verified; otherwise consider a quick smoke test.

legacy/training/flame/logging.py (1)

9-10: LGTM – import consolidation only.

Switching to a single-line import has no functional impact and passes black/isort settings.

legacy/training/run.py (1)

7-10: Verify that delayed fla import does not defer monkey-patches needed before model instantiation.

Some fla sub-modules patch HF layers at import time. Because AutoModelForCausalLM.from_pretrained is now executed before any fla code runs, ensure no critical monkey-patch is required for forward/backward compatibility. If patches are essential, move the import back above the HF call or perform an explicit fla.patch_all() right after importing.

fla/ops/gated_delta_product/chunk_deltaproduct_h.py (3)

250-252: Good fix for chunk offset calculation consistency.

The update to boh calculation to include num_householder correctly aligns the backward kernel with the forward kernel's chunking logic.


319-319: Correct variable naming fix.

Good catch on fixing the variable names from p_wo/b_wo to p_do/b_do to accurately represent the gradient of the output.

Also applies to: 322-322


475-475: Verify divisibility assertion is appropriate.

The assertion T % num_householder == 0 enforces strict divisibility. Ensure this requirement is documented and validated at higher levels to provide better error messages to users.

Also applies to: 485-485

fla/ops/gated_delta_product/chunk.py (2)

249-251: Add clarification for the TODO comment.

The TODO comment "implement delta product version of chunk_bwd_dqkwg" suggests using a non-optimized version. Please clarify:

  1. What specific optimizations are missing?
  2. Is there a performance impact?
  3. Is there a timeline for implementing the optimized version?

292-293: Good implementation of gradient accumulation.

The gradient accumulation using in-place addition and the clear comment explaining the multivariate chain rule makes the code easier to understand and maintain.

Comment on lines +156 to +363
# if USE_G:
# g += bos * H + i_h
# p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,))
# b_g = tl.load(p_g, boundary_check=(0,))
# m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
# b_A = tl.where(m_A, b_A * exp(b_g[:, None] - b_g[None, :]), 0)
# else:
# m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
# b_A = tl.where(m_A, b_A, 0)

# # Load values for this Householder step
# p_v = tl.make_block_ptr(v+i_dp*H*V, (T, V), (H*V*num_householder, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# b_v = tl.load(p_v, boundary_check=(0, 1))
# b_do = tl.load(p_do, boundary_check=(0, 1))

# # Gradient w.r.t. values: dv += A^T @ do
# b_dv += tl.dot(tl.trans(b_A.to(b_v.dtype)), b_do)

# # Gradient w.r.t. attention scores: ds = do @ v^T
# b_ds += tl.dot(b_do, tl.trans(b_v))

# # Apply scale and gating to score gradients
# b_ds = b_ds * scale
# if USE_G:
# b_ds = tl.where(m_A, b_ds * exp(b_g[:, None] - b_g[None, :]), 0)
# else:
# b_ds = tl.where(m_A, b_ds, 0)

# # Compute final gradients for each Householder step
# for i_dp in range(num_householder):
# for i_k in range(tl.cdiv(K, BK)):
# p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
# p_k = tl.make_block_ptr(k+i_dp*H*K, (K, T), (1, num_householder*H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
# p_dq = tl.make_block_ptr(dq, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
# p_dk = tl.make_block_ptr(dk+i_dp*H*K, (K, T), (1, num_householder*H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))

# b_q = tl.load(p_q, boundary_check=(0, 1))
# b_k = tl.load(p_k, boundary_check=(0, 1))

# # dq += ds @ k^T
# b_dq += tl.dot(b_ds, tl.trans(b_k))
# # dk += q^T @ ds
# b_dk = tl.dot(tl.trans(b_q), b_ds)

# tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
# tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))

# # Store value gradients
# for i_dp in range(num_householder):
# p_dv = tl.make_block_ptr(dv+i_dp*H*V, (T, V), (H*V*num_householder, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))


# def chunk_gated_delta_product_bwd_o(
# q: torch.Tensor,
# k: torch.Tensor,
# v: torch.Tensor,
# h: torch.Tensor,
# g: Optional[torch.Tensor] = None,
# do: torch.Tensor = None,
# scale: Optional[float] = None,
# cu_seqlens: Optional[torch.LongTensor] = None,
# chunk_size: int = 64,
# num_householder: int = 1,
# ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# assert q.shape[1] * num_householder == k.shape[1], "q.shape[1] * num_householder must be equal to k.shape[1]"
# B, T, H, K, V = *q.shape, v.shape[-1]
# BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
# chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
# NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)

# dq = torch.zeros_like(q)
# dk = torch.zeros_like(k)
# dv = torch.zeros_like(v)
# dh = torch.zeros_like(h)

# def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
# chunk_gated_delta_product_bwd_kernel_o[grid](
# q,
# k,
# v,
# h,
# g,
# do,
# dq,
# dk,
# dv,
# dh,
# cu_seqlens,
# chunk_indices,
# scale,
# T=T,
# num_householder=num_householder,
# H=H,
# K=K,
# V=V,
# BT=BT,
# )
# return dq, dk, dv, dh
Copy link

@coderabbitai coderabbitai bot Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider removing or documenting the purpose of the commented backward kernel code.

The PR objectives mention adding a "backward kernel skeleton", but having 200+ lines of commented code without clear documentation about its purpose or timeline for implementation can lead to confusion and maintenance issues. Consider:

  1. If this is a work-in-progress, add a clear TODO comment explaining the implementation timeline
  2. If this is example/reference code, move it to documentation
  3. If this is no longer needed, remove it entirely
🤖 Prompt for AI Agents
In fla/ops/gated_delta_product/chunk_deltaproduct_o.py between lines 156 and
363, there is over 200 lines of commented-out backward kernel code without
explanation. To fix this, either add a clear TODO comment at the top explaining
that this is a work-in-progress and when it will be implemented, or if it is
reference/example code, move it to separate documentation files. If the code is
obsolete or no longer needed, remove it entirely to avoid confusion and improve
maintainability.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phi-jkim as suggested by coderabbitai

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

Comment on lines +111 to +197
# from fla.ops.gated_delta_product.chunk_deltaproduct_h import chunk_gated_delta_product_bwd_dhu
# from fla.ops.gated_delta_product.chunk_deltaproduct_o import chunk_gated_delta_product_bwd_o

# cu_seqlens_dp = cu_seqlens * num_householder if cu_seqlens is not None else None

# # compute gradients wrt q, k, v
# # dv_new_o is gradient w.r.t. v_new
# # might not be fully parallelizable due to gradient Q including previous states H_0 to H_T
# dq_o, dv_new_o, dk_o, dh = chunk_gated_delta_product_bwd_o(
# q=q,
# k=k,
# v=v_new, # v_new = U[i] - W[i]H[i]^T
# h=h,
# g=g, # forward_h takes in g instead of g_interleaved
# scale=scale,
# cu_seqlens=cu_seqlens,
# num_householder=num_householder,
# do=do, # gradient of the output
# )

# # recompute w, u from WY representation to compute gradient
# from fla.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd as gdn_recompute_w_u_fwd
# from fla.ops.delta_rule.wy_fast import recompute_w_u_fwd as dn_recompute_w_u_fwd

# if g_interleaved is not None:
# w, u = gdn_recompute_w_u_fwd(
# k=k, v=v, beta=beta, A=A, g=g_interleaved, cu_seqlens=cu_seqlens_dp,
# )
# else:
# w, u = dn_recompute_w_u_fwd(
# k=k, v=v, beta=beta, A=A, cu_seqlens=cu_seqlens_dp,
# )

# # compute gradients with respect to u and w
# # but need to account for gradients used for sequential computation of hidden states of H_0 to H_T (sequential)
# dh0, du, dw = chunk_gated_delta_product_bwd_dhu(
# q=q,
# k=k,
# w=w,
# u=u,
# g=g_interleaved,
# initial_state=initial_state, #H_0
# cu_seqlens=cu_seqlens_dp,
# num_householder=num_householder,
# dht=dht, # gradient w.r.t to last hidden state
# dv=dv_new_o, # gradient w.r.t. v_new
# scale=scale,
# )

# # compute gradients w.r.t. WY representation (dk, dv, dbeta, dg)
# # This involves computing gradients through the Householder transformations
# from fla.ops.gated_delta_rule.wy_fast import prepare_wy_repr_bwd

# # gradient descent from W and U
# # g is used for computing W and U in the forward pass
# # dk2 accounts for the gradient of hidden state wrt to k
# # dv_final is the gradient of hidden state wrt to v
# # this can be fully parallelized
# dk2, dv, dbeta, dg = prepare_wy_repr_bwd(
# k=k,
# v=v,
# beta=beta,
# g=g_interleaved,
# A=A,
# dw=dw, # Use key gradients from output as weights gradients
# du=du, # Use value gradients from hidden state backward
# cu_seqlens=cu_seqlens_dp,
# )

# # accumulate gradients
# # dk_final = dk_o + dk2
# # dk_final = dk2 # should there be (Q[i] K[i]^T \cdot M)
# dv_final = dv
# dg_final = dg

# # process gating gradients with local cumsum (reverse)
# if g is not None:
# from fla.ops.utils import chunk_local_cumsum
# dg_final = chunk_local_cumsum(dg_final, chunk_size=64, reverse=True, cu_seqlens=cu_seqlens_dp)

# # Convert interleaved gating gradients back to original format
# dg_final = rearrange(dg_final, 'b (l n) h -> b l n h', n=num_householder)[:, :, 0].contiguous()
# else:
# dg_final = None

# return dq_o, dk, dv_final, dg_final, dbeta, dh0

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Remove commented alternative implementation.

The function contains 80+ lines of commented code that appears to be an alternative implementation approach. This makes the code harder to read and maintain. If this alternative approach is needed for reference, consider moving it to documentation or a separate development branch.

🤖 Prompt for AI Agents
In fla/ops/gated_delta_product/chunk.py between lines 111 and 197, there is a
large block of commented-out code representing an alternative implementation.
This commented code should be removed to improve code readability and
maintainability. If the alternative implementation needs to be preserved for
reference, move it to documentation or a separate development branch instead of
keeping it commented in the main code.

@yzhangcs
Copy link
Member

@phi-jkim very nice job! could you please revert the changes on legacy code and simpla GLA README, which I think is unrelated to this PR. Free free to create another one if necessary

from lm_eval.__main__ import cli_evaluate
from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM

import fla # noqa

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should format this part in somewhere else rather than in this PR

Comment on lines +156 to +363
# if USE_G:
# g += bos * H + i_h
# p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,))
# b_g = tl.load(p_g, boundary_check=(0,))
# m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
# b_A = tl.where(m_A, b_A * exp(b_g[:, None] - b_g[None, :]), 0)
# else:
# m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
# b_A = tl.where(m_A, b_A, 0)

# # Load values for this Householder step
# p_v = tl.make_block_ptr(v+i_dp*H*V, (T, V), (H*V*num_householder, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# b_v = tl.load(p_v, boundary_check=(0, 1))
# b_do = tl.load(p_do, boundary_check=(0, 1))

# # Gradient w.r.t. values: dv += A^T @ do
# b_dv += tl.dot(tl.trans(b_A.to(b_v.dtype)), b_do)

# # Gradient w.r.t. attention scores: ds = do @ v^T
# b_ds += tl.dot(b_do, tl.trans(b_v))

# # Apply scale and gating to score gradients
# b_ds = b_ds * scale
# if USE_G:
# b_ds = tl.where(m_A, b_ds * exp(b_g[:, None] - b_g[None, :]), 0)
# else:
# b_ds = tl.where(m_A, b_ds, 0)

# # Compute final gradients for each Householder step
# for i_dp in range(num_householder):
# for i_k in range(tl.cdiv(K, BK)):
# p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
# p_k = tl.make_block_ptr(k+i_dp*H*K, (K, T), (1, num_householder*H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
# p_dq = tl.make_block_ptr(dq, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
# p_dk = tl.make_block_ptr(dk+i_dp*H*K, (K, T), (1, num_householder*H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))

# b_q = tl.load(p_q, boundary_check=(0, 1))
# b_k = tl.load(p_k, boundary_check=(0, 1))

# # dq += ds @ k^T
# b_dq += tl.dot(b_ds, tl.trans(b_k))
# # dk += q^T @ ds
# b_dk = tl.dot(tl.trans(b_q), b_ds)

# tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
# tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))

# # Store value gradients
# for i_dp in range(num_householder):
# p_dv = tl.make_block_ptr(dv+i_dp*H*V, (T, V), (H*V*num_householder, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
# tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))


# def chunk_gated_delta_product_bwd_o(
# q: torch.Tensor,
# k: torch.Tensor,
# v: torch.Tensor,
# h: torch.Tensor,
# g: Optional[torch.Tensor] = None,
# do: torch.Tensor = None,
# scale: Optional[float] = None,
# cu_seqlens: Optional[torch.LongTensor] = None,
# chunk_size: int = 64,
# num_householder: int = 1,
# ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# assert q.shape[1] * num_householder == k.shape[1], "q.shape[1] * num_householder must be equal to k.shape[1]"
# B, T, H, K, V = *q.shape, v.shape[-1]
# BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
# chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
# NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)

# dq = torch.zeros_like(q)
# dk = torch.zeros_like(k)
# dv = torch.zeros_like(v)
# dh = torch.zeros_like(h)

# def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
# chunk_gated_delta_product_bwd_kernel_o[grid](
# q,
# k,
# v,
# h,
# g,
# do,
# dq,
# dk,
# dv,
# dh,
# cu_seqlens,
# chunk_indices,
# scale,
# T=T,
# num_householder=num_householder,
# H=H,
# K=K,
# V=V,
# BT=BT,
# )
# return dq, dk, dv, dh
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phi-jkim as suggested by coderabbitai

@yzhangcs
Copy link
Member

@phi-jkim Seems that current GDP code can't pass our modeling & generation tests in H100 CIs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants