Skip to content

Fix varlen ref_gdr chunk offsets after prepare_chunk_offsets return change#16

Open
ReinforcedKnowledge wants to merge 1 commit into
QwenLM:mainfrom
ReinforcedKnowledge:codex-fix-ref-gdr-chunk-offsets
Open

Fix varlen ref_gdr chunk offsets after prepare_chunk_offsets return change#16
ReinforcedKnowledge wants to merge 1 commit into
QwenLM:mainfrom
ReinforcedKnowledge:codex-fix-ref-gdr-chunk-offsets

Conversation

@ReinforcedKnowledge
Copy link
Copy Markdown

Summary

This PR updates those call sites to destructure the return value and pass only chunk_offsets.

Hi! This is a small fix for the varlen reference implementation in tests/ref_gdr.py.

prepare_chunk_offsets now returns (chunk_offsets, num_chunks), and the kernel callers already destructure that return value. A few reference helpers in tests/ref_gdr.py were still passing the full tuple into pack / unpack, which expect a tensor.

This PR updates those call sites to destructure the return value and pass only chunk_offsets.

Reproduction

On current main, a varlen reference forward call can fail before any kernel comparison:

AttributeError: 'tuple' object has no attribute 'shape'
  File "/root/FlashQLA/tests/ref_gdr.py", line 206, in torch_chunk_gdr_fwd
    h = pack(h, prepare_chunk_offsets(cu_seqlens, chunk_size))
  File "/root/FlashQLA/flash_qla/utils/pack.py", line 27, in pack
    assert len(cu_seqlens.shape) == 1
               ^^^^^^^^^^^^^^^^

Minimal code to reproduce it, from tests:

import json

import torch

from ref_gdr import chunk_gated_delta_rule_fwd

torch.manual_seed(123)
device = "cuda"

q = torch.randn((1, 96, 2, 128), device=device)
k = torch.randn((1, 96, 2, 128), device=device)
v = torch.randn((1, 96, 4, 128), device=device)
g = torch.zeros((1, 96, 4), device=device)
beta = torch.ones((1, 96, 4), device=device)
cu_seqlens = torch.tensor([0, 31, 96], device=device, dtype=torch.int32)

outputs = chunk_gated_delta_rule_fwd(
    q=q,
    k=k,
    v=v,
    g=g,
    beta=beta,
    scale=128 ** -0.5,
    initial_state=None,
    cu_seqlens=cu_seqlens,
)

print(json.dumps({
    "output_shapes": [tuple(x.shape) for x in outputs],
    "output_dtypes": [str(x.dtype) for x in outputs],
}))

Calling tests/ref_gdr.py::chunk_gated_delta_rule_fwd(..., cu_seqlens=cu_seqlens) hits the tuple passed into pack.

Verification

After this patch, the same varlen reference forward call completes and returns the expected reference tensors.

Note: I’m still getting familiar with this codebase, so I’d be very happy to adjust the fix if you prefer a different style. The change is intentionally small and follows the existing production caller pattern of destructuring prepare_chunk_offsets. I'd also like to say that I love the work you did here! Trying to learn from it 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant