Skip to content

Fix output_qk corruption, softcap ordering, and test coverage for ONNX Attention op#27992

Open
titaiwangms wants to merge 4 commits intomicrosoft:mainfrom
titaiwangms:feature/mea-decode-support-v2
Open

Fix output_qk corruption, softcap ordering, and test coverage for ONNX Attention op#27992
titaiwangms wants to merge 4 commits intomicrosoft:mainfrom
titaiwangms:feature/mea-decode-support-v2

Conversation

@titaiwangms
Copy link
Copy Markdown
Contributor

@titaiwangms titaiwangms commented Apr 6, 2026

Summary

Fixes multiple issues in the ONNX standard Attention op CUDA implementation:

Bug Fixes

  1. output_qk corruption: Moved CopyQK before softcap/bias mutations in unfused attention kernel (attention_impl.cu). Previously, output_qk contained softcapped values instead of raw QK logits.
  2. Softcap ordering: Corrected ApplySoftcap → AddAttentionBias ordering (was reversed). The ONNX spec itself had this wrong (Softcap in Attention op onnx/onnx#7865, fix: Fix Attention op softcap ordering: apply before mask/bias onnx/onnx#7867).
  3. MEA eligibility guard: Restored GQA+FP32 exclusion from MEA path (lost during rebase from GQA unfused attention with FP32 QK accumulation (fixes #28195) #28198). FP32 GQA now correctly falls through to unfused path.
  4. GQA test bugs from GQA unfused attention with FP32 QK accumulation (fixes #28195) #28198: Fixed mask shape mismatch and missing nonpad_kv_seqlen tensor in 4 GQA large-head unfused tests.

Safety Guards

Tests Added

  • 12 Python softcap tests (ordering guards with poison-value technique, partial masking)
  • G1 regression test: Attention4DSoftCapOutputQkRawLogits (C++, runs on CPU+CUDA)
  • Verifies CopyQK happens before ApplySoftcap using analytical values

Infrastructure

  • CI test discovery fix: test_onnx_attention/ subdirectory was silently skipped by CMake GLOB. Added recursive copy in cmake/onnxruntime_python.cmake.
  • Relocated kCutlassSafeMaskFilterValue to memory_efficient_attention.h
  • Updated SKILL.md with current dispatch cascade and patterns
  • Fixed 7 stale comments referencing wrong softcap ordering
  • Removed dead mea_aligned_past_seq() function
  • Updated opset references from 23 to 23/24

Related Issues

How to Test

# C++ tests (52 attention tests)
cd build/Debug && ./bin/onnxruntime_test_all --gtest_filter='*Attention4D*:*Attention_Gqa*'

# Python tests (217 tests)
python -m pytest onnxruntime/test/python/transformers/test_onnx_attention/ -v

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends ONNX Runtime’s CUDA Attention operator to support softcap in the unfused CUDA path and adds Memory Efficient Attention (CUTLASS FMHA) decode support (past/present KV cache), along with expanded ONNX/Python/C++ tests and updated backend test filters.

Changes:

  • Add softcap plumbing (AttentionParameters.softcap) and apply softcap to unfused attention logits via a CUDA kernel before softmax.
  • Implement MEA decode by concatenating past+new KV into the present buffer (via LaunchConcatNewToPastKV) and updating kernel selection/verbosity logs.
  • Add/adjust Python and C++ tests for MEA decode, unfused softcap, bool/float masks, and asymmetric head-size fallback; update ONNX backend filters accordingly.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
onnxruntime/core/providers/cuda/llm/attention.cc Adds MEA decode path (concat KV into present), propagates softcap to unfused, and adds verbose kernel-selection logging.
onnxruntime/contrib_ops/cuda/bert/attention_impl.cu Implements ApplySoftcap CUDA kernel and applies it in unfused attention before softmax.
onnxruntime/contrib_ops/cpu/bert/attention_parameters.h Adds softcap to AttentionParameters.
onnxruntime/test/providers/cpu/llm/attention_op_test.cc Enables CUDA softcap tests and adds a CUDA MEA decode regression test (forced via env var).
onnxruntime/test/python/transformers/test_onnx_attention/common.py Adds v_head_size support to graph IO shapes/bindings; adds MEA alignment helper for decode+mask tests.
onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py Adds MEA decode tests (fp16/fp32, bool/float masks), unfused softcap tests, and asymmetric head-size fallback regression test.
onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py Re-enables/extends GQA MEA decode tests (including bf16) and adjusts padding-mask cases for MEA alignment.
onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc Updates exclusions/comments for attention backend tests based on new softcap/decode behavior.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/core/providers/cuda/llm/attention.cc Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/attention_impl.cu Outdated
titaiwangms added a commit to titaiwangms/onnxruntime that referenced this pull request Apr 6, 2026
1. attention.cc: Replace ORT_ENFORCE for present_key/present_value with
   scratch buffer allocation when outputs are nullptr. MEA decode now
   works even when present outputs are not requested. Use ORT_RETURN_IF_NOT
   for user-facing validation (past_value, nonpad_kv_seqlen, head_size).
2. attention_impl.cu: Replace ORT_ENFORCE(total_elements > 0) with
   early return for zero elements, since q_sequence_length=0 is valid.

Per Copilot review on PR microsoft#27992.

Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms titaiwangms changed the title Add Memory Efficient Attention decode support and tests for ONNX ONNX Attention CUDA: Add MEA decode support and unfused softcap Apr 6, 2026
Copy link
Copy Markdown
Contributor Author

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably no but check whether we are not breaking graph at any point. For example, #27484

Also, does nonpad_kv_seqlens paths totally unrelated to gap table? I don't see it's mentioned at all.

Comment thread onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
Comment thread onnxruntime/core/providers/cuda/llm/attention.cc
Comment thread onnxruntime/core/providers/cuda/llm/attention.cc Outdated
Comment thread onnxruntime/test/providers/cpu/llm/attention_op_test.cc
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 8 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/contrib_ops/cuda/bert/attention_impl.cu Outdated
@titaiwangms titaiwangms changed the title ONNX Attention CUDA: Add MEA decode support and unfused softcap ONNX Attention CUDA: MEA decode, unfused softcap, and spec-correct softcap ordering Apr 7, 2026
@titaiwangms titaiwangms changed the title ONNX Attention CUDA: MEA decode, unfused softcap, and spec-correct softcap ordering ONNX Attention CUDA: MEA decode, unfused softcap, spec-correct ordering, and NaN fix Apr 7, 2026
@titaiwangms titaiwangms requested a review from Copilot April 7, 2026 23:25
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 8 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/contrib_ops/cuda/bert/attention_impl.cu Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Comment thread onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Comment thread onnxruntime/contrib_ops/cuda/bert/attention_impl.cu Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates here. I re-checked the current head and the earlier MEA optional-output, softcap zero-size, grid-stride, and mask-test concerns look addressed. I found one remaining correctness issue in the new unfused softcap path: kQK output is copied after the logits have already been softcapped, and in the softcap+bias branch it is not copied at all.

Comment thread onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
@justinchuby
Copy link
Copy Markdown
Contributor

Will this fix #28196 and #28195?

@titaiwangms titaiwangms requested a review from tianleiwu April 23, 2026 17:44
@titaiwangms titaiwangms force-pushed the feature/mea-decode-support-v2 branch 2 times, most recently from c74bbaa to e5f74e9 Compare April 23, 2026 21:17
titaiwangms added a commit to titaiwangms/onnxruntime that referenced this pull request Apr 23, 2026
1. attention.cc: Replace ORT_ENFORCE for present_key/present_value with
   scratch buffer allocation when outputs are nullptr. MEA decode now
   works even when present outputs are not requested. Use ORT_RETURN_IF_NOT
   for user-facing validation (past_value, nonpad_kv_seqlen, head_size).
2. attention_impl.cu: Replace ORT_ENFORCE(total_elements > 0) with
   early return for zero elements, since q_sequence_length=0 is valid.

Per Copilot review on PR microsoft#27992.

Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms titaiwangms force-pushed the feature/mea-decode-support-v2 branch from e5f74e9 to 5afc98e Compare April 23, 2026 21:20
tianleiwu
tianleiwu previously approved these changes Apr 23, 2026
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

APPROVE

The implementation is sound, addresses the prior output_qk ordering concern (resolved that thread), and provides comprehensive test coverage across MEA decode, unfused softcap, spec-correct mask→softcap ordering, and the NaN fix. No high-severity issues found.

Highlights

  • output_qk copy correctly moved before bias/softcap mutations — prior CHANGES_REQUESTED concern resolved.
  • MEA decode path has well-managed scratch buffer lifetimes; kv_is_bsnh tracking correctly propagates through GQA expansion.
  • mask_filter_value cap (-1e+30f) is correctly scoped to MEA only and the math is verified.
  • Test coverage is thorough: C++ and Python, MEA/unfused, prompt/decode, fp16/bf16/fp32, various mask dims.

Suggestions (non-blocking)

  1. Causal mask ordering with softcap: In the has_softcap && has_bias branch, ComputeSoftmax applies is_unidirectional causal masking after softcap. Strict ONNX spec folds causal into the mask before softcap. This matches the CPU reference and tests pass, but a TODO comment would help track future spec alignment.
  2. Defensive assertion: The has_softcap && has_bias branch bypasses use_raw_attention_mask handling. Safe today (ONNX domain converts all masks to data.attention_bias before unfused), but adding assert(!use_raw_attention_mask) guards against future contrib softcap use where mask_index could coexist.
  3. Bool mask test filters: test_attention_4d_attn_mask_bool_cuda and _4d_cuda are still excluded with "may work now" TODOs. Consider removing or filing a follow-up issue.

Comment thread onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
@titaiwangms
Copy link
Copy Markdown
Contributor Author

Nonblocking suggestions will be addressed in follow-up.

Comment thread .agents/skills/cuda-attention-kernel-patterns/SKILL.md
@titaiwangms titaiwangms disabled auto-merge April 23, 2026 22:48
Comment thread onnxruntime/core/providers/cuda/llm/attention.cc Outdated
@titaiwangms
Copy link
Copy Markdown
Contributor Author

Wait for #28198 to merge first.

@titaiwangms titaiwangms force-pushed the feature/mea-decode-support-v2 branch 3 times, most recently from 28e2ba9 to e8813fc Compare April 27, 2026 16:48
titaiwangms added a commit to titaiwangms/onnxruntime that referenced this pull request Apr 27, 2026
1. attention.cc: Replace ORT_ENFORCE for present_key/present_value with
   scratch buffer allocation when outputs are nullptr. MEA decode now
   works even when present outputs are not requested. Use ORT_RETURN_IF_NOT
   for user-facing validation (past_value, nonpad_kv_seqlen, head_size).
2. attention_impl.cu: Replace ORT_ENFORCE(total_elements > 0) with
   early return for zero elements, since q_sequence_length=0 is valid.

Per Copilot review on PR microsoft#27992.

Agent-signed-off: Developer (cbe67c8b) [claude-opus-4.6]
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms titaiwangms force-pushed the feature/mea-decode-support-v2 branch from e8813fc to be8557e Compare April 27, 2026 17:29
@titaiwangms titaiwangms changed the title ONNX Attention CUDA: MEA decode, unfused softcap, spec-correct ordering, and NaN fix Fix output_qk corruption, softcap ordering, and test coverage for ONNX Attention op Apr 27, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 12 out of 12 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/core/providers/cuda/llm/attention.cc Outdated
Comment thread onnxruntime/core/providers/cuda/llm/attention.cc
Comment thread onnxruntime/core/providers/cuda/llm/attention.cc Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/attention_impl.cu Outdated
titaiwangms and others added 3 commits April 27, 2026 19:05
- Copy raw QK logits to output_qk BEFORE any softcap or bias mutations
  (CopyQK was after ApplySoftcap, corrupting output_qk with clamped values)
- Reorder unfused path: softcap → mask/bias → softmax (per onnx/onnx#7865)
  Old ordering (mask → softcap → softmax) leaked probability to masked positions
  because tanh(-inf/sc) = -sc (finite)
- Use gsl::narrow<int> for qk_size cast to catch overflow at runtime
- Add assert guard against raw attention masks in softcap+bias path
- Add softcap field to base AttentionParameters struct

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Remove MEA rejection guard for softcap+mask: CUTLASS FMHA applies softcap
  before bias (fused in kernel tiles), matching corrected ONNX spec ordering
  (onnx/onnx#7865). No need to force unfused path for this case.
- Restore GQA+FP32 MEA guard lost during rebase: LaunchUngroup only has
  fp16/bf16 instantiations, so GQA with FP32 must skip MEA.
- Add #include <algorithm> for std::max
- Relocate kCutlassSafeMaskFilterValue constant to memory_efficient_attention.h
- Add MEA decode path support
- Fix stale softcap ordering comments throughout

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
C++ tests (attention_op_test.cc):
- Attention4DSoftCapOutputQkRawLogits: regression test for output_qk corruption
  using constant Q=K=1 with analytical expected value

Python tests (test_mha.py):
- 5 partial masking tests (fp32 prompt/decode, 2D/3D mask, longer seq)
- 2 ordering guard tests using poison-value technique (V=1000 at masked pos)
- Flash/MEA softcap coverage for MHA and GQA paths

Python tests (test_gqa.py):
- GQA unfused tests: large head size, causal, past key, softcap+mask, BSNH, FP32
- Flash GQA softcap tests with padding mask

Reference fix (common.py):
- Fix attention_ref() to apply softcap BEFORE bias/mask (per onnx/onnx#7865)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 12 out of 12 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread .agents/skills/cuda-attention-kernel-patterns/SKILL.md Outdated
Comment thread .agents/skills/cuda-attention-kernel-patterns/SKILL.md Outdated
- Fix cmake/onnxruntime_python.cmake to recursively glob test_onnx_attention/
  subdirectory so pytest discovers tests in CI
- Add SKILL.md docs for cuda-attention-kernel-patterns and
  cuda-bfloat16-type-traits

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms titaiwangms force-pushed the feature/mea-decode-support-v2 branch from 36e52eb to 7b5ef7a Compare April 27, 2026 20:05
// Element-wise softcap: score = softcap * tanh(score / softcap)
// Applied after Q*K' GEMM, before softmax, to bound attention logits.
template <typename T>
__global__ void ApplySoftcapKernel(T* scores, float softcap_inv, float softcap, int64_t total_elements) {
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use the unfused kernel UnfusedGqaAttention instead of adding new kernels?
We can add output_qk support in that kernel if needed.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's used by MHA route though. Or are you suggesting that we should generalize UnfusedGQA to cover MHA as well? (I don't want to extend MHA kernel directly.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ONNX Attention CUDA: Coverage Gaps in Runner Fallback Paths Refactor all of Attention CUDA kernels to have separate kernels and shared functionality

4 participants