Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 84 additions & 119 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

Large diffs are not rendered by default.

51 changes: 21 additions & 30 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,14 @@
bool has_head_sink_;
};

class FlashAttentionDecodeQKTProgram final : public Program<FlashAttentionDecodeQKTProgram> {
class FlashAttentionDecodeQKVProgram final : public Program<FlashAttentionDecodeQKVProgram> {
public:
FlashAttentionDecodeQKTProgram(const std::string& kernel_name,
bool has_attention_bias, uint32_t tile_size, bool use_indirect_dispatch)
: Program{kernel_name}, has_attention_bias_(has_attention_bias), tile_size_(tile_size), use_indirect_dispatch_(use_indirect_dispatch) {
FlashAttentionDecodeQKVProgram(const std::string& kernel_name,
bool has_attention_bias, uint32_t tile_size, int head_size_vec,
bool use_indirect_dispatch, bool q_BNSH = false,
bool is_unidirectional = false, bool use_seqlen_k = false,
uint32_t m_tile = 1)
: Program{kernel_name}, has_attention_bias_(has_attention_bias), tile_size_(tile_size), head_size_vec_(head_size_vec), use_indirect_dispatch_(use_indirect_dispatch), q_BNSH_(q_BNSH), is_unidirectional_(is_unidirectional), use_seqlen_k_(use_seqlen_k), m_tile_(m_tile) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -135,41 +138,24 @@
{"num_heads", ProgramUniformVariableDataType::Uint32},
{"batch_size", ProgramUniformVariableDataType::Uint32},
{"attn_bias_dim0", ProgramUniformVariableDataType::Uint32},
{"attn_bias_dim1", ProgramUniformVariableDataType::Uint32});
{"attn_bias_dim1", ProgramUniformVariableDataType::Uint32},
{"new_sequence_length", ProgramUniformVariableDataType::Uint32});

private:
bool has_attention_bias_;
uint32_t tile_size_;
bool use_indirect_dispatch_;
};

class FlashAttentionDecodeSplitVxProgram final : public Program<FlashAttentionDecodeSplitVxProgram> {
public:
FlashAttentionDecodeSplitVxProgram(const std::string& kernel_name, uint32_t tile_size, int head_size_vec, bool use_indirect_dispatch, bool has_head_sink = false)
: Program{kernel_name}, tile_size_(tile_size), head_size_vec_(head_size_vec), use_indirect_dispatch_(use_indirect_dispatch), has_head_sink_(has_head_sink) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"total_sequence_length", ProgramUniformVariableDataType::Uint32},
{"head_size_vec", ProgramUniformVariableDataType::Uint32},
{"present_sequence_length", ProgramUniformVariableDataType::Uint32},
{"n_reps", ProgramUniformVariableDataType::Uint32},
{"num_present_sequence_length_tile", ProgramUniformVariableDataType::Uint32},
{"batch_heads", ProgramUniformVariableDataType::Uint32},
{"num_heads", ProgramUniformVariableDataType::Uint32});

private:
uint32_t tile_size_;
int head_size_vec_;
bool use_indirect_dispatch_;
bool has_head_sink_;
bool q_BNSH_;
bool is_unidirectional_;
bool use_seqlen_k_;
uint32_t m_tile_;
};

class FlashAttentionDecodeVxReduceProgram final : public Program<FlashAttentionDecodeVxReduceProgram> {
public:
FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool use_indirect_dispatch)
: Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), use_indirect_dispatch_(use_indirect_dispatch) {
FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool use_indirect_dispatch, bool use_seqlen_k = false, bool has_head_sink = false, uint32_t m_tile = 1)

Check warning on line 157 in onnxruntime/contrib_ops/webgpu/bert/flash_attention.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/flash_attention.h:157: Add #include <string> for string [build/include_what_you_use] [4]
: Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), use_indirect_dispatch_(use_indirect_dispatch), use_seqlen_k_(use_seqlen_k), has_head_sink_(has_head_sink), m_tile_(m_tile) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -178,12 +164,17 @@
{"num_total_seq_length_tile", ProgramUniformVariableDataType::Uint32},
{"num_present_sequence_length_tile", ProgramUniformVariableDataType::Uint32},
{"num_head_size_tile", ProgramUniformVariableDataType::Uint32},
{"batch_heads", ProgramUniformVariableDataType::Uint32});
{"batch_heads", ProgramUniformVariableDataType::Uint32},
{"new_sequence_length", ProgramUniformVariableDataType::Uint32},
{"num_heads", ProgramUniformVariableDataType::Uint32});

private:
uint32_t tile_size_;
uint32_t seq_tile_size_;
bool use_indirect_dispatch_;
bool use_seqlen_k_;
bool has_head_sink_;
uint32_t m_tile_;
};

Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
Expand Down

This file was deleted.

Loading
Loading