diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index fb6a4eb22a872..9d549ac5e1219 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -57,6 +57,8 @@ Do not modify directly.* * com.microsoft.MatMulInteger16 * com.microsoft.MatMulIntegerToFloat * com.microsoft.MatMulNBits + * com.microsoft.MatMulNBitsMlp + * com.microsoft.MatMulNBitsQkv * com.microsoft.MaxpoolWithMask * com.microsoft.MoE * com.microsoft.MulInteger @@ -3189,6 +3191,190 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.MatMulNBitsMlp** + + MatMulNBitsMlp fuses two MatMulNBits projections that share the same input and computes + + gate = MatMulNBits(A, gate_weight) + gate_bias + up = MatMulNBits(A, up_weight) + up_bias + Y = activation(gate) * up + + It can also optionally fuse SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization before the + two projections: + + A_norm = SimplifiedLayerNormalization(A, norm_scale, epsilon) + gate = MatMulNBits(A_norm, gate_weight) + gate_bias + up = MatMulNBits(A_norm, up_weight) + up_bias + Y = activation(gate) * up + + A_norm = SkipSimplifiedLayerNormalization(A, skip, norm_scale, epsilon) + gate = MatMulNBits(A_norm, gate_weight) + gate_bias + up = MatMulNBits(A_norm, up_weight) + up_bias + Y = activation(gate) * up + + This operator is intended for decoder MLP patterns such as Qwen-style gate and up projections, but it remains + semantically valid for both prefill and decode because the output shape is the standard MatMul result shape + derived from the runtime shape of A and the shared attributes K and N. + + The operator contract includes a string attribute describing the fused gate activation. + + When fused from SkipSimplifiedLayerNormalization, the optional residual-sum output may also be materialized: + + A_norm, input_skip_bias_sum = SkipSimplifiedLayerNormalization(A, skip, norm_scale, epsilon) + gate = MatMulNBits(A_norm, gate_weight) + gate_bias + up = MatMulNBits(A_norm, up_weight) + up_bias + Y = activation(gate) * up + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
K : int (required)
+
Input feature dimension shared by both quantized weight matrices.
+
N : int (required)
+
Output feature dimension shared by both quantized weight matrices.
+
accuracy_level : int
+
The minimum accuracy level of input A. It follows the same semantics as MatMulNBits.
+
activation : string (required)
+
Activation applied to the gate projection.
+
bits : int
+
Bit-width used to quantize both weight matrices (valid range: 2~8)
+
block_size : int (required)
+
Size of each quantization block along the K dimension. Must be a power of two and >= 16.
+
epsilon : float
+
Epsilon used by the optional fused (Skip)SimplifiedLayerNormalization. Defaults to 1e-5.
+
+ +#### Inputs (8 - 9) + +
+
A : T1
+
The shared input tensor.
+
skip (optional) : T1
+
Optional skip input used by SkipSimplifiedLayerNormalization.
+
norm_scale (optional) : T1
+
Optional RMSNorm scale with shape [K] used by SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization.
+
gate_B : T2
+
Packed uint8 tensor for the gate projection weights.
+
gate_scales : T1
+
Per-block scaling factors for the gate projection.
+
gate_bias (optional) : T1
+
Optional bias for the gate projection with shape [N].
+
up_B : T2
+
Packed uint8 tensor for the up projection weights.
+
up_scales : T1
+
Per-block scaling factors for the up projection.
+
up_bias (optional) : T1
+
Optional bias for the up projection with shape [N].
+
+ +#### Outputs (1 - 2) + +
+
Y : T1
+
The fused gated MLP output tensor.
+
input_skip_bias_sum (optional) : T1
+
Optional residual-sum output for SkipSimplifiedLayerNormalization.
+
+ +#### Type Constraints + +
+
T1 : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain input and output types to float tensors.
+
T2 : tensor(uint8)
+
Constrain quantized weight types to uint8.
+
+ + +### **com.microsoft.MatMulNBitsQkv** + + MatMulNBitsQkv fuses either SimplifiedLayerNormalization (RMSNorm) + or SkipSimplifiedLayerNormalization with three MatMulNBits projections that share the + same normalized activation. + + A_norm = SimplifiedLayerNormalization(A, norm_scale, epsilon) + Q = MatMulNBits(A_norm, q_weight) + K = MatMulNBits(A_norm, k_weight) + V = MatMulNBits(A_norm, v_weight) + + If skip is provided, the operator computes the SkipSimplifiedLayerNormalization variant + and may also return the input+skip residual sum as output 3. + + This operator is intended as a decode-oriented QKV fusion primitive. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
K : int (required)
+
Input feature dimension shared by the normalized input and all projection weights.
+
Nkv : int (required)
+
Output feature dimension shared by the K and V projections.
+
Nq : int (required)
+
Output feature dimension of the Q projection.
+
accuracy_level : int
+
The minimum accuracy level of input A. It follows the same semantics as MatMulNBits.
+
bits : int
+
Bit-width used to quantize all weight matrices (valid range: 2~8)
+
block_size : int (required)
+
Size of each quantization block along the K dimension. Must be a power of two and >= 16.
+
epsilon : float
+
Epsilon used by the simplified layer norm reduction.
+
+ +#### Inputs + +
+
A : T1
+
The shared input tensor.
+
skip (optional) : T1
+
Optional residual input for SkipSimplifiedLayerNormalization.
+
norm_scale : T1
+
Scale input for the simplified layer norm with shape [K].
+
q_B : T2
+
Packed uint8 tensor for the Q projection weights.
+
q_scales : T1
+
Per-block scaling factors for the Q projection.
+
k_B : T2
+
Packed uint8 tensor for the K projection weights.
+
k_scales : T1
+
Per-block scaling factors for the K projection.
+
v_B : T2
+
Packed uint8 tensor for the V projection weights.
+
v_scales : T1
+
Per-block scaling factors for the V projection.
+
+ +#### Outputs (3 - 4) + +
+
Q : T1
+
The Q projection output tensor.
+
K : T1
+
The K projection output tensor.
+
V : T1
+
The V projection output tensor.
+
input_skip_bias_sum (optional) : T1
+
Optional residual-sum output for SkipSimplifiedLayerNormalization.
+
+ +#### Type Constraints + +
+
T1 : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain input and output types to float tensors.
+
T2 : tensor(uint8)
+
Constrain quantized weight types to uint8.
+
+ + ### **com.microsoft.MaxpoolWithMask** For internal use. diff --git a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc index 619aff1e806b8..473121edbf5dc 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc @@ -154,8 +154,26 @@ Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeCo auto* output = context.Output(0, x_shape); auto* input_skip_bias_sum = context.Output(3, x_shape); - int64_t data_size = x_shape.Size(); - if (data_size == 0) { + if (x_shape.Size() == 0) { + return Status::OK(); + } + + return RunSkipLayerNormProgram(context, x, skip, gamma, beta, bias, epsilon_, simplified, + output, input_skip_bias_sum); +} + +Status RunSkipLayerNormProgram(ComputeContext& context, + const Tensor* x, + const Tensor* skip, + const Tensor* gamma, + const Tensor* beta, + const Tensor* bias, + float epsilon, + bool simplified, + Tensor* output, + Tensor* input_skip_bias_sum) { + const auto& x_shape = x->Shape(); + if (x_shape.Size() == 0) { return Status::OK(); } @@ -165,18 +183,17 @@ Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeCo const uint32_t norm_count = onnxruntime::narrow(x_shape.SizeToDimension(x_shape.NumDimensions() - 1)); const bool split_hidden_dim = hidden_size % 512 == 0 && norm_count == 1; - const auto skip_shape = skip->Shape(); - const uint32_t skip_size = onnxruntime::narrow(skip_shape.Size()); + const uint32_t skip_size = onnxruntime::narrow(skip->Shape().Size()); SkipLayerNormProgram program{ - beta != nullptr, bias != nullptr, epsilon_, hidden_size, has_input_skip_bias_sum, simplified, split_hidden_dim}; + beta != nullptr, bias != nullptr, epsilon, hidden_size, has_input_skip_bias_sum, simplified, split_hidden_dim}; program .CacheHint(simplified, beta != nullptr, bias != nullptr, has_input_skip_bias_sum, split_hidden_dim) .AddInputs({{x, ProgramTensorMetadataDependency::Type, components}}) .AddInputs({{skip, ProgramTensorMetadataDependency::Type, components}}) .AddInputs({{gamma, ProgramTensorMetadataDependency::Type, components}}) .AddOutputs({{output, ProgramTensorMetadataDependency::None, components}}) - .SetDispatchGroupSize(onnxruntime::narrow(ceil(1.0 * data_size / hidden_size))) + .SetDispatchGroupSize(onnxruntime::narrow(ceil(1.0 * x_shape.Size() / hidden_size))) .AddUniformVariables({ {static_cast(components)}, }) @@ -184,7 +201,7 @@ Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeCo {static_cast(hidden_size)}, }) .AddUniformVariables({ - {static_cast(epsilon_)}, + {static_cast(epsilon)}, }) .AddUniformVariables({ {static_cast(skip_size)}, diff --git a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h index bfaec1c3d0d79..0430074bb2ae0 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h +++ b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h @@ -60,6 +60,21 @@ class SkipLayerNorm final : public WebGpuKernel { float epsilon_; }; +// Configures and dispatches a SkipLayerNormProgram. Centralizes program-setup logic +// (uniform variables, components, split_hidden_dim heuristic, workgroup sizing) so callers +// other than the SkipLayerNorm kernel (e.g. fused MatMulNBits ops) do not need to duplicate it. +// `beta`, `bias` and `input_skip_bias_sum` may be nullptr. +Status RunSkipLayerNormProgram(ComputeContext& context, + const Tensor* x, + const Tensor* skip, + const Tensor* gamma, + const Tensor* beta, + const Tensor* bias, + float epsilon, + bool simplified, + Tensor* output, + Tensor* input_skip_bias_sum); + } // namespace webgpu } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index b4e3991344089..e0a78aab1220b 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -18,10 +18,6 @@ namespace onnxruntime { namespace contrib { namespace webgpu { -namespace { -constexpr unsigned int kMinMForTileOptimization = 4; -} // namespace - ONNX_OPERATOR_KERNEL_EX( MatMulNBits, kMSDomain, @@ -226,29 +222,44 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, uint32_t zero_blocks_per_col = (n_blocks_per_col + zp_elements_per_byte - 1) / zp_elements_per_byte * zp_elements_per_byte; #if !defined(__wasm__) + // apple|intel - Experimental dawn support for subgroup matrix matmul. int32_t subgroup_matrix_config_index = -1; - // Experimental dawn support for subgroup matrix matmul (vendor-agnostic). - if ((M >= kMinMForTileOptimization && !has_weight_idx_indirect) && - CanApplySubgroupMatrixMatMulNBits(context, accuracy_level, block_size, batch_count, N, K, static_cast(nbits), y->DataType() == DataTypeImpl::GetType(), subgroup_matrix_config_index)) { + if (WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(M, + N, + K, + batch_count, + block_size, + accuracy_level, + nbits, + context, + y, + has_weight_idx_indirect, + &subgroup_matrix_config_index, + override_M)) { return ApplySubgroupMatrixMatMulNBits(a, b, scales, zero_points, bias, M, N, K, static_cast(nbits), zero_blocks_per_col, subgroup_matrix_config_index, context, y, weight_index, weight_index_indirect); } #endif // On FP32 only GPUs and Qualcomm GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M. // DP4A Q2 path now supports custom zero points via a 1024-entry LUT (4 zero-point sections × 256 byte values). - if (((M >= kMinMForTileOptimization && !has_weight_idx_indirect) || y->DataType() == DataTypeImpl::GetType() || context.AdapterInfo().vendor == std::string_view{"qualcomm"}) && - CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a)) { + if (WouldApplyDP4AMatMulNBitsInCurrentDispatch(M, + N, + K, + block_size, + accuracy_level, + context, + y, + has_weight_idx_indirect)) { return ApplyDP4AMatrixMatMulNBits(a, b, scales, zero_points, bias, batch_count, M, dispatch_M, N, K, block_size, zero_blocks_per_col, kMinMForTileOptimization, static_cast(nbits), context, y, weight_index, weight_index_indirect); } // WideTileProgram // This program is optimized for Block32 prefill using Tile16x128. - const bool use_wide_tile_program = !has_weight_idx_indirect && - block_size == 32 && - components_a == 4 && - components_b == 4 && - nbits != 2 && - M >= kMinMForTileOptimization; + const bool use_wide_tile_program = WouldApplyWideTileMatMulNBitsInCurrentDispatch(M, + K, + block_size, + nbits, + has_weight_idx_indirect); if (use_wide_tile_program) { // Enforce output components to 1. @@ -308,7 +319,8 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, // Use tile_size_k_vec=32 by default for better K-dimension parallelism. // Intel devices use 16 as they have different subgroup/cache characteristics. - const uint32_t tile_size_k_vec = (context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u; + const uint32_t tile_size_k_vec = + (context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u; constexpr uint32_t workgroup_size = 128; constexpr uint32_t tile_size = 8; diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc index b9eafeb43c7b6..488194a60a31a 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc @@ -2,9 +2,16 @@ // Licensed under the MIT License. #include "contrib_ops/webgpu/quantization/matmul_nbits_common.h" + #include + #include "core/common/common.h" +#include "contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h" +#include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h" +#include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/webgpu/webgpu_context.h" +#include "core/providers/webgpu/webgpu_utils.h" +#include "core/framework/tensor_shape.h" namespace onnxruntime { namespace contrib { @@ -61,6 +68,160 @@ bool HasDP4ADeviceSupport(int context_id) { ctx.AdapterInfo().vendor != std::string_view{"apple"}; } +bool WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(uint32_t M, + [[maybe_unused]] uint32_t N, + [[maybe_unused]] uint32_t K, + [[maybe_unused]] uint32_t batch_count, + [[maybe_unused]] uint32_t block_size, + [[maybe_unused]] int64_t accuracy_level, + [[maybe_unused]] int64_t nbits, + [[maybe_unused]] onnxruntime::webgpu::ComputeContext& context, + [[maybe_unused]] Tensor* y, + [[maybe_unused]] bool has_weight_idx_indirect, + [[maybe_unused]] int32_t* subgroup_matrix_config_index, + [[maybe_unused]] uint32_t override_M) { + [[maybe_unused]] const uint32_t dispatch_M = override_M > 0 ? override_M : M; + +#if !defined(__wasm__) + int32_t local_subgroup_matrix_config_index = -1; + if (dispatch_M != M) { + return false; + } + + return (M >= kMinMForTileOptimization && !has_weight_idx_indirect) && + CanApplySubgroupMatrixMatMulNBits(context, + accuracy_level, + block_size, + batch_count, + N, + K, + static_cast(nbits), + y->DataType() == DataTypeImpl::GetType(), + subgroup_matrix_config_index != nullptr ? *subgroup_matrix_config_index : local_subgroup_matrix_config_index); +#else + return false; +#endif +} + +bool WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(const Tensor* a, + int64_t K_op, + int64_t N_op, + int64_t block_size_op, + int64_t accuracy_level, + int64_t nbits, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y, + bool has_weight_idx_indirect, + int32_t* subgroup_matrix_config_index, + uint32_t override_M) { + TensorShape b_shape({N_op, K_op}); + MatMulComputeHelper helper; + if (!helper.Compute(a->Shape(), b_shape, false, true).IsOK()) { + return false; + } + + return WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch( + onnxruntime::narrow(helper.M()), + onnxruntime::narrow(helper.N()), + onnxruntime::narrow(helper.K()), + onnxruntime::narrow(helper.OutputOffsets().size()), + onnxruntime::narrow(block_size_op), + accuracy_level, + nbits, + context, + y, + has_weight_idx_indirect, + subgroup_matrix_config_index, + override_M); +} + +bool WouldApplyDP4AMatMulNBitsInCurrentDispatch(uint32_t M, + uint32_t N, + uint32_t K, + uint32_t block_size, + int64_t accuracy_level, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y, + bool has_weight_idx_indirect) { + const uint32_t components_a = GetMaxComponents(K); + + return ((M >= kMinMForTileOptimization && !has_weight_idx_indirect) || + y->DataType() == DataTypeImpl::GetType() || + context.AdapterInfo().vendor == std::string_view{"qualcomm"}) && + CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a); +} + +bool WouldApplyDP4AMatMulNBitsInCurrentDispatch(const Tensor* a, + int64_t K_op, + int64_t N_op, + int64_t block_size_op, + int64_t accuracy_level, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y, + bool has_weight_idx_indirect) { + TensorShape b_shape({N_op, K_op}); + MatMulComputeHelper helper; + if (!helper.Compute(a->Shape(), b_shape, false, true).IsOK()) { + return false; + } + + return WouldApplyDP4AMatMulNBitsInCurrentDispatch( + onnxruntime::narrow(helper.M()), + onnxruntime::narrow(helper.N()), + onnxruntime::narrow(helper.K()), + onnxruntime::narrow(block_size_op), + accuracy_level, + context, + y, + has_weight_idx_indirect); +} + +bool WouldApplyWideTileMatMulNBitsInCurrentDispatch(uint32_t M, + uint32_t K, + uint32_t block_size, + int64_t nbits, + bool has_weight_idx_indirect) { + if (has_weight_idx_indirect) { + return false; + } + + const uint32_t components_a = GetMaxComponents(K); + const uint32_t block_size_per_col = block_size; + const uint32_t blob_size = (block_size_per_col / 8) * static_cast(nbits); + const uint32_t blob_size_in_words = blob_size / 4; + const uint32_t components_b = GetMaxComponents(blob_size_in_words); + + return block_size == 32 && + components_a == 4 && + components_b == 4 && + nbits != 2 && + M >= kMinMForTileOptimization; +} + +bool WouldApplyWideTileMatMulNBitsInCurrentDispatch(const Tensor* a, + int64_t K_op, + int64_t N_op, + int64_t block_size_op, + int64_t nbits, + bool has_weight_idx_indirect) { + if (has_weight_idx_indirect) { + return false; + } + + TensorShape b_shape({N_op, K_op}); + MatMulComputeHelper helper; + if (!helper.Compute(a->Shape(), b_shape, false, true).IsOK()) { + return false; + } + + return WouldApplyWideTileMatMulNBitsInCurrentDispatch( + onnxruntime::narrow(helper.M()), + onnxruntime::narrow(helper.K()), + onnxruntime::narrow(block_size_op), + nbits, + has_weight_idx_indirect); +} + } // namespace webgpu } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h index 3db7c722b11eb..f3277e536ae62 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h @@ -6,10 +6,20 @@ #include #include +namespace onnxruntime { +class Tensor; + +namespace webgpu { +class ComputeContext; +} // namespace webgpu +} // namespace onnxruntime + namespace onnxruntime { namespace contrib { namespace webgpu { +inline constexpr uint32_t kMinMForTileOptimization = 4u; + /** * Generates WebGPU shader code for reading zero points in quantized matrix multiplication * @@ -26,6 +36,65 @@ std::string GenerateZeroPointReadingCode(uint32_t nbits, bool has_zero_points, /// \p context_id is the WebGpuContext slot (0 for the default context). bool HasDP4ADeviceSupport(int context_id = 0); +bool WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(const Tensor* a, + int64_t K_op, + int64_t N_op, + int64_t block_size_op, + int64_t accuracy_level, + int64_t nbits, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y, + bool has_weight_idx_indirect = false, + int32_t* subgroup_matrix_config_index = nullptr, + uint32_t override_M = 0); + +// Precomputed-dims overload for callers (e.g., ApplyMatMulNBits, MatMulNBitsMlp, +// MatMulNBitsQkv) that have already run MatMulComputeHelper and have M/N/K and +// batch_count in scope. Avoids re-running shape inference per dispatch decision. +bool WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(uint32_t M, + uint32_t N, + uint32_t K, + uint32_t batch_count, + uint32_t block_size, + int64_t accuracy_level, + int64_t nbits, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y, + bool has_weight_idx_indirect = false, + int32_t* subgroup_matrix_config_index = nullptr, + uint32_t override_M = 0); + +bool WouldApplyDP4AMatMulNBitsInCurrentDispatch(const Tensor* a, + int64_t K_op, + int64_t N_op, + int64_t block_size_op, + int64_t accuracy_level, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y, + bool has_weight_idx_indirect = false); + +bool WouldApplyDP4AMatMulNBitsInCurrentDispatch(uint32_t M, + uint32_t N, + uint32_t K, + uint32_t block_size, + int64_t accuracy_level, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y, + bool has_weight_idx_indirect = false); + +bool WouldApplyWideTileMatMulNBitsInCurrentDispatch(const Tensor* a, + int64_t K_op, + int64_t N_op, + int64_t block_size_op, + int64_t nbits, + bool has_weight_idx_indirect = false); + +bool WouldApplyWideTileMatMulNBitsInCurrentDispatch(uint32_t M, + uint32_t K, + uint32_t block_size, + int64_t nbits, + bool has_weight_idx_indirect = false); + } // namespace webgpu } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc new file mode 100644 index 0000000000000..8dd454083c9b3 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc @@ -0,0 +1,539 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/webgpu/quantization/matmul_nbits_mlp.h" + +#include + +#include "contrib_ops/webgpu/quantization/matmul_nbits.h" +#include "contrib_ops/webgpu/quantization/matmul_nbits_common.h" +#include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h" +#include "contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h" +#include "contrib_ops/webgpu/bert/skip_layer_norm.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/webgpu/nn/layer_norm.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/webgpu_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +Status ParseMlpActivation(std::string_view name, MlpActivationKind* out) { + if (name == "silu") { + *out = MlpActivationKind::Silu; + return Status::OK(); + } + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "MatMulNBitsMlp: activation '", name, "' is not supported."); +} + +namespace { + +constexpr uint32_t kFusedDecodeFastPathBits = 4u; +constexpr uint32_t kFusedDecodeFastPathBlockSize = 32u; + +// Emits the WGSL expression that applies the gate activation. The result must use +// the variable names produced by the inline kernel (`gate_value`) and the shader +// template (`gate_output_value`), so callers pass the gate operand variable name. +// Adding a new activation here is the kernel-side counterpart to extending the +// MlpActivationKind enum. +std::string EmitGateActivationExpr(MlpActivationKind kind, std::string_view gate_var) { + switch (kind) { + case MlpActivationKind::Silu: + // SiLU(x) = x * sigmoid(x) + return std::string{gate_var} + " * (one / (one + exp(-" + std::string{gate_var} + ")))"; + } + ORT_THROW("MatMulNBitsMlp: unhandled MlpActivationKind ", static_cast(kind)); +} + +class MatMulNBitsMlpDecodeProgram final : public Program { + public: + MatMulNBitsMlpDecodeProgram(uint32_t tile_size, + bool has_gate_bias, + bool has_up_bias, + bool has_norm_input, + bool has_skip_input, + bool has_skip_output, + bool single_scale_weights, + uint32_t tile_size_k_vec, + uint32_t k_unroll_tiles, + MlpActivationKind activation_kind) + : Program{"MatMulNBitsMlpDecode"}, + tile_size_(tile_size), + has_gate_bias_(has_gate_bias), + has_up_bias_(has_up_bias), + has_norm_input_(has_norm_input), + has_skip_input_(has_skip_input), + has_skip_output_(has_skip_output), + single_scale_weights_(single_scale_weights), + tile_size_k_vec_(tile_size_k_vec), + k_unroll_tiles_(k_unroll_tiles), + activation_kind_(activation_kind) {} + + Status GenerateShaderCode(ShaderHelper& shader) const override { + const auto& a = shader.AddInput("input_a", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto* skip = has_skip_input_ ? &shader.AddInput("skip", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias) : nullptr; + const auto* norm_scale = has_norm_input_ ? &shader.AddInput("norm_scale", ShaderUsage::UseValueTypeAlias) : nullptr; + const auto& gate_b = shader.AddInput("gate_b"); + const auto& gate_scales_b = shader.AddInput("gate_scales_b"); + const auto& up_b = shader.AddInput("up_b"); + const auto& up_scales_b = shader.AddInput("up_scales_b"); + if (has_gate_bias_) { + shader.AddInput("gate_bias", ShaderUsage::UseUniform); + } + if (has_up_bias_) { + shader.AddInput("up_bias", ShaderUsage::UseUniform); + } + const auto& output = shader.AddOutput("output", + ShaderUsage::UseElementTypeAlias); + const auto* input_skip_bias_sum = has_skip_output_ + ? &shader.AddOutput("input_skip_bias_sum", + ShaderUsage::UseValueTypeAlias | + ShaderUsage::UseElementTypeAlias) + : nullptr; + const auto& skip_var = skip != nullptr ? *skip : a; + const auto& norm_scale_var = norm_scale != nullptr ? *norm_scale : a; + const auto& input_skip_bias_sum_var = input_skip_bias_sum != nullptr ? *input_skip_bias_sum : output; + + const uint32_t components_a = a.NumComponents(); + const uint32_t components_b = gate_b.NumComponents() / 4; + const uint32_t tile_size_k_vec = tile_size_k_vec_; + const uint32_t elements_in_value_b = components_b * 8u; + const uint32_t tile_size_k = tile_size_k_vec * elements_in_value_b; + const uint32_t a_length_per_tile = tile_size_k / components_a; + const uint32_t sub_tile_count = WorkgroupSizeX() / tile_size_k_vec; + + // Template parameters are identical across the (has_skip_input, has_norm_input, + // has_skip_output) combinations; only the AddInput/AddOutput wiring upstream changes. + // The template's own #if directives select the appropriate code paths. + return WGSL_TEMPLATE_APPLY(shader, "quantization/matmul_nbits_mlp.wgsl.template", + WGSL_TEMPLATE_PARAMETER(a_length_per_tile, a_length_per_tile), + WGSL_TEMPLATE_PARAMETER(activation_kind, static_cast(activation_kind_)), + WGSL_TEMPLATE_PARAMETER(component_a, components_a), + WGSL_TEMPLATE_PARAMETER(component_b, components_b), + WGSL_TEMPLATE_PARAMETER(elements_in_value_b, elements_in_value_b), + WGSL_TEMPLATE_PARAMETER(has_gate_bias, has_gate_bias_), + WGSL_TEMPLATE_PARAMETER(has_norm_input, has_norm_input_), + WGSL_TEMPLATE_PARAMETER(has_skip_input, has_skip_input_), + WGSL_TEMPLATE_PARAMETER(has_skip_output, has_skip_output_), + WGSL_TEMPLATE_PARAMETER(has_up_bias, has_up_bias_), + WGSL_TEMPLATE_PARAMETER(k_unroll_tiles, k_unroll_tiles_), + WGSL_TEMPLATE_PARAMETER(single_scale_weights, single_scale_weights_), + WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count), + WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), + WGSL_TEMPLATE_PARAMETER(tile_size_k, tile_size_k), + WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), + WGSL_TEMPLATE_VARIABLE(a, a), + WGSL_TEMPLATE_VARIABLE(gate_b, gate_b), + WGSL_TEMPLATE_VARIABLE(gate_scales_b, gate_scales_b), + WGSL_TEMPLATE_VARIABLE(input_skip_bias_sum, input_skip_bias_sum_var), + WGSL_TEMPLATE_VARIABLE(norm_scale, norm_scale_var), + WGSL_TEMPLATE_VARIABLE(output, output), + WGSL_TEMPLATE_VARIABLE(skip, skip_var), + WGSL_TEMPLATE_VARIABLE(up_b, up_b), + WGSL_TEMPLATE_VARIABLE(up_scales_b, up_scales_b)); + } + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"N", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"K_of_a", ProgramUniformVariableDataType::Uint32}, + {"K_of_b", ProgramUniformVariableDataType::Uint32}, + {"block_size", ProgramUniformVariableDataType::Uint32}, + {"blocks_per_col", ProgramUniformVariableDataType::Uint32}, + {"num_N_tile", ProgramUniformVariableDataType::Uint32}, + {"batch_count", ProgramUniformVariableDataType::Uint32}, + {"skip_size", ProgramUniformVariableDataType::Uint32}, + {"epsilon", ProgramUniformVariableDataType::Float32}); + + private: + uint32_t tile_size_; + bool has_gate_bias_; + bool has_up_bias_; + bool has_norm_input_; + bool has_skip_input_; + bool has_skip_output_; + bool single_scale_weights_; + uint32_t tile_size_k_vec_; + uint32_t k_unroll_tiles_; + MlpActivationKind activation_kind_; +}; + +class MatMulNBitsMlpProgram final : public Program { + public: + explicit MatMulNBitsMlpProgram(MlpActivationKind activation_kind) + : Program{"MatMulNBitsMlp"}, activation_kind_(activation_kind) { + CacheHint(static_cast(activation_kind_)); + } + + Status GenerateShaderCode(ShaderHelper& shader) const override { + const auto& gate = shader.AddInput("gate", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& up = shader.AddInput("up", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size") + << "let gate_value = " << gate.GetByOffset("global_idx") << ";\n" + << "let up_value = " << up.GetByOffset("global_idx") << ";\n" + << "let one = output_value_t(1.0);\n" + << "let activated_value = " << EmitGateActivationExpr(activation_kind_, "gate_value") << ";\n" + << output.SetByOffset("global_idx", "activated_value * up_value"); + + return Status::OK(); + } + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + + private: + MlpActivationKind activation_kind_; +}; + +Status ApplyUnfusedMlp(const Tensor* a, + const Tensor* gate_b, + const Tensor* gate_scales, + const Tensor* gate_bias, + const Tensor* up_b, + const Tensor* up_scales, + const Tensor* up_bias, + int64_t K, + int64_t N, + int64_t block_size, + int64_t accuracy_level, + int64_t bits, + MlpActivationKind activation_kind, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y) { + MatMulComputeHelper helper; + TensorShape b_shape({N, K}); + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); + const auto output_shape = helper.OutputShape(); + + Tensor gate_output = context.CreateGPUTensor(a->DataType(), output_shape); + Tensor up_output = context.CreateGPUTensor(a->DataType(), output_shape); + + ORT_RETURN_IF_ERROR(ApplyMatMulNBits(a, gate_b, gate_scales, nullptr, gate_bias, K, N, block_size, accuracy_level, bits, context, &gate_output)); + ORT_RETURN_IF_ERROR(ApplyMatMulNBits(a, up_b, up_scales, nullptr, up_bias, K, N, block_size, accuracy_level, bits, context, &up_output)); + + const uint32_t data_size = onnxruntime::narrow(y->Shape().Size()); + const uint32_t vec_size = (data_size + 3u) / 4u; + MatMulNBitsMlpProgram program{activation_kind}; + program + .AddInputs({{&gate_output, ProgramTensorMetadataDependency::Type, ProgramInput::Flatten, 4}, + {&up_output, ProgramTensorMetadataDependency::Type, ProgramInput::Flatten, 4}}) + .AddOutput({y, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({vec_size}); + + return context.RunProgram(program); +} + +} // namespace + +ONNX_OPERATOR_KERNEL_EX( + MatMulNBitsMlp, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", WebGpuSupportedFloatTypes()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulNBitsMlp); + +Status MatMulNBitsMlp::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const Tensor* a = context.Input(0); + const Tensor* skip = context.Input(1); + const Tensor* norm_scale = context.Input(2); + const Tensor* gate_b = context.Input(3); + const Tensor* gate_scales = context.Input(4); + const Tensor* gate_bias = context.Input(5); + const Tensor* up_b = context.Input(6); + const Tensor* up_scales = context.Input(7); + const Tensor* up_bias = context.Input(8); + + ORT_ENFORCE(skip == nullptr || norm_scale != nullptr, + "MatMulNBitsMlp requires norm_scale when skip is present."); + + MatMulComputeHelper helper; + TensorShape b_shape({N_, K_}); + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); + const auto output_shape = helper.OutputShape(); + const uint32_t batch_count = onnxruntime::narrow(helper.OutputOffsets().size()); + const uint32_t M = onnxruntime::narrow(helper.M()); + const uint32_t N = onnxruntime::narrow(helper.N()); + const uint32_t K = onnxruntime::narrow(helper.K()); + const uint32_t block_size = onnxruntime::narrow(block_size_); + const uint32_t components_a = GetMaxComponents(K); + const bool single_scale_weights = (block_size == K * N); + const uint32_t block_size_per_col = single_scale_weights ? K : block_size; + const uint32_t n_blocks_per_col = (K + block_size_per_col - 1) / block_size_per_col; + const uint32_t blob_size = (block_size_per_col / 8) * onnxruntime::narrow(bits_); + const uint32_t blob_size_in_words = blob_size / 4; + const uint32_t components_b = GetMaxComponents(blob_size_in_words); + constexpr uint32_t kU32Components = 4; + const uint32_t components_b_with_u32 = components_b * kU32Components; + const uint32_t K_of_b = (n_blocks_per_col * blob_size) / components_b_with_u32; + + Tensor* y = context.Output(0, output_shape); + Tensor* input_skip_bias_sum = (skip != nullptr && context.OutputCount() > 1) + ? context.Output(1, a->Shape()) + : nullptr; + const uint32_t data_size = onnxruntime::narrow(y->Shape().Size()); + if (data_size == 0) { + return Status::OK(); + } + + if (norm_scale != nullptr) { + ORT_ENFORCE(norm_scale->Shape().Size() == K_, "norm_scale must have shape [K]."); + } + + const bool has_skip_input = skip != nullptr; + const bool has_skip_output = input_skip_bias_sum != nullptr; + + const bool is_decode_fast_path_candidate = + M == 1 && + bits_ == kFusedDecodeFastPathBits && + block_size == kFusedDecodeFastPathBlockSize; + const bool has_norm_input = norm_scale != nullptr; + + const bool would_use_subgroup_unfused = + WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(M, + N, + K, + batch_count, + block_size, + accuracy_level_, + bits_, + context, + y); + const bool would_use_dp4a_unfused = + WouldApplyDP4AMatMulNBitsInCurrentDispatch(M, + N, + K, + block_size, + accuracy_level_, + context, + y); + const bool would_use_wide_tile_unfused = + WouldApplyWideTileMatMulNBitsInCurrentDispatch(M, + K, + block_size, + bits_); + + const bool can_use_decode_fast_path = + is_decode_fast_path_candidate && + !would_use_subgroup_unfused && + !would_use_dp4a_unfused && + !would_use_wide_tile_unfused; + + if (can_use_decode_fast_path) { + ORT_ENFORCE(bits_ == kFusedDecodeFastPathBits, + "MatMulNBitsMlpDecodeProgram is specialized for 4-bit weights only."); + ORT_ENFORCE(block_size == kFusedDecodeFastPathBlockSize, + "MatMulNBitsMlpDecodeProgram is specialized for block_size=32 only."); + + const bool has_gate_bias = gate_bias != nullptr; + const bool has_up_bias = up_bias != nullptr; + + // The fully-fused MLP decode shader binds every weight/scale/bias plus the norm/skip + // tensors as storage buffers. Devices with a tight maxStorageBuffersPerShaderStage + // (notably macOS Metal at 10) cannot bind that many. For those devices we run the layer + // norm separately into a scratch tensor and then dispatch a no-norm variant of the + // decode program (which omits the norm_scale, skip, and skip-output bindings, dropping + // the storage-buffer count from up to 11 down to 8). + // + // Storage-buffer count: input_a + (skip?) + (norm_scale?) + 2 * (weight + scales) + // + output + (skip output?) + (gate_bias?) + (up_bias?) + const uint32_t required_storage_buffers = + 1u // input_a + + (has_skip_input ? 1u : 0u) // skip + + (has_norm_input ? 1u : 0u) // norm_scale + + 4u // gate/up weights + scales + + 1u // output + + (has_skip_output ? 1u : 0u) // skip output + + (has_gate_bias ? 1u : 0u) // gate bias + + (has_up_bias ? 1u : 0u); // up bias + const bool exceeds_storage_buffer_limit = + required_storage_buffers > context.DeviceLimits().maxStorageBuffersPerShaderStage; + + // Optionally pre-normalize a into a scratch tensor and drop the norm/skip bindings + // from the decode program. The user-visible residual passthrough (input_skip_bias_sum) + // is produced by the skip-norm op directly in this path. + std::optional normalized_a_storage; + const Tensor* decode_a = a; + if (exceeds_storage_buffer_limit && has_norm_input) { + normalized_a_storage.emplace(context.CreateGPUTensor(a->DataType(), a->Shape())); + if (has_skip_input) { + ORT_RETURN_IF_ERROR(RunSkipLayerNormProgram(context, a, skip, norm_scale, + /*beta=*/nullptr, + /*bias=*/nullptr, + epsilon_, /*simplified=*/true, + &*normalized_a_storage, + input_skip_bias_sum)); + } else { + const auto& a_shape = a->Shape(); + const int64_t norm_size = a_shape[a_shape.NumDimensions() - 1]; + const uint32_t norm_count = onnxruntime::narrow(a_shape.Size() / norm_size); + ORT_RETURN_IF_ERROR(onnxruntime::webgpu::RunLayerNormProgram( + context, a, norm_scale, /*bias=*/nullptr, epsilon_, norm_count, norm_size, + /*simplified=*/true, &*normalized_a_storage, /*mean=*/nullptr, + /*inv_std_dev=*/nullptr)); + } + decode_a = &*normalized_a_storage; + } + + // Decode-program-level norm/skip bindings: only used when the device has spare + // storage-buffer slots. Otherwise they are wired to the pre-normalized input above. + const bool decode_has_norm_input = has_norm_input && !exceeds_storage_buffer_limit; + const bool decode_has_skip_input = has_skip_input && !exceeds_storage_buffer_limit; + const bool decode_has_skip_output = has_skip_output && !exceeds_storage_buffer_limit; + + uint32_t workgroup_size = 128; + uint32_t tile_size = 8; + uint32_t tile_size_k_vec = + (context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u; + + const uint32_t elements_in_value_b = components_b * (32u / onnxruntime::narrow(bits_)); + const uint32_t tile_size_k = tile_size_k_vec * elements_in_value_b; + const uint32_t k_tile_iterations = K / tile_size_k; + + uint32_t k_unroll_tiles = 1; + if ((K % tile_size_k) == 0) { + if (k_tile_iterations >= 8 && N <= 2048 && context.AdapterInfo().vendor != std::string_view{"intel"}) { + k_unroll_tiles = 4; + } else if (k_tile_iterations >= 4) { + k_unroll_tiles = 2; + } + } + + const uint32_t num_N_tile = CeilDiv(N, tile_size); + + MatMulNBitsMlpDecodeProgram program{tile_size, + has_gate_bias, + has_up_bias, + decode_has_norm_input, + decode_has_skip_input, + decode_has_skip_output, + single_scale_weights, + tile_size_k_vec, + k_unroll_tiles, + activation_kind_}; + program.SetWorkgroupSize(workgroup_size); + program.SetDispatchGroupSize(num_N_tile, 1, batch_count); + program.AddInput({decode_a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}); + if (decode_has_skip_input) { + program.AddInput({skip, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}); + } + if (decode_has_norm_input) { + program.AddInput({norm_scale, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}); + } + program + .AddInputs({{gate_b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, + {gate_scales, ProgramTensorMetadataDependency::TypeAndRank}, + {up_b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, + {up_scales, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank}) + .AddUniformVariables({{N}, + {K}, + {K / components_a}, + {K_of_b}, + {block_size}, + {n_blocks_per_col}, + {num_N_tile}, + {batch_count}, + {decode_has_skip_input ? onnxruntime::narrow(skip->Shape().Size()) : 0u}, + {epsilon_}}) + .CacheHint(single_scale_weights, + has_gate_bias, + has_up_bias, + decode_has_norm_input, + decode_has_skip_input, + decode_has_skip_output, + tile_size_k_vec, + k_unroll_tiles, + static_cast(activation_kind_), + "decode_4bit"); + if (decode_has_skip_output) { + program.AddOutput({input_skip_bias_sum, + ProgramTensorMetadataDependency::TypeAndRank, + static_cast(components_a)}); + } + if (has_gate_bias) { + program.AddInput({gate_bias, ProgramTensorMetadataDependency::None}); + } + if (has_up_bias) { + program.AddInput({up_bias, ProgramTensorMetadataDependency::None}); + } + + return context.RunProgram(program); + } + + if (skip != nullptr) { + Tensor normalized_a = context.CreateGPUTensor(a->DataType(), a->Shape()); + ORT_RETURN_IF_ERROR(RunSkipLayerNormProgram(context, a, skip, norm_scale, + /*beta=*/nullptr, /*bias=*/nullptr, + epsilon_, /*simplified=*/true, + &normalized_a, input_skip_bias_sum)); + return ApplyUnfusedMlp(&normalized_a, + gate_b, + gate_scales, + gate_bias, + up_b, + up_scales, + up_bias, + K_, + N_, + block_size_, + accuracy_level_, + bits_, + activation_kind_, + context, + y); + } + + if (norm_scale != nullptr) { + Tensor normalized_a = context.CreateGPUTensor(a->DataType(), a->Shape()); + const auto& a_shape = a->Shape(); + const int64_t norm_size = a_shape[a_shape.NumDimensions() - 1]; + const uint32_t norm_count = onnxruntime::narrow(a_shape.Size() / norm_size); + ORT_RETURN_IF_ERROR(onnxruntime::webgpu::RunLayerNormProgram( + context, a, norm_scale, /*bias=*/nullptr, epsilon_, norm_count, norm_size, + /*simplified=*/true, &normalized_a, /*mean=*/nullptr, /*inv_std_dev=*/nullptr)); + return ApplyUnfusedMlp(&normalized_a, + gate_b, + gate_scales, + gate_bias, + up_b, + up_scales, + up_bias, + K_, + N_, + block_size_, + accuracy_level_, + bits_, + activation_kind_, + context, + y); + } + + return ApplyUnfusedMlp(a, + gate_b, + gate_scales, + gate_bias, + up_b, + up_scales, + up_bias, + K_, + N_, + block_size_, + accuracy_level_, + bits_, + activation_kind_, + context, + y); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.h new file mode 100644 index 0000000000000..002ebca4c0e54 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.h @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/common/status.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +// Gate activation applied between the gate and up MatMulNBits projections. +// Currently only SiLU is supported; future activations (e.g. GELU for Gemma-style +// gated MLPs) can be added here and threaded through the kernel and shader template. +enum class MlpActivationKind : uint32_t { + Silu = 0, +}; + +// Parses the `activation` attribute string into MlpActivationKind. Returns a non-OK +// Status for unsupported activations so the kernel rejects unknown values up front. +Status ParseMlpActivation(std::string_view name, MlpActivationKind* out); + +class MatMulNBitsMlp final : public WebGpuKernel { + public: + explicit MatMulNBitsMlp(const OpKernelInfo& info) : WebGpuKernel(info) { + K_ = info.GetAttr("K"); + N_ = info.GetAttr("N"); + block_size_ = info.GetAttr("block_size"); + bits_ = info.GetAttr("bits"); + accuracy_level_ = info.GetAttrOrDefault("accuracy_level", 4); + epsilon_ = info.GetAttrOrDefault("epsilon", 1e-5f); + std::string activation; + ORT_ENFORCE(info.GetAttr("activation", &activation).IsOK(), + "MatMulNBitsMlp requires the 'activation' attribute."); + ORT_ENFORCE(ParseMlpActivation(activation, &activation_kind_).IsOK(), + "MatMulNBitsMlp: unsupported activation '", activation, "'."); + ORT_ENFORCE(bits_ == 4 || bits_ == 8 || bits_ == 2, + "Only 4b/8b/2b quantization is supported for MatMulNBitsMlp op."); + } + + Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t accuracy_level_; + int64_t bits_; + float epsilon_; + MlpActivationKind activation_kind_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.wgsl.template new file mode 100644 index 0000000000000..f64f0d38f24e2 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.wgsl.template @@ -0,0 +1,266 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#param a_length_per_tile +#param component_a +#param component_b +#param elements_in_value_b +#param single_scale_weights +#param sub_tile_count +#param has_norm_input +#param has_skip_input +#param has_skip_output +#param tile_size_k_vec +#param tile_size_k +#param tile_size +#param has_gate_bias +#param has_up_bias +#param k_unroll_tiles +// Gate activation applied between gate and up projections. +// Mirrors MlpActivationKind in matmul_nbits_mlp.h: +// 0 = SiLU (only value currently supported) +// New activations are added by extending the enum, the EmitGateActivationExpr +// helper, the fusion matcher, and the activation block below. +#param activation_kind + +#use .getByOffset .setByOffset + +#if has_norm_input +var sum_squared_shared : array; +#endif +var tile_A : array; +var gate_inter_results : array, tile_size>; +var up_inter_results : array, tile_size>; + +const default_zero_point = output_element_t(8); + +fn load_merged_input(input_offset: u32) -> input_a_value_t { +#if has_skip_input + let skip_offset = input_offset % (uniforms.skip_size / component_a); + return a.getByOffset(input_offset) + input_a_value_t(skip.getByOffset(skip_offset)); +#else + return a.getByOffset(input_offset); +#endif +} + +fn loadSHMA(batch: u32, b_global_base: u32, kidx: u32, col: u32, inv_std: f32) +{ + let k_offset = kidx / component_a + col; + let input_offset = batch * uniforms.K_of_a + k_offset; + if (k_offset < uniforms.K_of_a) { + let merged_value = load_merged_input(input_offset); +#if has_skip_output + if (b_global_base == 0u) { + input_skip_bias_sum.setByOffset(input_offset, input_skip_bias_sum_value_t(merged_value)); + } +#endif +#if has_norm_input + tile_A[col] = merged_value * input_a_value_t(input_a_element_t(inv_std)) * norm_scale.getByOffset(k_offset); +#else + tile_A[col] = merged_value; +#endif + } else { + tile_A[col] = input_a_value_t(0); + } +} + +fn compute_gate_up_sums(b_global: u32, kidx: u32, idx: u32, k_offset: u32) -> vec2 { +#if single_scale_weights + let gate_scale_b = gate_scales_b.getByOffset(0); + let up_scale_b = up_scales_b.getByOffset(0); +#else + let block_idx = (kidx + idx * elements_in_value_b) / uniforms.block_size; + let gate_scale_b = gate_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx); + let up_scale_b = up_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx); +#endif + let gate_b_value = gate_b.getByOffset(b_global * uniforms.K_of_b + k_offset); + let up_b_value = up_b.getByOffset(b_global * uniforms.K_of_b + k_offset); + + var gate_sum = output_element_t(0); + var up_sum = output_element_t(0); + var a_offset = idx * (8 / component_a) * component_b; +#if component_b == 1 + let gate_b_value_lower = vec4(unpack4xU8(gate_b_value & 0x0F0F0F0Fu)) - vec4(default_zero_point); + let gate_b_value_upper = vec4(unpack4xU8((gate_b_value >> 4) & 0x0F0F0F0Fu)) - vec4(default_zero_point); + let gate_b0 = vec4(gate_b_value_lower[0], gate_b_value_upper[0], gate_b_value_lower[1], gate_b_value_upper[1]) * gate_scale_b; + let gate_b1 = vec4(gate_b_value_lower[2], gate_b_value_upper[2], gate_b_value_lower[3], gate_b_value_upper[3]) * gate_scale_b; + let up_b_value_lower = vec4(unpack4xU8(up_b_value & 0x0F0F0F0Fu)) - vec4(default_zero_point); + let up_b_value_upper = vec4(unpack4xU8((up_b_value >> 4) & 0x0F0F0F0Fu)) - vec4(default_zero_point); + let up_b0 = vec4(up_b_value_lower[0], up_b_value_upper[0], up_b_value_lower[1], up_b_value_upper[1]) * up_scale_b; + let up_b1 = vec4(up_b_value_lower[2], up_b_value_upper[2], up_b_value_lower[3], up_b_value_upper[3]) * up_scale_b; +#if component_a == 1 + let a0 = vec4(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]); + let a1 = vec4(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]); + gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1); + up_sum += dot(a0, up_b0) + dot(a1, up_b1); +#elif component_a == 2 + let a0 = vec4(tile_A[a_offset], tile_A[a_offset + 1]); + let a1 = vec4(tile_A[a_offset + 2], tile_A[a_offset + 3]); + gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1); + up_sum += dot(a0, up_b0) + dot(a1, up_b1); +#elif component_a == 4 + let a0 = tile_A[a_offset]; + let a1 = tile_A[a_offset + 1]; + gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1); + up_sum += dot(a0, up_b0) + dot(a1, up_b1); +#endif +#else + for (var i = 0u; i < component_b; i++) { + let gate_b_value_lower = vec4(unpack4xU8(gate_b_value[i] & 0x0F0F0F0Fu)) - vec4(default_zero_point); + let gate_b_value_upper = vec4(unpack4xU8((gate_b_value[i] >> 4) & 0x0F0F0F0Fu)) - vec4(default_zero_point); + let gate_b0 = vec4(gate_b_value_lower[0], gate_b_value_upper[0], gate_b_value_lower[1], gate_b_value_upper[1]) * gate_scale_b; + let gate_b1 = vec4(gate_b_value_lower[2], gate_b_value_upper[2], gate_b_value_lower[3], gate_b_value_upper[3]) * gate_scale_b; + let up_b_value_lower = vec4(unpack4xU8(up_b_value[i] & 0x0F0F0F0Fu)) - vec4(default_zero_point); + let up_b_value_upper = vec4(unpack4xU8((up_b_value[i] >> 4) & 0x0F0F0F0Fu)) - vec4(default_zero_point); + let up_b0 = vec4(up_b_value_lower[0], up_b_value_upper[0], up_b_value_lower[1], up_b_value_upper[1]) * up_scale_b; + let up_b1 = vec4(up_b_value_lower[2], up_b_value_upper[2], up_b_value_lower[3], up_b_value_upper[3]) * up_scale_b; +#if component_a == 1 + let a0 = vec4(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]); + let a1 = vec4(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]); + gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1); + up_sum += dot(a0, up_b0) + dot(a1, up_b1); + a_offset += 8; +#elif component_a == 2 + let a0 = vec4(tile_A[a_offset], tile_A[a_offset + 1]); + let a1 = vec4(tile_A[a_offset + 2], tile_A[a_offset + 3]); + gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1); + up_sum += dot(a0, up_b0) + dot(a1, up_b1); + a_offset += 4; +#elif component_a == 4 + let a0 = tile_A[a_offset]; + let a1 = tile_A[a_offset + 1]; + gate_sum += dot(a0, gate_b0) + dot(a1, gate_b1); + up_sum += dot(a0, up_b0) + dot(a1, up_b1); + a_offset += 2; +#endif + } +#endif + + return vec2(gate_sum, up_sum); +} + +fn process_k_tile(batch: u32, b_global_base: u32, thread_idx: u32, idx: u32, idy: u32, kidx: u32, inv_std: f32) { + for (var id = thread_idx; id < a_length_per_tile; id += workgroup_size_x) + { + loadSHMA(batch, b_global_base, kidx, id, inv_std); + } + workgroupBarrier(); + + for (var local_row_offset = 0u; local_row_offset < tile_size; local_row_offset += sub_tile_count) + { + let b_global = b_global_base + local_row_offset + idy; + let k_offset = kidx / elements_in_value_b + idx; + if (b_global < uniforms.N && k_offset < uniforms.K_of_b) { + let sums = compute_gate_up_sums(b_global, kidx, idx, k_offset); + gate_inter_results[local_row_offset + idy][idx] += sums[0]; + up_inter_results[local_row_offset + idy][idx] += sums[1]; + } + } + workgroupBarrier(); +} + +$MAIN { + let batch = workgroup_idx / uniforms.num_N_tile; + let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size; + + let idx = local_idx % tile_size_k_vec; + let idy = local_idx / tile_size_k_vec; + + if (local_idx < tile_size) { + for (var b = 0u; b < tile_size_k_vec; b++) { + gate_inter_results[local_idx][b] = output_element_t(0); + up_inter_results[local_idx][b] = output_element_t(0); + } + } + workgroupBarrier(); + +#if has_norm_input + var sum_squared_local = 0.0; + for (var a_idx = local_idx; a_idx < uniforms.K_of_a; a_idx += workgroup_size_x) { + let a_value = load_merged_input(batch * uniforms.K_of_a + a_idx); +#if component_a == 1 + let a_f32 = f32(a_value); + sum_squared_local += a_f32 * a_f32; +#elif component_a == 2 + let a_f32 = vec2(a_value); + sum_squared_local += dot(a_f32, a_f32); +#elif component_a == 4 + let a_f32 = vec4(a_value); + sum_squared_local += dot(a_f32, a_f32); +#endif + } + sum_squared_shared[local_idx] = sum_squared_local; + workgroupBarrier(); + + var reduce_size : u32 = workgroup_size_x; + for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) { + reduce_size = curr_size + (reduce_size & 1u); + if (local_idx < curr_size) { + sum_squared_shared[local_idx] += sum_squared_shared[local_idx + reduce_size]; + } + workgroupBarrier(); + } + + let inv_std = inverseSqrt(sum_squared_shared[0] / f32(uniforms.K) + uniforms.epsilon); +#else + let inv_std = 1.0; +#endif + +#if k_unroll_tiles == 1 + for (var kidx = 0u; kidx < uniforms.K; kidx += tile_size_k) { + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std); + } +#elif k_unroll_tiles == 2 + let unrolled_k_step = tile_size_k * 2u; + let unrolled_k_limit = uniforms.K - (uniforms.K % unrolled_k_step); + for (var kidx = 0u; kidx < unrolled_k_limit; kidx += unrolled_k_step) { + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std); + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k, inv_std); + } + for (var kidx = unrolled_k_limit; kidx < uniforms.K; kidx += tile_size_k) { + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std); + } +#elif k_unroll_tiles == 4 + let unrolled_k_step = tile_size_k * 4u; + let unrolled_k_limit = uniforms.K - (uniforms.K % unrolled_k_step); + for (var kidx = 0u; kidx < unrolled_k_limit; kidx += unrolled_k_step) { + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std); + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k, inv_std); + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k * 2u, inv_std); + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k * 3u, inv_std); + } + for (var kidx = unrolled_k_limit; kidx < uniforms.K; kidx += tile_size_k) { + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std); + } +#endif + + if (batch >= uniforms.batch_count) { + return; + } + + if (local_idx < tile_size) { + var gate_output_value = output_element_t(0); + var up_output_value = output_element_t(0); + for (var b = 0u; b < tile_size_k_vec; b++) { + gate_output_value += gate_inter_results[local_idx][b]; + up_output_value += up_inter_results[local_idx][b]; + } + let b_global = b_global_base + local_idx; + let output_idx = batch * uniforms.N + b_global; + if (b_global < uniforms.N) { +#if has_gate_bias + gate_output_value += gate_bias[b_global]; +#endif +#if has_up_bias + up_output_value += up_bias[b_global]; +#endif + let one = output_element_t(1.0); +#if activation_kind == 0 + // SiLU(x) = x * sigmoid(x). New activations are added with additional + // `#elif activation_kind == N` blocks (must match MlpActivationKind). + let activated_value = gate_output_value * (one / (one + exp(-gate_output_value))); +#endif + output.setByOffset(output_idx, activated_value * up_output_value); + } + } +} // MAIN diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc new file mode 100644 index 0000000000000..4d46e03fba20c --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc @@ -0,0 +1,482 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/webgpu/quantization/matmul_nbits_qkv.h" + +#include + +#include "contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h" +#include "contrib_ops/webgpu/quantization/matmul_nbits.h" +#include "contrib_ops/webgpu/quantization/matmul_nbits_common.h" +#include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h" +#include "contrib_ops/webgpu/bert/skip_layer_norm.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/webgpu/nn/layer_norm.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/webgpu_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +namespace { + +Status ApplyUnfusedQKVSimplifiedLayerNorm(const Tensor* a, + const Tensor* norm_scale, + const Tensor* q_b, + const Tensor* q_scales, + const Tensor* k_b, + const Tensor* k_scales, + const Tensor* v_b, + const Tensor* v_scales, + int64_t K, + int64_t Nq, + int64_t Nkv, + int64_t block_size, + int64_t accuracy_level, + int64_t bits, + float epsilon, + onnxruntime::webgpu::ComputeContext& context, + Tensor* q_output, + Tensor* k_output, + Tensor* v_output) { + Tensor normalized_a = context.CreateGPUTensor(a->DataType(), a->Shape()); + const auto& a_shape = a->Shape(); + const int64_t norm_size = a_shape[a_shape.NumDimensions() - 1]; + const uint32_t norm_count = onnxruntime::narrow(a_shape.Size() / norm_size); + ORT_RETURN_IF_ERROR(onnxruntime::webgpu::RunLayerNormProgram( + context, a, norm_scale, /*bias=*/nullptr, epsilon, norm_count, norm_size, + /*simplified=*/true, &normalized_a, /*mean=*/nullptr, /*inv_std_dev=*/nullptr)); + ORT_RETURN_IF_ERROR(ApplyMatMulNBits(&normalized_a, q_b, q_scales, nullptr, nullptr, + K, Nq, block_size, accuracy_level, bits, context, q_output)); + ORT_RETURN_IF_ERROR(ApplyMatMulNBits(&normalized_a, k_b, k_scales, nullptr, nullptr, + K, Nkv, block_size, accuracy_level, bits, context, k_output)); + ORT_RETURN_IF_ERROR(ApplyMatMulNBits(&normalized_a, v_b, v_scales, nullptr, nullptr, + K, Nkv, block_size, accuracy_level, bits, context, v_output)); + return Status::OK(); +} + +Status ApplyUnfusedQKVSkipSimplifiedLayerNorm(const Tensor* a, + const Tensor* skip, + const Tensor* norm_scale, + const Tensor* q_b, + const Tensor* q_scales, + const Tensor* k_b, + const Tensor* k_scales, + const Tensor* v_b, + const Tensor* v_scales, + int64_t K, + int64_t Nq, + int64_t Nkv, + int64_t block_size, + int64_t accuracy_level, + int64_t bits, + float epsilon, + onnxruntime::webgpu::ComputeContext& context, + Tensor* q_output, + Tensor* k_output, + Tensor* v_output, + Tensor* input_skip_bias_sum) { + Tensor normalized_a = context.CreateGPUTensor(a->DataType(), a->Shape()); + ORT_RETURN_IF_ERROR(RunSkipLayerNormProgram(context, a, skip, norm_scale, + /*beta=*/nullptr, /*bias=*/nullptr, + epsilon, /*simplified=*/true, + &normalized_a, input_skip_bias_sum)); + ORT_RETURN_IF_ERROR(ApplyMatMulNBits(&normalized_a, q_b, q_scales, nullptr, nullptr, + K, Nq, block_size, accuracy_level, bits, context, q_output)); + ORT_RETURN_IF_ERROR(ApplyMatMulNBits(&normalized_a, k_b, k_scales, nullptr, nullptr, + K, Nkv, block_size, accuracy_level, bits, context, k_output)); + ORT_RETURN_IF_ERROR(ApplyMatMulNBits(&normalized_a, v_b, v_scales, nullptr, nullptr, + K, Nkv, block_size, accuracy_level, bits, context, v_output)); + return Status::OK(); +} + +class MatMulNBitsQkvDecodeProgram final + : public Program { + public: + MatMulNBitsQkvDecodeProgram(uint32_t tile_size, + bool single_scale_weights, + uint32_t tile_size_k_vec, + uint32_t k_unroll_tiles, + bool has_norm, + bool has_skip_input, + bool has_skip_output) + : Program{"MatMulNBitsQkvDecode"}, + tile_size_(tile_size), + single_scale_weights_(single_scale_weights), + tile_size_k_vec_(tile_size_k_vec), + k_unroll_tiles_(k_unroll_tiles), + has_norm_(has_norm), + has_skip_input_(has_skip_input), + has_skip_output_(has_skip_output) { + // The no-norm variant runs against an already-normalized input tensor and therefore + // never owns the residual skip path nor the residual passthrough output. + ORT_ENFORCE(has_norm_ || (!has_skip_input_ && !has_skip_output_), + "MatMulNBitsQkvDecodeProgram: skip input/output require has_norm=true."); + } + + Status GenerateShaderCode(ShaderHelper& shader) const override { + const auto& a = shader.AddInput("input_a", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto* skip = has_skip_input_ ? &shader.AddInput("skip", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias) : nullptr; + const auto* norm_scale_ptr = has_norm_ ? &shader.AddInput("norm_scale", ShaderUsage::UseValueTypeAlias) : nullptr; + const auto& q_b = shader.AddInput("q_b", ShaderUsage::UseValueTypeAlias); + const auto& q_scales_b = shader.AddInput("q_scales_b"); + const auto& k_b = shader.AddInput("k_b"); + const auto& k_scales_b = shader.AddInput("k_scales_b"); + const auto& v_b = shader.AddInput("v_b"); + const auto& v_scales_b = shader.AddInput("v_scales_b"); + const auto& q_output = shader.AddOutput("q_output", + ShaderUsage::UseValueTypeAlias | + ShaderUsage::UseElementTypeAlias); + const auto& k_output = shader.AddOutput("k_output", + ShaderUsage::UseValueTypeAlias | + ShaderUsage::UseElementTypeAlias); + const auto& v_output = shader.AddOutput("v_output", + ShaderUsage::UseValueTypeAlias | + ShaderUsage::UseElementTypeAlias); + const auto* input_skip_bias_sum = has_skip_output_ ? &shader.AddOutput("input_skip_bias_sum", ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias) : nullptr; + const auto& skip_var = skip != nullptr ? *skip : a; + const auto& norm_scale_var = norm_scale_ptr != nullptr ? *norm_scale_ptr : a; + const auto& input_skip_bias_sum_var = input_skip_bias_sum != nullptr ? *input_skip_bias_sum : q_output; + + const uint32_t components_a = a.NumComponents(); + const uint32_t components_b = q_b.NumComponents() / 4; + const uint32_t tile_size_k_vec = tile_size_k_vec_; + const uint32_t elements_in_value_b = components_b * 8u; + const uint32_t tile_size_k = tile_size_k_vec * elements_in_value_b; + const uint32_t a_length_per_tile = tile_size_k / components_a; + const uint32_t sub_tile_count = WorkgroupSizeX() / tile_size_k_vec; + + return WGSL_TEMPLATE_APPLY(shader, "quantization/matmul_nbits_qkv.wgsl.template", + WGSL_TEMPLATE_PARAMETER(a_length_per_tile, a_length_per_tile), + WGSL_TEMPLATE_PARAMETER(component_a, components_a), + WGSL_TEMPLATE_PARAMETER(component_b, components_b), + WGSL_TEMPLATE_PARAMETER(elements_in_value_b, elements_in_value_b), + WGSL_TEMPLATE_PARAMETER(has_norm, has_norm_), + WGSL_TEMPLATE_PARAMETER(has_skip_input, has_skip_input_), + WGSL_TEMPLATE_PARAMETER(has_skip_output, has_skip_output_), + WGSL_TEMPLATE_PARAMETER(k_unroll_tiles, k_unroll_tiles_), + WGSL_TEMPLATE_PARAMETER(single_scale_weights, single_scale_weights_), + WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count), + WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), + WGSL_TEMPLATE_PARAMETER(tile_size_k, tile_size_k), + WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), + WGSL_TEMPLATE_VARIABLE(a, a), + WGSL_TEMPLATE_VARIABLE(input_skip_bias_sum, input_skip_bias_sum_var), + WGSL_TEMPLATE_VARIABLE(k_b, k_b), + WGSL_TEMPLATE_VARIABLE(k_output, k_output), + WGSL_TEMPLATE_VARIABLE(k_scales_b, k_scales_b), + WGSL_TEMPLATE_VARIABLE(norm_scale, norm_scale_var), + WGSL_TEMPLATE_VARIABLE(q_b, q_b), + WGSL_TEMPLATE_VARIABLE(q_output, q_output), + WGSL_TEMPLATE_VARIABLE(q_scales_b, q_scales_b), + WGSL_TEMPLATE_VARIABLE(skip, skip_var), + WGSL_TEMPLATE_VARIABLE(v_b, v_b), + WGSL_TEMPLATE_VARIABLE(v_output, v_output), + WGSL_TEMPLATE_VARIABLE(v_scales_b, v_scales_b)); + } + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"Nq", ProgramUniformVariableDataType::Uint32}, + {"Nkv", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"K_of_a", ProgramUniformVariableDataType::Uint32}, + {"K_of_b", ProgramUniformVariableDataType::Uint32}, + {"block_size", ProgramUniformVariableDataType::Uint32}, + {"blocks_per_col", ProgramUniformVariableDataType::Uint32}, + {"num_N_tile", ProgramUniformVariableDataType::Uint32}, + {"batch_count", ProgramUniformVariableDataType::Uint32}, + {"skip_size", ProgramUniformVariableDataType::Uint32}, + {"epsilon", ProgramUniformVariableDataType::Float32}); + + private: + uint32_t tile_size_; + bool single_scale_weights_; + uint32_t tile_size_k_vec_; + uint32_t k_unroll_tiles_; + bool has_norm_; + bool has_skip_input_; + bool has_skip_output_; +}; + +} // namespace + +ONNX_OPERATOR_KERNEL_EX( + MatMulNBitsQkv, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", WebGpuSupportedFloatTypes()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulNBitsQkv); + +Status MatMulNBitsQkv::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const Tensor* a = context.Input(0); + const Tensor* skip = context.Input(1); + const Tensor* norm_scale = context.Input(2); + const Tensor* q_b = context.Input(3); + const Tensor* q_scales = context.Input(4); + const Tensor* k_b = context.Input(5); + const Tensor* k_scales = context.Input(6); + const Tensor* v_b = context.Input(7); + const Tensor* v_scales = context.Input(8); + + ORT_ENFORCE(bits_ == 4, "MatMulNBitsQkv currently supports 4-bit weights only."); + ORT_ENFORCE(block_size_ == 32, "MatMulNBitsQkv currently supports block_size=32 only."); + + TensorShape q_b_shape({Nq_, K_}); + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), q_b_shape, false, true)); + + const uint32_t batch_count = onnxruntime::narrow(helper.OutputOffsets().size()); + const uint32_t M = onnxruntime::narrow(helper.M()); + const uint32_t K = onnxruntime::narrow(helper.K()); + const uint32_t Nq = onnxruntime::narrow(Nq_); + const uint32_t Nkv = onnxruntime::narrow(Nkv_); + + auto q_shape = helper.OutputShape(); + TensorShapeVector kv_dims(q_shape.GetDims().begin(), q_shape.GetDims().end()); + kv_dims.back() = Nkv_; + TensorShape kv_shape(kv_dims); + Tensor* q_output = context.Output(0, q_shape); + Tensor* k_output = context.Output(1, kv_shape); + Tensor* v_output = context.Output(2, kv_shape); + Tensor* input_skip_bias_sum = (skip != nullptr && context.OutputCount() > 3) + ? context.Output(3, a->Shape()) + : nullptr; + if (q_output->Shape().Size() == 0) { + return Status::OK(); + } + + ORT_ENFORCE(norm_scale->Shape().Size() == K_, "norm_scale must have shape [K]."); + + const uint32_t block_size = onnxruntime::narrow(block_size_); + const bool would_use_subgroup_unfused = + WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(M, + Nq, + K, + batch_count, + block_size, + accuracy_level_, + bits_, + context, + q_output); + const bool would_use_dp4a_unfused = + !would_use_subgroup_unfused && + WouldApplyDP4AMatMulNBitsInCurrentDispatch(M, + Nq, + K, + block_size, + accuracy_level_, + context, + q_output); + const bool would_use_wide_tile_unfused = + !would_use_subgroup_unfused && + !would_use_dp4a_unfused && + WouldApplyWideTileMatMulNBitsInCurrentDispatch(M, + K, + block_size, + bits_); + + // The fused MatMulNBitsQkv shader binds every Q/K/V weight + scales tensor and the + // norm/skip tensors as storage buffers. Devices with a tight maxStorageBuffersPerShaderStage + // (notably macOS Metal at 10) cannot bind that many. For those devices we run the layer + // norm separately into a scratch tensor and then dispatch a no-norm variant of the fused + // QKV decode program (which omits the norm_scale, skip, and skip-output bindings, dropping + // the storage-buffer count from up to 13 down to 10). + // + // Storage-buffer count: input_a + (skip?) + norm_scale + 3 * (weight + scales) + // + q/k/v outputs + (skip output?) + const uint32_t required_storage_buffers = + 1u // input_a + + (skip != nullptr ? 1u : 0u) // skip + + 1u // norm_scale + + 6u // q/k/v weights + scales + + 3u // q/k/v outputs + + (input_skip_bias_sum != nullptr ? 1u : 0u); // skip output + const bool exceeds_storage_buffer_limit = + required_storage_buffers > context.DeviceLimits().maxStorageBuffersPerShaderStage; + + if (would_use_subgroup_unfused || would_use_dp4a_unfused || would_use_wide_tile_unfused || + M != 1) { + if (skip != nullptr) { + return ApplyUnfusedQKVSkipSimplifiedLayerNorm(a, + skip, + norm_scale, + q_b, + q_scales, + k_b, + k_scales, + v_b, + v_scales, + K_, + Nq_, + Nkv_, + block_size_, + accuracy_level_, + bits_, + epsilon_, + context, + q_output, + k_output, + v_output, + input_skip_bias_sum); + } + return ApplyUnfusedQKVSimplifiedLayerNorm(a, + norm_scale, + q_b, + q_scales, + k_b, + k_scales, + v_b, + v_scales, + K_, + Nq_, + Nkv_, + block_size_, + accuracy_level_, + bits_, + epsilon_, + context, + q_output, + k_output, + v_output); + } + + // For the partial-fuse path, run [Skip]SimplifiedLayerNormalization into a scratch tensor + // first, then point the decode program at the normalized tensor with the norm/skip bindings + // turned off. The user-visible residual passthrough (input_skip_bias_sum) is produced by the + // skip-norm op directly, so the decode program never needs to write it itself. + std::optional normalized_a_storage; + const Tensor* decode_a = a; + if (exceeds_storage_buffer_limit) { + normalized_a_storage.emplace(context.CreateGPUTensor(a->DataType(), a->Shape())); + if (skip != nullptr) { + ORT_RETURN_IF_ERROR(RunSkipLayerNormProgram(context, a, skip, norm_scale, + /*beta=*/nullptr, /*bias=*/nullptr, + epsilon_, /*simplified=*/true, + &*normalized_a_storage, + input_skip_bias_sum)); + } else { + const auto& a_shape = a->Shape(); + const int64_t norm_size = a_shape[a_shape.NumDimensions() - 1]; + const uint32_t norm_count = onnxruntime::narrow(a_shape.Size() / norm_size); + ORT_RETURN_IF_ERROR(onnxruntime::webgpu::RunLayerNormProgram( + context, a, norm_scale, /*bias=*/nullptr, epsilon_, norm_count, norm_size, + /*simplified=*/true, &*normalized_a_storage, /*mean=*/nullptr, + /*inv_std_dev=*/nullptr)); + } + decode_a = &*normalized_a_storage; + } + + const uint32_t components_a = GetMaxComponents(K); + const uint32_t block_size_per_col = block_size; + const uint32_t n_blocks_per_col = (K + block_size_per_col - 1) / block_size_per_col; + const uint32_t blob_size = (block_size_per_col / 8) * static_cast(bits_); + const uint32_t blob_size_in_words = blob_size / 4; + const uint32_t components_b = GetMaxComponents(blob_size_in_words); + constexpr uint32_t kU32Components = 4; + const uint32_t components_b_with_u32 = components_b * kU32Components; + const uint32_t K_of_b = (n_blocks_per_col * blob_size) / components_b_with_u32; + const bool single_scale_weights = + q_scales->Shape().Size() == 1 && k_scales->Shape().Size() == 1 && v_scales->Shape().Size() == 1; + + uint32_t workgroup_size = 128; + uint32_t tile_size = 8; + uint32_t tile_size_k_vec = (context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u; + + const uint32_t elements_in_value_b = components_b * (32u / onnxruntime::narrow(bits_)); + const uint32_t tile_size_k = tile_size_k_vec * elements_in_value_b; + const uint32_t k_tile_iterations = K / tile_size_k; + + // Decode-program-level skip bindings: only used by the fully-fused path. The partial-fuse + // path has already merged the residual upstream and exposed the residual passthrough output + // through the layer-norm op, so the decode program runs with skip_input/skip_output disabled. + const bool decode_has_norm = !exceeds_storage_buffer_limit; + const bool decode_has_skip_input = !exceeds_storage_buffer_limit && skip != nullptr; + std::optional input_skip_bias_sum_scratch; + Tensor* decode_input_skip_bias_sum = nullptr; + if (decode_has_skip_input) { + decode_input_skip_bias_sum = input_skip_bias_sum; + if (decode_input_skip_bias_sum == nullptr) { + input_skip_bias_sum_scratch.emplace(context.CreateGPUTensor(a->DataType(), a->Shape())); + decode_input_skip_bias_sum = &*input_skip_bias_sum_scratch; + } + } + const bool decode_has_skip_output = decode_input_skip_bias_sum != nullptr; + + uint32_t k_unroll_tiles = 1; + if ((K % tile_size_k) == 0) { + if (k_tile_iterations >= 8 && std::max(Nq, Nkv) <= 2048 && + context.AdapterInfo().vendor != std::string_view{"intel"}) { + k_unroll_tiles = 4; + } else if (k_tile_iterations >= 4) { + k_unroll_tiles = 2; + } + } + + const uint32_t num_N_tile = CeilDiv(std::max(Nq, Nkv), tile_size); + MatMulNBitsQkvDecodeProgram program{tile_size, + single_scale_weights, + tile_size_k_vec, + k_unroll_tiles, + decode_has_norm, + decode_has_skip_input, + decode_has_skip_output}; + program.SetWorkgroupSize(workgroup_size); + program.SetDispatchGroupSize(num_N_tile, 1, batch_count); + program + .AddInput({decode_a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}); + if (decode_has_skip_input) { + program.AddInput({skip, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}); + } + if (decode_has_norm) { + program.AddInput({norm_scale, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_a)}); + } + program + .AddInputs({{q_b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, + {q_scales, ProgramTensorMetadataDependency::TypeAndRank}, + {k_b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, + {k_scales, ProgramTensorMetadataDependency::TypeAndRank}, + {v_b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(components_b_with_u32)}, + {v_scales, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutputs({{q_output, ProgramTensorMetadataDependency::TypeAndRank}, + {k_output, ProgramTensorMetadataDependency::TypeAndRank}, + {v_output, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddUniformVariables({{Nq}, + {Nkv}, + {K}, + {K / components_a}, + {K_of_b}, + {block_size}, + {n_blocks_per_col}, + {num_N_tile}, + {batch_count}, + {decode_has_skip_input ? onnxruntime::narrow(skip->Shape().Size()) : 0u}, + {epsilon_}}) + .CacheHint(Nq, + Nkv, + K, + tile_size, + tile_size_k_vec, + k_unroll_tiles, + single_scale_weights, + decode_has_norm, + decode_has_skip_input, + decode_has_skip_output, + "decode_qkv_sln"); + if (decode_has_skip_output) { + program.AddOutput({decode_input_skip_bias_sum, + ProgramTensorMetadataDependency::TypeAndRank, + static_cast(components_a)}); + } + + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.h new file mode 100644 index 0000000000000..4d57ab5ac2b3c --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class MatMulNBitsQkv final : public WebGpuKernel { + public: + explicit MatMulNBitsQkv(const OpKernelInfo& info) : WebGpuKernel(info) { + K_ = info.GetAttr("K"); + Nq_ = info.GetAttr("Nq"); + Nkv_ = info.GetAttr("Nkv"); + block_size_ = info.GetAttr("block_size"); + bits_ = info.GetAttr("bits"); + accuracy_level_ = info.GetAttrOrDefault("accuracy_level", 4); + epsilon_ = info.GetAttrOrDefault("epsilon", 1e-6f); + ORT_ENFORCE(bits_ == 4, + "MatMulNBitsQkv currently supports 4-bit weights only."); + } + + Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; + + private: + int64_t K_; + int64_t Nq_; + int64_t Nkv_; + int64_t block_size_; + int64_t accuracy_level_; + int64_t bits_; + float epsilon_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.wgsl.template new file mode 100644 index 0000000000000..60f34e9ef2530 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.wgsl.template @@ -0,0 +1,281 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#param a_length_per_tile +#param component_a +#param component_b +#param elements_in_value_b +#param has_norm +#param has_skip_input +#param has_skip_output +#param k_unroll_tiles +#param single_scale_weights +#param sub_tile_count +#param tile_size_k_vec +#param tile_size_k +#param tile_size + +#use .getByOffset .setByOffset + +#if has_norm +var sum_squared_shared : array; +#endif +var tile_A : array; +var q_inter_results : array, tile_size>; +var k_inter_results : array, tile_size>; +var v_inter_results : array, tile_size>; + +const default_zero_point = vec4(q_output_element_t(8)); + +fn unpack_nibble_values(word: u32) -> vec4 { + let unpacked = unpack4xU8(word); + return vec4(q_output_element_t(unpacked[0]), + q_output_element_t(unpacked[1]), + q_output_element_t(unpacked[2]), + q_output_element_t(unpacked[3])); +} + +fn load_merged_input(input_offset: u32) -> input_a_value_t { + var value = a.getByOffset(input_offset); +#if has_skip_input + let skip_offset = input_offset % (uniforms.skip_size / component_a); + value += input_a_value_t(skip.getByOffset(skip_offset)); +#endif + return value; +} + +#if component_a == 1 +fn load_a_vec4(a_offset: u32) -> vec4 { + return vec4(q_output_element_t(tile_A[a_offset]), + q_output_element_t(tile_A[a_offset + 1]), + q_output_element_t(tile_A[a_offset + 2]), + q_output_element_t(tile_A[a_offset + 3])); +} +#elif component_a == 2 +fn load_a_vec4(a_offset: u32) -> vec4 { + let a0 = tile_A[a_offset]; + let a1 = tile_A[a_offset + 1]; + return vec4(q_output_element_t(a0[0]), + q_output_element_t(a0[1]), + q_output_element_t(a1[0]), + q_output_element_t(a1[1])); +} +#elif component_a == 4 +fn load_a_vec4(a_offset: u32) -> vec4 { + let a = tile_A[a_offset]; + return vec4(q_output_element_t(a[0]), + q_output_element_t(a[1]), + q_output_element_t(a[2]), + q_output_element_t(a[3])); +} +#endif + +fn loadSHMA(batch: u32, b_global_base: u32, kidx: u32, col: u32, inv_std: f32) { + let k_offset = kidx / component_a + col; + let input_offset = batch * uniforms.K_of_a + k_offset; + if (k_offset < uniforms.K_of_a) { +#if has_norm + let merged_value = load_merged_input(input_offset); +#if has_skip_output + if (b_global_base == 0u) { + input_skip_bias_sum.setByOffset(input_offset, input_skip_bias_sum_value_t(merged_value)); + } +#endif + tile_A[col] = merged_value * input_a_value_t(input_a_element_t(inv_std)) * norm_scale.getByOffset(k_offset); +#else + // Layer norm has already been applied to `a` upstream; load the pre-normalized value directly. + _ = b_global_base; + _ = inv_std; + tile_A[col] = a.getByOffset(input_offset); +#endif + } else { + tile_A[col] = input_a_value_t(0); + } +} + +fn compute_projection_sum(weight: q_b_value_t, + scale: q_output_element_t, + idx: u32) -> q_output_element_t { + var sum = q_output_element_t(0); + var a_offset = idx * (8 / component_a) * component_b; +#if component_b == 1 + let weight_lower = unpack_nibble_values(weight & 0x0F0F0F0Fu) - default_zero_point; + let weight_upper = unpack_nibble_values((weight >> 4) & 0x0F0F0F0Fu) - default_zero_point; + let w0 = vec4(q_output_element_t(weight_lower[0]), q_output_element_t(weight_upper[0]), q_output_element_t(weight_lower[1]), q_output_element_t(weight_upper[1])) * scale; + let w1 = vec4(q_output_element_t(weight_lower[2]), q_output_element_t(weight_upper[2]), q_output_element_t(weight_lower[3]), q_output_element_t(weight_upper[3])) * scale; +#if component_a == 1 + let a0 = load_a_vec4(a_offset); + let a1 = load_a_vec4(a_offset + 4); + sum += dot(a0, w0) + dot(a1, w1); +#elif component_a == 2 + let a0 = load_a_vec4(a_offset); + let a1 = load_a_vec4(a_offset + 2); + sum += dot(a0, w0) + dot(a1, w1); +#elif component_a == 4 + let a0 = load_a_vec4(a_offset); + let a1 = load_a_vec4(a_offset + 1); + sum += dot(a0, w0) + dot(a1, w1); +#endif +#else + for (var i = 0u; i < component_b; i++) { + let weight_lower = unpack_nibble_values(weight[i] & 0x0F0F0F0Fu) - default_zero_point; + let weight_upper = unpack_nibble_values((weight[i] >> 4) & 0x0F0F0F0Fu) - default_zero_point; + let w0 = vec4(q_output_element_t(weight_lower[0]), q_output_element_t(weight_upper[0]), q_output_element_t(weight_lower[1]), q_output_element_t(weight_upper[1])) * scale; + let w1 = vec4(q_output_element_t(weight_lower[2]), q_output_element_t(weight_upper[2]), q_output_element_t(weight_lower[3]), q_output_element_t(weight_upper[3])) * scale; +#if component_a == 1 + let a0 = load_a_vec4(a_offset); + let a1 = load_a_vec4(a_offset + 4); + sum += dot(a0, w0) + dot(a1, w1); + a_offset += 8; +#elif component_a == 2 + let a0 = load_a_vec4(a_offset); + let a1 = load_a_vec4(a_offset + 2); + sum += dot(a0, w0) + dot(a1, w1); + a_offset += 4; +#elif component_a == 4 + let a0 = load_a_vec4(a_offset); + let a1 = load_a_vec4(a_offset + 1); + sum += dot(a0, w0) + dot(a1, w1); + a_offset += 2; +#endif + } +#endif + return sum; +} + +fn process_k_tile(batch: u32, b_global_base: u32, thread_idx: u32, idx: u32, idy: u32, kidx: u32, inv_std: f32) { + for (var id = thread_idx; id < a_length_per_tile; id += workgroup_size_x) { + loadSHMA(batch, b_global_base, kidx, id, inv_std); + } + workgroupBarrier(); + +#if single_scale_weights + let q_scale_b = q_output_element_t(q_scales_b.getByOffset(0)); + let k_scale_b = q_output_element_t(k_scales_b.getByOffset(0)); + let v_scale_b = q_output_element_t(v_scales_b.getByOffset(0)); +#endif + + for (var local_row_offset = 0u; local_row_offset < tile_size; local_row_offset += sub_tile_count) { + let b_global = b_global_base + local_row_offset + idy; + let k_offset = kidx / elements_in_value_b + idx; + #if !single_scale_weights + let block_idx = (kidx + idx * elements_in_value_b) / uniforms.block_size; + #endif + if (k_offset < uniforms.K_of_b) { + if (b_global < uniforms.Nq) { + #if !single_scale_weights + let q_scale_b = q_output_element_t(q_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx)); + #endif + let q_weight = q_b.getByOffset(b_global * uniforms.K_of_b + k_offset); + q_inter_results[local_row_offset + idy][idx] += compute_projection_sum(q_weight, q_scale_b, idx); + } + if (b_global < uniforms.Nkv) { + #if !single_scale_weights + let k_scale_b = q_output_element_t(k_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx)); + let v_scale_b = q_output_element_t(v_scales_b.getByOffset(b_global * uniforms.blocks_per_col + block_idx)); + #endif + let k_weight = k_b.getByOffset(b_global * uniforms.K_of_b + k_offset); + let v_weight = v_b.getByOffset(b_global * uniforms.K_of_b + k_offset); + k_inter_results[local_row_offset + idy][idx] += compute_projection_sum(k_weight, k_scale_b, idx); + v_inter_results[local_row_offset + idy][idx] += compute_projection_sum(v_weight, v_scale_b, idx); + } + } + } + workgroupBarrier(); + } + +$MAIN { + let batch = workgroup_idx / uniforms.num_N_tile; + let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size; + + let idx = local_idx % tile_size_k_vec; + let idy = local_idx / tile_size_k_vec; + + if (local_idx < tile_size) { + for (var b = 0u; b < tile_size_k_vec; b++) { + q_inter_results[local_idx][b] = q_output_element_t(0); + k_inter_results[local_idx][b] = q_output_element_t(0); + v_inter_results[local_idx][b] = q_output_element_t(0); + } + } + +#if has_norm + var sum_squared_local = 0.0; + for (var a_idx = local_idx; a_idx < uniforms.K_of_a; a_idx += workgroup_size_x) { + let a_value = load_merged_input(batch * uniforms.K_of_a + a_idx); +#if component_a == 1 + let a_f32 = f32(a_value); + sum_squared_local += a_f32 * a_f32; +#elif component_a == 2 + let a_f32 = vec2(a_value); + sum_squared_local += dot(a_f32, a_f32); +#elif component_a == 4 + let a_f32 = vec4(a_value); + sum_squared_local += dot(a_f32, a_f32); +#endif + } + sum_squared_shared[local_idx] = sum_squared_local; + workgroupBarrier(); + + var reduce_size : u32 = workgroup_size_x; + for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) { + reduce_size = curr_size + (reduce_size & 1u); + if (local_idx < curr_size) { + sum_squared_shared[local_idx] += sum_squared_shared[local_idx + reduce_size]; + } + workgroupBarrier(); + } + + let inv_std = inverseSqrt(sum_squared_shared[0] / f32(uniforms.K) + uniforms.epsilon); +#else + // Layer norm already applied upstream; inv_std is unused but kept in the loadSHMA signature. + let inv_std = 1.0; +#endif + +#if k_unroll_tiles == 1 + for (var kidx = 0u; kidx < uniforms.K; kidx += tile_size_k) { + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std); + } +#elif k_unroll_tiles == 2 + let unrolled_k_step = tile_size_k * 2u; + let unrolled_k_limit = uniforms.K - (uniforms.K % unrolled_k_step); + for (var kidx = 0u; kidx < unrolled_k_limit; kidx += unrolled_k_step) { + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std); + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k, inv_std); + } + for (var kidx = unrolled_k_limit; kidx < uniforms.K; kidx += tile_size_k) { + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std); + } +#elif k_unroll_tiles == 4 + let unrolled_k_step = tile_size_k * 4u; + let unrolled_k_limit = uniforms.K - (uniforms.K % unrolled_k_step); + for (var kidx = 0u; kidx < unrolled_k_limit; kidx += unrolled_k_step) { + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std); + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k, inv_std); + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k * 2u, inv_std); + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx + tile_size_k * 3u, inv_std); + } + for (var kidx = unrolled_k_limit; kidx < uniforms.K; kidx += tile_size_k) { + process_k_tile(batch, b_global_base, local_idx, idx, idy, kidx, inv_std); + } +#endif + + if (local_idx < tile_size) { + let b_global = b_global_base + local_idx; + var q_output_value = q_output_element_t(0); + var k_output_value = q_output_element_t(0); + var v_output_value = q_output_element_t(0); + for (var b = 0u; b < tile_size_k_vec; b++) { + q_output_value += q_inter_results[local_idx][b]; + k_output_value += k_inter_results[local_idx][b]; + v_output_value += v_inter_results[local_idx][b]; + } + if (b_global < uniforms.Nq) { + q_output.setByOffset(batch * uniforms.Nq + b_global, q_output_value_t(q_output_value)); + } + if (b_global < uniforms.Nkv) { + k_output.setByOffset(batch * uniforms.Nkv + b_global, k_output_value_t(k_output_value)); + v_output.setByOffset(batch * uniforms.Nkv + b_global, v_output_value_t(v_output_value)); + } + } +} diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 41e2b52016cfd..32368287f6347 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -31,6 +31,8 @@ static const BuildKernelCreateInfoFn build_kernel_create_info_function_table[] = BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index a5537c7d58b05..8ce5527d3730f 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3619,6 +3619,210 @@ For example, for 4 bits, the first 4 bits are stored in the lower 4 bits of a by } }); + static const char* MatMulNBitsMlp_ver1_doc = R"DOC( +MatMulNBitsMlp fuses two MatMulNBits projections that share the same input and computes + + gate = MatMulNBits(A, gate_weight) + gate_bias + up = MatMulNBits(A, up_weight) + up_bias + Y = activation(gate) * up + +It can also optionally fuse SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization before the +two projections: + + A_norm = SimplifiedLayerNormalization(A, norm_scale, epsilon) + gate = MatMulNBits(A_norm, gate_weight) + gate_bias + up = MatMulNBits(A_norm, up_weight) + up_bias + Y = activation(gate) * up + + A_norm = SkipSimplifiedLayerNormalization(A, skip, norm_scale, epsilon) + gate = MatMulNBits(A_norm, gate_weight) + gate_bias + up = MatMulNBits(A_norm, up_weight) + up_bias + Y = activation(gate) * up + +This operator is intended for decoder MLP patterns such as Qwen-style gate and up projections, but it remains +semantically valid for both prefill and decode because the output shape is the standard MatMul result shape +derived from the runtime shape of A and the shared attributes K and N. + +The operator contract includes a string attribute describing the fused gate activation. + +When fused from SkipSimplifiedLayerNormalization, the optional residual-sum output may also be materialized: + + A_norm, input_skip_bias_sum = SkipSimplifiedLayerNormalization(A, skip, norm_scale, epsilon) + gate = MatMulNBits(A_norm, gate_weight) + gate_bias + up = MatMulNBits(A_norm, up_weight) + up_bias + Y = activation(gate) * up +)DOC"; + + ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBitsMlp) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc(MatMulNBitsMlp_ver1_doc) + .Attr("K", "Input feature dimension shared by both quantized weight matrices.", AttributeProto::INT) + .Attr("N", "Output feature dimension shared by both quantized weight matrices.", AttributeProto::INT) + .Attr("bits", "Bit-width used to quantize both weight matrices (valid range: 2~8)", AttributeProto::INT, static_cast(4)) + .Attr("block_size", + "Size of each quantization block along the K dimension. Must be a power of two and >= 16.", + AttributeProto::INT) + .Attr("accuracy_level", + "The minimum accuracy level of input A. It follows the same semantics as MatMulNBits.", + AttributeProto::INT, static_cast(0)) + .Attr("activation", + "Activation applied to the gate projection.", + AttributeProto::STRING) + .Attr("epsilon", + "Epsilon used by the optional fused (Skip)SimplifiedLayerNormalization. Defaults to 1e-5.", + AttributeProto::FLOAT, 1e-5f) + .Input(0, "A", "The shared input tensor.", "T1") + .Input(1, "skip", "Optional skip input used by SkipSimplifiedLayerNormalization.", "T1", OpSchema::Optional) + .Input(2, "norm_scale", "Optional RMSNorm scale with shape [K] used by SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization.", "T1", OpSchema::Optional) + .Input(3, "gate_B", "Packed uint8 tensor for the gate projection weights.", "T2") + .Input(4, "gate_scales", "Per-block scaling factors for the gate projection.", "T1") + .Input(5, "gate_bias", "Optional bias for the gate projection with shape [N].", "T1", OpSchema::Optional) + .Input(6, "up_B", "Packed uint8 tensor for the up projection weights.", "T2") + .Input(7, "up_scales", "Per-block scaling factors for the up projection.", "T1") + .Input(8, "up_bias", "Optional bias for the up projection with shape [N].", "T1", OpSchema::Optional) + .Output(0, "Y", "The fused gated MLP output tensor.", "T1") + .Output(1, "input_skip_bias_sum", "Optional residual-sum output for SkipSimplifiedLayerNormalization.", "T1", OpSchema::Optional) + .TypeConstraint("T1", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, + "Constrain input and output types to float tensors.") + .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (ctx.getNumOutputs() > 1) { + propagateElemTypeFromInputToOutput(ctx, 0, 1); + } + + const int64_t in_features = getAttribute(ctx, "K", -1); + const int64_t out_features = getAttribute(ctx, "N", -1); + MatmulWithQuantWeightShapeInference(ctx, in_features, out_features, true); + + if (ctx.hasInput(1) && !ctx.hasInput(2)) { + fail_shape_inference("norm_scale input must be present when skip input is provided"); + } + + if (ctx.hasOutput(1)) { + if (!ctx.hasInput(1)) { + fail_shape_inference("skip input must be present when input_skip_bias_sum output is requested"); + } + + if (!hasInputShape(ctx, 0)) { + return; + } + + auto* skip_sum_shape = getOutputShape(ctx, 1); + *skip_sum_shape = getInputShape(ctx, 0); + } + + if (ctx.hasInput(2)) { + if (!hasInputShape(ctx, 2)) { + fail_shape_inference("norm_scale shape must be known"); + } + + const auto& norm_scale_shape = getInputShape(ctx, 2); + if (norm_scale_shape.dim_size() != 1 || + !norm_scale_shape.dim(0).has_dim_value() || + norm_scale_shape.dim(0).dim_value() != in_features) { + fail_shape_inference("norm_scale shape must be [K] where K = ", in_features); + } + } + + for (size_t bias_input_index : {5U, 8U}) { + if (!ctx.hasInput(static_cast(bias_input_index))) { + continue; + } + + if (!hasInputShape(ctx, static_cast(bias_input_index))) { + fail_shape_inference("bias shape must be known"); + } + + const auto& bias_shape = getInputShape(ctx, static_cast(bias_input_index)); + if (bias_shape.dim_size() != 1 || + !bias_shape.dim(0).has_dim_value() || + bias_shape.dim(0).dim_value() != out_features) { + fail_shape_inference("bias shape must be [N] where N = ", out_features); + } + } + }); + + static const char* MatMulNBitsQkv_ver1_doc = R"DOC( +MatMulNBitsQkv fuses either SimplifiedLayerNormalization (RMSNorm) +or SkipSimplifiedLayerNormalization with three MatMulNBits projections that share the +same normalized activation. + + A_norm = SimplifiedLayerNormalization(A, norm_scale, epsilon) + Q = MatMulNBits(A_norm, q_weight) + K = MatMulNBits(A_norm, k_weight) + V = MatMulNBits(A_norm, v_weight) + +If skip is provided, the operator computes the SkipSimplifiedLayerNormalization variant +and may also return the input+skip residual sum as output 3. + +This operator is intended as a decode-oriented QKV fusion primitive. +)DOC"; + + ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBitsQkv) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc(MatMulNBitsQkv_ver1_doc) + .Attr("K", "Input feature dimension shared by the normalized input and all projection weights.", AttributeProto::INT) + .Attr("Nq", "Output feature dimension of the Q projection.", AttributeProto::INT) + .Attr("Nkv", "Output feature dimension shared by the K and V projections.", AttributeProto::INT) + .Attr("bits", "Bit-width used to quantize all weight matrices (valid range: 2~8)", AttributeProto::INT, static_cast(4)) + .Attr("block_size", + "Size of each quantization block along the K dimension. Must be a power of two and >= 16.", + AttributeProto::INT) + .Attr("accuracy_level", + "The minimum accuracy level of input A. It follows the same semantics as MatMulNBits.", + AttributeProto::INT, static_cast(0)) + .Attr("epsilon", "Epsilon used by the simplified layer norm reduction.", AttributeProto::FLOAT, 1e-6f) + .Input(0, "A", "The shared input tensor.", "T1") + .Input(1, "skip", "Optional residual input for SkipSimplifiedLayerNormalization.", "T1", OpSchema::Optional) + .Input(2, "norm_scale", "Scale input for the simplified layer norm with shape [K].", "T1") + .Input(3, "q_B", "Packed uint8 tensor for the Q projection weights.", "T2") + .Input(4, "q_scales", "Per-block scaling factors for the Q projection.", "T1") + .Input(5, "k_B", "Packed uint8 tensor for the K projection weights.", "T2") + .Input(6, "k_scales", "Per-block scaling factors for the K projection.", "T1") + .Input(7, "v_B", "Packed uint8 tensor for the V projection weights.", "T2") + .Input(8, "v_scales", "Per-block scaling factors for the V projection.", "T1") + .Output(0, "Q", "The Q projection output tensor.", "T1") + .Output(1, "K", "The K projection output tensor.", "T1") + .Output(2, "V", "The V projection output tensor.", "T1") + .Output(3, "input_skip_bias_sum", "Optional residual-sum output for SkipSimplifiedLayerNormalization.", "T1", OpSchema::Optional) + .TypeConstraint("T1", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, + "Constrain input and output types to float tensors.") + .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + for (size_t output_index = 0; output_index < ctx.getNumOutputs(); ++output_index) { + propagateElemTypeFromInputToOutput(ctx, 0, output_index); + } + + if (!hasInputShape(ctx, 0)) { + return; + } + + const auto& input_shape = getInputShape(ctx, 0); + if (input_shape.dim_size() == 0) { + fail_shape_inference("A must have rank >= 1"); + } + + const int64_t q_out_features = getAttribute(ctx, "Nq", -1); + const int64_t kv_out_features = getAttribute(ctx, "Nkv", -1); + + auto set_output_shape = [&](int output_index, int64_t out_features) { + auto* output_shape = getOutputShape(ctx, output_index); + *output_shape = input_shape; + output_shape->mutable_dim(output_shape->dim_size() - 1)->set_dim_value(out_features); + }; + + set_output_shape(0, q_out_features); + set_output_shape(1, kv_out_features); + set_output_shape(2, kv_out_features); + if (ctx.getNumOutputs() > 3) { + auto* output_shape = getOutputShape(ctx, 3); + *output_shape = input_shape; + } + }); + static const char* MatMulBnb4_ver1_doc = R"DOC( MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 282261cab0e58..a4311969a6d73 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -55,6 +55,8 @@ #include "core/optimizer/layer_norm_fusion.h" #include "core/optimizer/matmul_activation_fusion.h" #include "core/optimizer/matmul_add_fusion.h" +#include "core/optimizer/matmul_nbits_qkv_fusion.h" +#include "core/optimizer/matmul_nbits_mlp_fusion.h" #include "core/optimizer/matmul_bn_fusion.h" #include "core/optimizer/matmul_integer_to_float.h" #include "core/optimizer/matmul_scale_fusion.h" @@ -446,6 +448,10 @@ InlinedVector> GenerateTransformers( #endif transformers.emplace_back(std::make_unique(cpu_ep)); + transformers.emplace_back(std::make_unique( + InlinedHashSet{onnxruntime::kWebGpuExecutionProvider})); + transformers.emplace_back(std::make_unique( + InlinedHashSet{onnxruntime::kWebGpuExecutionProvider})); #endif // !defined(DISABLE_CONTRIB_OPS) // The QDQFinalCleanupTransformer must run AFTER other transformers that fuse Q/DQ nodes. Otherwise, their diff --git a/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.cc b/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.cc new file mode 100644 index 0000000000000..522e5f9e495cf --- /dev/null +++ b/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.cc @@ -0,0 +1,494 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/matmul_nbits_mlp_fusion.h" + +#include + +#include "core/graph/graph_utils.h" +#include "core/graph/node_attr_utils.h" +#include "core/optimizer/utils.h" + +namespace onnxruntime { + +namespace { + +constexpr const char* kActivationAttrName = "activation"; +// The transformer name is generic for future expansion, but the current fused +// pattern and emitted op only support gate activation = "silu". To add another +// gate activation (e.g. GELU for Gemma-style MLPs), extend the pattern matcher +// below to recognize the new activation subgraph (or a unary node like `Gelu`), +// add the new value to `MlpActivationKind` in matmul_nbits_mlp.h, and update +// `EmitGateActivationExpr` plus the `#if activation_kind` block in the WGSL +// template. +constexpr const char* kSupportedActivation = "silu"; + +bool HasInput(const Node& node, size_t index) { + return index < node.InputDefs().size() && node.InputDefs()[index] != nullptr && !node.InputDefs()[index]->Name().empty(); +} + +const Node* GetInputNode(const Graph& graph, const Node& node, size_t input_index) { + const auto* edge = graph_utils::GetInputEdge(node, static_cast(input_index)); + return edge == nullptr ? nullptr : graph.GetNode(edge->GetNode().Index()); +} + +bool IsSupportedMul(const Node& node) { + return graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7, 13, 14}); +} + +bool IsSupportedSigmoid(const Node& node) { + return graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6, 13}); +} + +bool IsSupportedSimplifiedLayerNormalization(const Node& node) { + return graph_utils::IsSupportedOptypeVersionAndDomain(node, "SimplifiedLayerNormalization", {1}); +} + +bool IsSupportedSkipSimplifiedLayerNormalization(const Node& node) { + return graph_utils::IsSupportedOptypeVersionAndDomain(node, "SkipSimplifiedLayerNormalization", {1}, kMSDomain); +} + +bool IsSupportedMlpNormAnchor(const Node& node) { + return IsSupportedSimplifiedLayerNormalization(node) || IsSupportedSkipSimplifiedLayerNormalization(node); +} + +bool HasProducedOutput(const Node& node, size_t index) { + return index < node.OutputDefs().size() && node.OutputDefs()[index] != nullptr && !node.OutputDefs()[index]->Name().empty(); +} + +bool ProducesOnlyOptionalSkipOutputAsGraphOutput(const Graph& graph, const Node& node) { + const auto graph_outputs = graph.GetNodeOutputsInGraphOutputs(node); + return std::all_of(graph_outputs.begin(), graph_outputs.end(), [](int output_idx) { return output_idx == 3; }); +} + +size_t ExpectedNormConsumerEdgeCount(const Node& node) { + return 2u + ((IsSupportedSkipSimplifiedLayerNormalization(node) && HasProducedOutput(node, 3)) ? 1u : 0u); +} + +bool HasExpectedNormConsumers(const Graph& graph, const Node& node) { + const auto graph_outputs = graph.GetNodeOutputsInGraphOutputs(node); + const size_t expected_output_edges = ExpectedNormConsumerEdgeCount(node) - graph_outputs.size(); + if (node.GetOutputEdgesCount() != expected_output_edges) { + return false; + } + + for (auto output_edge_it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); output_edge_it != end; ++output_edge_it) { + const auto& output_node = output_edge_it->GetNode(); + const auto output_node_input_arg_idx = static_cast(output_edge_it->GetDstArgIndex()); + const bool is_implicit_input_to_output_node = output_node_input_arg_idx >= output_node.InputDefs().size(); + if (is_implicit_input_to_output_node) { + return false; + } + } + + return true; +} + +bool IsMatMulNBitsWithoutZeroPointOrGroupIdx(const Node& node) { + return graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMulNBits", {1}, kMSDomain) && + !HasInput(node, 3) && !HasInput(node, 4); +} + +int64_t GetIntAttr(const Node& node, const char* name, int64_t default_value, bool required = false) { + const auto* attr = graph_utils::GetNodeAttribute(node, name); + if (attr == nullptr) { + ORT_ENFORCE(!required, "Missing required attribute ", name, " on node ", node.Name()); + return default_value; + } + + return attr->i(); +} + +float GetFloatAttr(const Node& node, const char* name, float default_value) { + const auto* attr = graph_utils::GetNodeAttribute(node, name); + return attr == nullptr ? default_value : attr->f(); +} + +bool IsSupportedQuickGelu(const Node& node) { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "QuickGelu", {1}, kMSDomain)) { + return false; + } + // SiLU is equivalent to QuickGelu(x, alpha=1.0). Any other alpha is a valid + // QuickGelu activation but is not the SiLU function that the fused kernel + // implements, so we conservatively reject it here. + return GetFloatAttr(node, "alpha", 1.0f) == 1.0f; +} + +bool HasSingleNonGraphConsumer(const Graph& graph, const Node& node) { + return !graph.NodeProducesGraphOutput(node) && optimizer_utils::CheckOutputEdges(graph, node, 1); +} + +const Node* GetNormProducer(const Graph& graph, + const Node& gate_matmul, + const Node& up_matmul) { + if (gate_matmul.InputDefs().empty() || up_matmul.InputDefs().empty() || + gate_matmul.InputDefs()[0] != up_matmul.InputDefs()[0]) { + return nullptr; + } + + const Node* gate_input = GetInputNode(graph, gate_matmul, 0); + const Node* up_input = GetInputNode(graph, up_matmul, 0); + if (gate_input == nullptr || gate_input != up_input || !IsSupportedMlpNormAnchor(*gate_input)) { + return nullptr; + } + + if (!HasProducedOutput(*gate_input, 0)) { + return nullptr; + } + + if (graph.NodeProducesGraphOutput(*gate_input) && !ProducesOnlyOptionalSkipOutputAsGraphOutput(graph, *gate_input)) { + return nullptr; + } + + if (!HasExpectedNormConsumers(graph, *gate_input)) { + return nullptr; + } + + const size_t min_norm_inputs = IsSupportedSkipSimplifiedLayerNormalization(*gate_input) ? 3u : 2u; + if (gate_input->InputDefs().size() < min_norm_inputs) { + return nullptr; + } + + return gate_input; +} + +bool ValidateMatMulNBitsPair(const Graph& graph, + const Node& gate_matmul, + const Node& up_matmul, + size_t expected_gate_fanout) { + if (!IsMatMulNBitsWithoutZeroPointOrGroupIdx(gate_matmul) || !IsMatMulNBitsWithoutZeroPointOrGroupIdx(up_matmul)) { + return false; + } + + if (!HasSingleNonGraphConsumer(graph, up_matmul)) { + return false; + } + + if (graph.NodeProducesGraphOutput(gate_matmul) || gate_matmul.GetOutputEdgesCount() != expected_gate_fanout) { + return false; + } + + if (gate_matmul.InputDefs().empty() || up_matmul.InputDefs().empty() || + gate_matmul.InputDefs()[0] != up_matmul.InputDefs()[0]) { + return false; + } + + const int64_t gate_k = GetIntAttr(gate_matmul, "K", -1, true); + const int64_t up_k = GetIntAttr(up_matmul, "K", -1, true); + const int64_t gate_n = GetIntAttr(gate_matmul, "N", -1, true); + const int64_t up_n = GetIntAttr(up_matmul, "N", -1, true); + const int64_t gate_bits = GetIntAttr(gate_matmul, "bits", 4); + const int64_t up_bits = GetIntAttr(up_matmul, "bits", 4); + const int64_t gate_block_size = GetIntAttr(gate_matmul, "block_size", -1, true); + const int64_t up_block_size = GetIntAttr(up_matmul, "block_size", -1, true); + const int64_t gate_accuracy_level = GetIntAttr(gate_matmul, "accuracy_level", 0); + const int64_t up_accuracy_level = GetIntAttr(up_matmul, "accuracy_level", 0); + + return gate_k == up_k && gate_n == up_n && + gate_bits == up_bits && gate_bits == 4 && + gate_block_size == up_block_size && gate_block_size == 32 && + gate_accuracy_level == up_accuracy_level; +} + +// Validates the SiLU-decomposed activation shape: +// gate_matmul -> Sigmoid -+ +// gate_matmul ------------+-> silu_mul -> final_mul <- up_matmul +bool IsFuseCandidateSilu(const Graph& graph, + const Node& gate_matmul, + const Node& up_matmul, + const Node& sigmoid, + const Node& silu_mul, + const Node& final_mul) { + if (!IsSupportedSigmoid(sigmoid) || !IsSupportedMul(silu_mul) || !IsSupportedMul(final_mul)) { + return false; + } + + if (!HasSingleNonGraphConsumer(graph, sigmoid) || !HasSingleNonGraphConsumer(graph, silu_mul)) { + return false; + } + + if (!ValidateMatMulNBitsPair(graph, gate_matmul, up_matmul, /*expected_gate_fanout=*/2)) { + return false; + } + + if (sigmoid.InputDefs()[0] != gate_matmul.OutputDefs()[0]) { + return false; + } + + const bool silu_mul_matches = + (silu_mul.InputDefs()[0] == gate_matmul.OutputDefs()[0] && silu_mul.InputDefs()[1] == sigmoid.OutputDefs()[0]) || + (silu_mul.InputDefs()[1] == gate_matmul.OutputDefs()[0] && silu_mul.InputDefs()[0] == sigmoid.OutputDefs()[0]); + if (!silu_mul_matches) { + return false; + } + + const bool final_mul_matches = + (final_mul.InputDefs()[0] == silu_mul.OutputDefs()[0] && final_mul.InputDefs()[1] == up_matmul.OutputDefs()[0]) || + (final_mul.InputDefs()[1] == silu_mul.OutputDefs()[0] && final_mul.InputDefs()[0] == up_matmul.OutputDefs()[0]); + return final_mul_matches; +} + +// Validates the fused-QuickGelu activation shape produced by QuickGeluFusion: +// gate_matmul -> QuickGelu(alpha=1.0) -> final_mul <- up_matmul +bool IsFuseCandidateQuickGelu(const Graph& graph, + const Node& gate_matmul, + const Node& up_matmul, + const Node& quick_gelu, + const Node& final_mul) { + if (!IsSupportedQuickGelu(quick_gelu) || !IsSupportedMul(final_mul)) { + return false; + } + + if (!HasSingleNonGraphConsumer(graph, quick_gelu)) { + return false; + } + + if (!ValidateMatMulNBitsPair(graph, gate_matmul, up_matmul, /*expected_gate_fanout=*/1)) { + return false; + } + + if (quick_gelu.InputDefs()[0] != gate_matmul.OutputDefs()[0]) { + return false; + } + + const bool final_mul_matches = + (final_mul.InputDefs()[0] == quick_gelu.OutputDefs()[0] && final_mul.InputDefs()[1] == up_matmul.OutputDefs()[0]) || + (final_mul.InputDefs()[1] == quick_gelu.OutputDefs()[0] && final_mul.InputDefs()[0] == up_matmul.OutputDefs()[0]); + return final_mul_matches; +} + +} // namespace + +Status MatMulNBitsMlpFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + for (auto node_index : node_topology_list) { + auto* node_ptr = graph.GetNode(node_index); + if (node_ptr == nullptr) { + continue; + } + + auto& node = *node_ptr; + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + if (!IsSupportedMul(node)) { + continue; + } + + const auto& node_ep = node.GetExecutionProviderType(); + if (!node_ep.empty() && node_ep != kWebGpuExecutionProvider) { + continue; + } + + const Node* input0 = GetInputNode(graph, node, 0); + const Node* input1 = GetInputNode(graph, node, 1); + if (input0 == nullptr || input1 == nullptr) { + continue; + } + + const Node* activation_root = nullptr; + const Node* up_matmul = nullptr; + if (IsMatMulNBitsWithoutZeroPointOrGroupIdx(*input1) && + (IsSupportedMul(*input0) || IsSupportedQuickGelu(*input0))) { + activation_root = input0; + up_matmul = input1; + } else if (IsMatMulNBitsWithoutZeroPointOrGroupIdx(*input0) && + (IsSupportedMul(*input1) || IsSupportedQuickGelu(*input1))) { + activation_root = input1; + up_matmul = input0; + } else { + continue; + } + + // The gate-side subgraph between `gate_matmul` and the outer Mul `node` + // takes one of two shapes: + // 1) SiLU decomposed: gate -> Sigmoid -+ + // gate ------------+-> silu_mul -> node + // `activation_root` is the inner Mul (silu_mul); 2 intermediates. + // 2) Fused QuickGelu (post QuickGeluFusion): gate -> QuickGelu -> node + // `activation_root` is the QuickGelu node; 1 intermediate. + const Node* gate_matmul = nullptr; + InlinedVector activation_intermediates; + const char* matched_shape = nullptr; + + if (IsSupportedQuickGelu(*activation_root)) { + const Node* qg_input = GetInputNode(graph, *activation_root, 0); + if (qg_input == nullptr || !IsMatMulNBitsWithoutZeroPointOrGroupIdx(*qg_input)) { + continue; + } + gate_matmul = qg_input; + if (!IsFuseCandidateQuickGelu(graph, *gate_matmul, *up_matmul, *activation_root, node)) { + continue; + } + activation_intermediates.push_back(activation_root); + matched_shape = "quick_gelu"; + } else { + const Node* silu_input0 = GetInputNode(graph, *activation_root, 0); + const Node* silu_input1 = GetInputNode(graph, *activation_root, 1); + if (silu_input0 == nullptr || silu_input1 == nullptr) { + continue; + } + + const Node* sigmoid = nullptr; + if (IsMatMulNBitsWithoutZeroPointOrGroupIdx(*silu_input0) && IsSupportedSigmoid(*silu_input1)) { + gate_matmul = silu_input0; + sigmoid = silu_input1; + } else if (IsMatMulNBitsWithoutZeroPointOrGroupIdx(*silu_input1) && IsSupportedSigmoid(*silu_input0)) { + gate_matmul = silu_input1; + sigmoid = silu_input0; + } else { + continue; + } + + if (!IsFuseCandidateSilu(graph, *gate_matmul, *up_matmul, *sigmoid, *activation_root, node)) { + continue; + } + activation_intermediates.push_back(sigmoid); + activation_intermediates.push_back(activation_root); + matched_shape = "silu"; + } + + LOGS(logger, VERBOSE) << "MatMulNBitsMlpFusion: matched candidate shape='" << matched_shape + << "' output_mul='" << node.Name() + << "' gate='" << gate_matmul->Name() << "' up='" << up_matmul->Name() + << "' attrs={K=" << GetIntAttr(*gate_matmul, "K", -1, true) + << ", N=" << GetIntAttr(*gate_matmul, "N", -1, true) + << ", bits=" << GetIntAttr(*gate_matmul, "bits", 4) + << ", block_size=" << GetIntAttr(*gate_matmul, "block_size", -1, true) + << ", accuracy_level=" << GetIntAttr(*gate_matmul, "accuracy_level", 0) + << "}"; + + bool intermediates_on_supported_ep = true; + for (const Node* intermediate : activation_intermediates) { + const auto& ep = intermediate->GetExecutionProviderType(); + if (!ep.empty() && ep != kWebGpuExecutionProvider) { + intermediates_on_supported_ep = false; + break; + } + } + if ((!gate_matmul->GetExecutionProviderType().empty() && gate_matmul->GetExecutionProviderType() != kWebGpuExecutionProvider) || + (!up_matmul->GetExecutionProviderType().empty() && up_matmul->GetExecutionProviderType() != kWebGpuExecutionProvider) || + !intermediates_on_supported_ep) { + LOGS(logger, VERBOSE) << "MatMulNBitsMlpFusion: skipping candidate due to non-WebGPU EP assignment."; + continue; + } + + const Node* norm = GetNormProducer(graph, *gate_matmul, *up_matmul); + if (norm == nullptr) { + continue; + } + + if (!norm->GetExecutionProviderType().empty() && norm->GetExecutionProviderType() != kWebGpuExecutionProvider) { + continue; + } + + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("K", GetIntAttr(*gate_matmul, "K", -1, true)), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("N", GetIntAttr(*gate_matmul, "N", -1, true)), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("bits", GetIntAttr(*gate_matmul, "bits", 4)), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", GetIntAttr(*gate_matmul, "block_size", -1, true)), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", GetIntAttr(*gate_matmul, "accuracy_level", 0)), attrs); + utils::SetNodeAttribute(utils::MakeAttribute(kActivationAttrName, std::string{kSupportedActivation}), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("epsilon", GetFloatAttr(*norm, "epsilon", 1e-5f)), attrs); + + NodeArg& empty_arg = graph.GetOrCreateNodeArg("", nullptr); + const bool is_skip_sln = norm != nullptr && IsSupportedSkipSimplifiedLayerNormalization(*norm); + + InlinedVector fused_inputs{ + const_cast(norm->InputDefs()[0]), + is_skip_sln ? const_cast(norm->InputDefs()[1]) : &empty_arg, + const_cast(norm->InputDefs()[is_skip_sln ? 2 : 1]), + const_cast(gate_matmul->InputDefs()[1]), + const_cast(gate_matmul->InputDefs()[2]), + HasInput(*gate_matmul, 5) ? const_cast(gate_matmul->InputDefs()[5]) : &empty_arg, + const_cast(up_matmul->InputDefs()[1]), + const_cast(up_matmul->InputDefs()[2]), + HasInput(*up_matmul, 5) ? const_cast(up_matmul->InputDefs()[5]) : &empty_arg, + }; + + InlinedVector fused_outputs{const_cast(node.OutputDefs()[0])}; + const bool preserve_skip_output = is_skip_sln && norm != nullptr && HasProducedOutput(*norm, 3); + if (preserve_skip_output) { + fused_outputs.push_back(const_cast(norm->OutputDefs()[3])); + } + + const auto norm_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(*norm); + const auto gate_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(*gate_matmul); + const auto up_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(*up_matmul); + const auto final_mul_output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(node); + const auto norm_output_edges = preserve_skip_output ? graph_utils::GraphEdge::GetNodeOutputEdges(*norm) + : std::vector{}; + + const std::string output_mul_name = node.Name(); + + graph_utils::RemoveNodeOutputEdges(graph, const_cast(*norm)); + graph.RemoveNode(norm->Index()); + graph_utils::RemoveNodeOutputEdges(graph, const_cast(*gate_matmul)); + graph.RemoveNode(gate_matmul->Index()); + graph_utils::RemoveNodeOutputEdges(graph, const_cast(*up_matmul)); + graph.RemoveNode(up_matmul->Index()); + for (const Node* intermediate : activation_intermediates) { + graph_utils::RemoveNodeOutputEdges(graph, const_cast(*intermediate)); + graph.RemoveNode(intermediate->Index()); + } + graph_utils::RemoveNodeOutputEdges(graph, node); + graph.RemoveNode(node.Index()); + + Node& fused_node = graph.AddNode(graph.GenerateNodeName("MatMulNBitsMlp"), + "MatMulNBitsMlp", + "fused MatMulNBits gated MLP projections", + fused_inputs, + fused_outputs, + &attrs, + kMSDomain); + fused_node.SetExecutionProviderType(kWebGpuExecutionProvider); + + LOGS(logger, VERBOSE) << "MatMulNBitsMlpFusion: created fused node '" << fused_node.Name() + << "' from output_mul='" << output_mul_name << "'"; + + for (const auto& input_edge : norm_input_edges) { + int fused_input_index = input_edge.dst_arg_index; + if (!is_skip_sln && input_edge.dst_arg_index == 1) { + fused_input_index = 2; + } + + graph.AddEdge(input_edge.src_node, fused_node.Index(), input_edge.src_arg_index, fused_input_index); + } + + auto add_input_edge_if_present = [&](const std::vector& edges, + int source_input_index, + int fused_input_index) { + for (const auto& input_edge : edges) { + if (input_edge.dst_arg_index == source_input_index) { + graph.AddEdge(input_edge.src_node, fused_node.Index(), input_edge.src_arg_index, fused_input_index); + } + } + }; + + add_input_edge_if_present(gate_input_edges, 1, 3); + add_input_edge_if_present(gate_input_edges, 2, 4); + add_input_edge_if_present(gate_input_edges, 5, 5); + add_input_edge_if_present(up_input_edges, 1, 6); + add_input_edge_if_present(up_input_edges, 2, 7); + add_input_edge_if_present(up_input_edges, 5, 8); + + for (const auto& output_edge : final_mul_output_edges) { + graph.AddEdge(fused_node.Index(), output_edge.dst_node, 0, output_edge.dst_arg_index); + } + if (preserve_skip_output) { + for (const auto& output_edge : norm_output_edges) { + if (output_edge.src_arg_index == 3) { + graph.AddEdge(fused_node.Index(), output_edge.dst_node, 1, output_edge.dst_arg_index); + } + } + } + + modified = true; + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.h b/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.h new file mode 100644 index 0000000000000..007d21027dca0 --- /dev/null +++ b/onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +// Fuses the SwiGLU gated-activation subgraph (gate / up MatMulNBits projections around a +// SimplifiedLayerNormalization anchor) into a single MatMulNBitsMlp contrib op: +// +// ... -> [Skip]SimplifiedLayerNormalization -+-> MatMulNBits (gate) -+-> Sigmoid -+ +// | | | v +// | | +----------> Mul (silu) -+ +// | +-> MatMulNBits (up) ---------------------------+--> Mul -> out +// +--> (optional) skip residual passthrough --> downstream consumers +// +// becomes +// +// ... -> MatMulNBitsMlp(activation="silu") -+-> out +// +-> (optional) residual passthrough +// +// The downstream "down" projection (a third MatMulNBits that follows the gated-activation +// output) is intentionally NOT part of this fusion -- it remains a separate MatMulNBits node +// in the resulting graph. +// +// Only activation="silu" (i.e. x * Sigmoid(x)) is matched / emitted, and the fusion is restricted +// to the WebGPU EP because MatMulNBitsMlp is a WebGPU-only contrib op. +class MatMulNBitsMlpFusion : public GraphTransformer { + public: + explicit MatMulNBitsMlpFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("MatMulNBitsMlpFusion", compatible_execution_providers) {} + + private: + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.cc b/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.cc new file mode 100644 index 0000000000000..05cd234dba577 --- /dev/null +++ b/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.cc @@ -0,0 +1,372 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/matmul_nbits_qkv_fusion.h" + +#include +#include +#include + +#include "core/graph/graph_utils.h" +#include "core/graph/node_attr_utils.h" +#include "core/optimizer/utils.h" + +namespace onnxruntime { + +namespace { + +bool HasInput(const Node& node, size_t index) { + return index < node.InputDefs().size() && node.InputDefs()[index] != nullptr && !node.InputDefs()[index]->Name().empty(); +} + +bool IsSupportedSimplifiedLayerNormalization(const Node& node) { + return graph_utils::IsSupportedOptypeVersionAndDomain(node, "SimplifiedLayerNormalization", {1}); +} + +bool IsSupportedSkipSimplifiedLayerNormalization(const Node& node) { + return graph_utils::IsSupportedOptypeVersionAndDomain(node, "SkipSimplifiedLayerNormalization", {1}, kMSDomain); +} + +bool IsSupportedNormForFusion(const Node& node) { + return IsSupportedSimplifiedLayerNormalization(node) || IsSupportedSkipSimplifiedLayerNormalization(node); +} + +bool HasProducedOutput(const Node& node, size_t index) { + return index < node.OutputDefs().size() && node.OutputDefs()[index] != nullptr && !node.OutputDefs()[index]->Name().empty(); +} + +bool IsMatMulNBitsWithoutOptionalInputs(const Node& node) { + return graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMulNBits", {1}, kMSDomain) && + !HasInput(node, 3) && !HasInput(node, 4) && !HasInput(node, 5); +} + +int64_t GetIntAttr(const Node& node, const char* name, int64_t default_value, bool required = false) { + const auto* attr = graph_utils::GetNodeAttribute(node, name); + if (attr == nullptr) { + ORT_ENFORCE(!required, "Missing required attribute ", name, " on node ", node.Name()); + return default_value; + } + + return attr->i(); +} + +float GetFloatAttr(const Node& node, const char* name, float default_value) { + const auto* attr = graph_utils::GetNodeAttribute(node, name); + return attr == nullptr ? default_value : attr->f(); +} + +struct QkvNodes { + const Node* q = nullptr; + const Node* k = nullptr; + const Node* v = nullptr; +}; + +bool IsGraphOutput(const Graph& graph, const Node& node, size_t index) { + if (!HasProducedOutput(node, index)) { + return false; + } + const auto& output_name = node.OutputDefs()[index]->Name(); + for (const auto* graph_output : graph.GetOutputs()) { + if (graph_output != nullptr && graph_output->Name() == output_name) { + return true; + } + } + return false; +} + +bool HasOutputConsumers(const Node& node, size_t index) { + if (!HasProducedOutput(node, index)) { + return false; + } + for (auto edge_it = node.OutputEdgesBegin(); edge_it != node.OutputEdgesEnd(); ++edge_it) { + if (static_cast(edge_it->GetSrcArgIndex()) == index) { + return true; + } + } + return false; +} + +// Output 0 of the norm is consumed by the fused op, so it must not be a graph output. +// For SkipSimplifiedLayerNormalization the optional residual sum at output 3 is +// preserved by the fused MatMulNBitsQkv op, so it is allowed to remain a graph output. +// Outputs 1 and 2 (mean / inv_std_var) are not exposed by the fused op and must not +// be graph outputs or feed any downstream nodes. +bool IsSupportedNormGraphOutputsForFusion(const Graph& graph, const Node& norm) { + if (IsGraphOutput(graph, norm, 0)) { + return false; + } + for (size_t i = 1; i < norm.OutputDefs().size(); ++i) { + if (i == 1 || i == 2) { + if (IsGraphOutput(graph, norm, i) || HasOutputConsumers(norm, i)) { + return false; + } + continue; + } + if (!IsGraphOutput(graph, norm, i)) { + continue; + } + if (!(IsSupportedSkipSimplifiedLayerNormalization(norm) && i == 3)) { + return false; + } + } + return true; +} + +std::optional GetQkvNodes(const Graph& graph, const Node& norm) { + if (!HasProducedOutput(norm, 0) || !IsSupportedNormGraphOutputsForFusion(graph, norm)) { + return std::nullopt; + } + + std::array consumers{}; + size_t consumer_index = 0; + for (auto edge_it = norm.OutputEdgesBegin(); edge_it != norm.OutputEdgesEnd(); ++edge_it) { + if (edge_it->GetSrcArgIndex() != 0) { + continue; + } + + if (consumer_index >= consumers.size()) { + return std::nullopt; + } + + if (edge_it->GetDstArgIndex() != 0) { + return std::nullopt; + } + + const Node* consumer = graph.GetNode(edge_it->GetNode().Index()); + if (consumer == nullptr || !IsMatMulNBitsWithoutOptionalInputs(*consumer)) { + return std::nullopt; + } + + consumers[consumer_index++] = consumer; + } + + if (consumer_index != consumers.size()) { + return std::nullopt; + } + + const int64_t n0 = GetIntAttr(*consumers[0], "N", -1, true); + const int64_t n1 = GetIntAttr(*consumers[1], "N", -1, true); + const int64_t n2 = GetIntAttr(*consumers[2], "N", -1, true); + + QkvNodes qkv; + if (n0 != n1 && n1 == n2) { + qkv = {consumers[0], consumers[1], consumers[2]}; + } else if (n1 != n0 && n0 == n2) { + qkv = {consumers[1], consumers[0], consumers[2]}; + } else if (n2 != n0 && n0 == n1) { + qkv = {consumers[2], consumers[0], consumers[1]}; + } else { + return std::nullopt; + } + + return qkv; +} + +bool HasSupportedExecutionProvider(const Node& node) { + const auto& node_ep = node.GetExecutionProviderType(); + return node_ep.empty() || node_ep == kWebGpuExecutionProvider; +} + +bool IsFuseCandidate(const Node& norm, const QkvNodes& qkv) { + if (!IsSupportedNormForFusion(norm) || qkv.q == nullptr || qkv.k == nullptr || qkv.v == nullptr) { + return false; + } + + if (!HasSupportedExecutionProvider(norm) || !HasSupportedExecutionProvider(*qkv.q) || + !HasSupportedExecutionProvider(*qkv.k) || !HasSupportedExecutionProvider(*qkv.v)) { + return false; + } + + const size_t min_norm_inputs = IsSupportedSkipSimplifiedLayerNormalization(norm) ? 3u : 2u; + if (norm.InputDefs().size() < min_norm_inputs || qkv.q->InputDefs().empty() || qkv.k->InputDefs().empty() || qkv.v->InputDefs().empty()) { + return false; + } + + if (qkv.q->InputDefs()[0] != norm.OutputDefs()[0] || qkv.k->InputDefs()[0] != norm.OutputDefs()[0] || + qkv.v->InputDefs()[0] != norm.OutputDefs()[0]) { + return false; + } + + const int64_t q_k = GetIntAttr(*qkv.q, "K", -1, true); + const int64_t k_k = GetIntAttr(*qkv.k, "K", -1, true); + const int64_t v_k = GetIntAttr(*qkv.v, "K", -1, true); + const int64_t q_bits = GetIntAttr(*qkv.q, "bits", 4); + const int64_t k_bits = GetIntAttr(*qkv.k, "bits", 4); + const int64_t v_bits = GetIntAttr(*qkv.v, "bits", 4); + const int64_t q_block_size = GetIntAttr(*qkv.q, "block_size", -1, true); + const int64_t k_block_size = GetIntAttr(*qkv.k, "block_size", -1, true); + const int64_t v_block_size = GetIntAttr(*qkv.v, "block_size", -1, true); + const int64_t q_accuracy_level = GetIntAttr(*qkv.q, "accuracy_level", 0); + const int64_t k_accuracy_level = GetIntAttr(*qkv.k, "accuracy_level", 0); + const int64_t v_accuracy_level = GetIntAttr(*qkv.v, "accuracy_level", 0); + + return q_k == k_k && q_k == v_k && + q_bits == k_bits && q_bits == v_bits && q_bits == 4 && + q_block_size == k_block_size && q_block_size == v_block_size && q_block_size == 32 && + q_accuracy_level == k_accuracy_level && q_accuracy_level == v_accuracy_level; +} + +} // namespace + +Status MatMulNBitsQkvFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + for (auto node_index : node_topology_list) { + auto* node_ptr = graph.GetNode(node_index); + if (node_ptr == nullptr) { + continue; + } + + auto& node = *node_ptr; + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + if (!IsSupportedNormForFusion(node)) { + continue; + } + + const auto qkv_nodes = GetQkvNodes(graph, node); + if (!qkv_nodes || !IsFuseCandidate(node, *qkv_nodes)) { + continue; + } + + const int64_t K = GetIntAttr(*qkv_nodes->q, "K", -1, true); + const int64_t Nq = GetIntAttr(*qkv_nodes->q, "N", -1, true); + const int64_t Nkv = GetIntAttr(*qkv_nodes->k, "N", -1, true); + const int64_t bits = GetIntAttr(*qkv_nodes->q, "bits", 4); + const int64_t block_size = GetIntAttr(*qkv_nodes->q, "block_size", -1, true); + const int64_t accuracy_level = GetIntAttr(*qkv_nodes->q, "accuracy_level", 0); + const float epsilon = GetFloatAttr(node, "epsilon", 1e-6f); + + const bool is_skip_sln = IsSupportedSkipSimplifiedLayerNormalization(node); + + LOGS(logger, VERBOSE) << "MatMulNBitsQkvFusion: matched norm='" << node.Name() + << "' q='" << qkv_nodes->q->Name() << "' k='" << qkv_nodes->k->Name() + << "' v='" << qkv_nodes->v->Name() << "' attrs={K=" << K + << ", Nq=" << Nq << ", Nkv=" << Nkv << ", bits=" << bits + << ", block_size=" << block_size << ", accuracy_level=" << accuracy_level + << ", epsilon=" << epsilon << ", skip_sln=" << is_skip_sln << "}"; + + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("K", K), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("Nq", Nq), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("Nkv", Nkv), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("bits", bits), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("epsilon", epsilon), attrs); + + NodeArg& empty_arg = graph.GetOrCreateNodeArg("", nullptr); + + InlinedVector fused_inputs{ + const_cast(node.InputDefs()[0]), + is_skip_sln ? const_cast(node.InputDefs()[1]) : &empty_arg, + const_cast(node.InputDefs()[is_skip_sln ? 2 : 1]), + const_cast(qkv_nodes->q->InputDefs()[1]), + const_cast(qkv_nodes->q->InputDefs()[2]), + const_cast(qkv_nodes->k->InputDefs()[1]), + const_cast(qkv_nodes->k->InputDefs()[2]), + const_cast(qkv_nodes->v->InputDefs()[1]), + const_cast(qkv_nodes->v->InputDefs()[2]), + }; + + InlinedVector fused_outputs{ + const_cast(qkv_nodes->q->OutputDefs()[0]), + const_cast(qkv_nodes->k->OutputDefs()[0]), + const_cast(qkv_nodes->v->OutputDefs()[0]), + }; + if (is_skip_sln && HasProducedOutput(node, 3)) { + fused_outputs.push_back(const_cast(node.OutputDefs()[3])); + } + + const bool has_residual_output = is_skip_sln && HasProducedOutput(node, 3); + const std::string norm_name = node.Name(); + const std::string q_name = qkv_nodes->q->Name(); + const std::string k_name = qkv_nodes->k->Name(); + const std::string v_name = qkv_nodes->v->Name(); + + const auto norm_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(node); + const auto q_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(*qkv_nodes->q); + const auto k_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(*qkv_nodes->k); + const auto v_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(*qkv_nodes->v); + const auto q_output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(*qkv_nodes->q); + const auto k_output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(*qkv_nodes->k); + const auto v_output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(*qkv_nodes->v); + const auto norm_output_edges = has_residual_output + ? graph_utils::GraphEdge::GetNodeOutputEdges(node) + : std::vector{}; + graph_utils::RemoveNodeOutputEdges(graph, const_cast(*qkv_nodes->q)); + graph.RemoveNode(qkv_nodes->q->Index()); + graph_utils::RemoveNodeOutputEdges(graph, const_cast(*qkv_nodes->k)); + graph.RemoveNode(qkv_nodes->k->Index()); + graph_utils::RemoveNodeOutputEdges(graph, const_cast(*qkv_nodes->v)); + graph.RemoveNode(qkv_nodes->v->Index()); + graph_utils::RemoveNodeOutputEdges(graph, node); + graph.RemoveNode(node.Index()); + + Node& fused_node = graph.AddNode(graph.GenerateNodeName("MatMulNBitsQkv"), + "MatMulNBitsQkv", + "fused SimplifiedLayerNormalization with Q/K/V MatMulNBits projections", + fused_inputs, + fused_outputs, + &attrs, + kMSDomain); + fused_node.SetExecutionProviderType(kWebGpuExecutionProvider); + + LOGS(logger, VERBOSE) << "MatMulNBitsQkvFusion: created fused node '" << fused_node.Name() + << "' from norm='" << norm_name << "' q='" << q_name + << "' k='" << k_name << "' v='" << v_name << "'"; + + for (const auto& input_edge : norm_input_edges) { + int fused_input_index = input_edge.dst_arg_index; + if (!is_skip_sln && input_edge.dst_arg_index == 1) { + fused_input_index = 2; + } + + graph.AddEdge(input_edge.src_node, fused_node.Index(), input_edge.src_arg_index, fused_input_index); + } + + // Q/K/V weight + scale tensors are usually initializers, but if any of them is produced + // by an upstream node we must rewire that producer edge to the fused input slot. + auto add_input_edge_if_present = [&](const std::vector& edges, + int source_input_index, + int fused_input_index) { + for (const auto& input_edge : edges) { + if (input_edge.dst_arg_index == source_input_index) { + graph.AddEdge(input_edge.src_node, fused_node.Index(), input_edge.src_arg_index, fused_input_index); + } + } + }; + + add_input_edge_if_present(q_input_edges, 1, 3); // q_weight + add_input_edge_if_present(q_input_edges, 2, 4); // q_scales + add_input_edge_if_present(k_input_edges, 1, 5); // k_weight + add_input_edge_if_present(k_input_edges, 2, 6); // k_scales + add_input_edge_if_present(v_input_edges, 1, 7); // v_weight + add_input_edge_if_present(v_input_edges, 2, 8); // v_scales + + for (const auto& output_edge : q_output_edges) { + graph.AddEdge(fused_node.Index(), output_edge.dst_node, 0, output_edge.dst_arg_index); + } + for (const auto& output_edge : k_output_edges) { + graph.AddEdge(fused_node.Index(), output_edge.dst_node, 1, output_edge.dst_arg_index); + } + for (const auto& output_edge : v_output_edges) { + graph.AddEdge(fused_node.Index(), output_edge.dst_node, 2, output_edge.dst_arg_index); + } + if (has_residual_output) { + for (const auto& output_edge : norm_output_edges) { + if (output_edge.src_arg_index == 3) { + graph.AddEdge(fused_node.Index(), output_edge.dst_node, 3, output_edge.dst_arg_index); + } + } + } + + modified = true; + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.h b/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.h new file mode 100644 index 0000000000000..fcbbb78457f52 --- /dev/null +++ b/onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +// Fuses three sibling MatMulNBits Q/K/V projections that share a SimplifiedLayerNormalization +// (or SkipSimplifiedLayerNormalization) anchor into a single MatMulNBitsQkv contrib op: +// +// ... -> [Skip]SimplifiedLayerNormalization -+-> MatMulNBits (Q proj) -+ +// | +-> MatMulNBits (K proj) -+--> downstream consumers +// | +-> MatMulNBits (V proj) -+ +// +--> (optional) skip residual passthrough --> downstream consumers +// +// becomes +// +// ... -> [Skip]SimplifiedLayerNormalization --> MatMulNBitsQkv -+-> Q out +// +-> K out +// +-> V out +// +-> (optional) residual passthrough +// +// The fusion is restricted to the WebGPU EP because MatMulNBitsQkv is a WebGPU-only contrib op. +class MatMulNBitsQkvFusion : public GraphTransformer { + public: + explicit MatMulNBitsQkvFusion( + const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("MatMulNBitsQkvFusion", compatible_execution_providers) {} + + private: + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc index 7d4ae8c2197ff..3afedea30adaf 100644 --- a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc @@ -162,8 +162,6 @@ Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContex const size_t axis = NormalizeAxis(axis_, x_shape.NumDimensions()); const uint32_t norm_count = onnxruntime::narrow(x_shape.SizeToDimension(axis)); const int64_t norm_size = x_shape.SizeFromDimension(axis); - const int components = GetMaxComponents(norm_size); - const uint32_t norm_size_vectorized = onnxruntime::narrow((norm_size + components - 1) / components); const auto scale_size = scale->Shape().Size(); const auto bias_size = (bias) ? bias->Shape().Size() : 0; @@ -192,6 +190,28 @@ Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContex return Status::OK(); } + return RunLayerNormProgram(context, x, scale, bias, epsilon_, norm_count, norm_size, + simplified, y, mean, inv_std_dev); +} + +Status RunLayerNormProgram(ComputeContext& context, + const Tensor* x, + const Tensor* scale, + const Tensor* bias, + float epsilon, + uint32_t norm_count, + int64_t norm_size, + bool simplified, + Tensor* y, + Tensor* mean, + Tensor* inv_std_dev) { + if (x->Shape().Size() == 0) { + return Status::OK(); + } + + const int components = GetMaxComponents(norm_size); + const uint32_t norm_size_vectorized = onnxruntime::narrow((norm_size + components - 1) / components); + // Check if we should use split norm dimension optimization const bool split_norm_dim = norm_size % 512 == 0 && norm_count == 1; @@ -215,7 +235,7 @@ Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContex {static_cast(norm_size_vectorized)}, }) .AddUniformVariables({ - {static_cast(epsilon_)}, + {static_cast(epsilon)}, }); if (split_norm_dim) { diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.h b/onnxruntime/core/providers/webgpu/nn/layer_norm.h index 112b152d37130..a6323dc7721d4 100644 --- a/onnxruntime/core/providers/webgpu/nn/layer_norm.h +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.h @@ -56,5 +56,21 @@ class LayerNorm final : public WebGpuKernel { int64_t stash_type_; }; +// Configures and dispatches a LayerNormProgram. Centralizes the program-setup logic +// (uniform variables, components, split_norm_dim heuristic, workgroup sizing) so callers +// other than the LayerNorm kernel (e.g. fused MatMulNBits ops) do not need to duplicate it. +// `bias`, `mean` and `inv_std_dev` may be nullptr. +Status RunLayerNormProgram(ComputeContext& context, + const Tensor* x, + const Tensor* scale, + const Tensor* bias, + float epsilon, + uint32_t norm_count, + int64_t norm_size, + bool simplified, + Tensor* y, + Tensor* mean, + Tensor* inv_std_dev); + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_utils_test.cc b/onnxruntime/test/optimizer/graph_transform_utils_test.cc index f7bfa3055f96d..302768b9fbdc7 100644 --- a/onnxruntime/test/optimizer/graph_transform_utils_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_utils_test.cc @@ -70,6 +70,7 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) { ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); #endif } + TEST(GraphTransformerUtilsTests, TestDQMatMulNBitsFusionConfigWithContribGating) { SessionOptions session_options; const auto status = session_options.config_options.AddConfigEntry( diff --git a/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc b/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc new file mode 100644 index 0000000000000..038876b6c5777 --- /dev/null +++ b/onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc @@ -0,0 +1,550 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/graph_utils.h" +#include "core/graph/node_attr_utils.h" +#include "core/optimizer/graph_transformer_mgr.h" +#include "core/optimizer/matmul_nbits_mlp_fusion.h" +#include "core/optimizer/utils.h" +#include "core/session/onnxruntime_session_options_config_keys.h" + +#include "test/util/include/asserts.h" +#include "test/util/include/default_providers.h" +#include "test/unittest_util/framework_test_utils.h" +#include "test/unittest_util/graph_transform_test_builder.h" +#include "test/optimizer/graph_transform_test_fixture.h" +#include "test/optimizer/webgpu_fusion_test_util.h" + +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +#if !defined(DISABLE_CONTRIB_OPS) + +namespace { + +constexpr const char* kExpectedActivation = "silu"; + +enum class NormAnchorKind { + kSimplified, + kSkipSimplified, +}; + +enum class SkipOutputKind { + kNone, + kGraphOutput, +}; + +enum class BiasKind { + kWithBias, + kNoBias, +}; + +// Selects the gate-activation subgraph emitted by the test builder. +// kSilu : gate -> Sigmoid -+ +// gate ------------+-> Mul -> final_mul +// kQuickGelu : gate -> com.microsoft::QuickGelu(alpha=1.0) -> final_mul +// (the shape QuickGeluFusion produces after PR #28410.) +enum class ActivationShape { + kSilu, + kQuickGelu, +}; + +void SetWebGpuProvider(Node& node) { + node.SetExecutionProviderType(kWebGpuExecutionProvider); +} + +NodeAttributes MakeMatMulNBitsAttrs(int64_t k, int64_t n, int64_t block_size, int64_t bits, int64_t accuracy_level) { + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("K", k), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("N", n), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("bits", bits), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level), attrs); + return attrs; +} + +Status CheckMatMulNBitsMlpFusedGraphImpl(const Graph& graph, NormAnchorKind norm_anchor_kind) { + const auto op_to_count = CountOpsInGraph(graph); + if (OpCount(op_to_count, "com.microsoft.MatMulNBitsMlp") != 1 || + OpCount(op_to_count, "com.microsoft.MatMulNBits") != 0 || + OpCount(op_to_count, "SimplifiedLayerNormalization") != 0 || + OpCount(op_to_count, "com.microsoft.SkipSimplifiedLayerNormalization") != 0 || + OpCount(op_to_count, "Sigmoid") != 0 || + OpCount(op_to_count, "com.microsoft.QuickGelu") != 0 || + OpCount(op_to_count, "Mul") != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unexpected operator counts after MatMulNBitsMlpFusion."); + } + + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "MatMulNBitsMlp") { + ORT_RETURN_IF_NOT(node.Domain() == kMSDomain, "Fused node must be in com.microsoft domain."); + ORT_RETURN_IF_NOT(node.GetExecutionProviderType() == kWebGpuExecutionProvider, + "Fused node must be assigned to WebGPU EP."); + ORT_RETURN_IF_NOT(node.InputDefs().size() == 9u, "Fused node must have 9 inputs."); + const bool has_skip = node.InputDefs()[1] != nullptr && !node.InputDefs()[1]->Name().empty(); + const bool has_norm_scale = node.InputDefs()[2] != nullptr && !node.InputDefs()[2]->Name().empty(); + ORT_RETURN_IF_NOT(has_skip == (norm_anchor_kind == NormAnchorKind::kSkipSimplified), + "Unexpected skip input presence on fused node."); + ORT_RETURN_IF_NOT(has_norm_scale, + "Expected norm_scale input on fused node."); + ORT_RETURN_IF_NOT(node.OutputDefs().size() == 1u, + "Non-passthrough fusion should expose only the Y output."); + + const auto* activation_attr = graph_utils::GetNodeAttribute(node, "activation"); + ORT_RETURN_IF_NOT(activation_attr != nullptr && activation_attr->s() == kExpectedActivation, + "Fused node must carry activation='silu'."); + } + } + + return Status::OK(); +} + +Status CheckMatMulNBitsMlpSimplifiedFusedGraph(const Graph& graph) { + return CheckMatMulNBitsMlpFusedGraphImpl(graph, NormAnchorKind::kSimplified); +} + +Status CheckMatMulNBitsMlpSkipFusedGraph(const Graph& graph) { + return CheckMatMulNBitsMlpFusedGraphImpl(graph, NormAnchorKind::kSkipSimplified); +} + +Status CheckMatMulNBitsMlpSkipOutputPassthroughFusedGraph(const Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + if (OpCount(op_to_count, "com.microsoft.MatMulNBitsMlp") != 1 || + OpCount(op_to_count, "com.microsoft.MatMulNBits") != 0 || + OpCount(op_to_count, "SimplifiedLayerNormalization") != 0 || + OpCount(op_to_count, "com.microsoft.SkipSimplifiedLayerNormalization") != 0 || + OpCount(op_to_count, "Sigmoid") != 0 || + OpCount(op_to_count, "com.microsoft.QuickGelu") != 0 || + OpCount(op_to_count, "Mul") != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Unexpected operator counts after MatMulNBitsMlpFusion with skip output passthrough."); + } + + bool found_fused_node = false; + for (const auto& node : graph.Nodes()) { + if (node.OpType() != "MatMulNBitsMlp") { + continue; + } + + found_fused_node = true; + ORT_RETURN_IF_NOT(node.Domain() == kMSDomain, "Fused node must be in com.microsoft domain."); + ORT_RETURN_IF_NOT(node.GetExecutionProviderType() == kWebGpuExecutionProvider, + "Fused node must be assigned to WebGPU EP."); + ORT_RETURN_IF_NOT(node.InputDefs().size() == 9u, "Fused node must have 9 inputs."); + ORT_RETURN_IF_NOT(node.OutputDefs().size() == 2u, + "Fused node must expose Y and the passthrough residual output."); + const bool has_skip = node.InputDefs()[1] != nullptr && !node.InputDefs()[1]->Name().empty(); + const bool has_norm_scale = node.InputDefs()[2] != nullptr && !node.InputDefs()[2]->Name().empty(); + ORT_RETURN_IF_NOT(has_skip && has_norm_scale, + "Skip output passthrough should remain fused into MatMulNBitsMlp."); + ORT_RETURN_IF_NOT(node.OutputDefs()[1] != nullptr && !node.OutputDefs()[1]->Name().empty(), + "Expected fused node to preserve the residual passthrough output."); + + const auto* activation_attr = graph_utils::GetNodeAttribute(node, "activation"); + ORT_RETURN_IF_NOT(activation_attr != nullptr && activation_attr->s() == kExpectedActivation, + "Fused node must carry activation='silu'."); + } + + ORT_RETURN_IF_NOT(found_fused_node, "Expected a MatMulNBitsMlp node in the transformed graph."); + return Status::OK(); +} + +void BuildMatMulNBitsMlpWebGpuPatternImpl(ModelTestBuilder& builder, + NormAnchorKind norm_anchor_kind, + SkipOutputKind skip_output_kind = SkipOutputKind::kNone, + BiasKind bias_kind = BiasKind::kWithBias, + ActivationShape activation_shape = ActivationShape::kSilu) { + constexpr int64_t k = 32; + constexpr int64_t n = 8; + constexpr int64_t block_size = 32; + constexpr int64_t bits = 4; + constexpr int64_t accuracy_level = 4; + constexpr int64_t blob_size = block_size * bits / 8; + + NodeArg* input = builder.MakeInput( + std::vector{1, k}, + std::vector{ + MLFloat16(-1.0f), MLFloat16(-0.875f), MLFloat16(-0.75f), MLFloat16(-0.625f), + MLFloat16(-0.5f), MLFloat16(-0.375f), MLFloat16(-0.25f), MLFloat16(-0.125f), + MLFloat16(0.125f), MLFloat16(0.25f), MLFloat16(0.375f), MLFloat16(0.5f), + MLFloat16(0.625f), MLFloat16(0.75f), MLFloat16(0.875f), MLFloat16(1.0f), + MLFloat16(-1.0f), MLFloat16(-0.875f), MLFloat16(-0.75f), MLFloat16(-0.625f), + MLFloat16(-0.5f), MLFloat16(-0.375f), MLFloat16(-0.25f), MLFloat16(-0.125f), + MLFloat16(0.125f), MLFloat16(0.25f), MLFloat16(0.375f), MLFloat16(0.5f), + MLFloat16(0.625f), MLFloat16(0.75f), MLFloat16(0.875f), MLFloat16(1.0f)}); + NodeArg* optional_tensor = builder.MakeOptionalTensor(); + + NodeArg* gate_weight = builder.MakeInitializer({n, 1, blob_size}, uint8_t{0}, uint8_t{15}); + NodeArg* gate_scale = builder.MakeInitializer({n, 1}, MLFloat16(1.0f), MLFloat16(1.0f)); + NodeArg* gate_bias = (bias_kind == BiasKind::kWithBias) + ? builder.MakeInitializer({n}, MLFloat16(0.0f), MLFloat16(0.0f)) + : optional_tensor; + NodeArg* up_weight = builder.MakeInitializer({n, 1, blob_size}, uint8_t{0}, uint8_t{15}); + NodeArg* up_scale = builder.MakeInitializer({n, 1}, MLFloat16(1.0f), MLFloat16(1.0f)); + NodeArg* up_bias = (bias_kind == BiasKind::kWithBias) + ? builder.MakeInitializer({n}, MLFloat16(0.0f), MLFloat16(0.0f)) + : optional_tensor; + + NodeArg* normalized_input = builder.MakeIntermediate(std::vector{1, k}); + NodeArg* gate_out = builder.MakeIntermediate(std::vector{1, n}); + NodeArg* up_out = builder.MakeIntermediate(std::vector{1, n}); + NodeArg* activated_out = builder.MakeIntermediate(std::vector{1, n}); + NodeArg* output = builder.MakeOutput(std::vector{1, n}); + + NodeAttributes matmul_attrs = MakeMatMulNBitsAttrs(k, n, block_size, bits, accuracy_level); + Node* norm = nullptr; + if (norm_anchor_kind == NormAnchorKind::kSkipSimplified) { + NodeArg* skip_input = builder.MakeInput( + std::vector{1, k}, + std::vector(static_cast(k), MLFloat16(0.25f))); + NodeArg* norm_scale = builder.MakeInitializer({k}, MLFloat16(1.0f), MLFloat16(1.0f)); + NodeArg* optional_norm_output_1 = builder.MakeOptionalTensor(); + NodeArg* optional_norm_output_2 = builder.MakeOptionalTensor(); + std::vector norm_outputs{normalized_input}; + if (skip_output_kind == SkipOutputKind::kGraphOutput) { + NodeArg* residual_output = builder.MakeOutput(std::vector{1, k}); + norm_outputs.push_back(optional_norm_output_1); + norm_outputs.push_back(optional_norm_output_2); + norm_outputs.push_back(residual_output); + } + norm = &builder.AddNode("SkipSimplifiedLayerNormalization", {input, skip_input, norm_scale}, norm_outputs, + kMSDomain); + } else { + NodeArg* norm_scale = builder.MakeInitializer({k}, MLFloat16(1.0f), MLFloat16(1.0f)); + norm = &builder.AddNode("SimplifiedLayerNormalization", {input, norm_scale}, {normalized_input}); + } + + Node& gate_matmul = builder.AddNode("MatMulNBits", + {normalized_input, gate_weight, gate_scale, optional_tensor, optional_tensor, + gate_bias}, + {gate_out}, kMSDomain, &matmul_attrs); + Node& up_matmul = builder.AddNode("MatMulNBits", + {normalized_input, up_weight, up_scale, optional_tensor, optional_tensor, + up_bias}, + {up_out}, kMSDomain, &matmul_attrs); + + Node* sigmoid = nullptr; + Node* silu_mul = nullptr; + Node* quick_gelu = nullptr; + if (activation_shape == ActivationShape::kSilu) { + NodeArg* sigmoid_out = builder.MakeIntermediate(std::vector{1, n}); + sigmoid = &builder.AddNode("Sigmoid", {gate_out}, {sigmoid_out}); + silu_mul = &builder.AddNode("Mul", {gate_out, sigmoid_out}, {activated_out}); + } else { + NodeAttributes quick_gelu_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("alpha", 1.0f), quick_gelu_attrs); + quick_gelu = &builder.AddNode("QuickGelu", {gate_out}, {activated_out}, kMSDomain, &quick_gelu_attrs); + } + Node& final_mul = builder.AddNode("Mul", {activated_out, up_out}, {output}); + + if (norm != nullptr) { + SetWebGpuProvider(*norm); + } + SetWebGpuProvider(gate_matmul); + SetWebGpuProvider(up_matmul); + if (sigmoid != nullptr) { + SetWebGpuProvider(*sigmoid); + } + if (silu_mul != nullptr) { + SetWebGpuProvider(*silu_mul); + } + if (quick_gelu != nullptr) { + SetWebGpuProvider(*quick_gelu); + } + SetWebGpuProvider(final_mul); +} + +void BuildMatMulNBitsMlpSimplifiedWebGpuPattern(ModelTestBuilder& builder) { + BuildMatMulNBitsMlpWebGpuPatternImpl(builder, NormAnchorKind::kSimplified); +} + +void BuildMatMulNBitsMlpSkipWebGpuPattern(ModelTestBuilder& builder) { + BuildMatMulNBitsMlpWebGpuPatternImpl(builder, NormAnchorKind::kSkipSimplified); +} + +void BuildMatMulNBitsMlpSkipOutputPassthroughWebGpuPattern(ModelTestBuilder& builder) { + BuildMatMulNBitsMlpWebGpuPatternImpl(builder, NormAnchorKind::kSkipSimplified, SkipOutputKind::kGraphOutput); +} + +void BuildMatMulNBitsMlpSimplifiedWebGpuPatternNoBias(ModelTestBuilder& builder) { + BuildMatMulNBitsMlpWebGpuPatternImpl(builder, NormAnchorKind::kSimplified, SkipOutputKind::kNone, + BiasKind::kNoBias); +} + +void BuildMatMulNBitsMlpSkipWebGpuPatternNoBias(ModelTestBuilder& builder) { + BuildMatMulNBitsMlpWebGpuPatternImpl(builder, NormAnchorKind::kSkipSimplified, SkipOutputKind::kNone, + BiasKind::kNoBias); +} + +void BuildMatMulNBitsMlpSkipOutputPassthroughWebGpuPatternNoBias(ModelTestBuilder& builder) { + BuildMatMulNBitsMlpWebGpuPatternImpl(builder, NormAnchorKind::kSkipSimplified, SkipOutputKind::kGraphOutput, + BiasKind::kNoBias); +} + +void BuildMatMulNBitsMlpSimplifiedQuickGeluWebGpuPattern(ModelTestBuilder& builder) { + BuildMatMulNBitsMlpWebGpuPatternImpl(builder, NormAnchorKind::kSimplified, SkipOutputKind::kNone, + BiasKind::kWithBias, ActivationShape::kQuickGelu); +} + +void BuildMatMulNBitsMlpSkipQuickGeluWebGpuPattern(ModelTestBuilder& builder) { + BuildMatMulNBitsMlpWebGpuPatternImpl(builder, NormAnchorKind::kSkipSimplified, SkipOutputKind::kNone, + BiasKind::kWithBias, ActivationShape::kQuickGelu); +} + +} // namespace + +TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionFusesSimplifiedWebGpuPattern) { + ASSERT_STATUS_OK(TestGraphTransformer( + BuildMatMulNBitsMlpSimplifiedWebGpuPattern, + 21, + *logger_, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + TransformerLevel::Level2, + 1, + nullptr, + CheckMatMulNBitsMlpSimplifiedFusedGraph)); +} + +TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionFusesSkipWebGpuPattern) { + ASSERT_STATUS_OK(TestGraphTransformer( + BuildMatMulNBitsMlpSkipWebGpuPattern, + 21, + *logger_, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + TransformerLevel::Level2, + 1, + nullptr, + CheckMatMulNBitsMlpSkipFusedGraph)); +} + +TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionFusesSkipWebGpuPatternWithResidualOutputPassthrough) { + ASSERT_STATUS_OK(TestGraphTransformer( + BuildMatMulNBitsMlpSkipOutputPassthroughWebGpuPattern, + 21, + *logger_, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + TransformerLevel::Level2, + 1, + nullptr, + CheckMatMulNBitsMlpSkipOutputPassthroughFusedGraph)); +} + +TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSimplifiedWebGpuResults) { + if (!DefaultWebGpuExecutionProvider()) { + GTEST_SKIP() << "WebGPU EP unavailable in this build."; + } + + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + ASSERT_STATUS_OK(CheckMatMulNBitsMlpSimplifiedFusedGraph(session.GetGraph())); + }; + + RunWebGpuFusionTransformerTest( + BuildMatMulNBitsMlpSimplifiedWebGpuPattern, + check_transformed_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21, + 1e-3, + 1e-3, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + []() { return DefaultWebGpuExecutionProvider(); }); +} + +TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSkipWebGpuResults) { + if (!DefaultWebGpuExecutionProvider()) { + GTEST_SKIP() << "WebGPU EP unavailable in this build."; + } + + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + ASSERT_STATUS_OK(CheckMatMulNBitsMlpSkipFusedGraph(session.GetGraph())); + }; + + RunWebGpuFusionTransformerTest( + BuildMatMulNBitsMlpSkipWebGpuPattern, + check_transformed_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21, + 1e-3, + 1e-3, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + []() { return DefaultWebGpuExecutionProvider(); }); +} + +TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSkipWebGpuResultsWithResidualOutputPassthrough) { + if (!DefaultWebGpuExecutionProvider()) { + GTEST_SKIP() << "WebGPU EP unavailable in this build."; + } + + auto add_session_options = [](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableSpecifiedOptimizers, + "EliminateIdentity")); + }; + + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + ASSERT_STATUS_OK(CheckMatMulNBitsMlpSkipOutputPassthroughFusedGraph(session.GetGraph())); + }; + + RunWebGpuFusionTransformerTest( + BuildMatMulNBitsMlpSkipOutputPassthroughWebGpuPattern, + check_transformed_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21, + 1e-3, + 1e-3, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + []() { return DefaultWebGpuExecutionProvider(); }, + add_session_options); +} + +TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSimplifiedWebGpuResultsNoBias) { + if (!DefaultWebGpuExecutionProvider()) { + GTEST_SKIP() << "WebGPU EP unavailable in this build."; + } + + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + ASSERT_STATUS_OK(CheckMatMulNBitsMlpSimplifiedFusedGraph(session.GetGraph())); + }; + + RunWebGpuFusionTransformerTest( + BuildMatMulNBitsMlpSimplifiedWebGpuPatternNoBias, + check_transformed_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21, + 1e-3, + 1e-3, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + []() { return DefaultWebGpuExecutionProvider(); }); +} + +TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSkipWebGpuResultsNoBias) { + if (!DefaultWebGpuExecutionProvider()) { + GTEST_SKIP() << "WebGPU EP unavailable in this build."; + } + + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + ASSERT_STATUS_OK(CheckMatMulNBitsMlpSkipFusedGraph(session.GetGraph())); + }; + + RunWebGpuFusionTransformerTest( + BuildMatMulNBitsMlpSkipWebGpuPatternNoBias, + check_transformed_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21, + 1e-3, + 1e-3, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + []() { return DefaultWebGpuExecutionProvider(); }); +} + +TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSkipWebGpuResultsWithResidualOutputPassthroughNoBias) { + if (!DefaultWebGpuExecutionProvider()) { + GTEST_SKIP() << "WebGPU EP unavailable in this build."; + } + + auto add_session_options = [](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableSpecifiedOptimizers, + "EliminateIdentity")); + }; + + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + ASSERT_STATUS_OK(CheckMatMulNBitsMlpSkipOutputPassthroughFusedGraph(session.GetGraph())); + }; + + RunWebGpuFusionTransformerTest( + BuildMatMulNBitsMlpSkipOutputPassthroughWebGpuPatternNoBias, + check_transformed_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21, + 1e-3, + 1e-3, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + []() { return DefaultWebGpuExecutionProvider(); }, + add_session_options); +} + +// QuickGelu-shape tests: after PR #28410, QuickGeluFusion collapses the +// Sigmoid+Mul subgraph in SwiGLU MLPs into a single com.microsoft::QuickGelu +// node (with alpha=1.0). MatMulNBitsMlpFusion must still recognize this shape +// so the fused MLP kernel keeps firing on WebGPU models. +TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionFusesSimplifiedQuickGeluWebGpuPattern) { + ASSERT_STATUS_OK(TestGraphTransformer( + BuildMatMulNBitsMlpSimplifiedQuickGeluWebGpuPattern, + 21, + *logger_, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + TransformerLevel::Level2, + 1, + nullptr, + CheckMatMulNBitsMlpSimplifiedFusedGraph)); +} + +TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionFusesSkipQuickGeluWebGpuPattern) { + ASSERT_STATUS_OK(TestGraphTransformer( + BuildMatMulNBitsMlpSkipQuickGeluWebGpuPattern, + 21, + *logger_, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + TransformerLevel::Level2, + 1, + nullptr, + CheckMatMulNBitsMlpSkipFusedGraph)); +} + +TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSimplifiedQuickGeluWebGpuResults) { + if (!DefaultWebGpuExecutionProvider()) { + GTEST_SKIP() << "WebGPU EP unavailable in this build."; + } + + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + ASSERT_STATUS_OK(CheckMatMulNBitsMlpSimplifiedFusedGraph(session.GetGraph())); + }; + + // The unfused baseline runs the WebGPU `QuickGelu` kernel (a branchy + // sigmoid-by-cases implementation), while the fused kernel evaluates SiLU + // directly via `1 / (1 + exp(-x))`. The two decompositions are + // mathematically equivalent but produce slightly different fp16 rounding + // around the SiLU midpoint, so we use a marginally looser tolerance here. + RunWebGpuFusionTransformerTest( + BuildMatMulNBitsMlpSimplifiedQuickGeluWebGpuPattern, + check_transformed_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21, + 1.5e-2, + 1.5e-2, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + []() { return DefaultWebGpuExecutionProvider(); }); +} + +TEST_F(GraphTransformationTests, MatMulNBitsMlpFusionMatchesUnfusedSkipQuickGeluWebGpuResults) { + if (!DefaultWebGpuExecutionProvider()) { + GTEST_SKIP() << "WebGPU EP unavailable in this build."; + } + + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + ASSERT_STATUS_OK(CheckMatMulNBitsMlpSkipFusedGraph(session.GetGraph())); + }; + + RunWebGpuFusionTransformerTest( + BuildMatMulNBitsMlpSkipQuickGeluWebGpuPattern, + check_transformed_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21, + 5e-3, + 5e-3, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + []() { return DefaultWebGpuExecutionProvider(); }); +} + +#endif // !defined(DISABLE_CONTRIB_OPS) + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/matmul_nbits_qkv_fusion_test.cc b/onnxruntime/test/optimizer/matmul_nbits_qkv_fusion_test.cc new file mode 100644 index 0000000000000..754c23155bf47 --- /dev/null +++ b/onnxruntime/test/optimizer/matmul_nbits_qkv_fusion_test.cc @@ -0,0 +1,287 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/node_attr_utils.h" +#include "core/optimizer/graph_transformer_mgr.h" +#include "core/optimizer/matmul_nbits_qkv_fusion.h" +#include "core/optimizer/utils.h" +#include "core/session/onnxruntime_session_options_config_keys.h" + +#include "test/util/include/asserts.h" +#include "test/util/include/default_providers.h" +#include "test/unittest_util/framework_test_utils.h" +#include "test/unittest_util/graph_transform_test_builder.h" +#include "test/optimizer/graph_transform_test_fixture.h" +#include "test/optimizer/webgpu_fusion_test_util.h" + +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +#if !defined(DISABLE_CONTRIB_OPS) + +namespace { + +void SetWebGpuProvider(Node& node) { + node.SetExecutionProviderType(kWebGpuExecutionProvider); +} + +NodeAttributes MakeMatMulNBitsAttrs(int64_t k, int64_t n, int64_t block_size, int64_t bits, int64_t accuracy_level) { + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("K", k), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("N", n), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("bits", bits), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level), attrs); + return attrs; +} + +Status CheckMatMulNBitsQkvFusedGraphImpl(const Graph& graph, bool expect_skip_sln_output, bool expect_skip_input) { + const auto op_to_count = CountOpsInGraph(graph); + if (OpCount(op_to_count, "com.microsoft.MatMulNBitsQkv") != 1 || + OpCount(op_to_count, "SimplifiedLayerNormalization") != 0 || + OpCount(op_to_count, "com.microsoft.SkipSimplifiedLayerNormalization") != 0 || + OpCount(op_to_count, "com.microsoft.MatMulNBits") != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Unexpected operator counts after MatMulNBitsQkvFusion."); + } + + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "MatMulNBitsQkv") { + ORT_RETURN_IF_NOT(node.Domain() == kMSDomain, "Fused node must be in com.microsoft domain."); + ORT_RETURN_IF_NOT(node.GetExecutionProviderType() == kWebGpuExecutionProvider, + "Fused node must be assigned to WebGPU EP."); + ORT_RETURN_IF_NOT(node.InputDefs().size() == 9, "Fused node must expose the 9-input contract."); + ORT_RETURN_IF_NOT(node.OutputDefs().size() == (expect_skip_sln_output ? 4u : 3u), + "Fused node outputs did not match the expected simplified vs skip-simplified contract."); + // skip is at input index 1; for the SkipSimplifiedLayerNormalization-anchored pattern it + // must be wired to a real NodeArg, otherwise it must be the empty optional. + const auto* skip_def = node.InputDefs()[1]; + const bool skip_present = skip_def != nullptr && skip_def->Exists(); + ORT_RETURN_IF_NOT(skip_present == expect_skip_input, + "Fused node skip-input presence did not match the expected pattern variant."); + } + } + + return Status::OK(); +} + +Status CheckMatMulNBitsQkvFusedGraph(Graph& graph) { + return CheckMatMulNBitsQkvFusedGraphImpl(static_cast(graph), + /*expect_skip_sln_output=*/false, + /*expect_skip_input=*/false); +} + +Status CheckMatMulNBitsQkvSkipFusedGraph(Graph& graph) { + return CheckMatMulNBitsQkvFusedGraphImpl(static_cast(graph), + /*expect_skip_sln_output=*/false, + /*expect_skip_input=*/true); +} + +Status CheckMatMulNBitsQkvSkipOutputPassthroughFusedGraph(Graph& graph) { + return CheckMatMulNBitsQkvFusedGraphImpl(static_cast(graph), + /*expect_skip_sln_output=*/true, + /*expect_skip_input=*/true); +} + +void BuildMatMulNBitsQkvWebGpuPatternImpl(ModelTestBuilder& builder, bool with_skip_input, bool with_skip_output) { + constexpr int64_t k = 16; + constexpr int64_t q_n = 8; + constexpr int64_t kv_n = 4; + constexpr int64_t block_size = 32; + constexpr int64_t bits = 4; + constexpr int64_t accuracy_level = 4; + constexpr int64_t blob_size = block_size * bits / 8; + + NodeArg* input = builder.MakeInput( + std::vector{1, k}, + std::vector{ + MLFloat16(-1.0f), MLFloat16(-0.875f), MLFloat16(-0.75f), MLFloat16(-0.625f), + MLFloat16(-0.5f), MLFloat16(-0.375f), MLFloat16(-0.25f), MLFloat16(-0.125f), + MLFloat16(0.125f), MLFloat16(0.25f), MLFloat16(0.375f), MLFloat16(0.5f), + MLFloat16(0.625f), MLFloat16(0.75f), MLFloat16(0.875f), MLFloat16(1.0f)}); + NodeArg* skip_input = with_skip_input + ? builder.MakeInput( + std::vector{1, k}, + std::vector{ + MLFloat16(1.0f), MLFloat16(0.875f), MLFloat16(0.75f), MLFloat16(0.625f), + MLFloat16(0.5f), MLFloat16(0.375f), MLFloat16(0.25f), MLFloat16(0.125f), + MLFloat16(-0.125f), MLFloat16(-0.25f), MLFloat16(-0.375f), MLFloat16(-0.5f), + MLFloat16(-0.625f), MLFloat16(-0.75f), MLFloat16(-0.875f), MLFloat16(-1.0f)}) + : nullptr; + + NodeArg* norm_scale = builder.MakeInitializer({k}, MLFloat16(1.0f), MLFloat16(1.0f)); + NodeArg* q_weight = builder.MakeInitializer({q_n, 1, blob_size}, uint8_t{0}, uint8_t{15}); + NodeArg* q_scale = builder.MakeInitializer({q_n, 1}, MLFloat16(1.0f), MLFloat16(1.0f)); + NodeArg* k_weight = builder.MakeInitializer({kv_n, 1, blob_size}, uint8_t{0}, uint8_t{15}); + NodeArg* k_scale = builder.MakeInitializer({kv_n, 1}, MLFloat16(1.0f), MLFloat16(1.0f)); + NodeArg* v_weight = builder.MakeInitializer({kv_n, 1, blob_size}, uint8_t{0}, uint8_t{15}); + NodeArg* v_scale = builder.MakeInitializer({kv_n, 1}, MLFloat16(1.0f), MLFloat16(1.0f)); + NodeArg* optional_tensor = builder.MakeOptionalTensor(); + + NodeArg* norm_out = builder.MakeIntermediate(std::vector{1, k}); + NodeArg* optional_norm_output_1 = builder.MakeOptionalTensor(); + NodeArg* optional_norm_output_2 = builder.MakeOptionalTensor(); + NodeArg* residual_out = (with_skip_input && with_skip_output) ? builder.MakeIntermediate(std::vector{1, k}) : nullptr; + NodeArg* q_output = builder.MakeOutput(std::vector{1, q_n}); + NodeArg* k_output = builder.MakeOutput(std::vector{1, kv_n}); + NodeArg* v_output = builder.MakeOutput(std::vector{1, kv_n}); + NodeArg* residual_passthrough = (with_skip_input && with_skip_output) ? builder.MakeOutput(std::vector{1, k}) : nullptr; + + NodeAttributes q_attrs = MakeMatMulNBitsAttrs(k, q_n, block_size, bits, accuracy_level); + NodeAttributes kv_attrs = MakeMatMulNBitsAttrs(k, kv_n, block_size, bits, accuracy_level); + + Node& norm = with_skip_input + ? builder.AddNode("SkipSimplifiedLayerNormalization", + {input, skip_input, norm_scale}, + with_skip_output ? std::vector{norm_out, optional_norm_output_1, optional_norm_output_2, residual_out} + : std::vector{norm_out}, + kMSDomain) + : builder.AddNode("SimplifiedLayerNormalization", {input, norm_scale}, {norm_out}); + norm.AddAttribute("epsilon", 1e-6f); + + Node& q_matmul = builder.AddNode("MatMulNBits", {norm_out, q_weight, q_scale, optional_tensor, optional_tensor, optional_tensor}, {q_output}, kMSDomain, &q_attrs); + Node& k_matmul = builder.AddNode("MatMulNBits", {norm_out, k_weight, k_scale, optional_tensor, optional_tensor, optional_tensor}, {k_output}, kMSDomain, &kv_attrs); + Node& v_matmul = builder.AddNode("MatMulNBits", {norm_out, v_weight, v_scale, optional_tensor, optional_tensor, optional_tensor}, {v_output}, kMSDomain, &kv_attrs); + + SetWebGpuProvider(norm); + SetWebGpuProvider(q_matmul); + SetWebGpuProvider(k_matmul); + SetWebGpuProvider(v_matmul); + + if (with_skip_output) { + Node& residual_identity = builder.AddNode("Identity", {residual_out}, {residual_passthrough}); + SetWebGpuProvider(residual_identity); + } +} + +void BuildMatMulNBitsQkvWebGpuPattern(ModelTestBuilder& builder) { + BuildMatMulNBitsQkvWebGpuPatternImpl(builder, false, false); +} + +void BuildMatMulNBitsQkvSkipWebGpuPattern(ModelTestBuilder& builder) { + BuildMatMulNBitsQkvWebGpuPatternImpl(builder, true, false); +} + +void BuildMatMulNBitsQkvSkipOutputPassthroughWebGpuPattern(ModelTestBuilder& builder) { + BuildMatMulNBitsQkvWebGpuPatternImpl(builder, true, true); +} + +} // namespace + +TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionFusesWebGpuPattern) { + ASSERT_STATUS_OK(TestGraphTransformer( + BuildMatMulNBitsQkvWebGpuPattern, + 21, + *logger_, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + TransformerLevel::Level2, + 1, + nullptr, + CheckMatMulNBitsQkvFusedGraph)); +} + +TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionMatchesUnfusedWebGpuResults) { + if (!DefaultWebGpuExecutionProvider()) { + GTEST_SKIP() << "WebGPU EP unavailable in this build."; + } + + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + ASSERT_STATUS_OK(CheckMatMulNBitsQkvFusedGraphImpl(session.GetGraph(), + /*expect_skip_sln_output=*/false, + /*expect_skip_input=*/false)); + }; + + RunWebGpuFusionTransformerTest( + BuildMatMulNBitsQkvWebGpuPattern, + check_transformed_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21, + 1e-3, + 1e-3, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + []() { return DefaultWebGpuExecutionProvider(); }); +} + +TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionFusesSkipWebGpuPattern) { + ASSERT_STATUS_OK(TestGraphTransformer( + BuildMatMulNBitsQkvSkipWebGpuPattern, + 21, + *logger_, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + TransformerLevel::Level2, + 1, + nullptr, + CheckMatMulNBitsQkvSkipFusedGraph)); +} + +TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionMatchesUnfusedSkipWebGpuResults) { + if (!DefaultWebGpuExecutionProvider()) { + GTEST_SKIP() << "WebGPU EP unavailable in this build."; + } + + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + ASSERT_STATUS_OK(CheckMatMulNBitsQkvFusedGraphImpl(session.GetGraph(), + /*expect_skip_sln_output=*/false, + /*expect_skip_input=*/true)); + }; + + RunWebGpuFusionTransformerTest( + BuildMatMulNBitsQkvSkipWebGpuPattern, + check_transformed_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21, + 1e-3, + 1e-3, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + []() { return DefaultWebGpuExecutionProvider(); }); +} + +TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionFusesSkipWebGpuPatternWithResidualOutputPassthrough) { + ASSERT_STATUS_OK(TestGraphTransformer( + BuildMatMulNBitsQkvSkipOutputPassthroughWebGpuPattern, + 21, + *logger_, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + TransformerLevel::Level2, + 1, + nullptr, + CheckMatMulNBitsQkvSkipOutputPassthroughFusedGraph)); +} + +TEST_F(GraphTransformationTests, MatMulNBitsQkvFusionMatchesUnfusedSkipWebGpuResultsWithResidualOutputPassthrough) { + if (!DefaultWebGpuExecutionProvider()) { + GTEST_SKIP() << "WebGPU EP unavailable in this build."; + } + + auto add_session_options = [](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableSpecifiedOptimizers, + "EliminateIdentity")); + }; + + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + ASSERT_STATUS_OK(CheckMatMulNBitsQkvFusedGraphImpl(session.GetGraph(), + /*expect_skip_sln_output=*/true, + /*expect_skip_input=*/true)); + }; + + RunWebGpuFusionTransformerTest( + BuildMatMulNBitsQkvSkipOutputPassthroughWebGpuPattern, + check_transformed_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21, + 1e-3, + 1e-3, + std::make_unique(InlinedHashSet{kWebGpuExecutionProvider}), + []() { return DefaultWebGpuExecutionProvider(); }, + add_session_options); +} + +#endif // !defined(DISABLE_CONTRIB_OPS) + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/webgpu_fusion_test_util.h b/onnxruntime/test/optimizer/webgpu_fusion_test_util.h new file mode 100644 index 0000000000000..2fb3344bb9313 --- /dev/null +++ b/onnxruntime/test/optimizer/webgpu_fusion_test_util.h @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "core/framework/execution_provider.h" +#include "core/framework/session_options.h" +#include "core/graph/constants.h" +#include "core/graph/model.h" +#include "core/optimizer/graph_transformer.h" +#include "core/session/inference_session.h" +#include "test/compare_ortvalue.h" +#include "test/test_environment.h" +#include "test/unittest_util/graph_transform_test_builder.h" +#include "test/util/include/asserts.h" +#include "test/util/include/inference_session_wrapper.h" + +namespace onnxruntime { +namespace test { + +// Variant of TransformerTester for WebGPU fusion tests that creates a fresh execution provider +// per session via the provided factory, instead of sharing one EP across the baseline and target +// sessions. Sharing a single WebGPU EP across multiple InferenceSessions in series can leave the +// EP holding a dangling pointer to a destroyed session-level profiler; a separate fix to the EP +// addresses that, but using a fresh EP per session also avoids the issue and keeps the fusion PR +// independent of profiler-lifetime changes. +inline void RunWebGpuFusionTransformerTest( + const std::function& build_test_case, + const std::function& check_transformed_graph, + TransformerLevel baseline_level, + TransformerLevel target_level, + int opset_version, + double per_sample_tolerance, + double relative_per_sample_tolerance, + std::unique_ptr transformer, + const std::function()>& ep_factory, + const std::function& add_session_options = {}) { + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = opset_version; + domain_to_version[kMSDomain] = 1; + Model model("WebGpuFusionTester", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, {}, DefaultLoggingManager().DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + ASSERT_TRUE(build_test_case); + build_test_case(helper); + helper.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + auto run_model = [&](TransformerLevel level, std::vector& fetches, + std::unique_ptr level_transformer) { + SessionOptions session_options; + session_options.graph_optimization_level = level_transformer ? baseline_level : level; + if (add_session_options) { + add_session_options(session_options); + } + + InferenceSessionWrapper session{session_options, GetEnvironment()}; + auto ep = ep_factory(); + ASSERT_TRUE(ep != nullptr) << "ep_factory() returned nullptr"; + ASSERT_STATUS_OK(session.RegisterExecutionProvider(std::move(ep))); + + ASSERT_STATUS_OK(session.Load(model_data.data(), static_cast(model_data.size()))); + if (level_transformer) { + ASSERT_STATUS_OK(session.RegisterGraphTransformer(std::move(level_transformer), level)); + } + + ASSERT_STATUS_OK(session.Initialize()); + + RunOptions run_options; + ASSERT_STATUS_OK(session.Run(run_options, helper.feeds_, helper.output_names_, &fetches)); + + if (level == target_level && check_transformed_graph) { + check_transformed_graph(session); + } + }; + + std::vector baseline_fetches; + ASSERT_NO_FATAL_FAILURE(run_model(baseline_level, baseline_fetches, /*level_transformer=*/nullptr)); + + std::vector target_fetches; + ASSERT_NO_FATAL_FAILURE(run_model(target_level, target_fetches, std::move(transformer))); + + const size_t num_outputs = baseline_fetches.size(); + ASSERT_EQ(num_outputs, target_fetches.size()); + for (size_t i = 0; i < num_outputs; ++i) { + auto ret = CompareOrtValue(target_fetches[i], baseline_fetches[i], + per_sample_tolerance, relative_per_sample_tolerance, false); + EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second; + } +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 19311c99c6f6f..382cf534193d1 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -318,10 +318,14 @@ std::unique_ptr DefaultWebGpuExecutionProvider(bool is_nhwc) std::unique_ptr WebGpuExecutionProviderWithOptions(const ConfigOptions& config_options) { #if defined(USE_WEBGPU) #if defined(ORT_USE_EP_API_ADAPTERS) + // Return nullptr (rather than throwing) when the dynamic plugin EP is either uninitialized + // or initialized as a different EP. Tests interpret nullptr as "WebGPU EP unavailable" and + // skip themselves, which matches the behavior of the non-plugin code path below when + // USE_WEBGPU is undefined. auto ep_name = dynamic_plugin_ep_infra::GetEpName(); - ORT_ENFORCE(ep_name == kWebGpuExecutionProvider, - "Dynamic plugin EP is not the WebGPU EP. Expected \"", kWebGpuExecutionProvider, - "\", got \"", ep_name.value_or(""), "\""); + if (ep_name != kWebGpuExecutionProvider) { + return nullptr; + } return dynamic_plugin_ep_infra::MakeEp(nullptr, &config_options); #else return WebGpuProviderFactoryCreator::Create(config_options)->CreateProvider();