diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_data.h b/onnxruntime/contrib_ops/cuda/bert/attention_data.h index 486bf05bd86d5..c54a1fea9ad3a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_data.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_data.h @@ -197,10 +197,22 @@ struct GroupQueryAttentionData { bool use_memory_efficient_attention = false; bool use_flash_attention_fast_decode = false; bool use_xqa = false; + // GQA-capable unfused fallback (issue #28195): used when Flash/MEA/XQA are all ineligible, + // e.g. fp16 head_size > 256 with past_key, or GQA on old GPUs without MEA/Flash support. + bool use_unfused = false; // XQA buffer void* xqa_buffer = nullptr; size_t xqa_buffer_bytes = 0; + + // Unfused fallback buffers (see LaunchGqaUnfusedAttention in gqa_unfused_attention.h): + // unfused_q_bnsh : [B, N_q, S_q, H] (Q transposed from BSNH to BNSH) + // unfused_y_bnsh : [B, N_q, S_q, H_v] (output BNSH, transposed to BSNH before leaving op) + // unfused_workspace: FP32 QK scratch + T softmax scratch (sized by + // GetGqaUnfusedAttentionWorkspaceSize) + T* unfused_q_bnsh = nullptr; + T* unfused_y_bnsh = nullptr; + void* unfused_workspace = nullptr; }; template diff --git a/onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.cu b/onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.cu new file mode 100644 index 0000000000000..7ca8516dc617e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.cu @@ -0,0 +1,417 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// GQA-capable unfused CUDA attention kernel. See header for contract. + +#include +#include +#include +#include +#include "core/common/safeint.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cuda_type_conversion.h" +#include "core/providers/cuda/shared_inc/cuda_call.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "contrib_ops/cuda/bert/gqa_unfused_attention.h" + +using onnxruntime::cuda::OrtToCudaType; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +namespace { + +constexpr size_t kAlign = 256; + +inline SafeInt AlignTo(SafeInt a, size_t b) { return ((a + (b - 1)) / b) * b; } + +// Device helper: convert T to float. Specialised for __half and __nv_bfloat16 +// to keep conversions consistent with the rest of the codebase. +template +__device__ __forceinline__ float ToFloat(T v); +template <> +__device__ __forceinline__ float ToFloat(float v) { return v; } +template <> +__device__ __forceinline__ float ToFloat<__half>(__half v) { return __half2float(v); } +template <> +__device__ __forceinline__ float ToFloat<__nv_bfloat16>(__nv_bfloat16 v) { return __bfloat162float(v); } + +inline size_t QkElementCount(int batch_size, int num_heads, int q_seq, int total_kv) { + return SafeInt(batch_size) * num_heads * q_seq * total_kv; +} + +// --------------------------------------------------------------------------- +// Softmax kernel: reads FP32 QK scores, writes T softmax output. +// +// Applies (in this order): +// 1. scale: x = scale * qk +// 2. softcap (if > 0): x = softcap * tanh(x / softcap) +// 3. attn_bias (if provided): x += bias +// 4. mask (causal + sliding window + per-batch seqlens_k) +// 5. stable softmax across [start, end) for each row +// +// Uses 3-pass strided reads to avoid shared memory size limits for large +// total_kv_length. Handles fully-masked rows by emitting zeros (no NaN). +// --------------------------------------------------------------------------- +template +__global__ void GqaUnfusedSoftmaxKernel( + const int q_sequence_length, + const int total_kv_length, + const int num_heads, // N_q + const float* __restrict__ qk_in, + const T* __restrict__ attn_bias, + const bool has_bias, + const bool bcast_bias_dim_0, + const bool bcast_bias_dim_1, + const int* __restrict__ seqlens_k, + const bool is_causal, + const int local_window_size, + const float scale, + const float softcap, + T* __restrict__ softmax_out) { + // Grid: (N_q * S_q, B, 1). Block: (TPB, 1, 1). + const int q_in_head = blockIdx.x % q_sequence_length; + const int head = blockIdx.x / q_sequence_length; // 0..N_q-1 + const int batch = blockIdx.y; + + int kv_end = total_kv_length; + if (seqlens_k != nullptr) { + int v = seqlens_k[batch]; + if (v < kv_end) kv_end = v; + if (v < 0) kv_end = 0; + } + // past (number of KV positions before the current query tokens) must be + // per-batch when seqlens_k is provided, since different batches can have + // different amounts of valid past context. Using the global total_kv_length + // would over-estimate past for short batches and shift the sliding-window + // start past kv_end, producing an all-masked (zero) row. + const int past = kv_end - q_sequence_length; + const int q_pos = past + q_in_head; + + int end = kv_end; + if (is_causal) { + const int c = q_pos + 1; + if (c < end) end = c; + } + int start = 0; + if (local_window_size >= 0) { + const int s = q_pos - local_window_size; + if (s > start) start = s; + } + if (end < 0) end = 0; + if (start > end) start = end; + + // Row offsets + const int64_t row_idx = (static_cast(batch) * gridDim.x) + blockIdx.x; + const int64_t row_offset = row_idx * total_kv_length; + + int64_t bias_row_offset = 0; + if (has_bias) { + const int b_eff = bcast_bias_dim_0 ? 0 : batch; + const int n_stride = bcast_bias_dim_1 ? 1 : num_heads; + const int h_eff = bcast_bias_dim_1 ? 0 : head; + bias_row_offset = ((static_cast(b_eff) * n_stride + h_eff) * q_sequence_length + q_in_head) * + static_cast(total_kv_length); + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmp_storage; + __shared__ float s_max; + __shared__ float s_inv_sum; + + // Pass 1: compute max of masked values. + float thread_max = -CUDART_INF_F; + for (int i = threadIdx.x; i < total_kv_length; i += TPB) { + if (i < start || i >= end) continue; + float x = qk_in[row_offset + i] * scale; + if (softcap > 0.f) { + x = softcap * tanhf(x / softcap); + } + if (has_bias) { + x += ToFloat(attn_bias[bias_row_offset + i]); + } + if (x > thread_max) thread_max = x; + } + float block_max = BlockReduce(tmp_storage).Reduce(thread_max, cub::Max()); + if (threadIdx.x == 0) s_max = block_max; + __syncthreads(); + + // If the row is fully masked, emit zeros (match existing mask-of-zeros behavior). + if (s_max == -CUDART_INF_F) { + for (int i = threadIdx.x; i < total_kv_length; i += TPB) { + softmax_out[row_offset + i] = T(0.f); + } + return; + } + + // Pass 2: compute sum of exp. + float thread_sum = 0.f; + for (int i = threadIdx.x; i < total_kv_length; i += TPB) { + if (i < start || i >= end) continue; + float x = qk_in[row_offset + i] * scale; + if (softcap > 0.f) { + x = softcap * tanhf(x / softcap); + } + if (has_bias) { + x += ToFloat(attn_bias[bias_row_offset + i]); + } + thread_sum += expf(x - s_max); + } + float block_sum = BlockReduce(tmp_storage).Reduce(thread_sum, cub::Sum()); + if (threadIdx.x == 0) s_inv_sum = (block_sum > 0.f) ? (1.f / block_sum) : 0.f; + __syncthreads(); + + // Pass 3: write softmax output in type T. + for (int i = threadIdx.x; i < total_kv_length; i += TPB) { + float y = 0.f; + if (i >= start && i < end) { + float x = qk_in[row_offset + i] * scale; + if (softcap > 0.f) { + x = softcap * tanhf(x / softcap); + } + if (has_bias) { + x += ToFloat(attn_bias[bias_row_offset + i]); + } + y = expf(x - s_max) * s_inv_sum; + } + softmax_out[row_offset + i] = T(y); + } +} + +template +void LaunchGqaUnfusedSoftmax( + cudaStream_t stream, + const GqaUnfusedAttentionParams& params, + const float* qk_in, + const T* attn_bias, + T* softmax_out) { + const dim3 grid(params.num_heads * params.q_sequence_length, params.batch_size, 1); + const bool has_bias = (attn_bias != nullptr); + constexpr int TPB = 256; + GqaUnfusedSoftmaxKernel<<>>( + params.q_sequence_length, + params.total_kv_length, + params.num_heads, + qk_in, + attn_bias, + has_bias, + params.broadcast_attn_bias_dim_0, + params.broadcast_attn_bias_dim_1, + params.seqlens_k, + params.is_causal, + params.local_window_size, + params.scale, + params.softcap, + softmax_out); +} + +// --------------------------------------------------------------------------- +// QK GEMM: FP32 accumulate into FP32 scratch (fixes #28195 fp16 overflow). +// +// Reshape-Q trick for GQA: batch_count = B * N_kv, and within each batch the +// Q matrix is the concatenation of `group_size` Q heads. No K/V replication. +// +// Per-batch matrices (row-major view): +// Q sub-matrix: [group_size * S_q, H] (contiguous in memory: BNSH layout +// with heads in the same KV group +// being contiguous). +// K sub-matrix: [S_kv, H] +// C sub-matrix: [group_size * S_q, S_kv] +// +// cuBLAS is column-major: the row-major (M, K) matrix is column-major (K, M). +// We issue: C = op_A(A) * op_B(B) where +// A = K (col-major (H, S_kv), op_A = T) → M_cublas = S_kv, K_cublas = H +// B = Q (col-major (H, group_size*S_q), op_B = N) +// → N_cublas = group_size*S_q +// C (col-major (S_kv, group_size*S_q)) == row-major (group_size*S_q, S_kv). +// --------------------------------------------------------------------------- +template +cudaDataType CudaTypeFor(); +template <> +cudaDataType CudaTypeFor<__half>() { return CUDA_R_16F; } +template <> +cudaDataType CudaTypeFor<__nv_bfloat16>() { return CUDA_R_16BF; } +template <> +cudaDataType CudaTypeFor() { return CUDA_R_32F; } + +template +common::Status LaunchQkGemmFp32( + const cudaDeviceProp& /*device_prop*/, + cublasHandle_t cublas, + const GqaUnfusedAttentionParams& params, + const T* query, + const T* key, + float* qk_out) { + const int B = params.batch_size; + const int N_kv = params.kv_num_heads; + const int group = params.num_heads / params.kv_num_heads; + const int S_q = params.q_sequence_length; + const int S_kv = params.total_kv_length; + const int H = params.head_size; + + const float alpha = 1.0f; + const float beta = 0.0f; + + // Strides between (b, n_kv) blocks: + // Q is BNSH with heads grouped: element (b, n_kv, g, s_q, h) at offset + // ((b * N_kv + n_kv) * group + g) * S_q * H + s_q * H + h + // so stride per (b, n_kv) = group * S_q * H. + // K is BNSH: per (b, n_kv) = max_kv_length * H. + // C (fp32 scratch) is packed [B, N_q, S_q, S_kv]; per (b, n_kv) block is + // group * S_q * S_kv. + const int64_t stride_q = static_cast(group) * S_q * H; + const int64_t stride_k = static_cast(params.max_kv_length) * H; + const int64_t stride_c = static_cast(group) * S_q * S_kv; + + cudaDataType ab_type = CudaTypeFor(); + // compute + scale type is FP32 → no fp16 overflow of raw QK even at + // head_size=512, scale=1.0 (direct fix for issue #28195). + cublasStatus_t status = cublasGemmStridedBatchedEx( + cublas, + CUBLAS_OP_T, CUBLAS_OP_N, + /*m=*/S_kv, /*n=*/group * S_q, /*k=*/H, + &alpha, + /*A=*/key, ab_type, /*lda=*/H, stride_k, + /*B=*/query, ab_type, /*ldb=*/H, stride_q, + &beta, + /*C=*/qk_out, CUDA_R_32F, /*ldc=*/S_kv, stride_c, + /*batch_count=*/B * N_kv, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT); + + if (status != CUBLAS_STATUS_SUCCESS) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GqaUnfusedAttention QK GEMM failed: ", status); + } + return common::Status::OK(); +} + +// --------------------------------------------------------------------------- +// Attn * V GEMM: C = P * V, where P is the softmax output in type T. +// +// Per-batch matrices (row-major view, batch_count = B * N_kv): +// P sub-matrix: [group_size * S_q, S_kv] (packed, leading dim = total_kv) +// V sub-matrix: [S_kv, H_v] +// Y sub-matrix: [group_size * S_q, H_v] +// +// cuBLAS column-major: issue C = A * B with +// A = V (col-major (H_v, S_kv), op_A = N) → M = H_v, K = S_kv +// B = P (col-major (S_kv, group*S_q), op_B = N) → N = group * S_q +// --------------------------------------------------------------------------- +template +common::Status LaunchAttnVGemm( + cublasHandle_t cublas, + const GqaUnfusedAttentionParams& params, + const T* softmax_out, + const T* value, + T* output) { + const int B = params.batch_size; + const int N_kv = params.kv_num_heads; + const int group = params.num_heads / params.kv_num_heads; + const int S_q = params.q_sequence_length; + const int S_kv = params.total_kv_length; + const int H_v = params.v_head_size; + + const float alpha = 1.0f; + const float beta = 0.0f; + + const int64_t stride_v = static_cast(params.max_kv_length) * H_v; + const int64_t stride_p = static_cast(group) * S_q * S_kv; + const int64_t stride_y = static_cast(group) * S_q * H_v; + + // Use cublasGemmStridedBatchedEx directly with FP32 alpha/beta + FP32 compute. + // The helper has no __nv_bfloat16 overload and __half overload depends on + // global HalfGemmOptions; going direct gives deterministic behavior. + cudaDataType ab_type = CudaTypeFor(); + cublasStatus_t status = cublasGemmStridedBatchedEx( + cublas, + CUBLAS_OP_N, CUBLAS_OP_N, + /*m=*/H_v, /*n=*/group * S_q, /*k=*/S_kv, + &alpha, + /*A=*/value, ab_type, /*lda=*/H_v, stride_v, + /*B=*/softmax_out, ab_type, /*ldb=*/S_kv, stride_p, + &beta, + /*C=*/output, ab_type, /*ldc=*/H_v, stride_y, + /*batch_count=*/B * N_kv, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT); + if (status != CUBLAS_STATUS_SUCCESS) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GqaUnfusedAttention AV GEMM failed: ", status); + } + return common::Status::OK(); +} + +} // namespace + +// --------------------------------------------------------------------------- +// Public API +// --------------------------------------------------------------------------- +size_t GetGqaUnfusedAttentionWorkspaceSize(int batch_size, + int num_heads, + int q_sequence_length, + int total_kv_length) { + const size_t elems = QkElementCount(batch_size, num_heads, q_sequence_length, total_kv_length); + // FP32 QK scratch + T softmax scratch. We always allocate sizeof(float) per + // element for the T scratch too (upper bound); caller can cast appropriately. + const size_t qk_bytes = AlignTo(SafeInt(elems) * sizeof(float), kAlign); + const size_t softmax_bytes = AlignTo(SafeInt(elems) * sizeof(float), kAlign); + return SafeInt(qk_bytes) + softmax_bytes; +} + +template +common::Status LaunchGqaUnfusedAttention( + const cudaDeviceProp& device_prop, + cublasHandle_t cublas, + cudaStream_t stream, + const GqaUnfusedAttentionParams& params, + const T* query, + const T* key, + const T* value, + const T* attn_bias, + T* output, + void* workspace) { + ORT_RETURN_IF_NOT(params.batch_size > 0 && params.num_heads > 0 && params.kv_num_heads > 0 && + params.head_size > 0 && params.v_head_size > 0 && + params.q_sequence_length > 0 && params.total_kv_length > 0 && + params.max_kv_length >= params.total_kv_length, + "GqaUnfusedAttention: invalid params."); + ORT_RETURN_IF_NOT(params.num_heads % params.kv_num_heads == 0, + "GqaUnfusedAttention: num_heads (", params.num_heads, + ") must be a multiple of kv_num_heads (", params.kv_num_heads, ")."); + ORT_RETURN_IF(workspace == nullptr, "GqaUnfusedAttention: workspace is null."); + + const size_t elems = QkElementCount(params.batch_size, params.num_heads, + params.q_sequence_length, params.total_kv_length); + const size_t qk_bytes = AlignTo(SafeInt(elems) * sizeof(float), kAlign); + + auto* qk_fp32 = reinterpret_cast(workspace); + auto* softmax_T = reinterpret_cast(reinterpret_cast(workspace) + qk_bytes); + + ORT_RETURN_IF_ERROR((LaunchQkGemmFp32(device_prop, cublas, params, query, key, qk_fp32))); + + LaunchGqaUnfusedSoftmax(stream, params, qk_fp32, attn_bias, softmax_T); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + + ORT_RETURN_IF_ERROR((LaunchAttnVGemm(cublas, params, softmax_T, value, output))); + + return common::Status::OK(); +} + +// Explicit template instantiations. +template common::Status LaunchGqaUnfusedAttention<__half>( + const cudaDeviceProp&, cublasHandle_t, cudaStream_t, + const GqaUnfusedAttentionParams&, const __half*, const __half*, const __half*, + const __half*, __half*, void*); +template common::Status LaunchGqaUnfusedAttention<__nv_bfloat16>( + const cudaDeviceProp&, cublasHandle_t, cudaStream_t, + const GqaUnfusedAttentionParams&, const __nv_bfloat16*, const __nv_bfloat16*, + const __nv_bfloat16*, const __nv_bfloat16*, __nv_bfloat16*, void*); +template common::Status LaunchGqaUnfusedAttention( + const cudaDeviceProp&, cublasHandle_t, cudaStream_t, + const GqaUnfusedAttentionParams&, const float*, const float*, const float*, + const float*, float*, void*); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.h b/onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.h new file mode 100644 index 0000000000000..84d645cd2b349 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.h @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "core/common/status.h" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// ============================================================================ +// GQA Unfused Attention (CUDA fallback for large head_size / fp16 overflow) +// ============================================================================ +// +// Purpose: +// A numerically-stable unfused attention kernel that handles: +// - Group-Query Attention natively (num_heads != kv_num_heads) via a +// reshape-Q trick (no K/V head-replication) — works with MHA too. +// - head_size > 256 in fp16/bf16 (writes QK scores into a FP32 scratch so +// raw Q*K^T cannot overflow fp16 even when scale=1.0 — see issue #28195). +// - Different Q and K sequence lengths (prompt with/without past). +// - Causal mask, optional sliding-window mask, optional softcap, optional +// additive attention bias, per-batch variable k-sequence lengths. +// +// Input layout contract: +// Q : [B, N_q, S_q, H] BNSH, contiguous. N_q must be a multiple of +// N_kv; heads within a KV group must be contiguous +// (i.e. [B, N_kv, group_size, S_q, H]). +// K cache : [B, N_kv, max_S_kv, H] BNSH. Valid data is [..., 0:total_kv, :]. +// V cache : [B, N_kv, max_S_kv, H_v] BNSH. Valid data is [..., 0:total_kv, :]. +// Output : [B, N_q, S_q, H_v] BNSH, contiguous. +// +// Mask/softcap/scale semantics: +// - scale is applied to raw QK (before softcap / bias). +// - softcap (> 0) is applied after scale: x = softcap * tanh(x / softcap). +// - attn_bias (if non-null) is added after softcap (additive mask). +// - causal: k > (past + q) is -inf where past = total_kv - S_q. +// - local_window_size (>= 0): k < (past + q) - local_window_size is -inf. +// local_window_size == -1 disables the sliding-window mask. +// +// The new kernel is suitable only as a fallback when Flash / MEA are ineligible +// (head_size > 256, past_key present with mask, GQA with MHA-only unfused, etc). +// The QK GEMM runs with CUBLAS_COMPUTE_32F and writes a FP32 scratch to avoid +// fp16 overflow. +// +// ============================================================================ + +struct GqaUnfusedAttentionParams { + int batch_size = 0; + int num_heads = 0; // N_q + int kv_num_heads = 0; // N_kv (num_heads % kv_num_heads == 0) + int head_size = 0; // H + int v_head_size = 0; // H_v (usually == H) + + int q_sequence_length = 0; // S_q + int total_kv_length = 0; // total valid K/V positions (past + new) + int max_kv_length = 0; // K/V buffer allocated length for stride (>= total_kv_length) + + // attn_bias (optional): shape [B or 1, N_q or 1, S_q, total_kv_length] (row-major). + // When broadcast_dim_0 is true, batch axis is broadcast (shape[0]==1). + // When broadcast_dim_1 is true, head axis is broadcast (shape[1]==1). + bool broadcast_attn_bias_dim_0 = false; + bool broadcast_attn_bias_dim_1 = false; + + bool is_causal = false; + int local_window_size = -1; // -1 disables sliding window + float scale = 1.0f; + float softcap = 0.0f; // 0 disables + + // Per-batch K lengths (optional). When non-null, positions k >= seqlens_k[b] + // are masked out (useful for right-padded packed batches). + const int* seqlens_k = nullptr; +}; + +// Returns required scratch size in bytes. Caller must allocate +// GetGqaUnfusedAttentionWorkspaceSize(...) bytes and pass as workspace. +size_t GetGqaUnfusedAttentionWorkspaceSize(int batch_size, + int num_heads, + int q_sequence_length, + int total_kv_length); + +// Compute: Y = softmax(scale * Q * K^T [softcap, causal, window, bias, seqlens_k]) * V. +// All pointers are on device. Q/K/V/output are in type T (fp16/bf16/float). +// attn_bias (if present) is in type T. +template +common::Status LaunchGqaUnfusedAttention( + const cudaDeviceProp& device_prop, + cublasHandle_t cublas, + cudaStream_t stream, + const GqaUnfusedAttentionParams& params, + const T* query, + const T* key, + const T* value, + const T* attn_bias, + T* output, + void* workspace); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 3b6b5f9079ebe..5f21f3cd34e8f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -4,6 +4,7 @@ #include #include #include +#include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cuda_type_conversion.h" #include "core/platform/env_var_utils.h" @@ -13,6 +14,7 @@ #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" #include "contrib_ops/cuda/bert/xqa/xqa_loader.h" +#include "contrib_ops/cuda/bert/gqa_unfused_attention.h" #include "contrib_ops/cuda/utils/dump_cuda_tensor.h" #include "contrib_ops/cpu/utils/debug_macros.h" @@ -507,6 +509,47 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons data.qkv_buffer = reinterpret_cast(unpacked_qkv_buffer.get()); } + // --------------------------------------------------------------------- + // GQA-capable unfused fallback (issue #28195). + // Activates when Flash / MEA / XQA are all ineligible and KV is not quantized. + // Supports any head_size (FP32 QK accumulation), GQA, sliding window, softcap. + // See LaunchGqaUnfusedAttention in contrib_ops/cuda/bert/gqa_unfused_attention.h. + // --------------------------------------------------------------------- + IAllocatorUniquePtr unfused_scratch; + if (!data.use_xqa && !data.use_flash_attention && !data.use_memory_efficient_attention && + !is_inputs_quantized && !parameters.use_smooth_softmax && head_sink == nullptr && + parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH) { + data.use_unfused = true; + + const size_t B = static_cast(parameters.batch_size); + const size_t N_q = static_cast(parameters.num_heads); + const size_t S_q = static_cast(parameters.sequence_length); + const size_t H = static_cast(parameters.head_size); + // GQA guarantees head_size == v_head_size; use H_v for the Y output buffer + // so the allocation stays correct if a distinct v_head_size is ever exposed. + const size_t H_v = (parameters.v_head_size > 0) + ? static_cast(parameters.v_head_size) + : H; + const size_t S_kv = static_cast(parameters.total_sequence_length); + + auto align = [](SafeInt v) -> SafeInt { + return ((v + SafeInt(255)) / SafeInt(256)) * SafeInt(256); + }; + const SafeInt q_bnsh_bytes = align(SafeInt(B) * N_q * S_q * H * sizeof(T)); + const SafeInt y_bnsh_bytes = align(SafeInt(B) * N_q * S_q * H_v * sizeof(T)); + const SafeInt ws_bytes = SafeInt( + onnxruntime::contrib::cuda::GetGqaUnfusedAttentionWorkspaceSize( + static_cast(B), static_cast(N_q), static_cast(S_q), static_cast(S_kv))); + const SafeInt workspace_offset = q_bnsh_bytes + y_bnsh_bytes; + + unfused_scratch = GetScratchBuffer(static_cast(q_bnsh_bytes + y_bnsh_bytes + ws_bytes), + GetComputeStream(context)); + auto* base = reinterpret_cast(unfused_scratch.get()); + data.unfused_q_bnsh = reinterpret_cast(base); + data.unfused_y_bnsh = reinterpret_cast(base + static_cast(q_bnsh_bytes)); + data.unfused_workspace = reinterpret_cast(base + static_cast(workspace_offset)); + } + if (kernel_options_->AllowDebugInfo()) { AttentionKernelDebugInfo debug_info; debug_info.use_flash_attention = data.use_flash_attention; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index c617de747ccf7..ebb6a0b0da215 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -38,6 +38,7 @@ limitations under the License. #include "contrib_ops/cuda/bert/bert_padding.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include "contrib_ops/cuda/bert/gqa_unfused_attention.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cpu/bert/attention_common.h" #include "contrib_ops/cuda/bert/group_query_attention_qkv.cuh" @@ -1030,12 +1031,123 @@ Status EfficientAttention( } #endif +// ============================================================================ +// UnfusedGqaAttention: fallback path that handles GQA natively and fixes the +// fp16 head_size > 256 NaN (issue #28195). +// +// Dispatched when Flash / MEA / XQA are all ineligible. Supports: +// - Any head_size up to H (FP32 QK accumulation avoids fp16 overflow). +// - GQA (num_heads != kv_num_heads) via reshape-Q trick in the GEMM. +// - Different Q / K sequence lengths (first prompt or decode with past). +// - Causal, sliding window (local_window_size), softcap, per-batch seqlens. +// +// Not supported (caller falls through elsewhere): +// - Quantized KV cache (U != T): hit by the original NOT_IMPLEMENTED path. +// - attention_bias input: rejected by op-level ComputeInternal. +// - Smooth softmax / head_sink: Flash-only feature. +// ============================================================================ +template +Status UnfusedGqaAttention( + const cudaDeviceProp& device_prop, + cublasHandle_t cublas, + cudaStream_t stream, + GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + float scale) { + static_assert(std::is_same::value, + "UnfusedGqaAttention requires non-quantized KV cache (T == U)."); + + const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + const int max_kv = parameters.seqlen_present_kv_cache; + + ORT_RETURN_IF(data.unfused_q_bnsh == nullptr || data.unfused_y_bnsh == nullptr || + data.unfused_workspace == nullptr, + "Unfused GQA scratch buffers are not allocated."); + ORT_RETURN_IF(parameters.past_kv_format != AttentionQkvFormat::Q_K_V_BNSH, + "Unfused GQA fallback requires BNSH KV cache layout."); + + ORT_GQA_TRACE("UnfusedGqaAttention"); + + // Step 1: unpack Q (optionally RoPE), append new K/V into present_key/value (BNSH). + const T* q_prep = nullptr; + ORT_RETURN_IF_ERROR((PrepareQKV(stream, max_threads_per_block, parameters, data, q_prep))); + + // Step 2: transpose Q from BSNH (PrepareQKV output) to BNSH. + // Transpose_BSNH_to_BNSH has overloads for half/BFloat16/float; bridge via reinterpret_cast. + // GQA only registers half and bf16 types; guard against accidental float instantiation. + static_assert(std::is_same::value || std::is_same::value, + "UnfusedGqaAttention transpose only supports __half and __nv_bfloat16."); + if constexpr (std::is_same::value) { + ORT_RETURN_IF_ERROR((Transpose_BSNH_to_BNSH(batch_size, sequence_length, num_heads, head_size, + reinterpret_cast(q_prep), + reinterpret_cast(data.unfused_q_bnsh), + stream, max_threads_per_block))); + } else if constexpr (std::is_same::value) { + ORT_RETURN_IF_ERROR((Transpose_BSNH_to_BNSH(batch_size, sequence_length, num_heads, head_size, + reinterpret_cast(q_prep), + reinterpret_cast(data.unfused_q_bnsh), + stream, max_threads_per_block))); + } + + // Step 3: run unfused attention with FP32 QK accumulation. + GqaUnfusedAttentionParams p; + p.batch_size = batch_size; + p.num_heads = num_heads; + p.kv_num_heads = kv_num_heads; + p.head_size = head_size; + ORT_ENFORCE(head_size == parameters.v_head_size || parameters.v_head_size == 0, + "UnfusedGqaAttention requires head_size == v_head_size"); + p.v_head_size = head_size; // GQA op has head_size == v_head_size + p.q_sequence_length = sequence_length; + // For the decode/prompt, data.total_seq_lens[b] <= seqlen_present_kv_cache. + // Use seqlen_present_kv_cache as the upper bound for the GEMM and pass per-batch + // seqlens to the softmax so positions beyond the valid length are masked. + p.total_kv_length = parameters.total_sequence_length; + p.max_kv_length = max_kv; + p.broadcast_attn_bias_dim_0 = false; + p.broadcast_attn_bias_dim_1 = false; + p.is_causal = parameters.is_unidirectional; + p.local_window_size = parameters.local_window_size; // -1 disables + p.scale = scale; + p.softcap = parameters.softcap; + p.seqlens_k = data.total_seq_lens; + + ORT_RETURN_IF_ERROR((LaunchGqaUnfusedAttention( + device_prop, cublas, stream, p, + data.unfused_q_bnsh, + reinterpret_cast(data.present_key), + reinterpret_cast(data.present_value), + /*attn_bias=*/nullptr, + data.unfused_y_bnsh, + data.unfused_workspace))); + + // Step 4: transpose output BNSH → BSNH into data.output. + // Use p.v_head_size (== head_size per ORT_ENFORCE) for semantic correctness. + if constexpr (std::is_same::value) { + ORT_RETURN_IF_ERROR((Transpose_BNSH_to_BSNH(batch_size, sequence_length, num_heads, p.v_head_size, + reinterpret_cast(data.unfused_y_bnsh), + reinterpret_cast(data.output), + stream, max_threads_per_block))); + } else if constexpr (std::is_same::value) { + ORT_RETURN_IF_ERROR((Transpose_BNSH_to_BSNH(batch_size, sequence_length, num_heads, p.v_head_size, + reinterpret_cast(data.unfused_y_bnsh), + reinterpret_cast(data.output), + stream, max_threads_per_block))); + } + return Status::OK(); +} + ////////// API Functions template Status QkvToContext( const cudaDeviceProp& device_prop, - cublasHandle_t& /*cublas*/, + cublasHandle_t& cublas, Stream* ort_stream, GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data) { @@ -1069,6 +1181,15 @@ Status QkvToContext( } #endif + if (data.use_unfused) { + if constexpr (std::is_same::value) { + return UnfusedGqaAttention(device_prop, cublas, stream, parameters, data, scale); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "Unfused GQA fallback does not support quantized KV cache."); + } + } + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unfused Group Query Attention not implemented yet."); } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh index 1fd5713f9407a..5fb36e094482b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh @@ -292,8 +292,13 @@ Status DispatchUnpackRoPEAppendHeadSize( packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); + } else if (head_size <= 512) { + UnpackRoPEAppend<<>>( + packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, + num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Head size (", head_size, ") exceeds maximum supported MAX_HEAD_SIZE (256)."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Head size (", head_size, ") exceeds maximum supported MAX_HEAD_SIZE (512)."); } return CUDA_CALL(cudaGetLastError()); } diff --git a/onnxruntime/core/providers/cuda/llm/attention.cc b/onnxruntime/core/providers/cuda/llm/attention.cc index 0c2a646278e65..228729745b65b 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.cc +++ b/onnxruntime/core/providers/cuda/llm/attention.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "core/providers/cpu/llm/attention.h" #include "core/providers/cpu/llm/attention_helper.h" @@ -10,6 +11,7 @@ #include "contrib_ops/cuda/bert/attention_impl.h" #include "contrib_ops/cuda/bert/attention_kv_cache.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" +#include "contrib_ops/cuda/bert/gqa_unfused_attention.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" #include "core/providers/cuda/cuda_type_conversion.h" @@ -1038,13 +1040,249 @@ Status Attention::RunUnfusedAttention( device_prop, cublas, cudnn, ort_stream.get(), contribop_parameters, data); } +// ============================================================================ +// RunGqaUnfusedAttention: GQA-capable unfused path + large-head fp16/bf16 fix +// ============================================================================ +// +// Routes to LaunchGqaUnfusedAttention from contrib_ops/cuda/bert/gqa_unfused_attention.h. +// +// Handles: +// - GQA natively (no K/V head replication; reshape-Q trick inside kernel). +// - fp16/bf16 with large head_size via FP32 QK scratch (fixes issue #28195: +// unfused attention producing NaN when head_dim > 256 at scale=1.0). +// - Different Q/K sequence lengths, past_key+past_value, nonpad_kv_seqlen. +// - attn_mask (bool/float, 2D/3D/4D), causal, softcap. +// +// Not supported here (caller rejects upstream): +// - output_qk: only MHA unfused emits QK, so this path requires output_qk==nullptr. +// ============================================================================ +template +Status Attention::RunGqaUnfusedAttention( + OpKernelContext* context, + const Tensor* Q, const Tensor* K, const Tensor* V, + const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value, + const Tensor* nonpad_kv_seqlen, + Tensor* Y, Tensor* present_key, Tensor* present_value, + const attention_helper::AttentionParameters& parameters) const { + using NativeCudaT = typename onnxruntime::cuda::OrtToCudaType::type; + auto& device_prop = GetDeviceProp(); + auto cuda_stream = Stream(context); + const bool is_bsnh = parameters.transpose_output; + const int B = parameters.batch_size; + const int S_q = parameters.q_sequence_length; + const int N_q = parameters.q_num_heads; + const int N_kv = parameters.kv_num_heads; + const int H = parameters.head_size; + const int H_v = parameters.v_head_size; + const int total_kv = parameters.total_sequence_length; + const int max_threads = device_prop.maxThreadsPerBlock; + + // -------- Build BNSH Q (transpose if input was BSNH) ------------------------ + const NativeCudaT* q_bnsh = nullptr; + IAllocatorUniquePtr q_bnsh_buffer; + if (is_bsnh) { + const size_t q_bytes = SafeInt(B) * S_q * N_q * H * sizeof(T); + q_bnsh_buffer = GetScratchBuffer(q_bytes, GetComputeStream(context)); + ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH(B, S_q, N_q, H, + Q->Data(), q_bnsh_buffer.get(), + cuda_stream, max_threads)); + q_bnsh = reinterpret_cast(q_bnsh_buffer.get()); + } else { + q_bnsh = reinterpret_cast(Q->Data()); + } + + // -------- Build BNSH K/V cache of length total_kv -------------------------- + // Three cases: + // (a) nonpad_kv_seqlen: K/V are the full cache (kv_seq == total_kv). + // (b) past_key + new K/V: concat via LaunchConcatNewToPastKV into present buffers. + // (c) no past: K/V are the new tokens only (total_kv == kv_sequence_length). + // In cases (a) and (c) the cache is contiguous in the input tensors (subject + // to a BSNH->BNSH transpose). Case (b) writes into present_key/present_value. + const NativeCudaT* k_cache = nullptr; + const NativeCudaT* v_cache = nullptr; + IAllocatorUniquePtr k_bnsh_buffer; + IAllocatorUniquePtr v_bnsh_buffer; + bool present_already_populated = false; + + if (past_key != nullptr) { + ORT_ENFORCE(past_value != nullptr, "past_key requires past_value."); + ORT_ENFORCE(present_key != nullptr && present_value != nullptr, + "present_key/value outputs are required when past_key is provided."); + // LaunchConcatNewToPastKV uses a single head_size for both K and V caches. + ORT_RETURN_IF(H != H_v, + "RunGqaUnfusedAttention: past_key with H != H_v not supported"); + auto past_seqlens_buffer = GetScratchBuffer(B, GetComputeStream(context)); + ORT_RETURN_IF_ERROR(LaunchFillInt32(past_seqlens_buffer.get(), + parameters.past_sequence_length, B, + cuda_stream, max_threads)); + + // New K/V must be BSNH for the concat kernel; transpose if 4D BNSH input. + const T* k_new_bsnh = K->Data(); + const T* v_new_bsnh = V->Data(); + if (!is_bsnh) { + const size_t kn_bytes = SafeInt(B) * parameters.kv_sequence_length * N_kv * H * sizeof(T); + const size_t vn_bytes = SafeInt(B) * parameters.kv_sequence_length * N_kv * H_v * sizeof(T); + k_bnsh_buffer = GetScratchBuffer(kn_bytes, GetComputeStream(context)); + v_bnsh_buffer = GetScratchBuffer(vn_bytes, GetComputeStream(context)); + ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH(B, parameters.kv_sequence_length, N_kv, H, + K->Data(), k_bnsh_buffer.get(), + cuda_stream, max_threads)); + ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH(B, parameters.kv_sequence_length, N_kv, H_v, + V->Data(), v_bnsh_buffer.get(), + cuda_stream, max_threads)); + k_new_bsnh = static_cast(k_bnsh_buffer.get()); + v_new_bsnh = static_cast(v_bnsh_buffer.get()); + } + + ORT_RETURN_IF_ERROR(onnxruntime::contrib::cuda::LaunchConcatNewToPastKV( + B, N_kv, H, parameters.kv_sequence_length, parameters.past_sequence_length, total_kv, + /*is_bsnh=*/false, + past_seqlens_buffer.get(), /*total_seq_lens=*/nullptr, + reinterpret_cast(past_key->Data()), + reinterpret_cast(past_value->Data()), + reinterpret_cast(k_new_bsnh), + reinterpret_cast(v_new_bsnh), + reinterpret_cast(present_key->MutableData()), + reinterpret_cast(present_value->MutableData()), + cuda_stream, max_threads, /*past_only=*/false)); + k_cache = reinterpret_cast(present_key->MutableData()); + v_cache = reinterpret_cast(present_value->MutableData()); + present_already_populated = true; + } else if (is_bsnh) { + // BSNH K/V -> BNSH. total_kv == kv_sequence_length (no past). + // When present_key/present_value outputs exist, transpose directly into them + // to avoid a redundant copy later. + if (present_key != nullptr && present_value != nullptr) { + ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH(B, total_kv, N_kv, H, + K->Data(), present_key->MutableData(), + cuda_stream, max_threads)); + ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH(B, total_kv, N_kv, H_v, + V->Data(), present_value->MutableData(), + cuda_stream, max_threads)); + k_cache = reinterpret_cast(present_key->Data()); + v_cache = reinterpret_cast(present_value->Data()); + present_already_populated = true; + } else { + const size_t k_bytes = SafeInt(B) * total_kv * N_kv * H * sizeof(T); + const size_t v_bytes = SafeInt(B) * total_kv * N_kv * H_v * sizeof(T); + k_bnsh_buffer = GetScratchBuffer(k_bytes, GetComputeStream(context)); + v_bnsh_buffer = GetScratchBuffer(v_bytes, GetComputeStream(context)); + ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH(B, total_kv, N_kv, H, + K->Data(), k_bnsh_buffer.get(), + cuda_stream, max_threads)); + ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH(B, total_kv, N_kv, H_v, + V->Data(), v_bnsh_buffer.get(), + cuda_stream, max_threads)); + k_cache = reinterpret_cast(k_bnsh_buffer.get()); + v_cache = reinterpret_cast(v_bnsh_buffer.get()); + } + } else { + // 4D BNSH input, no past: use directly. + k_cache = reinterpret_cast(K->Data()); + v_cache = reinterpret_cast(V->Data()); + } + + // -------- Build per-batch seqlens (for nonpad_kv_seqlen) -------------------- + const int* seqlens_k_ptr = nullptr; + IAllocatorUniquePtr seqlens_k_buffer; + if (nonpad_kv_seqlen != nullptr) { + seqlens_k_buffer = GetScratchBuffer(B, GetComputeStream(context)); + ORT_RETURN_IF_ERROR(LaunchConvertNonpadKvSeqlenToFlashSeqlensK( + nonpad_kv_seqlen->Data(), seqlens_k_buffer.get(), + B, total_kv, cuda_stream, max_threads)); + seqlens_k_ptr = seqlens_k_buffer.get(); + } + + // -------- Build attn_bias from attn_mask ------------------------------------ + IAllocatorUniquePtr mask_bias_buffer; + const NativeCudaT* attn_bias_data = nullptr; + bool bcast0 = false, bcast1 = false; + if (attn_mask != nullptr) { + const void* bias_void = nullptr; + ORT_RETURN_IF_ERROR(ConvertAttnMaskToBias(context, attn_mask, cuda_stream, max_threads, + mask_bias_buffer, bias_void, bcast0, bcast1)); + attn_bias_data = reinterpret_cast(bias_void); + } + + // -------- Allocate output BNSH scratch (if 3D BSNH output needed) ---------- + NativeCudaT* out_bnsh = reinterpret_cast(Y->MutableData()); + IAllocatorUniquePtr out_bnsh_buffer; + if (is_bsnh) { + const size_t out_bytes = SafeInt(B) * S_q * N_q * H_v * sizeof(T); + out_bnsh_buffer = GetScratchBuffer(out_bytes, GetComputeStream(context)); + out_bnsh = reinterpret_cast(out_bnsh_buffer.get()); + } + + // -------- Allocate kernel workspace ----------------------------------------- + const size_t ws_bytes = onnxruntime::contrib::cuda::GetGqaUnfusedAttentionWorkspaceSize( + B, N_q, S_q, total_kv); + auto ws_buffer = GetScratchBuffer(ws_bytes, GetComputeStream(context)); + + // -------- Call the kernel --------------------------------------------------- + onnxruntime::contrib::cuda::GqaUnfusedAttentionParams p; + p.batch_size = B; + p.num_heads = N_q; + p.kv_num_heads = N_kv; + p.head_size = H; + p.v_head_size = H_v; + p.q_sequence_length = S_q; + p.total_kv_length = total_kv; + p.max_kv_length = total_kv; // ONNX Attention caches are packed (no shared buffer). + p.broadcast_attn_bias_dim_0 = bcast0; + p.broadcast_attn_bias_dim_1 = bcast1; + p.is_causal = parameters.is_causal; + p.local_window_size = -1; // ONNX Attention (opset 23/24) does not expose sliding window. + p.scale = parameters.scale; + p.softcap = parameters.softcap; + p.seqlens_k = seqlens_k_ptr; + + ORT_RETURN_IF_ERROR((onnxruntime::contrib::cuda::LaunchGqaUnfusedAttention( + device_prop, GetCublasHandle(context), cuda_stream, + p, q_bnsh, k_cache, v_cache, attn_bias_data, out_bnsh, ws_buffer.get()))); + + // -------- Transpose output BNSH -> BSNH if input was 3D -------------------- + if (is_bsnh && out_bnsh_buffer != nullptr) { + ORT_RETURN_IF_ERROR(TransposeBNSHtoBSNH(B, S_q, N_q, H_v, + out_bnsh_buffer.get(), Y->MutableData(), + cuda_stream, max_threads)); + } + + // -------- Populate present_key/present_value if requested ------------------ + if (!present_already_populated) { + if (present_key != nullptr) { + if (is_bsnh) { + ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH(B, parameters.kv_sequence_length, N_kv, H, + K->Data(), present_key->MutableData(), + cuda_stream, max_threads)); + } else { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( + present_key->MutableData(), K->Data(), + K->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream)); + } + } + if (present_value != nullptr) { + if (is_bsnh) { + ORT_RETURN_IF_ERROR(TransposeBSNHtoBNSH(B, parameters.kv_sequence_length, N_kv, H_v, + V->Data(), present_value->MutableData(), + cuda_stream, max_threads)); + } else { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync( + present_value->MutableData(), V->Data(), + V->SizeInBytes(), cudaMemcpyDeviceToDevice, cuda_stream)); + } + } + } + + return Status::OK(); +} + // ============================================================================ // ComputeInternal: Dispatch to appropriate attention kernel // ============================================================================ // MHA path (q_num_heads == kv_num_heads): uses direct kernel dispatch cascade // flash → memory efficient → unfused -// GQA path (q_num_heads != kv_num_heads): uses flash (handles GQA natively) or MEA -// (with head expansion via LaunchUngroup). Unfused fallback not yet supported for GQA. +// GQA path (q_num_heads != kv_num_heads): uses flash (handles GQA natively), MEA +// (with head expansion via LaunchUngroup, fp16/bf16 only), or GQA unfused fallback. // ============================================================================ template Status Attention::ComputeInternal(OpKernelContext* context) const { @@ -1145,7 +1383,10 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { sm, std::is_same::value, std::is_same::value, parameters.head_size, parameters.v_head_size) && !has_output_qk && - past_key == nullptr; + past_key == nullptr && + // GQA+MEA requires LaunchUngroup which only has fp16/bf16 instantiations. + // FP32 GQA must fall through to the unfused path. + !(is_gqa && std::is_same::value); // Cutlass FMHA requires bias strides to satisfy minimum alignment even in the // "unaligned" kernel path. When an attention mask is present (with or without @@ -1173,15 +1414,6 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { } #endif - // Fallback: unfused attention - // Softcap is not implemented in the unfused path — it requires Flash or MEA. - if (parameters.softcap > 0.0f) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "softcap requires flash attention or memory efficient attention, " - "but neither is eligible for this configuration. Check dtype (fp16/bf16 required for Flash), " - "head_size constraints, and past_key compatibility."); - } - // TODO(titaiwang): Support additional output_qk modes beyond kNone and kQK. // Currently only unfused handles output_qk, and only kNone/kQK modes. if (qk_matmul_output_mode_ != attention_helper::QKMatMulOutputMode::kNone && @@ -1191,16 +1423,46 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { "in Attention op (CUDA)."); } + // GQA-capable unfused fallback (issue #28195). + // Routes through LaunchGqaUnfusedAttention when: + // - GQA (q_num_heads != kv_num_heads) — the MHA unfused runner cannot handle this. + // - fp16/bf16 with head_size > 128 — raw Q*K^T can overflow fp16 storage even + // though cuBLAS accumulates in FP32; the new kernel writes QK to an FP32 scratch. + // The overflow threshold depends on the distribution of Q/K values and scale. + // head_size=256 at scale=1/sqrt(256)=0.0625 is borderline; head_size=512 at + // scale=1.0 (Gemma 4) definitely overflows. We use 128 as a conservative + // threshold since all fused kernels already handle head_size <= 128 anyway. + // This kernel supports softcap. It does not support output_qk, so we only enter it + // when qk_matmul_output_mode_ == kNone. + const bool is_half_or_bf16 = std::is_same::value || std::is_same::value; + const bool needs_fp32_qk_scratch = is_half_or_bf16 && parameters.head_size > 128; + if ((is_gqa || needs_fp32_qk_scratch) && + qk_matmul_output_mode_ == attention_helper::QKMatMulOutputMode::kNone) { + LOGS_DEFAULT(VERBOSE) << "Attention: using GQA unfused fallback (is_gqa=" << is_gqa + << ", needs_fp32_qk_scratch=" << needs_fp32_qk_scratch + << ", head_size=" << parameters.head_size + << ", softcap=" << parameters.softcap << ")"; + return RunGqaUnfusedAttention(context, Q, K, V, attn_mask, past_key, past_value, + nonpad_kv_seqlen, Y, present_key, present_value, parameters); + } + if (is_gqa) { - // TODO(titaiwang): Support GQA in unfused attention path for fp32/old-GPU fallback. - // Currently blocked because QkvToContext allocates K/V workspace assuming - // num_heads == kv_num_heads. GQA needs a head expansion step (ExpandKVHeads kernel) - // to replicate kv_num_heads -> q_num_heads before unfused can process. - // Requires ~160 lines. See issue #27516. + // qk_matmul_output_mode != kNone reaches here; the unfused MHA runner cannot handle GQA. return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "ONNX Attention with GQA (q_num_heads != kv_num_heads) is not supported by the " - "unfused runner. Flash requires fp16/bf16, SM>=80, and attn_mask==nullptr; MEA " - "requires past_key==nullptr. See PR #27851 for MEA past_key support."); + "ONNX Attention with GQA (q_num_heads != kv_num_heads) and output_qk is not " + "supported by the unfused runner."); + } + + // Fallback: unfused MHA attention (legacy runner). + // Softcap is not implemented in the legacy unfused path — it requires Flash or MEA + // (or the new GQA unfused path above, which supports softcap for fp16/bf16/fp32). + // NOTE: keep this guard even if future PRs add softcap to more fused paths — this + // legacy unfused runner does NOT apply softcap and would silently produce wrong results. + if (parameters.softcap > 0.0f) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "softcap requires flash attention or memory efficient attention, " + "but neither is eligible for this configuration. Check dtype (fp16/bf16 required for Flash), " + "head_size constraints, and past_key compatibility."); } return RunUnfusedAttention(context, Q, K, V, attn_mask, past_key, past_value, diff --git a/onnxruntime/core/providers/cuda/llm/attention.h b/onnxruntime/core/providers/cuda/llm/attention.h index c53c5c80d61e2..2acbf3b2ed829 100644 --- a/onnxruntime/core/providers/cuda/llm/attention.h +++ b/onnxruntime/core/providers/cuda/llm/attention.h @@ -40,6 +40,20 @@ class Attention final : public CudaKernel { Tensor* output_qk, const attention_helper::AttentionParameters& parameters) const; + // GQA-capable unfused fallback. Handles: + // - GQA (q_num_heads != kv_num_heads) without K/V head replication. + // - fp16/bf16 with large head_size (FP32 QK accumulation, fixes #28195). + // - past_key+past_value, attn_mask (bool/float), nonpad_kv_seqlen. + // Does not support: output_qk + // (output_qk modes other than kNone are rejected upstream). + Status RunGqaUnfusedAttention( + OpKernelContext* context, + const Tensor* Q, const Tensor* K, const Tensor* V, + const Tensor* attn_mask, const Tensor* past_key, const Tensor* past_value, + const Tensor* nonpad_kv_seqlen, + Tensor* Y, Tensor* present_key, Tensor* present_value, + const attention_helper::AttentionParameters& parameters) const; + Status ConvertAttnMaskToBias( OpKernelContext* context, const Tensor* attn_mask, diff --git a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc index d287ce5da1504..0cf95141b7a6c 100644 --- a/onnxruntime/test/providers/cpu/llm/attention_op_test.cc +++ b/onnxruntime/test/providers/cpu/llm/attention_op_test.cc @@ -2308,5 +2308,370 @@ TEST(AttentionTest, Attention_NonPadKVSeqLen_WithFloatAttnMask_MultiBatch) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } +// GQA unfused attention with FP32 QK accumulation for large head_size (> 128). +// This exercises the RunGqaUnfusedAttention path in attention.cc which uses +// an FP32 scratch buffer for QK matmul to prevent overflow in fp16. +TEST(AttentionTest, Attention_GqaUnfused_LargeHeadSize_FP16) { + if (!HasCudaEnvironment(530)) { + return; // fp16 requires SM 5.3+ + } + + OpTester test("Attention", 24, onnxruntime::kOnnxDomain); + + // head_size=256 > 128 triggers needs_fp32_qk_scratch in the CUDA Attention kernel. + // GQA: q_num_heads=4, kv_num_heads=2. + int batch_size = 2; + int q_num_heads = 4; + int kv_num_heads = 2; + int q_sequence_length = 2; + int kv_sequence_length = 4; + int head_size = 256; + + int q_elements = batch_size * q_num_heads * q_sequence_length * head_size; + int k_elements = batch_size * kv_num_heads * kv_sequence_length * head_size; + int v_elements = k_elements; + + // Use constant Q and K so softmax produces uniform weights. + std::vector q(q_elements, 0.01f); + std::vector k(k_elements, 0.01f); + // V: each KV position s gets value (s+1)*0.1 across all head dims. + std::vector v(v_elements); + for (int b = 0; b < batch_size; b++) { + for (int n = 0; n < kv_num_heads; n++) { + for (int s = 0; s < kv_sequence_length; s++) { + float val = static_cast(s + 1) * 0.1f; + for (int h = 0; h < head_size; h++) { + v[(b * kv_num_heads * kv_sequence_length + n * kv_sequence_length + s) * head_size + h] = val; + } + } + } + } + + test.AddAttribute("kv_num_heads", kv_num_heads); + test.AddAttribute("q_num_heads", q_num_heads); + + test.AddInput("Q", {batch_size, q_num_heads, q_sequence_length, head_size}, ToFloat16(q)); + test.AddInput("K", {batch_size, kv_num_heads, kv_sequence_length, head_size}, ToFloat16(k)); + test.AddInput("V", {batch_size, kv_num_heads, kv_sequence_length, head_size}, ToFloat16(v)); + test.AddOptionalInputEdge(); // attn_mask + test.AddOptionalInputEdge(); // past_key + test.AddOptionalInputEdge(); // past_value + + // Uniform softmax over all 4 KV positions → output = mean of V. + // mean = (0.1 + 0.2 + 0.3 + 0.4) / 4 = 0.25 + int y_elements = batch_size * q_num_heads * q_sequence_length * head_size; + std::vector expected_y(y_elements, 0.25f); + test.AddOutput("Y", {batch_size, q_num_heads, q_sequence_length, head_size}, + ToFloat16(expected_y), false, 0, 0.02f); + test.AddOptionalOutputEdge(); // present_key + test.AddOptionalOutputEdge(); // present_value + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// GQA unfused attention with causal mask and large head_size. +// Verifies that is_causal works correctly in the unfused GQA path. +TEST(AttentionTest, Attention_GqaUnfused_LargeHeadSize_Causal_FP16) { + if (!HasCudaEnvironment(530)) { + return; // fp16 requires SM 5.3+ + } + + OpTester test("Attention", 24, onnxruntime::kOnnxDomain); + + // Self-attention: q_sequence_length == kv_sequence_length with causal mask. + int batch_size = 1; + int q_num_heads = 4; + int kv_num_heads = 2; + int sequence_length = 3; + int head_size = 256; + + int q_elements = batch_size * q_num_heads * sequence_length * head_size; + int k_elements = batch_size * kv_num_heads * sequence_length * head_size; + int v_elements = k_elements; + + // Constant Q and K → equal attention scores before causal masking. + std::vector q(q_elements, 0.01f); + std::vector k(k_elements, 0.01f); + // V: position s gets value (s+1)*0.1 + std::vector v(v_elements); + for (int n = 0; n < kv_num_heads; n++) { + for (int s = 0; s < sequence_length; s++) { + float val = static_cast(s + 1) * 0.1f; + for (int h = 0; h < head_size; h++) { + v[(n * sequence_length + s) * head_size + h] = val; + } + } + } + + test.AddAttribute("is_causal", static_cast(1)); + test.AddAttribute("kv_num_heads", kv_num_heads); + test.AddAttribute("q_num_heads", q_num_heads); + + test.AddInput("Q", {batch_size, q_num_heads, sequence_length, head_size}, ToFloat16(q)); + test.AddInput("K", {batch_size, kv_num_heads, sequence_length, head_size}, ToFloat16(k)); + test.AddInput("V", {batch_size, kv_num_heads, sequence_length, head_size}, ToFloat16(v)); + test.AddOptionalInputEdge(); // attn_mask + test.AddOptionalInputEdge(); // past_key + test.AddOptionalInputEdge(); // past_value + + // With causal mask and constant Q/K, each position attends uniformly to itself and prior positions. + // Position 0: attends to [0] → output = V[0] = 0.1 + // Position 1: attends to [0,1] → output = mean(0.1, 0.2) = 0.15 + // Position 2: attends to [0,1,2] → output = mean(0.1, 0.2, 0.3) = 0.2 + int y_elements = batch_size * q_num_heads * sequence_length * head_size; + std::vector expected_y(y_elements); + float pos_values[] = {0.1f, 0.15f, 0.2f}; + for (int n = 0; n < q_num_heads; n++) { + for (int s = 0; s < sequence_length; s++) { + for (int h = 0; h < head_size; h++) { + expected_y[(n * sequence_length + s) * head_size + h] = pos_values[s]; + } + } + } + test.AddOutput("Y", {batch_size, q_num_heads, sequence_length, head_size}, + ToFloat16(expected_y), false, 0, 0.02f); + test.AddOptionalOutputEdge(); // present_key + test.AddOptionalOutputEdge(); // present_value + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// GQA unfused with past_key + attn_mask: exercises concat + bias path together. +TEST(AttentionTest, Attention_GqaUnfused_PastKey_AttnMask_FP16) { + if (!HasCudaEnvironment(530)) { + return; + } + + OpTester test("Attention", 24, onnxruntime::kOnnxDomain); + + int batch_size = 1; + int q_num_heads = 4; + int kv_num_heads = 2; + int q_sequence_length = 1; // decode step + int kv_sequence_length = 1; // one new token + int past_sequence_length = 2; + int total_sequence_length = past_sequence_length + kv_sequence_length; // 3 + int head_size = 256; + + test.AddAttribute("kv_num_heads", kv_num_heads); + test.AddAttribute("q_num_heads", q_num_heads); + + // Constant Q, K → uniform attention scores before masking. + std::vector q(batch_size * q_num_heads * q_sequence_length * head_size, 0.01f); + std::vector k(batch_size * kv_num_heads * kv_sequence_length * head_size, 0.01f); + // V new token: position 2 value = 0.3 + std::vector v(batch_size * kv_num_heads * kv_sequence_length * head_size, 0.3f); + + std::vector past_key(batch_size * kv_num_heads * past_sequence_length * head_size, 0.01f); + // Past V: position 0 = 0.1, position 1 = 0.2 + std::vector past_value(batch_size * kv_num_heads * past_sequence_length * head_size); + for (int n = 0; n < kv_num_heads; n++) { + for (int s = 0; s < past_sequence_length; s++) { + float val = static_cast(s + 1) * 0.1f; + for (int h = 0; h < head_size; h++) { + past_value[(n * past_sequence_length + s) * head_size + h] = val; + } + } + } + + float neg_inf = -std::numeric_limits::infinity(); + test.AddInput("Q", {batch_size, q_num_heads, q_sequence_length, head_size}, ToFloat16(q)); + test.AddInput("K", {batch_size, kv_num_heads, kv_sequence_length, head_size}, ToFloat16(k)); + test.AddInput("V", {batch_size, kv_num_heads, kv_sequence_length, head_size}, ToFloat16(v)); + test.AddInput("attn_mask", {q_sequence_length, total_sequence_length}, + ToFloat16(std::vector{0.0f, neg_inf, 0.0f})); + test.AddInput("past_key", + {batch_size, kv_num_heads, past_sequence_length, head_size}, ToFloat16(past_key)); + test.AddInput("past_value", + {batch_size, kv_num_heads, past_sequence_length, head_size}, ToFloat16(past_value)); + + // Mask position 1 → uniform over positions 0, 2 → mean(0.1, 0.3) = 0.2 + std::vector expected_y(batch_size * q_num_heads * q_sequence_length * head_size, 0.2f); + test.AddOutput("Y", {batch_size, q_num_heads, q_sequence_length, head_size}, + ToFloat16(expected_y), false, 0, 0.02f); + + // present_key: all 0.01 + std::vector expected_pk(batch_size * kv_num_heads * total_sequence_length * head_size, 0.01f); + test.AddOutput("present_key", + {batch_size, kv_num_heads, total_sequence_length, head_size}, + ToFloat16(expected_pk), false, 0, 0.01f); + + // present_value: pos 0→0.1, pos 1→0.2, pos 2→0.3 + std::vector expected_pv(batch_size * kv_num_heads * total_sequence_length * head_size); + for (int n = 0; n < kv_num_heads; n++) { + for (int s = 0; s < total_sequence_length; s++) { + float val = static_cast(s + 1) * 0.1f; + for (int h = 0; h < head_size; h++) { + expected_pv[(n * total_sequence_length + s) * head_size + h] = val; + } + } + } + test.AddOutput("present_value", + {batch_size, kv_num_heads, total_sequence_length, head_size}, + ToFloat16(expected_pv), false, 0, 0.01f); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// GQA unfused with softcap + attn_mask: verifies the softcap + bias interaction. +TEST(AttentionTest, Attention_GqaUnfused_Softcap_AttnMask_FP16) { + if (!HasCudaEnvironment(530)) { + return; + } + + OpTester test("Attention", 24, onnxruntime::kOnnxDomain); + + int batch_size = 1; + int q_num_heads = 4; + int kv_num_heads = 2; + int q_sequence_length = 1; + int kv_sequence_length = 3; + int head_size = 256; + + test.AddAttribute("kv_num_heads", kv_num_heads); + test.AddAttribute("q_num_heads", q_num_heads); + test.AddAttribute("softcap", 50.0f); + + std::vector q(batch_size * q_num_heads * q_sequence_length * head_size, 0.01f); + std::vector k(batch_size * kv_num_heads * kv_sequence_length * head_size, 0.01f); + std::vector v(batch_size * kv_num_heads * kv_sequence_length * head_size); + for (int n = 0; n < kv_num_heads; n++) { + for (int s = 0; s < kv_sequence_length; s++) { + float val = static_cast(s + 1) * 0.1f; + for (int h = 0; h < head_size; h++) { + v[(n * kv_sequence_length + s) * head_size + h] = val; + } + } + } + + float neg_inf = -std::numeric_limits::infinity(); + test.AddInput("Q", {batch_size, q_num_heads, q_sequence_length, head_size}, ToFloat16(q)); + test.AddInput("K", {batch_size, kv_num_heads, kv_sequence_length, head_size}, ToFloat16(k)); + test.AddInput("V", {batch_size, kv_num_heads, kv_sequence_length, head_size}, ToFloat16(v)); + test.AddInput("attn_mask", {q_sequence_length, kv_sequence_length}, + ToFloat16(std::vector{0.0f, neg_inf, 0.0f})); + test.AddOptionalInputEdge(); // past_key + test.AddOptionalInputEdge(); // past_value + + // softcap(50) with near-zero logits is ~identity → uniform over positions 0,2. + // mean(0.1, 0.3) = 0.2 + std::vector expected_y(batch_size * q_num_heads * q_sequence_length * head_size, 0.2f); + test.AddOutput("Y", {batch_size, q_num_heads, q_sequence_length, head_size}, + ToFloat16(expected_y), false, 0, 0.02f); + test.AddOptionalOutputEdge(); // present_key + test.AddOptionalOutputEdge(); // present_value + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// GQA unfused with BSNH (3D) input: previous tests all use 4D BNSH input. +TEST(AttentionTest, Attention_GqaUnfused_BSNH_FP16) { + if (!HasCudaEnvironment(530)) { + return; + } + + OpTester test("Attention", 24, onnxruntime::kOnnxDomain); + + int batch_size = 1; + int q_num_heads = 4; + int kv_num_heads = 2; + int q_sequence_length = 2; + int kv_sequence_length = 4; + int head_size = 256; + int q_hidden = q_num_heads * head_size; // 1024 + int kv_hidden = kv_num_heads * head_size; // 512 + + test.AddAttribute("kv_num_heads", kv_num_heads); + test.AddAttribute("q_num_heads", q_num_heads); + + std::vector q(batch_size * q_sequence_length * q_hidden, 0.01f); + std::vector k(batch_size * kv_sequence_length * kv_hidden, 0.01f); + // BSNH V: position s gets value (s+1)*0.1 across all head dims. + std::vector v(batch_size * kv_sequence_length * kv_hidden); + for (int s = 0; s < kv_sequence_length; s++) { + float val = static_cast(s + 1) * 0.1f; + for (int d = 0; d < kv_hidden; d++) { + v[s * kv_hidden + d] = val; + } + } + + test.AddInput("Q", {batch_size, q_sequence_length, q_hidden}, ToFloat16(q)); + test.AddInput("K", {batch_size, kv_sequence_length, kv_hidden}, ToFloat16(k)); + test.AddInput("V", {batch_size, kv_sequence_length, kv_hidden}, ToFloat16(v)); + test.AddOptionalInputEdge(); // attn_mask + test.AddOptionalInputEdge(); // past_key + test.AddOptionalInputEdge(); // past_value + + // Uniform over 4 positions → mean(0.1, 0.2, 0.3, 0.4) = 0.25 + std::vector expected_y(batch_size * q_sequence_length * q_hidden, 0.25f); + test.AddOutput("Y", {batch_size, q_sequence_length, q_hidden}, + ToFloat16(expected_y), false, 0, 0.02f); + test.AddOptionalOutputEdge(); // present_key + test.AddOptionalOutputEdge(); // present_value + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +// GQA unfused with fp32: exercises the float template instantiation. +TEST(AttentionTest, Attention_GqaUnfused_FP32) { + if (!HasCudaEnvironment(0)) { + return; + } + + OpTester test("Attention", 24, onnxruntime::kOnnxDomain); + + // GQA triggers the unfused path for fp32 regardless of head_size. + int batch_size = 1; + int q_num_heads = 4; + int kv_num_heads = 2; + int q_sequence_length = 2; + int kv_sequence_length = 4; + int head_size = 8; + + test.AddAttribute("kv_num_heads", kv_num_heads); + test.AddAttribute("q_num_heads", q_num_heads); + + std::vector q(batch_size * q_num_heads * q_sequence_length * head_size, 0.1f); + std::vector k(batch_size * kv_num_heads * kv_sequence_length * head_size, 0.1f); + std::vector v(batch_size * kv_num_heads * kv_sequence_length * head_size); + for (int n = 0; n < kv_num_heads; n++) { + for (int s = 0; s < kv_sequence_length; s++) { + float val = static_cast(s + 1) * 0.1f; + for (int h = 0; h < head_size; h++) { + v[(n * kv_sequence_length + s) * head_size + h] = val; + } + } + } + + test.AddInput("Q", {batch_size, q_num_heads, q_sequence_length, head_size}, q); + test.AddInput("K", {batch_size, kv_num_heads, kv_sequence_length, head_size}, k); + test.AddInput("V", {batch_size, kv_num_heads, kv_sequence_length, head_size}, v); + test.AddOptionalInputEdge(); // attn_mask + test.AddOptionalInputEdge(); // past_key + test.AddOptionalInputEdge(); // past_value + + // Uniform over 4 positions → mean(0.1, 0.2, 0.3, 0.4) = 0.25 + int y_elements = batch_size * q_num_heads * q_sequence_length * head_size; + std::vector expected_y(y_elements, 0.25f); + test.AddOutput("Y", {batch_size, q_num_heads, q_sequence_length, head_size}, + expected_y, false, 0, 1e-4f); + test.AddOptionalOutputEdge(); // present_key + test.AddOptionalOutputEdge(); // present_value + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/python/transformers/benchmark_gqa.py b/onnxruntime/test/python/transformers/benchmark_gqa.py index 3a835d0852a9d..10e7ea953a503 100644 --- a/onnxruntime/test/python/transformers/benchmark_gqa.py +++ b/onnxruntime/test/python/transformers/benchmark_gqa.py @@ -266,6 +266,13 @@ def run_performance_test( (32, 96, 32, 131072, None, "Phi-3-mini-128k"), (32, 128, 8, 131072, None, "Phi-3-small-128k"), # Sparsity is not used in this test (40, 128, 10, 131072, None, "Phi-3-medium-128K"), + # Gemma 4 global attention layers: num_attention_heads=8, + # num_key_value_heads=4, head_dim=512. Head_dim > 256 is unsupported by + # Flash / Memory-Efficient Attention, so this exercises the GQA unfused + # fallback kernel (issue #28195). Listed twice: global (dense) and local + # (sliding window) variants. + (8, 512, 4, 32768, None, "Gemma4-global-h512"), + (8, 512, 4, 32768, 4096, "Gemma4-local-h512"), ] if fast: diff --git a/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py b/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py index 7b05d364309d9..ff9f2edd9d002 100644 --- a/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py +++ b/onnxruntime/test/python/transformers/test_cuda_plugin_ep.py @@ -344,6 +344,22 @@ def run_operator_test( def run_provider_options_test(provider_options, expect_plugin_provider=True): require_cuda_plugin_ep() + + # When we expect the plugin provider to work, verify that at least one plugin device is available. + # Device enumeration can fail in some CI environments even when the plugin library loads successfully. + if expect_plugin_provider: + try: + devices = onnxrt.get_ep_devices() + plugin_devices = [d for d in devices if d.ep_name == CUDA_PLUGIN_EP_NAME] + if not plugin_devices: + raise unittest.SkipTest( + f"{CUDA_PLUGIN_EP_NAME} registered but no plugin devices enumerated in this environment" + ) + except unittest.SkipTest: + raise + except Exception as exc: + raise unittest.SkipTest(f"Failed to enumerate plugin EP devices: {exc}") from exc + with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmp: model_path = tmp.name try: diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py index 23c47e84c1630..55c8b56ae027a 100644 --- a/onnxruntime/test/python/transformers/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_gqa.py @@ -2550,6 +2550,123 @@ def test_gqa_fp8_fallback_unsupported_head_size(self): atol=5e-2, ) + # ------------------------------------------------------------------------ + # Gemma 4 global attention layers (issue #28195): num_attention_heads=8, + # num_key_value_heads=4, head_dim=512. The unfused CUDA runner produced + # NaN at head_dim=512, scale=1.0 because raw Q*K^T overflowed fp16 even + # though cuBLAS accumulated in FP32 (output C was fp16). The new GQA + # unfused kernel writes QK to an FP32 scratch and fixes this. + # ------------------------------------------------------------------------ + def _run_gemma4_gqa( + self, + torch_type, + ort_type, + q_sequence_length, + past_kv_sequence_length, + is_prompt, + local_window_size=-1, + softcap=0.0, + ): + if not has_cuda_provider(): + self.skipTest("CUDA required") + if torch_type == torch.bfloat16 and not torch.cuda.is_bf16_supported(): + self.skipTest("BFloat16 not supported on this device") + + # Force the unfused path: disable Flash (doesn't support head_size>256) + # and Memory-Efficient Attention (cutlass FMHA caps at head_size 256 too). + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + os.environ["ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION"] = "1" + self.addCleanup(os.environ.pop, "ORT_DISABLE_FLASH_ATTENTION", None) + self.addCleanup(os.environ.pop, "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION", None) + + config = GQAConfig( + batch_size=1, + num_heads=8, + kv_num_heads=4, + head_size=512, + q_sequence_length=q_sequence_length, + kv_sequence_length=q_sequence_length, + past_kv_sequence_length=past_kv_sequence_length, + buffer_sequence_length=q_sequence_length + past_kv_sequence_length + 8, + local_window_size=local_window_size, + rotary=False, + rotary_interleaved=False, + packed=False, + share_buffer=True, + softcap=softcap, + use_smooth_softmax=False, + has_head_sink=False, + has_position_ids=False, + ) + + dtype_key = "fp16" if ort_type == TensorProto.FLOAT16 else "bf16" + check = parity_check_gqa_prompt if is_prompt else parity_check_gqa_past + check( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch_type, + ort_type=ort_type, + causal=True, + rtol=rtol[dtype_key], + atol=atol[dtype_key], + ) + + def test_gqa_gemma4_global_prompt_fp16(self): + """#28195 exact repro: fp16 prompt with head_dim=512, Gemma 4 head config.""" + self._run_gemma4_gqa( + torch.float16, TensorProto.FLOAT16, q_sequence_length=16, past_kv_sequence_length=0, is_prompt=True + ) + + def test_gqa_gemma4_global_decode_fp16(self): + """#28195: fp16 decode with past KV at head_dim=512.""" + self._run_gemma4_gqa( + torch.float16, TensorProto.FLOAT16, q_sequence_length=1, past_kv_sequence_length=64, is_prompt=False + ) + + def test_gqa_gemma4_global_decode_fp16_long(self): + """Gemma 4 global attention with longer past at head_dim=512.""" + self._run_gemma4_gqa( + torch.float16, TensorProto.FLOAT16, q_sequence_length=1, past_kv_sequence_length=2048, is_prompt=False + ) + + def test_gqa_gemma4_global_prompt_bf16(self): + """Gemma 4 global attention in bf16 prompt phase at head_dim=512.""" + self._run_gemma4_gqa( + torch.bfloat16, TensorProto.BFLOAT16, q_sequence_length=16, past_kv_sequence_length=0, is_prompt=True + ) + + def test_gqa_gemma4_global_decode_bf16(self): + """Gemma 4 global attention in bf16 decode phase at head_dim=512.""" + self._run_gemma4_gqa( + torch.bfloat16, TensorProto.BFLOAT16, q_sequence_length=1, past_kv_sequence_length=64, is_prompt=False + ) + + def test_gqa_gemma4_global_prompt_fp16_softcap(self): + """Gemma 4 global attention with softcap (Gemma family uses logit softcap).""" + self._run_gemma4_gqa( + torch.float16, + TensorProto.FLOAT16, + q_sequence_length=16, + past_kv_sequence_length=0, + is_prompt=True, + softcap=50.0, + ) + + def test_gqa_gemma4_local_window_decode_fp16(self): + """ + Gemma 4 has mixed global + sliding-window (local) attention layers. This + exercises the unfused kernel's sliding-window mask at head_dim=512. + """ + self._run_gemma4_gqa( + torch.float16, + TensorProto.FLOAT16, + q_sequence_length=1, + past_kv_sequence_length=256, + is_prompt=False, + local_window_size=128, + ) + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py b/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py index ceca17d6fc155..c4e3c1b19e85e 100644 --- a/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_onnx_attention/test_gqa.py @@ -1343,5 +1343,194 @@ def test_gqa_prompt_float_mask_4d(self): numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol["fp16"], atol=atol["fp16"]) +# ################################################################################################# +# Large Head Size Unfused GQA Tests (head_size=512, fixes #28195) +# +# Flash Attention and Memory-Efficient Attention cap at head_size=256. For head_size=512 the +# op falls through to RunGqaUnfusedAttention which writes Q*K^T to an FP32 scratch buffer, +# eliminating fp16/bf16 overflow that caused NaNs (e.g. Gemma 4 global-attention layers). +# +# These tests deliberately disable both Flash and MEA to make the unfused fallback explicit +# and to guard against future changes that might inadvertently route large-head configs +# away from the FP32-scratch path. +# ################################################################################################# + + +def gqa_large_head_unfused_test_cases(): + """Test cases for GQA with head_size=512 (unfused FP32-QK path, fixes #28195).""" + # prompt phase + for b, sq in [(1, 16), (2, 64)]: + for softcap in [0.0, 50.0]: + config = AttentionConfig( + batch_size=b, + q_sequence_length=sq, + kv_sequence_length=sq, + past_kv_sequence_length=0, + q_num_heads=8, + kv_num_heads=4, + head_size=512, + is_causal=1, + softcap=softcap, + ) + yield f"prompt_b{b}_sq{sq}_sc{softcap}", config + + # decode phase (past KV cache) + for b, past in [(1, 32), (2, 128)]: + config = AttentionConfig( + batch_size=b, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=past, + q_num_heads=8, + kv_num_heads=4, + head_size=512, + is_causal=1, + softcap=0.0, + ) + yield f"decode_b{b}_past{past}", config + + # prompt with boolean attn_mask (exercises ConvertAttnMaskToBias + unfused bias path) + config = AttentionConfig( + batch_size=2, + q_sequence_length=32, + kv_sequence_length=32, + past_kv_sequence_length=0, + q_num_heads=8, + kv_num_heads=4, + head_size=512, + is_causal=1, + has_attn_mask=True, + ) + yield "prompt_attn_mask", config + + # prompt with nonpad_kv_seqlen (opset 24, exercises seqlens_k path in unfused kernel) + config = AttentionConfig( + batch_size=2, + q_sequence_length=32, + kv_sequence_length=32, + past_kv_sequence_length=0, + q_num_heads=8, + kv_num_heads=4, + head_size=512, + is_causal=1, + has_nonpad_kv_seqlen=True, + ) + yield "prompt_nonpad_seqlen", config + + +@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping large head unfused tests.") +@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1", "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION": "1"}) +class TestONNXAttentionGQALargeHeadUnfused(unittest.TestCase): + """ + Regression tests for GQA with head_size=512 via the unfused FP32-QK path (issue #28195). + + Flash Attention and MEA both cap at head_size=256. With both disabled the op routes + to RunGqaUnfusedAttention, which writes Q*K^T to an FP32 scratch buffer to avoid + fp16/bf16 overflow that produced NaNs for Gemma 4 global-attention layers. + + Validates: no NaNs, numerical parity vs. PyTorch SDPA reference, for fp16 and bf16. + """ + + @parameterized.expand(gqa_large_head_unfused_test_cases()) + def test_gqa_large_head_unfused_fp16(self, name, config): + func = parity_check_gqa_past if "decode" in name else parity_check_gqa_prompt + kwargs = dict( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=rtol["fp16"], + atol=atol["fp16"], + ) + func(**kwargs) + + @parameterized.expand(gqa_large_head_unfused_test_cases()) + def test_gqa_large_head_unfused_bf16(self, name, config): + if not torch.cuda.is_bf16_supported(): + self.skipTest("BFloat16 not supported on this device") + func = parity_check_gqa_past if "decode" in name else parity_check_gqa_prompt + kwargs = dict( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.bfloat16, + ort_type=TensorProto.BFLOAT16, + causal=True, + rtol=rtol["bf16"], + atol=atol["bf16"], + ) + func(**kwargs) + + def test_gqa_large_head_unfused_softcap_additive_mask_poison_fp16(self): + config = AttentionConfig( + batch_size=1, + q_sequence_length=1, + kv_sequence_length=3, + past_kv_sequence_length=0, + q_num_heads=8, + kv_num_heads=4, + head_size=512, + is_causal=0, + softcap=1.0, + has_attn_mask=True, + attn_mask_dims=4, + attn_mask_type="additive", + ) + + device = "cuda" + torch_type = torch.float16 + q = torch.zeros( + config.batch_size, + config.q_sequence_length, + config.q_num_heads, + config.head_size, + device=device, + dtype=torch_type, + ) + k = torch.zeros( + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads, + config.head_size, + device=device, + dtype=torch_type, + ) + v = torch.full_like(k, 0.2) + v[:, 1, :, :] = 1000.0 + + attn_mask = torch.zeros( + config.batch_size, + config.q_num_heads, + config.q_sequence_length, + config.kv_sequence_length, + device=device, + dtype=torch_type, + ) + attn_mask[:, :, :, 1] = float("-inf") + + out_ort, _, _ = attention_prompt_func( + q=q, + k=k, + v=v, + config=config, + attn_mask=attn_mask, + ep="CUDAExecutionProvider", + device=device, + ort_type=TensorProto.FLOAT16, + ) + + out = out_ort.reshape( + config.batch_size, + config.q_sequence_length, + config.q_num_heads, + config.head_size, + ) + expected = torch.full_like(out, 0.2) + torch.testing.assert_close(out, expected, rtol=0, atol=2e-2) + self.assertLess(out.float().max().item(), 1.0) + + if __name__ == "__main__": unittest.main()