Skip to content

Conversation

@foreverYoungGitHub
Copy link

@foreverYoungGitHub foreverYoungGitHub commented Sep 22, 2025

Summary

Implements the missing flash_attn_varlen_qkvpacked_func function and FlashAttnVarlenQKVPackedFunc class in the hopper implementation (flash_attn_3) to achieve API compatibility with flash_attn. (relative issue #1501)

What's Added

  • FlashAttnVarlenQKVPackedFunc class - PyTorch autograd function for variable-length sequences with QKV packed format
  • flash_attn_varlen_qkvpacked_func function - Public API wrapper maintaining flash_attn compatibility
  • test suite - test_flash_attn_varlen_qkvpacked_output with parametrized testing

Risks & Considerations

  • ⚠️ Numerical precision - Flash attention implementations can have subtle numerical differences; tolerances set conservatively

test output:

Output max diff: 0.0078125
Output mean diff: 0.00017070770263671875
dQKV max diff: 0.0078125
dQKV mean diff: 0.00020885467529296875
.Output max diff: 0.0009765625
Output mean diff: 2.1457672119140625e-05
dQKV max diff: 0.0009765625
dQKV mean diff: 2.5987625122070312e-05
.Output max diff: 0.015625
Output mean diff: 0.0002117156982421875
dQKV max diff: 0.03125
dQKV mean diff: 0.00026702880859375
.Output max diff: 0.0009765625
Output mean diff: 2.6464462280273438e-05
dQKV max diff: 0.001953125
dQKV mean diff: 3.36766242980957e-05
.Output max diff: 0.00390625
Output mean diff: 0.0001697540283203125
dQKV max diff: 0.015625
dQKV mean diff: 0.0002079010009765625
.Output max diff: 0.00048828125
Output mean diff: 2.1278858184814453e-05
dQKV max diff: 0.0009765625
dQKV mean diff: 2.5987625122070312e-05
.Output max diff: 0.015625
Output mean diff: 0.0002117156982421875
dQKV max diff: 0.03125
dQKV mean diff: 0.0002651214599609375
.Output max diff: 0.0009765625
Output mean diff: 2.6404857635498047e-05
dQKV max diff: 0.00390625
dQKV mean diff: 3.421306610107422e-05
.Output max diff: 0.0078125
Output mean diff: 0.00017261505126953125
dQKV max diff: 0.0078125
dQKV mean diff: 0.00020599365234375
.Output max diff: 0.0009765625
Output mean diff: 2.1338462829589844e-05
dQKV max diff: 0.0009765625
dQKV mean diff: 2.574920654296875e-05
.Output max diff: 0.015625
Output mean diff: 0.000213623046875
dQKV max diff: 0.015625
dQKV mean diff: 0.0002689361572265625
.Output max diff: 0.001953125
Output mean diff: 2.6524066925048828e-05
dQKV max diff: 0.00390625
dQKV mean diff: 3.325939178466797e-05
.Output max diff: 0.00390625
Output mean diff: 0.00018215179443359375
dQKV max diff: 0.015625
dQKV mean diff: 0.00021839141845703125
.Output max diff: 0.00048828125
Output mean diff: 2.282857894897461e-05
dQKV max diff: 0.001953125
dQKV mean diff: 2.7120113372802734e-05
.Output max diff: 0.015625
Output mean diff: 0.0002288818359375
dQKV max diff: 0.03125
dQKV mean diff: 0.0002880096435546875
.Output max diff: 0.001953125
Output mean diff: 2.8789043426513672e-05
dQKV max diff: 0.00390625
dQKV mean diff: 3.534555435180664e-05
.Output max diff: 0.0078125
Output mean diff: 0.00018310546875
dQKV max diff: 0.0078125
dQKV mean diff: 0.000217437744140625
.Output max diff: 0.00048828125
Output mean diff: 2.2649765014648438e-05
dQKV max diff: 0.0009765625
dQKV mean diff: 2.7477741241455078e-05
.Output max diff: 0.015625
Output mean diff: 0.00022983551025390625
dQKV max diff: 0.03125
dQKV mean diff: 0.0002880096435546875
.Output max diff: 0.001953125
Output mean diff: 2.86102294921875e-05
dQKV max diff: 0.00390625
dQKV mean diff: 3.5881996154785156e-05
.Output max diff: 0.00390625
Output mean diff: 0.00014495849609375
dQKV max diff: 0.0078125
dQKV mean diff: 0.00016689300537109375
.Output max diff: 0.00048828125
Output mean diff: 1.811981201171875e-05
dQKV max diff: 0.0009765625
dQKV mean diff: 2.086162567138672e-05
.Output max diff: 0.015625
Output mean diff: 0.0001983642578125
dQKV max diff: 0.03125
dQKV mean diff: 0.00023174285888671875
.Output max diff: 0.0009765625
Output mean diff: 2.4974346160888672e-05
dQKV max diff: 0.00390625
dQKV mean diff: 2.8789043426513672e-05
.Output max diff: 0.00390625
Output mean diff: 0.000148773193359375
dQKV max diff: 0.0078125
dQKV mean diff: 0.00017070770263671875
.Output max diff: 0.00048828125
Output mean diff: 1.8477439880371094e-05
dQKV max diff: 0.0009765625
dQKV mean diff: 2.1338462829589844e-05
.Output max diff: 0.015625
Output mean diff: 0.0002079010009765625
dQKV max diff: 0.03125
dQKV mean diff: 0.0002384185791015625
.Output max diff: 0.0009765625
Output mean diff: 2.580881118774414e-05
dQKV max diff: 0.00390625
dQKV mean diff: 2.9802322387695312e-05
.Output max diff: 0.00390625
Output mean diff: 0.000148773193359375
dQKV max diff: 0.0078125
dQKV mean diff: 0.00016880035400390625
.Output max diff: 0.00048828125
Output mean diff: 1.8537044525146484e-05
dQKV max diff: 0.0009765625
dQKV mean diff: 2.1159648895263672e-05
.Output max diff: 0.015625
Output mean diff: 0.00020694732666015625
dQKV max diff: 0.03125
dQKV mean diff: 0.00023746490478515625
.Output max diff: 0.001953125
Output mean diff: 2.580881118774414e-05
dQKV max diff: 0.00390625
dQKV mean diff: 2.9742717742919922e-05
.Output max diff: 0.00390625
Output mean diff: 0.000148773193359375
dQKV max diff: 0.0078125
dQKV mean diff: 0.0001735687255859375
.Output max diff: 0.00048828125
Output mean diff: 1.8596649169921875e-05
dQKV max diff: 0.0009765625
dQKV mean diff: 2.1755695343017578e-05
.Output max diff: 0.015625
Output mean diff: 0.00021266937255859375
dQKV max diff: 0.03125
dQKV mean diff: 0.000240325927734375
.Output max diff: 0.001953125
Output mean diff: 2.6524066925048828e-05
dQKV max diff: 0.00390625
dQKV mean diff: 3.0279159545898438e-05
.Output max diff: 0.00390625
Output mean diff: 0.00014781951904296875
dQKV max diff: 0.0078125
dQKV mean diff: 0.0001735687255859375
.Output max diff: 0.00048828125
Output mean diff: 1.8358230590820312e-05
dQKV max diff: 0.0009765625
dQKV mean diff: 2.1576881408691406e-05
.Output max diff: 0.015625
Output mean diff: 0.0002117156982421875
dQKV max diff: 0.03125
dQKV mean diff: 0.00024318695068359375
.Output max diff: 0.001953125
Output mean diff: 2.6464462280273438e-05
dQKV max diff: 0.00390625
dQKV mean diff: 3.0279159545898438e-05
.Output max diff: 0.00390625
Output mean diff: 0.00011110305786132812
dQKV max diff: 0.00390625
dQKV mean diff: 0.00012683868408203125
.Output max diff: 0.00048828125
Output mean diff: 1.3887882232666016e-05
dQKV max diff: 0.00048828125
dQKV mean diff: 1.5854835510253906e-05
.Output max diff: 0.015625
Output mean diff: 0.00016880035400390625
dQKV max diff: 0.03125
dQKV mean diff: 0.00018405914306640625
.Output max diff: 0.0009765625
Output mean diff: 2.1338462829589844e-05
dQKV max diff: 0.001953125
dQKV mean diff: 2.3066997528076172e-05
.Output max diff: 0.00390625
Output mean diff: 0.0001125335693359375
dQKV max diff: 0.0078125
dQKV mean diff: 0.00012874603271484375
.Output max diff: 0.00048828125
Output mean diff: 1.4007091522216797e-05
dQKV max diff: 0.0009765625
dQKV mean diff: 1.609325408935547e-05
.Output max diff: 0.015625
Output mean diff: 0.00017261505126953125
dQKV max diff: 0.03125
dQKV mean diff: 0.00018978118896484375
.Output max diff: 0.001953125
Output mean diff: 2.1457672119140625e-05
dQKV max diff: 0.00390625
dQKV mean diff: 2.3663043975830078e-05
.Output max diff: 0.00390625
Output mean diff: 0.0001125335693359375
dQKV max diff: 0.0078125
dQKV mean diff: 0.0001277923583984375
.Output max diff: 0.00048828125
Output mean diff: 1.4066696166992188e-05
dQKV max diff: 0.00048828125
dQKV mean diff: 1.5974044799804688e-05
.Output max diff: 0.015625
Output mean diff: 0.000171661376953125
dQKV max diff: 0.03125
dQKV mean diff: 0.0001888275146484375
.Output max diff: 0.0009765625
Output mean diff: 2.1398067474365234e-05
dQKV max diff: 0.00390625
dQKV mean diff: 2.3305416107177734e-05
.Output max diff: 0.00390625
Output mean diff: 0.00010919570922851562
dQKV max diff: 0.00390625
dQKV mean diff: 0.000125885009765625
.Output max diff: 0.00048828125
Output mean diff: 1.3649463653564453e-05
dQKV max diff: 0.0009765625
dQKV mean diff: 1.5676021575927734e-05
.Output max diff: 0.015625
Output mean diff: 0.00016880035400390625
dQKV max diff: 0.03125
dQKV mean diff: 0.00018310546875
.Output max diff: 0.001953125
Output mean diff: 2.104043960571289e-05
dQKV max diff: 0.001953125
dQKV mean diff: 2.300739288330078e-05
.Output max diff: 0.00390625
Output mean diff: 0.0001087188720703125
dQKV max diff: 0.00390625
dQKV mean diff: 0.00012493133544921875
.Output max diff: 0.00048828125
Output mean diff: 1.3530254364013672e-05
dQKV max diff: 0.0009765625
dQKV mean diff: 1.5676021575927734e-05
.Output max diff: 0.015625
Output mean diff: 0.00016689300537109375
dQKV max diff: 0.03125
dQKV mean diff: 0.0001850128173828125
.Output max diff: 0.001953125
Output mean diff: 2.0802021026611328e-05
dQKV max diff: 0.00390625
dQKV mean diff: 2.3066997528076172e-05
.Output max diff: 0.001953125
Output mean diff: 8.344650268554688e-05
dQKV max diff: 0.0078125
dQKV mean diff: 9.34600830078125e-05
.Output max diff: 0.000244140625
Output mean diff: 1.043081283569336e-05
dQKV max diff: 0.0009765625
dQKV mean diff: 1.1682510375976562e-05
.Output max diff: 0.015625
Output mean diff: 0.000133514404296875
dQKV max diff: 0.03125
dQKV mean diff: 0.000141143798828125
.Output max diff: 0.0009765625
Output mean diff: 1.6748905181884766e-05
dQKV max diff: 0.001953125
dQKV mean diff: 1.7702579498291016e-05
.Output max diff: 0.001953125
Output mean diff: 8.20159912109375e-05
dQKV max diff: 0.00390625
dQKV mean diff: 9.250640869140625e-05
.Output max diff: 0.000244140625
Output mean diff: 1.0251998901367188e-05
dQKV max diff: 0.0009765625
dQKV mean diff: 1.1563301086425781e-05
.Output max diff: 0.015625
Output mean diff: 0.00013256072998046875
dQKV max diff: 0.03125
dQKV mean diff: 0.0001392364501953125
.Output max diff: 0.0009765625
Output mean diff: 1.6570091247558594e-05
dQKV max diff: 0.00390625
dQKV mean diff: 1.7404556274414062e-05
.Output max diff: 0.001953125
Output mean diff: 8.249282836914062e-05
dQKV max diff: 0.00390625
dQKV mean diff: 9.250640869140625e-05
.Output max diff: 0.000244140625
Output mean diff: 1.0311603546142578e-05
dQKV max diff: 0.0009765625
dQKV mean diff: 1.1563301086425781e-05
.Output max diff: 0.015625
Output mean diff: 0.00013256072998046875
dQKV max diff: 0.03125
dQKV mean diff: 0.0001392364501953125
.Output max diff: 0.001953125
Output mean diff: 1.6570091247558594e-05
dQKV max diff: 0.001953125
dQKV mean diff: 1.7404556274414062e-05
.Output max diff: 0.001953125
Output mean diff: 8.153915405273438e-05
dQKV max diff: 0.0078125
dQKV mean diff: 9.202957153320312e-05
.Output max diff: 0.000244140625
Output mean diff: 1.0192394256591797e-05
dQKV max diff: 0.0009765625
dQKV mean diff: 1.1563301086425781e-05
.Output max diff: 0.015625
Output mean diff: 0.00013256072998046875
dQKV max diff: 0.03125
dQKV mean diff: 0.00014019012451171875
.Output max diff: 0.001953125
Output mean diff: 1.6570091247558594e-05
dQKV max diff: 0.00390625
dQKV mean diff: 1.7404556274414062e-05
.Output max diff: 0.001953125
Output mean diff: 8.106231689453125e-05
dQKV max diff: 0.00390625
dQKV mean diff: 9.202957153320312e-05
.Output max diff: 0.000244140625
Output mean diff: 1.0132789611816406e-05
dQKV max diff: 0.00048828125
dQKV mean diff: 1.150369644165039e-05
.Output max diff: 0.015625
Output mean diff: 0.0001316070556640625
dQKV max diff: 0.03125
dQKV mean diff: 0.0001392364501953125
.Output max diff: 0.001953125
Output mean diff: 1.633167266845703e-05
dQKV max diff: 0.00390625
dQKV mean diff: 1.7523765563964844e-05
..
81 passed in 4.07s

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant