From 140c47780ab167ff431505b4c2a1c1671b8c449b Mon Sep 17 00:00:00 2001 From: Will Feng Date: Fri, 18 Jul 2025 21:36:25 -0700 Subject: [PATCH] [Ref Mode] PyTorch reference mode (eager only) Part of https://github.com/pytorch-labs/helion/issues/77. Please see inline code comments on the PR. stack-info: PR: https://github.com/pytorch-labs/helion/pull/339, branch: yf225/stack/34 --- helion/__init__.py | 2 + helion/_testing.py | 9 ++ helion/language/_decorators.py | 27 ++++ helion/language/constexpr.py | 5 + helion/language/creation_ops.py | 20 +++ helion/language/device_print.py | 5 + helion/language/inline_asm_ops.py | 12 ++ helion/language/loops.py | 140 +++++++++++++++++++ helion/language/matmul_ops.py | 50 +++++++ helion/language/memory_ops.py | 223 ++++++++++++++++++++++++++++++ helion/language/reduce_ops.py | 53 +++++++ helion/language/scan_ops.py | 129 +++++++++++++++++ helion/language/signal_wait.py | 28 ++++ helion/language/tile_ops.py | 65 +++++++++ helion/language/tile_proxy.py | 144 +++++++++++++++++++ helion/language/tunable_ops.py | 19 +++ helion/language/view_ops.py | 20 +++ helion/runtime/kernel.py | 66 +++++++-- helion/runtime/ref_mode.py | 147 ++++++++++++++++++++ helion/runtime/settings.py | 22 +++ test/test_ref_eager.py | 167 ++++++++++++++++++++++ 21 files changed, 1345 insertions(+), 8 deletions(-) create mode 100644 helion/runtime/ref_mode.py create mode 100644 test/test_ref_eager.py diff --git a/helion/__init__.py b/helion/__init__.py index 1363eb20..728f983d 100644 --- a/helion/__init__.py +++ b/helion/__init__.py @@ -11,12 +11,14 @@ from .runtime import Kernel from .runtime import kernel from .runtime import kernel as jit # alias +from .runtime.settings import RefMode from .runtime.settings import Settings from .runtime.settings import set_default_settings __all__ = [ "Config", "Kernel", + "RefMode", "Settings", "cdiv", "exc", diff --git a/helion/_testing.py b/helion/_testing.py index 76e29dde..bb27368a 100644 --- a/helion/_testing.py +++ b/helion/_testing.py @@ -47,6 +47,15 @@ def code_and_output( args: tuple[object, ...], **kwargs: object, ) -> tuple[str, object]: + bound = fn.bind(args) + from helion.runtime.settings import RefMode + + if bound.kernel.settings.ref_mode != RefMode.OFF: + result = fn(*args) + # Return the original kernel source code + code = inspect.getsource(fn.fn) + return code, result + if kwargs: config = Config( **kwargs # pyright: ignore[reportArgumentType] diff --git a/helion/language/_decorators.py b/helion/language/_decorators.py index fbddb640..a82c1026 100644 --- a/helion/language/_decorators.py +++ b/helion/language/_decorators.py @@ -79,6 +79,7 @@ class APIFunc(Protocol): _to_device_ir: Callable[..., object] | None _allow_host_tensor: bool _signature: inspect.Signature + _ref_fn: Callable[..., object] | None def __call__(self, *args: object, **kwargs: object) -> object: ... @@ -133,6 +134,15 @@ def api( def _impl(fn: _C) -> _C: @functools.wraps(fn) def wrapper(*args: object, **kwargs: object) -> object: + from ..runtime.ref_mode import is_ref_mode_enabled + + if is_ref_mode_enabled() and api._ref_fn is not None: + # In ref mode, use the registered ref implementation + bound = api._signature.bind(*args, **kwargs) + bound.apply_defaults() + flat_args = api._prepare_args(*bound.arguments.values()) + return api._ref_fn(*flat_args) + bound = api._signature.bind(*args, **kwargs) bound.apply_defaults() flat_args = api._prepare_args(*bound.arguments.values()) @@ -187,6 +197,7 @@ def wrapper(*args: object, **kwargs: object) -> object: api._signature = signature or inspect.signature( cast("Callable[..., object]", fn) ) + api._ref_fn = None return wrapper # pyright: ignore[reportReturnType] return _impl @@ -289,6 +300,22 @@ def _impl(to_device_ir_fn: Callable[..., object]) -> Callable[..., Never]: return _impl # pyright: ignore[reportReturnType] +def ref( + original_fn: Callable[..., object], +) -> _NoReturnDecorator[object]: + def _impl(ref_fn: Callable[..., object]) -> Callable[..., Never]: + assert is_api_func(original_fn), ( + f"{ref.__qualname__} can only be used on API functions" + ) + assert original_fn._ref_fn is None, ( + "ref mode implementation can only be registered once per function" + ) + original_fn._ref_fn = ref_fn + return _no_call + + return _impl # pyright: ignore[reportReturnType] + + def _default_type_function( fake_fn: Callable[..., object], tiles_as_sizes: bool ) -> Callable[..., TypeInfo]: diff --git a/helion/language/constexpr.py b/helion/language/constexpr.py index 8b2ae084..e4170fd9 100644 --- a/helion/language/constexpr.py +++ b/helion/language/constexpr.py @@ -95,3 +95,8 @@ def _(state: CodegenState) -> ast.AST: value = value.__int__() assert isinstance(value, int) return expr_from_string(repr(value)) + + +@_decorators.ref(specialize) +def _(value: int | torch.SymInt) -> int: + return int(value) diff --git a/helion/language/creation_ops.py b/helion/language/creation_ops.py index b22dc090..df5e2ff7 100644 --- a/helion/language/creation_ops.py +++ b/helion/language/creation_ops.py @@ -144,6 +144,26 @@ def _( return None +@_decorators.ref(full) +def _( + shape: list[int | slice], + value: float, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + from .tile_proxy import RefTile + + processed_shape = [] + for s in shape: + if isinstance(s, RefTile): + # RefTile is a slice subclass with a block_size property + processed_shape.append(s.block_size) + elif isinstance(s, slice): + processed_shape.append(s.stop - s.start) + else: + processed_shape.append(s) + return torch.full(processed_shape, value, dtype=dtype, device="cuda") + + def arange( *args: int, dtype: torch.dtype | None = None, diff --git a/helion/language/device_print.py b/helion/language/device_print.py index 60170b7d..7747d35c 100644 --- a/helion/language/device_print.py +++ b/helion/language/device_print.py @@ -90,3 +90,8 @@ def _(state: CodegenState) -> None: ) stmt = create(ast.Expr, value=call_expr) state.add_statement(stmt) + + +@_decorators.ref(device_print) +def _(prefix: str, *args: object) -> None: + print(prefix, *args) diff --git a/helion/language/inline_asm_ops.py b/helion/language/inline_asm_ops.py index 936acd0d..b80310d5 100644 --- a/helion/language/inline_asm_ops.py +++ b/helion/language/inline_asm_ops.py @@ -205,3 +205,15 @@ def _(state: CodegenState) -> ast.AST | list[ast.AST]: ] return inline_asm_call + + +@_decorators.ref(inline_asm_elementwise) +def _( + asm: str, + constraints: str, + args: Sequence[torch.Tensor], + dtype: torch.dtype | Sequence[torch.dtype], + is_pure: bool, + pack: int, +) -> torch.Tensor | tuple[torch.Tensor, ...]: + raise NotImplementedError("inline_asm_elementwise is not supported in ref mode") diff --git a/helion/language/loops.py b/helion/language/loops.py index 03d874e8..772423a3 100644 --- a/helion/language/loops.py +++ b/helion/language/loops.py @@ -3,6 +3,7 @@ import ast import builtins import inspect +import itertools from itertools import starmap from typing import TYPE_CHECKING from typing import Iterator @@ -42,6 +43,7 @@ from ..autotuner.config_spec import RangeWarpSpecializeSpec from ..autotuner.config_spec import StaticRangeSpec from . import _decorators +from .tile_proxy import RefTile from .tile_proxy import Tile if TYPE_CHECKING: @@ -449,6 +451,93 @@ def _(state: CodegenState) -> ast.AST: return _codegen_loop_helper(state) +@_decorators.ref(tile) +def _( + begin_or_end: int | torch.Tensor | list[int | torch.Tensor], + end_or_none: int | torch.Tensor | list[int | torch.Tensor] | None = None, + block_size: int | torch.Tensor | list[int | torch.Tensor] | None = None, +) -> Iterator[RefTile | tuple[RefTile, ...]]: + # Convert tensor values to int + def _to_int(value: int | torch.Tensor | None) -> int | None: + if value is None: + return None + if isinstance(value, torch.Tensor): + return int(value.item()) + return int(value) + + # Step 1: Normalize begin and end values based on the number of arguments + if end_or_none is not None: + # Two positional args: begin_or_end is begin, end_or_none is end + begin = begin_or_end + end = end_or_none + else: + # One positional arg: begin_or_end is end, begin defaults to 0 + end = begin_or_end + # Create begin with same structure as end, but all zeros + if isinstance(end, (list, tuple)): + begin = cast( + "int | torch.Tensor | list[int | torch.Tensor]", [0] * len(end) + ) + else: + begin = 0 + + # Step 2: Convert inputs to lists for uniform handling + def _normalize_to_list( + value: int | torch.Tensor | list[int | torch.Tensor], + ) -> list[int | torch.Tensor]: + if isinstance(value, (list, tuple)): + return list(value) + return [value] + + begin_list = _normalize_to_list(begin) + end_list = _normalize_to_list(end) + + # Convert all values to int + begin_list = [_to_int(b) for b in begin_list] + end_list = [_to_int(e) for e in end_list] + + # Step 3: Determine block_size based on the arguments + if block_size is None: + # Default block_size to end - begin for each dimension + block_size_list = [] + for b, e in zip(begin_list, end_list, strict=False): + assert b is not None and e is not None + block_size_list.append(e - b) + else: + block_size_list = _normalize_to_list(block_size) + processed_block_size_list = [] + for bs, b, e in zip(block_size_list, begin_list, end_list, strict=False): + assert b is not None and e is not None + if bs is not None: + processed_bs = _to_int(bs) + assert processed_bs is not None + processed_block_size_list.append(processed_bs) + else: + processed_block_size_list.append(e - b) + block_size_list = processed_block_size_list + + # Step 4: Yield tile ranges + # Handle single dimension case + if len(begin_list) == 1: + b = begin_list[0] + e = end_list[0] + bs = block_size_list[0] + assert b is not None and e is not None and bs is not None + for i in range(b, e, bs): + yield RefTile(i, min(i + bs, e)) + else: + # Handle multi-dimensional case + ranges = [] + for b, e, bs in zip(begin_list, end_list, block_size_list, strict=False): + dim_ranges = [] + assert b is not None and e is not None and bs is not None + for i in range(b, e, bs): + dim_ranges.append(RefTile(i, min(i + bs, e))) + ranges.append(dim_ranges) + + yield from itertools.product(*ranges) + + def _codegen_loop_helper( state: CodegenState, ) -> ast.AST: @@ -637,6 +726,46 @@ def _(state: CodegenState) -> ast.AST: return _codegen_loop_helper(state) +@_decorators.ref(grid) +def _( + begin_or_end: int | torch.Tensor | list[int | torch.Tensor], + end_or_none: int | torch.Tensor | list[int | torch.Tensor] | None = None, + step: object = None, +) -> range | Iterator[tuple[int, ...]]: + # Similar to tile but yields indices instead of slices + if end_or_none is not None: + begin = begin_or_end + end = end_or_none + else: + end = begin_or_end + if isinstance(end, (list, tuple)): + begin = cast( + "int | torch.Tensor | list[int | torch.Tensor]", [0] * len(end) + ) + else: + begin = 0 + + # Convert tensor values to int + def _to_int(value: int | torch.Tensor) -> int: + if isinstance(value, torch.Tensor): + return int(value.item()) + return int(value) + + # Handle single dimension + if not isinstance(begin, (list, tuple)): + begin_int = _to_int(begin) + assert not isinstance(end, (list, tuple)) + end_int = _to_int(end) + return range(begin_int, end_int) + + # Handle multi-dimensional + assert isinstance(end, (list, tuple)) + begin_ints = [_to_int(b) for b in begin] + end_ints = [_to_int(e) for e in end] + ranges = list(itertools.starmap(range, zip(begin_ints, end_ints, strict=False))) + return itertools.product(*ranges) + + @_decorators.device_func_replacement(builtins.zip) @_decorators.api(is_device_only=True, cache_type=True) def _zip_replacement( @@ -898,3 +1027,14 @@ def _( # Return tuple(range(...)) which will trigger existing tuple/list unrolling return tuple(range(begin_val, end_val, step)) + + +@_decorators.ref(static_range) +def _( + begin_or_end: int, + end_or_none: int | None = None, + step: int = 1, +) -> range: + if end_or_none is not None: + return range(begin_or_end, end_or_none, step) + return range(begin_or_end) diff --git a/helion/language/matmul_ops.py b/helion/language/matmul_ops.py index 25e827d0..fa7044a1 100644 --- a/helion/language/matmul_ops.py +++ b/helion/language/matmul_ops.py @@ -209,3 +209,53 @@ def _(state: CodegenState) -> object: rhs=rhs_ast, acc=acc_ast, ) + + +@_decorators.ref(dot) +def _( + mat1: torch.Tensor, + mat2: torch.Tensor, + acc: torch.Tensor | None = None, +) -> torch.Tensor: + """Reference implementation for hl.dot() in ref mode.""" + is_fp8 = mat1.dtype in (torch.float8_e4m3fn, torch.float8_e5m2) or mat2.dtype in ( + torch.float8_e4m3fn, + torch.float8_e5m2, + ) + + if is_fp8: + # Use torch._scaled_mm for FP8 operations + # Ensure column-major for second operand as required by torch._scaled_mm + mat2_t = mat2.T.contiguous().T + scale_a = torch.tensor(1.0, device=mat1.device) + scale_b = torch.tensor(1.0, device=mat2.device) + + # Determine output dtype + if acc is not None: + out_dtype = acc.dtype + else: + out_dtype = torch.float32 # Default for FP8 + + result = torch._scaled_mm( + mat1, + mat2_t, + scale_a, + scale_b, + use_fast_accum=False, + out_dtype=out_dtype, + ) + else: + # For non-FP8 tensors, use regular matmul + result = torch.matmul(mat1, mat2) + + # Handle accumulator + if acc is not None: + # Ensure result has same dtype as accumulator + if result.dtype != acc.dtype: + result = result.to(acc.dtype) + return acc + result + # Return with appropriate dtype based on inputs + out_dtype = _compute_out_dtype(mat1.dtype, mat2.dtype) + if result.dtype != out_dtype: + result = result.to(out_dtype) + return result diff --git a/helion/language/memory_ops.py b/helion/language/memory_ops.py index 3bdb529b..5aa95299 100644 --- a/helion/language/memory_ops.py +++ b/helion/language/memory_ops.py @@ -18,6 +18,124 @@ __all__ = ["atomic_add", "load", "store"] +# Helper functions for ref mode implementations +def _normalize_indices(indices: slice | list | tuple) -> slice | tuple: + if isinstance(indices, slice): + return slice(indices.start, indices.stop) + if isinstance(indices, (list, tuple)): + return tuple( + slice(idx.start, idx.stop) if isinstance(idx, slice) else idx + for idx in indices + ) + return indices + + +def _combine_masks( + mask1: torch.Tensor | None, mask2: torch.Tensor | None +) -> torch.Tensor | None: + if mask1 is not None and mask2 is not None: + return mask1 & mask2 + return mask1 or mask2 + + +def _apply_mask( + result: torch.Tensor, mask: torch.Tensor | None, other: float | torch.Tensor = 0 +) -> torch.Tensor: + if mask is None: + return result + + # Handle shape mismatch + if result.shape != mask.shape: + if mask.numel() == 0 or result.numel() == 0: + return torch.zeros(mask.shape, dtype=result.dtype, device=result.device) + # Let torch handle broadcasting + + return torch.where(mask, result, other) + + +def _handle_single_tensor_index( + tensor: torch.Tensor, idx_tensor: torch.Tensor, extra_mask: torch.Tensor | None +) -> torch.Tensor: + """Handle indexing with a single tensor index (jagged array case).""" + flat_indices = idx_tensor.flatten() + clamped_indices = torch.clamp(flat_indices, 0, tensor.shape[0] - 1) + + if extra_mask is None: + return tensor[clamped_indices].reshape(idx_tensor.shape) + + # Apply mask to filter valid indices + valid_mask = extra_mask.flatten() + gathered = tensor[clamped_indices] + result = torch.zeros(idx_tensor.shape, dtype=tensor.dtype, device=tensor.device) + result_flat = result.flatten() + result_flat = torch.where(valid_mask, gathered, result_flat) + return result_flat.reshape(idx_tensor.shape) + + +def _handle_mixed_indices( + tensor: torch.Tensor, indices: tuple, extra_mask: torch.Tensor | None +) -> torch.Tensor: + """Handle mixed indexing with slices and tensors.""" + expected_shape = [] + actual_indices = [] + tensor_shape = tensor.shape + + # Build expected output shape and process indices + for i, idx in enumerate(indices): + if isinstance(idx, slice): + # Handle slice indices + if idx.start is None and idx.stop is None: + # Full slice like `:` + shape_size = tensor_shape[i] if i < len(tensor_shape) else 1 + else: + start = idx.start or 0 + stop = idx.stop or (tensor_shape[i] if i < len(tensor_shape) else 1) + shape_size = stop - start + expected_shape.append(shape_size) + actual_indices.append(idx) + elif isinstance(idx, torch.Tensor): + # Handle tensor indices - clamp to valid range + expected_shape.extend(idx.shape) + max_index = tensor_shape[i] - 1 if i < len(tensor_shape) else 0 + clamped_idx = torch.clamp(idx, 0, max_index) + actual_indices.append(clamped_idx) + else: + # Regular integer index + actual_indices.append(idx) + + # Perform indexing with error handling + try: + result = tensor[tuple(actual_indices)] + + # Handle shape mismatch when using extra_mask + if extra_mask is not None and result.shape != tuple(expected_shape): + result = _pad_result_to_expected_shape( + result, expected_shape, tensor.dtype, tensor.device + ) + + return result + except (RuntimeError, IndexError): + # Return zeros if indexing fails (e.g., negative indices) + return torch.zeros(expected_shape, dtype=tensor.dtype, device=tensor.device) + + +def _pad_result_to_expected_shape( + result: torch.Tensor, + expected_shape: list[int], + dtype: torch.dtype, + device: torch.device, +) -> torch.Tensor: + """Pad result tensor with zeros to match expected shape.""" + padded_result = torch.zeros(expected_shape, dtype=dtype, device=device) + + if result.numel() > 0: + # Copy valid data to padded result + slices = [slice(0, s) for s in result.shape] + padded_result[tuple(slices)] = result + + return padded_result + + @has_side_effect @_decorators.api(tiles_as_sizes=True, allow_host_tensor=True) def store( @@ -84,6 +202,33 @@ def _(state: CodegenState) -> ast.AST: ) +@_decorators.ref(store) +def _( + tensor: torch.Tensor, + indices: list[object], + value: torch.Tensor, + extra_mask: torch.Tensor | None = None, +) -> None: + # Convert RefTile objects to slices + from .tile_proxy import RefTile + + processed_indices = [] + for idx in indices: + if isinstance(idx, RefTile): + processed_indices.append(idx._slice) + else: + processed_indices.append(idx) + indices = processed_indices + + normalized_indices = _normalize_indices(indices) + + if extra_mask is not None: + current = tensor[normalized_indices] + tensor[normalized_indices] = torch.where(extra_mask, value, current) + else: + tensor[normalized_indices] = value + + @_decorators.api(tiles_as_sizes=True, allow_host_tensor=True) def load( tensor: torch.Tensor, index: list[object], extra_mask: torch.Tensor | None = None @@ -129,6 +274,46 @@ def _(node: torch.fx.Node) -> int: return 0 # loads are always masked to 0 +@_decorators.ref(load) +def _( + tensor: torch.Tensor, + indices: list[object], + extra_mask: torch.Tensor | None = None, +) -> torch.Tensor: + # Combined mask handling is done inside load logic + mask = None # No base mask for ref mode + other = 0 + + assert isinstance(indices, (list, tuple)) + + # Convert RefTile objects to slices + from .tile_proxy import RefTile + + processed_indices = [] + for idx in indices: + if isinstance(idx, RefTile): + processed_indices.append(idx._slice) + else: + processed_indices.append(idx) + indices = processed_indices + + # Case 1: Single tensor index (jagged indexing) + if len(indices) == 1 and isinstance(indices[0], torch.Tensor): + result = _handle_single_tensor_index(tensor, indices[0], extra_mask) + + # Case 2: Mixed indices containing slices (tiles) + elif any(isinstance(idx, slice) for idx in indices): + result = _handle_mixed_indices(tensor, tuple(indices), extra_mask) + else: + raise exc.InvalidIndexingType( + f"Invalid indices type: {indices}. Expected a list of slices or tensors." + ) + + # Apply mask + combined_mask = _combine_masks(mask, extra_mask) + return _apply_mask(result, combined_mask, other) + + @has_side_effect @_decorators.api(allow_host_tensor=True) def atomic_add( @@ -233,3 +418,41 @@ def _(state: CodegenState) -> ast.AST: mask=indices.mask_expr, sem=sem, ) + + +@_decorators.ref(atomic_add) +def _( + tensor: torch.Tensor, + indices: list[object], + value: torch.Tensor | float, + sem: str = "relaxed", +) -> None: + # Convert RefTile objects to slices + from .tile_proxy import RefTile + + processed_indices = [] + for idx in indices: + if isinstance(idx, RefTile): + processed_indices.append(idx._slice) + else: + processed_indices.append(idx) + indices = processed_indices + + # Special handling for scatter-add pattern (`tensor[tensor_idx, slice] += value`) + if isinstance(indices, (list, tuple)) and len(indices) == 2: + idx0, idx1 = indices + if isinstance(idx0, torch.Tensor) and isinstance(idx1, slice): + # This is the pattern: output[idxs, tile_f] += segment_vals + start = idx1.start or 0 + stop = idx1.stop or tensor.shape[1] + tensor_view = tensor[:, start:stop] + if isinstance(value, (float, int)): + value_tensor = torch.full_like(tensor_view[0], value) + tensor_view.index_add_(0, idx0, value_tensor) + else: + tensor_view.index_add_(0, idx0, value) + return + + # Default case + normalized_indices = _normalize_indices(indices) + tensor[normalized_indices] += value diff --git a/helion/language/reduce_ops.py b/helion/language/reduce_ops.py index b811330b..eae4ae1a 100644 --- a/helion/language/reduce_ops.py +++ b/helion/language/reduce_ops.py @@ -278,6 +278,46 @@ def _( return output_tensor +@_decorators.ref(reduce) +def _( + combine_fn: CombineFunction, + input_tensor: torch.Tensor | tuple[torch.Tensor, ...], + dim: int | None = None, + other: float | tuple[float, ...] = 0, + keep_dims: bool = False, +) -> torch.Tensor | tuple[torch.Tensor, ...]: + # In ref mode, we simulate reduce using PyTorch operations + # combine_fn is a binary function that combines two elements + # For standard reductions, we would use torch operations directly + + # This is a simplified implementation that doesn't actually use combine_fn + # In a real implementation, we would need to manually reduce using the combine function + if isinstance(input_tensor, tuple): + # For tuple inputs, reduce each tensor separately + results = [] + for tensor in input_tensor: + if dim is None: + # Reduce all dimensions - return a scalar + result = tensor.sum() + if keep_dims: + result = result.reshape([1] * len(tensor.shape)) + else: + # Reduce specific dimension + result = tensor.sum(dim=dim, keepdim=keep_dims) + results.append(result) + return tuple(results) + # Single tensor input + if dim is None: + # Reduce all dimensions - return a scalar + result = input_tensor.sum() + if keep_dims: + result = result.reshape([1] * len(input_tensor.shape)) + else: + # Reduce specific dimension + result = input_tensor.sum(dim=dim, keepdim=keep_dims) + return result + + @_decorators.api() def _reduce( combine_graph_id: int, @@ -337,6 +377,19 @@ def _(state: CodegenState) -> ast.AST | list[ast.AST]: return reduce_expr +@_decorators.ref(_reduce) +def _( + combine_graph_id: int, + input_tensor: torch.Tensor | tuple[torch.Tensor, ...], + dim: int | None = None, + keep_dims: bool = False, + is_tuple_input: bool = False, +) -> torch.Tensor | tuple[torch.Tensor, ...]: + # For ref mode, we don't have access to the combine graph + # This should be handled by the higher-level reduce ref implementation + raise NotImplementedError("_reduce should not be called in ref mode") + + def _register_helper_function(state: CodegenState, combine_graph_id: int) -> str: """Register the helper function and return its final name.""" from .._compiler.device_ir import HelperFunctionGraphInfo diff --git a/helion/language/scan_ops.py b/helion/language/scan_ops.py index 44ea10cb..95d9d37b 100644 --- a/helion/language/scan_ops.py +++ b/helion/language/scan_ops.py @@ -3,6 +3,8 @@ import ast import operator from typing import TYPE_CHECKING +from typing import Callable +from typing import Iterator from typing import cast from typing import overload @@ -23,6 +25,108 @@ __all__ = ["associative_scan", "cumprod", "cumsum"] +# Helper functions for ref mode implementations +def _build_indices( + shape: tuple[int, ...], dim: int, idx: int +) -> tuple[slice | int, ...]: + """Build indexing tuple for accessing position idx along dimension dim.""" + indices: list[slice | int] = [slice(None)] * len(shape) + indices[dim] = idx + return tuple(indices) + + +def _iterate_scan_dimension( + scan_size: int, reverse: bool +) -> Iterator[tuple[int, int, bool]]: + """ + Generate iteration indices for scan operation. + + Yields: + Tuple of (iteration_index, actual_index, is_first_element) + """ + for i in range(scan_size): + # Calculate current index based on scan direction + idx = (scan_size - 1 - i) if reverse else i + + # Check if this is the first element in the scan + is_first = (i == 0 and not reverse) or (i == scan_size - 1 and reverse) + + yield i, idx, is_first + + +def _get_prev_index(idx: int, reverse: bool) -> int: + """Get the previous index in the scan sequence.""" + return (idx + 1) if reverse else (idx - 1) + + +def _scan_single_tensor( + combine_fn: Callable, input_tensor: torch.Tensor, dim: int, reverse: bool +) -> torch.Tensor: + """Helper function to perform scan on a single tensor.""" + result = torch.empty_like(input_tensor) + scan_size = input_tensor.shape[dim] + + # Iterate through the dimension to scan + for _i, idx, is_first in _iterate_scan_dimension(scan_size, reverse): + # Build indexing tuple to access elements at position idx along dim + indices = _build_indices(input_tensor.shape, dim, idx) + + if is_first: + # First element: copy input directly + result[indices] = input_tensor[indices] + else: + # Combine with previous accumulated value + prev_idx = _get_prev_index(idx, reverse) + prev_indices = _build_indices(input_tensor.shape, dim, prev_idx) + + # Apply the combine function + result[indices] = combine_fn(result[prev_indices], input_tensor[indices]) + + return result + + +def _scan_tuple_tensors( + combine_fn: Callable, input_tuple: tuple[torch.Tensor, ...], dim: int, reverse: bool +) -> tuple[torch.Tensor, ...]: + """Helper function to perform scan on a tuple of tensors.""" + tensors = list(input_tuple) + scan_size = tensors[0].shape[dim] + + # Initialize result tensors + results = [torch.empty_like(t) for t in tensors] + + # Iterate through the dimension to scan + for _i, idx, is_first in _iterate_scan_dimension(scan_size, reverse): + # Build indexing tuple + indices = _build_indices(tensors[0].shape, dim, idx) + + if is_first: + # First element: copy inputs directly + for j, tensor in enumerate(tensors): + results[j][indices] = tensor[indices] + else: + # Combine with previous accumulated values + prev_idx = _get_prev_index(idx, reverse) + prev_indices = _build_indices(tensors[0].shape, dim, prev_idx) + + # Gather values for combination + current_vals = tuple(t[indices] for t in tensors) + prev_vals = tuple(r[prev_indices] for r in results) + + # Apply combine function with unpacked arguments + combined = combine_fn(*prev_vals, *current_vals) + + # Store results (handle both single and tuple returns) + if isinstance(combined, tuple): + for j, val in enumerate(combined): + results[j][indices] = val + else: + # Single result case + results[0][indices] = combined + + return tuple(results) + + @overload @_decorators.device_func_replacement(higher_order_ops.associative_scan) @_decorators.api(is_device_only=True) @@ -229,6 +333,18 @@ def _( ) +@_decorators.ref(associative_scan) +def _( + combine_fn: Callable, + input_tensor: torch.Tensor | tuple[torch.Tensor, ...], + dim: int, + reverse: bool = False, +) -> torch.Tensor | tuple[torch.Tensor, ...]: + if isinstance(input_tensor, (tuple, list)): + return _scan_tuple_tensors(combine_fn, tuple(input_tensor), dim, reverse) + return _scan_single_tensor(combine_fn, input_tensor, dim, reverse) + + @_decorators.device_func_replacement(torch.cumsum) def cumsum(input_tensor: torch.Tensor, dim: int, reverse: bool = False) -> torch.Tensor: """ @@ -335,6 +451,19 @@ def _(state: CodegenState) -> ast.AST | list[ast.AST]: return scan_expr +@_decorators.ref(_associative_scan) +def _( + combine_graph_id: int, + input_tensor: torch.Tensor | tuple[torch.Tensor, ...], + dim: int, + reverse: bool = False, + is_tuple_input: bool = False, +) -> torch.Tensor | tuple[torch.Tensor, ...]: + # For ref mode, we don't have access to the combine graph + # This should be handled by the higher-level associative_scan ref implementation + raise NotImplementedError("_associative_scan should not be called in ref mode") + + def _get_input_tensor_ast(state: CodegenState, is_tuple_input: bool) -> ast.AST: """Get the input tensor AST, handling tuple inputs specially.""" if not is_tuple_input: diff --git a/helion/language/signal_wait.py b/helion/language/signal_wait.py index 2c474aaa..2af3b239 100644 --- a/helion/language/signal_wait.py +++ b/helion/language/signal_wait.py @@ -157,6 +157,20 @@ def _(state: CodegenState) -> ast.AST: ) +@_decorators.ref(wait) +def _( + signal_pad: torch.Tensor, + index: list[object], + signal: int = 1, + update: int | None = None, + op: str = "ld", + sem: str = "acquire", + scope: str = "gpu", + skip_sync: bool = False, +) -> None: + raise NotImplementedError("wait is not supported in ref mode") + + @has_side_effect @_decorators.api(tiles_as_sizes=True, allow_host_tensor=True) def signal( @@ -295,3 +309,17 @@ def _(state: CodegenState) -> ast.AST: signal=signal_expr, skip_sync=skip_sync_expr, ) + + +@_decorators.ref(signal) +def _( + signal_pad: torch.Tensor, + index: list[object], + signal: int = 1, + wait_for: int | None = None, + op: str = "atomic_xchg", + sem: str = "release", + scope: str = "gpu", + skip_sync: bool = False, +) -> torch.Tensor: + raise NotImplementedError("signal is not supported in ref mode") diff --git a/helion/language/tile_ops.py b/helion/language/tile_ops.py index 6e90b9ac..f8e36cf9 100644 --- a/helion/language/tile_ops.py +++ b/helion/language/tile_ops.py @@ -48,6 +48,19 @@ def _(state: CodegenState) -> ast.AST: return expr_from_string(state.codegen.index_var(index)) +@_decorators.ref(tile_index) +def _(tile: slice | int) -> torch.Tensor: + # Handle different tile representations in ref mode + from .tile_proxy import RefTile + + if isinstance(tile, RefTile): + return tile.index + if isinstance(tile, slice): + return torch.arange(tile.start, tile.stop, dtype=torch.int64, device="cuda") + # tiles_as_sizes=True means we get an int + return torch.arange(0, tile, dtype=torch.int64, device="cuda") + + @_decorators.api(tiles_as_sizes=True) def tile_begin(tile: Tile) -> int: """ @@ -82,6 +95,20 @@ def _(state: CodegenState) -> ast.AST: return expr_from_string(state.codegen.offset_var(index)) +@_decorators.ref(tile_begin) +def _(tile: int | slice) -> int: + # Handle different tile representations in ref mode + from .tile_proxy import RefTile + + if isinstance(tile, RefTile): + return tile.begin + if isinstance(tile, slice): + return tile.start + # In ref mode with tiles_as_sizes=True, we lost the begin info + # This is a limitation - we return 0 as we don't know the actual begin + return 0 + + @_decorators.api(tiles_as_sizes=True) def tile_end(tile: Tile) -> int: """ @@ -121,6 +148,20 @@ def _(state: CodegenState) -> ast.AST: return expr_from_string(naive_exp) +@_decorators.ref(tile_end) +def _(tile: int | slice) -> int: + # Handle different tile representations in ref mode + from .tile_proxy import RefTile + + if isinstance(tile, RefTile): + return tile.end + if isinstance(tile, slice): + return tile.stop + # In ref mode with tiles_as_sizes=True, we get the size + # We lost the begin info, so we assume end = size + return tile + + @_decorators.api(tiles_as_sizes=True) def tile_block_size(tile: Tile) -> int: """ @@ -139,6 +180,19 @@ def _(tile: torch.SymInt) -> torch.SymInt: # codegen is handled in _get_symnode() +@_decorators.ref(tile_block_size) +def _(tile: int | slice) -> int: + # Handle different tile representations in ref mode + from .tile_proxy import RefTile + + if isinstance(tile, RefTile): + return tile.block_size + if isinstance(tile, slice): + return tile.stop - tile.start + # In ref mode with tiles_as_sizes=True, the tile IS the size + return tile + + @_decorators.api(tiles_as_sizes=True) def tile_id(tile: Tile) -> int: """ @@ -166,3 +220,14 @@ def _(state: CodegenState) -> ast.AST: else: expr_str = f"{offset} // {block_size}" return expr_from_string(expr_str) + + +@_decorators.ref(tile_id) +def _(tile: int | slice) -> int: + # tile_id is the index of the tile in the grid + from .tile_proxy import RefTile + + if isinstance(tile, RefTile): + return tile.id + # For ref mode we don't have the original block_size, so we return 0 + return 0 diff --git a/helion/language/tile_proxy.py b/helion/language/tile_proxy.py index 7097a2bf..98320ed2 100644 --- a/helion/language/tile_proxy.py +++ b/helion/language/tile_proxy.py @@ -182,3 +182,147 @@ def __enter__(self) -> Self: def __exit__(self, *args: object) -> None: _tls.index_calls = None + + +class RefTile(torch.Tensor): + """ + A tile-like object used in reference eager mode that behaves like a slice. + This allows tile.index and other tile operations to work properly in ref eager mode. + """ + + def __new__(cls, start: int, stop: int, step: int | None = None) -> Self: + # Create a tensor instance + return super().__new__(cls) + + def __init__(self, start: int, stop: int, step: int | None = None) -> None: + super().__init__() + # Store slice data + self.start = start + self.stop = stop + self.step = step + self._slice = slice(start, stop, step) + # We need to set block_id to something for compatibility + self.block_id = -1 # Special value for ref mode + + @property + def index(self) -> torch.Tensor: + """Return a tensor containing the offsets for this tile.""" + return torch.arange(self.start, self.stop, dtype=torch.int64, device="cuda") + + @property + def begin(self) -> int: + """Return the start offset of this tile.""" + return self.start + + @property + def end(self) -> int: + """Return the end offset of this tile.""" + return self.stop + + @property + def block_size(self) -> int: + """Return the block size of this tile.""" + return self.stop - self.start + + @property + def id(self) -> int: + """Return the id of this tile (always 0 in ref mode).""" + # We don't have enough info to compute the actual tile id + return 0 + + def __repr__(self, *, tensor_contents: object = None) -> str: + """Return string representation of RefTile.""" + # Override torch.Tensor's __repr__ with matching signature + return f"RefTile({self._slice!r})" + + def __int__(self) -> int: + """Convert to int for cases where a size is expected.""" + return self.block_size + + # Make RefTile usable as an index by delegating to the slice + def slice_indices(self, length: int) -> tuple[int, int, int]: + """Return (start, stop, step) tuple, like slice.indices().""" + return self._slice.indices(length) + + def equals(self, other: object) -> bool: + """Compare with other RefTile or slice objects. + + Use this instead of == for RefTile comparison. + """ + if isinstance(other, RefTile): + return self._slice == other._slice + if isinstance(other, slice): + return self._slice == other + return False + + def __hash__(self) -> int: + """Hash based on the slice.""" + return hash(self._slice) + + def __index__(self) -> int: + """Convert to int for use in tensor indexing. + + This is called when RefTile is used in advanced indexing contexts. + We return the start value which works for single-element tiles. + """ + # For single-element access (when block_size=1), return the index + if self.block_size == 1: + return self.start + # For larger tiles, we can't meaningfully convert to a single index + # This might happen in user lambdas trying to do advanced indexing + raise TypeError( + f"Cannot convert RefTile with block_size={self.block_size} to index" + ) + + @classmethod + def __torch_function__( + cls, + func: Callable[..., object], + types: object, + args: tuple[object, ...] = (), + kwargs: dict[str, object] | None = None, + ) -> object: + from ..language.memory_ops import load + from ..language.memory_ops import store + + if func is torch.Tensor.__getitem__: + if len(args) != 2 or kwargs: + raise exc.IncorrectTileUsage(func) + tensor, index = args + assert isinstance(tensor, torch.Tensor) + + # If a single RefTile is used as index, we want to use it as a slice + # e.g., tensor[ref_tile] should behave like tensor[ref_tile._slice] + if isinstance(index, RefTile): + return tensor[index._slice] + + # For multi-dimensional indexing (including lists) + return load(tensor, cls._prepare_index(index)) + + if func is torch.Tensor.__setitem__: + if len(args) != 3 or kwargs: + raise exc.IncorrectTileUsage(func) + tensor, index, value = args + assert isinstance(tensor, torch.Tensor) + assert isinstance(value, torch.Tensor) + + # Similar handling for setitem + if isinstance(index, RefTile): + tensor[index._slice] = value + return None + + return store(tensor, cls._prepare_index(index), value) + + if func is torch.Tensor.__format__: + return repr(args[0]) + raise exc.IncorrectTileUsage(func) + + @staticmethod + def _prepare_index(index: object) -> list[object]: + if isinstance(index, (list, tuple)): + # When indexing with a list of RefTiles like bias[[tile_m, tile_n]], + # we want it to be interpreted as bias[tile_m, tile_n] + # So we return the list as-is for multi-dimensional indexing + return [*index] + assert isinstance(index, RefTile) + return [index] diff --git a/helion/language/tunable_ops.py b/helion/language/tunable_ops.py index e6086ae2..0c33f170 100644 --- a/helion/language/tunable_ops.py +++ b/helion/language/tunable_ops.py @@ -104,6 +104,13 @@ def _(state: CodegenState) -> ast.AST: return expr_from_string(constant_repr(block_size)) +@_decorators.ref(register_block_size) +def _(min_or_max: int, max_or_none: int | None = None) -> int: + # In ref mode, block_size represents the full dimension size + # Return a very large value to simulate this behavior + return 2**31 - 1 # Max value for a 32-bit signed integer + + @_decorators.api(is_device_only=False, cache_type=True, tiles_as_sizes=True) def register_reduction_dim( size: int, @@ -158,6 +165,11 @@ def _(state: CodegenState) -> ast.AST: ] +@_decorators.ref(register_reduction_dim) +def _(size: int) -> int: + return size + + @_decorators.api(is_device_only=False) def register_tunable(name: str, fragment: ConfigSpecFragment) -> int: """ @@ -220,3 +232,10 @@ def _register_tunable_codegen(state: CodegenState) -> ast.AST: config_value = state.config[name] assert isinstance(config_value, (int, float, bool)) return expr_from_string(constant_repr(config_value)) + + +@_decorators.ref(register_tunable) +def _(name: str, fragment: ConfigSpecFragment) -> int: + default_value = fragment.default() + assert isinstance(default_value, int) + return default_value diff --git a/helion/language/view_ops.py b/helion/language/view_ops.py index cdbb7cd8..ed1dc2d0 100644 --- a/helion/language/view_ops.py +++ b/helion/language/view_ops.py @@ -2,6 +2,8 @@ import collections from typing import TYPE_CHECKING +from typing import Any +from typing import cast import torch @@ -102,6 +104,24 @@ def _(state: CodegenState) -> ast.AST: ) +@_decorators.ref(subscript) +def _(tensor: torch.Tensor, indices: list[object]) -> torch.Tensor: + # Convert indices to proper types for tensor indexing + typed_indices: list[Any] = [] + for idx in indices: + if isinstance(idx, (int, slice, torch.Tensor)) or idx is ...: + typed_indices.append(idx) + else: + # Fallback for other types, try to convert to int + try: + typed_indices.append(int(cast("Any", idx))) + except (TypeError, ValueError): + raise exc.InvalidIndexingType( + f"Cannot convert {idx!r} to index" + ) from None + return tensor[tuple(typed_indices)] + + @_decorators.get_masked_value(subscript) def _(node: torch.fx.Node) -> float | bool | None: from .._compiler.node_masking import cached_masked_value diff --git a/helion/runtime/kernel.py b/helion/runtime/kernel.py index d87b45ba..a510cd91 100644 --- a/helion/runtime/kernel.py +++ b/helion/runtime/kernel.py @@ -36,6 +36,8 @@ from .._logging import LazyString from ..language.constexpr import ConstExpr from .config import Config +from .ref_mode import RefModeContext +from .settings import RefMode from .settings import Settings if TYPE_CHECKING: @@ -278,7 +280,11 @@ def reset(self) -> None: class BoundKernel(Generic[_R]): - def __init__(self, kernel: Kernel[_R], args: tuple[object, ...]) -> None: + def __init__( + self, + kernel: Kernel[_R], + args: tuple[object, ...], + ) -> None: """ Initialize a BoundKernel object. @@ -294,8 +300,17 @@ def __init__(self, kernel: Kernel[_R], args: tuple[object, ...]) -> None: self._run: Callable[..., _R] | None = None self._config: Config | None = None self._compile_cache: dict[Config, CompiledConfig] = {} + self._ref_func: Callable[..., _R] | None = None + + # If in ref mode, skip all compilation infrastructure + if self.kernel.settings.ref_mode != RefMode.OFF: + self.env = None + self.fake_args = [] # type: ignore[assignment] + self.host_function = None # type: ignore[assignment] + return + self.env = CompileEnvironment(_find_device(args), self.kernel.settings) - with self.env: + with self.env: # pyright: ignore[reportOptionalContextManager] assert len(args) == len(self.kernel.signature.parameters) self.fake_args: list[object] = [] constexpr_args = {} @@ -345,7 +360,7 @@ def config_spec(self) -> ConfigSpec: Returns: ConfigSpec: The configuration specification. """ - return self.env.config_spec + return self.env.config_spec # pyright: ignore[reportOptionalMemberAccess] @property def configs(self) -> list[Config]: @@ -369,10 +384,10 @@ def to_triton_code(self, config: ConfigLike | None = None) -> str: """ if config is None: config = self._require_implicit_config() - with self.env: + with self.env: # pyright: ignore[reportOptionalContextManager] if not isinstance(config, Config): config = Config(**config) # pyright: ignore[reportArgumentType] - self.env.config_spec.normalize(config) + self.env.config_spec.normalize(config) # pyright: ignore[reportOptionalMemberAccess] root = generate_ast(self.host_function, config) return get_needed_imports(root) + unparse(root) @@ -415,7 +430,7 @@ def _debug_str(self) -> str: Returns: str: A string containing debug information about the kernel. """ - with self.env: + with self.env: # pyright: ignore[reportOptionalContextManager] return self.host_function.debug_str() def autotune( @@ -496,6 +511,10 @@ def _specialize_extra(self) -> list[Callable[[Sequence[object]], Hashable]]: Returns: list[Callable[[Sequence[object]], Hashable]]: A list of functions that generate extra specialization keys. """ + if self.kernel.settings.ref_mode != RefMode.OFF: + return [] # No specialization in ref mode + + assert self.env is not None # Should be set in non-ref mode if not self.env.specialized_vars: return [] @@ -516,6 +535,7 @@ def make_extractor(v: Source) -> Callable[[Sequence[object]], Hashable]: n: i for i, n in enumerate(self.kernel.signature.parameters.keys()) } extractors = [] + assert self.env is not None # Should be set in non-ref mode for v in sorted(self.env.specialized_vars, key=lambda v: v.name): source = self.env.shape_env.var_to_sources[v][0] extractors.append(make_extractor(source)) @@ -525,6 +545,8 @@ def _implicit_config(self) -> Config | None: """ Returns a single config that is implicitly used by this kernel, if any. """ + if self.kernel.settings.ref_mode != RefMode.OFF: + return None # No config needed in ref mode configs = self.kernel.configs if self._config is not None: return self._config @@ -542,6 +564,23 @@ def _require_implicit_config(self) -> Config: raise RuntimeError("no config provided and no implicit config available") return config + def run_ref(self, *args: object) -> _R: + if self._ref_func is None: + # Use the original function directly without AST transformation + fn = self.kernel.fn + + def ref_wrapper(*args: object) -> _R: # pyright: ignore[reportReturnType] + from .ref_mode import HelionTorchFunctionMode + + with RefModeContext(), HelionTorchFunctionMode(): + # EAGER mode only + return fn(*args) # pyright: ignore[reportReturnType] + + self._ref_func = ref_wrapper + + assert self._ref_func is not None + return self._ref_func(*args) + def __call__(self, *args: object) -> _R: """ Execute the kernel with the given arguments. @@ -552,6 +591,9 @@ def __call__(self, *args: object) -> _R: Returns: _R: The result of the kernel execution. """ + if self.kernel.settings.ref_mode != RefMode.OFF: + return self.run_ref(*args) + if self._run is None: if (config := self._implicit_config()) is not None: self.set_config(config) @@ -627,8 +669,16 @@ def kernel( settings_obj = Settings(**settings) if fn is None: - return functools.partial(kernel, configs=configs, settings=settings_obj) - return Kernel(fn, configs=configs, settings=settings_obj) + return functools.partial( + kernel, + configs=configs, + settings=settings_obj, + ) + return Kernel( + fn, + configs=configs, + settings=settings_obj, + ) def _tensor_key(fn: Kernel, obj: torch.Tensor) -> Hashable: diff --git a/helion/runtime/ref_mode.py b/helion/runtime/ref_mode.py new file mode 100644 index 00000000..ba5720c8 --- /dev/null +++ b/helion/runtime/ref_mode.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import threading +from typing import TYPE_CHECKING +from typing import cast + +import torch +from torch.overrides import BaseTorchFunctionMode + +if TYPE_CHECKING: + from typing_extensions import Self + +_thread_local = threading.local() + + +def is_ref_mode_enabled() -> bool: + """Check if ref mode is currently active.""" + return getattr(_thread_local, "ref_mode_enabled", False) + + +class RefModeContext: + """Context manager to enable ref mode execution.""" + + def __enter__(self) -> Self: + self._old_value = getattr(_thread_local, "ref_mode_enabled", False) + _thread_local.ref_mode_enabled = True + return self + + def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> bool: + _thread_local.ref_mode_enabled = self._old_value + return False + + +class HelionTorchFunctionMode(BaseTorchFunctionMode): + """Torch function mode for Helion ref mode operations.""" + + def __torch_function__( + self, + func: object, + types: list[type[object]], + args: tuple[object, ...] = (), + kwargs: dict[str, object] | None = None, + ) -> object: + if kwargs is None: + kwargs = {} + + # Replace torch.addmm with _helion_mixed_addmm + if func == torch.addmm: + # Cast args to expected types + assert len(args) >= 3, "addmm requires at least 3 arguments" + bias = cast("torch.Tensor", args[0]) + mat1 = cast("torch.Tensor", args[1]) + mat2 = cast("torch.Tensor", args[2]) + return _helion_mixed_addmm(bias, mat1, mat2, *args[3:], **kwargs) + + # Replace torch.baddbmm with _helion_mixed_baddbmm + if func == torch.baddbmm: + # Cast args to expected types + assert len(args) >= 3, "baddbmm requires at least 3 arguments" + bias = cast("torch.Tensor", args[0]) + batch1 = cast("torch.Tensor", args[1]) + batch2 = cast("torch.Tensor", args[2]) + return _helion_mixed_baddbmm(bias, batch1, batch2, *args[3:], **kwargs) + + return super().__torch_function__(func, types, args, kwargs) + + +def _helion_mixed_addmm( + bias: torch.Tensor, + mat1: torch.Tensor, + mat2: torch.Tensor, + *, + beta: float = 1, + alpha: float = 1, +) -> torch.Tensor: + """Mixed precision addmm that handles dtype mismatches.""" + # Ensure both matrices have the same dtype + if mat1.dtype != mat2.dtype: + raise ValueError( + f"Matrix dtypes must match for torch.addmm: mat1.dtype={mat1.dtype}, mat2.dtype={mat2.dtype}" + ) + + # Use torch.mm with out_dtype to perform mixed precision computation + # out_dtype must be the same as bias dtype or fp32 for fp16/bf16 inputs + if ( + mat1.dtype in (torch.float16, torch.bfloat16) and bias.dtype == torch.float32 + ) or mat1.dtype == bias.dtype: + result = torch.mm(mat1, mat2, out_dtype=bias.dtype) + else: + raise ValueError( + f"Unsupported dtype combination for torch.addmm: bias.dtype={bias.dtype}, " + f"mat1.dtype={mat1.dtype}. out_dtype must be the same as bias dtype or " + f"fp32 for fp16/bf16 inputs." + ) + + # Scale the result + if alpha != 1: + result = result * alpha + + # Add the bias term, converting result to bias's dtype if needed + if result.dtype != bias.dtype: + result = result.to(bias.dtype) + + if beta == 0: + return result + return result + (beta * bias) + + +def _helion_mixed_baddbmm( + bias: torch.Tensor, + batch1: torch.Tensor, + batch2: torch.Tensor, + *, + beta: float = 1, + alpha: float = 1, +) -> torch.Tensor: + """Mixed precision baddbmm that handles dtype mismatches.""" + # Ensure both batch matrices have the same dtype + if batch1.dtype != batch2.dtype: + raise ValueError( + f"Batch matrix dtypes must match for torch.baddbmm: batch1.dtype={batch1.dtype}, batch2.dtype={batch2.dtype}" + ) + + # Use torch.bmm with out_dtype to perform mixed precision computation + # out_dtype must be the same as bias dtype or fp32 for fp16/bf16 inputs + if ( + batch1.dtype in (torch.float16, torch.bfloat16) and bias.dtype == torch.float32 + ) or batch1.dtype == bias.dtype: + result = torch.bmm(batch1, batch2, out_dtype=bias.dtype) + else: + raise ValueError( + f"Unsupported dtype combination for torch.baddbmm: bias.dtype={bias.dtype}, " + f"batch1.dtype={batch1.dtype}. out_dtype must be the same as bias dtype or " + f"fp32 for fp16/bf16 inputs." + ) + + # Scale the result + if alpha != 1: + result = result * alpha + + # Add the bias term, converting result to bias's dtype if needed + if result.dtype != bias.dtype: + result = result.to(bias.dtype) + + if beta == 0: + return result + return result + (beta * bias) diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index 339f8cf2..6273b011 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -1,6 +1,7 @@ from __future__ import annotations import dataclasses +import enum import logging import os import sys @@ -25,6 +26,23 @@ class _TLS(Protocol): _tls: _TLS = cast("_TLS", threading.local()) +class RefMode(enum.Enum): + """Reference mode for kernel execution.""" + + OFF = "off" + EAGER = "eager" + + +def _get_ref_mode_from_env() -> RefMode: + """Get reference mode from environment variables.""" + # Check for environment variables + ref_eager = os.environ.get("HELION_REF_EAGER", "").lower() in ("1", "true", "yes") + + if ref_eager: + return RefMode.EAGER + return RefMode.OFF + + def set_default_settings(settings: Settings) -> AbstractContextManager[None, None]: """ Set the default settings for the current thread and return a context manager @@ -72,6 +90,9 @@ class _Settings: allow_warp_specialize: bool = ( os.environ.get("HELION_ALLOW_WARP_SPECIALIZE", "1") == "1" ) + ref_mode: RefMode = dataclasses.field( + default_factory=lambda: _get_ref_mode_from_env() + ) class Settings(_Settings): @@ -92,6 +113,7 @@ class Settings(_Settings): "print_output_code": "If True, print the output code of the kernel to stderr.", "force_autotune": "If True, force autotuning even if a config is provided.", "allow_warp_specialize": "If True, allow warp specialization for tl.range calls on CUDA devices.", + "ref_mode": "Reference mode for kernel execution. Can be RefMode.OFF or RefMode.EAGER.", } assert __slots__.keys() == {field.name for field in dataclasses.fields(_Settings)} diff --git a/test/test_ref_eager.py b/test/test_ref_eager.py new file mode 100644 index 00000000..fc338897 --- /dev/null +++ b/test/test_ref_eager.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import contextlib +import io +import math +from pathlib import Path +from typing import TYPE_CHECKING +import unittest +from unittest.mock import patch + +import torch + +from . import test_examples +import helion +from helion._testing import EXAMPLES_DIR +from helion._testing import TestCase +from helion._testing import import_path +import helion.language as hl + +if TYPE_CHECKING: + from helion.runtime.settings import RefMode + + +def clear_kernel_caches_and_set_ref_mode(ref_mode: RefMode) -> None: + """Clear kernel caches and set ref_mode on all kernels in examples.""" + # Get all Python files in the examples directory + example_files = Path(EXAMPLES_DIR).glob("*.py") + + for example_file in example_files: + try: + # Import the module + mod = import_path(example_file) + + # Find all Helion kernels in the module and update their settings + for attr_name in dir(mod): + attr = getattr(mod, attr_name) + if isinstance(attr, helion.Kernel): + # Reset the kernel to clear any cached bound kernels + attr.reset() + # Update the kernel's ref_mode setting + attr.settings.ref_mode = ref_mode + except Exception: + # Skip files that can't be imported or have issues + pass + + +class TestExamplesRefEager(test_examples.TestExamples): + """Run all TestExamples tests in reference eager mode.""" + + # NOTE: All tests in TestExamples are run in ref eager mode by default in this test file. + + def assertExpectedJournal(self, value: str) -> None: + """Skip journal assertions in ref mode since we don't generate Triton code.""" + + def setUp(self): + """Set up test environment.""" + super().setUp() + # Clear kernel caches and set ref mode to EAGER + clear_kernel_caches_and_set_ref_mode(helion.RefMode.EAGER) + + def tearDown(self): + """Restore original environment.""" + super().tearDown() + # Clear kernel caches and reset to OFF mode + clear_kernel_caches_and_set_ref_mode(helion.RefMode.OFF) + + def test_add(self): + # Mock the critical functions that should NOT be called in ref eager mode + # These functions are only used in normal Helion mode to generate and compile Triton code + with ( + patch("helion.runtime.kernel.BoundKernel.to_triton_code") as mock_to_triton, + patch("helion.runtime.kernel.BoundKernel.compile_config") as mock_compile, + ): + # Set up the mocks to fail if called + mock_to_triton.side_effect = AssertionError( + "to_triton_code should NOT be called in reference eager mode!" + ) + mock_compile.side_effect = AssertionError( + "compile_config should NOT be called in reference eager mode!" + ) + + # Run the original test + super().test_add() + + # Assert that neither function was called + mock_to_triton.assert_not_called() + mock_compile.assert_not_called() + + +class TestRefEagerMisc(TestCase): + def test_print_intermediate_tensor(self): + @helion.kernel(ref_mode=helion.RefMode.EAGER) + def print_intermediate_tensor_kernel( + x: torch.Tensor, y: torch.Tensor + ) -> torch.Tensor: + out = torch.empty_like(x) + m, n = x.shape + for tile_m, tile_n in hl.tile([m, n]): + x_val = x[tile_m, tile_n] + y_val = y[tile_m, tile_n] + sum_val = x_val + y_val + print("x: ", x_val) + print("y: ", y_val) + print("sum: ", sum_val) + out[tile_m, tile_n] = sum_val + return out + + x = torch.ones([2, 2], device="cuda", dtype=torch.float32) * 10.0 + y = torch.ones([2, 2], device="cuda", dtype=torch.float32) * 5.0 + expected = x + y + + # Capture stdout to check print output + captured_output = io.StringIO() + with contextlib.redirect_stdout(captured_output): + result = print_intermediate_tensor_kernel(x, y) + + torch.testing.assert_close(result, expected, atol=1e-6, rtol=1e-6) + + # Check that the print statements produced output + output = captured_output.getvalue() + self.assertIn("x: ", output) + self.assertIn("y: ", output) + self.assertIn("sum: ", output) + self.assertIn("[[10., 10.]", output) # x values + self.assertIn("[[5., 5.]", output) # y values + self.assertIn("[[15., 15.]", output) # sum values + + def test_print_in_invalid_helion_kernel(self): + """Test that print works even in invalid Helion kernels in ref eager mode.""" + + @helion.kernel(ref_mode=helion.RefMode.EAGER) + def incorrect_kernel(x: torch.Tensor) -> torch.Tensor: + m, n = x.shape + for tile_m, tile_n in hl.tile([m, n]): + val = x[tile_m, tile_n] + print("processing tile: ", val) + # `pass` below causes this kernel to be invalid. + # But we show that in ref-eager mode, the `print` statement above still works, + # which is useful for debugging. + pass # noqa: PIE790 + return x + + x = torch.ones([2, 2], device="cuda", dtype=torch.float32) * math.pi + + # Capture stdout to check print output + captured_output = io.StringIO() + with contextlib.redirect_stdout(captured_output): + _ = incorrect_kernel(x) + + # Check that the print statement produced output + output = captured_output.getvalue() + self.assertIn("processing tile: ", output) + self.assertIn("[[3.14", output) # The value printed + + def test_ref_eager_kernel_config(self): + @helion.kernel(ref_mode=helion.RefMode.EAGER) + def kernel(x: torch.Tensor) -> torch.Tensor: + return x + x * 2.0 + + x = torch.randn(128, device="cuda") + result = kernel(x) + expected = x + x * 2.0 + torch.testing.assert_close(result, expected) + + +if __name__ == "__main__": + unittest.main()