-
Notifications
You must be signed in to change notification settings - Fork 12.3k
vulkan: support softmax/FA batch and broadcast #14449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
vulkan: support softmax/FA batch and broadcast #14449
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Feel free to merge to the target branch when review is ready.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code looks fine to me, the tests pass on AMD and Nvidia. A lot of FA tests fail on Intel, but that is already the case on master.
5b5e27e
into
ggml-org:gg/ggml-batch-soft-max-ops
@@ -652,6 +652,7 @@ struct vk_flash_attn_push_constants { | |||
uint32_t split_kv; | |||
uint32_t k_num; | |||
}; | |||
static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this 128 byte a hard limit of the size of the struct? I need to add nem3
for the change in #14505 and it will cross the limit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
128 is the minimum required by the Vulkan spec, although I think it went up to 256 with Vulkan 1.4. The Vulkan device database has a list: https://vulkan.gpuinfo.org/displaydevicelimit.php?name=maxPushConstantsSize&platform=all
If we want to support older implementations/drivers, it has to stay at or below 128 (or shaders that need more have to be disabled on devices that don't support them)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
128 bytes is the minimum required push constant size, so if we go over that then we would need to check for support and we would probably lose a lot of devices (see http://vulkan.gpuinfo.org/displaydevicelimit.php?name=maxPushConstantsSize&platform=all).
If you just need one more dword, there are some existing fields you could pack together, e.g. mask (which is just a bool) and maybe n_head_log2. Sadly, 16-bit push constants are not available everywhere so you'll probably need to pack them into a uint32_t.
* origin/master: llama : initial Mamba-2 support (ggml-org#9126) sync : ggml ggml : add version function to get lib version (ggml/1286) Set RPATH to "@loader_path" / "$ORIGIN" to ensure executables and dynamic libraries search for dependencies in their origin directory. (ggml-org#14309) CUDA: add softmax broadcast (ggml-org#14475) CUDA: broadcasting for FlashAttention mask (ggml-org#14500) vulkan: support softmax/FA batch and broadcast (ggml-org#14449) ggml : support bcast ggml_soft_max_ext, ggml_flash_attn_ext (ggml-org#14435) opencl : fix possible buffer overflow in dump_tensor (ggml-org#14490) simple-chat : fix context-exceeded condition (ggml-org#14494) opencl : skip empty nodes on cgraph compute (ggml-org#14491) opencl : update upscale to support align corners (ggml-org#14488) ci : add OpenCL to labeler workflow (ggml-org#14496) github : add OpenCL backend to issue templates (ggml-org#14492) ggml : Callback before abort (ggml-org#14481) ci : disable fast-math for Metal GHA CI (ggml-org#14478)
For #14435.