diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index c288a82994e98..9a5aebd696555 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -222,52 +222,62 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_seqlen_k_)); } -Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader) const { +Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); shader.AddInput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - if (use_indirect_dispatch_) { + shader.AddInput("present_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + if (use_indirect_dispatch_ || use_seqlen_k_) { shader.AddInput("seqlens_k", ShaderUsage::None); } if (has_attention_bias_) { shader.AddInput("attention_bias", ShaderUsage::UseUniform); } - shader.AddOutput("output", ShaderUsage::UseUniform); + shader.AddOutput("out_split_vx", ShaderUsage::UseUniform); shader.AddOutput("metadata", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); const uint32_t tile_size_k_vec = 8; const uint32_t sub_tile_count = WorkgroupSizeX() / tile_size_k_vec; - return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_qkt.wgsl.template", + return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_qkv.wgsl.template", WGSL_TEMPLATE_PARAMETER(has_attention_bias, has_attention_bias_), + WGSL_TEMPLATE_PARAMETER(is_unidirectional, is_unidirectional_), + WGSL_TEMPLATE_PARAMETER(m_tile, m_tile_), + WGSL_TEMPLATE_PARAMETER(q_BNSH, q_BNSH_), WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count), WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), - WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_)); + WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_indirect_dispatch_ || use_seqlen_k_), + WGSL_TEMPLATE_PARAMETER(v_head_size_vec, head_size_vec_)); } -Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q, - const Tensor* attention_bias, Tensor* output, Tensor* present_key, Tensor* metadata, const Tensor* seqlen_k, - const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length) { +Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q, + const Tensor* attention_bias, Tensor* out_split_vx, Tensor* present_key, Tensor* present_value, + Tensor* metadata, const Tensor* seqlen_k, + const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length, uint32_t m_tile) { const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) : parameters.scale_; const bool has_attention_bias = attention_bias != nullptr; const int components = 4; + const int head_size_vec = parameters.v_head_size_ / components; - FlashAttentionDecodeQKTProgram program{"FlashAttentionDecodeQKT", has_attention_bias, tile_size, use_indirect_dispatch}; + bool q_BNSH = parameters.qkv_format_ == Q_K_V_BNSH; + bool is_unidirectional = parameters.is_unidirectional_; + bool decode_use_seqlen_k = !use_indirect_dispatch && seqlen_k != nullptr; + FlashAttentionDecodeQKVProgram program{"FlashAttentionDecodeQKV", has_attention_bias, tile_size, head_size_vec, use_indirect_dispatch, q_BNSH, is_unidirectional, decode_use_seqlen_k, m_tile}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, - {present_key, ProgramTensorMetadataDependency::TypeAndRank, components}}); - if (use_indirect_dispatch) { + {present_key, ProgramTensorMetadataDependency::TypeAndRank, components}, + {present_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); + if (use_indirect_dispatch || decode_use_seqlen_k) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } if (has_attention_bias) { program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank}); } - program.AddOutputs({{output, ProgramTensorMetadataDependency::Rank}, + program.AddOutputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}, {metadata, ProgramTensorMetadataDependency::Rank, 2}}); const uint32_t vectorized_head_size = parameters.head_size_ / components; - // Get attention bias dimensions for broadcasting uint32_t attn_bias_dim0 = 1; uint32_t attn_bias_dim1 = 1; if (has_attention_bias) { @@ -279,10 +289,10 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte if (use_indirect_dispatch) { program.SetIndirectDispatchTensor(indirect_buffer); } else { - program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_total_seq_length_tile); + program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * ((parameters.sequence_length_ + m_tile - 1) / m_tile) * num_total_seq_length_tile); } program.SetWorkgroupSize(64) - .CacheHint(tile_size, has_attention_bias, use_indirect_dispatch) + .CacheHint(tile_size, head_size_vec, has_attention_bias, use_indirect_dispatch, q_BNSH, is_unidirectional, decode_use_seqlen_k, m_tile) .AddUniformVariables({{static_cast(vectorized_head_size)}, {static_cast(parameters.total_sequence_length_)}, {static_cast(alpha)}, @@ -292,124 +302,70 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte {static_cast(parameters.num_heads_)}, {static_cast(parameters.batch_size_)}, {attn_bias_dim0}, - {attn_bias_dim1}}); + {attn_bias_dim1}, + {static_cast(parameters.sequence_length_)}}); return context.RunProgram(program); } -Status FlashAttentionDecodeSplitVxProgram::GenerateShaderCode(ShaderHelper& shader) const { +Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("input", ShaderUsage::UseUniform); shader.AddInput("metadata", ShaderUsage::UseUniform); - shader.AddInput("qk", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - shader.AddInput("present_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - if (use_indirect_dispatch_) { + if (use_indirect_dispatch_ || use_seqlen_k_) { shader.AddInput("seqlens_k", ShaderUsage::None); } if (has_head_sink_) { shader.AddInput("head_sink", ShaderUsage::UseUniform); } - shader.AddOutput("out_split_vx", ShaderUsage::UseUniform); - - const uint32_t tile_size_k_vec = 8u; - - return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_split_vx.wgsl.template", - WGSL_TEMPLATE_PARAMETER(has_head_sink, has_head_sink_), - WGSL_TEMPLATE_PARAMETER(head_size_vec, head_size_vec_), - WGSL_TEMPLATE_PARAMETER(sub_tile_count, WorkgroupSizeX() / tile_size_k_vec), - WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), - WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), - WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_)); -} - -Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeContext& context, - const Tensor* metadata, - const Tensor* qk, - Tensor* out_split_vx, - Tensor* present_value, - const Tensor* seqlen_k, - const WebgpuAttentionParameters& parameters, - const Tensor* indirect_buffer, - uint32_t num_total_seq_length_tile, - uint32_t num_present_sequence_length_tile, - uint32_t tile_size, - bool use_indirect_dispatch, - uint32_t present_sequence_length, - const Tensor* head_sink) { - const int components = 4; - const bool has_head_sink = head_sink != nullptr; - int head_size_vec = parameters.v_head_size_ / components; - FlashAttentionDecodeSplitVxProgram program{"FlashAttentionDecodeSplitVx", tile_size, head_size_vec, use_indirect_dispatch, has_head_sink}; - program.AddInputs({{metadata, ProgramTensorMetadataDependency::TypeAndRank, 2}, - {qk, ProgramTensorMetadataDependency::TypeAndRank}, - {present_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); - program.AddOutputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}}); // [B, N, split_k, head_size] - const uint32_t batch_heads = static_cast(parameters.batch_size_ * parameters.num_heads_); - if (use_indirect_dispatch) { - program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); - } - if (has_head_sink) { - program.AddInput({head_sink, ProgramTensorMetadataDependency::Type}); - } - // SetIndirectDispatchTensor must be called after all AddInput calls because it - // appends the indirect buffer as the last program input. - if (use_indirect_dispatch) { - program.SetIndirectDispatchTensor(indirect_buffer); - } else { - program.SetDispatchGroupSize(batch_heads * num_total_seq_length_tile); - } - program.CacheHint(tile_size, head_size_vec, use_indirect_dispatch, has_head_sink) - .SetWorkgroupSize(64) - .AddUniformVariables({{static_cast(parameters.total_sequence_length_)}, - {static_cast(head_size_vec)}, - present_sequence_length, - {static_cast(parameters.n_reps)}, - num_present_sequence_length_tile, - {batch_heads}, - {static_cast(parameters.num_heads_)}}); - - return context.RunProgram(program); -} - -Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("input", ShaderUsage::UseUniform); - if (use_indirect_dispatch_) { - shader.AddInput("seqlens_k", ShaderUsage::None); - } - shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_vx_reduce.wgsl.template", + WGSL_TEMPLATE_PARAMETER(has_head_sink, has_head_sink_), + WGSL_TEMPLATE_PARAMETER(m_tile, m_tile_), WGSL_TEMPLATE_PARAMETER(seq_tile_size, seq_tile_size_), WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), - WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_)); + WGSL_TEMPLATE_PARAMETER(use_seqlen_k, use_indirect_dispatch_ || use_seqlen_k_)); } Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& context, const Tensor* out_split_vx, + const Tensor* metadata, Tensor* output, const Tensor* seqlen_k, const WebgpuAttentionParameters& parameters, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t seq_tile_size, - bool use_indirect_dispatch) { + bool use_indirect_dispatch, + const Tensor* head_sink, + uint32_t m_tile) { const int components = 4; constexpr int tile_size = 8; int tile_head_size = tile_size * components; - FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, seq_tile_size, use_indirect_dispatch}; - program.AddInputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}}); - if (use_indirect_dispatch) { + bool decode_use_seqlen_k = !use_indirect_dispatch && seqlen_k != nullptr; + bool has_head_sink = head_sink != nullptr; + FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, seq_tile_size, use_indirect_dispatch, decode_use_seqlen_k, has_head_sink, m_tile}; + program.AddInputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}, + {metadata, ProgramTensorMetadataDependency::TypeAndRank, 2}}); + if (use_indirect_dispatch || decode_use_seqlen_k) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } + if (has_head_sink) { + program.AddInput({head_sink, ProgramTensorMetadataDependency::Type}); + } program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, components}}); const uint32_t num_head_size_tile = static_cast((parameters.v_head_size_ + tile_head_size - 1) / tile_head_size); const uint32_t batch_heads = static_cast(parameters.batch_size_ * parameters.num_heads_); - program.SetDispatchGroupSize(batch_heads * num_head_size_tile) - .CacheHint(tile_size, seq_tile_size, use_indirect_dispatch) + program.SetDispatchGroupSize(batch_heads * ((parameters.sequence_length_ + m_tile - 1) / m_tile) * num_head_size_tile) + .CacheHint(tile_size, seq_tile_size, use_indirect_dispatch, decode_use_seqlen_k, has_head_sink, m_tile) .SetWorkgroupSize(tile_size * tile_size) .AddUniformVariables({{static_cast(parameters.v_head_size_ / components)}, num_total_seq_length_tile, num_present_sequence_length_tile, {num_head_size_tile}, - {batch_heads}}); + {batch_heads}, + {static_cast(parameters.sequence_length_)}, + {static_cast(parameters.num_heads_)}}); return context.RunProgram(program); } @@ -479,7 +435,15 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr)); } - if (parameters.sequence_length_ > 1) { + // Route between prefill path (FlashAttentionProgram, single kernel, requires subgroups) + // and fused decode path (QKV + VxReduce, 2 kernels, no subgroups needed). + const bool use_split_reduce = (parameters.sequence_length_ <= 4) || + !context.HasFeature(wgpu::FeatureName::Subgroups) || + (parameters.sequence_length_ < 64 && + static_cast(parameters.total_sequence_length_) > 1000); + + if (!use_split_reduce) { + // Prefill path: FlashAttentionProgram (single kernel with subgroup shuffles) bool has_attention_bias = attention_bias != nullptr; bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"}; bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"}; @@ -514,7 +478,6 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co : parameters.scale_; const uint32_t num_seq_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size; - // Get attention bias dimensions for broadcasting uint32_t attn_bias_dim0 = 1; uint32_t attn_bias_dim1 = 1; if (has_attention_bias) { @@ -539,37 +502,40 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co return context.RunProgram(program); } - // For decode path (sequence_length == 1) - const TensorShapeVector qk_dims({parameters.batch_size_, parameters.num_heads_, - parameters.sequence_length_, present_sequence_length}); - const TensorShape qk_shape(qk_dims); - Tensor qk = context.CreateGPUTensor(Q->DataType(), qk_shape); - const uint32_t num_total_seq_length_tile = (parameters.total_sequence_length_ + tile_size - 1) / tile_size; + // Split-reduce path (QKV + VxReduce) + + // Compute m_tile: process multiple Q rows per workgroup to amortize K/V loads + const uint32_t m_tile = parameters.sequence_length_ >= 4 ? 4u : (parameters.sequence_length_ >= 2 ? 2u : 1u); + + // When use_seqlen_k is true, total_sequence_length_ may be 0 (actual value is in seqlen_k tensor). + // Use present_sequence_length for tile count calculations; shaders will read the actual value from seqlen_k. + const uint32_t effective_total_seq_len = use_seqlen_k ? present_sequence_length + : static_cast(parameters.total_sequence_length_); + + const uint32_t num_total_seq_length_tile = (effective_total_seq_len + tile_size - 1) / tile_size; const uint32_t num_present_sequence_length_tile = (present_sequence_length + tile_size - 1) / tile_size; // The metadata is used to store the max and sum of each tile. const TensorShapeVector metadata_dims({parameters.batch_size_, parameters.num_heads_, - num_present_sequence_length_tile, 2}); + parameters.sequence_length_, num_present_sequence_length_tile, 2}); const TensorShape metadata_shape(metadata_dims); Tensor metadata = context.CreateGPUTensor(DataTypeImpl::GetType(), metadata_shape); - ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeQKT(context, Q, attention_bias, &qk, present_key, &metadata, seqlen_k, - parameters, indirect_buffer_ptr, num_total_seq_length_tile, - num_present_sequence_length_tile, tile_size, use_indirect_dispatch, - present_sequence_length)); const TensorShapeVector out_split_vx_dims({parameters.batch_size_, parameters.num_heads_, - num_present_sequence_length_tile, parameters.head_size_}); + parameters.sequence_length_, num_present_sequence_length_tile, parameters.head_size_}); const TensorShape out_split_vx_shape(out_split_vx_dims); Tensor out_split_vx = context.CreateGPUTensor(Q->DataType(), out_split_vx_shape); - ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeSplitVxScore(context, &metadata, &qk, &out_split_vx, present_value, - seqlen_k, parameters, indirect_buffer_ptr, - num_total_seq_length_tile, - num_present_sequence_length_tile, tile_size, - use_indirect_dispatch, present_sequence_length, - head_sink)); - ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, output, seqlen_k, parameters, + + ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeQKV(context, Q, attention_bias, &out_split_vx, present_key, present_value, + &metadata, seqlen_k, + parameters, indirect_buffer_ptr, num_total_seq_length_tile, + num_present_sequence_length_tile, tile_size, use_indirect_dispatch, + present_sequence_length, m_tile)); + + ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, &metadata, output, seqlen_k, parameters, num_total_seq_length_tile, - num_present_sequence_length_tile, tile_size, use_indirect_dispatch)); + num_present_sequence_length_tile, tile_size, use_indirect_dispatch, + head_sink, m_tile)); return Status::OK(); } @@ -577,7 +543,6 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { return !parameters.is_packed_qkv_ && parameters.head_size_ == parameters.v_head_size_ && - context.HasFeature(wgpu::FeatureName::Subgroups) && ((context.AdapterInfo().vendor == std::string_view{"qualcomm"} && parameters.head_size_ % 8 == 0) || parameters.head_size_ % 4 == 0); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 980ddc3a5373b..38797d0d8edd6 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -117,11 +117,14 @@ class FlashAttentionProgram final : public Program { bool has_head_sink_; }; -class FlashAttentionDecodeQKTProgram final : public Program { +class FlashAttentionDecodeQKVProgram final : public Program { public: - FlashAttentionDecodeQKTProgram(const std::string& kernel_name, - bool has_attention_bias, uint32_t tile_size, bool use_indirect_dispatch) - : Program{kernel_name}, has_attention_bias_(has_attention_bias), tile_size_(tile_size), use_indirect_dispatch_(use_indirect_dispatch) { + FlashAttentionDecodeQKVProgram(const std::string& kernel_name, + bool has_attention_bias, uint32_t tile_size, int head_size_vec, + bool use_indirect_dispatch, bool q_BNSH = false, + bool is_unidirectional = false, bool use_seqlen_k = false, + uint32_t m_tile = 1) + : Program{kernel_name}, has_attention_bias_(has_attention_bias), tile_size_(tile_size), head_size_vec_(head_size_vec), use_indirect_dispatch_(use_indirect_dispatch), q_BNSH_(q_BNSH), is_unidirectional_(is_unidirectional), use_seqlen_k_(use_seqlen_k), m_tile_(m_tile) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -135,41 +138,24 @@ class FlashAttentionDecodeQKTProgram final : public Program { - public: - FlashAttentionDecodeSplitVxProgram(const std::string& kernel_name, uint32_t tile_size, int head_size_vec, bool use_indirect_dispatch, bool has_head_sink = false) - : Program{kernel_name}, tile_size_(tile_size), head_size_vec_(head_size_vec), use_indirect_dispatch_(use_indirect_dispatch), has_head_sink_(has_head_sink) { - } - - Status GenerateShaderCode(ShaderHelper& sh) const override; - - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"total_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"head_size_vec", ProgramUniformVariableDataType::Uint32}, - {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"n_reps", ProgramUniformVariableDataType::Uint32}, - {"num_present_sequence_length_tile", ProgramUniformVariableDataType::Uint32}, - {"batch_heads", ProgramUniformVariableDataType::Uint32}, - {"num_heads", ProgramUniformVariableDataType::Uint32}); - - private: uint32_t tile_size_; int head_size_vec_; bool use_indirect_dispatch_; - bool has_head_sink_; + bool q_BNSH_; + bool is_unidirectional_; + bool use_seqlen_k_; + uint32_t m_tile_; }; class FlashAttentionDecodeVxReduceProgram final : public Program { public: - FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool use_indirect_dispatch) - : Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), use_indirect_dispatch_(use_indirect_dispatch) { + FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool use_indirect_dispatch, bool use_seqlen_k = false, bool has_head_sink = false, uint32_t m_tile = 1) + : Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), use_indirect_dispatch_(use_indirect_dispatch), use_seqlen_k_(use_seqlen_k), has_head_sink_(has_head_sink), m_tile_(m_tile) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -178,12 +164,17 @@ class FlashAttentionDecodeVxReduceProgram final : public Program tile_q: array; -var inner_qk_values: array, tile_size>; -var tile_qk: array; - -#if has_attention_bias - fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t - { - // Handle broadcasting: if dimension size is 1, use index 0 - let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); - let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); - - // Calculate flat offset with broadcasting applied - // attention_bias shape: [attn_bias_dim0, attn_bias_dim1, new_seq_length, total_seq_length] - // For decode, new_seq_length is 1, so we can simplify: - let offset = bias_batch_idx * uniforms.attn_bias_dim1 * total_seq_length + - bias_head_idx * total_seq_length + - k_idx; - return attention_bias[offset]; - } -#else - fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t - { - return q_element_t(0); - } -#endif - -$MAIN { - let local_row = u32(local_idx / tile_size_k_vec); - let local_col = local_idx % tile_size_k_vec; -#if use_indirect_dispatch - let total_sequence_length = u32(seqlens_k[0]) + 1u; -#else - let total_sequence_length = uniforms.total_sequence_length; -#endif - let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; - let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size; - let batch_head_idx = u32(workgroup_idx / num_total_seq_length_tile); - let head_idx = batch_head_idx % uniforms.num_heads; - let batch_idx = batch_head_idx / uniforms.num_heads; - if (batch_idx >= uniforms.batch_size) { - return; - } - let q_offset = batch_idx * uniforms.num_heads * uniforms.head_size_vec + head_idx * uniforms.head_size_vec; - let present_offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * uniforms.head_size_vec; - for (var k: u32 = 0u; k < uniforms.head_size_vec; k += tile_size_k_vec) { - if (local_idx < tile_size_k_vec && k + local_idx < uniforms.head_size_vec) { - tile_q[local_idx] = q[q_offset + k + local_idx]; - } - workgroupBarrier(); - let q_data = tile_q[local_col] * q_element_t(uniforms.alpha); - if (k + local_col < uniforms.head_size_vec) { - for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) { - if (total_seq_offset + row_offset + local_row < total_sequence_length) { - inner_qk_values[row_offset + local_row][local_col] += dot(present_key[present_offset + (total_seq_offset + row_offset + local_row) * uniforms.head_size_vec + k + local_col], q_data); - } - } - } - workgroupBarrier(); - } - - if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) { - var sum = q_element_t(0); - for (var i = 0u; i < tile_size_k_vec; i++) { - sum += inner_qk_values[local_idx][i]; - } - - sum = sum + loadAttentionBias(batch_idx, head_idx, 0u, total_seq_offset + local_idx, total_sequence_length); - tile_qk[local_idx] = sum; - output[batch_head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx] = sum; - } - workgroupBarrier(); - - if (local_idx == 0u) { - // Calculate the max and sum in current split. - var l_max = f32(-3.4028234663852886e+38f); - var l_sum = f32(0); - for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) { - l_max = max(l_max, f32(tile_qk[i])); - } - for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) { - l_sum += exp(f32(tile_qk[i]) - l_max); - } - let meta_offset = batch_head_idx * uniforms.num_present_sequence_length_tile + workgroup_idx % num_total_seq_length_tile; - metadata[meta_offset] = metadata_value_t(l_max, l_sum); - } -} diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template new file mode 100644 index 0000000000000..04f732a488038 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template @@ -0,0 +1,195 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#param has_attention_bias +#param v_head_size_vec +#param is_unidirectional +#param m_tile +#param q_BNSH +#param sub_tile_count +#param tile_size +#param tile_size_k_vec +#param use_seqlen_k + +// Fused QK^T + softmax + V multiply shader. +// +// Each workgroup processes one KV tile (tile_size rows of present_key/value) +// for m_tile Q rows. The computation has two phases: +// +// Phase 1: QK^T (dot product of Q with K, attention bias, causal mask, +// per-tile max/sum for online softmax) +// Phase 2: Local softmax normalization + V multiply (using local max/sum, +// no cross-workgroup dependency) +// +// The VxReduce shader performs the final rescaling across tiles. + +var tile_q: array, m_tile>; +var inner_qk_values: array, tile_size>, m_tile>; +var tile_qk: array, m_tile>; +var tile_output: array, m_tile>; +var qkv_values: array, sub_tile_count>, m_tile>; +var tile_max: array; +var tile_sum: array; + +#if has_attention_bias + fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t + { + let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); + let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); + let offset = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * total_seq_length + + bias_head_idx * uniforms.new_sequence_length * total_seq_length + + q_idx * total_seq_length + + k_idx; + return attention_bias[offset]; + } +#else + fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t + { + return q_element_t(0); + } +#endif + +$MAIN { + let local_row = u32(local_idx / tile_size_k_vec); + let local_col = local_idx % tile_size_k_vec; + #if use_seqlen_k + let total_sequence_length = u32(seqlens_k[0]) + 1u; + #else + let total_sequence_length = uniforms.total_sequence_length; + #endif + let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; + let num_q_tiles = (uniforms.new_sequence_length + m_tile - 1) / m_tile; + // Workgroup layout: [batch_heads, num_q_tiles, num_total_seq_length_tile] + let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size; + let q_tile_idx = (workgroup_idx / num_total_seq_length_tile) % num_q_tiles; + let q_base = q_tile_idx * m_tile; + let batch_head_idx = u32(workgroup_idx / (num_total_seq_length_tile * num_q_tiles)); + let head_idx = batch_head_idx % uniforms.num_heads; + let batch_idx = batch_head_idx / uniforms.num_heads; + if (batch_idx >= uniforms.batch_size) { + return; + } + let present_key_offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * uniforms.head_size_vec; + let present_value_offset = u32(batch_head_idx / uniforms.n_reps) * v_head_size_vec * uniforms.present_sequence_length; + + // ============================================================ + // Phase 1: QK^T computation + // ============================================================ + for (var k: u32 = 0u; k < uniforms.head_size_vec; k += tile_size_k_vec) { + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + if (local_idx < tile_size_k_vec && k + local_idx < uniforms.head_size_vec) { + let q_idx = q_base + m; +#if q_BNSH + let q_offset = batch_idx * uniforms.num_heads * uniforms.new_sequence_length * uniforms.head_size_vec + + head_idx * uniforms.new_sequence_length * uniforms.head_size_vec + + q_idx * uniforms.head_size_vec; +#else + let q_offset = batch_idx * uniforms.new_sequence_length * uniforms.num_heads * uniforms.head_size_vec + + q_idx * uniforms.num_heads * uniforms.head_size_vec + + head_idx * uniforms.head_size_vec; +#endif + tile_q[m][local_idx] = q[q_offset + k + local_idx]; + } + } + workgroupBarrier(); + if (k + local_col < uniforms.head_size_vec) { + for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) { + if (total_seq_offset + row_offset + local_row < total_sequence_length) { + let k_data = present_key[present_key_offset + (total_seq_offset + row_offset + local_row) * uniforms.head_size_vec + k + local_col]; + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + let q_data = tile_q[m][local_col] * q_element_t(uniforms.alpha); + inner_qk_values[m][row_offset + local_row][local_col] += dot(k_data, q_data); + } + } + } + } + workgroupBarrier(); + } + + // Reduce inner_qk_values to tile_qk, apply attention bias and causal mask + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + let q_idx = q_base + m; + if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) { + var sum = q_element_t(0); + for (var i = 0u; i < tile_size_k_vec; i++) { + sum += inner_qk_values[m][local_idx][i]; + } + + sum = sum + loadAttentionBias(batch_idx, head_idx, q_idx, total_seq_offset + local_idx, total_sequence_length); +#if is_unidirectional + if (total_seq_offset + local_idx > total_sequence_length - uniforms.new_sequence_length + q_idx) { + sum = q_element_t(-65504.0f); + } +#endif + tile_qk[m][local_idx] = present_value_element_t(sum); + } + workgroupBarrier(); + + // Compute per-tile max and sum for online softmax + if (local_idx == 0u) { + var l_max = f32(-3.4028234663852886e+38f); + var l_sum = f32(0); + for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) { + l_max = max(l_max, f32(tile_qk[m][i])); + } + for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) { + l_sum += exp(f32(tile_qk[m][i]) - l_max); + } + tile_max[m] = l_max; + tile_sum[m] = l_sum; + let meta_offset = (batch_head_idx * uniforms.new_sequence_length + q_idx) * uniforms.num_present_sequence_length_tile + workgroup_idx % num_total_seq_length_tile; + metadata[meta_offset] = metadata_value_t(l_max, l_sum); + } + } + workgroupBarrier(); + + // ============================================================ + // Phase 2: Local softmax + V multiply + // ============================================================ + + // Normalize tile_qk with local max/sum + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) { + tile_qk[m][local_idx] = present_value_element_t(exp(f32(tile_qk[m][local_idx]) - tile_max[m]) / tile_sum[m]); + } + } + workgroupBarrier(); + + for (var k: u32 = 0u; k < v_head_size_vec; k += tile_size_k_vec) { + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + qkv_values[m][local_row][local_col] = present_value_value_t(0); + } + workgroupBarrier(); + + if (k + local_col < v_head_size_vec) { + for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) { + if (total_seq_offset + row_offset + local_row < total_sequence_length) { + let v_data = present_value[present_value_offset + (total_seq_offset + row_offset + local_row) * v_head_size_vec + k + local_col]; + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + qkv_values[m][local_row][local_col] += v_data * tile_qk[m][row_offset + local_row]; + } + } + } + } + workgroupBarrier(); + + if (local_idx < tile_size_k_vec) { + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + for (var i = 0u; i < sub_tile_count; i++) { + tile_output[m][k + local_idx] += qkv_values[m][i][local_idx]; + } + } + } + workgroupBarrier(); + } + + // Write output + let tile_idx = workgroup_idx % num_total_seq_length_tile; + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + let q_idx = q_base + m; + let out_base = (batch_head_idx * uniforms.new_sequence_length + q_idx) * uniforms.num_present_sequence_length_tile * v_head_size_vec; + for (var i = local_idx; i < v_head_size_vec; i += workgroup_size_x) { + out_split_vx[out_base + tile_idx * v_head_size_vec + i] = tile_output[m][i]; + } + } +} diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template deleted file mode 100644 index 6f1ad1ca41b71..0000000000000 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#param has_head_sink -#param tile_size -#param head_size_vec -#param tile_size_k_vec -#param sub_tile_count -#param use_indirect_dispatch - -// Note that this shader adopts similar algorithm with dp4a generation shader. -// -// This algorithm works to compute dot product of v with qk parallelly, by -// processing on the head_size dimension at each step amongst tile_size_k_vec -// threads, and utilizing the remaining threads in the workgroup to process -// additional rows of |present_value| in parallel (such that the values in -// shared memory (tile_qk) for |qk| can be reused). The tile_size_k_vec threads -// also reload |present_value| tile_size/sub_tile_count times to compute partial -// dot products of other |present_value| rows in order to complete all tile_size -// |present_value| rows in this workgroup and also reusing the values in -// tile_qk. -// -// The difference with FlashAttentionDecodeQKTProgram is that the dot products -// go through the rows (total_sequence_length) of |present_value| instead of -// columns (head_size_vec). And each workgroup only calculate current -// tile_size's dot products instead of iterating the whole row -// |total_sequence_length|. That's why this shader is a split shader. The final -// reduce will be done in FlashAttentionDecodeReduceProgram. - -// TODO: Ideally, there should only be two shaders FlashAttentionDecodeSplitVx -// and FlashAttentionDecodeVxReduce, which can also reduce the intermediate -// memory. The FlashAttentionDecodeQKT can be merged into split shader and do -// the final softmax adjustment in the reduce shader. However, some issues are -// met that when the total sequence length exceeds some value, the result will -// become garbage. Since it can't be resolved in a short time, leave it as TODO -// to fix it in future. - -var tile_qk: array; -var tile_output: array; -var qkv_values: array, sub_tile_count>; - -$MAIN { - let local_row = u32(local_idx / tile_size_k_vec); - let local_col = local_idx % tile_size_k_vec; - #if use_indirect_dispatch - let total_sequence_length = u32(seqlens_k[0]) + 1u; - #else - let total_sequence_length = uniforms.total_sequence_length; - #endif - let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; - let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size; - let batch_head_idx = u32(workgroup_idx / num_total_seq_length_tile); - if (batch_head_idx >= uniforms.batch_heads) { - return; - } - let present_offset = u32(batch_head_idx / uniforms.n_reps) * head_size_vec * uniforms.present_sequence_length; - - // Calculate the global max and sum in qk. - var g_max = f32(-3.4028234663852886e+38f); -#if has_head_sink - let head_idx = batch_head_idx % uniforms.num_heads; - let sink_value = f32(head_sink[head_idx]); - g_max = max(g_max, sink_value); -#endif - var g_sum = f32(0); - for (var i = 0u; i < num_total_seq_length_tile; i++) - { - let meta_offset = batch_head_idx * uniforms.num_present_sequence_length_tile + i; - g_max = max(g_max, metadata[meta_offset].x); - } - for (var i = 0u; i < num_total_seq_length_tile; i++) - { - let meta_offset = batch_head_idx * uniforms.num_present_sequence_length_tile + i; - let m_value = metadata[meta_offset]; - g_sum += exp(m_value.x - g_max) * m_value.y; - } -#if has_head_sink - g_sum += exp(sink_value - g_max); -#endif - - if (total_seq_offset + local_idx < total_sequence_length) { - tile_qk[local_idx] = present_value_element_t(exp(f32(qk[batch_head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx]) - g_max) / g_sum); - } - - for (var k: u32 = 0u; k < head_size_vec; k += tile_size_k_vec) { - var value = present_value_value_t(0); - qkv_values[local_row][local_col] = present_value_value_t(0); - workgroupBarrier(); - - if (k + local_col < head_size_vec) { - for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) { - if (total_seq_offset + row_offset + local_row < total_sequence_length) { - value += present_value[present_offset + (total_seq_offset + row_offset + local_row) * head_size_vec + k + local_col] * tile_qk[row_offset + local_row]; - } - } - } - - qkv_values[local_row][local_col] = value; - workgroupBarrier(); - - if (local_idx < tile_size_k_vec) { - for (var i = 0u; i < sub_tile_count; i++) { - tile_output[k + local_idx] += qkv_values[i][local_idx]; - } - } - workgroupBarrier(); - } - - for (var i = local_idx; i < head_size_vec; i += workgroup_size_x) { - let out_offset = batch_head_idx * uniforms.num_present_sequence_length_tile * head_size_vec + (workgroup_idx % num_total_seq_length_tile) * head_size_vec + i; - out_split_vx[out_offset] = tile_output[i]; - } -} diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template index f909a87724da6..dccabb75254c8 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template @@ -1,57 +1,96 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#param has_head_sink +#param m_tile #param seq_tile_size #param tile_size -#param use_indirect_dispatch - -// Inputs are splits of the GQA output, split into num_total_seq_length_tiles -// rows. This shader needs to add these splits across the row dimension to -// arrive at the final result. The column is head size wide. The reduction -// achieves maximum parallelization by splitting this task first into tile_size -// columns that each workgroup is responsible for. Then within each workgroup -// the task of summation over the num_total_seq_length_tile for the tile_size -// columns is further split in two ways. First across the row dimension to have -// WORKGROUP_SIZE/TILE_SIZE parallel computations of summation of TILE_SIZE -// rows. Then across the column dimension where each thread is responsible for 1 -// column of the TILE_SIZE columns the workgroup is responsible for. +#param use_seqlen_k + +// This shader reduces partial V outputs from the fused QKV shader. +// Each tile produced a locally-normalized V contribution. To get the +// correct global result, we rescale each tile's contribution using +// per-tile metadata (max, sum) with online softmax: +// +// global_max = max(local_max_i for all tiles) +// global_sum = sum(local_sum_i * exp(local_max_i - global_max)) +// output[h] = sum(partial_i[h] * exp(local_max_i - global_max)) / global_sum var tile_input: array, tile_size>; $MAIN { + let num_q_tiles = (uniforms.new_sequence_length + m_tile - 1) / m_tile; + // Workgroup layout: [batch_heads, num_q_tiles, num_head_size_tile] let head_size_offset = (workgroup_idx % uniforms.num_head_size_tile) * tile_size; - let batch_head_idx = u32(workgroup_idx / uniforms.num_head_size_tile); + let q_tile_idx = (workgroup_idx / uniforms.num_head_size_tile) % num_q_tiles; + let q_base = q_tile_idx * m_tile; + let batch_head_idx = u32(workgroup_idx / (uniforms.num_head_size_tile * num_q_tiles)); if (batch_head_idx >= uniforms.batch_heads) { return; } - let in_offset = batch_head_idx * uniforms.num_present_sequence_length_tile * uniforms.head_size_vec; - var value = output_value_t(0); let local_row = u32(local_idx / tile_size); let local_col = local_idx % tile_size; - #if use_indirect_dispatch + #if use_seqlen_k let total_sequence_length = u32(seqlens_k[0]) + 1u; let num_total_seq_length_tile = (total_sequence_length + seq_tile_size - 1) / seq_tile_size; #else let num_total_seq_length_tile = uniforms.num_total_seq_length_tile; #endif - if (head_size_offset + local_col < uniforms.head_size_vec) { - for (var r = 0u; r < num_total_seq_length_tile; r += tile_size) { - if (r + local_row < num_total_seq_length_tile) { - value += input[in_offset + (r + local_row) * uniforms.head_size_vec + head_size_offset + local_col]; + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + let q_idx = q_base + m; + let in_offset = (batch_head_idx * uniforms.new_sequence_length + q_idx) * uniforms.num_present_sequence_length_tile * uniforms.head_size_vec; + let meta_base = (batch_head_idx * uniforms.new_sequence_length + q_idx) * uniforms.num_present_sequence_length_tile; + + // Compute global max across all tiles + var g_max = f32(-3.4028234663852886e+38f); +#if has_head_sink + let head_idx_for_sink = batch_head_idx % uniforms.num_heads; + let sink_value = f32(head_sink[head_idx_for_sink]); + g_max = max(g_max, sink_value); +#endif + for (var i = 0u; i < num_total_seq_length_tile; i++) { + g_max = max(g_max, metadata[meta_base + i].x); + } + + // Compute global sum with rescaling + var g_sum = f32(0); + for (var i = 0u; i < num_total_seq_length_tile; i++) { + let m_value = metadata[meta_base + i]; + g_sum += m_value.y * exp(m_value.x - g_max); + } +#if has_head_sink + g_sum += exp(sink_value - g_max); +#endif + + // Accumulate rescaled partial outputs + var value = output_value_t(0); + if (head_size_offset + local_col < uniforms.head_size_vec) { + for (var r = 0u; r < num_total_seq_length_tile; r += tile_size) { + if (r + local_row < num_total_seq_length_tile) { + let tile_meta = metadata[meta_base + r + local_row]; + let rescale_f32 = tile_meta.y * exp(tile_meta.x - g_max) / g_sum; + value += input[in_offset + (r + local_row) * uniforms.head_size_vec + head_size_offset + local_col] * output_value_t(output_element_t(rescale_f32)); + } } } - } - tile_input[local_row][local_col] = value; - workgroupBarrier(); + tile_input[local_row][local_col] = value; + workgroupBarrier(); - if (local_idx < tile_size && head_size_offset + local_idx < uniforms.head_size_vec) { - value = output_value_t(0); - for (var i = 0u; i < tile_size; i++) { - value += tile_input[i][local_idx]; + if (local_idx < tile_size && head_size_offset + local_idx < uniforms.head_size_vec) { + value = output_value_t(0); + for (var i = 0u; i < tile_size; i++) { + value += tile_input[i][local_idx]; + } + let head_idx = batch_head_idx % uniforms.num_heads; + let batch_idx = batch_head_idx / uniforms.num_heads; + let output_id = batch_idx * uniforms.new_sequence_length * uniforms.num_heads * uniforms.head_size_vec + + q_idx * uniforms.num_heads * uniforms.head_size_vec + + head_idx * uniforms.head_size_vec + + head_size_offset + local_idx; + output[output_id] = value; } - let output_id = batch_head_idx * uniforms.head_size_vec + head_size_offset + local_idx; - output[output_id] = value; + workgroupBarrier(); } }