Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
7c0c1b9
Initial commit
hariharans29 Apr 8, 2026
c42879d
Merge remote-tracking branch 'origin' into hari/webgpu_perf_1
hariharans29 Apr 8, 2026
a0550b6
More changes
hariharans29 Apr 9, 2026
c55adfe
Merge branch 'hari/webgpu_perf_1' of https://github.com/microsoft/onn…
hariharans29 Apr 9, 2026
ee09d8e
Stage
hariharans29 Apr 13, 2026
aa357ee
More changes
hariharans29 Apr 15, 2026
318b26b
Stage
hariharans29 Apr 20, 2026
ad53b3d
Worka nd good perf
hariharans29 Apr 22, 2026
b67ae81
Skip + MatmulNBitsSilu fusion - works and good perf
hariharans29 Apr 23, 2026
01671d9
Cleanup
hariharans29 Apr 30, 2026
30485dd
Move back to workgroup/tile_size default
hariharans29 Apr 30, 2026
27317b8
Merge main
hariharans29 Apr 30, 2026
a56fb56
Merge remote-tracking branch 'origin' into hari/webgpu_perf_1
hariharans29 May 1, 2026
13bf979
Copilot comments + Fix builds + Fix lint + Fusion diagrams
hariharans29 May 1, 2026
d1090c8
Fix test
hariharans29 May 1, 2026
ffacd4c
Fix builds
hariharans29 May 1, 2026
92874ce
Fixes
hariharans29 May 1, 2026
a7899c6
Slim PR: drop benchmark harness, lazy buffer-mgr fix, consteval fix, …
hariharans29 May 2, 2026
2039c7f
Remove unused dp4a_matmul_mlp.wgsl.template
hariharans29 May 2, 2026
a02cf12
Cleanup: drop unused empty namespace + env_var_utils include in graph…
hariharans29 May 2, 2026
beb1709
Merge remote-tracking branch 'origin' into hari/webgpu_perf_1
hariharans29 May 2, 2026
9065063
Copilot comments
hariharans29 May 2, 2026
4ac9c81
Fixes
hariharans29 May 2, 2026
306fba3
Fix
hariharans29 May 3, 2026
6c8c7a3
Use fresh WebGPU EP per session in fusion-vs-unfused tests
hariharans29 May 3, 2026
a90a049
Remove unused file
hariharans29 May 10, 2026
007a78e
[WebGPU] Extract shared LayerNorm/SkipLayerNorm program runners
hariharans29 May 11, 2026
37db5b8
[WebGPU] MatMulNBitsMlp: adopt shared norm helpers + activation enum
hariharans29 May 11, 2026
2c1a2a3
[WebGPU] MatMulNBitsMlpFusion: match fused-QuickGelu MLP shape
hariharans29 May 11, 2026
234bcf4
[WebGPU/JSEP] Enable QuickGeluFusion for WebGPU and JSEP EPs
hariharans29 May 11, 2026
eaa6635
Copilot comments
hariharans29 May 12, 2026
106c07e
Merge main and resolve conflicts
hariharans29 May 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ Do not modify directly.*
* <a href="#com.microsoft.MatMulInteger16">com.microsoft.MatMulInteger16</a>
* <a href="#com.microsoft.MatMulIntegerToFloat">com.microsoft.MatMulIntegerToFloat</a>
* <a href="#com.microsoft.MatMulNBits">com.microsoft.MatMulNBits</a>
* <a href="#com.microsoft.MatMulNBitsMlp">com.microsoft.MatMulNBitsMlp</a>
* <a href="#com.microsoft.MatMulNBitsQkv">com.microsoft.MatMulNBitsQkv</a>
* <a href="#com.microsoft.MaxpoolWithMask">com.microsoft.MaxpoolWithMask</a>
* <a href="#com.microsoft.MoE">com.microsoft.MoE</a>
* <a href="#com.microsoft.MulInteger">com.microsoft.MulInteger</a>
Expand Down Expand Up @@ -3189,6 +3191,190 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>


