Skip to content

[Feature Request] CUDA EP: Support head_dim > 256 in GroupQueryAttention kernel #28196

@justinchuby

Description

@justinchuby

Feature Request

The CUDA EP GroupQueryAttention kernel enforces MAX_HEAD_SIZE = 256, rejecting any model with head_dim > 256. This forces such layers to fall back to the standard Attention op, whose unfused runner produces NaN for fp16 (see #28195).

Motivation

Models like Gemma 4 use a hybrid attention architecture with two head dimensions:

  • Sliding-attention layers: head_dim=256 → GQA works perfectly ✓
  • Full-attention layers: head_dim=512 → GQA rejects, must use Attention → NaN on CUDA

This creates a situation where the model cannot run on CUDA EP at all for the full-attention layers, because:

  1. GQA refuses head_dim=512
  2. Attention unfused runner produces NaN in fp16 (CUDA EP: Unfused Attention runner produces NaN for fp16 with head_dim > 256 #28195)
  3. Flash Attention also has a head_dim <= 256 limit

Requested behavior

Extend the GroupQueryAttention CUDA kernel to support head_dim > 256 (at least up to 512). This would allow models like Gemma 4 to use GQA for all layers, avoiding the broken unfused Attention runner entirely.

Workaround

Currently none for CUDA EP with these layers in fp16. CPU EP works correctly.

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    ep:CUDAissues related to the CUDA execution provider

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions