diff --git a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc index 54df00fd58d92..619aff1e806b8 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc @@ -146,8 +146,8 @@ Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeCo const Tensor* skip = context.Input(1); const Tensor* gamma = context.Input(2); // optional - const Tensor* beta = context.Input(3); - const Tensor* bias = context.Input(4); + const Tensor* beta = simplified ? nullptr : context.Input(3); + const Tensor* bias = context.Input(simplified ? 3 : 4); const auto x_shape = x->Shape(); @@ -168,9 +168,10 @@ Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeCo const auto skip_shape = skip->Shape(); const uint32_t skip_size = onnxruntime::narrow(skip_shape.Size()); - SkipLayerNormProgram program{beta != nullptr, bias != nullptr, epsilon_, hidden_size, has_input_skip_bias_sum, simplified, split_hidden_dim}; + SkipLayerNormProgram program{ + beta != nullptr, bias != nullptr, epsilon_, hidden_size, has_input_skip_bias_sum, simplified, split_hidden_dim}; program - .CacheHint(simplified, has_input_skip_bias_sum, split_hidden_dim) + .CacheHint(simplified, beta != nullptr, bias != nullptr, has_input_skip_bias_sum, split_hidden_dim) .AddInputs({{x, ProgramTensorMetadataDependency::Type, components}}) .AddInputs({{skip, ProgramTensorMetadataDependency::Type, components}}) .AddInputs({{gamma, ProgramTensorMetadataDependency::Type, components}}) diff --git a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc index a87fe8fe30b7b..b28fec5111f2b 100644 --- a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc @@ -875,6 +875,61 @@ TEST(SkipLayerNormTest, SkipSimplifiedLayerNormBatch1_Float16) { simplified); } +TEST(SkipLayerNormTest, SkipSimplifiedLayerNormBatch1_Bias_Float16) { + int batch_size = 1; + int sequence_length = 1; + int hidden_size = 8; + + std::vector input_data = { + 0.12573242f, -0.13208008f, 0.640625f, 0.10491943f, + -0.53564453f, 0.36157227f, 1.3037109f, 0.94726562f}; + + std::vector skip_data = { + -0.70361328f, -1.265625f, -0.62304688f, 0.041320801f, + -2.3242188f, -0.21875f, -1.2460938f, -0.73242188f}; + + std::vector gamma_data = { + 0.94580078f, 0.96826172f, 1.0410156f, 1.1044922f, + 0.98730469f, 1.1367188f, 0.93359375f, 1.0351562f}; + + std::vector bias_data = { + 0.45166016f, 0.04699707f, -0.37182617f, -0.4609375f, + -0.22888184f, 0.11010742f, -0.50488281f, -0.10461426f}; + + std::vector output_data = { + -0.098144531f, -1.0732422f, -0.30273438f, -0.28515625f, + -2.5019531f, 0.23596191f, -0.34277344f, 0.09362793f}; + + std::vector sum_output_data = { + -0.12646484f, -1.3505859f, -0.35424805f, -0.31469727f, + -3.0878906f, 0.25292969f, -0.44726562f, 0.11022949f}; + + bool use_float16 = true; + bool use_bfloat16 = false; + bool no_beta = true; + bool simplified = true; + + if (DefaultWebGpuExecutionProvider().get() == nullptr && DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "DirectML does not support this test case."; + } + + RunTest(input_data, + skip_data, + gamma_data, + std::vector(), + bias_data, + output_data, + sum_output_data, + 1e-5f, + batch_size, + sequence_length, + hidden_size, + use_float16, + use_bfloat16, + no_beta, + simplified); +} + TEST(SkipLayerNormTest, SkipLayerNormBatch2_Skip_Broadcast_No_Batch_Size) { int batch_size = 2; int sequence_length = 2;