diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index fb6a4eb22a872..9d549ac5e1219 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -57,6 +57,8 @@ Do not modify directly.*
* com.microsoft.MatMulInteger16
* com.microsoft.MatMulIntegerToFloat
* com.microsoft.MatMulNBits
+ * com.microsoft.MatMulNBitsMlp
+ * com.microsoft.MatMulNBitsQkv
* com.microsoft.MaxpoolWithMask
* com.microsoft.MoE
* com.microsoft.MulInteger
@@ -3189,6 +3191,190 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.MatMulNBitsMlp**
+
+ MatMulNBitsMlp fuses two MatMulNBits projections that share the same input and computes
+
+ gate = MatMulNBits(A, gate_weight) + gate_bias
+ up = MatMulNBits(A, up_weight) + up_bias
+ Y = activation(gate) * up
+
+ It can also optionally fuse SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization before the
+ two projections:
+
+ A_norm = SimplifiedLayerNormalization(A, norm_scale, epsilon)
+ gate = MatMulNBits(A_norm, gate_weight) + gate_bias
+ up = MatMulNBits(A_norm, up_weight) + up_bias
+ Y = activation(gate) * up
+
+ A_norm = SkipSimplifiedLayerNormalization(A, skip, norm_scale, epsilon)
+ gate = MatMulNBits(A_norm, gate_weight) + gate_bias
+ up = MatMulNBits(A_norm, up_weight) + up_bias
+ Y = activation(gate) * up
+
+ This operator is intended for decoder MLP patterns such as Qwen-style gate and up projections, but it remains
+ semantically valid for both prefill and decode because the output shape is the standard MatMul result shape
+ derived from the runtime shape of A and the shared attributes K and N.
+
+ The operator contract includes a string attribute describing the fused gate activation.
+
+ When fused from SkipSimplifiedLayerNormalization, the optional residual-sum output may also be materialized:
+
+ A_norm, input_skip_bias_sum = SkipSimplifiedLayerNormalization(A, skip, norm_scale, epsilon)
+ gate = MatMulNBits(A_norm, gate_weight) + gate_bias
+ up = MatMulNBits(A_norm, up_weight) + up_bias
+ Y = activation(gate) * up
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Attributes
+
+
+- K : int (required)
+- Input feature dimension shared by both quantized weight matrices.
+- N : int (required)
+- Output feature dimension shared by both quantized weight matrices.
+- accuracy_level : int
+- The minimum accuracy level of input A. It follows the same semantics as MatMulNBits.
+- activation : string (required)
+- Activation applied to the gate projection.
+- bits : int
+- Bit-width used to quantize both weight matrices (valid range: 2~8)
+- block_size : int (required)
+- Size of each quantization block along the K dimension. Must be a power of two and >= 16.
+- epsilon : float
+- Epsilon used by the optional fused (Skip)SimplifiedLayerNormalization. Defaults to 1e-5.
+
+
+#### Inputs (8 - 9)
+
+
+- A : T1
+- The shared input tensor.
+- skip (optional) : T1
+- Optional skip input used by SkipSimplifiedLayerNormalization.
+- norm_scale (optional) : T1
+- Optional RMSNorm scale with shape [K] used by SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization.
+- gate_B : T2
+- Packed uint8 tensor for the gate projection weights.
+- gate_scales : T1
+- Per-block scaling factors for the gate projection.
+- gate_bias (optional) : T1
+- Optional bias for the gate projection with shape [N].
+- up_B : T2
+- Packed uint8 tensor for the up projection weights.
+- up_scales : T1
+- Per-block scaling factors for the up projection.
+- up_bias (optional) : T1
+- Optional bias for the up projection with shape [N].
+
+
+#### Outputs (1 - 2)
+
+
+- Y : T1
+- The fused gated MLP output tensor.
+- input_skip_bias_sum (optional) : T1
+- Optional residual-sum output for SkipSimplifiedLayerNormalization.
+
+
+#### Type Constraints
+
+
+- T1 : tensor(float), tensor(float16), tensor(bfloat16)
+- Constrain input and output types to float tensors.
+- T2 : tensor(uint8)
+- Constrain quantized weight types to uint8.
+
+
+
+### **com.microsoft.MatMulNBitsQkv**
+
+ MatMulNBitsQkv fuses either SimplifiedLayerNormalization (RMSNorm)
+ or SkipSimplifiedLayerNormalization with three MatMulNBits projections that share the
+ same normalized activation.
+
+ A_norm = SimplifiedLayerNormalization(A, norm_scale, epsilon)
+ Q = MatMulNBits(A_norm, q_weight)
+ K = MatMulNBits(A_norm, k_weight)
+ V = MatMulNBits(A_norm, v_weight)
+
+ If skip is provided, the operator computes the SkipSimplifiedLayerNormalization variant
+ and may also return the input+skip residual sum as output 3.
+
+ This operator is intended as a decode-oriented QKV fusion primitive.
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Attributes
+
+
+- K : int (required)
+- Input feature dimension shared by the normalized input and all projection weights.
+- Nkv : int (required)
+- Output feature dimension shared by the K and V projections.
+- Nq : int (required)
+- Output feature dimension of the Q projection.
+- accuracy_level : int
+- The minimum accuracy level of input A. It follows the same semantics as MatMulNBits.
+- bits : int
+- Bit-width used to quantize all weight matrices (valid range: 2~8)
+- block_size : int (required)
+- Size of each quantization block along the K dimension. Must be a power of two and >= 16.
+- epsilon : float
+- Epsilon used by the simplified layer norm reduction.
+
+
+#### Inputs
+
+
+- A : T1
+- The shared input tensor.
+- skip (optional) : T1
+- Optional residual input for SkipSimplifiedLayerNormalization.
+- norm_scale : T1
+- Scale input for the simplified layer norm with shape [K].
+- q_B : T2
+- Packed uint8 tensor for the Q projection weights.
+- q_scales : T1
+- Per-block scaling factors for the Q projection.
+- k_B : T2
+- Packed uint8 tensor for the K projection weights.
+- k_scales : T1
+- Per-block scaling factors for the K projection.
+- v_B : T2
+- Packed uint8 tensor for the V projection weights.
+- v_scales : T1
+- Per-block scaling factors for the V projection.
+
+
+#### Outputs (3 - 4)
+
+
+- Q : T1
+- The Q projection output tensor.
+- K : T1
+- The K projection output tensor.
+- V : T1
+- The V projection output tensor.
+- input_skip_bias_sum (optional) : T1
+- Optional residual-sum output for SkipSimplifiedLayerNormalization.
+
+
+#### Type Constraints
+
+
+- T1 : tensor(float), tensor(float16), tensor(bfloat16)
+- Constrain input and output types to float tensors.
+- T2 : tensor(uint8)
+- Constrain quantized weight types to uint8.
+
+
+
### **com.microsoft.MaxpoolWithMask**
For internal use.
diff --git a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc
index 619aff1e806b8..473121edbf5dc 100644
--- a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc
+++ b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc
@@ -154,8 +154,26 @@ Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeCo
auto* output = context.Output(0, x_shape);
auto* input_skip_bias_sum = context.Output(3, x_shape);
- int64_t data_size = x_shape.Size();
- if (data_size == 0) {
+ if (x_shape.Size() == 0) {
+ return Status::OK();
+ }
+
+ return RunSkipLayerNormProgram(context, x, skip, gamma, beta, bias, epsilon_, simplified,
+ output, input_skip_bias_sum);
+}
+
+Status RunSkipLayerNormProgram(ComputeContext& context,
+ const Tensor* x,
+ const Tensor* skip,
+ const Tensor* gamma,
+ const Tensor* beta,
+ const Tensor* bias,
+ float epsilon,
+ bool simplified,
+ Tensor* output,
+ Tensor* input_skip_bias_sum) {
+ const auto& x_shape = x->Shape();
+ if (x_shape.Size() == 0) {
return Status::OK();
}
@@ -165,18 +183,17 @@ Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeCo
const uint32_t norm_count = onnxruntime::narrow(x_shape.SizeToDimension(x_shape.NumDimensions() - 1));
const bool split_hidden_dim = hidden_size % 512 == 0 && norm_count == 1;
- const auto skip_shape = skip->Shape();
- const uint32_t skip_size = onnxruntime::narrow(skip_shape.Size());
+ const uint32_t skip_size = onnxruntime::narrow(skip->Shape().Size());
SkipLayerNormProgram program{
- beta != nullptr, bias != nullptr, epsilon_, hidden_size, has_input_skip_bias_sum, simplified, split_hidden_dim};
+ beta != nullptr, bias != nullptr, epsilon, hidden_size, has_input_skip_bias_sum, simplified, split_hidden_dim};
program
.CacheHint(simplified, beta != nullptr, bias != nullptr, has_input_skip_bias_sum, split_hidden_dim)
.AddInputs({{x, ProgramTensorMetadataDependency::Type, components}})
.AddInputs({{skip, ProgramTensorMetadataDependency::Type, components}})
.AddInputs({{gamma, ProgramTensorMetadataDependency::Type, components}})
.AddOutputs({{output, ProgramTensorMetadataDependency::None, components}})
- .SetDispatchGroupSize(onnxruntime::narrow(ceil(1.0 * data_size / hidden_size)))
+ .SetDispatchGroupSize(onnxruntime::narrow(ceil(1.0 * x_shape.Size() / hidden_size)))
.AddUniformVariables({
{static_cast(components)},
})
@@ -184,7 +201,7 @@ Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeCo
{static_cast(hidden_size)},
})
.AddUniformVariables({
- {static_cast(epsilon_)},
+ {static_cast(epsilon)},
})
.AddUniformVariables({
{static_cast(skip_size)},
diff --git a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h
index bfaec1c3d0d79..0430074bb2ae0 100644
--- a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h
+++ b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h
@@ -60,6 +60,21 @@ class SkipLayerNorm final : public WebGpuKernel {
float epsilon_;
};
+// Configures and dispatches a SkipLayerNormProgram. Centralizes program-setup logic
+// (uniform variables, components, split_hidden_dim heuristic, workgroup sizing) so callers
+// other than the SkipLayerNorm kernel (e.g. fused MatMulNBits ops) do not need to duplicate it.
+// `beta`, `bias` and `input_skip_bias_sum` may be nullptr.
+Status RunSkipLayerNormProgram(ComputeContext& context,
+ const Tensor* x,
+ const Tensor* skip,
+ const Tensor* gamma,
+ const Tensor* beta,
+ const Tensor* bias,
+ float epsilon,
+ bool simplified,
+ Tensor* output,
+ Tensor* input_skip_bias_sum);
+
} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
index b4e3991344089..e0a78aab1220b 100644
--- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
+++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
@@ -18,10 +18,6 @@ namespace onnxruntime {
namespace contrib {
namespace webgpu {
-namespace {
-constexpr unsigned int kMinMForTileOptimization = 4;
-} // namespace
-
ONNX_OPERATOR_KERNEL_EX(
MatMulNBits,
kMSDomain,
@@ -226,29 +222,44 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
uint32_t zero_blocks_per_col = (n_blocks_per_col + zp_elements_per_byte - 1) / zp_elements_per_byte * zp_elements_per_byte;
#if !defined(__wasm__)
+ // apple|intel - Experimental dawn support for subgroup matrix matmul.
int32_t subgroup_matrix_config_index = -1;
- // Experimental dawn support for subgroup matrix matmul (vendor-agnostic).
- if ((M >= kMinMForTileOptimization && !has_weight_idx_indirect) &&
- CanApplySubgroupMatrixMatMulNBits(context, accuracy_level, block_size, batch_count, N, K, static_cast(nbits), y->DataType() == DataTypeImpl::GetType(), subgroup_matrix_config_index)) {
+ if (WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(M,
+ N,
+ K,
+ batch_count,
+ block_size,
+ accuracy_level,
+ nbits,
+ context,
+ y,
+ has_weight_idx_indirect,
+ &subgroup_matrix_config_index,
+ override_M)) {
return ApplySubgroupMatrixMatMulNBits(a, b, scales, zero_points, bias, M, N, K, static_cast(nbits), zero_blocks_per_col, subgroup_matrix_config_index, context, y, weight_index, weight_index_indirect);
}
#endif
// On FP32 only GPUs and Qualcomm GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M.
// DP4A Q2 path now supports custom zero points via a 1024-entry LUT (4 zero-point sections × 256 byte values).
- if (((M >= kMinMForTileOptimization && !has_weight_idx_indirect) || y->DataType() == DataTypeImpl::GetType() || context.AdapterInfo().vendor == std::string_view{"qualcomm"}) &&
- CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a)) {
+ if (WouldApplyDP4AMatMulNBitsInCurrentDispatch(M,
+ N,
+ K,
+ block_size,
+ accuracy_level,
+ context,
+ y,
+ has_weight_idx_indirect)) {
return ApplyDP4AMatrixMatMulNBits(a, b, scales, zero_points, bias, batch_count, M, dispatch_M, N, K, block_size, zero_blocks_per_col, kMinMForTileOptimization, static_cast(nbits), context, y, weight_index, weight_index_indirect);
}
// WideTileProgram
// This program is optimized for Block32 prefill using Tile16x128.
- const bool use_wide_tile_program = !has_weight_idx_indirect &&
- block_size == 32 &&
- components_a == 4 &&
- components_b == 4 &&
- nbits != 2 &&
- M >= kMinMForTileOptimization;
+ const bool use_wide_tile_program = WouldApplyWideTileMatMulNBitsInCurrentDispatch(M,
+ K,
+ block_size,
+ nbits,
+ has_weight_idx_indirect);
if (use_wide_tile_program) {
// Enforce output components to 1.
@@ -308,7 +319,8 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
// Use tile_size_k_vec=32 by default for better K-dimension parallelism.
// Intel devices use 16 as they have different subgroup/cache characteristics.
- const uint32_t tile_size_k_vec = (context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u;
+ const uint32_t tile_size_k_vec =
+ (context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u;
constexpr uint32_t workgroup_size = 128;
constexpr uint32_t tile_size = 8;
diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc
index b9eafeb43c7b6..488194a60a31a 100644
--- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc
+++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc
@@ -2,9 +2,16 @@
// Licensed under the MIT License.
#include "contrib_ops/webgpu/quantization/matmul_nbits_common.h"
+
#include
+
#include "core/common/common.h"
+#include "contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h"
+#include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h"
+#include "core/providers/cpu/math/matmul_helper.h"
#include "core/providers/webgpu/webgpu_context.h"
+#include "core/providers/webgpu/webgpu_utils.h"
+#include "core/framework/tensor_shape.h"
namespace onnxruntime {
namespace contrib {
@@ -61,6 +68,160 @@ bool HasDP4ADeviceSupport(int context_id) {
ctx.AdapterInfo().vendor != std::string_view{"apple"};
}
+bool WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(uint32_t M,
+ [[maybe_unused]] uint32_t N,
+ [[maybe_unused]] uint32_t K,
+ [[maybe_unused]] uint32_t batch_count,
+ [[maybe_unused]] uint32_t block_size,
+ [[maybe_unused]] int64_t accuracy_level,
+ [[maybe_unused]] int64_t nbits,
+ [[maybe_unused]] onnxruntime::webgpu::ComputeContext& context,
+ [[maybe_unused]] Tensor* y,
+ [[maybe_unused]] bool has_weight_idx_indirect,
+ [[maybe_unused]] int32_t* subgroup_matrix_config_index,
+ [[maybe_unused]] uint32_t override_M) {
+ [[maybe_unused]] const uint32_t dispatch_M = override_M > 0 ? override_M : M;
+
+#if !defined(__wasm__)
+ int32_t local_subgroup_matrix_config_index = -1;
+ if (dispatch_M != M) {
+ return false;
+ }
+
+ return (M >= kMinMForTileOptimization && !has_weight_idx_indirect) &&
+ CanApplySubgroupMatrixMatMulNBits(context,
+ accuracy_level,
+ block_size,
+ batch_count,
+ N,
+ K,
+ static_cast(nbits),
+ y->DataType() == DataTypeImpl::GetType(),
+ subgroup_matrix_config_index != nullptr ? *subgroup_matrix_config_index : local_subgroup_matrix_config_index);
+#else
+ return false;
+#endif
+}
+
+bool WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(const Tensor* a,
+ int64_t K_op,
+ int64_t N_op,
+ int64_t block_size_op,
+ int64_t accuracy_level,
+ int64_t nbits,
+ onnxruntime::webgpu::ComputeContext& context,
+ Tensor* y,
+ bool has_weight_idx_indirect,
+ int32_t* subgroup_matrix_config_index,
+ uint32_t override_M) {
+ TensorShape b_shape({N_op, K_op});
+ MatMulComputeHelper helper;
+ if (!helper.Compute(a->Shape(), b_shape, false, true).IsOK()) {
+ return false;
+ }
+
+ return WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(
+ onnxruntime::narrow(helper.M()),
+ onnxruntime::narrow(helper.N()),
+ onnxruntime::narrow(helper.K()),
+ onnxruntime::narrow(helper.OutputOffsets().size()),
+ onnxruntime::narrow(block_size_op),
+ accuracy_level,
+ nbits,
+ context,
+ y,
+ has_weight_idx_indirect,
+ subgroup_matrix_config_index,
+ override_M);
+}
+
+bool WouldApplyDP4AMatMulNBitsInCurrentDispatch(uint32_t M,
+ uint32_t N,
+ uint32_t K,
+ uint32_t block_size,
+ int64_t accuracy_level,
+ onnxruntime::webgpu::ComputeContext& context,
+ Tensor* y,
+ bool has_weight_idx_indirect) {
+ const uint32_t components_a = GetMaxComponents(K);
+
+ return ((M >= kMinMForTileOptimization && !has_weight_idx_indirect) ||
+ y->DataType() == DataTypeImpl::GetType() ||
+ context.AdapterInfo().vendor == std::string_view{"qualcomm"}) &&
+ CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a);
+}
+
+bool WouldApplyDP4AMatMulNBitsInCurrentDispatch(const Tensor* a,
+ int64_t K_op,
+ int64_t N_op,
+ int64_t block_size_op,
+ int64_t accuracy_level,
+ onnxruntime::webgpu::ComputeContext& context,
+ Tensor* y,
+ bool has_weight_idx_indirect) {
+ TensorShape b_shape({N_op, K_op});
+ MatMulComputeHelper helper;
+ if (!helper.Compute(a->Shape(), b_shape, false, true).IsOK()) {
+ return false;
+ }
+
+ return WouldApplyDP4AMatMulNBitsInCurrentDispatch(
+ onnxruntime::narrow(helper.M()),
+ onnxruntime::narrow(helper.N()),
+ onnxruntime::narrow(helper.K()),
+ onnxruntime::narrow(block_size_op),
+ accuracy_level,
+ context,
+ y,
+ has_weight_idx_indirect);
+}
+
+bool WouldApplyWideTileMatMulNBitsInCurrentDispatch(uint32_t M,
+ uint32_t K,
+ uint32_t block_size,
+ int64_t nbits,
+ bool has_weight_idx_indirect) {
+ if (has_weight_idx_indirect) {
+ return false;
+ }
+
+ const uint32_t components_a = GetMaxComponents(K);
+ const uint32_t block_size_per_col = block_size;
+ const uint32_t blob_size = (block_size_per_col / 8) * static_cast(nbits);
+ const uint32_t blob_size_in_words = blob_size / 4;
+ const uint32_t components_b = GetMaxComponents(blob_size_in_words);
+
+ return block_size == 32 &&
+ components_a == 4 &&
+ components_b == 4 &&
+ nbits != 2 &&
+ M >= kMinMForTileOptimization;
+}
+
+bool WouldApplyWideTileMatMulNBitsInCurrentDispatch(const Tensor* a,
+ int64_t K_op,
+ int64_t N_op,
+ int64_t block_size_op,
+ int64_t nbits,
+ bool has_weight_idx_indirect) {
+ if (has_weight_idx_indirect) {
+ return false;
+ }
+
+ TensorShape b_shape({N_op, K_op});
+ MatMulComputeHelper helper;
+ if (!helper.Compute(a->Shape(), b_shape, false, true).IsOK()) {
+ return false;
+ }
+
+ return WouldApplyWideTileMatMulNBitsInCurrentDispatch(
+ onnxruntime::narrow(helper.M()),
+ onnxruntime::narrow(helper.K()),
+ onnxruntime::narrow(block_size_op),
+ nbits,
+ has_weight_idx_indirect);
+}
+
} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h
index 3db7c722b11eb..f3277e536ae62 100644
--- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h
+++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h
@@ -6,10 +6,20 @@
#include
#include
+namespace onnxruntime {
+class Tensor;
+
+namespace webgpu {
+class ComputeContext;
+} // namespace webgpu
+} // namespace onnxruntime
+
namespace onnxruntime {
namespace contrib {
namespace webgpu {
+inline constexpr uint32_t kMinMForTileOptimization = 4u;
+
/**
* Generates WebGPU shader code for reading zero points in quantized matrix multiplication
*
@@ -26,6 +36,65 @@ std::string GenerateZeroPointReadingCode(uint32_t nbits, bool has_zero_points,
/// \p context_id is the WebGpuContext slot (0 for the default context).
bool HasDP4ADeviceSupport(int context_id = 0);
+bool WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(const Tensor* a,
+ int64_t K_op,
+ int64_t N_op,
+ int64_t block_size_op,
+ int64_t accuracy_level,
+ int64_t nbits,
+ onnxruntime::webgpu::ComputeContext& context,
+ Tensor* y,
+ bool has_weight_idx_indirect = false,
+ int32_t* subgroup_matrix_config_index = nullptr,
+ uint32_t override_M = 0);
+
+// Precomputed-dims overload for callers (e.g., ApplyMatMulNBits, MatMulNBitsMlp,
+// MatMulNBitsQkv) that have already run MatMulComputeHelper and have M/N/K and
+// batch_count in scope. Avoids re-running shape inference per dispatch decision.
+bool WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(uint32_t M,
+ uint32_t N,
+ uint32_t K,
+ uint32_t batch_count,
+ uint32_t block_size,
+ int64_t accuracy_level,
+ int64_t nbits,
+ onnxruntime::webgpu::ComputeContext& context,
+ Tensor* y,
+ bool has_weight_idx_indirect = false,
+ int32_t* subgroup_matrix_config_index = nullptr,
+ uint32_t override_M = 0);
+
+bool WouldApplyDP4AMatMulNBitsInCurrentDispatch(const Tensor* a,
+ int64_t K_op,
+ int64_t N_op,
+ int64_t block_size_op,
+ int64_t accuracy_level,
+ onnxruntime::webgpu::ComputeContext& context,
+ Tensor* y,
+ bool has_weight_idx_indirect = false);
+
+bool WouldApplyDP4AMatMulNBitsInCurrentDispatch(uint32_t M,
+ uint32_t N,
+ uint32_t K,
+ uint32_t block_size,
+ int64_t accuracy_level,
+ onnxruntime::webgpu::ComputeContext& context,
+ Tensor* y,
+ bool has_weight_idx_indirect = false);
+
+bool WouldApplyWideTileMatMulNBitsInCurrentDispatch(const Tensor* a,
+ int64_t K_op,
+ int64_t N_op,
+ int64_t block_size_op,
+ int64_t nbits,
+ bool has_weight_idx_indirect = false);
+
+bool WouldApplyWideTileMatMulNBitsInCurrentDispatch(uint32_t M,
+ uint32_t K,
+ uint32_t block_size,
+ int64_t nbits,
+ bool has_weight_idx_indirect = false);
+
} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc
new file mode 100644
index 0000000000000..8dd454083c9b3
--- /dev/null
+++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc
@@ -0,0 +1,539 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "contrib_ops/webgpu/quantization/matmul_nbits_mlp.h"
+
+#include
+
+#include "contrib_ops/webgpu/quantization/matmul_nbits.h"
+#include "contrib_ops/webgpu/quantization/matmul_nbits_common.h"
+#include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h"
+#include "contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h"
+#include "contrib_ops/webgpu/bert/skip_layer_norm.h"
+#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
+#include "core/providers/cpu/math/matmul_helper.h"
+#include "core/providers/webgpu/nn/layer_norm.h"
+#include "core/providers/webgpu/shader_helper.h"
+#include "core/providers/webgpu/webgpu_supported_types.h"
+#include "core/providers/webgpu/webgpu_utils.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace webgpu {
+
+Status ParseMlpActivation(std::string_view name, MlpActivationKind* out) {
+ if (name == "silu") {
+ *out = MlpActivationKind::Silu;
+ return Status::OK();
+ }
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "MatMulNBitsMlp: activation '", name, "' is not supported.");
+}
+
+namespace {
+
+constexpr uint32_t kFusedDecodeFastPathBits = 4u;
+constexpr uint32_t kFusedDecodeFastPathBlockSize = 32u;
+
+// Emits the WGSL expression that applies the gate activation. The result must use
+// the variable names produced by the inline kernel (`gate_value`) and the shader
+// template (`gate_output_value`), so callers pass the gate operand variable name.
+// Adding a new activation here is the kernel-side counterpart to extending the
+// MlpActivationKind enum.
+std::string EmitGateActivationExpr(MlpActivationKind kind, std::string_view gate_var) {
+ switch (kind) {
+ case MlpActivationKind::Silu:
+ // SiLU(x) = x * sigmoid(x)
+ return std::string{gate_var} + " * (one / (one + exp(-" + std::string{gate_var} + ")))";
+ }
+ ORT_THROW("MatMulNBitsMlp: unhandled MlpActivationKind ", static_cast(kind));
+}
+
+class MatMulNBitsMlpDecodeProgram final : public Program {
+ public:
+ MatMulNBitsMlpDecodeProgram(uint32_t tile_size,
+ bool has_gate_bias,
+ bool has_up_bias,
+ bool has_norm_input,
+ bool has_skip_input,
+ bool has_skip_output,
+ bool single_scale_weights,
+ uint32_t tile_size_k_vec,
+ uint32_t k_unroll_tiles,
+ MlpActivationKind activation_kind)
+ : Program{"MatMulNBitsMlpDecode"},
+ tile_size_(tile_size),
+ has_gate_bias_(has_gate_bias),
+ has_up_bias_(has_up_bias),
+ has_norm_input_(has_norm_input),
+ has_skip_input_(has_skip_input),
+ has_skip_output_(has_skip_output),
+ single_scale_weights_(single_scale_weights),
+ tile_size_k_vec_(tile_size_k_vec),
+ k_unroll_tiles_(k_unroll_tiles),
+ activation_kind_(activation_kind) {}
+
+ Status GenerateShaderCode(ShaderHelper& shader) const override {
+ const auto& a = shader.AddInput("input_a", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
+ const auto* skip = has_skip_input_ ? &shader.AddInput("skip", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias) : nullptr;
+ const auto* norm_scale = has_norm_input_ ? &shader.AddInput("norm_scale", ShaderUsage::UseValueTypeAlias) : nullptr;
+ const auto& gate_b = shader.AddInput("gate_b");
+ const auto& gate_scales_b = shader.AddInput("gate_scales_b");
+ const auto& up_b = shader.AddInput("up_b");
+ const auto& up_scales_b = shader.AddInput("up_scales_b");
+ if (has_gate_bias_) {
+ shader.AddInput("gate_bias", ShaderUsage::UseUniform);
+ }
+ if (has_up_bias_) {
+ shader.AddInput("up_bias", ShaderUsage::UseUniform);
+ }
+ const auto& output = shader.AddOutput("output",
+ ShaderUsage::UseElementTypeAlias);
+ const auto* input_skip_bias_sum = has_skip_output_
+ ? &shader.AddOutput("input_skip_bias_sum",
+ ShaderUsage::UseValueTypeAlias |
+ ShaderUsage::UseElementTypeAlias)
+ : nullptr;
+ const auto& skip_var = skip != nullptr ? *skip : a;
+ const auto& norm_scale_var = norm_scale != nullptr ? *norm_scale : a;
+ const auto& input_skip_bias_sum_var = input_skip_bias_sum != nullptr ? *input_skip_bias_sum : output;
+
+ const uint32_t components_a = a.NumComponents();
+ const uint32_t components_b = gate_b.NumComponents() / 4;
+ const uint32_t tile_size_k_vec = tile_size_k_vec_;
+ const uint32_t elements_in_value_b = components_b * 8u;
+ const uint32_t tile_size_k = tile_size_k_vec * elements_in_value_b;
+ const uint32_t a_length_per_tile = tile_size_k / components_a;
+ const uint32_t sub_tile_count = WorkgroupSizeX() / tile_size_k_vec;
+
+ // Template parameters are identical across the (has_skip_input, has_norm_input,
+ // has_skip_output) combinations; only the AddInput/AddOutput wiring upstream changes.
+ // The template's own #if directives select the appropriate code paths.
+ return WGSL_TEMPLATE_APPLY(shader, "quantization/matmul_nbits_mlp.wgsl.template",
+ WGSL_TEMPLATE_PARAMETER(a_length_per_tile, a_length_per_tile),
+ WGSL_TEMPLATE_PARAMETER(activation_kind, static_cast(activation_kind_)),
+ WGSL_TEMPLATE_PARAMETER(component_a, components_a),
+ WGSL_TEMPLATE_PARAMETER(component_b, components_b),
+ WGSL_TEMPLATE_PARAMETER(elements_in_value_b, elements_in_value_b),
+ WGSL_TEMPLATE_PARAMETER(has_gate_bias, has_gate_bias_),
+ WGSL_TEMPLATE_PARAMETER(has_norm_input, has_norm_input_),
+ WGSL_TEMPLATE_PARAMETER(has_skip_input, has_skip_input_),
+ WGSL_TEMPLATE_PARAMETER(has_skip_output, has_skip_output_),
+ WGSL_TEMPLATE_PARAMETER(has_up_bias, has_up_bias_),
+ WGSL_TEMPLATE_PARAMETER(k_unroll_tiles, k_unroll_tiles_),
+ WGSL_TEMPLATE_PARAMETER(single_scale_weights, single_scale_weights_),
+ WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count),
+ WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_),
+ WGSL_TEMPLATE_PARAMETER(tile_size_k, tile_size_k),
+ WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec),
+ WGSL_TEMPLATE_VARIABLE(a, a),
+ WGSL_TEMPLATE_VARIABLE(gate_b, gate_b),
+ WGSL_TEMPLATE_VARIABLE(gate_scales_b, gate_scales_b),
+ WGSL_TEMPLATE_VARIABLE(input_skip_bias_sum, input_skip_bias_sum_var),
+ WGSL_TEMPLATE_VARIABLE(norm_scale, norm_scale_var),
+ WGSL_TEMPLATE_VARIABLE(output, output),
+ WGSL_TEMPLATE_VARIABLE(skip, skip_var),
+ WGSL_TEMPLATE_VARIABLE(up_b, up_b),
+ WGSL_TEMPLATE_VARIABLE(up_scales_b, up_scales_b));
+ }
+
+ WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
+ {"N", ProgramUniformVariableDataType::Uint32},
+ {"K", ProgramUniformVariableDataType::Uint32},
+ {"K_of_a", ProgramUniformVariableDataType::Uint32},
+ {"K_of_b", ProgramUniformVariableDataType::Uint32},
+ {"block_size", ProgramUniformVariableDataType::Uint32},
+ {"blocks_per_col", ProgramUniformVariableDataType::Uint32},
+ {"num_N_tile", ProgramUniformVariableDataType::Uint32},
+ {"batch_count", ProgramUniformVariableDataType::Uint32},
+ {"skip_size", ProgramUniformVariableDataType::Uint32},
+ {"epsilon", ProgramUniformVariableDataType::Float32});
+
+ private:
+ uint32_t tile_size_;
+ bool has_gate_bias_;
+ bool has_up_bias_;
+ bool has_norm_input_;
+ bool has_skip_input_;
+ bool has_skip_output_;
+ bool single_scale_weights_;
+ uint32_t tile_size_k_vec_;
+ uint32_t k_unroll_tiles_;
+ MlpActivationKind activation_kind_;
+};
+
+class MatMulNBitsMlpProgram final : public Program {
+ public:
+ explicit MatMulNBitsMlpProgram(MlpActivationKind activation_kind)
+ : Program{"MatMulNBitsMlp"}, activation_kind_(activation_kind) {
+ CacheHint(static_cast(activation_kind_));
+ }
+
+ Status GenerateShaderCode(ShaderHelper& shader) const override {
+ const auto& gate = shader.AddInput("gate", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
+ const auto& up = shader.AddInput("up", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
+ const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
+
+ shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size")
+ << "let gate_value = " << gate.GetByOffset("global_idx") << ";\n"
+ << "let up_value = " << up.GetByOffset("global_idx") << ";\n"
+ << "let one = output_value_t(1.0);\n"
+ << "let activated_value = " << EmitGateActivationExpr(activation_kind_, "gate_value") << ";\n"
+ << output.SetByOffset("global_idx", "activated_value * up_value");
+
+ return Status::OK();
+ }
+
+ WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32});
+
+ private:
+ MlpActivationKind activation_kind_;
+};
+
+Status ApplyUnfusedMlp(const Tensor* a,
+ const Tensor* gate_b,
+ const Tensor* gate_scales,
+ const Tensor* gate_bias,
+ const Tensor* up_b,
+ const Tensor* up_scales,
+ const Tensor* up_bias,
+ int64_t K,
+ int64_t N,
+ int64_t block_size,
+ int64_t accuracy_level,
+ int64_t bits,
+ MlpActivationKind activation_kind,
+ onnxruntime::webgpu::ComputeContext& context,
+ Tensor* y) {
+ MatMulComputeHelper helper;
+ TensorShape b_shape({N, K});
+ ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true));
+ const auto output_shape = helper.OutputShape();
+
+ Tensor gate_output = context.CreateGPUTensor(a->DataType(), output_shape);
+ Tensor up_output = context.CreateGPUTensor(a->DataType(), output_shape);
+
+ ORT_RETURN_IF_ERROR(ApplyMatMulNBits(a, gate_b, gate_scales, nullptr, gate_bias, K, N, block_size, accuracy_level, bits, context, &gate_output));
+ ORT_RETURN_IF_ERROR(ApplyMatMulNBits(a, up_b, up_scales, nullptr, up_bias, K, N, block_size, accuracy_level, bits, context, &up_output));
+
+ const uint32_t data_size = onnxruntime::narrow(y->Shape().Size());
+ const uint32_t vec_size = (data_size + 3u) / 4u;
+ MatMulNBitsMlpProgram program{activation_kind};
+ program
+ .AddInputs({{&gate_output, ProgramTensorMetadataDependency::Type, ProgramInput::Flatten, 4},
+ {&up_output, ProgramTensorMetadataDependency::Type, ProgramInput::Flatten, 4}})
+ .AddOutput({y, ProgramTensorMetadataDependency::Type, {vec_size}, 4})
+ .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
+ .AddUniformVariables({vec_size});
+
+ return context.RunProgram(program);
+}
+
+} // namespace
+
+ONNX_OPERATOR_KERNEL_EX(
+ MatMulNBitsMlp,
+ kMSDomain,
+ 1,
+ kWebGpuExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("T1", WebGpuSupportedFloatTypes())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ MatMulNBitsMlp);
+
+Status MatMulNBitsMlp::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
+ const Tensor* a = context.Input(0);
+ const Tensor* skip = context.Input(1);
+ const Tensor* norm_scale = context.Input(2);
+ const Tensor* gate_b = context.Input(3);
+ const Tensor* gate_scales = context.Input(4);
+ const Tensor* gate_bias = context.Input(5);
+ const Tensor* up_b = context.Input(6);
+ const Tensor* up_scales = context.Input(7);
+ const Tensor* up_bias = context.Input(8);
+
+ ORT_ENFORCE(skip == nullptr || norm_scale != nullptr,
+ "MatMulNBitsMlp requires norm_scale when skip is present.");
+
+ MatMulComputeHelper helper;
+ TensorShape b_shape({N_, K_});
+ ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true));
+ const auto output_shape = helper.OutputShape();
+ const uint32_t batch_count = onnxruntime::narrow(helper.OutputOffsets().size());
+ const uint32_t M = onnxruntime::narrow(helper.M());
+ const uint32_t N = onnxruntime::narrow(helper.N());
+ const uint32_t K = onnxruntime::narrow(helper.K());
+ const uint32_t block_size = onnxruntime::narrow(block_size_);
+ const uint32_t components_a = GetMaxComponents(K);
+ const bool single_scale_weights = (block_size == K * N);
+ const uint32_t block_size_per_col = single_scale_weights ? K : block_size;
+ const uint32_t n_blocks_per_col = (K + block_size_per_col - 1) / block_size_per_col;
+ const uint32_t blob_size = (block_size_per_col / 8) * onnxruntime::narrow(bits_);
+ const uint32_t blob_size_in_words = blob_size / 4;
+ const uint32_t components_b = GetMaxComponents(blob_size_in_words);
+ constexpr uint32_t kU32Components = 4;
+ const uint32_t components_b_with_u32 = components_b * kU32Components;
+ const uint32_t K_of_b = (n_blocks_per_col * blob_size) / components_b_with_u32;
+
+ Tensor* y = context.Output(0, output_shape);
+ Tensor* input_skip_bias_sum = (skip != nullptr && context.OutputCount() > 1)
+ ? context.Output(1, a->Shape())
+ : nullptr;
+ const uint32_t data_size = onnxruntime::narrow(y->Shape().Size());
+ if (data_size == 0) {
+ return Status::OK();
+ }
+
+ if (norm_scale != nullptr) {
+ ORT_ENFORCE(norm_scale->Shape().Size() == K_, "norm_scale must have shape [K].");
+ }
+
+ const bool has_skip_input = skip != nullptr;
+ const bool has_skip_output = input_skip_bias_sum != nullptr;
+
+ const bool is_decode_fast_path_candidate =
+ M == 1 &&
+ bits_ == kFusedDecodeFastPathBits &&
+ block_size == kFusedDecodeFastPathBlockSize;
+ const bool has_norm_input = norm_scale != nullptr;
+
+ const bool would_use_subgroup_unfused =
+ WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(M,
+ N,
+ K,
+ batch_count,
+ block_size,
+ accuracy_level_,
+ bits_,
+ context,
+ y);
+ const bool would_use_dp4a_unfused =
+ WouldApplyDP4AMatMulNBitsInCurrentDispatch(M,
+ N,
+ K,
+ block_size,
+ accuracy_level_,
+ context,
+ y);
+ const bool would_use_wide_tile_unfused =
+ WouldApplyWideTileMatMulNBitsInCurrentDispatch(M,
+ K,
+ block_size,
+ bits_);
+
+ const bool can_use_decode_fast_path =
+ is_decode_fast_path_candidate &&
+ !would_use_subgroup_unfused &&
+ !would_use_dp4a_unfused &&
+ !would_use_wide_tile_unfused;
+
+ if (can_use_decode_fast_path) {
+ ORT_ENFORCE(bits_ == kFusedDecodeFastPathBits,
+ "MatMulNBitsMlpDecodeProgram is specialized for 4-bit weights only.");
+ ORT_ENFORCE(block_size == kFusedDecodeFastPathBlockSize,
+ "MatMulNBitsMlpDecodeProgram is specialized for block_size=32 only.");
+
+ const bool has_gate_bias = gate_bias != nullptr;
+ const bool has_up_bias = up_bias != nullptr;
+
+ // The fully-fused MLP decode shader binds every weight/scale/bias plus the norm/skip
+ // tensors as storage buffers. Devices with a tight maxStorageBuffersPerShaderStage
+ // (notably macOS Metal at 10) cannot bind that many. For those devices we run the layer
+ // norm separately into a scratch tensor and then dispatch a no-norm variant of the
+ // decode program (which omits the norm_scale, skip, and skip-output bindings, dropping
+ // the storage-buffer count from up to 11 down to 8).
+ //
+ // Storage-buffer count: input_a + (skip?) + (norm_scale?) + 2 * (weight + scales)
+ // + output + (skip output?) + (gate_bias?) + (up_bias?)
+ const uint32_t required_storage_buffers =
+ 1u // input_a
+ + (has_skip_input ? 1u : 0u) // skip
+ + (has_norm_input ? 1u : 0u) // norm_scale
+ + 4u // gate/up weights + scales
+ + 1u // output
+ + (has_skip_output ? 1u : 0u) // skip output
+ + (has_gate_bias ? 1u : 0u) // gate bias
+ + (has_up_bias ? 1u : 0u); // up bias
+ const bool exceeds_storage_buffer_limit =
+ required_storage_buffers > context.DeviceLimits().maxStorageBuffersPerShaderStage;
+
+ // Optionally pre-normalize a into a scratch tensor and drop the norm/skip bindings
+ // from the decode program. The user-visible residual passthrough (input_skip_bias_sum)
+ // is produced by the skip-norm op directly in this path.
+ std::optional normalized_a_storage;
+ const Tensor* decode_a = a;
+ if (exceeds_storage_buffer_limit && has_norm_input) {
+ normalized_a_storage.emplace(context.CreateGPUTensor(a->DataType(), a->Shape()));
+ if (has_skip_input) {
+ ORT_RETURN_IF_ERROR(RunSkipLayerNormProgram(context, a, skip, norm_scale,
+ /*beta=*/nullptr,
+ /*bias=*/nullptr,
+ epsilon_, /*simplified=*/true,
+ &*normalized_a_storage,
+ input_skip_bias_sum));
+ } else {
+ const auto& a_shape = a->Shape();
+ const int64_t norm_size = a_shape[a_shape.NumDimensions() - 1];
+ const uint32_t norm_count = onnxruntime::narrow(a_shape.Size() / norm_size);
+ ORT_RETURN_IF_ERROR(onnxruntime::webgpu::RunLayerNormProgram(
+ context, a, norm_scale, /*bias=*/nullptr, epsilon_, norm_count, norm_size,
+ /*simplified=*/true, &*normalized_a_storage, /*mean=*/nullptr,
+ /*inv_std_dev=*/nullptr));
+ }
+ decode_a = &*normalized_a_storage;
+ }
+
+ // Decode-program-level norm/skip bindings: only used when the device has spare
+ // storage-buffer slots. Otherwise they are wired to the pre-normalized input above.
+ const bool decode_has_norm_input = has_norm_input && !exceeds_storage_buffer_limit;
+ const bool decode_has_skip_input = has_skip_input && !exceeds_storage_buffer_limit;
+ const bool decode_has_skip_output = has_skip_output && !exceeds_storage_buffer_limit;
+
+ uint32_t workgroup_size = 128;
+ uint32_t tile_size = 8;
+ uint32_t tile_size_k_vec =
+ (context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u;
+
+ const uint32_t elements_in_value_b = components_b * (32u / onnxruntime::narrow(bits_));
+ const uint32_t tile_size_k = tile_size_k_vec * elements_in_value_b;
+ const uint32_t k_tile_iterations = K / tile_size_k;
+
+ uint32_t k_unroll_tiles = 1;
+ if ((K % tile_size_k) == 0) {
+ if (k_tile_iterations >= 8 && N <= 2048 && context.AdapterInfo().vendor != std::string_view{"intel"}) {
+ k_unroll_tiles = 4;
+ } else if (k_tile_iterations >= 4) {
+ k_unroll_tiles = 2;
+ }
+ }
+
+ const uint32_t num_N_tile = CeilDiv(N, tile_size);
+
+ MatMulNBitsMlpDecodeProgram program{tile_size,
+ has_gate_bias,
+ has_up_bias,
+ decode_has_norm_input,
+ decode_has_skip_input,
+ decode_has_skip_output,
+ single_scale_weights,
+ tile_size_k_vec,
+ k_unroll_tiles,
+ activation_kind_};
+ program.SetWorkgroupSize(workgroup_size);
+ program.SetDispatchGroupSize(num_N_tile, 1, batch_count);
+ program.AddInput({decode_a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)});
+ if (decode_has_skip_input) {
+ program.AddInput({skip, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)});
+ }
+ if (decode_has_norm_input) {
+ program.AddInput({norm_scale, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)});
+ }
+ program
+ .AddInputs({{gate_b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)},
+ {gate_scales, ProgramTensorMetadataDependency::TypeAndRank},
+ {up_b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)},
+ {up_scales, ProgramTensorMetadataDependency::TypeAndRank}})
+ .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank})
+ .AddUniformVariables({{N},
+ {K},
+ {K / components_a},
+ {K_of_b},
+ {block_size},
+ {n_blocks_per_col},
+ {num_N_tile},
+ {batch_count},
+ {decode_has_skip_input ? onnxruntime::narrow(skip->Shape().Size()) : 0u},
+ {epsilon_}})
+ .CacheHint(single_scale_weights,
+ has_gate_bias,
+ has_up_bias,
+ decode_has_norm_input,
+ decode_has_skip_input,
+ decode_has_skip_output,
+ tile_size_k_vec,
+ k_unroll_tiles,
+ static_cast(activation_kind_),
+ "decode_4bit");
+ if (decode_has_skip_output) {
+ program.AddOutput({input_skip_bias_sum,
+ ProgramTensorMetadataDependency::TypeAndRank,
+ static_cast(components_a)});
+ }
+ if (has_gate_bias) {
+ program.AddInput({gate_bias, ProgramTensorMetadataDependency::None});
+ }
+ if (has_up_bias) {
+ program.AddInput({up_bias, ProgramTensorMetadataDependency::None});
+ }
+
+ return context.RunProgram(program);
+ }
+
+ if (skip != nullptr) {
+ Tensor normalized_a = context.CreateGPUTensor(a->DataType(), a->Shape());
+ ORT_RETURN_IF_ERROR(RunSkipLayerNormProgram(context, a, skip, norm_scale,
+ /*beta=*/nullptr, /*bias=*/nullptr,
+ epsilon_, /*simplified=*/true,
+ &normalized_a, input_skip_bias_sum));
+ return ApplyUnfusedMlp(&normalized_a,
+ gate_b,
+ gate_scales,
+ gate_bias,
+ up_b,
+ up_scales,
+ up_bias,
+ K_,
+ N_,
+ block_size_,
+ accuracy_level_,
+ bits_,
+ activation_kind_,
+ context,
+ y);
+ }
+
+ if (norm_scale != nullptr) {
+ Tensor normalized_a = context.CreateGPUTensor(a->DataType(), a->Shape());
+ const auto& a_shape = a->Shape();
+ const int64_t norm_size = a_shape[a_shape.NumDimensions() - 1];
+ const uint32_t norm_count = onnxruntime::narrow(a_shape.Size() / norm_size);
+ ORT_RETURN_IF_ERROR(onnxruntime::webgpu::RunLayerNormProgram(
+ context, a, norm_scale, /*bias=*/nullptr, epsilon_, norm_count, norm_size,
+ /*simplified=*/true, &normalized_a, /*mean=*/nullptr, /*inv_std_dev=*/nullptr));
+ return ApplyUnfusedMlp(&normalized_a,
+ gate_b,
+ gate_scales,
+ gate_bias,
+ up_b,
+ up_scales,
+ up_bias,
+ K_,
+ N_,
+ block_size_,
+ accuracy_level_,
+ bits_,
+ activation_kind_,
+ context,
+ y);
+ }
+
+ return ApplyUnfusedMlp(a,
+ gate_b,
+ gate_scales,
+ gate_bias,
+ up_b,
+ up_scales,
+ up_bias,
+ K_,
+ N_,
+ block_size_,
+ accuracy_level_,
+ bits_,
+ activation_kind_,
+ context,
+ y);
+}
+
+} // namespace webgpu
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.h
new file mode 100644
index 0000000000000..002ebca4c0e54
--- /dev/null
+++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.h
@@ -0,0 +1,63 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include
+#include
+#include
+
+#include "core/common/status.h"
+#include "core/providers/webgpu/webgpu_kernel.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace webgpu {
+
+using namespace onnxruntime::webgpu;
+using onnxruntime::webgpu::ComputeContext;
+
+// Gate activation applied between the gate and up MatMulNBits projections.
+// Currently only SiLU is supported; future activations (e.g. GELU for Gemma-style
+// gated MLPs) can be added here and threaded through the kernel and shader template.
+enum class MlpActivationKind : uint32_t {
+ Silu = 0,
+};
+
+// Parses the `activation` attribute string into MlpActivationKind. Returns a non-OK
+// Status for unsupported activations so the kernel rejects unknown values up front.
+Status ParseMlpActivation(std::string_view name, MlpActivationKind* out);
+
+class MatMulNBitsMlp final : public WebGpuKernel {
+ public:
+ explicit MatMulNBitsMlp(const OpKernelInfo& info) : WebGpuKernel(info) {
+ K_ = info.GetAttr("K");
+ N_ = info.GetAttr("N");
+ block_size_ = info.GetAttr("block_size");
+ bits_ = info.GetAttr("bits");
+ accuracy_level_ = info.GetAttrOrDefault("accuracy_level", 4);
+ epsilon_ = info.GetAttrOrDefault("epsilon", 1e-5f);
+ std::string activation;
+ ORT_ENFORCE(info.GetAttr("activation", &activation).IsOK(),
+ "MatMulNBitsMlp requires the 'activation' attribute.");
+ ORT_ENFORCE(ParseMlpActivation(activation, &activation_kind_).IsOK(),
+ "MatMulNBitsMlp: unsupported activation '", activation, "'.");
+ ORT_ENFORCE(bits_ == 4 || bits_ == 8 || bits_ == 2,
+ "Only 4b/8b/2b quantization is supported for MatMulNBitsMlp op.");
+ }
+
+ Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override;
+
+ private:
+ int64_t K_;
+ int64_t N_;
+ int64_t block_size_;
+ int64_t accuracy_level_;
+ int64_t bits_;
+ float epsilon_;
+ MlpActivationKind activation_kind_;
+};
+
+} // namespace webgpu
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.wgsl.template
new file mode 100644
index 0000000000000..f64f0d38f24e2
--- /dev/null
+++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.wgsl.template
@@ -0,0 +1,266 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#param a_length_per_tile
+#param component_a
+#param component_b
+#param elements_in_value_b
+#param single_scale_weights
+#param sub_tile_count
+#param has_norm_input
+#param has_skip_input
+#param has_skip_output
+#param tile_size_k_vec
+#param tile_size_k
+#param tile_size
+#param has_gate_bias
+#param has_up_bias
+#param k_unroll_tiles
+// Gate activation applied between gate and up projections.
+// Mirrors MlpActivationKind in matmul_nbits_mlp.h:
+// 0 = SiLU (only value currently supported)
+// New activations are added by extending the enum, the EmitGateActivationExpr
+// helper, the fusion matcher, and the activation block below.
+#param activation_kind
+
+#use .getByOffset .setByOffset
+
+#if has_norm_input
+var sum_squared_shared : array;
+#endif
+var tile_A : array;
+var gate_inter_results : array, tile_size>;
+var up_inter_results : array, tile_size>;
+
+const default_zero_point = output_element_t(8);
+
+fn load_merged_input(input_offset: u32) -> input_a_value_t {
+#if has_skip_input
+ let skip_offset = input_offset % (uniforms.skip_size / component_a);
+ return a.getByOffset(input_offset) + input_a_value_t(skip.getByOffset(skip_offset));
+#else
+ return a.getByOffset(input_offset);
+#endif
+}
+
+fn loadSHMA(batch: u32, b_global_base: u32, kidx: u32, col: u32, inv_std: f32)
+{
+ let k_offset = kidx / component_a + col;
+ let input_offset = batch * uniforms.K_of_a + k_offset;
+ if (k_offset < uniforms.K_of_a) {
+ let merged_value = load_merged_input(input_offset);
+#if has_skip_output
+ if (b_global_base == 0u) {
+ input_skip_bias_sum.setByOffset(input_offset, input_skip_bias_sum_value_t(merged_value));
+ }
+#endif
+#if has_norm_input
+ tile_A[col] = merged_value * input_a_value_t(input_a_element_t(inv_std)) * norm_scale.getByOffset(k_offset);
+#else
+ tile_A[col] = merged_value;
+#endif
+ } else {
+ tile_A[col] = input_a_value_t(0);
+ }
+}
+
+fn compute_gate_up_sums(b_global: u32, kidx: u32, idx: u32, k_offset: u32) -> vec2 {
+#if single_scale_weights
+ let gate_scale_b = gate_scales_b.getByOffset(0);
+ let up_scale_b = up_scales_b.getByOffset(0);
+#else
+ let block_idx = (kidx + idx * elements_in_value_b) / uniforms.block_size;
+ let gate_scale_b = gate_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx);
+ let up_scale_b = up_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx);
+#endif
+ let gate_b_value = gate_b.getByOffset(b_global * uniforms.K_of_b + k_offset);
+ let up_b_value = up_b.getByOffset(b_global * uniforms.K_of_b + k_offset);
+
+ var gate_sum = output_element_t(0);
+ var up_sum = output_element_t(0);
+ var a_offset = idx * (8 / component_a) * component_b;
+#if component_b == 1
+ let gate_b_value_lower = vec4(unpack4xU8(gate_b_value & 0x0F0F0F0Fu)) - vec4(default_zero_point);
+ let gate_b_value_upper = vec4(unpack4xU8((gate_b_value >> 4) & 0x0F0F0F0Fu)) - vec4(default_zero_point);
+ let gate_b0 = vec4(gate_b_value_lower[0], gate_b_value_upper[0], gate_b_value_lower[1], gate_b_value_upper[1]) * gate_scale_b;
+ let gate_b1 = vec4(gate_b_value_lower[2], gate_b_value_upper[2], gate_b_value_lower[3], gate_b_value_upper[3]) * gate_scale_b;
+ let up_b_value_lower = vec4(unpack4xU8(up_b_value & 0x0F0F0F0Fu)) - vec4(default_zero_point);
+ let up_b_value_upper = vec4(unpack4xU8((up_b_value >> 4) & 0x0F0F0F0Fu)) - vec4(default_zero_point);
+ let up_b0 = vec4(up_b_value_lower[0], up_b_value_upper[0], up_b_value_lower[1], up_b_value_upper[1]) * up_scale_b;
+ let up_b1 = vec4(up_b_value_lower[2], up_b_value_upper[2], up_b_value_lower[3], up_b_value_upper[3]) * up_scale_b;
+#if component_a == 1
+ let a0 = vec4(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]);
+ let a1 = vec4(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]);
+ gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1);
+ up_sum += dot(a0, up_b0) + dot(a1, up_b1);
+#elif component_a == 2
+ let a0 = vec4(tile_A[a_offset], tile_A[a_offset + 1]);
+ let a1 = vec4(tile_A[a_offset + 2], tile_A[a_offset + 3]);
+ gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1);
+ up_sum += dot(a0, up_b0) + dot(a1, up_b1);
+#elif component_a == 4
+ let a0 = tile_A[a_offset];
+ let a1 = tile_A[a_offset + 1];
+ gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1);
+ up_sum += dot(a0, up_b0) + dot(a1, up_b1);
+#endif
+#else
+ for (var i = 0u; i < component_b; i++) {
+ let gate_b_value_lower = vec4(unpack4xU8(gate_b_value[i] & 0x0F0F0F0Fu)) - vec4(default_zero_point);
+ let gate_b_value_upper = vec4(unpack4xU8((gate_b_value[i] >> 4) & 0x0F0F0F0Fu)) - vec4(default_zero_point);
+ let gate_b0 = vec4(gate_b_value_lower[0], gate_b_value_upper[0], gate_b_value_lower[1], gate_b_value_upper[1]) * gate_scale_b;
+ let gate_b1 = vec4(gate_b_value_lower[2], gate_b_value_upper[2], gate_b_value_lower[3], gate_b_value_upper[3]) * gate_scale_b;
+ let up_b_value_lower = vec4(unpack4xU8(up_b_value[i] & 0x0F0F0F0Fu)) - vec4(default_zero_point);
+ let up_b_value_upper = vec4(unpack4xU8((up_b_value[i] >> 4) & 0x0F0F0F0Fu)) - vec4(default_zero_point);
+ let up_b0 = vec4(up_b_value_lower[0], up_b_value_upper[0], up_b_value_lower[1], up_b_value_upper[1]) * up_scale_b;
+ let up_b1 = vec4(up_b_value_lower[2], up_b_value_upper[2], up_b_value_lower[3], up_b_value_upper[3]) * up_scale_b;
+#if component_a == 1
+ let a0 = vec4(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]);
+ let a1 = vec4(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]);
+ gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1);
+ up_sum += dot(a0, up_b0) + dot(a1, up_b1);
+ a_offset += 8;
+#elif component_a == 2
+ let a0 = vec4(tile_A[a_offset], tile_A[a_offset + 1]);
+ let a1 = vec4(tile_A[a_offset + 2], tile_A[a_offset + 3]);
+ gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1);
+ up_sum += dot(a0, up_b0) + dot(a1, up_b1);
+ a_offset += 4;
+#elif component_a == 4
+ let a0 = tile_A[a_offset];
+ let a1 = tile_A[a_offset + 1];
+ gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1);
+ up_sum += dot(a0, up_b0) + dot(a1, up_b1);
+ a_offset += 2;
+#endif
+ }
+#endif
+
+ return vec2(gate_sum, up_sum);
+}
+
+fn process_k_tile(batch: u32, b_global_base: u32, thread_idx: u32, idx: u32, idy: u32, kidx: u32, inv_std: f32) {
+ for (var id = thread_idx; id < a_length_per_tile; id += workgroup_size_x)
+ {
+ loadSHMA(batch, b_global_base, kidx, id, inv_std);
+ }
+ workgroupBarrier();
+
+ for (var local_row_offset = 0u; local_row_offset < tile_size; local_row_offset += sub_tile_count)
+ {
+ let b_global = b_global_base + local_row_offset + idy;
+ let k_offset = kidx / elements_in_value_b + idx;
+ if (b_global < uniforms.N && k_offset < uniforms.K_of_b) {
+ let sums = compute_gate_up_sums(b_global, kidx, idx, k_offset);
+ gate_inter_results[local_row_offset + idy][idx] += sums[0];
+ up_inter_results[local_row_offset + idy][idx] += sums[1];
+ }
+ }
+ workgroupBarrier();
+}
+
+$MAIN {
+ let batch = workgroup_idx / uniforms.num_N_tile;
+ let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size;
+
+ let idx = local_idx % tile_size_k_vec;
+ let idy = local_idx / tile_size_k_vec;
+
+ if (local_idx < tile_size) {
+ for (var b = 0u; b < tile_size_k_vec; b++) {
+ gate_inter_results[local_idx][b] = output_element_t(0);
+ up_inter_results[local_idx][b] = output_element_t(0);
+ }
+ }
+ workgroupBarrier();
+
+#if has_norm_input
+ var sum_squared_local = 0.0;
+ for (var a_idx = local_idx; a_idx < uniforms.K_of_a; a_idx += workgroup_size_x) {
+ let a_value = load_merged_input(batch * uniforms.K_of_a + a_idx);
+#if component_a == 1
+ let a_f32 = f32(a_value);
+ sum_squared_local += a_f32 * a_f32;
+#elif component_a == 2
+ let a_f32 = vec2(a_value);
+ sum_squared_local += dot(a_f32, a_f32);
+#elif component_a == 4
+ let a_f32 = vec4(a_value);
+ sum_squared_local += dot(a_f32, a_f32);
+#endif
+ }
+ sum_squared_shared[local_idx] = sum_squared_local;
+ workgroupBarrier();
+
+ var reduce_size : u32 = workgroup_size_x;
+ for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {
+ reduce_size = curr_size + (reduce_size & 1u);
+ if (local_idx < curr_size) {
+ sum_squared_shared[local_idx] += sum_squared_shared[local_idx + reduce_size];
+ }
+ workgroupBarrier();
+ }
+
+ let inv_std = inverseSqrt(sum_squared_shared[0] / f32(uniforms.K) + uniforms.epsilon);
+#else
+ let inv_std = 1.0;
+#endif
+
+#if k_unroll_tiles == 1
+ for (var kidx = 0u; kidx < uniforms.K; kidx += tile_size_k) {
+ process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std);
+ }
+#elif k_unroll_tiles == 2
+ let unrolled_k_step = tile_size_k * 2u;
+ let unrolled_k_limit = uniforms.K - (uniforms.K % unrolled_k_step);
+ for (var kidx = 0u; kidx < unrolled_k_limit; kidx += unrolled_k_step) {
+ process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std);
+ process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k, inv_std);
+ }
+ for (var kidx = unrolled_k_limit; kidx < uniforms.K; kidx += tile_size_k) {
+ process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std);
+ }
+#elif k_unroll_tiles == 4
+ let unrolled_k_step = tile_size_k * 4u;
+ let unrolled_k_limit = uniforms.K - (uniforms.K % unrolled_k_step);
+ for (var kidx = 0u; kidx < unrolled_k_limit; kidx += unrolled_k_step) {
+ process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std);
+ process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k, inv_std);
+ process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k * 2u, inv_std);
+ process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k * 3u, inv_std);
+ }
+ for (var kidx = unrolled_k_limit; kidx < uniforms.K; kidx += tile_size_k) {
+ process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std);
+ }
+#endif
+
+ if (batch >= uniforms.batch_count) {
+ return;
+ }
+
+ if (local_idx < tile_size) {
+ var gate_output_value = output_element_t(0);
+ var up_output_value = output_element_t(0);
+ for (var b = 0u; b < tile_size_k_vec; b++) {
+ gate_output_value += gate_inter_results[local_idx][b];
+ up_output_value += up_inter_results[local_idx][b];
+ }
+ let b_global = b_global_base + local_idx;
+ let output_idx = batch * uniforms.N + b_global;
+ if (b_global < uniforms.N) {
+#if has_gate_bias
+ gate_output_value += gate_bias[b_global];
+#endif
+#if has_up_bias
+ up_output_value += up_bias[b_global];
+#endif
+ let one = output_element_t(1.0);
+#if activation_kind == 0
+ // SiLU(x) = x * sigmoid(x). New activations are added with additional
+ // `#elif activation_kind == N` blocks (must match MlpActivationKind).
+ let activated_value = gate_output_value * (one / (one + exp(-gate_output_value)));
+#endif
+ output.setByOffset(output_idx, activated_value * up_output_value);
+ }
+ }
+} // MAIN
diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc
new file mode 100644
index 0000000000000..4d46e03fba20c
--- /dev/null
+++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc
@@ -0,0 +1,482 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "contrib_ops/webgpu/quantization/matmul_nbits_qkv.h"
+
+#include
+
+#include "contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h"
+#include "contrib_ops/webgpu/quantization/matmul_nbits.h"
+#include "contrib_ops/webgpu/quantization/matmul_nbits_common.h"
+#include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h"
+#include "contrib_ops/webgpu/bert/skip_layer_norm.h"
+#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
+#include "core/providers/cpu/math/matmul_helper.h"
+#include "core/providers/webgpu/nn/layer_norm.h"
+#include "core/providers/webgpu/shader_helper.h"
+#include "core/providers/webgpu/webgpu_supported_types.h"
+#include "core/providers/webgpu/webgpu_utils.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace webgpu {
+
+namespace {
+
+Status ApplyUnfusedQKVSimplifiedLayerNorm(const Tensor* a,
+ const Tensor* norm_scale,
+ const Tensor* q_b,
+ const Tensor* q_scales,
+ const Tensor* k_b,
+ const Tensor* k_scales,
+ const Tensor* v_b,
+ const Tensor* v_scales,
+ int64_t K,
+ int64_t Nq,
+ int64_t Nkv,
+ int64_t block_size,
+ int64_t accuracy_level,
+ int64_t bits,
+ float epsilon,
+ onnxruntime::webgpu::ComputeContext& context,
+ Tensor* q_output,
+ Tensor* k_output,
+ Tensor* v_output) {
+ Tensor normalized_a = context.CreateGPUTensor(a->DataType(), a->Shape());
+ const auto& a_shape = a->Shape();
+ const int64_t norm_size = a_shape[a_shape.NumDimensions() - 1];
+ const uint32_t norm_count = onnxruntime::narrow(a_shape.Size() / norm_size);
+ ORT_RETURN_IF_ERROR(onnxruntime::webgpu::RunLayerNormProgram(
+ context, a, norm_scale, /*bias=*/nullptr, epsilon, norm_count, norm_size,
+ /*simplified=*/true, &normalized_a, /*mean=*/nullptr, /*inv_std_dev=*/nullptr));
+ ORT_RETURN_IF_ERROR(ApplyMatMulNBits(&normalized_a, q_b, q_scales, nullptr, nullptr,
+ K, Nq, block_size, accuracy_level, bits, context, q_output));
+ ORT_RETURN_IF_ERROR(ApplyMatMulNBits(&normalized_a, k_b, k_scales, nullptr, nullptr,
+ K, Nkv, block_size, accuracy_level, bits, context, k_output));
+ ORT_RETURN_IF_ERROR(ApplyMatMulNBits(&normalized_a, v_b, v_scales, nullptr, nullptr,
+ K, Nkv, block_size, accuracy_level, bits, context, v_output));
+ return Status::OK();
+}
+
+Status ApplyUnfusedQKVSkipSimplifiedLayerNorm(const Tensor* a,
+ const Tensor* skip,
+ const Tensor* norm_scale,
+ const Tensor* q_b,
+ const Tensor* q_scales,
+ const Tensor* k_b,
+ const Tensor* k_scales,
+ const Tensor* v_b,
+ const Tensor* v_scales,
+ int64_t K,
+ int64_t Nq,
+ int64_t Nkv,
+ int64_t block_size,
+ int64_t accuracy_level,
+ int64_t bits,
+ float epsilon,
+ onnxruntime::webgpu::ComputeContext& context,
+ Tensor* q_output,
+ Tensor* k_output,
+ Tensor* v_output,
+ Tensor* input_skip_bias_sum) {
+ Tensor normalized_a = context.CreateGPUTensor(a->DataType(), a->Shape());
+ ORT_RETURN_IF_ERROR(RunSkipLayerNormProgram(context, a, skip, norm_scale,
+ /*beta=*/nullptr, /*bias=*/nullptr,
+ epsilon, /*simplified=*/true,
+ &normalized_a, input_skip_bias_sum));
+ ORT_RETURN_IF_ERROR(ApplyMatMulNBits(&normalized_a, q_b, q_scales, nullptr, nullptr,
+ K, Nq, block_size, accuracy_level, bits, context, q_output));
+ ORT_RETURN_IF_ERROR(ApplyMatMulNBits(&normalized_a, k_b, k_scales, nullptr, nullptr,
+ K, Nkv, block_size, accuracy_level, bits, context, k_output));
+ ORT_RETURN_IF_ERROR(ApplyMatMulNBits(&normalized_a, v_b, v_scales, nullptr, nullptr,
+ K, Nkv, block_size, accuracy_level, bits, context, v_output));
+ return Status::OK();
+}
+
+class MatMulNBitsQkvDecodeProgram final
+ : public Program {
+ public:
+ MatMulNBitsQkvDecodeProgram(uint32_t tile_size,
+ bool single_scale_weights,
+ uint32_t tile_size_k_vec,
+ uint32_t k_unroll_tiles,
+ bool has_norm,
+ bool has_skip_input,
+ bool has_skip_output)
+ : Program{"MatMulNBitsQkvDecode"},
+ tile_size_(tile_size),
+ single_scale_weights_(single_scale_weights),
+ tile_size_k_vec_(tile_size_k_vec),
+ k_unroll_tiles_(k_unroll_tiles),
+ has_norm_(has_norm),
+ has_skip_input_(has_skip_input),
+ has_skip_output_(has_skip_output) {
+ // The no-norm variant runs against an already-normalized input tensor and therefore
+ // never owns the residual skip path nor the residual passthrough output.
+ ORT_ENFORCE(has_norm_ || (!has_skip_input_ && !has_skip_output_),
+ "MatMulNBitsQkvDecodeProgram: skip input/output require has_norm=true.");
+ }
+
+ Status GenerateShaderCode(ShaderHelper& shader) const override {
+ const auto& a = shader.AddInput("input_a", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
+ const auto* skip = has_skip_input_ ? &shader.AddInput("skip", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias) : nullptr;
+ const auto* norm_scale_ptr = has_norm_ ? &shader.AddInput("norm_scale", ShaderUsage::UseValueTypeAlias) : nullptr;
+ const auto& q_b = shader.AddInput("q_b", ShaderUsage::UseValueTypeAlias);
+ const auto& q_scales_b = shader.AddInput("q_scales_b");
+ const auto& k_b = shader.AddInput("k_b");
+ const auto& k_scales_b = shader.AddInput("k_scales_b");
+ const auto& v_b = shader.AddInput("v_b");
+ const auto& v_scales_b = shader.AddInput("v_scales_b");
+ const auto& q_output = shader.AddOutput("q_output",
+ ShaderUsage::UseValueTypeAlias |
+ ShaderUsage::UseElementTypeAlias);
+ const auto& k_output = shader.AddOutput("k_output",
+ ShaderUsage::UseValueTypeAlias |
+ ShaderUsage::UseElementTypeAlias);
+ const auto& v_output = shader.AddOutput("v_output",
+ ShaderUsage::UseValueTypeAlias |
+ ShaderUsage::UseElementTypeAlias);
+ const auto* input_skip_bias_sum = has_skip_output_ ? &shader.AddOutput("input_skip_bias_sum", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias) : nullptr;
+ const auto& skip_var = skip != nullptr ? *skip : a;
+ const auto& norm_scale_var = norm_scale_ptr != nullptr ? *norm_scale_ptr : a;
+ const auto& input_skip_bias_sum_var = input_skip_bias_sum != nullptr ? *input_skip_bias_sum : q_output;
+
+ const uint32_t components_a = a.NumComponents();
+ const uint32_t components_b = q_b.NumComponents() / 4;
+ const uint32_t tile_size_k_vec = tile_size_k_vec_;
+ const uint32_t elements_in_value_b = components_b * 8u;
+ const uint32_t tile_size_k = tile_size_k_vec * elements_in_value_b;
+ const uint32_t a_length_per_tile = tile_size_k / components_a;
+ const uint32_t sub_tile_count = WorkgroupSizeX() / tile_size_k_vec;
+
+ return WGSL_TEMPLATE_APPLY(shader, "quantization/matmul_nbits_qkv.wgsl.template",
+ WGSL_TEMPLATE_PARAMETER(a_length_per_tile, a_length_per_tile),
+ WGSL_TEMPLATE_PARAMETER(component_a, components_a),
+ WGSL_TEMPLATE_PARAMETER(component_b, components_b),
+ WGSL_TEMPLATE_PARAMETER(elements_in_value_b, elements_in_value_b),
+ WGSL_TEMPLATE_PARAMETER(has_norm, has_norm_),
+ WGSL_TEMPLATE_PARAMETER(has_skip_input, has_skip_input_),
+ WGSL_TEMPLATE_PARAMETER(has_skip_output, has_skip_output_),
+ WGSL_TEMPLATE_PARAMETER(k_unroll_tiles, k_unroll_tiles_),
+ WGSL_TEMPLATE_PARAMETER(single_scale_weights, single_scale_weights_),
+ WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count),
+ WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_),
+ WGSL_TEMPLATE_PARAMETER(tile_size_k, tile_size_k),
+ WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec),
+ WGSL_TEMPLATE_VARIABLE(a, a),
+ WGSL_TEMPLATE_VARIABLE(input_skip_bias_sum, input_skip_bias_sum_var),
+ WGSL_TEMPLATE_VARIABLE(k_b, k_b),
+ WGSL_TEMPLATE_VARIABLE(k_output, k_output),
+ WGSL_TEMPLATE_VARIABLE(k_scales_b, k_scales_b),
+ WGSL_TEMPLATE_VARIABLE(norm_scale, norm_scale_var),
+ WGSL_TEMPLATE_VARIABLE(q_b, q_b),
+ WGSL_TEMPLATE_VARIABLE(q_output, q_output),
+ WGSL_TEMPLATE_VARIABLE(q_scales_b, q_scales_b),
+ WGSL_TEMPLATE_VARIABLE(skip, skip_var),
+ WGSL_TEMPLATE_VARIABLE(v_b, v_b),
+ WGSL_TEMPLATE_VARIABLE(v_output, v_output),
+ WGSL_TEMPLATE_VARIABLE(v_scales_b, v_scales_b));
+ }
+
+ WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
+ {"Nq", ProgramUniformVariableDataType::Uint32},
+ {"Nkv", ProgramUniformVariableDataType::Uint32},
+ {"K", ProgramUniformVariableDataType::Uint32},
+ {"K_of_a", ProgramUniformVariableDataType::Uint32},
+ {"K_of_b", ProgramUniformVariableDataType::Uint32},
+ {"block_size", ProgramUniformVariableDataType::Uint32},
+ {"blocks_per_col", ProgramUniformVariableDataType::Uint32},
+ {"num_N_tile", ProgramUniformVariableDataType::Uint32},
+ {"batch_count", ProgramUniformVariableDataType::Uint32},
+ {"skip_size", ProgramUniformVariableDataType::Uint32},
+ {"epsilon", ProgramUniformVariableDataType::Float32});
+
+ private:
+ uint32_t tile_size_;
+ bool single_scale_weights_;
+ uint32_t tile_size_k_vec_;
+ uint32_t k_unroll_tiles_;
+ bool has_norm_;
+ bool has_skip_input_;
+ bool has_skip_output_;
+};
+
+} // namespace
+
+ONNX_OPERATOR_KERNEL_EX(
+ MatMulNBitsQkv,
+ kMSDomain,
+ 1,
+ kWebGpuExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("T1", WebGpuSupportedFloatTypes())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ MatMulNBitsQkv);
+
+Status MatMulNBitsQkv::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
+ const Tensor* a = context.Input(0);
+ const Tensor* skip = context.Input(1);
+ const Tensor* norm_scale = context.Input(2);
+ const Tensor* q_b = context.Input(3);
+ const Tensor* q_scales = context.Input(4);
+ const Tensor* k_b = context.Input(5);
+ const Tensor* k_scales = context.Input(6);
+ const Tensor* v_b = context.Input(7);
+ const Tensor* v_scales = context.Input(8);
+
+ ORT_ENFORCE(bits_ == 4, "MatMulNBitsQkv currently supports 4-bit weights only.");
+ ORT_ENFORCE(block_size_ == 32, "MatMulNBitsQkv currently supports block_size=32 only.");
+
+ TensorShape q_b_shape({Nq_, K_});
+ MatMulComputeHelper helper;
+ ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), q_b_shape, false, true));
+
+ const uint32_t batch_count = onnxruntime::narrow(helper.OutputOffsets().size());
+ const uint32_t M = onnxruntime::narrow(helper.M());
+ const uint32_t K = onnxruntime::narrow(helper.K());
+ const uint32_t Nq = onnxruntime::narrow(Nq_);
+ const uint32_t Nkv = onnxruntime::narrow(Nkv_);
+
+ auto q_shape = helper.OutputShape();
+ TensorShapeVector kv_dims(q_shape.GetDims().begin(), q_shape.GetDims().end());
+ kv_dims.back() = Nkv_;
+ TensorShape kv_shape(kv_dims);
+ Tensor* q_output = context.Output(0, q_shape);
+ Tensor* k_output = context.Output(1, kv_shape);
+ Tensor* v_output = context.Output(2, kv_shape);
+ Tensor* input_skip_bias_sum = (skip != nullptr && context.OutputCount() > 3)
+ ? context.Output(3, a->Shape())
+ : nullptr;
+ if (q_output->Shape().Size() == 0) {
+ return Status::OK();
+ }
+
+ ORT_ENFORCE(norm_scale->Shape().Size() == K_, "norm_scale must have shape [K].");
+
+ const uint32_t block_size = onnxruntime::narrow(block_size_);
+ const bool would_use_subgroup_unfused =
+ WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(M,
+ Nq,
+ K,
+ batch_count,
+ block_size,
+ accuracy_level_,
+ bits_,
+ context,
+ q_output);
+ const bool would_use_dp4a_unfused =
+ !would_use_subgroup_unfused &&
+ WouldApplyDP4AMatMulNBitsInCurrentDispatch(M,
+ Nq,
+ K,
+ block_size,
+ accuracy_level_,
+ context,
+ q_output);
+ const bool would_use_wide_tile_unfused =
+ !would_use_subgroup_unfused &&
+ !would_use_dp4a_unfused &&
+ WouldApplyWideTileMatMulNBitsInCurrentDispatch(M,
+ K,
+ block_size,
+ bits_);
+
+ // The fused MatMulNBitsQkv shader binds every Q/K/V weight + scales tensor and the
+ // norm/skip tensors as storage buffers. Devices with a tight maxStorageBuffersPerShaderStage
+ // (notably macOS Metal at 10) cannot bind that many. For those devices we run the layer
+ // norm separately into a scratch tensor and then dispatch a no-norm variant of the fused
+ // QKV decode program (which omits the norm_scale, skip, and skip-output bindings, dropping
+ // the storage-buffer count from up to 13 down to 10).
+ //
+ // Storage-buffer count: input_a + (skip?) + norm_scale + 3 * (weight + scales)
+ // + q/k/v outputs + (skip output?)
+ const uint32_t required_storage_buffers =
+ 1u // input_a
+ + (skip != nullptr ? 1u : 0u) // skip
+ + 1u // norm_scale
+ + 6u // q/k/v weights + scales
+ + 3u // q/k/v outputs
+ + (input_skip_bias_sum != nullptr ? 1u : 0u); // skip output
+ const bool exceeds_storage_buffer_limit =
+ required_storage_buffers > context.DeviceLimits().maxStorageBuffersPerShaderStage;
+
+ if (would_use_subgroup_unfused || would_use_dp4a_unfused || would_use_wide_tile_unfused ||
+ M != 1) {
+ if (skip != nullptr) {
+ return ApplyUnfusedQKVSkipSimplifiedLayerNorm(a,
+ skip,
+ norm_scale,
+ q_b,
+ q_scales,
+ k_b,
+ k_scales,
+ v_b,
+ v_scales,
+ K_,
+ Nq_,
+ Nkv_,
+ block_size_,
+ accuracy_level_,
+ bits_,
+ epsilon_,
+ context,
+ q_output,
+ k_output,
+ v_output,
+ input_skip_bias_sum);
+ }
+ return ApplyUnfusedQKVSimplifiedLayerNorm(a,
+ norm_scale,
+ q_b,
+ q_scales,
+ k_b,
+ k_scales,
+ v_b,
+ v_scales,
+ K_,
+ Nq_,
+ Nkv_,
+ block_size_,
+ accuracy_level_,
+ bits_,
+ epsilon_,
+ context,
+ q_output,
+ k_output,
+ v_output);
+ }
+
+ // For the partial-fuse path, run [Skip]SimplifiedLayerNormalization into a scratch tensor
+ // first, then point the decode program at the normalized tensor with the norm/skip bindings
+ // turned off. The user-visible residual passthrough (input_skip_bias_sum) is produced by the
+ // skip-norm op directly, so the decode program never needs to write it itself.
+ std::optional normalized_a_storage;
+ const Tensor* decode_a = a;
+ if (exceeds_storage_buffer_limit) {
+ normalized_a_storage.emplace(context.CreateGPUTensor(a->DataType(), a->Shape()));
+ if (skip != nullptr) {
+ ORT_RETURN_IF_ERROR(RunSkipLayerNormProgram(context, a, skip, norm_scale,
+ /*beta=*/nullptr, /*bias=*/nullptr,
+ epsilon_, /*simplified=*/true,
+ &*normalized_a_storage,
+ input_skip_bias_sum));
+ } else {
+ const auto& a_shape = a->Shape();
+ const int64_t norm_size = a_shape[a_shape.NumDimensions() - 1];
+ const uint32_t norm_count = onnxruntime::narrow(a_shape.Size() / norm_size);
+ ORT_RETURN_IF_ERROR(onnxruntime::webgpu::RunLayerNormProgram(
+ context, a, norm_scale, /*bias=*/nullptr, epsilon_, norm_count, norm_size,
+ /*simplified=*/true, &*normalized_a_storage, /*mean=*/nullptr,
+ /*inv_std_dev=*/nullptr));
+ }
+ decode_a = &*normalized_a_storage;
+ }
+
+ const uint32_t components_a = GetMaxComponents(K);
+ const uint32_t block_size_per_col = block_size;
+ const uint32_t n_blocks_per_col = (K + block_size_per_col - 1) / block_size_per_col;
+ const uint32_t blob_size = (block_size_per_col / 8) * static_cast(bits_);
+ const uint32_t blob_size_in_words = blob_size / 4;
+ const uint32_t components_b = GetMaxComponents(blob_size_in_words);
+ constexpr uint32_t kU32Components = 4;
+ const uint32_t components_b_with_u32 = components_b * kU32Components;
+ const uint32_t K_of_b = (n_blocks_per_col * blob_size) / components_b_with_u32;
+ const bool single_scale_weights =
+ q_scales->Shape().Size() == 1 && k_scales->Shape().Size() == 1 && v_scales->Shape().Size() == 1;
+
+ uint32_t workgroup_size = 128;
+ uint32_t tile_size = 8;
+ uint32_t tile_size_k_vec = (context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u;
+
+ const uint32_t elements_in_value_b = components_b * (32u / onnxruntime::narrow(bits_));
+ const uint32_t tile_size_k = tile_size_k_vec * elements_in_value_b;
+ const uint32_t k_tile_iterations = K / tile_size_k;
+
+ // Decode-program-level skip bindings: only used by the fully-fused path. The partial-fuse
+ // path has already merged the residual upstream and exposed the residual passthrough output
+ // through the layer-norm op, so the decode program runs with skip_input/skip_output disabled.
+ const bool decode_has_norm = !exceeds_storage_buffer_limit;
+ const bool decode_has_skip_input = !exceeds_storage_buffer_limit && skip != nullptr;
+ std::optional input_skip_bias_sum_scratch;
+ Tensor* decode_input_skip_bias_sum = nullptr;
+ if (decode_has_skip_input) {
+ decode_input_skip_bias_sum = input_skip_bias_sum;
+ if (decode_input_skip_bias_sum == nullptr) {
+ input_skip_bias_sum_scratch.emplace(context.CreateGPUTensor(a->DataType(), a->Shape()));
+ decode_input_skip_bias_sum = &*input_skip_bias_sum_scratch;
+ }
+ }
+ const bool decode_has_skip_output = decode_input_skip_bias_sum != nullptr;
+
+ uint32_t k_unroll_tiles = 1;
+ if ((K % tile_size_k) == 0) {
+ if (k_tile_iterations >= 8 && std::max(Nq, Nkv) <= 2048 &&
+ context.AdapterInfo().vendor != std::string_view{"intel"}) {
+ k_unroll_tiles = 4;
+ } else if (k_tile_iterations >= 4) {
+ k_unroll_tiles = 2;
+ }
+ }
+
+ const uint32_t num_N_tile = CeilDiv(std::max(Nq, Nkv), tile_size);
+ MatMulNBitsQkvDecodeProgram program{tile_size,
+ single_scale_weights,
+ tile_size_k_vec,
+ k_unroll_tiles,
+ decode_has_norm,
+ decode_has_skip_input,
+ decode_has_skip_output};
+ program.SetWorkgroupSize(workgroup_size);
+ program.SetDispatchGroupSize(num_N_tile, 1, batch_count);
+ program
+ .AddInput({decode_a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)});
+ if (decode_has_skip_input) {
+ program.AddInput({skip, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)});
+ }
+ if (decode_has_norm) {
+ program.AddInput({norm_scale, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)});
+ }
+ program
+ .AddInputs({{q_b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)},
+ {q_scales, ProgramTensorMetadataDependency::TypeAndRank},
+ {k_b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)},
+ {k_scales, ProgramTensorMetadataDependency::TypeAndRank},
+ {v_b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)},
+ {v_scales, ProgramTensorMetadataDependency::TypeAndRank}})
+ .AddOutputs({{q_output, ProgramTensorMetadataDependency::TypeAndRank},
+ {k_output, ProgramTensorMetadataDependency::TypeAndRank},
+ {v_output, ProgramTensorMetadataDependency::TypeAndRank}})
+ .AddUniformVariables({{Nq},
+ {Nkv},
+ {K},
+ {K / components_a},
+ {K_of_b},
+ {block_size},
+ {n_blocks_per_col},
+ {num_N_tile},
+ {batch_count},
+ {decode_has_skip_input ? onnxruntime::narrow(skip->Shape().Size()) : 0u},
+ {epsilon_}})
+ .CacheHint(Nq,
+ Nkv,
+ K,
+ tile_size,
+ tile_size_k_vec,
+ k_unroll_tiles,
+ single_scale_weights,
+ decode_has_norm,
+ decode_has_skip_input,
+ decode_has_skip_output,
+ "decode_qkv_sln");
+ if (decode_has_skip_output) {
+ program.AddOutput({decode_input_skip_bias_sum,
+ ProgramTensorMetadataDependency::TypeAndRank,
+ static_cast(components_a)});
+ }
+
+ return context.RunProgram(program);
+}
+
+} // namespace webgpu
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.h
new file mode 100644
index 0000000000000..4d57ab5ac2b3c
--- /dev/null
+++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.h
@@ -0,0 +1,42 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/providers/webgpu/webgpu_kernel.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace webgpu {
+
+using namespace onnxruntime::webgpu;
+
+class MatMulNBitsQkv final : public WebGpuKernel {
+ public:
+ explicit MatMulNBitsQkv(const OpKernelInfo& info) : WebGpuKernel(info) {
+ K_ = info.GetAttr("K");
+ Nq_ = info.GetAttr("Nq");
+ Nkv_ = info.GetAttr("Nkv");
+ block_size_ = info.GetAttr("block_size");
+ bits_ = info.GetAttr("bits");
+ accuracy_level_ = info.GetAttrOrDefault("accuracy_level", 4);
+ epsilon_ = info.GetAttrOrDefault("epsilon", 1e-6f);
+ ORT_ENFORCE(bits_ == 4,
+ "MatMulNBitsQkv currently supports 4-bit weights only.");
+ }
+
+ Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override;
+
+ private:
+ int64_t K_;
+ int64_t Nq_;
+ int64_t Nkv_;
+ int64_t block_size_;
+ int64_t accuracy_level_;
+ int64_t bits_;
+ float epsilon_;
+};
+
+} // namespace webgpu
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.wgsl.template
new file mode 100644
index 0000000000000..60f34e9ef2530
--- /dev/null
+++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.wgsl.template
@@ -0,0 +1,281 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#param a_length_per_tile
+#param component_a
+#param component_b
+#param elements_in_value_b
+#param has_norm
+#param has_skip_input
+#param has_skip_output
+#param k_unroll_tiles
+#param single_scale_weights
+#param sub_tile_count
+#param tile_size_k_vec
+#param tile_size_k
+#param tile_size
+
+#use .getByOffset .setByOffset
+
+#if has_norm
+var sum_squared_shared : array;
+#endif
+var tile_A : array;
+var q_inter_results : array, tile_size>;
+var k_inter_results : array, tile_size>;
+var v_inter_results : array, tile_size>;
+
+const default_zero_point = vec4(q_output_element_t(8));
+
+fn unpack_nibble_values(word: u32) -> vec4 {
+ let unpacked = unpack4xU8(word);
+ return vec4(q_output_element_t(unpacked[0]),
+ q_output_element_t(unpacked[1]),
+ q_output_element_t(unpacked[2]),
+ q_output_element_t(unpacked[3]));
+}
+
+fn load_merged_input(input_offset: u32) -> input_a_value_t {
+ var value = a.getByOffset(input_offset);
+#if has_skip_input
+ let skip_offset = input_offset % (uniforms.skip_size / component_a);
+ value += input_a_value_t(skip.getByOffset(skip_offset));
+#endif
+ return value;
+}
+
+#if component_a == 1
+fn load_a_vec4(a_offset: u32) -> vec4 {
+ return vec4(q_output_element_t(tile_A[a_offset]),
+ q_output_element_t(tile_A[a_offset + 1]),
+ q_output_element_t(tile_A[a_offset + 2]),
+ q_output_element_t(tile_A[a_offset + 3]));
+}
+#elif component_a == 2
+fn load_a_vec4(a_offset: u32) -> vec4 {
+ let a0 = tile_A[a_offset];
+ let a1 = tile_A[a_offset + 1];
+ return vec4(q_output_element_t(a0[0]),
+ q_output_element_t(a0[1]),
+ q_output_element_t(a1[0]),
+ q_output_element_t(a1[1]));
+}
+#elif component_a == 4
+fn load_a_vec4(a_offset: u32) -> vec4 {
+ let a = tile_A[a_offset];
+ return vec4(q_output_element_t(a[0]),
+ q_output_element_t(a[1]),
+ q_output_element_t(a[2]),
+ q_output_element_t(a[3]));
+}
+#endif
+
+fn loadSHMA(batch: u32, b_global_base: u32, kidx: u32, col: u32, inv_std: f32) {
+ let k_offset = kidx / component_a + col;
+ let input_offset = batch * uniforms.K_of_a + k_offset;
+ if (k_offset < uniforms.K_of_a) {
+#if has_norm
+ let merged_value = load_merged_input(input_offset);
+#if has_skip_output
+ if (b_global_base == 0u) {
+ input_skip_bias_sum.setByOffset(input_offset, input_skip_bias_sum_value_t(merged_value));
+ }
+#endif
+ tile_A[col] = merged_value * input_a_value_t(input_a_element_t(inv_std)) * norm_scale.getByOffset(k_offset);
+#else
+ // Layer norm has already been applied to `a` upstream; load the pre-normalized value directly.
+ _ = b_global_base;
+ _ = inv_std;
+ tile_A[col] = a.getByOffset(input_offset);
+#endif
+ } else {
+ tile_A[col] = input_a_value_t(0);
+ }
+}
+
+fn compute_projection_sum(weight: q_b_value_t,
+ scale: q_output_element_t,
+ idx: u32) -> q_output_element_t {
+ var sum = q_output_element_t(0);
+ var a_offset = idx * (8 / component_a) * component_b;
+#if component_b == 1
+ let weight_lower = unpack_nibble_values(weight & 0x0F0F0F0Fu) - default_zero_point;
+ let weight_upper = unpack_nibble_values((weight >> 4) & 0x0F0F0F0Fu) - default_zero_point;
+ let w0 = vec4(q_output_element_t(weight_lower[0]), q_output_element_t(weight_upper[0]), q_output_element_t(weight_lower[1]), q_output_element_t(weight_upper[1])) * scale;
+ let w1 = vec4(q_output_element_t(weight_lower[2]), q_output_element_t(weight_upper[2]), q_output_element_t(weight_lower[3]), q_output_element_t(weight_upper[3])) * scale;
+#if component_a == 1
+ let a0 = load_a_vec4(a_offset);
+ let a1 = load_a_vec4(a_offset + 4);
+ sum += dot(a0, w0) + dot(a1, w1);
+#elif component_a == 2
+ let a0 = load_a_vec4(a_offset);
+ let a1 = load_a_vec4(a_offset + 2);
+ sum += dot(a0, w0) + dot(a1, w1);
+#elif component_a == 4
+ let a0 = load_a_vec4(a_offset);
+ let a1 = load_a_vec4(a_offset + 1);
+ sum += dot(a0, w0) + dot(a1, w1);
+#endif
+#else
+ for (var i = 0u; i < component_b; i++) {
+ let weight_lower = unpack_nibble_values(weight[i] & 0x0F0F0F0Fu) - default_zero_point;
+ let weight_upper = unpack_nibble_values((weight[i] >> 4) & 0x0F0F0F0Fu) - default_zero_point;
+ let w0 = vec4(q_output_element_t(weight_lower[0]), q_output_element_t(weight_upper[0]), q_output_element_t(weight_lower[1]), q_output_element_t(weight_upper[1])) * scale;
+ let w1 = vec4(q_output_element_t(weight_lower[2]), q_output_element_t(weight_upper[2]), q_output_element_t(weight_lower[3]), q_output_element_t(weight_upper[3])) * scale;
+#if component_a == 1
+ let a0 = load_a_vec4(a_offset);
+ let a1 = load_a_vec4(a_offset + 4);
+ sum += dot(a0, w0) + dot(a1, w1);
+ a_offset += 8;
+#elif component_a == 2
+ let a0 = load_a_vec4(a_offset);
+ let a1 = load_a_vec4(a_offset + 2);
+ sum += dot(a0, w0) + dot(a1, w1);
+ a_offset += 4;
+#elif component_a == 4
+ let a0 = load_a_vec4(a_offset);
+ let a1 = load_a_vec4(a_offset + 1);
+ sum += dot(a0, w0) + dot(a1, w1);
+ a_offset += 2;
+#endif
+ }
+#endif
+ return sum;
+}
+
+fn process_k_tile(batch: u32, b_global_base: u32, thread_idx: u32, idx: u32, idy: u32, kidx: u32, inv_std: f32) {
+ for (var id = thread_idx; id < a_length_per_tile; id += workgroup_size_x) {
+ loadSHMA(batch, b_global_base, kidx, id, inv_std);
+ }
+ workgroupBarrier();
+
+#if single_scale_weights
+ let q_scale_b = q_output_element_t(q_scales_b.getByOffset(0));
+ let k_scale_b = q_output_element_t(k_scales_b.getByOffset(0));
+ let v_scale_b = q_output_element_t(v_scales_b.getByOffset(0));
+#endif
+
+ for (var local_row_offset = 0u; local_row_offset < tile_size; local_row_offset += sub_tile_count) {
+ let b_global = b_global_base + local_row_offset + idy;
+ let k_offset = kidx / elements_in_value_b + idx;
+ #if !single_scale_weights
+ let block_idx = (kidx + idx * elements_in_value_b) / uniforms.block_size;
+ #endif
+ if (k_offset < uniforms.K_of_b) {
+ if (b_global < uniforms.Nq) {
+ #if !single_scale_weights
+ let q_scale_b = q_output_element_t(q_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx));
+ #endif
+ let q_weight = q_b.getByOffset(b_global * uniforms.K_of_b + k_offset);
+ q_inter_results[local_row_offset + idy][idx] += compute_projection_sum(q_weight, q_scale_b, idx);
+ }
+ if (b_global < uniforms.Nkv) {
+ #if !single_scale_weights
+ let k_scale_b = q_output_element_t(k_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx));
+ let v_scale_b = q_output_element_t(v_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx));
+ #endif
+ let k_weight = k_b.getByOffset(b_global * uniforms.K_of_b + k_offset);
+ let v_weight = v_b.getByOffset(b_global * uniforms.K_of_b + k_offset);
+ k_inter_results[local_row_offset + idy][idx] += compute_projection_sum(k_weight, k_scale_b, idx);
+ v_inter_results[local_row_offset + idy][idx] += compute_projection_sum(v_weight, v_scale_b, idx);
+ }
+ }
+ }
+ workgroupBarrier();
+ }
+
+$MAIN {
+ let batch = workgroup_idx / uniforms.num_N_tile;
+ let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size;
+
+ let idx = local_idx % tile_size_k_vec;
+ let idy = local_idx / tile_size_k_vec;
+
+ if (local_idx < tile_size) {
+ for (var b = 0u; b < tile_size_k_vec; b++) {
+ q_inter_results[local_idx][b] = q_output_element_t(0);
+ k_inter_results[local_idx][b] = q_output_element_t(0);
+ v_inter_results[local_idx][b] = q_output_element_t(0);
+ }
+ }
+
+#if has_norm
+ var sum_squared_local = 0.0;
+ for (var a_idx = local_idx; a_idx < uniforms.K_of_a; a_idx += workgroup_size_x) {
+ let a_value = load_merged_input(batch * uniforms.K_of_a + a_idx);
+#if component_a == 1
+ let a_f32 = f32(a_value);
+ sum_squared_local += a_f32 * a_f32;
+#elif component_a == 2
+ let a_f32 = vec2(a_value);
+ sum_squared_local += dot(a_f32, a_f32);
+#elif component_a == 4
+ let a_f32 = vec4(a_value);
+ sum_squared_local += dot(a_f32, a_f32);
+#endif
+ }
+ sum_squared_shared[local_idx] = sum_squared_local;
+ workgroupBarrier();
+
+ var reduce_size : u32 = workgroup_size_x;
+ for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {
+ reduce_size = curr_size + (reduce_size & 1u);
+ if (local_idx < curr_size) {
+ sum_squared_shared[local_idx] += sum_squared_shared[local_idx + reduce_size];
+ }
+ workgroupBarrier();
+ }
+
+ let inv_std = inverseSqrt(sum_squared_shared[0] / f32(uniforms.K) + uniforms.epsilon);
+#else
+ // Layer norm already applied upstream; inv_std is unused but kept in the loadSHMA signature.
+ let inv_std = 1.0;
+#endif
+
+#if k_unroll_tiles == 1
+ for (var kidx = 0u; kidx < uniforms.K; kidx += tile_size_k) {
+ process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std);
+ }
+#elif k_unroll_tiles == 2
+ let unrolled_k_step = tile_size_k * 2u;
+ let unrolled_k_limit = uniforms.K - (uniforms.K % unrolled_k_step);
+ for (var kidx = 0u; kidx < unrolled_k_limit; kidx += unrolled_k_step) {
+ process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std);
+ process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k, inv_std);
+ }
+ for (var kidx = unrolled_k_limit; kidx < uniforms.K; kidx += tile_size_k) {
+ process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std);
+ }
+#elif k_unroll_tiles == 4
+ let unrolled_k_step = tile_size_k * 4u;
+ let unrolled_k_limit = uniforms.K - (uniforms.K % unrolled_k_step);
+ for (var kidx = 0u; kidx < unrolled_k_limit; kidx += unrolled_k_step) {
+ process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std);
+ process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k, inv_std);
+ process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k * 2u, inv_std);
+ process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k * 3u, inv_std);
+ }
+ for (var kidx = unrolled_k_limit; kidx < uniforms.K; kidx += tile_size_k) {
+ process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std);
+ }
+#endif
+
+ if (local_idx < tile_size) {
+ let b_global = b_global_base + local_idx;
+ var q_output_value = q_output_element_t(0);
+ var k_output_value = q_output_element_t(0);
+ var v_output_value = q_output_element_t(0);
+ for (var b = 0u; b < tile_size_k_vec; b++) {
+ q_output_value += q_inter_results[local_idx][b];
+ k_output_value += k_inter_results[local_idx][b];
+ v_output_value += v_inter_results[local_idx][b];
+ }
+ if (b_global < uniforms.Nq) {
+ q_output.setByOffset(batch * uniforms.Nq + b_global, q_output_value_t(q_output_value));
+ }
+ if (b_global < uniforms.Nkv) {
+ k_output.setByOffset(batch * uniforms.Nkv + b_global, k_output_value_t(k_output_value));
+ v_output.setByOffset(batch * uniforms.Nkv + b_global, v_output_value_t(v_output_value));
+ }
+ }
+}
diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc
index 41e2b52016cfd..32368287f6347 100644
--- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc
@@ -31,6 +31,8 @@ static const BuildKernelCreateInfoFn build_kernel_create_info_function_table[] =
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index a5537c7d58b05..8ce5527d3730f 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -3619,6 +3619,210 @@ For example, for 4 bits, the first 4 bits are stored in the lower 4 bits of a by
}
});
+ static const char* MatMulNBitsMlp_ver1_doc = R"DOC(
+MatMulNBitsMlp fuses two MatMulNBits projections that share the same input and computes
+
+ gate = MatMulNBits(A, gate_weight) + gate_bias
+ up = MatMulNBits(A, up_weight) + up_bias
+ Y = activation(gate) * up
+
+It can also optionally fuse SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization before the
+two projections:
+
+ A_norm = SimplifiedLayerNormalization(A, norm_scale, epsilon)
+ gate = MatMulNBits(A_norm, gate_weight) + gate_bias
+ up = MatMulNBits(A_norm, up_weight) + up_bias
+ Y = activation(gate) * up
+
+ A_norm = SkipSimplifiedLayerNormalization(A, skip, norm_scale, epsilon)
+ gate = MatMulNBits(A_norm, gate_weight) + gate_bias
+ up = MatMulNBits(A_norm, up_weight) + up_bias
+ Y = activation(gate) * up
+
+This operator is intended for decoder MLP patterns such as Qwen-style gate and up projections, but it remains
+semantically valid for both prefill and decode because the output shape is the standard MatMul result shape
+derived from the runtime shape of A and the shared attributes K and N.
+
+The operator contract includes a string attribute describing the fused gate activation.
+
+When fused from SkipSimplifiedLayerNormalization, the optional residual-sum output may also be materialized:
+
+ A_norm, input_skip_bias_sum = SkipSimplifiedLayerNormalization(A, skip, norm_scale, epsilon)
+ gate = MatMulNBits(A_norm, gate_weight) + gate_bias
+ up = MatMulNBits(A_norm, up_weight) + up_bias
+ Y = activation(gate) * up
+)DOC";
+
+ ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBitsMlp)
+ .SetDomain(kMSDomain)
+ .SinceVersion(1)
+ .SetDoc(MatMulNBitsMlp_ver1_doc)
+ .Attr("K", "Input feature dimension shared by both quantized weight matrices.", AttributeProto::INT)
+ .Attr("N", "Output feature dimension shared by both quantized weight matrices.", AttributeProto::INT)
+ .Attr("bits", "Bit-width used to quantize both weight matrices (valid range: 2~8)", AttributeProto::INT, static_cast(4))
+ .Attr("block_size",
+ "Size of each quantization block along the K dimension. Must be a power of two and >= 16.",
+ AttributeProto::INT)
+ .Attr("accuracy_level",
+ "The minimum accuracy level of input A. It follows the same semantics as MatMulNBits.",
+ AttributeProto::INT, static_cast(0))
+ .Attr("activation",
+ "Activation applied to the gate projection.",
+ AttributeProto::STRING)
+ .Attr("epsilon",
+ "Epsilon used by the optional fused (Skip)SimplifiedLayerNormalization. Defaults to 1e-5.",
+ AttributeProto::FLOAT, 1e-5f)
+ .Input(0, "A", "The shared input tensor.", "T1")
+ .Input(1, "skip", "Optional skip input used by SkipSimplifiedLayerNormalization.", "T1", OpSchema::Optional)
+ .Input(2, "norm_scale", "Optional RMSNorm scale with shape [K] used by SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization.", "T1", OpSchema::Optional)
+ .Input(3, "gate_B", "Packed uint8 tensor for the gate projection weights.", "T2")
+ .Input(4, "gate_scales", "Per-block scaling factors for the gate projection.", "T1")
+ .Input(5, "gate_bias", "Optional bias for the gate projection with shape [N].", "T1", OpSchema::Optional)
+ .Input(6, "up_B", "Packed uint8 tensor for the up projection weights.", "T2")
+ .Input(7, "up_scales", "Per-block scaling factors for the up projection.", "T1")
+ .Input(8, "up_bias", "Optional bias for the up projection with shape [N].", "T1", OpSchema::Optional)
+ .Output(0, "Y", "The fused gated MLP output tensor.", "T1")
+ .Output(1, "input_skip_bias_sum", "Optional residual-sum output for SkipSimplifiedLayerNormalization.", "T1", OpSchema::Optional)
+ .TypeConstraint("T1", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"},
+ "Constrain input and output types to float tensors.")
+ .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.")
+ .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
+ propagateElemTypeFromInputToOutput(ctx, 0, 0);
+ if (ctx.getNumOutputs() > 1) {
+ propagateElemTypeFromInputToOutput(ctx, 0, 1);
+ }
+
+ const int64_t in_features = getAttribute(ctx, "K", -1);
+ const int64_t out_features = getAttribute(ctx, "N", -1);
+ MatmulWithQuantWeightShapeInference(ctx, in_features, out_features, true);
+
+ if (ctx.hasInput(1) && !ctx.hasInput(2)) {
+ fail_shape_inference("norm_scale input must be present when skip input is provided");
+ }
+
+ if (ctx.hasOutput(1)) {
+ if (!ctx.hasInput(1)) {
+ fail_shape_inference("skip input must be present when input_skip_bias_sum output is requested");
+ }
+
+ if (!hasInputShape(ctx, 0)) {
+ return;
+ }
+
+ auto* skip_sum_shape = getOutputShape(ctx, 1);
+ *skip_sum_shape = getInputShape(ctx, 0);
+ }
+
+ if (ctx.hasInput(2)) {
+ if (!hasInputShape(ctx, 2)) {
+ fail_shape_inference("norm_scale shape must be known");
+ }
+
+ const auto& norm_scale_shape = getInputShape(ctx, 2);
+ if (norm_scale_shape.dim_size() != 1 ||
+ !norm_scale_shape.dim(0).has_dim_value() ||
+ norm_scale_shape.dim(0).dim_value() != in_features) {
+ fail_shape_inference("norm_scale shape must be [K] where K = ", in_features);
+ }
+ }
+
+ for (size_t bias_input_index : {5U, 8U}) {
+ if (!ctx.hasInput(static_cast(bias_input_index))) {
+ continue;
+ }
+
+ if (!hasInputShape(ctx, static_cast(bias_input_index))) {
+ fail_shape_inference("bias shape must be known");
+ }
+
+ const auto& bias_shape = getInputShape(ctx, static_cast(bias_input_index));
+ if (bias_shape.dim_size() != 1 ||
+ !bias_shape.dim(0).has_dim_value() ||
+ bias_shape.dim(0).dim_value() != out_features) {
+ fail_shape_inference("bias shape must be [N] where N = ", out_features);
+ }
+ }
+ });
+
+ static const char* MatMulNBitsQkv_ver1_doc = R"DOC(
+MatMulNBitsQkv fuses either SimplifiedLayerNormalization (RMSNorm)
+or SkipSimplifiedLayerNormalization with three MatMulNBits projections that share the
+same normalized activation.
+
+ A_norm = SimplifiedLayerNormalization(A, norm_scale, epsilon)
+ Q = MatMulNBits(A_norm, q_weight)
+ K = MatMulNBits(A_norm, k_weight)
+ V = MatMulNBits(A_norm, v_weight)
+
+If skip is provided, the operator computes the SkipSimplifiedLayerNormalization variant
+and may also return the input+skip residual sum as output 3.
+
+This operator is intended as a decode-oriented QKV fusion primitive.
+)DOC";
+
+ ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBitsQkv)
+ .SetDomain(kMSDomain)
+ .SinceVersion(1)
+ .SetDoc(MatMulNBitsQkv_ver1_doc)
+ .Attr("K", "Input feature dimension shared by the normalized input and all projection weights.", AttributeProto::INT)
+ .Attr("Nq", "Output feature dimension of the Q projection.", AttributeProto::INT)
+ .Attr("Nkv", "Output feature dimension shared by the K and V projections.", AttributeProto::INT)
+ .Attr("bits", "Bit-width used to quantize all weight matrices (valid range: 2~8)", AttributeProto::INT, static_cast(4))
+ .Attr("block_size",
+ "Size of each quantization block along the K dimension. Must be a power of two and >= 16.",
+ AttributeProto::INT)
+ .Attr("accuracy_level",
+ "The minimum accuracy level of input A. It follows the same semantics as MatMulNBits.",
+ AttributeProto::INT, static_cast(0))
+ .Attr("epsilon", "Epsilon used by the simplified layer norm reduction.", AttributeProto::FLOAT, 1e-6f)
+ .Input(0, "A", "The shared input tensor.", "T1")
+ .Input(1, "skip", "Optional residual input for SkipSimplifiedLayerNormalization.", "T1", OpSchema::Optional)
+ .Input(2, "norm_scale", "Scale input for the simplified layer norm with shape [K].", "T1")
+ .Input(3, "q_B", "Packed uint8 tensor for the Q projection weights.", "T2")
+ .Input(4, "q_scales", "Per-block scaling factors for the Q projection.", "T1")
+ .Input(5, "k_B", "Packed uint8 tensor for the K projection weights.", "T2")
+ .Input(6, "k_scales", "Per-block scaling factors for the K projection.", "T1")
+ .Input(7, "v_B", "Packed uint8 tensor for the V projection weights.", "T2")
+ .Input(8, "v_scales", "Per-block scaling factors for the V projection.", "T1")
+ .Output(0, "Q", "The Q projection output tensor.", "T1")
+ .Output(1, "K", "The K projection output tensor.", "T1")
+ .Output(2, "V", "The V projection output tensor.", "T1")
+ .Output(3, "input_skip_bias_sum", "Optional residual-sum output for SkipSimplifiedLayerNormalization.", "T1", OpSchema::Optional)
+ .TypeConstraint("T1", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"},
+ "Constrain input and output types to float tensors.")
+ .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.")
+ .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
+ for (size_t output_index = 0; output_index < ctx.getNumOutputs(); ++output_index) {
+ propagateElemTypeFromInputToOutput(ctx, 0, output_index);
+ }
+
+ if (!hasInputShape(ctx, 0)) {
+ return;
+ }
+
+ const auto& input_shape = getInputShape(ctx, 0);
+ if (input_shape.dim_size() == 0) {
+ fail_shape_inference("A must have rank >= 1");
+ }
+
+ const int64_t q_out_features = getAttribute(ctx, "Nq", -1);
+ const int64_t kv_out_features = getAttribute(ctx, "Nkv", -1);
+
+ auto set_output_shape = [&](int output_index, int64_t out_features) {
+ auto* output_shape = getOutputShape(ctx, output_index);
+ *output_shape = input_shape;
+ output_shape->mutable_dim(output_shape->dim_size() - 1)->set_dim_value(out_features);
+ };
+
+ set_output_shape(0, q_out_features);
+ set_output_shape(1, kv_out_features);
+ set_output_shape(2, kv_out_features);
+ if (ctx.getNumOutputs() > 3) {
+ auto* output_shape = getOutputShape(ctx, 3);
+ *output_shape = input_shape;
+ }
+ });
+
static const char* MatMulBnb4_ver1_doc = R"DOC(
MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences:
1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'.
diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc
index 282261cab0e58..a4311969a6d73 100644
--- a/onnxruntime/core/optimizer/graph_transformer_utils.cc
+++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc
@@ -55,6 +55,8 @@
#include "core/optimizer/layer_norm_fusion.h"
#include "core/optimizer/matmul_activation_fusion.h"
#include "core/optimizer/matmul_add_fusion.h"
+#include "core/optimizer/matmul_nbits_qkv_fusion.h"
+#include "core/optimizer/matmul_nbits_mlp_fusion.h"
#include "core/optimizer/matmul_bn_fusion.h"
#include "core/optimizer/matmul_integer_to_float.h"
#include "core/optimizer/matmul_scale_fusion.h"
@@ -446,6 +448,10 @@ InlinedVector> GenerateTransformers(
#endif
transformers.emplace_back(std::make_unique(cpu_ep));
+ transformers.emplace_back(std::make_unique(
+ InlinedHashSet{onnxruntime::kWebGpuExecutionProvider}));
+ transformers.emplace_back(std::make_unique(
+ InlinedHashSet{onnxruntime::kWebGpuExecutionProvider}));
#endif // !defined(DISABLE_CONTRIB_OPS)
// The QDQFinalCleanupTransformer must run AFTER other transformers that fuse Q/DQ nodes. Otherwise, their
diff --git a/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.cc b/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.cc
new file mode 100644
index 0000000000000..522e5f9e495cf
--- /dev/null
+++ b/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.cc
@@ -0,0 +1,494 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/optimizer/matmul_nbits_mlp_fusion.h"
+
+#include
+
+#include "core/graph/graph_utils.h"
+#include "core/graph/node_attr_utils.h"
+#include "core/optimizer/utils.h"
+
+namespace onnxruntime {
+
+namespace {
+
+constexpr const char* kActivationAttrName = "activation";
+// The transformer name is generic for future expansion, but the current fused
+// pattern and emitted op only support gate activation = "silu". To add another
+// gate activation (e.g. GELU for Gemma-style MLPs), extend the pattern matcher
+// below to recognize the new activation subgraph (or a unary node like `Gelu`),
+// add the new value to `MlpActivationKind` in matmul_nbits_mlp.h, and update
+// `EmitGateActivationExpr` plus the `#if activation_kind` block in the WGSL
+// template.
+constexpr const char* kSupportedActivation = "silu";
+
+bool HasInput(const Node& node, size_t index) {
+ return index < node.InputDefs().size() && node.InputDefs()[index] != nullptr && !node.InputDefs()[index]->Name().empty();
+}
+
+const Node* GetInputNode(const Graph& graph, const Node& node, size_t input_index) {
+ const auto* edge = graph_utils::GetInputEdge(node, static_cast(input_index));
+ return edge == nullptr ? nullptr : graph.GetNode(edge->GetNode().Index());
+}
+
+bool IsSupportedMul(const Node& node) {
+ return graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7, 13, 14});
+}
+
+bool IsSupportedSigmoid(const Node& node) {
+ return graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13});
+}
+
+bool IsSupportedSimplifiedLayerNormalization(const Node& node) {
+ return graph_utils::IsSupportedOptypeVersionAndDomain(node, "SimplifiedLayerNormalization", {1});
+}
+
+bool IsSupportedSkipSimplifiedLayerNormalization(const Node& node) {
+ return graph_utils::IsSupportedOptypeVersionAndDomain(node, "SkipSimplifiedLayerNormalization", {1}, kMSDomain);
+}
+
+bool IsSupportedMlpNormAnchor(const Node& node) {
+ return IsSupportedSimplifiedLayerNormalization(node) || IsSupportedSkipSimplifiedLayerNormalization(node);
+}
+
+bool HasProducedOutput(const Node& node, size_t index) {
+ return index < node.OutputDefs().size() && node.OutputDefs()[index] != nullptr && !node.OutputDefs()[index]->Name().empty();
+}
+
+bool ProducesOnlyOptionalSkipOutputAsGraphOutput(const Graph& graph, const Node& node) {
+ const auto graph_outputs = graph.GetNodeOutputsInGraphOutputs(node);
+ return std::all_of(graph_outputs.begin(), graph_outputs.end(), [](int output_idx) { return output_idx == 3; });
+}
+
+size_t ExpectedNormConsumerEdgeCount(const Node& node) {
+ return 2u + ((IsSupportedSkipSimplifiedLayerNormalization(node) && HasProducedOutput(node, 3)) ? 1u : 0u);
+}
+
+bool HasExpectedNormConsumers(const Graph& graph, const Node& node) {
+ const auto graph_outputs = graph.GetNodeOutputsInGraphOutputs(node);
+ const size_t expected_output_edges = ExpectedNormConsumerEdgeCount(node) - graph_outputs.size();
+ if (node.GetOutputEdgesCount() != expected_output_edges) {
+ return false;
+ }
+
+ for (auto output_edge_it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); output_edge_it != end; ++output_edge_it) {
+ const auto& output_node = output_edge_it->GetNode();
+ const auto output_node_input_arg_idx = static_cast(output_edge_it->GetDstArgIndex());
+ const bool is_implicit_input_to_output_node = output_node_input_arg_idx >= output_node.InputDefs().size();
+ if (is_implicit_input_to_output_node) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool IsMatMulNBitsWithoutZeroPointOrGroupIdx(const Node& node) {
+ return graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMulNBits", {1}, kMSDomain) &&
+ !HasInput(node, 3) && !HasInput(node, 4);
+}
+
+int64_t GetIntAttr(const Node& node, const char* name, int64_t default_value, bool required = false) {
+ const auto* attr = graph_utils::GetNodeAttribute(node, name);
+ if (attr == nullptr) {
+ ORT_ENFORCE(!required, "Missing required attribute ", name, " on node ", node.Name());
+ return default_value;
+ }
+
+ return attr->i();
+}
+
+float GetFloatAttr(const Node& node, const char* name, float default_value) {
+ const auto* attr = graph_utils::GetNodeAttribute(node, name);
+ return attr == nullptr ? default_value : attr->f();
+}
+
+bool IsSupportedQuickGelu(const Node& node) {
+ if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "QuickGelu", {1}, kMSDomain)) {
+ return false;
+ }
+ // SiLU is equivalent to QuickGelu(x, alpha=1.0). Any other alpha is a valid
+ // QuickGelu activation but is not the SiLU function that the fused kernel
+ // implements, so we conservatively reject it here.
+ return GetFloatAttr(node, "alpha", 1.0f) == 1.0f;
+}
+
+bool HasSingleNonGraphConsumer(const Graph& graph, const Node& node) {
+ return !graph.NodeProducesGraphOutput(node) && optimizer_utils::CheckOutputEdges(graph, node, 1);
+}
+
+const Node* GetNormProducer(const Graph& graph,
+ const Node& gate_matmul,
+ const Node& up_matmul) {
+ if (gate_matmul.InputDefs().empty() || up_matmul.InputDefs().empty() ||
+ gate_matmul.InputDefs()[0] != up_matmul.InputDefs()[0]) {
+ return nullptr;
+ }
+
+ const Node* gate_input = GetInputNode(graph, gate_matmul, 0);
+ const Node* up_input = GetInputNode(graph, up_matmul, 0);
+ if (gate_input == nullptr || gate_input != up_input || !IsSupportedMlpNormAnchor(*gate_input)) {
+ return nullptr;
+ }
+
+ if (!HasProducedOutput(*gate_input, 0)) {
+ return nullptr;
+ }
+
+ if (graph.NodeProducesGraphOutput(*gate_input) && !ProducesOnlyOptionalSkipOutputAsGraphOutput(graph, *gate_input)) {
+ return nullptr;
+ }
+
+ if (!HasExpectedNormConsumers(graph, *gate_input)) {
+ return nullptr;
+ }
+
+ const size_t min_norm_inputs = IsSupportedSkipSimplifiedLayerNormalization(*gate_input) ? 3u : 2u;
+ if (gate_input->InputDefs().size() < min_norm_inputs) {
+ return nullptr;
+ }
+
+ return gate_input;
+}
+
+bool ValidateMatMulNBitsPair(const Graph& graph,
+ const Node& gate_matmul,
+ const Node& up_matmul,
+ size_t expected_gate_fanout) {
+ if (!IsMatMulNBitsWithoutZeroPointOrGroupIdx(gate_matmul) || !IsMatMulNBitsWithoutZeroPointOrGroupIdx(up_matmul)) {
+ return false;
+ }
+
+ if (!HasSingleNonGraphConsumer(graph, up_matmul)) {
+ return false;
+ }
+
+ if (graph.NodeProducesGraphOutput(gate_matmul) || gate_matmul.GetOutputEdgesCount() != expected_gate_fanout) {
+ return false;
+ }
+
+ if (gate_matmul.InputDefs().empty() || up_matmul.InputDefs().empty() ||
+ gate_matmul.InputDefs()[0] != up_matmul.InputDefs()[0]) {
+ return false;
+ }
+
+ const int64_t gate_k = GetIntAttr(gate_matmul, "K", -1, true);
+ const int64_t up_k = GetIntAttr(up_matmul, "K", -1, true);
+ const int64_t gate_n = GetIntAttr(gate_matmul, "N", -1, true);
+ const int64_t up_n = GetIntAttr(up_matmul, "N", -1, true);
+ const int64_t gate_bits = GetIntAttr(gate_matmul, "bits", 4);
+ const int64_t up_bits = GetIntAttr(up_matmul, "bits", 4);
+ const int64_t gate_block_size = GetIntAttr(gate_matmul, "block_size", -1, true);
+ const int64_t up_block_size = GetIntAttr(up_matmul, "block_size", -1, true);
+ const int64_t gate_accuracy_level = GetIntAttr(gate_matmul, "accuracy_level", 0);
+ const int64_t up_accuracy_level = GetIntAttr(up_matmul, "accuracy_level", 0);
+
+ return gate_k == up_k && gate_n == up_n &&
+ gate_bits == up_bits && gate_bits == 4 &&
+ gate_block_size == up_block_size && gate_block_size == 32 &&
+ gate_accuracy_level == up_accuracy_level;
+}
+
+// Validates the SiLU-decomposed activation shape:
+// gate_matmul -> Sigmoid -+
+// gate_matmul ------------+-> silu_mul -> final_mul <- up_matmul
+bool IsFuseCandidateSilu(const Graph& graph,
+ const Node& gate_matmul,
+ const Node& up_matmul,
+ const Node& sigmoid,
+ const Node& silu_mul,
+ const Node& final_mul) {
+ if (!IsSupportedSigmoid(sigmoid) || !IsSupportedMul(silu_mul) || !IsSupportedMul(final_mul)) {
+ return false;
+ }
+
+ if (!HasSingleNonGraphConsumer(graph, sigmoid) || !HasSingleNonGraphConsumer(graph, silu_mul)) {
+ return false;
+ }
+
+ if (!ValidateMatMulNBitsPair(graph, gate_matmul, up_matmul, /*expected_gate_fanout=*/2)) {
+ return false;
+ }
+
+ if (sigmoid.InputDefs()[0] != gate_matmul.OutputDefs()[0]) {
+ return false;
+ }
+
+ const bool silu_mul_matches =
+ (silu_mul.InputDefs()[0] == gate_matmul.OutputDefs()[0] && silu_mul.InputDefs()[1] == sigmoid.OutputDefs()[0]) ||
+ (silu_mul.InputDefs()[1] == gate_matmul.OutputDefs()[0] && silu_mul.InputDefs()[0] == sigmoid.OutputDefs()[0]);
+ if (!silu_mul_matches) {
+ return false;
+ }
+
+ const bool final_mul_matches =
+ (final_mul.InputDefs()[0] == silu_mul.OutputDefs()[0] && final_mul.InputDefs()[1] == up_matmul.OutputDefs()[0]) ||
+ (final_mul.InputDefs()[1] == silu_mul.OutputDefs()[0] && final_mul.InputDefs()[0] == up_matmul.OutputDefs()[0]);
+ return final_mul_matches;
+}
+
+// Validates the fused-QuickGelu activation shape produced by QuickGeluFusion:
+// gate_matmul -> QuickGelu(alpha=1.0) -> final_mul <- up_matmul
+bool IsFuseCandidateQuickGelu(const Graph& graph,
+ const Node& gate_matmul,
+ const Node& up_matmul,
+ const Node& quick_gelu,
+ const Node& final_mul) {
+ if (!IsSupportedQuickGelu(quick_gelu) || !IsSupportedMul(final_mul)) {
+ return false;
+ }
+
+ if (!HasSingleNonGraphConsumer(graph, quick_gelu)) {
+ return false;
+ }
+
+ if (!ValidateMatMulNBitsPair(graph, gate_matmul, up_matmul, /*expected_gate_fanout=*/1)) {
+ return false;
+ }
+
+ if (quick_gelu.InputDefs()[0] != gate_matmul.OutputDefs()[0]) {
+ return false;
+ }
+
+ const bool final_mul_matches =
+ (final_mul.InputDefs()[0] == quick_gelu.OutputDefs()[0] && final_mul.InputDefs()[1] == up_matmul.OutputDefs()[0]) ||
+ (final_mul.InputDefs()[1] == quick_gelu.OutputDefs()[0] && final_mul.InputDefs()[0] == up_matmul.OutputDefs()[0]);
+ return final_mul_matches;
+}
+
+} // namespace
+
+Status MatMulNBitsMlpFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
+ const logging::Logger& logger) const {
+ GraphViewer graph_viewer(graph);
+ const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
+
+ for (auto node_index : node_topology_list) {
+ auto* node_ptr = graph.GetNode(node_index);
+ if (node_ptr == nullptr) {
+ continue;
+ }
+
+ auto& node = *node_ptr;
+ ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
+
+ if (!IsSupportedMul(node)) {
+ continue;
+ }
+
+ const auto& node_ep = node.GetExecutionProviderType();
+ if (!node_ep.empty() && node_ep != kWebGpuExecutionProvider) {
+ continue;
+ }
+
+ const Node* input0 = GetInputNode(graph, node, 0);
+ const Node* input1 = GetInputNode(graph, node, 1);
+ if (input0 == nullptr || input1 == nullptr) {
+ continue;
+ }
+
+ const Node* activation_root = nullptr;
+ const Node* up_matmul = nullptr;
+ if (IsMatMulNBitsWithoutZeroPointOrGroupIdx(*input1) &&
+ (IsSupportedMul(*input0) || IsSupportedQuickGelu(*input0))) {
+ activation_root = input0;
+ up_matmul = input1;
+ } else if (IsMatMulNBitsWithoutZeroPointOrGroupIdx(*input0) &&
+ (IsSupportedMul(*input1) || IsSupportedQuickGelu(*input1))) {
+ activation_root = input1;
+ up_matmul = input0;
+ } else {
+ continue;
+ }
+
+ // The gate-side subgraph between `gate_matmul` and the outer Mul `node`
+ // takes one of two shapes:
+ // 1) SiLU decomposed: gate -> Sigmoid -+
+ // gate ------------+-> silu_mul -> node
+ // `activation_root` is the inner Mul (silu_mul); 2 intermediates.
+ // 2) Fused QuickGelu (post QuickGeluFusion): gate -> QuickGelu -> node
+ // `activation_root` is the QuickGelu node; 1 intermediate.
+ const Node* gate_matmul = nullptr;
+ InlinedVector activation_intermediates;
+ const char* matched_shape = nullptr;
+
+ if (IsSupportedQuickGelu(*activation_root)) {
+ const Node* qg_input = GetInputNode(graph, *activation_root, 0);
+ if (qg_input == nullptr || !IsMatMulNBitsWithoutZeroPointOrGroupIdx(*qg_input)) {
+ continue;
+ }
+ gate_matmul = qg_input;
+ if (!IsFuseCandidateQuickGelu(graph, *gate_matmul, *up_matmul, *activation_root, node)) {
+ continue;
+ }
+ activation_intermediates.push_back(activation_root);
+ matched_shape = "quick_gelu";
+ } else {
+ const Node* silu_input0 = GetInputNode(graph, *activation_root, 0);
+ const Node* silu_input1 = GetInputNode(graph, *activation_root, 1);
+ if (silu_input0 == nullptr || silu_input1 == nullptr) {
+ continue;
+ }
+
+ const Node* sigmoid = nullptr;
+ if (IsMatMulNBitsWithoutZeroPointOrGroupIdx(*silu_input0) && IsSupportedSigmoid(*silu_input1)) {
+ gate_matmul = silu_input0;
+ sigmoid = silu_input1;
+ } else if (IsMatMulNBitsWithoutZeroPointOrGroupIdx(*silu_input1) && IsSupportedSigmoid(*silu_input0)) {
+ gate_matmul = silu_input1;
+ sigmoid = silu_input0;
+ } else {
+ continue;
+ }
+
+ if (!IsFuseCandidateSilu(graph, *gate_matmul, *up_matmul, *sigmoid, *activation_root, node)) {
+ continue;
+ }
+ activation_intermediates.push_back(sigmoid);
+ activation_intermediates.push_back(activation_root);
+ matched_shape = "silu";
+ }
+
+ LOGS(logger, VERBOSE) << "MatMulNBitsMlpFusion: matched candidate shape='" << matched_shape
+ << "' output_mul='" << node.Name()
+ << "' gate='" << gate_matmul->Name() << "' up='" << up_matmul->Name()
+ << "' attrs={K=" << GetIntAttr(*gate_matmul, "K", -1, true)
+ << ", N=" << GetIntAttr(*gate_matmul, "N", -1, true)
+ << ", bits=" << GetIntAttr(*gate_matmul, "bits", 4)
+ << ", block_size=" << GetIntAttr(*gate_matmul, "block_size", -1, true)
+ << ", accuracy_level=" << GetIntAttr(*gate_matmul, "accuracy_level", 0)
+ << "}";
+
+ bool intermediates_on_supported_ep = true;
+ for (const Node* intermediate : activation_intermediates) {
+ const auto& ep = intermediate->GetExecutionProviderType();
+ if (!ep.empty() && ep != kWebGpuExecutionProvider) {
+ intermediates_on_supported_ep = false;
+ break;
+ }
+ }
+ if ((!gate_matmul->GetExecutionProviderType().empty() && gate_matmul->GetExecutionProviderType() != kWebGpuExecutionProvider) ||
+ (!up_matmul->GetExecutionProviderType().empty() && up_matmul->GetExecutionProviderType() != kWebGpuExecutionProvider) ||
+ !intermediates_on_supported_ep) {
+ LOGS(logger, VERBOSE) << "MatMulNBitsMlpFusion: skipping candidate due to non-WebGPU EP assignment.";
+ continue;
+ }
+
+ const Node* norm = GetNormProducer(graph, *gate_matmul, *up_matmul);
+ if (norm == nullptr) {
+ continue;
+ }
+
+ if (!norm->GetExecutionProviderType().empty() && norm->GetExecutionProviderType() != kWebGpuExecutionProvider) {
+ continue;
+ }
+
+ NodeAttributes attrs;
+ utils::SetNodeAttribute(utils::MakeAttribute("K", GetIntAttr(*gate_matmul, "K", -1, true)), attrs);
+ utils::SetNodeAttribute(utils::MakeAttribute("N", GetIntAttr(*gate_matmul, "N", -1, true)), attrs);
+ utils::SetNodeAttribute(utils::MakeAttribute("bits", GetIntAttr(*gate_matmul, "bits", 4)), attrs);
+ utils::SetNodeAttribute(utils::MakeAttribute("block_size", GetIntAttr(*gate_matmul, "block_size", -1, true)), attrs);
+ utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", GetIntAttr(*gate_matmul, "accuracy_level", 0)), attrs);
+ utils::SetNodeAttribute(utils::MakeAttribute(kActivationAttrName, std::string{kSupportedActivation}), attrs);
+ utils::SetNodeAttribute(utils::MakeAttribute("epsilon", GetFloatAttr(*norm, "epsilon", 1e-5f)), attrs);
+
+ NodeArg& empty_arg = graph.GetOrCreateNodeArg("", nullptr);
+ const bool is_skip_sln = norm != nullptr && IsSupportedSkipSimplifiedLayerNormalization(*norm);
+
+ InlinedVector fused_inputs{
+ const_cast(norm->InputDefs()[0]),
+ is_skip_sln ? const_cast(norm->InputDefs()[1]) : &empty_arg,
+ const_cast(norm->InputDefs()[is_skip_sln ? 2 : 1]),
+ const_cast(gate_matmul->InputDefs()[1]),
+ const_cast(gate_matmul->InputDefs()[2]),
+ HasInput(*gate_matmul, 5) ? const_cast(gate_matmul->InputDefs()[5]) : &empty_arg,
+ const_cast(up_matmul->InputDefs()[1]),
+ const_cast(up_matmul->InputDefs()[2]),
+ HasInput(*up_matmul, 5) ? const_cast(up_matmul->InputDefs()[5]) : &empty_arg,
+ };
+
+ InlinedVector fused_outputs{const_cast(node.OutputDefs()[0])};
+ const bool preserve_skip_output = is_skip_sln && norm != nullptr && HasProducedOutput(*norm, 3);
+ if (preserve_skip_output) {
+ fused_outputs.push_back(const_cast(norm->OutputDefs()[3]));
+ }
+
+ const auto norm_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(*norm);
+ const auto gate_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(*gate_matmul);
+ const auto up_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(*up_matmul);
+ const auto final_mul_output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(node);
+ const auto norm_output_edges = preserve_skip_output ? graph_utils::GraphEdge::GetNodeOutputEdges(*norm)
+ : std::vector{};
+
+ const std::string output_mul_name = node.Name();
+
+ graph_utils::RemoveNodeOutputEdges(graph, const_cast(*norm));
+ graph.RemoveNode(norm->Index());
+ graph_utils::RemoveNodeOutputEdges(graph, const_cast(*gate_matmul));
+ graph.RemoveNode(gate_matmul->Index());
+ graph_utils::RemoveNodeOutputEdges(graph, const_cast(*up_matmul));
+ graph.RemoveNode(up_matmul->Index());
+ for (const Node* intermediate : activation_intermediates) {
+ graph_utils::RemoveNodeOutputEdges(graph, const_cast(*intermediate));
+ graph.RemoveNode(intermediate->Index());
+ }
+ graph_utils::RemoveNodeOutputEdges(graph, node);
+ graph.RemoveNode(node.Index());
+
+ Node& fused_node = graph.AddNode(graph.GenerateNodeName("MatMulNBitsMlp"),
+ "MatMulNBitsMlp",
+ "fused MatMulNBits gated MLP projections",
+ fused_inputs,
+ fused_outputs,
+ &attrs,
+ kMSDomain);
+ fused_node.SetExecutionProviderType(kWebGpuExecutionProvider);
+
+ LOGS(logger, VERBOSE) << "MatMulNBitsMlpFusion: created fused node '" << fused_node.Name()
+ << "' from output_mul='" << output_mul_name << "'";
+
+ for (const auto& input_edge : norm_input_edges) {
+ int fused_input_index = input_edge.dst_arg_index;
+ if (!is_skip_sln && input_edge.dst_arg_index == 1) {
+ fused_input_index = 2;
+ }
+
+ graph.AddEdge(input_edge.src_node, fused_node.Index(), input_edge.src_arg_index, fused_input_index);
+ }
+
+ auto add_input_edge_if_present = [&](const std::vector& edges,
+ int source_input_index,
+ int fused_input_index) {
+ for (const auto& input_edge : edges) {
+ if (input_edge.dst_arg_index == source_input_index) {
+ graph.AddEdge(input_edge.src_node, fused_node.Index(), input_edge.src_arg_index, fused_input_index);
+ }
+ }
+ };
+
+ add_input_edge_if_present(gate_input_edges, 1, 3);
+ add_input_edge_if_present(gate_input_edges, 2, 4);
+ add_input_edge_if_present(gate_input_edges, 5, 5);
+ add_input_edge_if_present(up_input_edges, 1, 6);
+ add_input_edge_if_present(up_input_edges, 2, 7);
+ add_input_edge_if_present(up_input_edges, 5, 8);
+
+ for (const auto& output_edge : final_mul_output_edges) {
+ graph.AddEdge(fused_node.Index(), output_edge.dst_node, 0, output_edge.dst_arg_index);
+ }
+ if (preserve_skip_output) {
+ for (const auto& output_edge : norm_output_edges) {
+ if (output_edge.src_arg_index == 3) {
+ graph.AddEdge(fused_node.Index(), output_edge.dst_node, 1, output_edge.dst_arg_index);
+ }
+ }
+ }
+
+ modified = true;
+ }
+
+ return Status::OK();
+}
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.h b/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.h
new file mode 100644
index 0000000000000..007d21027dca0
--- /dev/null
+++ b/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.h
@@ -0,0 +1,39 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/optimizer/graph_transformer.h"
+
+namespace onnxruntime {
+
+// Fuses the SwiGLU gated-activation subgraph (gate / up MatMulNBits projections around a
+// SimplifiedLayerNormalization anchor) into a single MatMulNBitsMlp contrib op:
+//
+// ... -> [Skip]SimplifiedLayerNormalization -+-> MatMulNBits (gate) -+-> Sigmoid -+
+// | | | v
+// | | +----------> Mul (silu) -+
+// | +-> MatMulNBits (up) ---------------------------+--> Mul -> out
+// +--> (optional) skip residual passthrough --> downstream consumers
+//
+// becomes
+//
+// ... -> MatMulNBitsMlp(activation="silu") -+-> out
+// +-> (optional) residual passthrough
+//
+// The downstream "down" projection (a third MatMulNBits that follows the gated-activation
+// output) is intentionally NOT part of this fusion -- it remains a separate MatMulNBits node
+// in the resulting graph.
+//
+// Only activation="silu" (i.e. x * Sigmoid(x)) is matched / emitted, and the fusion is restricted
+// to the WebGPU EP because MatMulNBitsMlp is a WebGPU-only contrib op.
+class MatMulNBitsMlpFusion : public GraphTransformer {
+ public:
+ explicit MatMulNBitsMlpFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept
+ : GraphTransformer("MatMulNBitsMlpFusion", compatible_execution_providers) {}
+
+ private:
+ Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
+};
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.cc b/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.cc
new file mode 100644
index 0000000000000..05cd234dba577
--- /dev/null
+++ b/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.cc
@@ -0,0 +1,372 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/optimizer/matmul_nbits_qkv_fusion.h"
+
+#include
+#include
+#include
+
+#include "core/graph/graph_utils.h"
+#include "core/graph/node_attr_utils.h"
+#include "core/optimizer/utils.h"
+
+namespace onnxruntime {
+
+namespace {
+
+bool HasInput(const Node& node, size_t index) {
+ return index < node.InputDefs().size() && node.InputDefs()[index] != nullptr && !node.InputDefs()[index]->Name().empty();
+}
+
+bool IsSupportedSimplifiedLayerNormalization(const Node& node) {
+ return graph_utils::IsSupportedOptypeVersionAndDomain(node, "SimplifiedLayerNormalization", {1});
+}
+
+bool IsSupportedSkipSimplifiedLayerNormalization(const Node& node) {
+ return graph_utils::IsSupportedOptypeVersionAndDomain(node, "SkipSimplifiedLayerNormalization", {1}, kMSDomain);
+}
+
+bool IsSupportedNormForFusion(const Node& node) {
+ return IsSupportedSimplifiedLayerNormalization(node) || IsSupportedSkipSimplifiedLayerNormalization(node);
+}
+
+bool HasProducedOutput(const Node& node, size_t index) {
+ return index < node.OutputDefs().size() && node.OutputDefs()[index] != nullptr && !node.OutputDefs()[index]->Name().empty();
+}
+
+bool IsMatMulNBitsWithoutOptionalInputs(const Node& node) {
+ return graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMulNBits", {1}, kMSDomain) &&
+ !HasInput(node, 3) && !HasInput(node, 4) && !HasInput(node, 5);
+}
+
+int64_t GetIntAttr(const Node& node, const char* name, int64_t default_value, bool required = false) {
+ const auto* attr = graph_utils::GetNodeAttribute(node, name);
+ if (attr == nullptr) {
+ ORT_ENFORCE(!required, "Missing required attribute ", name, " on node ", node.Name());
+ return default_value;
+ }
+
+ return attr->i();
+}
+
+float GetFloatAttr(const Node& node, const char* name, float default_value) {
+ const auto* attr = graph_utils::GetNodeAttribute(node, name);
+ return attr == nullptr ? default_value : attr->f();
+}
+
+struct QkvNodes {
+ const Node* q = nullptr;
+ const Node* k = nullptr;
+ const Node* v = nullptr;
+};
+
+bool IsGraphOutput(const Graph& graph, const Node& node, size_t index) {
+ if (!HasProducedOutput(node, index)) {
+ return false;
+ }
+ const auto& output_name = node.OutputDefs()[index]->Name();
+ for (const auto* graph_output : graph.GetOutputs()) {
+ if (graph_output != nullptr && graph_output->Name() == output_name) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool HasOutputConsumers(const Node& node, size_t index) {
+ if (!HasProducedOutput(node, index)) {
+ return false;
+ }
+ for (auto edge_it = node.OutputEdgesBegin(); edge_it != node.OutputEdgesEnd(); ++edge_it) {
+ if (static_cast(edge_it->GetSrcArgIndex()) == index) {
+ return true;
+ }
+ }
+ return false;
+}
+
+// Output 0 of the norm is consumed by the fused op, so it must not be a graph output.
+// For SkipSimplifiedLayerNormalization the optional residual sum at output 3 is
+// preserved by the fused MatMulNBitsQkv op, so it is allowed to remain a graph output.
+// Outputs 1 and 2 (mean / inv_std_var) are not exposed by the fused op and must not
+// be graph outputs or feed any downstream nodes.
+bool IsSupportedNormGraphOutputsForFusion(const Graph& graph, const Node& norm) {
+ if (IsGraphOutput(graph, norm, 0)) {
+ return false;
+ }
+ for (size_t i = 1; i < norm.OutputDefs().size(); ++i) {
+ if (i == 1 || i == 2) {
+ if (IsGraphOutput(graph, norm, i) || HasOutputConsumers(norm, i)) {
+ return false;
+ }
+ continue;
+ }
+ if (!IsGraphOutput(graph, norm, i)) {
+ continue;
+ }
+ if (!(IsSupportedSkipSimplifiedLayerNormalization(norm) && i == 3)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+std::optional GetQkvNodes(const Graph& graph, const Node& norm) {
+ if (!HasProducedOutput(norm, 0) || !IsSupportedNormGraphOutputsForFusion(graph, norm)) {
+ return std::nullopt;
+ }
+
+ std::array consumers{};
+ size_t consumer_index = 0;
+ for (auto edge_it = norm.OutputEdgesBegin(); edge_it != norm.OutputEdgesEnd(); ++edge_it) {
+ if (edge_it->GetSrcArgIndex() != 0) {
+ continue;
+ }
+
+ if (consumer_index >= consumers.size()) {
+ return std::nullopt;
+ }
+
+ if (edge_it->GetDstArgIndex() != 0) {
+ return std::nullopt;
+ }
+
+ const Node* consumer = graph.GetNode(edge_it->GetNode().Index());
+ if (consumer == nullptr || !IsMatMulNBitsWithoutOptionalInputs(*consumer)) {
+ return std::nullopt;
+ }
+
+ consumers[consumer_index++] = consumer;
+ }
+
+ if (consumer_index != consumers.size()) {
+ return std::nullopt;
+ }
+
+ const int64_t n0 = GetIntAttr(*consumers[0], "N", -1, true);
+ const int64_t n1 = GetIntAttr(*consumers[1], "N", -1, true);
+ const int64_t n2 = GetIntAttr(*consumers[2], "N", -1, true);
+
+ QkvNodes qkv;
+ if (n0 != n1 && n1 == n2) {
+ qkv = {consumers[0], consumers[1], consumers[2]};
+ } else if (n1 != n0 && n0 == n2) {
+ qkv = {consumers[1], consumers[0], consumers[2]};
+ } else if (n2 != n0 && n0 == n1) {
+ qkv = {consumers[2], consumers[0], consumers[1]};
+ } else {
+ return std::nullopt;
+ }
+
+ return qkv;
+}
+
+bool HasSupportedExecutionProvider(const Node& node) {
+ const auto& node_ep = node.GetExecutionProviderType();
+ return node_ep.empty() || node_ep == kWebGpuExecutionProvider;
+}
+
+bool IsFuseCandidate(const Node& norm, const QkvNodes& qkv) {
+ if (!IsSupportedNormForFusion(norm) || qkv.q == nullptr || qkv.k == nullptr || qkv.v == nullptr) {
+ return false;
+ }
+
+ if (!HasSupportedExecutionProvider(norm) || !HasSupportedExecutionProvider(*qkv.q) ||
+ !HasSupportedExecutionProvider(*qkv.k) || !HasSupportedExecutionProvider(*qkv.v)) {
+ return false;
+ }
+
+ const size_t min_norm_inputs = IsSupportedSkipSimplifiedLayerNormalization(norm) ? 3u : 2u;
+ if (norm.InputDefs().size() < min_norm_inputs || qkv.q->InputDefs().empty() || qkv.k->InputDefs().empty() || qkv.v->InputDefs().empty()) {
+ return false;
+ }
+
+ if (qkv.q->InputDefs()[0] != norm.OutputDefs()[0] || qkv.k->InputDefs()[0] != norm.OutputDefs()[0] ||
+ qkv.v->InputDefs()[0] != norm.OutputDefs()[0]) {
+ return false;
+ }
+
+ const int64_t q_k = GetIntAttr(*qkv.q, "K", -1, true);
+ const int64_t k_k = GetIntAttr(*qkv.k, "K", -1, true);
+ const int64_t v_k = GetIntAttr(*qkv.v, "K", -1, true);
+ const int64_t q_bits = GetIntAttr(*qkv.q, "bits", 4);
+ const int64_t k_bits = GetIntAttr(*qkv.k, "bits", 4);
+ const int64_t v_bits = GetIntAttr(*qkv.v, "bits", 4);
+ const int64_t q_block_size = GetIntAttr(*qkv.q, "block_size", -1, true);
+ const int64_t k_block_size = GetIntAttr(*qkv.k, "block_size", -1, true);
+ const int64_t v_block_size = GetIntAttr(*qkv.v, "block_size", -1, true);
+ const int64_t q_accuracy_level = GetIntAttr(*qkv.q, "accuracy_level", 0);
+ const int64_t k_accuracy_level = GetIntAttr(*qkv.k, "accuracy_level", 0);
+ const int64_t v_accuracy_level = GetIntAttr(*qkv.v, "accuracy_level", 0);
+
+ return q_k == k_k && q_k == v_k &&
+ q_bits == k_bits && q_bits == v_bits && q_bits == 4 &&
+ q_block_size == k_block_size && q_block_size == v_block_size && q_block_size == 32 &&
+ q_accuracy_level == k_accuracy_level && q_accuracy_level == v_accuracy_level;
+}
+
+} // namespace
+
+Status MatMulNBitsQkvFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
+ const logging::Logger& logger) const {
+ GraphViewer graph_viewer(graph);
+ const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
+
+ for (auto node_index : node_topology_list) {
+ auto* node_ptr = graph.GetNode(node_index);
+ if (node_ptr == nullptr) {
+ continue;
+ }
+
+ auto& node = *node_ptr;
+ ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
+
+ if (!IsSupportedNormForFusion(node)) {
+ continue;
+ }
+
+ const auto qkv_nodes = GetQkvNodes(graph, node);
+ if (!qkv_nodes || !IsFuseCandidate(node, *qkv_nodes)) {
+ continue;
+ }
+
+ const int64_t K = GetIntAttr(*qkv_nodes->q, "K", -1, true);
+ const int64_t Nq = GetIntAttr(*qkv_nodes->q, "N", -1, true);
+ const int64_t Nkv = GetIntAttr(*qkv_nodes->k, "N", -1, true);
+ const int64_t bits = GetIntAttr(*qkv_nodes->q, "bits", 4);
+ const int64_t block_size = GetIntAttr(*qkv_nodes->q, "block_size", -1, true);
+ const int64_t accuracy_level = GetIntAttr(*qkv_nodes->q, "accuracy_level", 0);
+ const float epsilon = GetFloatAttr(node, "epsilon", 1e-6f);
+
+ const bool is_skip_sln = IsSupportedSkipSimplifiedLayerNormalization(node);
+
+ LOGS(logger, VERBOSE) << "MatMulNBitsQkvFusion: matched norm='" << node.Name()
+ << "' q='" << qkv_nodes->q->Name() << "' k='" << qkv_nodes->k->Name()
+ << "' v='" << qkv_nodes->v->Name() << "' attrs={K=" << K
+ << ", Nq=" << Nq << ", Nkv=" << Nkv << ", bits=" << bits
+ << ", block_size=" << block_size << ", accuracy_level=" << accuracy_level
+ << ", epsilon=" << epsilon << ", skip_sln=" << is_skip_sln << "}";
+
+ NodeAttributes attrs;
+ utils::SetNodeAttribute(utils::MakeAttribute("K", K), attrs);
+ utils::SetNodeAttribute(utils::MakeAttribute("Nq", Nq), attrs);
+ utils::SetNodeAttribute(utils::MakeAttribute("Nkv", Nkv), attrs);
+ utils::SetNodeAttribute(utils::MakeAttribute("bits", bits), attrs);
+ utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs);
+ utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level), attrs);
+ utils::SetNodeAttribute(utils::MakeAttribute("epsilon", epsilon), attrs);
+
+ NodeArg& empty_arg = graph.GetOrCreateNodeArg("", nullptr);
+
+ InlinedVector fused_inputs{
+ const_cast(node.InputDefs()[0]),
+ is_skip_sln ? const_cast(node.InputDefs()[1]) : &empty_arg,
+ const_cast(node.InputDefs()[is_skip_sln ? 2 : 1]),
+ const_cast(qkv_nodes->q->InputDefs()[1]),
+ const_cast(qkv_nodes->q->InputDefs()[2]),
+ const_cast(qkv_nodes->k->InputDefs()[1]),
+ const_cast(qkv_nodes->k->InputDefs()[2]),
+ const_cast(qkv_nodes->v->InputDefs()[1]),
+ const_cast(qkv_nodes->v->InputDefs()[2]),
+ };
+
+ InlinedVector fused_outputs{
+ const_cast(qkv_nodes->q->OutputDefs()[0]),
+ const_cast(qkv_nodes->k->OutputDefs()[0]),
+ const_cast(qkv_nodes->v->OutputDefs()[0]),
+ };
+ if (is_skip_sln && HasProducedOutput(node, 3)) {
+ fused_outputs.push_back(const_cast(node.OutputDefs()[3]));
+ }
+
+ const bool has_residual_output = is_skip_sln && HasProducedOutput(node, 3);
+ const std::string norm_name = node.Name();
+ const std::string q_name = qkv_nodes->q->Name();
+ const std::string k_name = qkv_nodes->k->Name();
+ const std::string v_name = qkv_nodes->v->Name();
+
+ const auto norm_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(node);
+ const auto q_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(*qkv_nodes->q);
+ const auto k_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(*qkv_nodes->k);
+ const auto v_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(*qkv_nodes->v);
+ const auto q_output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(*qkv_nodes->q);
+ const auto k_output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(*qkv_nodes->k);
+ const auto v_output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(*qkv_nodes->v);
+ const auto norm_output_edges = has_residual_output
+ ? graph_utils::GraphEdge::GetNodeOutputEdges(node)
+ : std::vector{};
+ graph_utils::RemoveNodeOutputEdges(graph, const_cast(*qkv_nodes->q));
+ graph.RemoveNode(qkv_nodes->q->Index());
+ graph_utils::RemoveNodeOutputEdges(graph, const_cast(*qkv_nodes->k));
+ graph.RemoveNode(qkv_nodes->k->Index());
+ graph_utils::RemoveNodeOutputEdges(graph, const_cast(*qkv_nodes->v));
+ graph.RemoveNode(qkv_nodes->v->Index());
+ graph_utils::RemoveNodeOutputEdges(graph, node);
+ graph.RemoveNode(node.Index());
+
+ Node& fused_node = graph.AddNode(graph.GenerateNodeName("MatMulNBitsQkv"),
+ "MatMulNBitsQkv",
+ "fused SimplifiedLayerNormalization with Q/K/V MatMulNBits projections",
+ fused_inputs,
+ fused_outputs,
+ &attrs,
+ kMSDomain);
+ fused_node.SetExecutionProviderType(kWebGpuExecutionProvider);
+
+ LOGS(logger, VERBOSE) << "MatMulNBitsQkvFusion: created fused node '" << fused_node.Name()
+ << "' from norm='" << norm_name << "' q='" << q_name
+ << "' k='" << k_name << "' v='" << v_name << "'";
+
+ for (const auto& input_edge : norm_input_edges) {
+ int fused_input_index = input_edge.dst_arg_index;
+ if (!is_skip_sln && input_edge.dst_arg_index == 1) {
+ fused_input_index = 2;
+ }
+
+ graph.AddEdge(input_edge.src_node, fused_node.Index(), input_edge.src_arg_index, fused_input_index);
+ }
+
+ // Q/K/V weight + scale tensors are usually initializers, but if any of them is produced
+ // by an upstream node we must rewire that producer edge to the fused input slot.
+ auto add_input_edge_if_present = [&](const std::vector& edges,
+ int source_input_index,
+ int fused_input_index) {
+ for (const auto& input_edge : edges) {
+ if (input_edge.dst_arg_index == source_input_index) {
+ graph.AddEdge(input_edge.src_node, fused_node.Index(), input_edge.src_arg_index, fused_input_index);
+ }
+ }
+ };
+
+ add_input_edge_if_present(q_input_edges, 1, 3); // q_weight
+ add_input_edge_if_present(q_input_edges, 2, 4); // q_scales
+ add_input_edge_if_present(k_input_edges, 1, 5); // k_weight
+ add_input_edge_if_present(k_input_edges, 2, 6); // k_scales
+ add_input_edge_if_present(v_input_edges, 1, 7); // v_weight
+ add_input_edge_if_present(v_input_edges, 2, 8); // v_scales
+
+ for (const auto& output_edge : q_output_edges) {
+ graph.AddEdge(fused_node.Index(), output_edge.dst_node, 0, output_edge.dst_arg_index);
+ }
+ for (const auto& output_edge : k_output_edges) {
+ graph.AddEdge(fused_node.Index(), output_edge.dst_node, 1, output_edge.dst_arg_index);
+ }
+ for (const auto& output_edge : v_output_edges) {
+ graph.AddEdge(fused_node.Index(), output_edge.dst_node, 2, output_edge.dst_arg_index);
+ }
+ if (has_residual_output) {
+ for (const auto& output_edge : norm_output_edges) {
+ if (output_edge.src_arg_index == 3) {
+ graph.AddEdge(fused_node.Index(), output_edge.dst_node, 3, output_edge.dst_arg_index);
+ }
+ }
+ }
+
+ modified = true;
+ }
+
+ return Status::OK();
+}
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.h b/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.h
new file mode 100644
index 0000000000000..fcbbb78457f52
--- /dev/null
+++ b/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.h
@@ -0,0 +1,36 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/optimizer/graph_transformer.h"
+
+namespace onnxruntime {
+
+// Fuses three sibling MatMulNBits Q/K/V projections that share a SimplifiedLayerNormalization
+// (or SkipSimplifiedLayerNormalization) anchor into a single MatMulNBitsQkv contrib op:
+//
+// ... -> [Skip]SimplifiedLayerNormalization -+-> MatMulNBits (Q proj) -+
+// | +-> MatMulNBits (K proj) -+--> downstream consumers
+// | +-> MatMulNBits (V proj) -+
+// +--> (optional) skip residual passthrough --> downstream consumers
+//
+// becomes
+//
+// ... -> [Skip]SimplifiedLayerNormalization --> MatMulNBitsQkv -+-> Q out
+// +-> K out
+// +-> V out
+// +-> (optional) residual passthrough
+//
+// The fusion is restricted to the WebGPU EP because MatMulNBitsQkv is a WebGPU-only contrib op.
+class MatMulNBitsQkvFusion : public GraphTransformer {
+ public:
+ explicit MatMulNBitsQkvFusion(
+ const InlinedHashSet