Skip to content

Commit 13367e2

Browse files
committed
Use torch._scaled_mm instead of torch.matmul for FP8 GEMM ops in Helion kernel
stack-info: PR: #356, branch: yf225/stack/39
1 parent 41fe6e9 commit 13367e2

File tree

4 files changed

+367
-156
lines changed

4 files changed

+367
-156
lines changed

examples/fp8_attention.py

Lines changed: 65 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def fp8_attention_kernel(
2323

2424
# Output tensor with 4D shape in FP8 format
2525
out = torch.empty(
26-
[batch, heads, seq_len, head_dim], dtype=torch.float8_e5m2, device=q.device
26+
[batch, heads, seq_len, head_dim], dtype=torch.float8_e4m3fn, device=q.device
2727
)
2828

2929
# Scale factor for attention
@@ -54,8 +54,15 @@ def fp8_attention_kernel(
5454
k_tile_t = k_tile.transpose(0, 1) # [dim, tile_n]
5555

5656
# Compute Q @ K^T with FP8 inputs, result in FP32
57-
qk = torch.matmul(q_tile, k_tile_t).to(
58-
torch.float32
57+
scale_a = hl.full([], 1.0, dtype=torch.float32)
58+
scale_b = hl.full([], 1.0, dtype=torch.float32)
59+
qk = torch._scaled_mm(
60+
q_tile,
61+
k_tile_t,
62+
scale_a,
63+
scale_b,
64+
use_fast_accum=False,
65+
out_dtype=torch.float32,
5966
) # [tile_m, tile_n]
6067

6168
# Scale QK scores first
@@ -90,8 +97,19 @@ def fp8_attention_kernel(
9097
p_fp8 = p.to(v.dtype) # Convert to same FP8 type as V
9198

9299
# Accumulate attention @ V with FP8 GEMM
93-
v_t = v_tile.transpose(0, 1) # [tile_n, dim]
94-
pv = torch.matmul(p_fp8, v_t).to(torch.float32) # [tile_m, dim]
100+
# torch._scaled_mm requires second matrix to be column-major
101+
# v_tile is [dim, tile_n], we need [tile_n, dim] in column-major
102+
v_t = v_tile.contiguous().t() # [tile_n, dim] in column-major format
103+
scale_p = hl.full([], 1.0, dtype=torch.float32)
104+
scale_v = hl.full([], 1.0, dtype=torch.float32)
105+
pv = torch._scaled_mm(
106+
p_fp8,
107+
v_t,
108+
scale_p,
109+
scale_v,
110+
use_fast_accum=False,
111+
out_dtype=torch.float32,
112+
) # [tile_m, dim]
95113
acc = acc + pv
96114

97115
# Update max tracker
@@ -100,18 +118,18 @@ def fp8_attention_kernel(
100118
# Final normalization
101119
acc = acc / l_i[:, None]
102120
# Convert to FP8 before writing to output
103-
out[b, h, tile_m, :] = acc.to(torch.float8_e5m2)
121+
out[b, h, tile_m, :] = acc.to(torch.float8_e4m3fn)
104122

105123
return out
106124

107125

108126
def preprocess_fp8_attention_inputs(
109127
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
110128
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
111-
q_fp8 = q.to(torch.float8_e5m2)
112-
k_fp8 = k.to(torch.float8_e5m2)
129+
q_fp8 = q.to(torch.float8_e4m3fn)
130+
k_fp8 = k.to(torch.float8_e4m3fn)
113131
v = v.permute(0, 1, 3, 2)
114-
v_fp8 = v.to(torch.float8_e5m2)
132+
v_fp8 = v.to(torch.float8_e4m3fn)
115133
batch, heads, seq_len, head_dim = q.shape
116134
q_fp8_reshaped = q_fp8.reshape(batch * heads, seq_len, head_dim)
117135
k_fp8_reshaped = k_fp8.reshape(batch * heads, seq_len, head_dim)
@@ -147,13 +165,25 @@ def _fp8_attention_pytorch_impl(
147165
k_i = k_fp8[i] # [seq, dim] - already FP8
148166
v_i = v_fp8[i] # [dim, seq] - pre-transposed, already FP8
149167

150-
# For Q @ K^T, we need K^T to be column-major
151-
kt_fp8 = k_i.t() # column-major [dim, seq]
152-
153-
# Q @ K^T - dequantize and use regular matmul since e5m2 not supported by _scaled_mm
154-
q_deq = q_i.to(torch.float32)
155-
kt_deq = kt_fp8.to(torch.float32)
156-
qk = torch.matmul(q_deq, kt_deq)
168+
# For Q @ K^T using torch._scaled_mm
169+
# torch._scaled_mm requires column-major for second operand
170+
# k_i is [seq, dim], we need K^T as [dim, seq] in column-major
171+
# Direct conversion: k_i -> contiguous -> transpose view
172+
kt_fp8_col_major = k_i.contiguous().t() # [dim, seq] in column-major
173+
174+
# Create scale tensors
175+
scale_q = torch.tensor(1.0, device=q_i.device)
176+
scale_k = torch.tensor(1.0, device=k_i.device)
177+
178+
# Q @ K^T using torch._scaled_mm
179+
qk = torch._scaled_mm(
180+
q_i,
181+
kt_fp8_col_major,
182+
scale_q,
183+
scale_k,
184+
use_fast_accum=False,
185+
out_dtype=torch.float32,
186+
)
157187

158188
# Compute max before scaling
159189
qk_max = torch.amax(qk, dim=-1, keepdim=True)
@@ -168,16 +198,26 @@ def _fp8_attention_pytorch_impl(
168198
# Step 2: Attention @ V using FP8
169199
# P is [seq, seq], V is [dim, seq]
170200
# We want P @ V^T = [seq, seq] @ [seq, dim] = [seq, dim]
171-
p_fp8 = p_norm.to(torch.float8_e5m2) # row-major [seq, seq]
201+
p_fp8 = p_norm.to(torch.float8_e4m3fn) # row-major [seq, seq]
172202

173203
# v_i is [dim, seq], already FP8
174-
vt_fp8 = v_i.t() # column-major [seq, dim]
175-
176-
# P @ V^T - dequantize and use regular matmul since e5m2 not supported by torch._scaled_mm
177-
p_deq = p_fp8.to(torch.float32)
178-
vt_deq = vt_fp8.to(torch.float32)
179-
out_i = torch.matmul(p_deq, vt_deq)
180-
out_i = out_i.to(torch.float8_e5m2) # convert back to FP8
204+
# Direct conversion: v_i -> contiguous -> transpose view
205+
vt_fp8_col_major = v_i.contiguous().t() # [seq, dim] in column-major
206+
207+
# Create scale tensors for P @ V^T
208+
scale_p = torch.tensor(1.0, device=p_fp8.device)
209+
scale_v = torch.tensor(1.0, device=v_i.device)
210+
211+
# P @ V^T using torch._scaled_mm
212+
out_i = torch._scaled_mm(
213+
p_fp8,
214+
vt_fp8_col_major,
215+
scale_p,
216+
scale_v,
217+
use_fast_accum=False,
218+
out_dtype=torch.float32,
219+
)
220+
out_i = out_i.to(torch.float8_e4m3fn) # convert back to FP8 to match kernel
181221

182222
outputs.append(out_i)
183223

@@ -192,7 +232,7 @@ def fp8_attention_pytorch(
192232
v: torch.Tensor, # [batch, heads, seq, dim]
193233
) -> Callable[[], torch.Tensor]:
194234
"""
195-
Baseline PyTorch implementation of FP8 attention using FP8 e5m2.
235+
Baseline PyTorch implementation of FP8 attention using torch._scaled_mm.
196236
"""
197237
batch, heads, seq_len, head_dim = q.shape
198238
q_fp8, k_fp8, v_fp8 = preprocess_fp8_attention_inputs(q, k, v)

examples/fp8_gemm.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
from __future__ import annotations
22

3+
import os
4+
35
import torch
46

57
import helion
68
from helion._testing import run_example
79
import helion.language as hl
810

11+
# Override default config to work around Triton tl.dot requirement:
12+
# `AssertionError: Input shapes should have M >= 16, N >= 16 and K >= 32`
13+
config = None
14+
if os.environ.get("HELION_USE_DEFAULT_CONFIG") == "1":
15+
config = helion.Config(block_sizes=[32, 32, 32])
16+
917

10-
@helion.kernel(static_shapes=True)
18+
@helion.kernel(static_shapes=True, config=config)
1119
def fp8_gemm(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1220
"""FP8 General Matrix Multiplication (GEMM).
1321
@@ -37,11 +45,24 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
3745
x_tile = x[tile_m, tile_k]
3846
y_tile = y[tile_k, tile_n]
3947

40-
# Use torch.matmul which will be lowered to tl.dot
41-
# When the inputs are FP8, tl.dot handles them natively
42-
# The result needs to be converted to FP32 for accumulation
43-
result = torch.matmul(x_tile, y_tile).to(torch.float32)
44-
acc = acc + result
48+
# torch._scaled_mm(A, B) requires B to be column-major
49+
# We make y_tile column-major by transposing twice
50+
y_tile_col_major = y_tile.transpose(0, 1).contiguous().transpose(0, 1)
51+
52+
# Create scale tensors
53+
scale_a = hl.full([], 1.0, dtype=torch.float32)
54+
scale_b = hl.full([], 1.0, dtype=torch.float32)
55+
56+
# Use torch._scaled_mm for FP8 GEMM, then accumulate result in FP32
57+
mm_out = torch._scaled_mm(
58+
x_tile,
59+
y_tile_col_major,
60+
scale_a,
61+
scale_b,
62+
use_fast_accum=False,
63+
out_dtype=torch.float32,
64+
)
65+
acc = acc + mm_out
4566
out[tile_m, tile_n] = acc.to(torch.float16)
4667

4768
return out
@@ -52,12 +73,17 @@ def reference_fp8_gemm_pytorch(
5273
) -> torch.Tensor:
5374
"""Reference implementation using torch._scaled_mm."""
5475
# torch._scaled_mm requires column-major for second operand
55-
y_fp8_t = y_fp8.T.contiguous().T
76+
y_fp8_col_major = y_fp8.T.contiguous().T
5677
scale_a = torch.tensor(1.0, device=x_fp8.device)
5778
scale_b = torch.tensor(1.0, device=x_fp8.device)
5879
return torch._scaled_mm(
59-
x_fp8, y_fp8_t, scale_a, scale_b, use_fast_accum=False, out_dtype=torch.float16
60-
)
80+
x_fp8,
81+
y_fp8_col_major,
82+
scale_a,
83+
scale_b,
84+
use_fast_accum=False,
85+
out_dtype=torch.float32,
86+
).to(torch.float16)
6187

6288

6389
def fp8_gemm_tritonbench(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)