Skip to content

[ONNX Attention] Follow-up: causal+softcap ordering, defensive assertion, and backend test filter cleanup #28215

@titaiwangms

Description

@titaiwangms

Context

These are non-blocking suggestions from @tianleiwu's approval review on PR #27992. Tracked as follow-up work under the parent issue #27516.

Note: PR #28198 adds a new GQA unfused attention path with FP32 QK accumulation (gqa_unfused_attention.cu). Items 1 and 2 below may also need to be applied to that new path, which has its own softmax kernel handling causal masking and softcap.


Item 1 — Causal mask ordering with softcap

File: onnxruntime/contrib_ops/cuda/bert/attention_impl.cu (~line 844)
Review comment: #27992 (comment)
Also consider: The GQA unfused kernel in PR #28198 (gqa_unfused_attention.cu) has its own softmax implementation that applies causal masking. Verify whether the same ordering concern applies there.

In the has_softcap && has_bias branch, ComputeSoftmax applies is_unidirectional causal masking after softcap. Under strict ONNX spec reading, the causal mask should be folded into the bias before softcap (causal positions get -inf then softcap(tanh(-inf/sc)) = -sc, a small but nonzero probability). The current code gives softcap(QK) then -inf then exact zero.

Current behavior matches attention_ref in the Python tests (which also applies causal after softcap), so cross-EP consistency is maintained and all tests pass.

Action: Add a TODO comment for future spec alignment if this distinction matters for model accuracy.


Item 2 — Defensive assertion for raw attention mask

File: onnxruntime/contrib_ops/cuda/bert/attention_impl.cu (~line 830)
Review comment: #27992 (comment)
Also consider: The GQA unfused path in PR #28198 routes through RunGqaUnfusedAttention in attention.cc and UnfusedGqaAttention in group_query_attention_impl.cu. Verify that these paths also cannot receive raw attention masks unexpectedly.

The has_softcap && has_bias branch bypasses use_raw_attention_mask and ComputeSoftmaxWithMask1D paths entirely. This is safe today because the ONNX domain converts all masks to data.attention_bias before calling QkvToContext. However, if contrib MultiHeadAttention ever adds softcap support (where mask_index can be set independently), this would silently drop the raw mask.

Action: Add assert(!use_raw_attention_mask) as a defensive guard against future misuse.


Item 3 — Backend test filter cleanup

File: onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc (~line 50)
Review comment: #27992 (comment)

Bool mask tests (test_attention_4d_attn_mask_bool_cuda and _4d_cuda) are still excluded with "may work now" TODOs. With PR #27992's changes (bool masks work via ConvertAttnMaskToBias in both MEA and unfused paths), these exclusions may no longer be necessary.

Action: Investigate whether these tests pass with current changes. Remove exclusions if they pass, or document specific reasoning for keeping them.


Item 4 — Post-merge audit: comprehensive ONNX Attention CUDA coverage review

Trigger: After both PR #27992 and PR #28198 are merged to main.

Once both PRs land, the ONNX standard Attention op's CUDA implementation will have four dispatch paths: Flash → MEA → GQA Unfused → MHA Unfused. Conduct a comprehensive audit to catalog exactly what is supported, what is not, and what is tested.

Audit scope

  1. Dispatch path inventory: For each of the 4 runners (Flash, MEA, GQA Unfused, MHA Unfused), document:

    • Eligibility conditions (dtype, head_size, SM arch, GQA vs MHA, past_key presence, etc.)
    • Which feature combinations are supported (softcap, attn_mask types, causal, output_qk modes, nonpad_kv_seqlen, sliding window)
    • Which combinations are explicitly rejected and why
  2. Feature × path coverage matrix: Build a matrix of (feature combination) × (dispatch path) showing supported/unsupported/untested. Key features:

    • Dtypes: FP16, BF16, FP32
    • Head configs: MHA (q_heads == kv_heads), GQA (q_heads != kv_heads), asymmetric head_size != v_head_size
    • Sequence modes: prompt (no past), decode (with past_key/past_value), nonpad_kv_seqlen
    • Masks: no mask, 2D bool, 2D float, 3D, 4D
    • Softcap: off, on, on + mask, on + causal
    • Output modes: output only, output_qk (kNone, kQK, kQKSoftCap)
    • Head sizes: ≤128, 129-256, 257-512, >512
  3. Test coverage gap analysis: For each supported (feature × path) cell, verify there is at least one test covering it. Flag cells that are supported in code but have no test coverage.

  4. Update coverage tracking: Update the gap table in issue ONNX Attention CUDA: Coverage Gaps in Runner Fallback Paths #27880 with the post-merge state.

References


Parent issue: #27516
Source PR: #27992
Related PR: #28198 (GQA unfused attention with FP32 QK accumulation)

Metadata

Metadata

Assignees

Labels

ep:CUDAissues related to the CUDA execution provider

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions