Complete chunkwise GatedDeltaNet#91
Conversation
382153e to
f2a58d5
Compare
adc47f9 to
1e070af
Compare
f2a58d5 to
5bef62d
Compare
1e070af to
dac9940
Compare
…terface other stages without transpose
asobczyk
left a comment
There was a problem hiding this comment.
I was able to take a look in about ~20 of the 113 files, I left some comments / suggestions.
In general it looks good, but i have some general remarks:
- Must-change: The changes under
csrc/kernel/kernel_tri_inv_rec_unroll.cppshould be thoroughly examined and tested in a separate, isolatedPR, with dedicated unit tests. - Nice-to-have: I would avoid special characters in the source code files, such as arrows, "
\mathbb{R}", or greek letters. It is better to be consistent with the variable names that are used by the functions - Nice-to-have: Doxygen-style docstrings are missing -- The current descriptions/docstrings could be translated to doxy-style
- Nice-to-have: Ideally, the main kernels that are used should be ported to
csrc/kernels. OnePRper kernel, with source code, torch integration, and unit tests. I know that this is a devious work so for now I do not mind if we do it in a separate PR
| AICORE inline void CopyOddOrEvenBlocksL1ToL0(SrcL1TileT src, DstL0TileT dst, | ||
| uint32_t block_size) { | ||
| uint32_t block_size, | ||
| bool swap_parity = false) { |
There was a problem hiding this comment.
the changes in this file (in any csrc/ file) should go to a separate MR with unit tests
There was a problem hiding this comment.
I also believe that if the swap_parity is only used for deciding between upper/lower triangular then we can implement it with a much more seemless way just by reading in row-major vs column major manner
| // For left: copy even blocks 0, 2, 4, ... (starting_block=0) | ||
| // For right: copy odd blocks 1, 3, 5, ... (starting_block=1) | ||
| const uint32_t starting_block_index = is_left ? 0 : 1; | ||
| // Default: left→even(0), right→odd(1). swap_parity flips this. |
There was a problem hiding this comment.
It might be better to avoid special characters such as →
| @@ -0,0 +1,263 @@ | |||
| #!/usr/bin/env python3 | |||
| """ | |||
| Benchmark dynamic BSND PTO kernels (bisheng-compiled, ctypes) for chunk GDN. | |||
There was a problem hiding this comment.
It would be helpful to expand the description here about what is being benchmarked, and how (a birds-eye view)
| // stream = NPU stream for async execution (like CUDA stream) | ||
| // rtGetC2cCtrlAddr: gets the FFTS control address for cross-core sync | ||
| // <<<block_dim, nullptr, stream>>>: NPU kernel launch syntax (like CUDA <<<>>>) | ||
| extern "C" void call_kernel( |
There was a problem hiding this comment.
I am not a big fan of call_kernel name, especially when it becomes an extern "C" name. In all our kernels we use a descriptive name, e.g. in this case something like chunk_cumsum_fp32
| batch_size, seq_len, total_tokens, ffts_addr); | ||
| } | ||
|
|
||
| extern "C" void call_kernel( |
There was a problem hiding this comment.
suggestion for name change: chunk_h_fp16
There was a problem hiding this comment.
call_kernel_chunk_h_fp16 ?
| batch_size, seq_len, total_tokens, ffts_addr); | ||
| } | ||
|
|
||
| // ── Host-side launcher ──────────────────────────────────────────────── |
There was a problem hiding this comment.
might be better to use doxy-style docstring
| if _HERE not in sys.path: | ||
| sys.path.insert(0, _HERE) | ||
|
|
||
| import numpy as np |
There was a problem hiding this comment.
ruff complains, just ensure to apply pre-commit to silence those warnings
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
There was a problem hiding this comment.
these are very useful files, we should eventually adapt them as unit tests under tests/
| @@ -0,0 +1,111 @@ | |||
| #include <pto/pto-inst.hpp> | |||
There was a problem hiding this comment.
the name of this folder is _old. If it is old and deprecated maybe we can remove it completely?
There was a problem hiding this comment.
the name of this folder is
_old. If it is old and deprecated maybe we can remove it completely?
Yes, I am not intended to merge this PR to main, but should instead extract useful pieces out as cleaner PRs.
| @@ -0,0 +1,145 @@ | |||
| #!/usr/bin/env python3 | |||
| """ | |||
| Benchmark mega-kernel vs aggregated per-stage PTO kernels. | |||
There was a problem hiding this comment.
would be helpful to write 1-2 sentences what is being benchmarked (brief overview)
gioelegott
left a comment
There was a problem hiding this comment.
Tested the mega-kernel and all tests pass
| def run_mega_kernel( | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| g_in: torch.Tensor, | ||
| beta: torch.Tensor, | ||
| cu_seqlens: torch.Tensor, | ||
| *, | ||
| chunk_size: int = 128, | ||
| scale: float = 1.0, | ||
| block_dim: int | None = None, | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
The interface is somewhat different from sgl-kernel-npu, but still compatible:
def chunk_gated_delta_rule_npu(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = False,
):
* wip * push cpp code * use backend='pto' * uni test varlen * dump varlen source code with head 32 and 48 variants * fix comment * standalone PTO demo ported from tilelang --------- Co-authored-by: Anastasios Zouzias <anastasios.zouzias@huawei.com> Co-authored-by: learning-chip <jiawei.zhuang@outlook.com>
Finish all the rest part of #88 to support full Qwen3.5 GDN layer.
Reproduce: Compiles and runs with pto-isa commit on April 03. I used this modified vllm-ascend docker image, with triton-ascend pre-installed, so it's easier to compare against triton baseline in vllm.
Performance
Shape:
(N_seq=16, L_seg=16384, H=16, DK=DV=128, C=128), packed varlenBSND with
T=262144.Reproduced by chunk_gdn/dynamic_bsnd vs chunk_gdn/triton_baseline
Accuracy evaluation
Reproduced by chunk_gdn/pto_e2e_measure
Feature list
chunk_gdn/static_baseline/gdn_chain_e2e_static.py)chunk_gdn/dynamic_bsnd)Grouped Value Attentionwherenum_key_head < num_value_head(required by larger Qwen)