Skip to content

[Ref Mode] PyTorch reference mode (eager only) #339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: yf225/stack/39
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions helion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
from .runtime import Kernel
from .runtime import kernel
from .runtime import kernel as jit # alias
from .runtime.settings import RefMode
from .runtime.settings import Settings
from .runtime.settings import set_default_settings

__all__ = [
"Config",
"Kernel",
"RefMode",
"Settings",
"cdiv",
"exc",
Expand Down
9 changes: 9 additions & 0 deletions helion/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ def code_and_output(
args: tuple[object, ...],
**kwargs: object,
) -> tuple[str, object]:
bound = fn.bind(args)
from helion.runtime.settings import RefMode

if bound.kernel.settings.ref_mode != RefMode.OFF:
result = fn(*args)
# Return the original kernel source code
code = inspect.getsource(fn.fn)
return code, result

if kwargs:
config = Config(
**kwargs # pyright: ignore[reportArgumentType]
Expand Down
27 changes: 27 additions & 0 deletions helion/language/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class APIFunc(Protocol):
_to_device_ir: Callable[..., object] | None
_allow_host_tensor: bool
_signature: inspect.Signature
_ref_fn: Callable[..., object] | None

def __call__(self, *args: object, **kwargs: object) -> object: ...

Expand Down Expand Up @@ -133,6 +134,15 @@ def api(
def _impl(fn: _C) -> _C:
@functools.wraps(fn)
def wrapper(*args: object, **kwargs: object) -> object:
from ..runtime.ref_mode import is_ref_mode_enabled

if is_ref_mode_enabled() and api._ref_fn is not None:
# In ref mode, use the registered ref implementation
bound = api._signature.bind(*args, **kwargs)
bound.apply_defaults()
flat_args = api._prepare_args(*bound.arguments.values())
return api._ref_fn(*flat_args)

bound = api._signature.bind(*args, **kwargs)
bound.apply_defaults()
flat_args = api._prepare_args(*bound.arguments.values())
Expand Down Expand Up @@ -187,6 +197,7 @@ def wrapper(*args: object, **kwargs: object) -> object:
api._signature = signature or inspect.signature(
cast("Callable[..., object]", fn)
)
api._ref_fn = None
return wrapper # pyright: ignore[reportReturnType]

return _impl
Expand Down Expand Up @@ -289,6 +300,22 @@ def _impl(to_device_ir_fn: Callable[..., object]) -> Callable[..., Never]:
return _impl # pyright: ignore[reportReturnType]


def ref(
original_fn: Callable[..., object],
) -> _NoReturnDecorator[object]:
def _impl(ref_fn: Callable[..., object]) -> Callable[..., Never]:
assert is_api_func(original_fn), (
f"{register_ref.__qualname__} can only be used on API functions"
)
assert original_fn._ref_fn is None, (
"ref mode implementation can only be registered once per function"
)
original_fn._ref_fn = ref_fn
return _no_call

return _impl # pyright: ignore[reportReturnType]


def _default_type_function(
fake_fn: Callable[..., object], tiles_as_sizes: bool
) -> Callable[..., TypeInfo]:
Expand Down
5 changes: 5 additions & 0 deletions helion/language/constexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,8 @@ def _(state: CodegenState) -> ast.AST:
value = value.__int__()
assert isinstance(value, int)
return expr_from_string(repr(value))


@_decorators.ref(specialize)
def _(value: int | torch.SymInt) -> int:
return int(value)
10 changes: 10 additions & 0 deletions helion/language/creation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,16 @@ def _(
return None


@_decorators.ref(full)
def _(
shape: list[int | slice],
value: float,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
processed_shape = [s.stop - s.start if isinstance(s, slice) else s for s in shape]
return torch.full(processed_shape, value, dtype=dtype, device="cuda")


def arange(
*args: int,
dtype: torch.dtype | None = None,
Expand Down
5 changes: 5 additions & 0 deletions helion/language/device_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,8 @@ def _(state: CodegenState) -> None:
)
stmt = create(ast.Expr, value=call_expr)
state.add_statement(stmt)


@_decorators.ref(device_print)
def _(prefix: str, *args: object) -> None:
print(prefix, *args)
12 changes: 12 additions & 0 deletions helion/language/inline_asm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,15 @@ def _(state: CodegenState) -> ast.AST | list[ast.AST]:
]

return inline_asm_call


@_decorators.ref(inline_asm_elementwise)
def _(
asm: str,
constraints: str,
args: Sequence[torch.Tensor],
dtype: torch.dtype | Sequence[torch.dtype],
is_pure: bool,
pack: int,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
raise NotImplementedError("inline_asm_elementwise is not supported in ref mode")
113 changes: 113 additions & 0 deletions helion/language/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ast
import builtins
import inspect
import itertools
from itertools import starmap
from typing import TYPE_CHECKING
from typing import Iterator
Expand Down Expand Up @@ -449,6 +450,81 @@ def _(state: CodegenState) -> ast.AST:
return _codegen_loop_helper(state)


@_decorators.ref(tile)
def _(
begin_or_end: int | torch.Tensor | list[int | torch.Tensor],
end_or_none: int | torch.Tensor | list[int | torch.Tensor] | None = None,
block_size: int | torch.Tensor | list[int | torch.Tensor] | None = None,
) -> Iterator[slice | tuple[slice, ...]]:
# Convert tensor values to int
def _to_int(value):
if value is None:
return None
if isinstance(value, torch.Tensor):
return int(value.item())
return int(value)

# Step 1: Normalize begin and end values based on the number of arguments
if end_or_none is not None:
# Two positional args: begin_or_end is begin, end_or_none is end
begin = begin_or_end
end = end_or_none
else:
# One positional arg: begin_or_end is end, begin defaults to 0
end = begin_or_end
# Create begin with same structure as end, but all zeros
if isinstance(end, (list, tuple)):
begin = [0] * len(end)
else:
begin = 0

# Step 2: Convert inputs to lists for uniform handling
def _normalize_to_list(
value: int | torch.Tensor | list[int | torch.Tensor],
) -> list[int | torch.Tensor]:
if isinstance(value, (list, tuple)):
return list(value)
return [value]

begin_list = _normalize_to_list(begin)
end_list = _normalize_to_list(end)

# Convert all values to int
begin_list = [_to_int(b) for b in begin_list]
end_list = [_to_int(e) for e in end_list]

# Step 3: Determine block_size based on the arguments
if block_size is None:
# Default block_size to end - begin for each dimension
block_size_list = [e - b for b, e in zip(begin_list, end_list, strict=False)]
else:
block_size_list = _normalize_to_list(block_size)
block_size_list = [
_to_int(bs) if bs is not None else (e - b)
for bs, b, e in zip(block_size_list, begin_list, end_list, strict=False)
]

# Step 4: Yield tile ranges
# Handle single dimension case
if len(begin_list) == 1:
b = begin_list[0]
e = end_list[0]
bs = block_size_list[0]
for i in range(b, e, bs):
yield slice(i, min(i + bs, e))
else:
# Handle multi-dimensional case
ranges = []
for b, e, bs in zip(begin_list, end_list, block_size_list, strict=False):
dim_ranges = []
for i in range(b, e, bs):
dim_ranges.append(slice(i, min(i + bs, e)))
ranges.append(dim_ranges)

for combo in itertools.product(*ranges):
yield combo


def _codegen_loop_helper(
state: CodegenState,
) -> ast.AST:
Expand Down Expand Up @@ -637,6 +713,32 @@ def _(state: CodegenState) -> ast.AST:
return _codegen_loop_helper(state)


@_decorators.ref(grid)
def _(
begin_or_end: int | torch.Tensor | list[int | torch.Tensor],
end_or_none: int | torch.Tensor | list[int | torch.Tensor] | None = None,
step: object = None,
) -> range | Iterator[tuple[int, ...]]:
# Similar to tile but yields indices instead of slices
if end_or_none is not None:
begin = begin_or_end
end = end_or_none
else:
end = begin_or_end
if isinstance(end, (list, tuple)):
begin = [0] * len(end)
else:
begin = 0

# Handle single dimension
if not isinstance(begin, (list, tuple)):
return range(begin, end)

# Handle multi-dimensional
ranges = list(itertools.starmap(range, zip(begin, end, strict=False)))
return itertools.product(*ranges)


@_decorators.device_func_replacement(builtins.zip)
@_decorators.api(is_device_only=True, cache_type=True)
def _zip_replacement(
Expand Down Expand Up @@ -898,3 +1000,14 @@ def _(

# Return tuple(range(...)) which will trigger existing tuple/list unrolling
return tuple(range(begin_val, end_val, step))


@_decorators.ref(static_range)
def _(
begin_or_end: int,
end_or_none: int | None = None,
step: int = 1,
) -> range:
if end_or_none is not None:
return range(begin_or_end, end_or_none, step)
return range(begin_or_end)
50 changes: 50 additions & 0 deletions helion/language/matmul_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,53 @@ def _(state: CodegenState) -> object:
rhs=rhs_ast,
acc=acc_ast,
)


@_decorators.ref(dot)
def _(
mat1: torch.Tensor,
mat2: torch.Tensor,
acc: torch.Tensor | None = None,
) -> torch.Tensor:
"""Reference implementation for hl.dot() in ref mode."""
is_fp8 = mat1.dtype in (torch.float8_e4m3fn, torch.float8_e5m2) or mat2.dtype in (
torch.float8_e4m3fn,
torch.float8_e5m2,
)

if is_fp8:
# Use torch._scaled_mm for FP8 operations
# Ensure column-major for second operand as required by torch._scaled_mm
mat2_t = mat2.T.contiguous().T
scale_a = torch.tensor(1.0, device=mat1.device)
scale_b = torch.tensor(1.0, device=mat2.device)

# Determine output dtype
if acc is not None:
out_dtype = acc.dtype
else:
out_dtype = torch.float32 # Default for FP8

result = torch._scaled_mm(
mat1,
mat2_t,
scale_a,
scale_b,
use_fast_accum=False,
out_dtype=out_dtype,
)
else:
# For non-FP8 tensors, use regular matmul
result = torch.matmul(mat1, mat2)

# Handle accumulator
if acc is not None:
# Ensure result has same dtype as accumulator
if result.dtype != acc.dtype:
result = result.to(acc.dtype)
return acc + result
# Return with appropriate dtype based on inputs
out_dtype = _compute_out_dtype(mat1.dtype, mat2.dtype)
if result.dtype != out_dtype:
result = result.to(out_dtype)
return result
Loading
Loading