[WebGPU] QKV and MLP layer fusions for Qwen3-style models#28280
[WebGPU] QKV and MLP layer fusions for Qwen3-style models#28280hariharans29 wants to merge 32 commits into
Conversation
…xruntime into hari/webgpu_perf_1
There was a problem hiding this comment.
Pull request overview
This PR adds WebGPU-focused fused operators and optimizer passes for decoder-style MatMulNBits patterns (MLP gate/up and QKV projections), along with tests and a microbenchmark to evaluate decode performance/correctness.
Changes:
- Introduces new contrib ops
MatMulNBitsMlpandMatMulNBitsQkv(schemas + WebGPU kernels + WGSL templates). - Adds graph transformers
MatMulNBitsMlpFusion/MatMulNBitsQkvFusionand corresponding optimizer tests. - Improves WebGPU runtime support (graph-capture buffer manager activation, queue-idle wait helper, better shader compilation diagnostics) and adds a decode microbenchmark.
Reviewed changes
Copilot reviewed 33 out of 33 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/test/optimizer/matmul_nbits_qkv_fusion_test.cc | New unit tests validating QKV fusion and output contracts on WebGPU. |
| onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc | New unit tests validating MLP fusion (simplified/skip + passthrough) on WebGPU. |
| onnxruntime/test/optimizer/graph_transform_utils_test.cc | Minor formatting-only tweak (blank line). |
| onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc | New benchmark harness for fused/unfused decode paths on WebGPU. |
| onnxruntime/test/onnx/microbenchmark/main.cc | Adjusts benchmark env logging severity. |
| onnxruntime/core/session/ort_version_check.h | Makes version parsing consteval-friendly with a macro fallback. |
| onnxruntime/core/providers/webgpu/webgpu_execution_provider.h | Tracks when graph-capture buffer manager is active. |
| onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc | Lazily creates/activates graph buffer manager for capture; allocator uses dynamic buffer manager getter. |
| onnxruntime/core/providers/webgpu/webgpu_context.h | Adds WaitForQueueIdle() declaration. |
| onnxruntime/core/providers/webgpu/webgpu_context.cc | Implements WaitForQueueIdle() using OnSubmittedWorkDone. |
| onnxruntime/core/providers/webgpu/program_manager.cc | Enhances pipeline build failures with shader compilation diagnostics. |
| onnxruntime/core/providers/webgpu/compute_context.h | Adds FlushAndWait() convenience for flushing + waiting on queue idle. |
| onnxruntime/core/providers/webgpu/allocator.h | Adds allocator ctor that accepts a buffer-manager getter function. |
| onnxruntime/core/providers/webgpu/allocator.cc | Implements getter-based allocator to support switching buffer managers. |
| onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.h | New transformer declaration for QKV fusion. |
| onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.cc | New transformer implementation for QKV fusion. |
| onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.h | New transformer declaration for MLP fusion. |
| onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.cc | New transformer implementation for MLP fusion. |
| onnxruntime/core/optimizer/graph_transformer_utils.cc | Registers the new fusion transformers. |
| onnxruntime/core/graph/contrib_ops/contrib_defs.cc | Adds contrib operator schemas/docs for MatMulNBitsMlp and MatMulNBitsQkv. |
| onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc | Registers WebGPU kernels for the new fused ops. |
| onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.wgsl.template | New WGSL template implementing fused QKV decode kernel. |
| onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.h | New WebGPU kernel wrapper for MatMulNBitsQkv. |
| onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc | New WebGPU kernel implementation for MatMulNBitsQkv. |
| onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp_wide_tile_m1.wgsl.template | New WGSL template for an MLP wide-tile variant. |
| onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.wgsl.template | New WGSL template implementing fused MLP (optionally with norm/skip/passthrough). |
| onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.h | New WebGPU kernel wrapper for MatMulNBitsMlp. |
| onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc | New WebGPU kernel implementation for MatMulNBitsMlp. |
| onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h | Adds declarations for “would apply” dispatch-selection helpers and shared constants. |
| onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc | Implements the new dispatch-selection helpers. |
| onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc | Refactors path selection to use the new “would apply” helpers. |
| onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_mlp.wgsl.template | Adds WGSL template for DP4A MLP path. |
| cmake/onnxruntime_unittests.cmake | Wires the new WebGPU decode benchmark into the benchmark target sources. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…shader diagnostics These changes are kept on hari/webgpu_perf_1_full locally. The lazy buffer-mgr fix is being submitted as a separate PR (branch hari/webgpu_graph_capture_buffer_fix) because it is an independent correctness fix for a pre-existing latent bug, exposed but not introduced by these fusions.
This template file was added speculatively but is not referenced by any kernel, include, or build rule. Removing to keep the PR clean.
…_transformer_utils
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 20 out of 20 changed files in this pull request and generated 5 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
The shared-EP path through TransformerTester triggers a SEH 0xC0000005 in CI when the EP outlives a per-session profiler whose pointer is still cached on the EP. A separate fix to the WebGPU EP's session_profiler_ lifetime is in flight; meanwhile, switch the 8 MatMulNBits MLP and QKV WebGPU fusion-vs- unfused tests to a small RunWebGpuFusionTransformerTest helper that creates a fresh execution provider per session via a factory lambda. Production code is unchanged.
qjia7
left a comment
There was a problem hiding this comment.
File: onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc, lines 33-127
ApplySimplifiedLayerNorm (lines 33-75) and ApplySkipSimplifiedLayerNorm (lines 77-127) duplicate the dispatch logic from LayerNormProgram in core/providers/webgpu/nn/layer_norm.cc and SkipLayerNormProgram in contrib_ops/webgpu/bert/skip_layer_norm.cc respectively.
The split condition (norm_size % 512 == 0 && norm_count == 1), workgroup sizing (workgroup_size_x = 128), uniform variable layout, and component handling are all replicated. If any of these change in the original kernels (e.g., workgroup size tuning, a new split threshold, or additional uniform variables), the copies here will silently diverge.
Consider extracting these as reusable utility functions in skip_layer_norm.h / layer_norm.h and calling them from here, rather than duplicating the setup logic.
qjia7
left a comment
There was a problem hiding this comment.
For MLP fusion, how do you think we extend the current matmulnits to support matmulnbits + activation + mul? Then for prefill, just call two matmulnbits.
And in future, maybe we can optimize builder.py to directly generate the new fused MatMulNbitsMLP op in onnx model.
For QKV fusion, microsoft/onnxruntime-genai#2137 resolves the QKV packing issue in builder.py. Do we still need dynamic fusion in the code after this fix?
The GenAI model builder QKV fusion does 3 MatmulNBits -> 1 large QKV MatmulNBits + 1 Split This ORT dynamic fusion does 1 SLN + 3 MatmulNBits-> 1 QKV MatmulNBits + produce skip values I can update my fusion logic to understand both the "unfused" Qwen QKV (1 SLN + 3 MatmulNBits) and the new fused nodes (1 SLN + 1 MatmulNBits + 1 Split). Would that be okay ? |
I am fine with the current status. You can add 1 SLN + 1 MatmulNBits + 1 Split support in follow-up PRs. For the MLP one, can you also extent it to support Gemma model (gate+up GELU gated)? Maybe also need to fuse cast (which is used to improve the quality microsoft/onnxruntime-genai#1448) into it since I see the subgraph in gemma3 is like below: |
| const Node& sigmoid, | ||
| const Node& silu_mul, | ||
| const Node& final_mul) { | ||
| if (!IsMatMulNBitsWithoutZeroPointOrGroupIdx(gate_matmul) || !IsMatMulNBitsWithoutZeroPointOrGroupIdx(up_matmul) || |
There was a problem hiding this comment.
fyi, with #28410, you may also need to support QuickGelu.
Addressed |
Can I add support for the Gemma3 pattern in a future PR ? In this PR, I added "extensibility" support for the activation - we can support multiple gated activations in the future - but this PR will only support Silu (for the Qwen use-case). In future, we can update the dynamic fusion to support the If you are okay - I will open 2 issues in the repo as follow-ups for me:
|
Re: "For MLP fusion, how do you think we extend the current matmulnits to support matmulnbits + activation + mul? Then for prefill, just call two matmulnbits. And in future, maybe we can optimize builder.py to directly generate the new fused MatMulNbitsMLP op in onnx model." Is this comment still valid ? Is this an approach you'd like me to explore ? |
Refactor in response to PR review feedback. The MatMulNBits MLP and QKV
fusion kernels previously each carried their own private copies of the
SimplifiedLayerNormalization and SkipSimplifiedLayerNormalization
program launchers (`GetOverrideShape` + `ApplySimplifiedLayerNorm` +
`ApplySkipSimplifiedLayerNorm`). Extract these into reusable helpers
exposed by the existing LayerNorm / SkipLayerNorm kernel sources so
fused kernels can drop the duplication.
* core/providers/webgpu/nn/layer_norm.{h,cc}:
- Expose `RunLayerNormProgram(...)` so other kernels can launch the
simplified layer-norm program with consistent uniforms / shape
overrides.
* contrib_ops/webgpu/bert/skip_layer_norm.{h,cc}:
- Expose `RunSkipLayerNormProgram(...)` mirroring the same shape for
the SkipSimplifiedLayerNormalization variant.
* contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc:
- Adopt the shared helpers and delete the local copies. No behavior
change; emitted WGSL and dispatch are byte-identical.
Two coupled cleanups to the MatMulNBitsMlp kernel, kept together because
they touch the same file:
1. Adopt the shared `RunLayerNormProgram` / `RunSkipLayerNormProgram`
helpers introduced in the prior commit. Deletes the local copies of
`GetOverrideShape`, `ApplySimplifiedLayerNorm`, and
`ApplySkipSimplifiedLayerNorm`. No behavior change.
2. Introduce a small `MlpActivationKind` enum so the kernel can later
gain GELU / GELU+Cast support (e.g. for Gemma-style MLPs) without
reshaping the call paths or schema. Today the enum has a single
value, `Silu = 0`, and the emitted WGSL is byte-identical to before.
* matmul_nbits_mlp.h:
- Add `MlpActivationKind` enum and `ParseMlpActivation()`. Kernel
stores the parsed kind in `activation_kind_`.
* matmul_nbits_mlp.cc:
- Thread `MlpActivationKind` through `MatMulNBitsMlpProgram`,
`MatMulNBitsMlpDecodeProgram`, and `ApplyUnfusedMlp`. Include
the kind in each program's CacheHint.
- Add `EmitGateActivationExpr()` so the inline kernel emits the
activation expression via a single helper; today returns the
SiLU expression.
- While here, collapse the four identical
`WGSL_TEMPLATE_APPLY` branches in
`MatMulNBitsMlpDecodeProgram::GenerateShaderCode` into one call.
Roughly 120 lines removed; emitted WGSL unchanged.
* matmul_nbits_mlp.wgsl.template:
- Add `#param activation_kind`. Wrap SiLU emission in
`#if activation_kind == 0 ... #endif` and produce the activated
value through a single `activated_value` binding so additional
activations can be added with a new `#elif` branch.
The schema already declares `activation` as a generic `STRING`, so no
schema change is required.
After QuickGeluFusion is enabled for the WebGPU EP (upstream PR #28410), the SwiGLU gate subgraph `gate * Sigmoid(gate)` is collapsed into a single `com.microsoft::QuickGelu(gate, alpha=1.0)` node before MatMulNBitsMlpFusion runs. Without this change, the MLP fusion would silently stop firing for Qwen3 / Llama / Phi style WebGPU models. * core/optimizer/matmul_nbits_mlp_fusion.cc: - Recognize the QuickGelu-decomposed shape gate_matmul -> com.microsoft::QuickGelu(alpha=1.0) -> final_mul in addition to the existing Sigmoid+Mul shape. Validates QuickGelu's `alpha == 1.0` (SiLU-equivalent). - Factor common pair validation into `ValidateMatMulNBitsPair` and keep shape-specific checks in `IsFuseCandidateSilu` and `IsFuseCandidateQuickGelu`. - Restructure the main matching loop to dispatch on which shape was found and track the intermediate nodes to remove in a small vector, so the node-removal block stays uniform across shapes. * core/providers/webgpu/math/unary_elementwise_ops.h: - Fix the `QuickGeluImpl` WGSL shader for fp16 by wrapping `1.0`, `0.0`, and `uniforms.attr` in `x_element_t(...)` casts. Without this, pipeline creation fails on fp16 models with `Invalid ShaderModule "QuickGelu"`. Matches the fix in PR #28410 so the in-tree build can run QuickGelu on fp16 models immediately rather than waiting on that PR to land. * test/optimizer/matmul_nbits_mlp_fusion_test.cc: - Add unit coverage mirroring the existing SiLU tests. Introduces an `ActivationShape` enum and parameterizes the existing test-pattern builder. The graph-shape checkers now also assert zero `com.microsoft.QuickGelu` nodes after fusion. Adds four tests: * Fusion only (Simplified-LN anchor) * Fusion only (Skip-Simplified-LN anchor) * Fused vs unfused correctness on WebGPU (Simplified-LN) * Fused vs unfused correctness on WebGPU (Skip-Simplified-LN) Correctness tests use a slightly looser 5e-3 tolerance because the by-cases sigmoid in the QuickGelu shader produces marginally different fp16 rounding than the fused kernel's direct SiLU evaluation; the two are mathematically equivalent.
Widen `QuickGeluFusion`'s compatible-EP set from `cpu_acl_cuda_dml_eps` to `cpu_acl_cuda_dml_js_webgpu_eps` so the `x * Sigmoid(x)` SwiGLU gate pattern is folded into a single `com.microsoft::QuickGelu` node on WebGPU and JSEP models. Without this, the QuickGelu match branch added to `MatMulNBitsMlpFusion` in the prior commit is unreachable on real WebGPU models, and the `QuickGelu` fp16 shader fix in `unary_elementwise_ops.h` cannot be exercised end-to-end. Mirrors upstream PR #28410 (registers `QuickGeluFusion` for WebGPU/JSEP and fixes the `QuickGelu` fp16 shader). This commit is expected to be redundant once #28410 lands; rebase will drop it cleanly.


Description
Summary
Adds two WebGPU-only graph fusions and the contrib ops they target, plus a small
refactor of the existing
MatMulNBitsdispatch logic so the new fused kernelscan share its predicates.
MatMulNBitsMlpop + kernelcontrib_ops/webgpu/quantization/matmul_nbits_mlp.{cc,h},*.wgsl.template(3)(Skip)SimplifiedLayerNormalization+ twoMatMulNBitsprojections (gate, up) + optional biases +Sigmoid/Mul(SiLU) + element-wiseMul. Single dispatch instead of 5–7.MatMulNBitsQkvop + kernelcontrib_ops/webgpu/quantization/matmul_nbits_qkv.{cc,h},*.wgsl.template(Skip)SimplifiedLayerNormalization+ threeMatMulNBitsprojections (Q, K, V) sharing the same input. Single dispatch instead of 4.core/graph/contrib_ops/contrib_defs.ccMatMulNBitsMlpandMatMulNBitsQkvcontrib op schemas (kMSDomain, opset 1).core/optimizer/matmul_nbits_{mlp,qkv}_fusion.{cc,h}graph_transformer_utils.cc.contrib_ops/webgpu/quantization/matmul_nbits_common.{cc,h}+matmul_nbits.ccMatMulNBitspath.test/optimizer/matmul_nbits_{mlp,qkv}_fusion_test.cc,graph_transform_utils_test.ccMotivation and Context
~25-30% decode TPS throughput improvement on WebGPU + D3D backend on Windows. GPU used: RTX 5060Ti for Qwe3-1.7B.
BEFORE (95 decode TPS): main branch

AFTER (120+ decode TPS): PR branch
