Skip to content

Commit 24c090c

Browse files
committed
Use hl.dot instead of torch.matmul for FP8 GEMM ops in Helion kernel
stack-info: PR: #356, branch: yf225/stack/39
1 parent 6f83712 commit 24c090c

File tree

3 files changed

+83
-63
lines changed

3 files changed

+83
-63
lines changed

examples/fp8_attention.py

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def fp8_attention_kernel(
2323

2424
# Output tensor with 4D shape in FP8 format
2525
out = torch.empty(
26-
[batch, heads, seq_len, head_dim], dtype=torch.float8_e5m2, device=q.device
26+
[batch, heads, seq_len, head_dim], dtype=torch.float8_e4m3fn, device=q.device
2727
)
2828

2929
# Scale factor for attention
@@ -54,9 +54,7 @@ def fp8_attention_kernel(
5454
k_tile_t = k_tile.transpose(0, 1) # [dim, tile_n]
5555

5656
# Compute Q @ K^T with FP8 inputs, result in FP32
57-
qk = torch.matmul(q_tile, k_tile_t).to(
58-
torch.float32
59-
) # [tile_m, tile_n]
57+
qk = hl.dot(q_tile, k_tile_t) # [tile_m, tile_n]
6058

6159
# Scale QK scores first
6260
qk_scaled = qk * sm_scale # [tile_m, tile_n]
@@ -90,28 +88,28 @@ def fp8_attention_kernel(
9088
p_fp8 = p.to(v.dtype) # Convert to same FP8 type as V
9189

9290
# Accumulate attention @ V with FP8 GEMM
93-
v_t = v_tile.transpose(0, 1) # [tile_n, dim]
94-
pv = torch.matmul(p_fp8, v_t).to(torch.float32) # [tile_m, dim]
95-
acc = acc + pv
91+
# v_tile is [dim, tile_n], we need to transpose for P @ V^T
92+
v_t = v_tile.t() # [tile_n, dim]
93+
acc = hl.dot(p_fp8, v_t, acc=acc) # [tile_m, dim]
9694

9795
# Update max tracker
9896
m_i = m_new
9997

10098
# Final normalization
10199
acc = acc / l_i[:, None]
102100
# Convert to FP8 before writing to output
103-
out[b, h, tile_m, :] = acc.to(torch.float8_e5m2)
101+
out[b, h, tile_m, :] = acc.to(torch.float8_e4m3fn)
104102

105103
return out
106104

107105

108106
def preprocess_fp8_attention_inputs(
109107
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
110108
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
111-
q_fp8 = q.to(torch.float8_e5m2)
112-
k_fp8 = k.to(torch.float8_e5m2)
109+
q_fp8 = q.to(torch.float8_e4m3fn)
110+
k_fp8 = k.to(torch.float8_e4m3fn)
113111
v = v.permute(0, 1, 3, 2)
114-
v_fp8 = v.to(torch.float8_e5m2)
112+
v_fp8 = v.to(torch.float8_e4m3fn)
115113
batch, heads, seq_len, head_dim = q.shape
116114
q_fp8_reshaped = q_fp8.reshape(batch * heads, seq_len, head_dim)
117115
k_fp8_reshaped = k_fp8.reshape(batch * heads, seq_len, head_dim)
@@ -147,13 +145,25 @@ def _fp8_attention_pytorch_impl(
147145
k_i = k_fp8[i] # [seq, dim] - already FP8
148146
v_i = v_fp8[i] # [dim, seq] - pre-transposed, already FP8
149147

150-
# For Q @ K^T, we need K^T to be column-major
151-
kt_fp8 = k_i.t() # column-major [dim, seq]
152-
153-
# Q @ K^T - dequantize and use regular matmul since e5m2 not supported by _scaled_mm
154-
q_deq = q_i.to(torch.float32)
155-
kt_deq = kt_fp8.to(torch.float32)
156-
qk = torch.matmul(q_deq, kt_deq)
148+
# For Q @ K^T using torch._scaled_mm
149+
# torch._scaled_mm requires column-major for second operand
150+
# k_i is [seq, dim], we need K^T as [dim, seq] in column-major
151+
# Direct conversion: k_i -> contiguous -> transpose view
152+
kt_fp8_col_major = k_i.contiguous().t() # [dim, seq] in column-major
153+
154+
# Create scale tensors
155+
scale_q = torch.tensor(1.0, device=q_i.device)
156+
scale_k = torch.tensor(1.0, device=k_i.device)
157+
158+
# Q @ K^T using torch._scaled_mm
159+
qk = torch._scaled_mm(
160+
q_i,
161+
kt_fp8_col_major,
162+
scale_q,
163+
scale_k,
164+
use_fast_accum=False,
165+
out_dtype=torch.float32,
166+
)
157167

158168
# Compute max before scaling
159169
qk_max = torch.amax(qk, dim=-1, keepdim=True)
@@ -168,16 +178,26 @@ def _fp8_attention_pytorch_impl(
168178
# Step 2: Attention @ V using FP8
169179
# P is [seq, seq], V is [dim, seq]
170180
# We want P @ V^T = [seq, seq] @ [seq, dim] = [seq, dim]
171-
p_fp8 = p_norm.to(torch.float8_e5m2) # row-major [seq, seq]
181+
p_fp8 = p_norm.to(torch.float8_e4m3fn) # row-major [seq, seq]
172182

173183
# v_i is [dim, seq], already FP8
174-
vt_fp8 = v_i.t() # column-major [seq, dim]
175-
176-
# P @ V^T - dequantize and use regular matmul since e5m2 not supported by torch._scaled_mm
177-
p_deq = p_fp8.to(torch.float32)
178-
vt_deq = vt_fp8.to(torch.float32)
179-
out_i = torch.matmul(p_deq, vt_deq)
180-
out_i = out_i.to(torch.float8_e5m2) # convert back to FP8
184+
# Direct conversion: v_i -> contiguous -> transpose view
185+
vt_fp8_col_major = v_i.contiguous().t() # [seq, dim] in column-major
186+
187+
# Create scale tensors for P @ V^T
188+
scale_p = torch.tensor(1.0, device=p_fp8.device)
189+
scale_v = torch.tensor(1.0, device=v_i.device)
190+
191+
# P @ V^T using torch._scaled_mm
192+
out_i = torch._scaled_mm(
193+
p_fp8,
194+
vt_fp8_col_major,
195+
scale_p,
196+
scale_v,
197+
use_fast_accum=False,
198+
out_dtype=torch.float32,
199+
)
200+
out_i = out_i.to(torch.float8_e4m3fn) # convert back to FP8 to match kernel
181201

182202
outputs.append(out_i)
183203

@@ -192,7 +212,7 @@ def fp8_attention_pytorch(
192212
v: torch.Tensor, # [batch, heads, seq, dim]
193213
) -> Callable[[], torch.Tensor]:
194214
"""
195-
Baseline PyTorch implementation of FP8 attention using FP8 e5m2.
215+
Baseline PyTorch implementation of FP8 attention using torch._scaled_mm.
196216
"""
197217
batch, heads, seq_len, head_dim = q.shape
198218
q_fp8, k_fp8, v_fp8 = preprocess_fp8_attention_inputs(q, k, v)

examples/fp8_gemm.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
from __future__ import annotations
22

3+
import os
4+
35
import torch
46

57
import helion
68
from helion._testing import run_example
79
import helion.language as hl
810

11+
# Override default config to work around Triton tl.dot requirement:
12+
# `AssertionError: Input shapes should have M >= 16, N >= 16 and K >= 32`
13+
config = None
14+
if os.environ.get("HELION_USE_DEFAULT_CONFIG") == "1":
15+
config = helion.Config(block_sizes=[32, 32, 32])
16+
917

10-
@helion.kernel(static_shapes=True)
18+
@helion.kernel(static_shapes=True, config=config)
1119
def fp8_gemm(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1220
"""FP8 General Matrix Multiplication (GEMM).
1321
@@ -37,11 +45,8 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
3745
x_tile = x[tile_m, tile_k]
3846
y_tile = y[tile_k, tile_n]
3947

40-
# Use torch.matmul which will be lowered to tl.dot
41-
# When the inputs are FP8, tl.dot handles them natively
42-
# The result needs to be converted to FP32 for accumulation
43-
result = torch.matmul(x_tile, y_tile).to(torch.float32)
44-
acc = acc + result
48+
# Use hl.dot for FP8 GEMM
49+
acc = hl.dot(x_tile, y_tile, acc=acc)
4550
out[tile_m, tile_n] = acc.to(torch.float16)
4651

4752
return out

test/test_examples.expected

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -608,41 +608,38 @@ def _fp8_attention_kernel_kernel(q, k, v, out, out_stride_0, heads, _RDIM_SIZE_2
608608
acc_copy_0 = acc_copy
609609
k_tile = tl.load(k + (offset_0 * 16384 + indices_2[:, None] * 64 + indices_5[None, :] * 1), None)
610610
k_tile_t = tl.permute(k_tile, [1, 0])
611-
mm = tl.dot(q_tile_copy_0, k_tile_t, input_precision='tf32')
612-
v_0 = mm.to(tl.float32)
613-
v_1 = 0.18033688
614-
v_2 = v_0 * v_1
615-
qk_max = tl.max(v_2, 1)
616-
v_3 = triton_helpers.maximum(m_i_copy_0, qk_max)
617-
subscript = v_3[:, None]
618-
v_4 = v_2 - subscript
619-
v_5 = libdevice.exp2(v_4)
620-
l_ij = tl.sum(v_5, 1)
621-
v_6 = m_i_copy_0 - v_3
622-
v_7 = libdevice.exp2(v_6)
623-
v_8 = l_i_copy_0 * v_7
624-
l_i = v_8 + l_ij
625-
subscript_1 = v_7[:, None]
626-
v_10 = acc_copy_0 * subscript_1
611+
qk = tl.dot(q_tile_copy_0, k_tile_t, acc=None, input_precision='tf32', out_dtype=tl.float32)
612+
v_0 = 0.18033688
613+
v_1 = qk * v_0
614+
qk_max = tl.max(v_1, 1)
615+
v_2 = triton_helpers.maximum(m_i_copy_0, qk_max)
616+
subscript = v_2[:, None]
617+
v_3 = v_1 - subscript
618+
v_4 = libdevice.exp2(v_3)
619+
l_ij = tl.sum(v_4, 1)
620+
v_5 = m_i_copy_0 - v_2
621+
v_6 = libdevice.exp2(v_5)
622+
v_7 = l_i_copy_0 * v_6
623+
l_i = v_7 + l_ij
624+
subscript_1 = v_6[:, None]
625+
v_9 = acc_copy_0 * subscript_1
627626
v_tile = tl.load(v + (offset_0 * 16384 + indices_5[:, None] * 1 + indices_2[None, :] * 64), None)
628-
v_11 = v_5.to(tl.float8e5)
627+
v_10 = v_4.to(tl.float8e4nv)
629628
v_t = tl.permute(v_tile, [1, 0])
630-
mm_1 = tl.dot(v_11, v_t, input_precision='tf32')
631-
v_12 = mm_1.to(tl.float32)
632-
acc = v_10 + v_12
633-
m_i = v_3
629+
acc = tl.dot(v_10, v_t, acc=v_9, input_precision='tf32', out_dtype=tl.float32)
630+
m_i = v_2
634631
subscript_2 = l_i[:, None]
635-
v_14 = acc / subscript_2
636-
v_15 = v_14.to(tl.float8e5)
632+
v_11 = acc / subscript_2
633+
v_12 = v_11.to(tl.float8e4nv)
637634
symnode_0 = triton_helpers.div_floor_integer(offset_0, heads)
638635
symnode_1 = triton_helpers.remainder_integer(offset_0, heads)
639-
tl.store(out + (symnode_0 * out_stride_0 + symnode_1 * 16384 + indices_4[:, None] * 64 + indices_5[None, :] * 1), v_15, None)
636+
tl.store(out + (symnode_0 * out_stride_0 + symnode_1 * 16384 + indices_4[:, None] * 64 + indices_5[None, :] * 1), v_12, None)
640637

641638
def fp8_attention_kernel(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, batch: int, heads: int, *, _launcher=_default_launcher):
642639
batch_heads = q.size(0)
643640
seq_len = q.size(1)
644641
head_dim = q.size(2)
645-
out = torch.empty([batch, heads, seq_len, head_dim], dtype=torch.float8_e5m2, device=q.device)
642+
out = torch.empty([batch, heads, seq_len, head_dim], dtype=torch.float8_e4m3fn, device=q.device)
646643
sm_scale = 1.0 / math.sqrt(float(head_dim))
647644
sm_scale = sm_scale * 1.44269504
648645
_RDIM_SIZE_2 = 64
@@ -675,11 +672,9 @@ def _fp8_gemm_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.c
675672
acc_copy_0 = acc_copy
676673
x_tile = tl.load(x + (indices_0[:, None] * 256 + indices_2[None, :] * 1), None)
677674
y_tile = tl.load(y + (indices_2[:, None] * 256 + indices_1[None, :] * 1), None)
678-
mm = tl.dot(x_tile, y_tile, input_precision='tf32')
679-
v_0 = mm.to(tl.float32)
680-
acc = acc_copy_0 + v_0
681-
v_2 = acc.to(tl.float16)
682-
tl.store(out + (indices_0[:, None] * 256 + indices_1[None, :] * 1), v_2, None)
675+
acc = tl.dot(x_tile, y_tile, acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32)
676+
v_0 = acc.to(tl.float16)
677+
tl.store(out + (indices_0[:, None] * 256 + indices_1[None, :] * 1), v_0, None)
683678

684679
def fp8_gemm(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
685680
"""FP8 General Matrix Multiplication (GEMM).

0 commit comments

Comments
 (0)