Skip to content

[Ref Mode] Make Tile apis work in ref eager mode #378

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 47 additions & 27 deletions examples/fp8_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def fp8_attention_kernel(

# Output tensor with 4D shape in FP8 format
out = torch.empty(
[batch, heads, seq_len, head_dim], dtype=torch.float8_e5m2, device=q.device
[batch, heads, seq_len, head_dim], dtype=torch.float8_e4m3fn, device=q.device
)

# Scale factor for attention
Expand Down Expand Up @@ -54,9 +54,7 @@ def fp8_attention_kernel(
k_tile_t = k_tile.transpose(0, 1) # [dim, tile_n]

# Compute Q @ K^T with FP8 inputs, result in FP32
qk = torch.matmul(q_tile, k_tile_t).to(
torch.float32
) # [tile_m, tile_n]
qk = hl.dot(q_tile, k_tile_t) # [tile_m, tile_n]

# Scale QK scores first
qk_scaled = qk * sm_scale # [tile_m, tile_n]
Expand Down Expand Up @@ -90,28 +88,28 @@ def fp8_attention_kernel(
p_fp8 = p.to(v.dtype) # Convert to same FP8 type as V

# Accumulate attention @ V with FP8 GEMM
v_t = v_tile.transpose(0, 1) # [tile_n, dim]
pv = torch.matmul(p_fp8, v_t).to(torch.float32) # [tile_m, dim]
acc = acc + pv
# v_tile is [dim, tile_n], we need to transpose for P @ V^T
v_t = v_tile.t() # [tile_n, dim]
acc = hl.dot(p_fp8, v_t, acc=acc) # [tile_m, dim]

# Update max tracker
m_i = m_new

# Final normalization
acc = acc / l_i[:, None]
# Convert to FP8 before writing to output
out[b, h, tile_m, :] = acc.to(torch.float8_e5m2)
out[b, h, tile_m, :] = acc.to(torch.float8_e4m3fn)

return out


def preprocess_fp8_attention_inputs(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q_fp8 = q.to(torch.float8_e5m2)
k_fp8 = k.to(torch.float8_e5m2)
q_fp8 = q.to(torch.float8_e4m3fn)
k_fp8 = k.to(torch.float8_e4m3fn)
v = v.permute(0, 1, 3, 2)
v_fp8 = v.to(torch.float8_e5m2)
v_fp8 = v.to(torch.float8_e4m3fn)
batch, heads, seq_len, head_dim = q.shape
q_fp8_reshaped = q_fp8.reshape(batch * heads, seq_len, head_dim)
k_fp8_reshaped = k_fp8.reshape(batch * heads, seq_len, head_dim)
Expand Down Expand Up @@ -147,13 +145,25 @@ def _fp8_attention_pytorch_impl(
k_i = k_fp8[i] # [seq, dim] - already FP8
v_i = v_fp8[i] # [dim, seq] - pre-transposed, already FP8

# For Q @ K^T, we need K^T to be column-major
kt_fp8 = k_i.t() # column-major [dim, seq]

# Q @ K^T - dequantize and use regular matmul since e5m2 not supported by _scaled_mm
q_deq = q_i.to(torch.float32)
kt_deq = kt_fp8.to(torch.float32)
qk = torch.matmul(q_deq, kt_deq)
# For Q @ K^T using torch._scaled_mm
# torch._scaled_mm requires column-major for second operand
# k_i is [seq, dim], we need K^T as [dim, seq] in column-major
# Direct conversion: k_i -> contiguous -> transpose view
kt_fp8_col_major = k_i.contiguous().t() # [dim, seq] in column-major

# Create scale tensors
scale_q = torch.tensor(1.0, device=q_i.device)
scale_k = torch.tensor(1.0, device=k_i.device)

# Q @ K^T using torch._scaled_mm
qk = torch._scaled_mm(
q_i,
kt_fp8_col_major,
scale_q,
scale_k,
use_fast_accum=False,
out_dtype=torch.float32,
)

# Compute max before scaling
qk_max = torch.amax(qk, dim=-1, keepdim=True)
Expand All @@ -168,16 +178,26 @@ def _fp8_attention_pytorch_impl(
# Step 2: Attention @ V using FP8
# P is [seq, seq], V is [dim, seq]
# We want P @ V^T = [seq, seq] @ [seq, dim] = [seq, dim]
p_fp8 = p_norm.to(torch.float8_e5m2) # row-major [seq, seq]
p_fp8 = p_norm.to(torch.float8_e4m3fn) # row-major [seq, seq]

# v_i is [dim, seq], already FP8
vt_fp8 = v_i.t() # column-major [seq, dim]

# P @ V^T - dequantize and use regular matmul since e5m2 not supported by torch._scaled_mm
p_deq = p_fp8.to(torch.float32)
vt_deq = vt_fp8.to(torch.float32)
out_i = torch.matmul(p_deq, vt_deq)
out_i = out_i.to(torch.float8_e5m2) # convert back to FP8
# Direct conversion: v_i -> contiguous -> transpose view
vt_fp8_col_major = v_i.contiguous().t() # [seq, dim] in column-major

# Create scale tensors for P @ V^T
scale_p = torch.tensor(1.0, device=p_fp8.device)
scale_v = torch.tensor(1.0, device=v_i.device)

# P @ V^T using torch._scaled_mm
out_i = torch._scaled_mm(
p_fp8,
vt_fp8_col_major,
scale_p,
scale_v,
use_fast_accum=False,
out_dtype=torch.float32,
)
out_i = out_i.to(torch.float8_e4m3fn) # convert back to FP8 to match kernel

outputs.append(out_i)

Expand All @@ -192,7 +212,7 @@ def fp8_attention_pytorch(
v: torch.Tensor, # [batch, heads, seq, dim]
) -> Callable[[], torch.Tensor]:
"""
Baseline PyTorch implementation of FP8 attention using FP8 e5m2.
Baseline PyTorch implementation of FP8 attention using torch._scaled_mm.
"""
batch, heads, seq_len, head_dim = q.shape
q_fp8, k_fp8, v_fp8 = preprocess_fp8_attention_inputs(q, k, v)
Expand Down
17 changes: 11 additions & 6 deletions examples/fp8_gemm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
from __future__ import annotations

import os

import torch

import helion
from helion._testing import run_example
import helion.language as hl

# Override default config to work around Triton tl.dot requirement:
# `AssertionError: Input shapes should have M >= 16, N >= 16 and K >= 32`
config = None
if os.environ.get("HELION_USE_DEFAULT_CONFIG") == "1":
config = helion.Config(block_sizes=[32, 32, 32])


@helion.kernel(static_shapes=True)
@helion.kernel(static_shapes=True, config=config)
def fp8_gemm(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""FP8 General Matrix Multiplication (GEMM).

Expand Down Expand Up @@ -37,11 +45,8 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x_tile = x[tile_m, tile_k]
y_tile = y[tile_k, tile_n]

# Use torch.matmul which will be lowered to tl.dot
# When the inputs are FP8, tl.dot handles them natively
# The result needs to be converted to FP32 for accumulation
result = torch.matmul(x_tile, y_tile).to(torch.float32)
acc = acc + result
# Use hl.dot for FP8 GEMM
acc = hl.dot(x_tile, y_tile, acc=acc)
out[tile_m, tile_n] = acc.to(torch.float16)

return out
Expand Down
2 changes: 2 additions & 0 deletions helion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
from .runtime import Kernel
from .runtime import kernel
from .runtime import kernel as jit # alias
from .runtime.settings import RefMode
from .runtime.settings import Settings
from .runtime.settings import set_default_settings

__all__ = [
"Config",
"Kernel",
"RefMode",
"Settings",
"cdiv",
"exc",
Expand Down
9 changes: 8 additions & 1 deletion helion/_compiler/indexing_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,14 @@ def codegen_load(
extra_mask: ast.AST | None,
) -> ast.AST:
indexing = SubscriptIndexing.create(state, fake_tensor, subscript, extra_mask)
extra = ", other=0" if indexing.has_mask() else ""
extra = ""
if indexing.has_mask():
# For FP8 dtypes, use other=0.0 (float literal) instead of other=0 (int literal)
# because Triton cannot cast integer 0 to FP8 types
if fake_tensor.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
extra = ", other=0.0"
else:
extra = ", other=0"
name = state.device_function.tensor_arg(fake_tensor).name
return expr_from_string(
f"tl.load({name} + offset, mask{extra})",
Expand Down
14 changes: 9 additions & 5 deletions helion/_compiler/inductor_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,15 +848,19 @@ def reduce_3d_dot(
rhs_node = node.args[1]
assert isinstance(lhs, ast.AST)
assert isinstance(rhs, ast.AST)
assert isinstance(lhs_node, torch.fx.Node)
assert isinstance(rhs_node, torch.fx.Node)

# Check if inputs are FP8 - if so, don't specify input_precision to allow native FP8 computation
lhs_dtype = lhs_node.meta["val"].dtype # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
rhs_dtype = rhs_node.meta["val"].dtype # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
# Check if inputs are FP8 - if so, redirect user to hl.dot()
lhs_dtype = lhs_node.meta["val"].dtype
rhs_dtype = rhs_node.meta["val"].dtype
if lhs_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and rhs_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
]:
datatype = None # Let Triton use native FP8 computation
raise NotImplementedError(
"FP8 GEMM via torch API is not supported yet. Please use hl.dot() instead."
)

lhs_size = lhs_node.meta["val"].size() # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
rhs_size = rhs_node.meta["val"].size() # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
Expand Down Expand Up @@ -1138,7 +1142,7 @@ def proxy_arg(self, i: int) -> object:

def ast_arg(self, i: int) -> ast.AST:
rv = self.ast_args[i]
if isinstance(rv, int | float | bool):
if isinstance(rv, int | float | bool | None):
rv = ast.Constant(value=rv)
assert isinstance(rv, ast.AST), "TODO: convert nested/defaults"
return rv
Expand Down
9 changes: 9 additions & 0 deletions helion/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ def code_and_output(
args: tuple[object, ...],
**kwargs: object,
) -> tuple[str, object]:
bound = fn.bind(args)
from helion.runtime.settings import RefMode

if bound.kernel.settings.ref_mode != RefMode.OFF:
result = fn(*args)
# Return the original kernel source code
code = inspect.getsource(fn.fn)
return code, result

if kwargs:
config = Config(
**kwargs # pyright: ignore[reportArgumentType]
Expand Down
1 change: 1 addition & 0 deletions helion/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .loops import grid as grid
from .loops import static_range as static_range
from .loops import tile as tile
from .matmul_ops import dot as dot
from .memory_ops import atomic_add as atomic_add
from .memory_ops import load as load
from .memory_ops import store as store
Expand Down
27 changes: 27 additions & 0 deletions helion/language/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class APIFunc(Protocol):
_to_device_ir: Callable[..., object] | None
_allow_host_tensor: bool
_signature: inspect.Signature
_ref_fn: Callable[..., object] | None

def __call__(self, *args: object, **kwargs: object) -> object: ...

Expand Down Expand Up @@ -133,6 +134,15 @@ def api(
def _impl(fn: _C) -> _C:
@functools.wraps(fn)
def wrapper(*args: object, **kwargs: object) -> object:
from ..runtime.ref_mode import is_ref_mode_enabled

if is_ref_mode_enabled() and api._ref_fn is not None:
# In ref mode, use the registered ref implementation
bound = api._signature.bind(*args, **kwargs)
bound.apply_defaults()
flat_args = api._prepare_args(*bound.arguments.values())
return api._ref_fn(*flat_args)

bound = api._signature.bind(*args, **kwargs)
bound.apply_defaults()
flat_args = api._prepare_args(*bound.arguments.values())
Expand Down Expand Up @@ -187,6 +197,7 @@ def wrapper(*args: object, **kwargs: object) -> object:
api._signature = signature or inspect.signature(
cast("Callable[..., object]", fn)
)
api._ref_fn = None
return wrapper # pyright: ignore[reportReturnType]

return _impl
Expand Down Expand Up @@ -289,6 +300,22 @@ def _impl(to_device_ir_fn: Callable[..., object]) -> Callable[..., Never]:
return _impl # pyright: ignore[reportReturnType]


def ref(
original_fn: Callable[..., object],
) -> _NoReturnDecorator[object]:
def _impl(ref_fn: Callable[..., object]) -> Callable[..., Never]:
assert is_api_func(original_fn), (
f"{ref.__qualname__} can only be used on API functions"
)
assert original_fn._ref_fn is None, (
"ref mode implementation can only be registered once per function"
)
original_fn._ref_fn = ref_fn
return _no_call

return _impl # pyright: ignore[reportReturnType]


def _default_type_function(
fake_fn: Callable[..., object], tiles_as_sizes: bool
) -> Callable[..., TypeInfo]:
Expand Down
5 changes: 5 additions & 0 deletions helion/language/constexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,8 @@ def _(state: CodegenState) -> ast.AST:
value = value.__int__()
assert isinstance(value, int)
return expr_from_string(repr(value))


@_decorators.ref(specialize)
def _(value: int | torch.SymInt) -> int:
return int(value)
20 changes: 20 additions & 0 deletions helion/language/creation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,26 @@ def _(
return None


@_decorators.ref(full)
def _(
shape: list[int | slice],
value: float,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
from .tile_proxy import RefTile

processed_shape = []
for s in shape:
if isinstance(s, RefTile):
# RefTile is a slice subclass with a block_size property
processed_shape.append(s.block_size)
elif isinstance(s, slice):
processed_shape.append(s.stop - s.start)
else:
processed_shape.append(s)
return torch.full(processed_shape, value, dtype=dtype, device="cuda")


def arange(
*args: int,
dtype: torch.dtype | None = None,
Expand Down
5 changes: 5 additions & 0 deletions helion/language/device_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,8 @@ def _(state: CodegenState) -> None:
)
stmt = create(ast.Expr, value=call_expr)
state.add_statement(stmt)


@_decorators.ref(device_print)
def _(prefix: str, *args: object) -> None:
print(prefix, *args)
12 changes: 12 additions & 0 deletions helion/language/inline_asm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,15 @@ def _(state: CodegenState) -> ast.AST | list[ast.AST]:
]

return inline_asm_call


@_decorators.ref(inline_asm_elementwise)
def _(
asm: str,
constraints: str,
args: Sequence[torch.Tensor],
dtype: torch.dtype | Sequence[torch.dtype],
is_pure: bool,
pack: int,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
raise NotImplementedError("inline_asm_elementwise is not supported in ref mode")
Loading
Loading