Fix output_qk corruption, softcap ordering, and test coverage for ONNX Attention op#27992
Fix output_qk corruption, softcap ordering, and test coverage for ONNX Attention op#27992titaiwangms wants to merge 4 commits intomicrosoft:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR extends ONNX Runtime’s CUDA Attention operator to support softcap in the unfused CUDA path and adds Memory Efficient Attention (CUTLASS FMHA) decode support (past/present KV cache), along with expanded ONNX/Python/C++ tests and updated backend test filters.
Changes:
- Add
softcapplumbing (AttentionParameters.softcap) and apply softcap to unfused attention logits via a CUDA kernel before softmax. - Implement MEA decode by concatenating past+new KV into the present buffer (via
LaunchConcatNewToPastKV) and updating kernel selection/verbosity logs. - Add/adjust Python and C++ tests for MEA decode, unfused softcap, bool/float masks, and asymmetric head-size fallback; update ONNX backend filters accordingly.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/core/providers/cuda/llm/attention.cc | Adds MEA decode path (concat KV into present), propagates softcap to unfused, and adds verbose kernel-selection logging. |
| onnxruntime/contrib_ops/cuda/bert/attention_impl.cu | Implements ApplySoftcap CUDA kernel and applies it in unfused attention before softmax. |
| onnxruntime/contrib_ops/cpu/bert/attention_parameters.h | Adds softcap to AttentionParameters. |
| onnxruntime/test/providers/cpu/llm/attention_op_test.cc | Enables CUDA softcap tests and adds a CUDA MEA decode regression test (forced via env var). |
| onnxruntime/test/python/transformers/test_onnx_attention/common.py | Adds v_head_size support to graph IO shapes/bindings; adds MEA alignment helper for decode+mask tests. |
| onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py | Adds MEA decode tests (fp16/fp32, bool/float masks), unfused softcap tests, and asymmetric head-size fallback regression test. |
| onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py | Re-enables/extends GQA MEA decode tests (including bf16) and adjusts padding-mask cases for MEA alignment. |
| onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc | Updates exclusions/comments for attention backend tests based on new softcap/decode behavior. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
1. attention.cc: Replace ORT_ENFORCE for present_key/present_value with scratch buffer allocation when outputs are nullptr. MEA decode now works even when present outputs are not requested. Use ORT_RETURN_IF_NOT for user-facing validation (past_value, nonpad_kv_seqlen, head_size). 2. attention_impl.cu: Replace ORT_ENFORCE(total_elements > 0) with early return for zero elements, since q_sequence_length=0 is valid. Per Copilot review on PR microsoft#27992. Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
titaiwangms
left a comment
There was a problem hiding this comment.
Probably no but check whether we are not breaking graph at any point. For example, #27484
Also, does nonpad_kv_seqlens paths totally unrelated to gap table? I don't see it's mentioned at all.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 9 out of 9 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
tianleiwu
left a comment
There was a problem hiding this comment.
Thanks for the updates here. I re-checked the current head and the earlier MEA optional-output, softcap zero-size, grid-stride, and mask-test concerns look addressed. I found one remaining correctness issue in the new unfused softcap path: kQK output is copied after the logits have already been softcapped, and in the softcap+bias branch it is not copied at all.
c74bbaa to
e5f74e9
Compare
1. attention.cc: Replace ORT_ENFORCE for present_key/present_value with scratch buffer allocation when outputs are nullptr. MEA decode now works even when present outputs are not requested. Use ORT_RETURN_IF_NOT for user-facing validation (past_value, nonpad_kv_seqlen, head_size). 2. attention_impl.cu: Replace ORT_ENFORCE(total_elements > 0) with early return for zero elements, since q_sequence_length=0 is valid. Per Copilot review on PR microsoft#27992. Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
e5f74e9 to
5afc98e
Compare
tianleiwu
left a comment
There was a problem hiding this comment.
APPROVE
The implementation is sound, addresses the prior output_qk ordering concern (resolved that thread), and provides comprehensive test coverage across MEA decode, unfused softcap, spec-correct mask→softcap ordering, and the NaN fix. No high-severity issues found.
Highlights
output_qkcopy correctly moved before bias/softcap mutations — prior CHANGES_REQUESTED concern resolved.- MEA decode path has well-managed scratch buffer lifetimes;
kv_is_bsnhtracking correctly propagates through GQA expansion. mask_filter_valuecap (-1e+30f) is correctly scoped to MEA only and the math is verified.- Test coverage is thorough: C++ and Python, MEA/unfused, prompt/decode, fp16/bf16/fp32, various mask dims.
Suggestions (non-blocking)
- Causal mask ordering with softcap: In the
has_softcap && has_biasbranch,ComputeSoftmaxappliesis_unidirectionalcausal masking after softcap. Strict ONNX spec folds causal into the mask before softcap. This matches the CPU reference and tests pass, but a TODO comment would help track future spec alignment. - Defensive assertion: The
has_softcap && has_biasbranch bypassesuse_raw_attention_maskhandling. Safe today (ONNX domain converts all masks todata.attention_biasbefore unfused), but addingassert(!use_raw_attention_mask)guards against future contrib softcap use wheremask_indexcould coexist. - Bool mask test filters:
test_attention_4d_attn_mask_bool_cudaand_4d_cudaare still excluded with "may work now" TODOs. Consider removing or filing a follow-up issue.
|
Nonblocking suggestions will be addressed in follow-up. |
367c23e to
d9cc2d8
Compare
|
Wait for #28198 to merge first. |
28e2ba9 to
e8813fc
Compare
1. attention.cc: Replace ORT_ENFORCE for present_key/present_value with scratch buffer allocation when outputs are nullptr. MEA decode now works even when present outputs are not requested. Use ORT_RETURN_IF_NOT for user-facing validation (past_value, nonpad_kv_seqlen, head_size). 2. attention_impl.cu: Replace ORT_ENFORCE(total_elements > 0) with early return for zero elements, since q_sequence_length=0 is valid. Per Copilot review on PR microsoft#27992. Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6] Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
e8813fc to
be8557e
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 12 out of 12 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- Copy raw QK logits to output_qk BEFORE any softcap or bias mutations (CopyQK was after ApplySoftcap, corrupting output_qk with clamped values) - Reorder unfused path: softcap → mask/bias → softmax (per onnx/onnx#7865) Old ordering (mask → softcap → softmax) leaked probability to masked positions because tanh(-inf/sc) = -sc (finite) - Use gsl::narrow<int> for qk_size cast to catch overflow at runtime - Add assert guard against raw attention masks in softcap+bias path - Add softcap field to base AttentionParameters struct Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Remove MEA rejection guard for softcap+mask: CUTLASS FMHA applies softcap before bias (fused in kernel tiles), matching corrected ONNX spec ordering (onnx/onnx#7865). No need to force unfused path for this case. - Restore GQA+FP32 MEA guard lost during rebase: LaunchUngroup only has fp16/bf16 instantiations, so GQA with FP32 must skip MEA. - Add #include <algorithm> for std::max - Relocate kCutlassSafeMaskFilterValue constant to memory_efficient_attention.h - Add MEA decode path support - Fix stale softcap ordering comments throughout Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
C++ tests (attention_op_test.cc): - Attention4DSoftCapOutputQkRawLogits: regression test for output_qk corruption using constant Q=K=1 with analytical expected value Python tests (test_mha.py): - 5 partial masking tests (fp32 prompt/decode, 2D/3D mask, longer seq) - 2 ordering guard tests using poison-value technique (V=1000 at masked pos) - Flash/MEA softcap coverage for MHA and GQA paths Python tests (test_gqa.py): - GQA unfused tests: large head size, causal, past key, softcap+mask, BSNH, FP32 - Flash GQA softcap tests with padding mask Reference fix (common.py): - Fix attention_ref() to apply softcap BEFORE bias/mask (per onnx/onnx#7865) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
be8557e to
36e52eb
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 12 out of 12 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- Fix cmake/onnxruntime_python.cmake to recursively glob test_onnx_attention/ subdirectory so pytest discovers tests in CI - Add SKILL.md docs for cuda-attention-kernel-patterns and cuda-bfloat16-type-traits Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
36e52eb to
7b5ef7a
Compare
| // Element-wise softcap: score = softcap * tanh(score / softcap) | ||
| // Applied after Q*K' GEMM, before softmax, to bound attention logits. | ||
| template <typename T> | ||
| __global__ void ApplySoftcapKernel(T* scores, float softcap_inv, float softcap, int64_t total_elements) { |
There was a problem hiding this comment.
Can we use the unfused kernel UnfusedGqaAttention instead of adding new kernels?
We can add output_qk support in that kernel if needed.
There was a problem hiding this comment.
It's used by MHA route though. Or are you suggesting that we should generalize UnfusedGQA to cover MHA as well? (I don't want to extend MHA kernel directly.)
Summary
Fixes multiple issues in the ONNX standard Attention op CUDA implementation:
Bug Fixes
attention_impl.cu). Previously, output_qk contained softcapped values instead of raw QK logits.Safety Guards
gsl::narrow<int>()for CopyQK size (overflow protection)assert(!use_raw_attention_mask)defensive guardTests Added
Attention4DSoftCapOutputQkRawLogits(C++, runs on CPU+CUDA)Infrastructure
test_onnx_attention/subdirectory was silently skipped by CMake GLOB. Added recursive copy incmake/onnxruntime_python.cmake.kCutlassSafeMaskFilterValuetomemory_efficient_attention.hmea_aligned_past_seq()functionRelated Issues
How to Test