diff --git a/.github/workflows/test-template.yml b/.github/workflows/test-template.yml new file mode 100644 index 00000000..db4f0df8 --- /dev/null +++ b/.github/workflows/test-template.yml @@ -0,0 +1,99 @@ +name: Reusable Test Workflow + +on: + workflow_call: + inputs: + test-name: + required: true + type: string + ref-eager: + required: false + type: boolean + default: false + +jobs: + test: + name: ${{ inputs.test-name }}-cuda12.6-py${{ matrix.python-version }}-a10g + + container: + image: nvidia/cuda:12.6.3-devel-ubuntu24.04 + options: --gpus all + + runs-on: linux.g5.4xlarge.nvidia.gpu + + strategy: + matrix: + python-version: ["3.10", "3.12"] + + defaults: + run: + shell: bash -l {0} + + steps: + - name: Check out code + uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v6 + with: + python-version: ${{ matrix.python-version }} + enable-cache: true + + - name: Create virtual environment + run: | + uv venv --python ${{ matrix.python-version }} + + - name: Get current month + id: date + run: echo "month=$(date +'%Y-%m')" >> $GITHUB_OUTPUT + + - name: Cache dependencies + id: cache + uses: actions/cache@v4 + with: + path: | + ~/.cache/uv + ~/.venv + key: ${{ runner.os }}-deps-${{ matrix.python-version }}-${{ hashFiles('.github/workflows/test.yml', 'requirements.txt') }}-${{ steps.date.outputs.month }} + restore-keys: | + ${{ runner.os }}-deps- + + - name: Install PyTorch + run: | + source .venv/bin/activate + uv pip install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 + + - name: Install Triton + if: steps.cache.outputs.cache-hit != 'true' + run: | + set -x + source .venv/bin/activate + apt-get update + apt-get install -y git + apt-get install -y gcc-13 g++-13 zlib1g-dev + export CC=gcc-13 + export CXX=g++-13 + mkdir -p /tmp/$USER + cd /tmp/$USER + uv pip uninstall triton pytorch-triton || true + rm -rf triton/ || true + git clone https://github.com/triton-lang/triton.git + cd triton/ + uv pip install -r python/requirements.txt + MAX_JOBS=$(nproc) TRITON_PARALLEL_LINK_JOBS=2 uv pip install . + cd /tmp/$USER + rm -rf triton/ + + - name: Install Requirements + run: | + source .venv/bin/activate + uv pip install -r requirements.txt + + - name: Run Tests + run: | + source .venv/bin/activate + if [[ "${{ inputs.ref-eager }}" == "true" ]]; then + HELION_INTERPRET=1 pytest + else + pytest + fi diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 706dd5cb..4a87286d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -13,83 +13,13 @@ concurrency: jobs: test: - name: test-cuda12.6-py${{ matrix.python-version }}-a10g - - container: - image: nvidia/cuda:12.6.3-devel-ubuntu24.04 - options: --gpus all - - runs-on: linux.g5.4xlarge.nvidia.gpu - - strategy: - matrix: - python-version: ["3.10", "3.12"] - - defaults: - run: - shell: bash -l {0} - - steps: - - name: Check out code - uses: actions/checkout@v4 - - - name: Install uv - uses: astral-sh/setup-uv@v6 - with: - python-version: ${{ matrix.python-version }} - enable-cache: true - - - name: Create virtual environment - run: | - uv venv --python ${{ matrix.python-version }} - - - name: Get current month - id: date - run: echo "month=$(date +'%Y-%m')" >> $GITHUB_OUTPUT - - - name: Cache dependencies - id: cache - uses: actions/cache@v4 - with: - path: | - ~/.cache/uv - ~/.venv - key: ${{ runner.os }}-deps-${{ matrix.python-version }}-${{ hashFiles('.github/workflows/test.yml', 'requirements.txt') }}-${{ steps.date.outputs.month }} - restore-keys: | - ${{ runner.os }}-deps- - - - name: Install PyTorch - run: | - source .venv/bin/activate - uv pip install -U --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 - - - name: Install Triton - if: steps.cache.outputs.cache-hit != 'true' - run: | - set -x - source .venv/bin/activate - apt-get update - apt-get install -y git - apt-get install -y gcc-13 g++-13 zlib1g-dev - export CC=gcc-13 - export CXX=g++-13 - mkdir -p /tmp/$USER - cd /tmp/$USER - uv pip uninstall triton pytorch-triton || true - rm -rf triton/ || true - git clone https://github.com/triton-lang/triton.git - cd triton/ - uv pip install -r python/requirements.txt - MAX_JOBS=$(nproc) TRITON_PARALLEL_LINK_JOBS=2 uv pip install . - cd /tmp/$USER - rm -rf triton/ - - - name: Install Requirements - run: | - source .venv/bin/activate - uv pip install -r requirements.txt - - - name: Run Tests - run: | - source .venv/bin/activate - pytest + uses: ./.github/workflows/test-template.yml + with: + test-name: test + ref-eager: false + + test-ref-eager: + uses: ./.github/workflows/test-template.yml + with: + test-name: test-ref-eager + ref-eager: true diff --git a/helion/__init__.py b/helion/__init__.py index 1363eb20..728f983d 100644 --- a/helion/__init__.py +++ b/helion/__init__.py @@ -11,12 +11,14 @@ from .runtime import Kernel from .runtime import kernel from .runtime import kernel as jit # alias +from .runtime.settings import RefMode from .runtime.settings import Settings from .runtime.settings import set_default_settings __all__ = [ "Config", "Kernel", + "RefMode", "Settings", "cdiv", "exc", diff --git a/helion/_testing.py b/helion/_testing.py index 76e29dde..256732ab 100644 --- a/helion/_testing.py +++ b/helion/_testing.py @@ -2,6 +2,7 @@ import collections import contextlib +import functools import importlib import inspect import operator @@ -11,6 +12,7 @@ import sys from typing import TYPE_CHECKING from typing import Callable +from typing import Generator import unittest import torch @@ -18,7 +20,9 @@ from ._utils import counters from .runtime.config import Config +import helion from helion._compat import get_tensor_descriptor_fn_name +from helion.runtime.ref_mode import is_ref_mode_enabled if TYPE_CHECKING: import types @@ -30,6 +34,218 @@ EXAMPLES_DIR: Path = Path(__file__).parent.parent / "examples" +def skipIfRefEager(reason: str) -> Callable[[Callable], Callable]: + """Skip test if running in ref eager mode (HELION_INTERPRET=1).""" + return unittest.skipIf(os.environ.get("HELION_INTERPRET") == "1", reason) + + +@contextlib.contextmanager +def track_run_ref_calls() -> Generator[list[int], None, None]: + """Context manager that tracks BoundKernel.run_ref calls. + + Yields: + A list that will contain the count of run_ref calls. + """ + from helion.runtime.kernel import BoundKernel + + original_run_ref = BoundKernel.run_ref + run_ref_count = [0] + + def tracked_run_ref(self: BoundKernel, *args: object) -> object: + run_ref_count[0] += 1 + return original_run_ref(self, *args) + + BoundKernel.run_ref = tracked_run_ref + + try: + yield run_ref_count + finally: + BoundKernel.run_ref = original_run_ref + + +@contextlib.contextmanager +def assert_helion_ref_mode( + ref_mode: helion.RefMode = helion.RefMode.OFF, +) -> Generator[None, None, None]: + """Context manager that asserts Helion compilation behavior based on RefMode. + + - RefMode.OFF: expects compilation (run_ref should not be called) + - RefMode.EAGER: expects no compilation (run_ref should be called) + """ + with track_run_ref_calls() as run_ref_count: + yield + + if ref_mode == helion.RefMode.OFF: + # In normal mode (RefMode.OFF), run_ref should not be called + assert run_ref_count[0] == 0, ( + f"Expected run_ref to not be called in normal mode (RefMode.OFF), but got: run_ref={run_ref_count[0]}" + ) + elif ref_mode == helion.RefMode.EAGER: + # In ref eager mode (RefMode.EAGER), run_ref should be called + assert run_ref_count[0] > 0, ( + f"Expected run_ref to be called in ref eager mode (RefMode.EAGER), but got: run_ref={run_ref_count[0]}" + ) + else: + raise ValueError(f"Unknown RefMode: {ref_mode}") + + +assert_helion_compilation = functools.partial( + assert_helion_ref_mode, ref_mode=helion.RefMode.OFF +) + +assert_ref_eager_mode = functools.partial( + assert_helion_ref_mode, ref_mode=helion.RefMode.EAGER +) + + +class RefEagerTestBase: + """Base class for all ref eager mode test shards of normal Helion unit test files.""" + + # Class-level tracking for assert_close counting + _assert_close_count = 0 + _original_assert_close_func = None + # Class-level tracking for assertRaises counting + _assert_raises_count = 0 + _original_assert_raises_func = None + # Class-level tracking for skipTest counting + _skip_test_count = 0 + _original_skip_test_func = None + + def setUp(self) -> None: + """Common setup for all ref eager tests.""" + super().setUp() # type: ignore[misc] + + # Check if HELION_INTERPRET is already set + self._in_ref_eager_mode = os.environ.get("HELION_INTERPRET") == "1" + + # If not in ref eager mode, skip the setup + if not self._in_ref_eager_mode: + return + + # Reset assert_close counter for this test + RefEagerTestBase._assert_close_count = 0 + # Reset assertRaises counter for this test + RefEagerTestBase._assert_raises_count = 0 + # Reset skipTest counter for this test + RefEagerTestBase._skip_test_count = 0 + + # Patch torch.testing.assert_close to count calls + if RefEagerTestBase._original_assert_close_func is None: + RefEagerTestBase._original_assert_close_func = torch.testing.assert_close + + def counting_assert_close(*args: object, **kwargs: object) -> None: + RefEagerTestBase._assert_close_count += 1 + return RefEagerTestBase._original_assert_close_func(*args, **kwargs) # type: ignore[misc] + + torch.testing.assert_close = counting_assert_close + + # Patch self.assertRaises to count calls + if RefEagerTestBase._original_assert_raises_func is None: + RefEagerTestBase._original_assert_raises_func = self.assertRaises + + def counting_assert_raises(*args: object, **kwargs: object) -> object: + RefEagerTestBase._assert_raises_count += 1 + return RefEagerTestBase._original_assert_raises_func(*args, **kwargs) # type: ignore[misc] + + self.assertRaises = counting_assert_raises + + # Patch self.skipTest to count calls + if RefEagerTestBase._original_skip_test_func is None: + RefEagerTestBase._original_skip_test_func = self.skipTest + + def counting_skip_test(*args: object, **kwargs: object) -> object: + RefEagerTestBase._skip_test_count += 1 + return RefEagerTestBase._original_skip_test_func(*args, **kwargs) # type: ignore[misc] + + self.skipTest = counting_skip_test + + # Store the tracking context manager instance so we can check counts in tearDown + self._run_ref_tracker = track_run_ref_calls() + self._run_ref_count = self._run_ref_tracker.__enter__() + + def tearDown(self) -> None: + """Common teardown with assertion counting check.""" + # If not in ref eager mode, skip the teardown logic + if not self._in_ref_eager_mode: + super().tearDown() # type: ignore[misc] + return + + try: + # Exit the run_ref tracker + self._run_ref_tracker.__exit__(None, None, None) + + # Check if the test was skipped + test_method = getattr(self, self._testMethodName, None) # type: ignore[attr-defined] + is_skipped = ( + test_method is not None + and hasattr(test_method, "__unittest_skip__") + and test_method.__unittest_skip__ + ) or RefEagerTestBase._skip_test_count > 0 + + # Assert that either run_ref was called or the test was skipped + if not is_skipped and self._run_ref_count[0] == 0: + self.fail( # type: ignore[attr-defined] + f"Test {self._testMethodName} did not call run_ref and was not skipped" # pyright: ignore[reportAttributeAccessIssue] + ) + + if not is_skipped: + # Check that either assert_close, assertRaises, or skipTest was called + total_assertions = ( + RefEagerTestBase._assert_close_count + + RefEagerTestBase._assert_raises_count + + RefEagerTestBase._skip_test_count + ) + self.assertGreater( # type: ignore[attr-defined] + total_assertions, + 0, + f"Test {self._testMethodName} did not call torch.testing.assert_close, assertRaises, or skipTest", # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + ) + finally: + # Restore the original assert_close function + if RefEagerTestBase._original_assert_close_func is not None: + torch.testing.assert_close = ( + RefEagerTestBase._original_assert_close_func + ) + + # Restore the original assertRaises function + if RefEagerTestBase._original_assert_raises_func is not None: + self.assertRaises = RefEagerTestBase._original_assert_raises_func + + # Restore the original skipTest function + if RefEagerTestBase._original_skip_test_func is not None: + self.skipTest = RefEagerTestBase._original_skip_test_func + + super().tearDown() # type: ignore[misc] + + # NOTE: We no-op these methods because they commonly check behaviors that are not relevant in ref eager mode. + # Instead, we solely rely on the unit test's `torch.testing.assert_close` and `assertRaises` checks to ensure ref eager mode's correctness. + def assertExpectedJournal(self, value: str) -> None: + if not self._in_ref_eager_mode: + super().assertExpectedJournal(value) # type: ignore[misc] + + def assertIn( + self, member: object, container: object, msg: str | None = None + ) -> None: + if not self._in_ref_eager_mode: + super().assertIn(member, container, msg) # type: ignore[misc] + + def assertNotIn( + self, member: object, container: object, msg: str | None = None + ) -> None: + if not self._in_ref_eager_mode: + super().assertNotIn(member, container, msg) # type: ignore[misc] + + def assertEqualCode(self, first: str, second: str, msg: str | None = None) -> None: + if not self._in_ref_eager_mode: + super().assertEqual(first, second, msg) # type: ignore[misc] + + def assertNotEqualCode( + self, first: str, second: str, msg: str | None = None + ) -> None: + if not self._in_ref_eager_mode: + super().assertNotEqual(first, second, msg) # type: ignore[misc] + + def import_path(filename: Path) -> types.ModuleType: module_name = f"{__name__}.{filename.stem}" if module_name not in sys.modules: @@ -47,6 +263,16 @@ def code_and_output( args: tuple[object, ...], **kwargs: object, ) -> tuple[str, object]: + bound = fn.bind(args) + if is_ref_mode_enabled(bound.kernel.settings): + if kwargs: + config = Config(**kwargs) # pyright: ignore[reportArgumentType] + bound._config = config + result = fn(*args) + # Return the original kernel source code + code = inspect.getsource(fn.fn) + return code, result + if kwargs: config = Config( **kwargs # pyright: ignore[reportArgumentType] @@ -280,6 +506,16 @@ def lookup(self, test_id: str, value: str) -> tuple[str, str]: return value, expected +class RefEagerTestDisabled: + """Base class for test classes that should be skipped when ref eager mode is enabled.""" + + def setUp(self) -> None: + """Skip test if ref eager mode is enabled.""" + super().setUp() # type: ignore[misc] + if os.environ.get("HELION_INTERPRET") == "1": + self.skipTest("Test class disabled in ref eager mode") # type: ignore[attr-defined] + + class TestCase(unittest.TestCase): maxDiff = 16384 diff --git a/helion/language/_decorators.py b/helion/language/_decorators.py index fbddb640..ca8d653d 100644 --- a/helion/language/_decorators.py +++ b/helion/language/_decorators.py @@ -79,6 +79,7 @@ class APIFunc(Protocol): _to_device_ir: Callable[..., object] | None _allow_host_tensor: bool _signature: inspect.Signature + _ref_fn: Callable[..., object] | None def __call__(self, *args: object, **kwargs: object) -> object: ... @@ -133,8 +134,17 @@ def api( def _impl(fn: _C) -> _C: @functools.wraps(fn) def wrapper(*args: object, **kwargs: object) -> object: + from ..runtime.ref_mode import is_in_ref_mode_context + bound = api._signature.bind(*args, **kwargs) bound.apply_defaults() + + if is_in_ref_mode_context(): + assert api._ref_fn is not None, ( + f"{fn.__qualname__} must be decorated with @helion.ref() to be used in ref mode" + ) + return api._ref_fn(*bound.arguments.values()) + flat_args = api._prepare_args(*bound.arguments.values()) del args, kwargs @@ -187,6 +197,7 @@ def wrapper(*args: object, **kwargs: object) -> object: api._signature = signature or inspect.signature( cast("Callable[..., object]", fn) ) + api._ref_fn = None return wrapper # pyright: ignore[reportReturnType] return _impl @@ -289,6 +300,22 @@ def _impl(to_device_ir_fn: Callable[..., object]) -> Callable[..., Never]: return _impl # pyright: ignore[reportReturnType] +def ref( + original_fn: Callable[..., object], +) -> _NoReturnDecorator[object]: + def _impl(ref_fn: Callable[..., object]) -> Callable[..., Never]: + assert is_api_func(original_fn), ( + f"{ref.__qualname__} can only be used on API functions" + ) + assert original_fn._ref_fn is None, ( + "ref mode implementation can only be registered once per function" + ) + original_fn._ref_fn = ref_fn + return _no_call + + return _impl # pyright: ignore[reportReturnType] + + def _default_type_function( fake_fn: Callable[..., object], tiles_as_sizes: bool ) -> Callable[..., TypeInfo]: diff --git a/helion/language/constexpr.py b/helion/language/constexpr.py index 8b2ae084..e4170fd9 100644 --- a/helion/language/constexpr.py +++ b/helion/language/constexpr.py @@ -95,3 +95,8 @@ def _(state: CodegenState) -> ast.AST: value = value.__int__() assert isinstance(value, int) return expr_from_string(repr(value)) + + +@_decorators.ref(specialize) +def _(value: int | torch.SymInt) -> int: + return int(value) diff --git a/helion/language/creation_ops.py b/helion/language/creation_ops.py index b22dc090..e279feb1 100644 --- a/helion/language/creation_ops.py +++ b/helion/language/creation_ops.py @@ -9,6 +9,7 @@ from .._compiler.compile_environment import CompileEnvironment from ..exc import NotInsideKernel from . import _decorators +from .ref_tile import RefTile if TYPE_CHECKING: import ast @@ -144,6 +145,22 @@ def _( return None +@_decorators.ref(full) +def _( + shape: list[int | RefTile], + value: float, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + processed_shape = [] + for s in shape: + if isinstance(s, RefTile): + processed_shape.append(s.end - s.begin) + else: + processed_shape.append(s) + env = CompileEnvironment.current() + return torch.full(processed_shape, value, dtype=dtype, device=env.device) + + def arange( *args: int, dtype: torch.dtype | None = None, diff --git a/helion/language/loops.py b/helion/language/loops.py index 03d874e8..bffec627 100644 --- a/helion/language/loops.py +++ b/helion/language/loops.py @@ -3,6 +3,7 @@ import ast import builtins import inspect +import itertools from itertools import starmap from typing import TYPE_CHECKING from typing import Iterator @@ -42,6 +43,7 @@ from ..autotuner.config_spec import RangeWarpSpecializeSpec from ..autotuner.config_spec import StaticRangeSpec from . import _decorators +from .ref_tile import RefTile from .tile_proxy import Tile if TYPE_CHECKING: @@ -449,6 +451,88 @@ def _(state: CodegenState) -> ast.AST: return _codegen_loop_helper(state) +def _to_int(value: int | torch.Tensor | None) -> int | None: + """Convert tensor values to int.""" + if value is None: + return None + if isinstance(value, torch.Tensor): + return int(value.item()) + return int(value) + + +def _normalize_to_list( + value: int | torch.Tensor | list[int | torch.Tensor], +) -> list[int | torch.Tensor]: + """Convert single values to lists for uniform handling.""" + if isinstance(value, (list, tuple)): + return list(value) + return [value] + + +def _normalize_begin_end_ref( + begin_or_end: int | torch.Tensor | list[int | torch.Tensor], + end_or_none: int | torch.Tensor | list[int | torch.Tensor] | None = None, +) -> tuple[ + int | torch.Tensor | list[int | torch.Tensor], + int | torch.Tensor | list[int | torch.Tensor], +]: + if end_or_none is not None: + # Two positional args: begin_or_end is begin, end_or_none is end + return begin_or_end, end_or_none + # One positional arg: begin_or_end is end, begin defaults to 0 + end = begin_or_end + if isinstance(end, (list, tuple)): + begin = cast("int | torch.Tensor | list[int | torch.Tensor]", [0] * len(end)) + else: + begin = 0 + return begin, end + + +@_decorators.ref(tile) +def _( + begin_or_end: int | torch.Tensor | list[int | torch.Tensor], + end_or_none: int | torch.Tensor | list[int | torch.Tensor] | None = None, + block_size: int | torch.Tensor | list[int | torch.Tensor] | None = None, +) -> Iterator[RefTile | tuple[RefTile, ...]]: + # Step 1: Normalize begin and end values + begin, end = _normalize_begin_end_ref(begin_or_end, end_or_none) + + # Step 2: Convert to lists and then to ints + begin_list = _normalize_to_list(begin) + end_list = _normalize_to_list(end) + begin_ints = [_to_int(b) for b in begin_list] + end_ints = [_to_int(e) for e in end_list] + + # Step 3: Determine block sizes - always return full dimension size, ignoring block_size parameter + block_size_list = [] + for b, e in zip(begin_ints, end_ints, strict=True): + assert b is not None and e is not None + block_size_list.append(e - b) + + # Step 4: Determine return type + # Return single tiles if input was not a list + return_single = not isinstance(begin, list) and not isinstance(end, list) + + # Step 5: Generate tiles + # Build tiles for each dimension + tiles = [] + for b, e in zip(begin_ints, end_ints, strict=True): + assert b is not None and e is not None + if b != e: + # Only create tile if range is non-empty + tiles.append(RefTile(b, e, e - b)) + + # Yield result based on return type + if tiles: # Only yield if we have at least one non-empty dimension + if return_single: + # Single dimension case - yield the tile directly + assert len(tiles) == 1 + yield tiles[0] + else: + # Multi-dimensional case - yield as tuple + yield tuple(tiles) + + def _codegen_loop_helper( state: CodegenState, ) -> ast.AST: @@ -484,7 +568,7 @@ def grid( begin_or_end: int | torch.Tensor, end_or_none: int | torch.Tensor | None = None, /, - step: object = None, + step: int | torch.Tensor | Sequence[int | torch.Tensor] | None = None, ) -> Iterator[torch.SymInt]: ... @@ -497,7 +581,7 @@ def grid( begin_or_end: Sequence[int | torch.Tensor], end_or_none: Sequence[int | torch.Tensor] | None = None, /, - step: object = None, + step: int | torch.Tensor | Sequence[int | torch.Tensor] | None = None, ) -> Iterator[Sequence[torch.SymInt]]: ... @@ -509,7 +593,7 @@ def grid( begin_or_end: int | torch.Tensor | Sequence[int | torch.Tensor], end_or_none: int | torch.Tensor | Sequence[int | torch.Tensor] | None = None, /, - step: object = None, + step: int | torch.Tensor | Sequence[int | torch.Tensor] | None = None, ) -> Iterator[torch.SymInt] | Iterator[Sequence[torch.SymInt]]: # type: ignore[type-arg] """Iterate over individual indices of the given iteration space. @@ -637,6 +721,109 @@ def _(state: CodegenState) -> ast.AST: return _codegen_loop_helper(state) +def _extract_step_value( + step: int | torch.Tensor | Sequence[int | torch.Tensor] | None, + index: int = 0, +) -> int | torch.Tensor | None: + """Extract step value from various input formats.""" + if step is None: + return None + + if isinstance(step, (list, tuple)): + # Extract from sequence at index + if index < len(step): + val = step[index] + # Type narrow to valid types for _to_int + if isinstance(val, (int, torch.Tensor, type(None))): + return val + return None + + # Single value - type narrow to valid types + if isinstance(step, (int, torch.Tensor)): + return step + return None + + +def _normalize_step_values( + step: int | torch.Tensor | Sequence[int | torch.Tensor] | None, + num_dims: int, +) -> list[int | None]: + """Normalize step values to a list of ints for each dimension.""" + if step is None: + return [None] * num_dims + + assert isinstance(step, (list, tuple)) + step_ints = [] + for i in range(num_dims): + step_val = _extract_step_value(step, i) + step_ints.append(_to_int(step_val)) + return step_ints + + +def _create_ranges( + begin_ints: list[int | None], + end_ints: list[int | None], + step_ints: list[int | None] | None = None, +) -> list[range]: + """Create range objects from begin, end, and optional step values.""" + ranges = [] + + if step_ints is None: + # No steps provided - use default ranges + for b, e in zip(begin_ints, end_ints, strict=True): + assert b is not None and e is not None + ranges.append(range(b, e)) + else: + # Steps provided - use them where available + for b, e, s in zip(begin_ints, end_ints, step_ints, strict=True): + assert b is not None and e is not None + if s is not None: + ranges.append(range(b, e, s)) + else: + ranges.append(range(b, e)) + + return ranges + + +@_decorators.ref(grid) +def _( + begin_or_end: int | torch.Tensor | list[int | torch.Tensor], + end_or_none: int | torch.Tensor | list[int | torch.Tensor] | None = None, + step: int | torch.Tensor | Sequence[int | torch.Tensor] | None = None, +) -> range | Iterator[tuple[int, ...]]: + # Step 1: Normalize begin and end values + begin, end = _normalize_begin_end_ref(begin_or_end, end_or_none) + + # Step 2: Handle single dimension case + if not isinstance(begin, (list, tuple)): + begin_int = _to_int(begin) + assert not isinstance(end, (list, tuple)) + end_int = _to_int(end) + assert begin_int is not None and end_int is not None + + # Extract step for single dimension + step_val = _extract_step_value(step, 0) + step_int = _to_int(step_val) + + if step_int is not None: + return range(begin_int, end_int, step_int) + return range(begin_int, end_int) + + # Step 3: Handle multi-dimensional case + assert isinstance(end, (list, tuple)) + begin_ints = [_to_int(b) for b in begin] + end_ints = [_to_int(e) for e in end] + + # Step 4: Normalize step values + step_ints = ( + _normalize_step_values(step, len(begin_ints)) if step is not None else None + ) + + # Step 5: Create ranges and return product + ranges = _create_ranges(begin_ints, end_ints, step_ints) + return itertools.product(*ranges) + + @_decorators.device_func_replacement(builtins.zip) @_decorators.api(is_device_only=True, cache_type=True) def _zip_replacement( @@ -898,3 +1085,14 @@ def _( # Return tuple(range(...)) which will trigger existing tuple/list unrolling return tuple(range(begin_val, end_val, step)) + + +@_decorators.ref(static_range) +def _( + begin_or_end: int, + end_or_none: int | None = None, + step: int = 1, +) -> range: + if end_or_none is not None: + return range(begin_or_end, end_or_none, step) + return range(begin_or_end) diff --git a/helion/language/matmul_ops.py b/helion/language/matmul_ops.py index 25e827d0..3fcb388b 100644 --- a/helion/language/matmul_ops.py +++ b/helion/language/matmul_ops.py @@ -209,3 +209,41 @@ def _(state: CodegenState) -> object: rhs=rhs_ast, acc=acc_ast, ) + + +@_decorators.ref(dot) +def _( + mat1: torch.Tensor, + mat2: torch.Tensor, + acc: torch.Tensor | None = None, +) -> torch.Tensor: + out_dtype = _compute_out_dtype( + mat1.dtype, mat2.dtype, None if acc is None else acc.dtype + ) + + is_fp8 = mat1.dtype in (torch.float8_e4m3fn, torch.float8_e5m2) or mat2.dtype in ( + torch.float8_e4m3fn, + torch.float8_e5m2, + ) + if is_fp8: + # Use torch._scaled_mm for FP8 operations + # Ensure column-major for second operand as required by torch._scaled_mm + mat2_t = mat2.T.contiguous().T + scale_a = torch.tensor(1.0, device=mat1.device) + scale_b = torch.tensor(1.0, device=mat2.device) + + result = torch._scaled_mm( + mat1, + mat2_t, + scale_a, + scale_b, + use_fast_accum=False, + out_dtype=out_dtype, + ) + else: + # For non-FP8 tensors, use regular matmul + result = torch.mm(mat1, mat2, out_dtype=out_dtype) + + if acc is not None: + return acc + result + return result diff --git a/helion/language/ref_tile.py b/helion/language/ref_tile.py new file mode 100644 index 00000000..80403938 --- /dev/null +++ b/helion/language/ref_tile.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from .. import exc +from .tile_interface import TileInterface + +if TYPE_CHECKING: + from collections.abc import Callable + + +class RefTile(TileInterface, torch.Tensor): + _slice: slice + _block_size: int + + def __init__(self, begin: int, end: int, block_size: int) -> None: + super().__init__() + + from ..runtime.ref_mode import is_in_ref_mode_context + + assert is_in_ref_mode_context() + self._slice = slice(begin, end, None) + self._block_size = block_size + + @classmethod + def __torch_function__( + cls, + func: Callable[..., object], + types: object, + args: tuple[object, ...] = (), + kwargs: dict[str, object] | None = None, + ) -> object: + if func is torch.Tensor.__getitem__: + return cls._handle_getitem(func, args, kwargs) + + if func is torch.Tensor.__setitem__: + return cls._handle_setitem(func, args, kwargs) + + if func is torch.Tensor.__format__: + return repr(args[0]) + + raise exc.IncorrectTileUsage(func) + + @classmethod + def _handle_getitem( + cls, + func: Callable[..., object], + args: tuple[object, ...], + kwargs: dict[str, object] | None, + ) -> object: + """Handle tensor[index] operations.""" + tensor, index = args + assert isinstance(tensor, torch.Tensor) + + if isinstance(index, RefTile): + return tensor[index._slice] + + if isinstance(index, tuple): + new_index = cls._convert_tile_indices_to_slices(index) + return tensor[tuple(new_index)] # pyright: ignore[reportArgumentType] + + # Non-tile index in ref mode + return tensor[index] # pyright: ignore[reportArgumentType] + + @classmethod + def _handle_setitem( + cls, + func: Callable[..., object], + args: tuple[object, ...], + kwargs: dict[str, object] | None, + ) -> object: + """Handle tensor[index] = value operations.""" + tensor, index, value = args + assert isinstance(tensor, torch.Tensor) + assert isinstance(value, (int, float, bool, torch.Tensor)) + + if isinstance(index, RefTile): + tensor[index._slice] = value + return None + + if isinstance(index, tuple): + new_index = cls._convert_tile_indices_to_slices(index) + tensor[tuple(new_index)] = value # pyright: ignore[reportArgumentType] + return None + + # Non-tile index in ref mode + tensor[index] = value # pyright: ignore[reportArgumentType] + return None + + @classmethod + def _convert_tile_indices_to_slices( + cls, indices: tuple[object, ...] + ) -> list[object]: + """Convert RefTile objects in a tuple of indices to slices.""" + new_index = [] + for idx in indices: + if isinstance(idx, RefTile): + new_index.append(idx._slice) + else: + new_index.append(idx) + return new_index + + def __repr__(self, tensor_contents: None = None) -> str: # pyright: ignore[reportIncompatibleMethodOverride] + return f"RefTile({self._slice!r})" + + def __index__(self) -> int: + return self.block_size diff --git a/helion/language/tile_interface.py b/helion/language/tile_interface.py new file mode 100644 index 00000000..b054f3bb --- /dev/null +++ b/helion/language/tile_interface.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import torch + + +class TileInterface: + """Base interface for tile objects in Helion.""" + + @property + def index(self) -> torch.Tensor: + """ + Alias for :func:`~helion.language.tile_index`, which retrieves a tensor containing the offsets for a tile. + """ + from .tile_ops import tile_index + + return tile_index(self) + + @property + def begin(self) -> int: + """ + Alias for :func:`~helion.language.tile_begin`, which retrieves the start offset of a tile. + """ + from .tile_ops import tile_begin + + return tile_begin(self) + + @property + def end(self) -> int: + """ + Alias for :func:`~helion.language.tile_end`, which retrieves the end offset of a tile. + """ + from .tile_ops import tile_end + + return tile_end(self) + + @property + def block_size(self) -> int: + """ + Alias for :func:`~helion.language.tile_block_size`, which retrieves the block_size of a tile. + """ + from .tile_ops import tile_block_size + + return tile_block_size(self) + + @property + def id(self) -> int: + """ + Alias for :func:`~helion.language.tile_id`, which retrieves the id of a tile. + """ + from .tile_ops import tile_id + + return tile_id(self) diff --git a/helion/language/tile_ops.py b/helion/language/tile_ops.py index 6e90b9ac..fc3abc1b 100644 --- a/helion/language/tile_ops.py +++ b/helion/language/tile_ops.py @@ -13,11 +13,12 @@ import ast from .._compiler.inductor_lowering import CodegenState - from .loops import Tile + from .ref_tile import RefTile + from .tile_interface import TileInterface @_decorators.api(tiles_as_sizes=True) -def tile_index(tile: Tile) -> torch.Tensor: +def tile_index(tile: TileInterface) -> torch.Tensor: """ Retrieve the index (a 1D tensor containing offsets) of the given tile. This can also be written as: `tile.index`. @@ -48,8 +49,16 @@ def _(state: CodegenState) -> ast.AST: return expr_from_string(state.codegen.index_var(index)) +@_decorators.ref(tile_index) +def _(tile: RefTile) -> torch.Tensor: + env = CompileEnvironment.current() + return torch.arange( + tile._slice.start, tile._slice.stop, dtype=torch.int32, device=env.device + ) + + @_decorators.api(tiles_as_sizes=True) -def tile_begin(tile: Tile) -> int: +def tile_begin(tile: TileInterface) -> int: """ Retrieve the start offset of the given tile. This can also be written as: `tile.begin`. @@ -82,8 +91,13 @@ def _(state: CodegenState) -> ast.AST: return expr_from_string(state.codegen.offset_var(index)) +@_decorators.ref(tile_begin) +def _(tile: RefTile) -> int: + return tile._slice.start + + @_decorators.api(tiles_as_sizes=True) -def tile_end(tile: Tile) -> int: +def tile_end(tile: TileInterface) -> int: """ Retrieve the end offset of the given tile. For the first 0 to N-1 tiles, this is equivalent to `tile.begin + tile.block_size`. @@ -121,8 +135,13 @@ def _(state: CodegenState) -> ast.AST: return expr_from_string(naive_exp) +@_decorators.ref(tile_end) +def _(tile: RefTile) -> int: + return tile._slice.stop + + @_decorators.api(tiles_as_sizes=True) -def tile_block_size(tile: Tile) -> int: +def tile_block_size(tile: TileInterface) -> int: """ Retrieve block size of a given tile, usually set the autotuner. This can also be written as: `tile.block_size`. @@ -139,8 +158,13 @@ def _(tile: torch.SymInt) -> torch.SymInt: # codegen is handled in _get_symnode() +@_decorators.ref(tile_block_size) +def _(tile: RefTile) -> int: + return tile._block_size + + @_decorators.api(tiles_as_sizes=True) -def tile_id(tile: Tile) -> int: +def tile_id(tile: TileInterface) -> int: """ Retrieve tile_id of a given tile or list of tiles. This is equivalent to `tile.begin // tile.block_size`. @@ -166,3 +190,9 @@ def _(state: CodegenState) -> ast.AST: else: expr_str = f"{offset} // {block_size}" return expr_from_string(expr_str) + + +@_decorators.ref(tile_id) +def _(tile: RefTile) -> int: + # ID is always 0 since we always have one tile per dim in ref mode + return 0 diff --git a/helion/language/tile_proxy.py b/helion/language/tile_proxy.py index 7097a2bf..3aeeb8e3 100644 --- a/helion/language/tile_proxy.py +++ b/helion/language/tile_proxy.py @@ -13,6 +13,7 @@ from .. import exc from .._compiler.compile_environment import CompileEnvironment +from .tile_interface import TileInterface if TYPE_CHECKING: from collections.abc import Callable @@ -26,7 +27,7 @@ class _TLS(Protocol): _tls: _TLS = cast("_TLS", threading.local()) -class Tile(torch.Tensor): +class Tile(TileInterface, torch.Tensor): """ This class should not be instantiated directly, it is the result of hl.tile(...) and represents a single tile of the iteration space. @@ -97,51 +98,6 @@ def _tiles_to_sizes(cls, it: _T) -> _T: def _tile_to_size(x: Tile) -> torch.SymInt: return CompileEnvironment.current().block_sizes[x.block_id].var - @property - def index(self) -> torch.Tensor: - """ - Alias for :func:`~helion.language.tile_index`, which retrieves a tensor containing the offsets for a tile. - """ - from .tile_ops import tile_index - - return tile_index(self) - - @property - def begin(self) -> int: - """ - Alias for :func:`~helion.language.tile_begin`, which retrieves the start offset of a tile. - """ - from .tile_ops import tile_begin - - return tile_begin(self) - - @property - def end(self) -> int: - """ - Alias for :func:`~helion.language.tile_end`, which retrieves the end offset of a tile. - """ - from .tile_ops import tile_end - - return tile_end(self) - - @property - def block_size(self) -> int: - """ - Alias for :func:`~helion.language.tile_block_size`, which retrieves the block_size of a tile. - """ - from .tile_ops import tile_block_size - - return tile_block_size(self) - - @property - def id(self) -> int: - """ - Alias for :func:`~helion.language.tile_id`, which retrieves the id of a tile. - """ - from .tile_ops import tile_id - - return tile_id(self) - class _CheckForIndexCalls: """ diff --git a/helion/language/view_ops.py b/helion/language/view_ops.py index cdbb7cd8..bee7f559 100644 --- a/helion/language/view_ops.py +++ b/helion/language/view_ops.py @@ -102,6 +102,11 @@ def _(state: CodegenState) -> ast.AST: ) +@_decorators.ref(subscript) +def _(tensor: torch.Tensor, indices: list[object]) -> torch.Tensor: + return tensor[indices] # pyright: ignore[reportArgumentType] + + @_decorators.get_masked_value(subscript) def _(node: torch.fx.Node) -> float | bool | None: from .._compiler.node_masking import cached_masked_value diff --git a/helion/runtime/kernel.py b/helion/runtime/kernel.py index 8e53d27d..0c90c374 100644 --- a/helion/runtime/kernel.py +++ b/helion/runtime/kernel.py @@ -36,6 +36,8 @@ from .._logging import LazyString from ..language.constexpr import ConstExpr from .config import Config +from .ref_mode import RefModeContext +from .ref_mode import is_ref_mode_enabled from .settings import Settings if TYPE_CHECKING: @@ -278,7 +280,11 @@ def reset(self) -> None: class BoundKernel(Generic[_R]): - def __init__(self, kernel: Kernel[_R], args: tuple[object, ...]) -> None: + def __init__( + self, + kernel: Kernel[_R], + args: tuple[object, ...], + ) -> None: """ Initialize a BoundKernel object. @@ -295,6 +301,12 @@ def __init__(self, kernel: Kernel[_R], args: tuple[object, ...]) -> None: self._config: Config | None = None self._compile_cache: dict[Config, CompiledConfig] = {} self.env = CompileEnvironment(_find_device(args), self.kernel.settings) + + if is_ref_mode_enabled(self.kernel.settings): + self.fake_args = [] # type: ignore[assignment] + self.host_function = None # type: ignore[assignment] + return + with self.env: assert len(args) == len(self.kernel.signature.parameters) self.fake_args: list[object] = [] @@ -540,6 +552,11 @@ def _require_implicit_config(self) -> Config: raise RuntimeError("no config provided and no implicit config available") return config + def run_ref(self, *args: object) -> _R: # pyright: ignore[reportReturnType] + with RefModeContext(self.env): + result = self.kernel.fn(*args) + return cast("_R", result) + def __call__(self, *args: object) -> _R: """ Execute the kernel with the given arguments. @@ -550,6 +567,11 @@ def __call__(self, *args: object) -> _R: Returns: _R: The result of the kernel execution. """ + if is_ref_mode_enabled(self.kernel.settings): + if (config := self._implicit_config()) is not None: + self._config = config + return self.run_ref(*args) + if self._run is None: if (config := self._implicit_config()) is not None: self.set_config(config) diff --git a/helion/runtime/ref_mode.py b/helion/runtime/ref_mode.py new file mode 100644 index 00000000..07dfbfd9 --- /dev/null +++ b/helion/runtime/ref_mode.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import enum +from typing import TYPE_CHECKING +from typing import cast + +import torch +from torch.overrides import BaseTorchFunctionMode + +from helion._compiler.compile_environment import CompileEnvironment +from helion._compiler.compile_environment import NoCurrentEnvironment +from helion._compiler.compile_environment import tls as ce_tls + +if TYPE_CHECKING: + from typing_extensions import Self + + from .settings import Settings + + +class RefMode(enum.Enum): + """Reference mode for kernel execution.""" + + OFF = "off" + EAGER = "eager" + + +def is_ref_mode_enabled(settings: Settings) -> bool: + """Check if ref mode is enabled based on settings.""" + return settings.ref_mode != RefMode.OFF + + +def is_in_ref_mode_context() -> bool: + """Check if we're currently executing in ref mode context. + + This checks if there's a current CompileEnvironment with ref mode enabled. + """ + try: + env = CompileEnvironment.current() + return is_ref_mode_enabled(env.settings) + except NoCurrentEnvironment: + return False + + +class RefModeContext: + """Context manager to enable ref mode execution.""" + + def __init__(self, env: CompileEnvironment) -> None: + self.env = env + self.func_mode = RefModeTorchFunctionMode() + + def __enter__(self) -> Self: + ce_tls.env = self.env + self.func_mode.__enter__() + return self + + def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> bool: + self.func_mode.__exit__(exc_type, exc_val, exc_tb) + ce_tls.env = None + return False + + +class RefModeTorchFunctionMode(BaseTorchFunctionMode): + """Torch function mode for Helion ref mode operations.""" + + def __torch_function__( + self, + func: object, + types: list[type[object]], + args: tuple[object, ...] = (), + kwargs: dict[str, object] | None = None, + ) -> object: + if kwargs is None: + kwargs = {} + + # Handle matrix multiplication operations + if func == torch.addmm: + return self._handle_addmm(args, kwargs) + if func == torch.baddbmm: + return self._handle_baddbmm(args, kwargs) + + return super().__torch_function__(func, types, args, kwargs) + + def _handle_addmm( + self, args: tuple[object, ...], kwargs: dict[str, object] + ) -> torch.Tensor: + """Handle torch.addmm with mixed precision support (e.g. torch.addmm(fp32, bf16, bf16)).""" + assert len(args) >= 3, "addmm requires at least 3 arguments" + bias = cast("torch.Tensor", args[0]) + mat1 = cast("torch.Tensor", args[1]) + mat2 = cast("torch.Tensor", args[2]) + beta = cast("float", kwargs.get("beta", 1)) + alpha = cast("float", kwargs.get("alpha", 1)) + + assert mat1.dtype == mat2.dtype, ( + f"Matrix dtypes must match for torch.addmm: " + f"mat1.dtype={mat1.dtype}, mat2.dtype={mat2.dtype}" + ) + + result = torch.mm(mat1, mat2, out_dtype=bias.dtype) + if alpha != 1: + result = result * alpha + if result.dtype != bias.dtype: + result = result.to(bias.dtype) + if beta == 0: + return result + return result + (beta * bias) + + def _handle_baddbmm( + self, args: tuple[object, ...], kwargs: dict[str, object] + ) -> torch.Tensor: + """Handle torch.baddbmm with mixed precision support (e.g. torch.baddbmm(fp32, bf16, bf16)).""" + assert len(args) >= 3, "baddbmm requires at least 3 arguments" + bias = cast("torch.Tensor", args[0]) + batch1 = cast("torch.Tensor", args[1]) + batch2 = cast("torch.Tensor", args[2]) + beta = cast("float", kwargs.get("beta", 1)) + alpha = cast("float", kwargs.get("alpha", 1)) + + assert batch1.dtype == batch2.dtype, ( + f"Matrix dtypes must match for torch.baddbmm: " + f"mat1.dtype={batch1.dtype}, mat2.dtype={batch2.dtype}" + ) + + result = torch.bmm(batch1, batch2, out_dtype=bias.dtype) + if alpha != 1: + result = result * alpha + if result.dtype != bias.dtype: + result = result.to(bias.dtype) + if beta == 0: + return result + return result + (beta * bias) diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index 339f8cf2..91bb80f0 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -14,6 +14,7 @@ from torch._environment import is_fbcode from helion import exc +from helion.runtime.ref_mode import RefMode if TYPE_CHECKING: from contextlib import AbstractContextManager @@ -72,6 +73,9 @@ class _Settings: allow_warp_specialize: bool = ( os.environ.get("HELION_ALLOW_WARP_SPECIALIZE", "1") == "1" ) + ref_mode: RefMode = ( + RefMode.EAGER if os.environ.get("HELION_INTERPRET", "") == "1" else RefMode.OFF + ) class Settings(_Settings): @@ -92,6 +96,7 @@ class Settings(_Settings): "print_output_code": "If True, print the output code of the kernel to stderr.", "force_autotune": "If True, force autotuning even if a config is provided.", "allow_warp_specialize": "If True, allow warp specialization for tl.range calls on CUDA devices.", + "ref_mode": "Reference mode for kernel execution. Can be RefMode.OFF or RefMode.EAGER.", } assert __slots__.keys() == {field.name for field in dataclasses.fields(_Settings)} diff --git a/test/test_associative_scan.py b/test/test_associative_scan.py index f947ab22..ded756a4 100644 --- a/test/test_associative_scan.py +++ b/test/test_associative_scan.py @@ -6,6 +6,7 @@ import helion from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import code_and_output import helion.language as hl @@ -96,7 +97,7 @@ def jit_add_combine_fn(x, y): return x + y -class TestAssociativeScan(TestCase): +class TestAssociativeScan(RefEagerTestDisabled, TestCase): def test_associative_scan_basic_addition(self): """Test basic associative_scan functionality with prefix sum.""" diff --git a/test/test_atomic_add.py b/test/test_atomic_add.py index f271ce7d..576ae532 100644 --- a/test/test_atomic_add.py +++ b/test/test_atomic_add.py @@ -6,6 +6,7 @@ import helion from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import code_and_output import helion.language as hl @@ -56,7 +57,7 @@ def atomic_add_w_tile_attr(x: torch.Tensor) -> torch.Tensor: return y -class TestAtomicOperations(TestCase): +class TestAtomicOperations(RefEagerTestDisabled, TestCase): def test_basic_atomic_add(self): x = torch.zeros(10, device=DEVICE) y = torch.ones(10, device=DEVICE) diff --git a/test/test_autotuner.py b/test/test_autotuner.py index 59aeec3f..8281506f 100644 --- a/test/test_autotuner.py +++ b/test/test_autotuner.py @@ -13,6 +13,7 @@ import helion from helion import _compat from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import import_path from helion.autotuner import DifferentialEvolutionSearch @@ -27,7 +28,7 @@ examples_matmul = import_path(examples_dir / "matmul.py").matmul -class TestAutotuner(TestCase): +class TestAutotuner(RefEagerTestDisabled, TestCase): def setUp(self): super().setUp() random.seed(112) diff --git a/test/test_broadcasting.py b/test/test_broadcasting.py index 47e55649..b4c52829 100644 --- a/test/test_broadcasting.py +++ b/test/test_broadcasting.py @@ -6,6 +6,7 @@ import helion from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import code_and_output import helion.language as hl @@ -36,7 +37,7 @@ def _check_broadcast_fn(**config): return code -class TestBroadcasting(TestCase): +class TestBroadcasting(RefEagerTestDisabled, TestCase): def test_broadcast_no_flatten(self): args = [torch.randn(512, 512, device=DEVICE), torch.randn(512, device=DEVICE)] assert not broadcast_fn.bind(args).config_spec.flatten_loops diff --git a/test/test_cache.py b/test/test_cache.py index 595756a4..d6ec1acf 100644 --- a/test/test_cache.py +++ b/test/test_cache.py @@ -6,6 +6,7 @@ import torch from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import import_path from helion._utils import counters @@ -21,7 +22,7 @@ def autotune(self): return self.config_spec.default_config() -class TestCache(TestCase): +class TestCache(RefEagerTestDisabled, TestCase): def test_basic(self): a = torch.randn(16, device=DEVICE, dtype=torch.bfloat16) args_a = (a, a) diff --git a/test/test_closures.py b/test/test_closures.py index cde29363..b8a3df2d 100644 --- a/test/test_closures.py +++ b/test/test_closures.py @@ -7,6 +7,7 @@ import helion from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import code_and_output from helion._testing import import_path @@ -26,7 +27,7 @@ def sin_func_arg(a, fn) -> torch.Tensor: return out -class TestClosures(TestCase): +class TestClosures(RefEagerTestDisabled, TestCase): def test_add_global(self): args = (torch.randn([512, 512], device=DEVICE),) code, out = code_and_output(basic_kernels.use_globals, args) diff --git a/test/test_constexpr.py b/test/test_constexpr.py index c637f456..a20093e6 100644 --- a/test/test_constexpr.py +++ b/test/test_constexpr.py @@ -6,12 +6,13 @@ import helion from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import code_and_output import helion.language as hl -class TestConstExpr(TestCase): +class TestConstExpr(RefEagerTestDisabled, TestCase): def test_constexpr_float(self): @helion.kernel() def fn(x: torch.Tensor, v: hl.constexpr) -> torch.Tensor: diff --git a/test/test_control_flow.py b/test/test_control_flow.py index d96932d1..87e55511 100644 --- a/test/test_control_flow.py +++ b/test/test_control_flow.py @@ -6,12 +6,13 @@ import helion from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import code_and_output import helion.language as hl -class TestControlFlow(TestCase): +class TestControlFlow(RefEagerTestDisabled, TestCase): def test_if_arg(self): @helion.kernel() def fn(x, v): diff --git a/test/test_dot.py b/test/test_dot.py index c3dee01a..7d0d6222 100644 --- a/test/test_dot.py +++ b/test/test_dot.py @@ -8,8 +8,10 @@ import helion from helion._testing import DEVICE +from helion._testing import RefEagerTestBase from helion._testing import TestCase from helion._testing import code_and_output +from helion._testing import skipIfRefEager import helion.language as hl @@ -163,10 +165,38 @@ def run_kernel(): return test_impl -class TestDot(TestCase): +class TestDot(RefEagerTestBase, TestCase): pass +# Define ref mode test failures +REF_EAGER_TEST_FAILURES = { + "test_input_float8_e5m2_acc_None_dynamic_shape": "Matmul with float8_e5m2 dtype not supported in ref eager mode", + "test_input_float8_e5m2_acc_None_static_shape": "Matmul with float8_e5m2 dtype not supported in ref eager mode", + "test_input_float8_e5m2_acc_float16_dynamic_shape": "Matmul with float8_e5m2 dtype not supported in ref eager mode", + "test_input_float8_e5m2_acc_float16_static_shape": "Matmul with float8_e5m2 dtype not supported in ref eager mode", + "test_input_float8_e5m2_acc_float32_dynamic_shape": "Matmul with float8_e5m2 dtype not supported in ref eager mode", + "test_input_float8_e5m2_acc_float32_static_shape": "Matmul with float8_e5m2 dtype not supported in ref eager mode", + "test_input_float8_e5m2_acc_int32_dynamic_shape": "Matmul with float8_e5m2 dtype not supported in ref eager mode", + "test_input_float8_e5m2_acc_int32_static_shape": "Matmul with float8_e5m2 dtype not supported in ref eager mode", + "test_input_int8_acc_None_dynamic_shape": "int8 @ int8 -> int32 is not supported in ref eager mode", + "test_input_int8_acc_None_static_shape": "int8 @ int8 -> int32 is not supported in ref eager mode", + "test_input_int8_acc_int32_dynamic_shape": "int8 @ int8 -> int32 is not supported in ref eager mode", + "test_input_int8_acc_int32_static_shape": "int8 @ int8 -> int32 is not supported in ref eager mode", +} + +# Define ref mode test failures for FP8 e4m3fn on GPUs with low compute capability (< 9.0) +REF_EAGER_TEST_FAILURES_FP8_E4M3FN_LOW_COMPUTE_CAP = { + "test_input_float8_e4m3fn_acc_None_dynamic_shape": "Matmul with float8_e4m3fn dtype not supported on this GPU in ref eager mode", + "test_input_float8_e4m3fn_acc_None_static_shape": "Matmul with float8_e4m3fn dtype not supported on this GPU in ref eager mode", + "test_input_float8_e4m3fn_acc_float16_dynamic_shape": "Matmul with float8_e4m3fn dtype not supported on this GPU in ref eager mode", + "test_input_float8_e4m3fn_acc_float16_static_shape": "Matmul with float8_e4m3fn dtype not supported on this GPU in ref eager mode", + "test_input_float8_e4m3fn_acc_float32_dynamic_shape": "Matmul with float8_e4m3fn dtype not supported on this GPU in ref eager mode", + "test_input_float8_e4m3fn_acc_float32_static_shape": "Matmul with float8_e4m3fn dtype not supported on this GPU in ref eager mode", + "test_input_float8_e4m3fn_acc_int32_dynamic_shape": "Matmul with float8_e4m3fn dtype not supported on this GPU in ref eager mode", + "test_input_float8_e4m3fn_acc_int32_static_shape": "Matmul with float8_e4m3fn dtype not supported on this GPU in ref eager mode", +} + # Dynamically generate test methods for input_dtype, acc_dtype, static_shapes_option in itertools.product( INPUT_DTYPES, ACC_DTYPES, STATIC_SHAPES_OPTIONS @@ -182,6 +212,17 @@ class TestDot(TestCase): # Create and add the test method _test_func = make_test_function(input_dtype, acc_dtype, static_shapes_option) _test_func.__name__ = test_name + + # Apply skipIfRefEager decorator if needed + if test_name in REF_EAGER_TEST_FAILURES: + _test_func = skipIfRefEager(REF_EAGER_TEST_FAILURES[test_name])(_test_func) + elif test_name in REF_EAGER_TEST_FAILURES_FP8_E4M3FN_LOW_COMPUTE_CAP: + # For e4m3fn tests, only skip if GPU capability < 9 + if torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] < 9: + _test_func = skipIfRefEager( + REF_EAGER_TEST_FAILURES_FP8_E4M3FN_LOW_COMPUTE_CAP[test_name] + )(_test_func) + setattr(TestDot, test_name, _test_func) diff --git a/test/test_errors.py b/test/test_errors.py index 5e56eacb..593c6cb3 100644 --- a/test/test_errors.py +++ b/test/test_errors.py @@ -6,12 +6,13 @@ import helion from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import code_and_output import helion.language as hl -class TestErrors(TestCase): +class TestErrors(RefEagerTestDisabled, TestCase): def test_tile_unpacking(self): @helion.kernel() def sum_kernel(x: torch.Tensor) -> torch.Tensor: diff --git a/test/test_examples.py b/test/test_examples.py index d2597255..b15fc5e2 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -7,15 +7,17 @@ from helion._testing import DEVICE from helion._testing import EXAMPLES_DIR +from helion._testing import RefEagerTestBase from helion._testing import TestCase from helion._testing import check_example from helion._testing import import_path +from helion._testing import skipIfRefEager torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True -class TestExamples(TestCase): +class TestExamples(RefEagerTestBase, TestCase): def test_add(self): args = ( torch.randn([512, 512], device=DEVICE, dtype=torch.float32), @@ -42,6 +44,9 @@ def test_matmul(self): ) ) + @skipIfRefEager( + "AssertionError: register_reduction_dim must be decorated with @helion.ref() to be used in ref mode" + ) def test_matmul_layernorm_static_shapes(self): args = ( torch.randn([128, 256], device=DEVICE, dtype=torch.float32), @@ -64,6 +69,9 @@ def test_matmul_layernorm_static_shapes(self): ) ) + @skipIfRefEager( + "AssertionError: register_reduction_dim must be decorated with @helion.ref() to be used in ref mode" + ) def test_matmul_layernorm_dynamic_shapes(self): args = ( torch.randn([128, 256], device=DEVICE, dtype=torch.float32), @@ -134,6 +142,9 @@ def test_fp8_gemm(self): ) ) + @skipIfRefEager( + "RuntimeError: The size of tensor a (64) must match the size of tensor b (0) at non-singleton dimension 0" + ) def test_template_via_closure0(self): bias = torch.randn([1, 1024], device=DEVICE, dtype=torch.float16) args = ( @@ -156,6 +167,9 @@ def test_template_via_closure0(self): ) ) + @skipIfRefEager( + "RuntimeError: The size of tensor a (64) must match the size of tensor b (0) at non-singleton dimension 0" + ) def test_template_via_closure1(self): bias = torch.randn([1, 1024], device=DEVICE, dtype=torch.float16) args = ( @@ -243,6 +257,9 @@ def test_softmax_decomposed(self): ) ) + @skipIfRefEager( + "AssertionError: register_block_size must be decorated with @helion.ref() to be used in ref mode" + ) def test_softmax_two_pass(self): args = (torch.randn([1024, 1024], device=DEVICE, dtype=torch.float32),) self.assertExpectedJournal( @@ -254,6 +271,9 @@ def test_softmax_two_pass(self): ) ) + @skipIfRefEager( + "AssertionError: register_block_size must be decorated with @helion.ref() to be used in ref mode" + ) def test_softmax_two_pass_block_ptr(self): args = (torch.randn([1024, 1024], device=DEVICE, dtype=torch.float32),) self.assertExpectedJournal( @@ -267,6 +287,9 @@ def test_softmax_two_pass_block_ptr(self): ) ) + @skipIfRefEager( + "AssertionError: load must be decorated with @helion.ref() to be used in ref mode" + ) def test_cross_entropy(self): n, v = 128, 1000 args = ( @@ -379,6 +402,9 @@ def test_attention_dynamic(self): ) ) + @skipIfRefEager( + "AssertionError: load must be decorated with @helion.ref() to be used in ref mode" + ) def test_concat(self): args = ( torch.randn(512, 500, device=DEVICE), @@ -393,6 +419,9 @@ def test_concat(self): ) ) + @skipIfRefEager( + "AssertionError: load must be decorated with @helion.ref() to be used in ref mode" + ) def test_concat_block_ptr(self): args = ( torch.randn(222, 100, device=DEVICE), @@ -409,6 +438,9 @@ def test_concat_block_ptr(self): ) ) + @skipIfRefEager( + "AssertionError: load must be decorated with @helion.ref() to be used in ref mode" + ) def test_jagged_dense_add(self): mod = import_path(EXAMPLES_DIR / "jagged_dense_add.py") args = ( @@ -424,6 +456,7 @@ def test_jagged_dense_add(self): ) ) + @skipIfRefEager("Test has skip_accuracy=True and doesn't call assert_close") def test_moe_matmul_ogs(self): mod = import_path(EXAMPLES_DIR / "moe_matmul_ogs.py") @@ -449,6 +482,9 @@ def test_moe_matmul_ogs(self): ) ) + @skipIfRefEager( + "AssertionError: register_tunable must be decorated with @helion.ref() to be used in ref mode" + ) def test_matmul_split_k(self): args = ( torch.randn(64, 1024, device=DEVICE), @@ -478,6 +514,9 @@ def test_sum(self): ) ) + @skipIfRefEager( + "AssertionError: load must be decorated with @helion.ref() to be used in ref mode" + ) def test_jagged_mean(self): num_rows, max_cols = 32, 64 M = 8 # number of features @@ -513,6 +552,9 @@ def test_jagged_mean(self): ) ) + @skipIfRefEager( + "AssertionError: load must be decorated with @helion.ref() to be used in ref mode" + ) def test_segment_reduction(self): num_nodes = 100 num_edges = 1000 diff --git a/test/test_generate_ast.py b/test/test_generate_ast.py index 94dc7c43..e27b71f9 100644 --- a/test/test_generate_ast.py +++ b/test/test_generate_ast.py @@ -6,6 +6,7 @@ import torch from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import code_and_output from helion._testing import import_path @@ -14,7 +15,7 @@ basic_kernels = import_path(datadir / "basic_kernels.py") -class TestGenerateAst(TestCase): +class TestGenerateAst(RefEagerTestDisabled, TestCase): def test_add1d(self): args = (torch.randn([4096], device=DEVICE), torch.randn([4096], device=DEVICE)) code, result = code_and_output(basic_kernels.add, args, block_size=1024) diff --git a/test/test_grid.py b/test/test_grid.py index 79892216..283db13a 100644 --- a/test/test_grid.py +++ b/test/test_grid.py @@ -8,6 +8,7 @@ import helion from helion import _compat from helion._testing import DEVICE +from helion._testing import RefEagerTestBase from helion._testing import TestCase from helion._testing import code_and_output import helion.language as hl @@ -31,7 +32,7 @@ def grid_2d_pytorch(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return out -class TestGrid(TestCase): +class TestGrid(RefEagerTestBase, TestCase): @patch.object(_compat, "_min_dot_size", lambda *args: (16, 16, 16)) def test_grid_1d(self): @helion.kernel(static_shapes=True) diff --git a/test/test_indexing.py b/test/test_indexing.py index 6c32a1a5..843c56d6 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -8,6 +8,7 @@ from helion._compat import get_tensor_descriptor_fn_name from helion._compat import supports_tensor_descriptor from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import code_and_output import helion.language as hl @@ -39,7 +40,7 @@ def reduction_sum(x: torch.Tensor) -> torch.Tensor: return out -class TestIndexing(TestCase): +class TestIndexing(RefEagerTestDisabled, TestCase): def test_arange(self): @helion.kernel def arange(length: int, device: torch.device) -> torch.Tensor: diff --git a/test/test_inline_asm_elementwise.py b/test/test_inline_asm_elementwise.py index e91fa910..d272039b 100644 --- a/test/test_inline_asm_elementwise.py +++ b/test/test_inline_asm_elementwise.py @@ -7,12 +7,13 @@ import helion from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import code_and_output import helion.language as hl -class TestInlineAsmElementwise(TestCase): +class TestInlineAsmElementwise(RefEagerTestDisabled, TestCase): @pytest.mark.skipif( DEVICE.type != "cuda", reason="inline_asm_elementwise is only supported on CUDA" ) diff --git a/test/test_logging.py b/test/test_logging.py index 4dbe4872..bedbdc68 100644 --- a/test/test_logging.py +++ b/test/test_logging.py @@ -6,11 +6,12 @@ import helion from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase import helion.language as hl -class TestLogging(TestCase): +class TestLogging(RefEagerTestDisabled, TestCase): def test_log_set(self): import logging diff --git a/test/test_loop_dependencies.py b/test/test_loop_dependencies.py index 7250451d..d1c0535c 100644 --- a/test/test_loop_dependencies.py +++ b/test/test_loop_dependencies.py @@ -8,12 +8,13 @@ import helion from helion import exc from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import code_and_output import helion.language as hl -class TestLoops(TestCase): +class TestLoops(RefEagerTestDisabled, TestCase): def test_loop_dependency_error1(self): @helion.kernel def kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: diff --git a/test/test_loops.py b/test/test_loops.py index 792af2c6..30814750 100644 --- a/test/test_loops.py +++ b/test/test_loops.py @@ -8,9 +8,11 @@ import helion from helion._testing import DEVICE +from helion._testing import RefEagerTestBase from helion._testing import TestCase from helion._testing import code_and_output from helion._testing import import_path +from helion._testing import skipIfRefEager import helion.language as hl datadir = Path(__file__).parent / "data" @@ -40,7 +42,7 @@ def nested_loop_kernel(x: torch.Tensor) -> torch.Tensor: return out -class TestLoops(TestCase): +class TestLoops(RefEagerTestBase, TestCase): def test_pointwise_device_loop(self): args = (torch.randn([512, 512], device=DEVICE),) code, result = code_and_output( @@ -158,6 +160,9 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ) self.assertExpectedJournal(code) + @skipIfRefEager( + "AssertionError: register_block_size must be decorated with @helion.ref() to be used in ref mode" + ) def test_data_dependent_bounds1(self): @helion.kernel() def fn(x: torch.Tensor, end: torch.Tensor) -> torch.Tensor: @@ -223,6 +228,9 @@ def fn(x: torch.Tensor, end0: torch.Tensor, end1: torch.Tensor) -> torch.Tensor: result, args[0][:, : args[1][0].item(), : args[2][0].item()].sum(-1).sum(-1) ) + @skipIfRefEager( + "AssertionError: register_block_size must be decorated with @helion.ref() to be used in ref mode" + ) def test_data_dependent_bounds4(self): @helion.kernel() def fn(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor) -> torch.Tensor: @@ -268,6 +276,9 @@ def fn(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor) -> torch.Tensor: result, args[0][:, args[1][0].item() : args[2][0].item()].sum(-1) ) + @skipIfRefEager( + "AssertionError: register_block_size must be decorated with @helion.ref() to be used in ref mode" + ) def test_register_block_size_minimum(self): @helion.kernel() def fn(x: torch.Tensor) -> torch.Tensor: @@ -285,6 +296,9 @@ def fn(x: torch.Tensor) -> torch.Tensor: self.assertEqual(spec.min_size, 32) self.assertEqual(spec.max_size, 256) + @skipIfRefEager( + "AssertionError: register_block_size must be decorated with @helion.ref() to be used in ref mode" + ) def test_reorder_with_register_block_size(self): @helion.kernel( config={ @@ -306,6 +320,9 @@ def fn(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, args[0] + 1) self.assertExpectedJournal(code) + @skipIfRefEager( + "AssertionError: register_block_size must be decorated with @helion.ref() to be used in ref mode" + ) def test_l2_grouping_with_register_block_size(self): @helion.kernel( config={ @@ -418,6 +435,9 @@ def addToBoth(a, b, c): self.assertExpectedJournal(code) + @skipIfRefEager( + "Test requires block_size=1 which is incompatible with full dimension tile implementation" + ) def test_chebyshev_polynomials(self): """Test nested loops with sequential computation - Chebyshev polynomials.""" @@ -504,6 +524,9 @@ def fn(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(output, x + 6) self.assertExpectedJournal(code) + @skipIfRefEager( + "Test requires block_size=1 which is incompatible with full dimension tile implementation" + ) def test_variable_assignment_phi_nodes(self): """Test for phi node issue with variable assignments like U1 = two_x. @@ -590,7 +613,7 @@ def test_range_unroll_factors(self): torch.testing.assert_close(result0, result2) torch.testing.assert_close(result0, args[0] + 1) - self.assertNotEqual(code0, code2) + self.assertNotEqualCode(code0, code2) self.assertNotIn("loop_unroll_factor", code0) self.assertExpectedJournal(code2) @@ -631,9 +654,9 @@ def test_range_warp_specialize(self): torch.testing.assert_close(result_none, args[0] + 1) # Ensure different code is generated for different settings - self.assertNotEqual(code_none, code_true) - self.assertNotEqual(code_none, code_false) - self.assertNotEqual(code_true, code_false) + self.assertNotEqualCode(code_none, code_true) + self.assertNotEqualCode(code_none, code_false) + self.assertNotEqualCode(code_true, code_false) # Check that warp_specialize appears in the generated code self.assertNotIn("warp_specialize", code_none) @@ -656,7 +679,7 @@ def test_range_num_stages(self): torch.testing.assert_close(result0, result3) torch.testing.assert_close(result0, args[0] + 1) - self.assertNotEqual(code0, code3) + self.assertNotEqualCode(code0, code3) # Check that range_num_stages parameter appears in tl.range call self.assertNotIn( "tl.range(0, x_size_1.to(tl.int32), _BLOCK_SIZE_1, num_stages=", code0 @@ -696,9 +719,9 @@ def test_range_multi_buffers(self): torch.testing.assert_close(result_none, result_true) torch.testing.assert_close(result_none, result_false) torch.testing.assert_close(result_none, args[0] + 1) - self.assertNotEqual(code_none, code_true) - self.assertNotEqual(code_none, code_false) - self.assertNotEqual(code_true, code_false) + self.assertNotEqualCode(code_none, code_true) + self.assertNotEqualCode(code_none, code_false) + self.assertNotEqualCode(code_true, code_false) # Check that disallow_acc_multi_buffer parameter appears in tl.range call self.assertNotIn("disallow_acc_multi_buffer", code_none) self.assertIn("disallow_acc_multi_buffer=False", code_true) @@ -726,14 +749,17 @@ def test_range_flatten(self): torch.testing.assert_close(result_none, result_true) torch.testing.assert_close(result_none, result_false) torch.testing.assert_close(result_none, args[0] + 1) - self.assertNotEqual(code_none, code_true) - self.assertNotEqual(code_none, code_false) - self.assertNotEqual(code_true, code_false) + self.assertNotEqualCode(code_none, code_true) + self.assertNotEqualCode(code_none, code_false) + self.assertNotEqualCode(code_true, code_false) # Check that flatten parameter appears in tl.range call self.assertNotIn("flatten", code_none) self.assertIn("flatten=True", code_true) self.assertIn("flatten=False", code_false) + @skipIfRefEager( + "Static range test checks code generation, not relevant in ref eager mode" + ) def test_static_range_2d(self): @helion.kernel() def nested_loop_kernel_2d(x: torch.Tensor) -> torch.Tensor: @@ -781,13 +807,16 @@ def nested_loop_kernel_2d(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result_false, result_true) torch.testing.assert_close(result_true, args[0] + 1) - self.assertEqual(code_default, code_false) - self.assertEqual(code_ignore, code_true) - self.assertNotEqual(code_true, code_false) + self.assertEqualCode(code_default, code_false) + self.assertEqualCode(code_ignore, code_true) + self.assertNotEqualCode(code_true, code_false) # Check that tl.range / tl.static_range is used according to setups. self.assertIn("tl.range", code_false) self.assertIn("tl.static_range", code_true) + @skipIfRefEager( + "Static range test checks code generation, not relevant in ref eager mode" + ) def test_static_range_scalar(self): @helion.kernel() def nested_loop_kernel_scalar(x: torch.Tensor) -> torch.Tensor: @@ -830,9 +859,9 @@ def nested_loop_kernel_scalar(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result_default, result_true) torch.testing.assert_close(result_default, result_false) torch.testing.assert_close(result_default, x + 4) - self.assertNotEqual(code_default, code_true) - self.assertNotEqual(code_true, code_false) - self.assertEqual(code_default, code_false) + self.assertNotEqualCode(code_default, code_true) + self.assertNotEqualCode(code_true, code_false) + self.assertEqualCode(code_default, code_false) # Check that tl.range / tl.static_range is used according to setups. self.assertIn("tl.range", code_false) self.assertIn("tl.static_range", code_true) diff --git a/test/test_masking.py b/test/test_masking.py index 0e376024..d85a4fc4 100644 --- a/test/test_masking.py +++ b/test/test_masking.py @@ -6,12 +6,13 @@ import helion from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import code_and_output import helion.language as hl -class TestMasking(TestCase): +class TestMasking(RefEagerTestDisabled, TestCase): def test_mask_dot(self): @helion.kernel(config={"block_sizes": [[32, 32], 32]}, dot_precision="ieee") def add1mm(x, y): diff --git a/test/test_matmul.py b/test/test_matmul.py index 381a2c2a..a35b13d8 100644 --- a/test/test_matmul.py +++ b/test/test_matmul.py @@ -9,6 +9,7 @@ from helion import Config from helion._compat import supports_tensor_descriptor from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import code_and_output from helion._testing import import_path @@ -66,7 +67,7 @@ def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return out -class TestMatmul(TestCase): +class TestMatmul(RefEagerTestDisabled, TestCase): def test_matmul0(self): args = ( torch.randn([128, 128], device=DEVICE, dtype=torch.float32), diff --git a/test/test_misc.py b/test/test_misc.py index 626ec422..06d4c0d1 100644 --- a/test/test_misc.py +++ b/test/test_misc.py @@ -11,12 +11,13 @@ import helion from helion._compat import supports_tensor_descriptor from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import code_and_output import helion.language as hl -class TestMisc(TestCase): +class TestMisc(RefEagerTestDisabled, TestCase): def test_binary_operation_duplicate_args(self): """Test case to reproduce issue #221: binary operations with duplicate tensor references""" diff --git a/test/test_persistent_kernels.py b/test/test_persistent_kernels.py index bd1acce7..472a5c33 100644 --- a/test/test_persistent_kernels.py +++ b/test/test_persistent_kernels.py @@ -8,6 +8,7 @@ from helion._compat import get_tensor_descriptor_fn_name from helion._compat import supports_tensor_descriptor from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import code_and_output import helion.language as hl @@ -53,7 +54,7 @@ def add1_kernel(x: torch.Tensor) -> torch.Tensor: return result -class TestPersistentKernels(TestCase): +class TestPersistentKernels(RefEagerTestDisabled, TestCase): """Test persistent kernel codegen with different PID strategies.""" def test_persistent_blocked_simple_add(self): diff --git a/test/test_print.py b/test/test_print.py index 9cbd5667..561f4079 100644 --- a/test/test_print.py +++ b/test/test_print.py @@ -10,6 +10,7 @@ import helion from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import code_and_output import helion.language as hl @@ -25,7 +26,7 @@ def _store_capfd_on_class(request, capfd): request.cls._capfd = capfd -class TestPrint(TestCase): +class TestPrint(RefEagerTestDisabled, TestCase): def run_kernel_and_capture_output(self, kernel_fn, args): """Helper to run kernel and capture output""" if hasattr(self, "_capfd"): diff --git a/test/test_reduce.py b/test/test_reduce.py index 83d8c1b9..4f6017a0 100644 --- a/test/test_reduce.py +++ b/test/test_reduce.py @@ -6,6 +6,7 @@ import helion from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import code_and_output import helion.language as hl @@ -78,7 +79,7 @@ def jit_add_combine_fn(x, y): return x + y -class TestReduce(TestCase): +class TestReduce(RefEagerTestDisabled, TestCase): def test_reduce_basic_sum(self): """Test basic reduce functionality with sum reduction along a dimension.""" diff --git a/test/test_reductions.expected b/test/test_reductions.expected index 8b2a4d71..f0b00796 100644 --- a/test/test_reductions.expected +++ b/test/test_reductions.expected @@ -276,42 +276,42 @@ def layer_norm_fwd_repro(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tens --- assertExpectedJournal(TestReductions.test_mean) def reduce_kernel(x: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor], out_dtype=torch.float32): - # Call: SequenceType((SymIntType(s77), SymIntType(s27))) SourceOrigin(location=) + # Call: SequenceType((SymIntType(s77), SymIntType(s27))) SourceOrigin(location=) # Attribute: TensorAttributeType AttributeOrigin(value=ArgumentOrigin(name='x'), key='size') # Name: TensorType([x_size0, x_size1], torch.float32) ArgumentOrigin(name='x') n, _m = x.size() - # Call: TensorType([x_size0], torch.float32) SourceOrigin(location=) + # Call: TensorType([x_size0], torch.float32) SourceOrigin(location=) # Attribute: CallableType(_VariableFunctionsClass.empty) AttributeOrigin(value=GlobalOrigin(name='torch'), key='empty') # Name: PythonModuleType(torch) GlobalOrigin(name='torch') - # List: SequenceType([SymIntType(s77)]) SourceOrigin(location=) - # Name: SymIntType(s77) GetItemOrigin(value=SourceOrigin(location=), key=0) + # List: SequenceType([SymIntType(s77)]) SourceOrigin(location=) + # Name: SymIntType(s77) GetItemOrigin(value=SourceOrigin(location=), key=0) # Name: LiteralType(torch.float32) ArgumentOrigin(name='out_dtype') # Attribute: LiteralType(device(type='cuda', index=0)) AttributeOrigin(value=ArgumentOrigin(name='x'), key='device') # Name: TensorType([x_size0, x_size1], torch.float32) ArgumentOrigin(name='x') # For: loop_type=GRID out = torch.empty([n], dtype=out_dtype, device=x.device) - # Call: IterType(TileIndexType(0)) SourceOrigin(location=) + # Call: IterType(TileIndexType(0)) SourceOrigin(location=) # Attribute: CallableType(tile) AttributeOrigin(value=GlobalOrigin(name='hl'), key='tile') # Name: PythonModuleType(helion.language) GlobalOrigin(name='hl') - # Name: SymIntType(s77) GetItemOrigin(value=SourceOrigin(location=), key=0) + # Name: SymIntType(s77) GetItemOrigin(value=SourceOrigin(location=), key=0) for tile_n in hl.tile(n): - # Subscript: TensorType([block_size_0], torch.float32) DeviceOrigin(location=) - # Name: TensorType([x_size0], torch.float32) SourceOrigin(location=) - # Name: TileIndexType(0) SourceOrigin(location=) - # Call: TensorType([block_size_0], torch.float32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0], torch.float32) DeviceOrigin(location=) + # Name: TensorType([x_size0], torch.float32) SourceOrigin(location=) + # Name: TileIndexType(0) SourceOrigin(location=) + # Call: TensorType([block_size_0], torch.float32) DeviceOrigin(location=) # Name: CallableType(_VariableFunctionsClass.mean) ArgumentOrigin(name='fn') - # Subscript: TensorType([block_size_0, rdim_1], torch.float32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0, rdim_1], torch.float32) DeviceOrigin(location=) # Name: TensorType([x_size0, x_size1], torch.float32) ArgumentOrigin(name='x') - # Name: TileIndexType(0) SourceOrigin(location=) - # Slice: SliceType(LiteralType(None):LiteralType(None):LiteralType(None)) DeviceOrigin(location=) - # UnaryOp: LiteralType(-1) DeviceOrigin(location=) - # Constant: LiteralType(1) DeviceOrigin(location=) + # Name: TileIndexType(0) SourceOrigin(location=) + # Slice: SliceType(LiteralType(None):LiteralType(None):LiteralType(None)) DeviceOrigin(location=) + # UnaryOp: LiteralType(-1) DeviceOrigin(location=) + # Constant: LiteralType(1) DeviceOrigin(location=) out[tile_n] = fn(x[tile_n, :], dim=-1) - # Name: TensorType([x_size0], torch.float32) SourceOrigin(location=) + # Name: TensorType([x_size0], torch.float32) SourceOrigin(location=) return out def root_graph_0(): - # File: .../test_reductions.py:55 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1) + # File: .../test_reductions.py:56 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1) x: "f32[s77, s27]" = helion_language__tracing_ops__host_tensor('x') block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0') load: "f32[u0, u1]" = helion_language_memory_ops_load(x, [block_size_0, slice(None, None, None)], None); x = None @@ -322,7 +322,7 @@ def root_graph_0(): return None def reduction_loop_1(): - # File: .../test_reductions.py:55 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1) + # File: .../test_reductions.py:56 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1) x: "f32[s77, s27]" = helion_language__tracing_ops__host_tensor('x') block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0') load: "f32[u0, u1]" = helion_language_memory_ops_load(x, [block_size_0, slice(None, None, None)], None); x = block_size_0 = None @@ -330,7 +330,7 @@ def reduction_loop_1(): return [mean_extra] def root_graph_2(): - # File: .../test_reductions.py:55 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1) + # File: .../test_reductions.py:56 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1) block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0') _get_symnode: "Sym(s27)" = helion_language__tracing_ops__get_symnode('rdim1') _for_loop = helion_language__tracing_ops__for_loop(1, [0], [_get_symnode], []); _get_symnode = None diff --git a/test/test_reductions.py b/test/test_reductions.py index 19fd87d6..0bef1abb 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -7,6 +7,7 @@ import helion from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import code_and_output import helion.language as hl @@ -56,7 +57,7 @@ def reduce_kernel( return out -class TestReductions(TestCase): +class TestReductions(RefEagerTestDisabled, TestCase): def test_sum(self): args = (torch.randn([512, 512], device=DEVICE),) code, output = code_and_output(sum_kernel, args, block_size=1) diff --git a/test/test_ref_eager.py b/test/test_ref_eager.py new file mode 100644 index 00000000..6aa6b758 --- /dev/null +++ b/test/test_ref_eager.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import contextlib +import io +import math +import unittest + +import torch + +import helion +from helion._testing import TestCase +from helion._testing import assert_ref_eager_mode +import helion.language as hl + + +class TestRefEagerMisc(TestCase): + def test_print_intermediate_tensor(self): + @helion.kernel(ref_mode=helion.RefMode.EAGER) + def print_intermediate_tensor_kernel( + x: torch.Tensor, y: torch.Tensor + ) -> torch.Tensor: + out = torch.empty_like(x) + m, n = x.shape + for tile_m, tile_n in hl.tile([m, n]): + x_val = x[tile_m, tile_n] + y_val = y[tile_m, tile_n] + sum_val = x_val + y_val + print("x: ", x_val) + print("y: ", y_val) + print("sum: ", sum_val) + out[tile_m, tile_n] = sum_val + return out + + x = torch.ones([2, 2], device="cuda", dtype=torch.float32) * 10.0 + y = torch.ones([2, 2], device="cuda", dtype=torch.float32) * 5.0 + expected = x + y + + # Capture stdout to check print output + captured_output = io.StringIO() + with contextlib.redirect_stdout(captured_output): + result = print_intermediate_tensor_kernel(x, y) + + torch.testing.assert_close(result, expected, atol=1e-6, rtol=1e-6) + + # Check that the print statements produced output + output = captured_output.getvalue() + self.assertIn("x: ", output) + self.assertIn("y: ", output) + self.assertIn("sum: ", output) + self.assertIn("[[10., 10.]", output) # x values + self.assertIn("[[5., 5.]", output) # y values + self.assertIn("[[15., 15.]", output) # sum values + + def test_print_in_invalid_helion_kernel(self): + """Test that print works even in invalid Helion kernels in ref eager mode.""" + + @helion.kernel(ref_mode=helion.RefMode.EAGER) + def incorrect_kernel(x: torch.Tensor) -> torch.Tensor: + m, n = x.shape + for tile_m, tile_n in hl.tile([m, n]): + val = x[tile_m, tile_n] + print("processing tile: ", val) + # `pass` below causes this kernel to be invalid. + # But we show that in ref-eager mode, the `print` statement above still works, + # which is useful for debugging. + pass # noqa: PIE790 + return x + + x = torch.ones([2, 2], device="cuda", dtype=torch.float32) * math.pi + + # Capture stdout to check print output + captured_output = io.StringIO() + with contextlib.redirect_stdout(captured_output): + _ = incorrect_kernel(x) + + # Check that the print statement produced output + output = captured_output.getvalue() + self.assertIn("processing tile: ", output) + self.assertIn("[[3.14", output) # The value printed + + def test_ref_eager_kernel_config(self): + @helion.kernel(ref_mode=helion.RefMode.EAGER) + def kernel(x: torch.Tensor) -> torch.Tensor: + m, n = x.shape + out = torch.empty_like(x) + for tile_m, tile_n in hl.tile([m, n]): + out[tile_m, tile_n] = x[tile_m, tile_n] * 2.0 + return out + + with assert_ref_eager_mode(): + x = torch.randn(128, 128, device="cuda") + result = kernel(x) + expected = x * 2.0 + torch.testing.assert_close(result, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_register_tunable.py b/test/test_register_tunable.py index 18ca8287..bce4be97 100644 --- a/test/test_register_tunable.py +++ b/test/test_register_tunable.py @@ -8,6 +8,7 @@ import helion from helion import _compat from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import code_and_output from helion.autotuner import EnumFragment @@ -17,7 +18,7 @@ from helion.language import loops -class TestRegisterTunable(TestCase): +class TestRegisterTunable(RefEagerTestDisabled, TestCase): maxDiff = 10000 def test_power_of_two_fragment_basic(self): diff --git a/test/test_signal_wait.py b/test/test_signal_wait.py index dbff6046..ec47c59e 100644 --- a/test/test_signal_wait.py +++ b/test/test_signal_wait.py @@ -6,12 +6,13 @@ import helion from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import code_and_output import helion.language as hl -class TestWait(TestCase): +class TestWait(RefEagerTestDisabled, TestCase): def test_wait_basic(self): @helion.kernel def gmem_wait_kernel(signal_pad: torch.Tensor) -> torch.Tensor: diff --git a/test/test_specialize.py b/test/test_specialize.py index e7f11b50..4e91b995 100644 --- a/test/test_specialize.py +++ b/test/test_specialize.py @@ -7,13 +7,14 @@ import helion from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import code_and_output from helion.exc import ShapeSpecializingAllocation import helion.language as hl -class TestSpecialize(TestCase): +class TestSpecialize(RefEagerTestDisabled, TestCase): maxDiff = 163842 def test_sqrt_does_not_specialize(self): diff --git a/test/test_tensor_descriptor.py b/test/test_tensor_descriptor.py index 3ad0d0c6..a64edd7a 100644 --- a/test/test_tensor_descriptor.py +++ b/test/test_tensor_descriptor.py @@ -8,13 +8,14 @@ from helion._compat import get_tensor_descriptor_fn_name from helion._compat import supports_tensor_descriptor from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import check_example from helion._testing import code_and_output import helion.language as hl -class TestTensorDescriptor(TestCase): +class TestTensorDescriptor(RefEagerTestDisabled, TestCase): @unittest.skipUnless( supports_tensor_descriptor(), "Tensor descriptor support is required" ) diff --git a/test/test_type_propagation.expected b/test/test_type_propagation.expected index 9500bbd1..ca6707dd 100644 --- a/test/test_type_propagation.expected +++ b/test/test_type_propagation.expected @@ -683,33 +683,33 @@ def root_graph_1(): --- assertExpectedJournal(TestTypePropagation.test_method_call) def fn(x): - # Call: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=) + # Call: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=) # Attribute: CallableType(_VariableFunctionsClass.empty_like) AttributeOrigin(value=GlobalOrigin(name='torch'), key='empty_like') # Name: PythonModuleType(torch) GlobalOrigin(name='torch') # Name: TensorType([x_size0, x_size1], torch.int32) ArgumentOrigin(name='x') # For: loop_type=GRID out = torch.empty_like(x) - # Call: IterType(SequenceType([TileIndexType(0), TileIndexType(1)])) SourceOrigin(location=) + # Call: IterType(SequenceType([TileIndexType(0), TileIndexType(1)])) SourceOrigin(location=) # Attribute: CallableType(tile) AttributeOrigin(value=GlobalOrigin(name='hl'), key='tile') # Name: PythonModuleType(helion.language) GlobalOrigin(name='hl') - # Call: SequenceType((SymIntType(s77), SymIntType(s27))) SourceOrigin(location=) + # Call: SequenceType((SymIntType(s77), SymIntType(s27))) SourceOrigin(location=) # Attribute: TensorAttributeType AttributeOrigin(value=ArgumentOrigin(name='x'), key='size') # Name: TensorType([x_size0, x_size1], torch.int32) ArgumentOrigin(name='x') for tile in hl.tile(x.size()): - # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) - # Name: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=) - # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=) - # Call: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) - # Attribute: TensorAttributeType AttributeOrigin(value=DeviceOrigin(location=), key='sin') - # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) + # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) + # Name: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=) + # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=) + # Call: TensorType([block_size_0, block_size_1], torch.float32) DeviceOrigin(location=) + # Attribute: TensorAttributeType AttributeOrigin(value=DeviceOrigin(location=), key='sin') + # Subscript: TensorType([block_size_0, block_size_1], torch.int32) DeviceOrigin(location=) # Name: TensorType([x_size0, x_size1], torch.int32) ArgumentOrigin(name='x') - # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=) + # Name: SequenceType([TileIndexType(0), TileIndexType(1)]) SourceOrigin(location=) out[tile] = x[tile].sin() - # Name: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=) + # Name: TensorType([x_size0, x_size1], torch.int32) SourceOrigin(location=) return out def root_graph_0(): - # File: .../test_type_propagation.py:78 in fn, code: out[tile] = x[tile].sin() + # File: .../test_type_propagation.py:79 in fn, code: out[tile] = x[tile].sin() x: "i32[s77, s27]" = helion_language__tracing_ops__host_tensor('x') block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0') block_size_1: "Sym(u1)" = helion_language__tracing_ops__get_symnode('block_size_1') diff --git a/test/test_type_propagation.py b/test/test_type_propagation.py index 564ff1ee..45a89c26 100644 --- a/test/test_type_propagation.py +++ b/test/test_type_propagation.py @@ -7,6 +7,7 @@ import torch import helion +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import import_path import helion.language as hl @@ -23,7 +24,7 @@ def type_propagation_report(fn: Kernel, *args, ignore=False): return fn.bind(args)._debug_str() -class TestTypePropagation(TestCase): +class TestTypePropagation(RefEagerTestDisabled, TestCase): def test_add(self): output = type_propagation_report( basic_kernels.add, diff --git a/test/test_unroll_tuples.py b/test/test_unroll_tuples.py index 0869df72..82115666 100644 --- a/test/test_unroll_tuples.py +++ b/test/test_unroll_tuples.py @@ -6,6 +6,7 @@ import helion from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import code_and_output import helion.language as hl @@ -209,7 +210,7 @@ def kernel_enumerate_constants( return result -class TestUnrollTuples(TestCase): +class TestUnrollTuples(RefEagerTestDisabled, TestCase): def test_basic_tuple_addition(self): """Test basic iteration over tuple of tensors with addition.""" size = (32,) diff --git a/test/test_views.py b/test/test_views.py index 98573b47..79c634d2 100644 --- a/test/test_views.py +++ b/test/test_views.py @@ -6,12 +6,13 @@ import helion from helion._testing import DEVICE +from helion._testing import RefEagerTestDisabled from helion._testing import TestCase from helion._testing import code_and_output import helion.language as hl -class TestViews(TestCase): +class TestViews(RefEagerTestDisabled, TestCase): def test_softmax_unsqueeze(self): @helion.kernel(config={"block_size": 1}) def softmax(x: torch.Tensor) -> torch.Tensor: