@@ -23,7 +23,7 @@ def fp8_attention_kernel(
23
23
24
24
# Output tensor with 4D shape in FP8 format
25
25
out = torch .empty (
26
- [batch , heads , seq_len , head_dim ], dtype = torch .float8_e5m2 , device = q .device
26
+ [batch , heads , seq_len , head_dim ], dtype = torch .float8_e4m3fn , device = q .device
27
27
)
28
28
29
29
# Scale factor for attention
@@ -54,8 +54,15 @@ def fp8_attention_kernel(
54
54
k_tile_t = k_tile .transpose (0 , 1 ) # [dim, tile_n]
55
55
56
56
# Compute Q @ K^T with FP8 inputs, result in FP32
57
- qk = torch .matmul (q_tile , k_tile_t ).to (
58
- torch .float32
57
+ scale_a = hl .full ([], 1.0 , dtype = torch .float32 )
58
+ scale_b = hl .full ([], 1.0 , dtype = torch .float32 )
59
+ qk = torch ._scaled_mm (
60
+ q_tile ,
61
+ k_tile_t ,
62
+ scale_a ,
63
+ scale_b ,
64
+ use_fast_accum = False ,
65
+ out_dtype = torch .float32 ,
59
66
) # [tile_m, tile_n]
60
67
61
68
# Scale QK scores first
@@ -90,8 +97,19 @@ def fp8_attention_kernel(
90
97
p_fp8 = p .to (v .dtype ) # Convert to same FP8 type as V
91
98
92
99
# Accumulate attention @ V with FP8 GEMM
93
- v_t = v_tile .transpose (0 , 1 ) # [tile_n, dim]
94
- pv = torch .matmul (p_fp8 , v_t ).to (torch .float32 ) # [tile_m, dim]
100
+ # torch._scaled_mm requires second matrix to be column-major
101
+ # v_tile is [dim, tile_n], we need [tile_n, dim] in column-major
102
+ v_t = v_tile .contiguous ().t () # [tile_n, dim] in column-major format
103
+ scale_p = hl .full ([], 1.0 , dtype = torch .float32 )
104
+ scale_v = hl .full ([], 1.0 , dtype = torch .float32 )
105
+ pv = torch ._scaled_mm (
106
+ p_fp8 ,
107
+ v_t ,
108
+ scale_p ,
109
+ scale_v ,
110
+ use_fast_accum = False ,
111
+ out_dtype = torch .float32 ,
112
+ ) # [tile_m, dim]
95
113
acc = acc + pv
96
114
97
115
# Update max tracker
@@ -100,18 +118,18 @@ def fp8_attention_kernel(
100
118
# Final normalization
101
119
acc = acc / l_i [:, None ]
102
120
# Convert to FP8 before writing to output
103
- out [b , h , tile_m , :] = acc .to (torch .float8_e5m2 )
121
+ out [b , h , tile_m , :] = acc .to (torch .float8_e4m3fn )
104
122
105
123
return out
106
124
107
125
108
126
def preprocess_fp8_attention_inputs (
109
127
q : torch .Tensor , k : torch .Tensor , v : torch .Tensor
110
128
) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
111
- q_fp8 = q .to (torch .float8_e5m2 )
112
- k_fp8 = k .to (torch .float8_e5m2 )
129
+ q_fp8 = q .to (torch .float8_e4m3fn )
130
+ k_fp8 = k .to (torch .float8_e4m3fn )
113
131
v = v .permute (0 , 1 , 3 , 2 )
114
- v_fp8 = v .to (torch .float8_e5m2 )
132
+ v_fp8 = v .to (torch .float8_e4m3fn )
115
133
batch , heads , seq_len , head_dim = q .shape
116
134
q_fp8_reshaped = q_fp8 .reshape (batch * heads , seq_len , head_dim )
117
135
k_fp8_reshaped = k_fp8 .reshape (batch * heads , seq_len , head_dim )
@@ -147,13 +165,25 @@ def _fp8_attention_pytorch_impl(
147
165
k_i = k_fp8 [i ] # [seq, dim] - already FP8
148
166
v_i = v_fp8 [i ] # [dim, seq] - pre-transposed, already FP8
149
167
150
- # For Q @ K^T, we need K^T to be column-major
151
- kt_fp8 = k_i .t () # column-major [dim, seq]
152
-
153
- # Q @ K^T - dequantize and use regular matmul since e5m2 not supported by _scaled_mm
154
- q_deq = q_i .to (torch .float32 )
155
- kt_deq = kt_fp8 .to (torch .float32 )
156
- qk = torch .matmul (q_deq , kt_deq )
168
+ # For Q @ K^T using torch._scaled_mm
169
+ # torch._scaled_mm requires column-major for second operand
170
+ # k_i is [seq, dim], we need K^T as [dim, seq] in column-major
171
+ # Direct conversion: k_i -> contiguous -> transpose view
172
+ kt_fp8_col_major = k_i .contiguous ().t () # [dim, seq] in column-major
173
+
174
+ # Create scale tensors
175
+ scale_q = torch .tensor (1.0 , device = q_i .device )
176
+ scale_k = torch .tensor (1.0 , device = k_i .device )
177
+
178
+ # Q @ K^T using torch._scaled_mm
179
+ qk = torch ._scaled_mm (
180
+ q_i ,
181
+ kt_fp8_col_major ,
182
+ scale_q ,
183
+ scale_k ,
184
+ use_fast_accum = False ,
185
+ out_dtype = torch .float32 ,
186
+ )
157
187
158
188
# Compute max before scaling
159
189
qk_max = torch .amax (qk , dim = - 1 , keepdim = True )
@@ -168,16 +198,26 @@ def _fp8_attention_pytorch_impl(
168
198
# Step 2: Attention @ V using FP8
169
199
# P is [seq, seq], V is [dim, seq]
170
200
# We want P @ V^T = [seq, seq] @ [seq, dim] = [seq, dim]
171
- p_fp8 = p_norm .to (torch .float8_e5m2 ) # row-major [seq, seq]
201
+ p_fp8 = p_norm .to (torch .float8_e4m3fn ) # row-major [seq, seq]
172
202
173
203
# v_i is [dim, seq], already FP8
174
- vt_fp8 = v_i .t () # column-major [seq, dim]
175
-
176
- # P @ V^T - dequantize and use regular matmul since e5m2 not supported by torch._scaled_mm
177
- p_deq = p_fp8 .to (torch .float32 )
178
- vt_deq = vt_fp8 .to (torch .float32 )
179
- out_i = torch .matmul (p_deq , vt_deq )
180
- out_i = out_i .to (torch .float8_e5m2 ) # convert back to FP8
204
+ # Direct conversion: v_i -> contiguous -> transpose view
205
+ vt_fp8_col_major = v_i .contiguous ().t () # [seq, dim] in column-major
206
+
207
+ # Create scale tensors for P @ V^T
208
+ scale_p = torch .tensor (1.0 , device = p_fp8 .device )
209
+ scale_v = torch .tensor (1.0 , device = v_i .device )
210
+
211
+ # P @ V^T using torch._scaled_mm
212
+ out_i = torch ._scaled_mm (
213
+ p_fp8 ,
214
+ vt_fp8_col_major ,
215
+ scale_p ,
216
+ scale_v ,
217
+ use_fast_accum = False ,
218
+ out_dtype = torch .float32 ,
219
+ )
220
+ out_i = out_i .to (torch .float8_e4m3fn ) # convert back to FP8 to match kernel
181
221
182
222
outputs .append (out_i )
183
223
@@ -192,7 +232,7 @@ def fp8_attention_pytorch(
192
232
v : torch .Tensor , # [batch, heads, seq, dim]
193
233
) -> Callable [[], torch .Tensor ]:
194
234
"""
195
- Baseline PyTorch implementation of FP8 attention using FP8 e5m2 .
235
+ Baseline PyTorch implementation of FP8 attention using torch._scaled_mm .
196
236
"""
197
237
batch , heads , seq_len , head_dim = q .shape
198
238
q_fp8 , k_fp8 , v_fp8 = preprocess_fp8_attention_inputs (q , k , v )
0 commit comments