Iterative implementations of Flash Attention 2 in CUDA, optimizing for performance. Tested on a Nvidia A10G GPU on an Amazon EC2 g5.xlarge instance.
Version | Optimization | Code | Duration | Cycles | Compute Throughput % | Memory Throughput % | Notes |
---|---|---|---|---|---|---|---|
V1 | Baseline | Link | 2.08s | 2,741,237,357 | 0.27% | 1.19% | Estimated speedup 98.75% since only 1 of 80 SMs being used. Compiling with nvcc -o flash_attention2_v1 flash_attention2_v1.cu -lineinfo . |
V2 | Parallelized work over multiple thread blocks | Link | 32.52ms | 42,890,564 | 17.25% | 76.00% | Uncoalesced shared accesses est speedup 86.73%, shared load bank conflicts est speedup 78.20%, L1TEX local store access pattern est speedup 74.97%. Matrix multiplication is primary memory overhead. |
V3 | Matrix multiplication multiplies (A @ B) instead of (A @ B.T) | Link | 8.64ms | 11,394,958 | 64.93% | 64.93% | L1TEX local store access pattern est speedup 55.43%; Memory I/O causing warp stalls. matrix_block_load_transpose() seems to have a big memory overhead. |
V4 | Faster matrix multiplication using registers based on https://siboehm.com/articles/22/CUDA-MMM | Link | 38.32ms | 50,579,102 | 45.72% | 45.72% | Why is this slower than V3? Seems to be using local memory not registers. |
V5 | (builds off V3) Add padding to matrix load transpose to attempt to reduce smem bank store conflicts | Link | 8.65ms | 11,411,135 | 64.84% | 64.84% | Matrix multiplication needs to be improved. li_update and mi_update also have excessive L1 wavefronts. Padding doesn't seem to have done much... |
V6 | V4 matmul code is correct but is using local memory (slow) because the indexing isn't computable at compile time (ref). | Link | 6.32ms | 8,342,705 | 55.81% | 55.81% | Threads per block reduced to 512 to allow for more register space (only 64k per thread block, according to technical specifications). However matmul is much faster so this is worth doing. New compile command: nvcc -o flash_attention2_v6 flash_attention2_v6.cu -lineinfo -Xptxas -v -O3 -maxrregcount 128 to utilize as many registers as possible. Tried various blocktiling sizes (constant T), T=4 has best performance. This builds off V4, padding didn't do much. |
V7 | Use cooperative_groups::memcpy_async to load HBM to shared memory |
Link | 6.69ms | 8,833,354 | 52.65% | 52.65% | Async memory seems to be slower, maybe not enough time between operations to reap benefits... could try double buffering in matrix multiply to get async memory load benefits. |
V8 | Use warp reduction techniques in rowmax op of mi_update and rowsum op of li_update (this also hopefully reduces bank conflicts) |
Link | 6.08ms | 8,105,621 | 59.97% | 59.97% | |
V9 | Fuse divide_by_scalar , mi_update , si_to_pi , and li_update operations into single function (saves repeated smem loads and syncthreads between ops) |
Link | 5.61ms | 7,400,516 | 62.72% | 62.72% | |
V10 | Remove bound checking / last block logic - assume the parameters used are N=M=8192, d=32 | Link | 5.53ms | 7,291,137 | 63.53% | 63.53% |
The python script verify_output.py
computes the attention operation on the same input matrices in PyTorch (using the GPU), then verifies the output generated by the CUDA code matches it.
Steps for running the correctness test:
- Compile the cuda program.
- Run the cuda program with a filepath argument (ex:
./flash_attention2_v1.cu result.out
). The output matrix O will be saved in the file. - Run the verify_output script with the filepath (ex:
python3 verify_output.py result.out
)
- Utilize tensor cores for matrix multiplication and switch to fp16
- Extend to multi-head attention