### <a name="com.microsoft.MatMulNBitsMlp"></a><a name="com.microsoft.matmulnbitsmlp">**com.microsoft.MatMulNBitsMlp**</a>

MatMulNBitsMlp fuses two MatMulNBits projections that share the same input and computes

gate = MatMulNBits(A, gate_weight) + gate_bias
up = MatMulNBits(A, up_weight) + up_bias
Y = activation(gate) * up

It can also optionally fuse SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization before the
two projections:

A_norm = SimplifiedLayerNormalization(A, norm_scale, epsilon)
gate = MatMulNBits(A_norm, gate_weight) + gate_bias
up = MatMulNBits(A_norm, up_weight) + up_bias
Y = activation(gate) * up

A_norm = SkipSimplifiedLayerNormalization(A, skip, norm_scale, epsilon)
gate = MatMulNBits(A_norm, gate_weight) + gate_bias
up = MatMulNBits(A_norm, up_weight) + up_bias
Y = activation(gate) * up

This operator is intended for decoder MLP patterns such as Qwen-style gate and up projections, but it remains
semantically valid for both prefill and decode because the output shape is the standard MatMul result shape
derived from the runtime shape of A and the shared attributes K and N.

The operator contract includes a string attribute describing the fused gate activation.

When fused from SkipSimplifiedLayerNormalization, the optional residual-sum output may also be materialized:

A_norm, input_skip_bias_sum = SkipSimplifiedLayerNormalization(A, skip, norm_scale, epsilon)
gate = MatMulNBits(A_norm, gate_weight) + gate_bias
up = MatMulNBits(A_norm, up_weight) + up_bias
Y = activation(gate) * up

#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>K</tt> : int (required)</dt>
<dd>Input feature dimension shared by both quantized weight matrices.</dd>
<dt><tt>N</tt> : int (required)</dt>
<dd>Output feature dimension shared by both quantized weight matrices.</dd>
<dt><tt>accuracy_level</tt> : int</dt>
<dd>The minimum accuracy level of input A. It follows the same semantics as MatMulNBits.</dd>
<dt><tt>activation</tt> : string (required)</dt>
<dd>Activation applied to the gate projection.</dd>
<dt><tt>bits</tt> : int</dt>
<dd>Bit-width used to quantize both weight matrices (valid range: 2~8)</dd>
<dt><tt>block_size</tt> : int (required)</dt>
<dd>Size of each quantization block along the K dimension. Must be a power of two and >= 16.</dd>
<dt><tt>epsilon</tt> : float</dt>
<dd>Epsilon used by the optional fused (Skip)SimplifiedLayerNormalization. Defaults to 1e-5.</dd>
</dl>

#### Inputs (8 - 9)

<dl>
<dt><tt>A</tt> : T1</dt>
<dd>The shared input tensor.</dd>
<dt><tt>skip</tt> (optional) : T1</dt>
<dd>Optional skip input used by SkipSimplifiedLayerNormalization.</dd>
<dt><tt>norm_scale</tt> (optional) : T1</dt>
<dd>Optional RMSNorm scale with shape [K] used by SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization.</dd>
<dt><tt>gate_B</tt> : T2</dt>
<dd>Packed uint8 tensor for the gate projection weights.</dd>
<dt><tt>gate_scales</tt> : T1</dt>
<dd>Per-block scaling factors for the gate projection.</dd>
<dt><tt>gate_bias</tt> (optional) : T1</dt>
<dd>Optional bias for the gate projection with shape [N].</dd>
<dt><tt>up_B</tt> : T2</dt>
<dd>Packed uint8 tensor for the up projection weights.</dd>
<dt><tt>up_scales</tt> : T1</dt>
<dd>Per-block scaling factors for the up projection.</dd>
<dt><tt>up_bias</tt> (optional) : T1</dt>
<dd>Optional bias for the up projection with shape [N].</dd>
</dl>

#### Outputs (1 - 2)

<dl>
<dt><tt>Y</tt> : T1</dt>
<dd>The fused gated MLP output tensor.</dd>
<dt><tt>input_skip_bias_sum</tt> (optional) : T1</dt>
<dd>Optional residual-sum output for SkipSimplifiedLayerNormalization.</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T1</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
<dd>Constrain input and output types to float tensors.</dd>
<dt><tt>T2</tt> : tensor(uint8)</dt>
<dd>Constrain quantized weight types to uint8.</dd>
</dl>


### <a name="com.microsoft.MatMulNBitsQkv"></a><a name="com.microsoft.matmulnbitsqkv">**com.microsoft.MatMulNBitsQkv**</a>

MatMulNBitsQkv fuses either SimplifiedLayerNormalization (RMSNorm)
or SkipSimplifiedLayerNormalization with three MatMulNBits projections that share the
same normalized activation.

A_norm = SimplifiedLayerNormalization(A, norm_scale, epsilon)
Q = MatMulNBits(A_norm, q_weight)
K = MatMulNBits(A_norm, k_weight)
V = MatMulNBits(A_norm, v_weight)

If skip is provided, the operator computes the SkipSimplifiedLayerNormalization variant
and may also return the input+skip residual sum as output 3.

This operator is intended as a decode-oriented QKV fusion primitive.

#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.

#### Attributes

<dl>
<dt><tt>K</tt> : int (required)</dt>
<dd>Input feature dimension shared by the normalized input and all projection weights.</dd>
<dt><tt>Nkv</tt> : int (required)</dt>
<dd>Output feature dimension shared by the K and V projections.</dd>
<dt><tt>Nq</tt> : int (required)</dt>
<dd>Output feature dimension of the Q projection.</dd>
<dt><tt>accuracy_level</tt> : int</dt>
<dd>The minimum accuracy level of input A. It follows the same semantics as MatMulNBits.</dd>
<dt><tt>bits</tt> : int</dt>
<dd>Bit-width used to quantize all weight matrices (valid range: 2~8)</dd>
<dt><tt>block_size</tt> : int (required)</dt>
<dd>Size of each quantization block along the K dimension. Must be a power of two and >= 16.</dd>
<dt><tt>epsilon</tt> : float</dt>
<dd>Epsilon used by the simplified layer norm reduction.</dd>
</dl>

#### Inputs

<dl>
<dt><tt>A</tt> : T1</dt>
<dd>The shared input tensor.</dd>
<dt><tt>skip</tt> (optional) : T1</dt>
<dd>Optional residual input for SkipSimplifiedLayerNormalization.</dd>
<dt><tt>norm_scale</tt> : T1</dt>
<dd>Scale input for the simplified layer norm with shape [K].</dd>
<dt><tt>q_B</tt> : T2</dt>
<dd>Packed uint8 tensor for the Q projection weights.</dd>
<dt><tt>q_scales</tt> : T1</dt>
<dd>Per-block scaling factors for the Q projection.</dd>
<dt><tt>k_B</tt> : T2</dt>
<dd>Packed uint8 tensor for the K projection weights.</dd>
<dt><tt>k_scales</tt> : T1</dt>
<dd>Per-block scaling factors for the K projection.</dd>
<dt><tt>v_B</tt> : T2</dt>
<dd>Packed uint8 tensor for the V projection weights.</dd>
<dt><tt>v_scales</tt> : T1</dt>
<dd>Per-block scaling factors for the V projection.</dd>
</dl>

#### Outputs (3 - 4)

<dl>
<dt><tt>Q</tt> : T1</dt>
<dd>The Q projection output tensor.</dd>
<dt><tt>K</tt> : T1</dt>
<dd>The K projection output tensor.</dd>
<dt><tt>V</tt> : T1</dt>
<dd>The V projection output tensor.</dd>
<dt><tt>input_skip_bias_sum</tt> (optional) : T1</dt>
<dd>Optional residual-sum output for SkipSimplifiedLayerNormalization.</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T1</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
<dd>Constrain input and output types to float tensors.</dd>
<dt><tt>T2</tt> : tensor(uint8)</dt>
<dd>Constrain quantized weight types to uint8.</dd>
</dl>


### <a name="com.microsoft.MaxpoolWithMask"></a><a name="com.microsoft.maxpoolwithmask">**com.microsoft.MaxpoolWithMask**</a>

For internal use.
Expand Down
31 changes: 24 additions & 7 deletions onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,26 @@ Status SkipLayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeCo
auto* output = context.Output(0, x_shape);
auto* input_skip_bias_sum = context.Output(3, x_shape);

int64_t data_size = x_shape.Size();
if (data_size == 0) {
if (x_shape.Size() == 0) {
return Status::OK();
}

return RunSkipLayerNormProgram(context, x, skip, gamma, beta, bias, epsilon_, simplified,
output, input_skip_bias_sum);
}

Status RunSkipLayerNormProgram(ComputeContext& context,
const Tensor* x,
const Tensor* skip,
const Tensor* gamma,
const Tensor* beta,
const Tensor* bias,
float epsilon,
bool simplified,
Tensor* output,
Tensor* input_skip_bias_sum) {
const auto& x_shape = x->Shape();
if (x_shape.Size() == 0) {
return Status::OK();
}

Expand All @@ -165,26 +183,25 @@ Status SkipLayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeCo
const uint32_t norm_count = onnxruntime::narrow<uint32_t>(x_shape.SizeToDimension(x_shape.NumDimensions() - 1));
const bool split_hidden_dim = hidden_size % 512 == 0 && norm_count == 1;

const auto skip_shape = skip->Shape();
const uint32_t skip_size = onnxruntime::narrow<uint32_t>(skip_shape.Size());
const uint32_t skip_size = onnxruntime::narrow<uint32_t>(skip->Shape().Size());

SkipLayerNormProgram program{
beta != nullptr, bias != nullptr, epsilon_, hidden_size, has_input_skip_bias_sum, simplified, split_hidden_dim};
beta != nullptr, bias != nullptr, epsilon, hidden_size, has_input_skip_bias_sum, simplified, split_hidden_dim};
program
.CacheHint(simplified, beta != nullptr, bias != nullptr, has_input_skip_bias_sum, split_hidden_dim)
.AddInputs({{x, ProgramTensorMetadataDependency::Type, components}})
.AddInputs({{skip, ProgramTensorMetadataDependency::Type, components}})
.AddInputs({{gamma, ProgramTensorMetadataDependency::Type, components}})
.AddOutputs({{output, ProgramTensorMetadataDependency::None, components}})
.SetDispatchGroupSize(onnxruntime::narrow<uint32_t>(ceil(1.0 * data_size / hidden_size)))
.SetDispatchGroupSize(onnxruntime::narrow<uint32_t>(ceil(1.0 * x_shape.Size() / hidden_size)))
.AddUniformVariables({
{static_cast<uint32_t>(components)},
})
.AddUniformVariables({
{static_cast<uint32_t>(hidden_size)},
})
.AddUniformVariables({
{static_cast<float>(epsilon_)},
{static_cast<float>(epsilon)},
})
.AddUniformVariables({
{static_cast<uint32_t>(skip_size)},
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,21 @@ class SkipLayerNorm final : public WebGpuKernel {
float epsilon_;
};

// Configures and dispatches a SkipLayerNormProgram. Centralizes program-setup logic
// (uniform variables, components, split_hidden_dim heuristic, workgroup sizing) so callers
// other than the SkipLayerNorm kernel (e.g. fused MatMulNBits ops) do not need to duplicate it.
// `beta`, `bias` and `input_skip_bias_sum` may be nullptr.
Status RunSkipLayerNormProgram(ComputeContext& context,
const Tensor* x,
const Tensor* skip,
const Tensor* gamma,
const Tensor* beta,
const Tensor* bias,
float epsilon,
bool simplified,
Tensor* output,
Tensor* input_skip_bias_sum);

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
44 changes: 28 additions & 16 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ namespace onnxruntime {
namespace contrib {
namespace webgpu {

namespace {
constexpr unsigned int kMinMForTileOptimization = 4;
} // namespace

ONNX_OPERATOR_KERNEL_EX(
MatMulNBits,
kMSDomain,
Expand Down Expand Up @@ -226,29 +222,44 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
uint32_t zero_blocks_per_col = (n_blocks_per_col + zp_elements_per_byte - 1) / zp_elements_per_byte * zp_elements_per_byte;

#if !defined(__wasm__)
// apple|intel - Experimental dawn support for subgroup matrix matmul.
int32_t subgroup_matrix_config_index = -1;
// Experimental dawn support for subgroup matrix matmul (vendor-agnostic).
if ((M >= kMinMForTileOptimization && !has_weight_idx_indirect) &&
CanApplySubgroupMatrixMatMulNBits(context, accuracy_level, block_size, batch_count, N, K, static_cast<uint32_t>(nbits), y->DataType() == DataTypeImpl::GetType<MLFloat16>(), subgroup_matrix_config_index)) {
if (WouldApplySubgroupMatrixMatMulNBitsInCurrentDispatch(M,
N,
K,
batch_count,
block_size,
accuracy_level,
nbits,
context,
y,
has_weight_idx_indirect,
&subgroup_matrix_config_index,
override_M)) {
return ApplySubgroupMatrixMatMulNBits(a, b, scales, zero_points, bias, M, N, K, static_cast<uint32_t>(nbits), zero_blocks_per_col, subgroup_matrix_config_index, context, y, weight_index, weight_index_indirect);
}
#endif

// On FP32 only GPUs and Qualcomm GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M.
// DP4A Q2 path now supports custom zero points via a 1024-entry LUT (4 zero-point sections × 256 byte values).
if (((M >= kMinMForTileOptimization && !has_weight_idx_indirect) || y->DataType() == DataTypeImpl::GetType<float>() || context.AdapterInfo().vendor == std::string_view{"qualcomm"}) &&
CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a)) {
if (WouldApplyDP4AMatMulNBitsInCurrentDispatch(M,
N,
K,
block_size,
accuracy_level,
context,
y,
has_weight_idx_indirect)) {
return ApplyDP4AMatrixMatMulNBits(a, b, scales, zero_points, bias, batch_count, M, dispatch_M, N, K, block_size, zero_blocks_per_col, kMinMForTileOptimization, static_cast<uint32_t>(nbits), context, y, weight_index, weight_index_indirect);
}

// WideTileProgram
// This program is optimized for Block32 prefill using Tile16x128.
const bool use_wide_tile_program = !has_weight_idx_indirect &&
block_size == 32 &&
components_a == 4 &&
components_b == 4 &&
nbits != 2 &&
M >= kMinMForTileOptimization;
const bool use_wide_tile_program = WouldApplyWideTileMatMulNBitsInCurrentDispatch(M,
K,
block_size,
nbits,
has_weight_idx_indirect);

if (use_wide_tile_program) {
// Enforce output components to 1.
Expand Down Expand Up @@ -308,7 +319,8 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,

// Use tile_size_k_vec=32 by default for better K-dimension parallelism.
// Intel devices use 16 as they have different subgroup/cache characteristics.
const uint32_t tile_size_k_vec = (context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u;
const uint32_t tile_size_k_vec =
(context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u;

constexpr uint32_t workgroup_size = 128;
constexpr uint32_t tile_size = 8;
Expand Down
Loading
Loading