Skip to content

Conversation

@timmy-feng
Copy link
Contributor

@timmy-feng timmy-feng commented Nov 7, 2025

This PR to allow the kernel to be compiled for only a single q tile. We enable this by default for split KV.

Previously, FA4 had a cta tile size of 256 along the query dimension. The Blackwell tensor core op has an M dimension of 128, so this constitutes two parallel tensor core pipelines. During inference, we never actually reach 128 queries though, even with GQA, so FA4 is performing twice the amount of computation wastefully.

Benchmarks

With this change, split KV now clearly beats an FA2 baseline. Here is a sweep across num splits:

q_stage_1
benchmark_type,num_splits,B,Q,K,H,D,kvheads_per_group,arithmetic_intensity,avg_time,tflops,bw,correctness
standard,1,1,1,131072,32,128,1,0.999992370663676,1.8188637526247513,1.180673178461568,1.180682186283059,True
standard,2,1,1,131072,32,128,1,0.999992370663676,0.6147790486642788,3.49309829712936,3.493124947354405,True
standard,4,1,1,131072,32,128,1,0.999992370663676,0.3908908886900404,5.493818633626075,5.493860548135914,True
standard,8,1,1,131072,32,128,1,0.999992370663676,0.3996654222757412,5.373203505502126,5.373244499791566,True
standard,16,1,1,131072,32,128,1,0.999992370663676,0.40813399929365457,5.261712211471199,5.261752355149571,True
standard,32,1,1,131072,32,128,1,0.999992370663676,0.38960500021684,5.511950942120324,5.511992994968698,True
standard,64,1,1,131072,32,128,1,0.999992370663676,0.41058456134784693,5.230307834640313,5.230347738722304,True
standard,128,1,1,131072,32,128,1,0.999992370663676,0.44936073867275433,4.778974803946766,4.7790112646309995,True

Correctness

All tests, including split KV tests, continue to pass:

============================= test session starts ==============================
platform linux -- Python 3.12.1, pytest-8.4.2, pluggy-1.6.0
rootdir: /root
collected 2592 items

............................... [  1%]
............ssssssssssssssssss.......................................... [  4%]
............................................................ssssssssssss [  7%]
ssssss......ssssssssssssssssss.......................................... [  9%]
........................................................................ [ 12%]
....................................ssssssssssssssssss......ssssssssssss [ 15%]
ssssss......ssssssssssssssssss.......................................... [ 18%]
............ssssssssssssssssss.......................................... [ 21%]
............................... [ 23%]
........................................................................ [ 26%]
........................................................................ [ 29%]
........................ssssssssssss...... [ 32%]
..............................ssssssssssss.............................. [ 34%]
......ssssssssssss....................................ssssssssssss...... [ 37%]
.............................................................. [ 40%]
........................................................................ [ 43%]
........................................................................ [ 46%]
........................................................................ [ 48%]
........................................................................ [ 51%]
........................................................................ [ 54%]
........................................................................ [ 57%]
........................................................................ [ 59%]
........................................................................ [ 62%]
........................................................................ [ 65%]
........................................................................ [ 68%]
........................................................................ [ 71%]
........................................................................ [ 73%]
........................................................................ [ 76%]
........................................................................ [ 79%]
........................................................................ [ 82%]
........................................................................ [ 84%]
........................................................................ [ 87%]
........................................................................ [ 90%]
........................................................................ [ 93%]
................................................................... [ 96%]
........................................................................ [ 98%]
..............................                                           [100%]

=============================== warnings summary ===============================
tests/cute/test_flash_attn.py: 6358 warnings
  /usr/local/lib/python3.12/site-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/_mlir_helpers/op.py:60: DeprecationWarning: `make_fragment` is deprecated, use `make_rmem_tensor` instead
    res_or_list = opFunc(*args, **kwargs, loc=loc)

tests/cute/test_flash_attn.py: 37748 warnings
  /usr/local/lib/python3.12/site-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/_mlir_helpers/op.py:60: DeprecationWarning: cute.arch.exp2 is deprecated, use cute.math.exp2 with `fastmath=True` instead
    res_or_list = opFunc(*args, **kwargs, loc=loc)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======== 2418 passed, 174 skipped, 44106 warnings in 479.10s (0:07:59) =========

@tridao
Copy link
Member

tridao commented Nov 10, 2025

Clarifying my understanding: if q_stage == 1, is there overlapping between softmax and mma?

@timmy-feng
Copy link
Contributor Author

No, I didn't do anything special to make that happen.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants