Skip to content

rishisankar/flashattention2

Repository files navigation

Flash Attention 2 in CUDA

Iterative implementations of Flash Attention 2 in CUDA, optimizing for performance. Tested on a Nvidia A10G GPU on an Amazon EC2 g5.xlarge instance.

Worklog (optimizing with Nsight Compute)

Version Optimization Code Duration Cycles Compute Throughput % Memory Throughput % Notes
V1 Baseline Link 2.08s 2,741,237,357 0.27% 1.19% Estimated speedup 98.75% since only 1 of 80 SMs being used. Compiling with nvcc -o flash_attention2_v1 flash_attention2_v1.cu -lineinfo.
V2 Parallelized work over multiple thread blocks Link 32.52ms 42,890,564 17.25% 76.00% Uncoalesced shared accesses est speedup 86.73%, shared load bank conflicts est speedup 78.20%, L1TEX local store access pattern est speedup 74.97%. Matrix multiplication is primary memory overhead.
V3 Matrix multiplication multiplies (A @ B) instead of (A @ B.T) Link 8.64ms 11,394,958 64.93% 64.93% L1TEX local store access pattern est speedup 55.43%; Memory I/O causing warp stalls. matrix_block_load_transpose() seems to have a big memory overhead.
V4 Faster matrix multiplication using registers based on https://siboehm.com/articles/22/CUDA-MMM Link 38.32ms 50,579,102 45.72% 45.72% Why is this slower than V3? Seems to be using local memory not registers.
V5 (builds off V3) Add padding to matrix load transpose to attempt to reduce smem bank store conflicts Link 8.65ms 11,411,135 64.84% 64.84% Matrix multiplication needs to be improved. li_update and mi_update also have excessive L1 wavefronts. Padding doesn't seem to have done much...
V6 V4 matmul code is correct but is using local memory (slow) because the indexing isn't computable at compile time (ref). Link 6.32ms 8,342,705 55.81% 55.81% Threads per block reduced to 512 to allow for more register space (only 64k per thread block, according to technical specifications). However matmul is much faster so this is worth doing. New compile command: nvcc -o flash_attention2_v6 flash_attention2_v6.cu -lineinfo -Xptxas -v -O3 -maxrregcount 128 to utilize as many registers as possible. Tried various blocktiling sizes (constant T), T=4 has best performance. This builds off V4, padding didn't do much.
V7 Use cooperative_groups::memcpy_async to load HBM to shared memory Link 6.69ms 8,833,354 52.65% 52.65% Async memory seems to be slower, maybe not enough time between operations to reap benefits... could try double buffering in matrix multiply to get async memory load benefits.
V8 Use warp reduction techniques in rowmax op of mi_update and rowsum op of li_update (this also hopefully reduces bank conflicts) Link 6.08ms 8,105,621 59.97% 59.97%
V9 Fuse divide_by_scalar, mi_update, si_to_pi, and li_update operations into single function (saves repeated smem loads and syncthreads between ops) Link 5.61ms 7,400,516 62.72% 62.72%
V10 Remove bound checking / last block logic - assume the parameters used are N=M=8192, d=32 Link 5.53ms 7,291,137 63.53% 63.53%

Testing

The python script verify_output.py computes the attention operation on the same input matrices in PyTorch (using the GPU), then verifies the output generated by the CUDA code matches it.

Steps for running the correctness test:

  1. Compile the cuda program.
  2. Run the cuda program with a filepath argument (ex: ./flash_attention2_v1.cu result.out). The output matrix O will be saved in the file.
  3. Run the verify_output script with the filepath (ex: python3 verify_output.py result.out)

Future

  • Utilize tensor cores for matrix multiplication and switch to fp16
  • Extend to multi-head attention

About

Flash Attention 2 CUDA implementations

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published