-
Notifications
You must be signed in to change notification settings - Fork 121
Expand file tree
/
Copy pathbench_attention.py
More file actions
124 lines (100 loc) · 4.06 KB
/
bench_attention.py
File metadata and controls
124 lines (100 loc) · 4.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
from math import ceil, sqrt
from conftest import dtype_id, shape_id
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import sdpa_kernel, SDPBackend
import cuda.tile as ct
import pytest
import torch
from util import estimate_bench_iter
from kernels.attention import fmha_kernel
def qkv_id(qkv_shape: tuple[tuple[int, ...], tuple[int, ...]]) -> str:
q_shape, kv_shape = qkv_shape
if q_shape[2] == 1:
prefix = "decode-"
else:
prefix = "prefill-"
b, q_head, q_len, d = q_shape
_, k_head, k_len, _ = kv_shape
return prefix + shape_id((b, q_head, k_head, q_len, k_len, d))
@pytest.fixture(
params=[
# B, H, L, D
((1, 32, 1024, 128), (1, 32, 1024, 128)), # prefill
((1, 32, 1024, 128), (1, 8, 1024, 128)), # prefill + gqa
((1, 32, 8192, 128), (1, 32, 8192, 128)), # prefill
((1, 32, 8192, 128), (1, 8, 8192, 128)), # prefill + gqa
((1, 32, 1, 128), (1, 32, 1024, 128)), # decode
((8, 32, 1, 128), (8, 32, 1024, 128)), # decode
((1, 32, 1, 128), (1, 8, 1024, 128)), # decode + gqa
((8, 32, 1, 128), (8, 8, 1024, 128)), # decode + gqa
],
ids=qkv_id)
def qkv_shape(request):
return request.param
@pytest.fixture(params=[torch.float16, torch.bfloat16], ids=dtype_id)
def dtype(request):
return request.param
@pytest.mark.benchmark(group='attention')
def bench_fmha(qkv_shape, dtype, backend, benchmark):
q_shape, kv_shape = qkv_shape
q = torch.randn(q_shape, dtype=dtype, device='cuda')
k = torch.randn(kv_shape, dtype=dtype, device='cuda')
v = torch.randn(kv_shape, dtype=dtype, device='cuda')
o = torch.empty_like(q)
ref = torch.empty_like(q)
is_causal = q_shape[2] == kv_shape[2]
enable_gqa = q_shape[1] != kv_shape[1]
backend(q, k, v, o, is_causal, enable_gqa)
ref_fmha(q, k, v, ref, is_causal, enable_gqa)
torch.testing.assert_close(o, ref, atol=1e-2, rtol=5e-2)
torch.cuda.synchronize()
warmup_rounds, iterations, rounds = estimate_bench_iter(
backend, (q, k, v, o, is_causal, enable_gqa),
)
benchmark.pedantic(
backend, (q, k, v, o, is_causal, enable_gqa),
rounds=rounds, warmup_rounds=warmup_rounds, iterations=iterations,
)
B, H, L, D = q.shape
# first gemm mma(q, k): 2 * B * H * L * L * D
# second gemm mma(p, v): 2 * B * H * L * L * D
flop_count = 4 * B * H * L * L * D
if is_causal:
flop_count /= 2
bytes_rw = sum([t.numel() * t.dtype.itemsize for t in (q, k, v, o)])
benchmark.extra_info['flop_count'] = flop_count
benchmark.extra_info['bytes_rw'] = bytes_rw
def cutile_fmha(q, k, v, o, is_causal, enable_gqa):
b, qh, q_len, d = q.shape
_, kh, k_len, _ = k.shape
qk_scale = 1 / sqrt(d)
TILE_M, TILE_N = (256, 128) if is_causal else (64, 128)
query_group_size = qh // kh
grid = (ceil(q_len / TILE_M), b * qh, 1)
input_pos = 0 if q_len == k_len else (k_len - 1)
EVEN_K = (k_len % TILE_N) == 0
ct.launch(torch.cuda.current_stream(), grid, fmha_kernel,
(q, k, v, o,
qk_scale,
input_pos,
d, qh,
TILE_M, TILE_N,
query_group_size, is_causal, EVEN_K))
def torch_fmha(q, k, v, o, is_causal, enable_gqa):
backend = SDPBackend.CUDNN_ATTENTION \
if (q.shape[2] == k.shape[2]) \
else SDPBackend.FLASH_ATTENTION
with sdpa_kernel(backend):
ret = scaled_dot_product_attention(q, k, v,
is_causal=is_causal,
enable_gqa=enable_gqa)
o.copy_(ret)
def ref_fmha(q, k, v, o, is_causal, enable_gqa):
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
ret = scaled_dot_product_attention(q, k, v,
is_causal=is_causal,
enable_gqa=enable_gqa)
o.copy_(ret)