From 1cee7d4c3f54849664a1f2ffde92d61c36487562 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Sat, 26 Jul 2025 12:21:41 -0700 Subject: [PATCH 1/3] Add hl.dot() API; Use hl.dot instead of torch.matmul for FP8 GEMM ops in Helion kernel stack-info: PR: https://github.com/pytorch-labs/helion/pull/356, branch: yf225/stack/39 --- examples/fp8_attention.py | 74 +- examples/fp8_gemm.py | 17 +- helion/_compiler/indexing_strategy.py | 9 +- helion/_compiler/inductor_lowering.py | 14 +- helion/language/__init__.py | 1 + helion/language/matmul_ops.py | 211 ++++ test/test_dot.expected | 1286 +++++++++++++++++++++++++ test/test_dot.py | 189 ++++ test/test_examples.expected | 55 +- 9 files changed, 1787 insertions(+), 69 deletions(-) create mode 100644 helion/language/matmul_ops.py create mode 100644 test/test_dot.expected create mode 100644 test/test_dot.py diff --git a/examples/fp8_attention.py b/examples/fp8_attention.py index f9c5153b..c6bd5c2b 100644 --- a/examples/fp8_attention.py +++ b/examples/fp8_attention.py @@ -23,7 +23,7 @@ def fp8_attention_kernel( # Output tensor with 4D shape in FP8 format out = torch.empty( - [batch, heads, seq_len, head_dim], dtype=torch.float8_e5m2, device=q.device + [batch, heads, seq_len, head_dim], dtype=torch.float8_e4m3fn, device=q.device ) # Scale factor for attention @@ -54,9 +54,7 @@ def fp8_attention_kernel( k_tile_t = k_tile.transpose(0, 1) # [dim, tile_n] # Compute Q @ K^T with FP8 inputs, result in FP32 - qk = torch.matmul(q_tile, k_tile_t).to( - torch.float32 - ) # [tile_m, tile_n] + qk = hl.dot(q_tile, k_tile_t) # [tile_m, tile_n] # Scale QK scores first qk_scaled = qk * sm_scale # [tile_m, tile_n] @@ -90,9 +88,9 @@ def fp8_attention_kernel( p_fp8 = p.to(v.dtype) # Convert to same FP8 type as V # Accumulate attention @ V with FP8 GEMM - v_t = v_tile.transpose(0, 1) # [tile_n, dim] - pv = torch.matmul(p_fp8, v_t).to(torch.float32) # [tile_m, dim] - acc = acc + pv + # v_tile is [dim, tile_n], we need to transpose for P @ V^T + v_t = v_tile.t() # [tile_n, dim] + acc = hl.dot(p_fp8, v_t, acc=acc) # [tile_m, dim] # Update max tracker m_i = m_new @@ -100,7 +98,7 @@ def fp8_attention_kernel( # Final normalization acc = acc / l_i[:, None] # Convert to FP8 before writing to output - out[b, h, tile_m, :] = acc.to(torch.float8_e5m2) + out[b, h, tile_m, :] = acc.to(torch.float8_e4m3fn) return out @@ -108,10 +106,10 @@ def fp8_attention_kernel( def preprocess_fp8_attention_inputs( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - q_fp8 = q.to(torch.float8_e5m2) - k_fp8 = k.to(torch.float8_e5m2) + q_fp8 = q.to(torch.float8_e4m3fn) + k_fp8 = k.to(torch.float8_e4m3fn) v = v.permute(0, 1, 3, 2) - v_fp8 = v.to(torch.float8_e5m2) + v_fp8 = v.to(torch.float8_e4m3fn) batch, heads, seq_len, head_dim = q.shape q_fp8_reshaped = q_fp8.reshape(batch * heads, seq_len, head_dim) k_fp8_reshaped = k_fp8.reshape(batch * heads, seq_len, head_dim) @@ -147,13 +145,25 @@ def _fp8_attention_pytorch_impl( k_i = k_fp8[i] # [seq, dim] - already FP8 v_i = v_fp8[i] # [dim, seq] - pre-transposed, already FP8 - # For Q @ K^T, we need K^T to be column-major - kt_fp8 = k_i.t() # column-major [dim, seq] - - # Q @ K^T - dequantize and use regular matmul since e5m2 not supported by _scaled_mm - q_deq = q_i.to(torch.float32) - kt_deq = kt_fp8.to(torch.float32) - qk = torch.matmul(q_deq, kt_deq) + # For Q @ K^T using torch._scaled_mm + # torch._scaled_mm requires column-major for second operand + # k_i is [seq, dim], we need K^T as [dim, seq] in column-major + # Direct conversion: k_i -> contiguous -> transpose view + kt_fp8_col_major = k_i.contiguous().t() # [dim, seq] in column-major + + # Create scale tensors + scale_q = torch.tensor(1.0, device=q_i.device) + scale_k = torch.tensor(1.0, device=k_i.device) + + # Q @ K^T using torch._scaled_mm + qk = torch._scaled_mm( + q_i, + kt_fp8_col_major, + scale_q, + scale_k, + use_fast_accum=False, + out_dtype=torch.float32, + ) # Compute max before scaling qk_max = torch.amax(qk, dim=-1, keepdim=True) @@ -168,16 +178,26 @@ def _fp8_attention_pytorch_impl( # Step 2: Attention @ V using FP8 # P is [seq, seq], V is [dim, seq] # We want P @ V^T = [seq, seq] @ [seq, dim] = [seq, dim] - p_fp8 = p_norm.to(torch.float8_e5m2) # row-major [seq, seq] + p_fp8 = p_norm.to(torch.float8_e4m3fn) # row-major [seq, seq] # v_i is [dim, seq], already FP8 - vt_fp8 = v_i.t() # column-major [seq, dim] - - # P @ V^T - dequantize and use regular matmul since e5m2 not supported by torch._scaled_mm - p_deq = p_fp8.to(torch.float32) - vt_deq = vt_fp8.to(torch.float32) - out_i = torch.matmul(p_deq, vt_deq) - out_i = out_i.to(torch.float8_e5m2) # convert back to FP8 + # Direct conversion: v_i -> contiguous -> transpose view + vt_fp8_col_major = v_i.contiguous().t() # [seq, dim] in column-major + + # Create scale tensors for P @ V^T + scale_p = torch.tensor(1.0, device=p_fp8.device) + scale_v = torch.tensor(1.0, device=v_i.device) + + # P @ V^T using torch._scaled_mm + out_i = torch._scaled_mm( + p_fp8, + vt_fp8_col_major, + scale_p, + scale_v, + use_fast_accum=False, + out_dtype=torch.float32, + ) + out_i = out_i.to(torch.float8_e4m3fn) # convert back to FP8 to match kernel outputs.append(out_i) @@ -192,7 +212,7 @@ def fp8_attention_pytorch( v: torch.Tensor, # [batch, heads, seq, dim] ) -> Callable[[], torch.Tensor]: """ - Baseline PyTorch implementation of FP8 attention using FP8 e5m2. + Baseline PyTorch implementation of FP8 attention using torch._scaled_mm. """ batch, heads, seq_len, head_dim = q.shape q_fp8, k_fp8, v_fp8 = preprocess_fp8_attention_inputs(q, k, v) diff --git a/examples/fp8_gemm.py b/examples/fp8_gemm.py index 81cc6815..8557ad33 100644 --- a/examples/fp8_gemm.py +++ b/examples/fp8_gemm.py @@ -1,13 +1,21 @@ from __future__ import annotations +import os + import torch import helion from helion._testing import run_example import helion.language as hl +# Override default config to work around Triton tl.dot requirement: +# `AssertionError: Input shapes should have M >= 16, N >= 16 and K >= 32` +config = None +if os.environ.get("HELION_USE_DEFAULT_CONFIG") == "1": + config = helion.Config(block_sizes=[32, 32, 32]) + -@helion.kernel(static_shapes=True) +@helion.kernel(static_shapes=True, config=config) def fp8_gemm(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """FP8 General Matrix Multiplication (GEMM). @@ -37,11 +45,8 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: x_tile = x[tile_m, tile_k] y_tile = y[tile_k, tile_n] - # Use torch.matmul which will be lowered to tl.dot - # When the inputs are FP8, tl.dot handles them natively - # The result needs to be converted to FP32 for accumulation - result = torch.matmul(x_tile, y_tile).to(torch.float32) - acc = acc + result + # Use hl.dot for FP8 GEMM + acc = hl.dot(x_tile, y_tile, acc=acc) out[tile_m, tile_n] = acc.to(torch.float16) return out diff --git a/helion/_compiler/indexing_strategy.py b/helion/_compiler/indexing_strategy.py index 9bcd01ab..0d2d7cc2 100644 --- a/helion/_compiler/indexing_strategy.py +++ b/helion/_compiler/indexing_strategy.py @@ -70,7 +70,14 @@ def codegen_load( extra_mask: ast.AST | None, ) -> ast.AST: indexing = SubscriptIndexing.create(state, fake_tensor, subscript, extra_mask) - extra = ", other=0" if indexing.has_mask() else "" + extra = "" + if indexing.has_mask(): + # For FP8 dtypes, use other=0.0 (float literal) instead of other=0 (int literal) + # because Triton cannot cast integer 0 to FP8 types + if fake_tensor.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + extra = ", other=0.0" + else: + extra = ", other=0" name = state.device_function.tensor_arg(fake_tensor).name return expr_from_string( f"tl.load({name} + offset, mask{extra})", diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index fdb17aab..599a0091 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -848,15 +848,19 @@ def reduce_3d_dot( rhs_node = node.args[1] assert isinstance(lhs, ast.AST) assert isinstance(rhs, ast.AST) + assert isinstance(lhs_node, torch.fx.Node) + assert isinstance(rhs_node, torch.fx.Node) - # Check if inputs are FP8 - if so, don't specify input_precision to allow native FP8 computation - lhs_dtype = lhs_node.meta["val"].dtype # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] - rhs_dtype = rhs_node.meta["val"].dtype # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] + # Check if inputs are FP8 - if so, redirect user to hl.dot() + lhs_dtype = lhs_node.meta["val"].dtype + rhs_dtype = rhs_node.meta["val"].dtype if lhs_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and rhs_dtype in [ torch.float8_e4m3fn, torch.float8_e5m2, ]: - datatype = None # Let Triton use native FP8 computation + raise NotImplementedError( + "FP8 GEMM via torch API is not supported yet. Please use hl.dot() instead." + ) lhs_size = lhs_node.meta["val"].size() # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] rhs_size = rhs_node.meta["val"].size() # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] @@ -1138,7 +1142,7 @@ def proxy_arg(self, i: int) -> object: def ast_arg(self, i: int) -> ast.AST: rv = self.ast_args[i] - if isinstance(rv, int | float | bool): + if isinstance(rv, int | float | bool | None): rv = ast.Constant(value=rv) assert isinstance(rv, ast.AST), "TODO: convert nested/defaults" return rv diff --git a/helion/language/__init__.py b/helion/language/__init__.py index 245b9ab6..3b8c5946 100644 --- a/helion/language/__init__.py +++ b/helion/language/__init__.py @@ -10,6 +10,7 @@ from .loops import grid as grid from .loops import static_range as static_range from .loops import tile as tile +from .matmul_ops import dot as dot from .memory_ops import atomic_add as atomic_add from .memory_ops import load as load from .memory_ops import store as store diff --git a/helion/language/matmul_ops.py b/helion/language/matmul_ops.py new file mode 100644 index 00000000..25e827d0 --- /dev/null +++ b/helion/language/matmul_ops.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +import ast +from typing import TYPE_CHECKING + +import torch +from torch._inductor.utils import triton_type +from torch._subclasses.fake_tensor import FakeTensor + +from .. import exc +from . import _decorators + +if TYPE_CHECKING: + from .._compiler.inductor_lowering import CodegenState + + +@_decorators.api(is_device_only=True) +def dot( + mat1: torch.Tensor, + mat2: torch.Tensor, + acc: torch.Tensor | None = None, +) -> torch.Tensor: + """ + Performs a matrix multiplication of tensors with support for multiple dtypes. + + This operation performs matrix multiplication with inputs of various dtypes including + float16, bfloat16, float32, int8, and FP8 formats (e4m3fn, e5m2). The computation is + performed with appropriate precision based on the input dtypes. + + Args: + mat1: First matrix (2D or 3D tensor of torch.float16, torch.bfloat16, torch.float32, torch.int8, torch.float8_e4m3fn, or torch.float8_e5m2) + mat2: Second matrix (2D or 3D tensor of torch.float16, torch.bfloat16, torch.float32, torch.int8, torch.float8_e4m3fn, or torch.float8_e5m2) + acc: The accumulator tensor (2D or 3D tensor of torch.float16, torch.float32, or torch.int32). + If not None, the result is added to this tensor. + If None, a new tensor is created with appropriate dtype based on inputs. + + Returns: + Result of matrix multiplication. If acc is provided, returns acc + (mat1 @ mat2). + Otherwise returns (mat1 @ mat2) with promoted dtype. + + Example: + >>> # FP8 example + >>> a = torch.randn(32, 64, device="cuda").to(torch.float8_e4m3fn) + >>> b = torch.randn(64, 128, device="cuda").to(torch.float8_e4m3fn) + >>> c = torch.zeros(32, 128, device="cuda", dtype=torch.float32) + >>> result = hl.dot(a, b, acc=c) # result is c + (a @ b) + + >>> # Float16 example + >>> a = torch.randn(32, 64, device="cuda", dtype=torch.float16) + >>> b = torch.randn(64, 128, device="cuda", dtype=torch.float16) + >>> result = hl.dot(a, b) # result dtype will be torch.float16 + + >>> # Int8 example + >>> a = torch.randint(-128, 127, (32, 64), device="cuda", dtype=torch.int8) + >>> b = torch.randint(-128, 127, (64, 128), device="cuda", dtype=torch.int8) + >>> acc = torch.zeros(32, 128, device="cuda", dtype=torch.int32) + >>> result = hl.dot(a, b, acc=acc) # int8 x int8 -> int32 + """ + raise exc.NotInsideKernel + + +@_decorators.prepare_args(dot) +def _( + mat1: torch.Tensor, + mat2: torch.Tensor, + acc: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + # Define supported dtypes + supported_dtypes = ( + torch.float16, + torch.bfloat16, + torch.float32, + torch.int8, + torch.float8_e4m3fn, + torch.float8_e5m2, + ) + + # Validate input types + if mat1.dtype not in supported_dtypes: + raise TypeError( + f"hl.dot: mat1 must be one of {[str(d) for d in supported_dtypes]}, got {mat1.dtype}" + ) + if mat2.dtype not in supported_dtypes: + raise TypeError( + f"hl.dot: mat2 must be one of {[str(d) for d in supported_dtypes]}, got {mat2.dtype}" + ) + + # Validate shapes for matrix multiplication + if mat1.ndim not in (2, 3): + raise ValueError(f"hl.dot: mat1 must be 2D or 3D tensor, got {mat1.ndim}D") + if mat2.ndim not in (2, 3): + raise ValueError(f"hl.dot: mat2 must be 2D or 3D tensor, got {mat2.ndim}D") + + # Check matrix multiplication compatibility + if mat1.shape[-1] != mat2.shape[-2]: + raise ValueError( + f"hl.dot: incompatible matrix dimensions for multiplication: " + f"{mat1.shape} @ {mat2.shape}" + ) + + # Validate accumulator if provided + if acc is not None: + # Allow int32 accumulator for int8 inputs + valid_acc_dtypes = (torch.float16, torch.float32, torch.int32) + if acc.dtype not in valid_acc_dtypes: + raise TypeError( + f"hl.dot: acc must be one of {[str(d) for d in valid_acc_dtypes]}, got {acc.dtype}" + ) + + # Check int8 inputs require int32 accumulator + if mat1.dtype == torch.int8 or mat2.dtype == torch.int8: + if acc.dtype != torch.int32: + raise TypeError( + f"hl.dot: int8 inputs require int32 accumulator, got {acc.dtype}" + ) + + # Check accumulator shape compatibility + expected_shape = list(mat1.shape) + expected_shape[-1] = mat2.shape[-1] + + if acc.ndim not in (2, 3): + raise ValueError(f"hl.dot: acc must be 2D or 3D tensor, got {acc.ndim}D") + + if list(acc.shape) != expected_shape: + raise ValueError( + f"hl.dot: acc shape {list(acc.shape)} incompatible with result shape {expected_shape}" + ) + + return (mat1, mat2, acc) + + +def _compute_out_dtype( + mat1_dtype: torch.dtype, + mat2_dtype: torch.dtype, + acc_dtype: torch.dtype | None = None, +) -> torch.dtype: + """Compute the output dtype for dot operation.""" + if acc_dtype is not None: + # If accumulator is provided, use its dtype + return acc_dtype + + # When no accumulator is specified: + # For int8 inputs, default to int32 + if mat1_dtype == torch.int8 or mat2_dtype == torch.int8: + return torch.int32 + # For all other inputs (including FP8), default to float32 + return torch.float32 + + +@_decorators.register_fake(dot) +def _( + mat1: torch.Tensor, mat2: torch.Tensor, acc: torch.Tensor | None = None +) -> torch.Tensor: + # Matrix multiplication shape computation + result_shape = list(mat1.shape) + result_shape[-1] = mat2.shape[-1] + + if acc is not None: + return acc.new_empty(result_shape) + + # Determine output dtype using the helper function + out_dtype = _compute_out_dtype(mat1.dtype, mat2.dtype) + return torch.empty(result_shape, dtype=out_dtype, device=mat1.device) + + +@_decorators.codegen(dot) +def _(state: CodegenState) -> object: + # Import here to avoid circular imports + from .._compiler.ast_extension import expr_from_string + from .._compiler.compile_environment import CompileEnvironment + + # Get the AST representations of our arguments + lhs_ast = state.ast_arg(0) + rhs_ast = state.ast_arg(1) + acc_ast = state.ast_arg(2) + + # Get the dtypes of the inputs from proxy args + lhs_proxy = state.proxy_args[0] + assert isinstance(lhs_proxy, FakeTensor), "lhs_proxy must be a FakeTensor" + rhs_proxy = state.proxy_args[1] + assert isinstance(rhs_proxy, FakeTensor), "rhs_proxy must be a FakeTensor" + acc_proxy = state.proxy_args[2] if len(state.proxy_args) > 2 else None + + # Access dtype - proxy_args can be FakeTensor objects + lhs_dtype = None + rhs_dtype = None + acc_dtype = None + + # For FakeTensor objects, dtype is directly accessible + lhs_dtype = lhs_proxy.dtype + rhs_dtype = rhs_proxy.dtype + + # Get accumulator dtype if available + if acc_proxy is not None: + assert isinstance(acc_proxy, FakeTensor), "acc_proxy must be a FakeTensor" + acc_dtype = acc_proxy.dtype + + # Check if accumulator is None + is_acc_none = isinstance(acc_ast, ast.Constant) and acc_ast.value is None + + # Determine output dtype using the helper function + out_dtype = _compute_out_dtype( + lhs_dtype, rhs_dtype, None if is_acc_none else acc_dtype + ) + + return expr_from_string( + f"tl.dot(lhs, rhs, acc=acc, input_precision='{CompileEnvironment.current().settings.dot_precision}', out_dtype={triton_type(out_dtype)})", + lhs=lhs_ast, + rhs=rhs_ast, + acc=acc_ast, + ) diff --git a/test/test_dot.expected b/test/test_dot.expected new file mode 100644 index 00000000..4355b2da --- /dev/null +++ b/test/test_dot.expected @@ -0,0 +1,1286 @@ +This file is automatically generated by assertExpectedJournal calls in test_dot.py. +Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. + +--- assertExpectedJournal(TestDot.test_input_bfloat16_acc_None_dynamic_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +import test.test_dot as _source_module + +@triton.jit +def _dot_kernel_no_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0) + dot = tl.dot(load, load_1, acc=None, input_precision='tf32', out_dtype=tl.float32) + acc = acc_copy_0 + dot + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_no_acc_arg(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + if x.dtype == torch.int8: + acc_dtype = torch.int32 + else: + acc_dtype = torch.float32 + out = torch.empty([m, n], dtype=torch.float32, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_no_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_bfloat16_acc_None_static_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +import test.test_dot as _source_module + +@triton.jit +def _dot_kernel_no_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0) + dot = tl.dot(load, load_1, acc=None, input_precision='tf32', out_dtype=tl.float32) + acc = acc_copy_0 + dot + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_no_acc_arg(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + if x.dtype == torch.int8: + acc_dtype = torch.int32 + else: + acc_dtype = torch.float32 + out = torch.empty([m, n], dtype=torch.float32, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_no_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_bfloat16_acc_float32_dynamic_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _dot_kernel_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0) + acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=acc_dtype, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_bfloat16_acc_float32_static_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _dot_kernel_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0) + acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=acc_dtype, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_float16_acc_None_dynamic_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +import test.test_dot as _source_module + +@triton.jit +def _dot_kernel_no_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0) + dot = tl.dot(load, load_1, acc=None, input_precision='tf32', out_dtype=tl.float32) + acc = acc_copy_0 + dot + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_no_acc_arg(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + if x.dtype == torch.int8: + acc_dtype = torch.int32 + else: + acc_dtype = torch.float32 + out = torch.empty([m, n], dtype=torch.float32, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_no_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_float16_acc_None_static_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +import test.test_dot as _source_module + +@triton.jit +def _dot_kernel_no_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0) + dot = tl.dot(load, load_1, acc=None, input_precision='tf32', out_dtype=tl.float32) + acc = acc_copy_0 + dot + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_no_acc_arg(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + if x.dtype == torch.int8: + acc_dtype = torch.int32 + else: + acc_dtype = torch.float32 + out = torch.empty([m, n], dtype=torch.float32, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_no_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_float16_acc_float16_dynamic_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _dot_kernel_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float16) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0) + acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float16) + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=acc_dtype, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_float16_acc_float16_static_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _dot_kernel_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float16) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0) + acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float16) + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=acc_dtype, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_float16_acc_float32_dynamic_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _dot_kernel_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0) + acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=acc_dtype, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_float16_acc_float32_static_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _dot_kernel_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0) + acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=acc_dtype, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_float32_acc_None_dynamic_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +import test.test_dot as _source_module + +@triton.jit +def _dot_kernel_no_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0) + dot = tl.dot(load, load_1, acc=None, input_precision='tf32', out_dtype=tl.float32) + acc = acc_copy_0 + dot + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_no_acc_arg(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + if x.dtype == torch.int8: + acc_dtype = torch.int32 + else: + acc_dtype = torch.float32 + out = torch.empty([m, n], dtype=torch.float32, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_no_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_float32_acc_None_static_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +import test.test_dot as _source_module + +@triton.jit +def _dot_kernel_no_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0) + dot = tl.dot(load, load_1, acc=None, input_precision='tf32', out_dtype=tl.float32) + acc = acc_copy_0 + dot + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_no_acc_arg(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + if x.dtype == torch.int8: + acc_dtype = torch.int32 + else: + acc_dtype = torch.float32 + out = torch.empty([m, n], dtype=torch.float32, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_no_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_float32_acc_float32_dynamic_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _dot_kernel_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0) + acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=acc_dtype, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_float32_acc_float32_static_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _dot_kernel_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0) + acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=acc_dtype, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_float8_e4m3fn_acc_None_dynamic_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +import test.test_dot as _source_module + +@triton.jit +def _dot_kernel_no_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0.0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0.0) + dot = tl.dot(load, load_1, acc=None, input_precision='tf32', out_dtype=tl.float32) + acc = acc_copy_0 + dot + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_no_acc_arg(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + if x.dtype == torch.int8: + acc_dtype = torch.int32 + else: + acc_dtype = torch.float32 + out = torch.empty([m, n], dtype=torch.float32, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_no_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_float8_e4m3fn_acc_None_static_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +import test.test_dot as _source_module + +@triton.jit +def _dot_kernel_no_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0.0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0.0) + dot = tl.dot(load, load_1, acc=None, input_precision='tf32', out_dtype=tl.float32) + acc = acc_copy_0 + dot + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_no_acc_arg(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + if x.dtype == torch.int8: + acc_dtype = torch.int32 + else: + acc_dtype = torch.float32 + out = torch.empty([m, n], dtype=torch.float32, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_no_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_float8_e4m3fn_acc_float16_dynamic_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _dot_kernel_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float16) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0.0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0.0) + acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float16) + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=acc_dtype, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_float8_e4m3fn_acc_float16_static_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _dot_kernel_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float16) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0.0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0.0) + acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float16) + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=acc_dtype, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_float8_e4m3fn_acc_float32_dynamic_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _dot_kernel_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0.0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0.0) + acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=acc_dtype, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_float8_e4m3fn_acc_float32_static_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _dot_kernel_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0.0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0.0) + acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=acc_dtype, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_float8_e5m2_acc_None_dynamic_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +import test.test_dot as _source_module + +@triton.jit +def _dot_kernel_no_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0.0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0.0) + dot = tl.dot(load, load_1, acc=None, input_precision='tf32', out_dtype=tl.float32) + acc = acc_copy_0 + dot + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_no_acc_arg(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + if x.dtype == torch.int8: + acc_dtype = torch.int32 + else: + acc_dtype = torch.float32 + out = torch.empty([m, n], dtype=torch.float32, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_no_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_float8_e5m2_acc_None_static_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +import test.test_dot as _source_module + +@triton.jit +def _dot_kernel_no_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0.0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0.0) + dot = tl.dot(load, load_1, acc=None, input_precision='tf32', out_dtype=tl.float32) + acc = acc_copy_0 + dot + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_no_acc_arg(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + if x.dtype == torch.int8: + acc_dtype = torch.int32 + else: + acc_dtype = torch.float32 + out = torch.empty([m, n], dtype=torch.float32, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_no_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_float8_e5m2_acc_float16_dynamic_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _dot_kernel_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float16) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0.0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0.0) + acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float16) + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=acc_dtype, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_float8_e5m2_acc_float16_static_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _dot_kernel_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float16) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0.0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0.0) + acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float16) + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=acc_dtype, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_float8_e5m2_acc_float32_dynamic_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _dot_kernel_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0.0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0.0) + acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=acc_dtype, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_float8_e5m2_acc_float32_static_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _dot_kernel_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0.0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0.0) + acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=acc_dtype, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_int8_acc_None_dynamic_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +import test.test_dot as _source_module + +@triton.jit +def _dot_kernel_no_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0, tl.int32) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0) + dot = tl.dot(load, load_1, acc=None, input_precision='tf32', out_dtype=tl.int32) + acc = acc_copy_0 + dot + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_no_acc_arg(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + if x.dtype == torch.int8: + acc_dtype = torch.int32 + else: + acc_dtype = torch.float32 + out = torch.empty([m, n], dtype=torch.int32, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_no_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_int8_acc_None_static_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +import test.test_dot as _source_module + +@triton.jit +def _dot_kernel_no_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0, tl.int32) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0) + dot = tl.dot(load, load_1, acc=None, input_precision='tf32', out_dtype=tl.int32) + acc = acc_copy_0 + dot + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_no_acc_arg(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + if x.dtype == torch.int8: + acc_dtype = torch.int32 + else: + acc_dtype = torch.float32 + out = torch.empty([m, n], dtype=torch.int32, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_no_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_int8_acc_int32_dynamic_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _dot_kernel_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0, tl.int32) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0) + acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32', out_dtype=tl.int32) + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=acc_dtype, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestDot.test_input_int8_acc_int32_static_shape) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _dot_kernel_acc_arg_kernel(x, y, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, y_stride_0, y_stride_1, m, n, k, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(m, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0, tl.int32) + for offset_2 in tl.range(0, k.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < k + acc_copy = acc + acc_copy_0 = acc_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + load_1 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0) + acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32', out_dtype=tl.int32) + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def dot_kernel_acc_arg(x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype, *, _launcher=_default_launcher): + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=acc_dtype, device=x.device) + _BLOCK_SIZE_0 = 32 + _BLOCK_SIZE_1 = 32 + _BLOCK_SIZE_2 = 32 + _launcher(_dot_kernel_acc_arg_kernel, (triton.cdiv(m, _BLOCK_SIZE_0) * triton.cdiv(n, _BLOCK_SIZE_1),), x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), m, n, k, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out diff --git a/test/test_dot.py b/test/test_dot.py new file mode 100644 index 00000000..e152155a --- /dev/null +++ b/test/test_dot.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +import itertools +import unittest + +import torch +import triton + +import helion +from helion._testing import DEVICE +from helion._testing import TestCase +from helion._testing import code_and_output +import helion.language as hl + + +@helion.kernel(config=helion.Config(block_sizes=[32, 32, 32]), dot_precision="tf32") +def dot_kernel_acc_arg( + x: torch.Tensor, y: torch.Tensor, acc_dtype: torch.dtype +) -> torch.Tensor: + m, k = x.size() + _, n = y.size() + out = torch.empty([m, n], dtype=acc_dtype, device=x.device) + + for tile_m, tile_n in hl.tile([m, n]): + acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) + for tile_k in hl.tile(k): + acc = hl.dot(x[tile_m, tile_k], y[tile_k, tile_n], acc=acc) + out[tile_m, tile_n] = acc + return out + + +@helion.kernel(config=helion.Config(block_sizes=[32, 32, 32]), dot_precision="tf32") +def dot_kernel_no_acc_arg(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + m, k = x.size() + _, n = y.size() + if x.dtype == torch.int8: + acc_dtype = torch.int32 + else: + acc_dtype = torch.float32 + out = torch.empty([m, n], dtype=acc_dtype, device=x.device) + + for tile_m, tile_n in hl.tile([m, n]): + acc = hl.zeros([tile_m, tile_n], dtype=acc_dtype) + for tile_k in hl.tile(k): + acc += hl.dot(x[tile_m, tile_k], y[tile_k, tile_n]) + out[tile_m, tile_n] = acc + return out + + +# Define test parameters +INPUT_DTYPES = [ + torch.float16, + torch.bfloat16, + torch.float32, + torch.int8, + torch.float8_e4m3fn, + torch.float8_e5m2, +] +ACC_DTYPES = [None, torch.float16, torch.float32, torch.int32] +STATIC_SHAPES_OPTIONS = [True, False] + +# Define expected failures +EXPECTED_FAILURES = { + # int8 requires int32 accumulator + (torch.int8, torch.int8, torch.float16), + (torch.int8, torch.int8, torch.float32), + # float16 accumulator only supported with float16 or fp8 inputs (Triton constraint) + (torch.float32, torch.float32, torch.float16), + (torch.bfloat16, torch.bfloat16, torch.float16), + # int32 accumulator only supported for int8 inputs + (torch.float16, torch.float16, torch.int32), + (torch.float32, torch.float32, torch.int32), + (torch.bfloat16, torch.bfloat16, torch.int32), + (torch.float8_e4m3fn, torch.float8_e4m3fn, torch.int32), + (torch.float8_e5m2, torch.float8_e5m2, torch.int32), +} + + +def make_test_function(input_dtype, acc_dtype, static_shapes_option): + """Create a test function for a specific combination of parameters.""" + combo = (input_dtype, input_dtype, acc_dtype) + + def test_impl(self): + # Skip FP8 tests if GPU doesn't support it + if ( + input_dtype in (torch.float8_e4m3fn, torch.float8_e5m2) + and torch.cuda.get_device_capability(0)[0] < 9 + ): + self.skipTest(f"FP8 dtype {input_dtype} not supported on this GPU") + + # Create test tensors + if input_dtype == torch.int8: + x = torch.randint(-10, 10, (64, 64), device=DEVICE, dtype=input_dtype) + y = torch.randint(-10, 10, (64, 64), device=DEVICE, dtype=input_dtype) + elif input_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + x = torch.randn(64, 64, device=DEVICE, dtype=torch.float32) * 0.5 + y = torch.randn(64, 64, device=DEVICE, dtype=torch.float32) * 0.5 + x = x.to(input_dtype) + y = y.to(input_dtype) + else: + x = torch.randn(64, 64, device=DEVICE, dtype=input_dtype) + y = torch.randn(64, 64, device=DEVICE, dtype=input_dtype) + + def run_kernel(): + if acc_dtype is None: + dot_kernel_no_acc_arg._static_shapes = static_shapes_option + return code_and_output(dot_kernel_no_acc_arg, (x, y)) + dot_kernel_acc_arg._static_shapes = static_shapes_option + return code_and_output(dot_kernel_acc_arg, (x, y, acc_dtype)) + + # Check if this combination should fail + if combo in EXPECTED_FAILURES: + # Use assertRaises for expected failures + with self.assertRaises( + ( + triton.compiler.errors.CompilationError, + RuntimeError, + helion.exc.InternalError, + ValueError, + OSError, + ) + ): + code, result = run_kernel() + return + + # Normal test execution for non-failing cases + code, result = run_kernel() + + # Compute expected result based on accumulator dtype + if input_dtype == torch.int8: + expected = (x.cpu().to(torch.int32) @ y.cpu().to(torch.int32)).to(DEVICE) + else: + # For floating point, compute in float32 for accuracy + x_f32 = x.to(torch.float32) + y_f32 = y.to(torch.float32) + expected = x_f32 @ y_f32 + + # Convert expected to match kernel output dtype + if acc_dtype == torch.float16: + expected = expected.to(torch.float16) + elif acc_dtype == torch.int32: + expected = expected.to(torch.int32) + # else: already float32 for acc_f32 or implicit float32 acc + + # Check result with appropriate tolerance + if input_dtype == torch.int8: + torch.testing.assert_close(result, expected, atol=0, rtol=0) + elif input_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + # FP8 has lower precision, use higher tolerance + torch.testing.assert_close(result, expected, atol=5e-3, rtol=0.5) + elif input_dtype == torch.float16 and acc_dtype == torch.float16: + # Use higher tolerance when accumulator is float16 due to precision limits + torch.testing.assert_close(result, expected, atol=5e-3, rtol=0.5) + elif input_dtype == torch.float32: + # Use higher tolerance for TF32 mode + torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-1) + else: + torch.testing.assert_close(result, expected) + + # Verify generated code matches expected + self.assertExpectedJournal(code) + + return test_impl + + +class TestDot(TestCase): + pass + + +# Dynamically generate test methods +for input_dtype, acc_dtype, static_shapes_option in itertools.product( + INPUT_DTYPES, ACC_DTYPES, STATIC_SHAPES_OPTIONS +): + # Create test method name + input_dtype_name = str(input_dtype).split(".")[-1] + acc_dtype_name = "None" if acc_dtype is None else str(acc_dtype).split(".")[-1] + static_shapes_name = "static_shape" if static_shapes_option else "dynamic_shape" + test_name = ( + f"test_input_{input_dtype_name}_acc_{acc_dtype_name}_{static_shapes_name}" + ) + + # Create and add the test method + _test_func = make_test_function(input_dtype, acc_dtype, static_shapes_option) + _test_func.__name__ = test_name + setattr(TestDot, test_name, _test_func) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_examples.expected b/test/test_examples.expected index 9e5b39c0..68e922c9 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -608,41 +608,38 @@ def _fp8_attention_kernel_kernel(q, k, v, out, out_stride_0, heads, _RDIM_SIZE_2 acc_copy_0 = acc_copy k_tile = tl.load(k + (offset_0 * 16384 + indices_2[:, None] * 64 + indices_5[None, :] * 1), None) k_tile_t = tl.permute(k_tile, [1, 0]) - mm = tl.dot(q_tile_copy_0, k_tile_t) - v_0 = mm.to(tl.float32) - v_1 = 0.18033688 - v_2 = v_0 * v_1 - qk_max = tl.max(v_2, 1) - v_3 = triton_helpers.maximum(m_i_copy_0, qk_max) - subscript = v_3[:, None] - v_4 = v_2 - subscript - v_5 = libdevice.exp2(v_4) - l_ij = tl.sum(v_5, 1) - v_6 = m_i_copy_0 - v_3 - v_7 = libdevice.exp2(v_6) - v_8 = l_i_copy_0 * v_7 - l_i = v_8 + l_ij - subscript_1 = v_7[:, None] - v_10 = acc_copy_0 * subscript_1 + qk = tl.dot(q_tile_copy_0, k_tile_t, acc=None, input_precision='tf32', out_dtype=tl.float32) + v_0 = 0.18033688 + v_1 = qk * v_0 + qk_max = tl.max(v_1, 1) + v_2 = triton_helpers.maximum(m_i_copy_0, qk_max) + subscript = v_2[:, None] + v_3 = v_1 - subscript + v_4 = libdevice.exp2(v_3) + l_ij = tl.sum(v_4, 1) + v_5 = m_i_copy_0 - v_2 + v_6 = libdevice.exp2(v_5) + v_7 = l_i_copy_0 * v_6 + l_i = v_7 + l_ij + subscript_1 = v_6[:, None] + v_9 = acc_copy_0 * subscript_1 v_tile = tl.load(v + (offset_0 * 16384 + indices_5[:, None] * 1 + indices_2[None, :] * 64), None) - v_11 = v_5.to(tl.float8e5) + v_10 = v_4.to(tl.float8e4nv) v_t = tl.permute(v_tile, [1, 0]) - mm_1 = tl.dot(v_11, v_t) - v_12 = mm_1.to(tl.float32) - acc = v_10 + v_12 - m_i = v_3 + acc = tl.dot(v_10, v_t, acc=v_9, input_precision='tf32', out_dtype=tl.float32) + m_i = v_2 subscript_2 = l_i[:, None] - v_14 = acc / subscript_2 - v_15 = v_14.to(tl.float8e5) + v_11 = acc / subscript_2 + v_12 = v_11.to(tl.float8e4nv) symnode_0 = triton_helpers.div_floor_integer(offset_0, heads) symnode_1 = triton_helpers.remainder_integer(offset_0, heads) - tl.store(out + (symnode_0 * out_stride_0 + symnode_1 * 16384 + indices_4[:, None] * 64 + indices_5[None, :] * 1), v_15, None) + tl.store(out + (symnode_0 * out_stride_0 + symnode_1 * 16384 + indices_4[:, None] * 64 + indices_5[None, :] * 1), v_12, None) def fp8_attention_kernel(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, batch: int, heads: int, *, _launcher=_default_launcher): batch_heads = q.size(0) seq_len = q.size(1) head_dim = q.size(2) - out = torch.empty([batch, heads, seq_len, head_dim], dtype=torch.float8_e5m2, device=q.device) + out = torch.empty([batch, heads, seq_len, head_dim], dtype=torch.float8_e4m3fn, device=q.device) sm_scale = 1.0 / math.sqrt(float(head_dim)) sm_scale = sm_scale * 1.44269504 _RDIM_SIZE_2 = 64 @@ -675,11 +672,9 @@ def _fp8_gemm_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.c acc_copy_0 = acc_copy x_tile = tl.load(x + (indices_0[:, None] * 256 + indices_2[None, :] * 1), None) y_tile = tl.load(y + (indices_2[:, None] * 256 + indices_1[None, :] * 1), None) - mm = tl.dot(x_tile, y_tile) - v_0 = mm.to(tl.float32) - acc = acc_copy_0 + v_0 - v_2 = acc.to(tl.float16) - tl.store(out + (indices_0[:, None] * 256 + indices_1[None, :] * 1), v_2, None) + acc = tl.dot(x_tile, y_tile, acc=acc_copy_0, input_precision='tf32', out_dtype=tl.float32) + v_0 = acc.to(tl.float16) + tl.store(out + (indices_0[:, None] * 256 + indices_1[None, :] * 1), v_0, None) def fp8_gemm(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): """FP8 General Matrix Multiplication (GEMM). From 06840864460df0300a8685fe4508bb11a7e1a7b6 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Fri, 18 Jul 2025 21:36:25 -0700 Subject: [PATCH 2/3] [Ref Mode] PyTorch reference mode (eager only) Part of https://github.com/pytorch-labs/helion/issues/77. Please see inline code comments on the PR. stack-info: PR: https://github.com/pytorch-labs/helion/pull/339, branch: yf225/stack/34 --- helion/__init__.py | 2 + helion/_testing.py | 9 ++ helion/language/_decorators.py | 27 +++++ helion/language/constexpr.py | 5 + helion/language/creation_ops.py | 10 ++ helion/language/device_print.py | 5 + helion/language/inline_asm_ops.py | 12 ++ helion/language/loops.py | 139 ++++++++++++++++++++++ helion/language/matmul_ops.py | 50 ++++++++ helion/language/memory_ops.py | 184 ++++++++++++++++++++++++++++++ helion/language/reduce_ops.py | 53 +++++++++ helion/language/scan_ops.py | 129 +++++++++++++++++++++ helion/language/signal_wait.py | 28 +++++ helion/language/tile_ops.py | 42 +++++++ helion/language/tunable_ops.py | 19 +++ helion/language/view_ops.py | 14 +++ helion/runtime/kernel.py | 66 +++++++++-- helion/runtime/ref_mode.py | 152 ++++++++++++++++++++++++ helion/runtime/settings.py | 22 ++++ test/ref_utils.py | 36 ++++++ test/test_ref_eager.py | 171 +++++++++++++++++++++++++++ 21 files changed, 1167 insertions(+), 8 deletions(-) create mode 100644 helion/runtime/ref_mode.py create mode 100644 test/ref_utils.py create mode 100644 test/test_ref_eager.py diff --git a/helion/__init__.py b/helion/__init__.py index 1363eb20..728f983d 100644 --- a/helion/__init__.py +++ b/helion/__init__.py @@ -11,12 +11,14 @@ from .runtime import Kernel from .runtime import kernel from .runtime import kernel as jit # alias +from .runtime.settings import RefMode from .runtime.settings import Settings from .runtime.settings import set_default_settings __all__ = [ "Config", "Kernel", + "RefMode", "Settings", "cdiv", "exc", diff --git a/helion/_testing.py b/helion/_testing.py index c323152d..8315289e 100644 --- a/helion/_testing.py +++ b/helion/_testing.py @@ -45,6 +45,15 @@ def code_and_output( args: tuple[object, ...], **kwargs: object, ) -> tuple[str, object]: + bound = fn.bind(args) + from helion.runtime.settings import RefMode + + if bound.kernel.settings.ref_mode != RefMode.OFF: + result = fn(*args) + # Return the original kernel source code + code = inspect.getsource(fn.fn) + return code, result + if kwargs: config = Config( **kwargs # pyright: ignore[reportArgumentType] diff --git a/helion/language/_decorators.py b/helion/language/_decorators.py index fbddb640..a82c1026 100644 --- a/helion/language/_decorators.py +++ b/helion/language/_decorators.py @@ -79,6 +79,7 @@ class APIFunc(Protocol): _to_device_ir: Callable[..., object] | None _allow_host_tensor: bool _signature: inspect.Signature + _ref_fn: Callable[..., object] | None def __call__(self, *args: object, **kwargs: object) -> object: ... @@ -133,6 +134,15 @@ def api( def _impl(fn: _C) -> _C: @functools.wraps(fn) def wrapper(*args: object, **kwargs: object) -> object: + from ..runtime.ref_mode import is_ref_mode_enabled + + if is_ref_mode_enabled() and api._ref_fn is not None: + # In ref mode, use the registered ref implementation + bound = api._signature.bind(*args, **kwargs) + bound.apply_defaults() + flat_args = api._prepare_args(*bound.arguments.values()) + return api._ref_fn(*flat_args) + bound = api._signature.bind(*args, **kwargs) bound.apply_defaults() flat_args = api._prepare_args(*bound.arguments.values()) @@ -187,6 +197,7 @@ def wrapper(*args: object, **kwargs: object) -> object: api._signature = signature or inspect.signature( cast("Callable[..., object]", fn) ) + api._ref_fn = None return wrapper # pyright: ignore[reportReturnType] return _impl @@ -289,6 +300,22 @@ def _impl(to_device_ir_fn: Callable[..., object]) -> Callable[..., Never]: return _impl # pyright: ignore[reportReturnType] +def ref( + original_fn: Callable[..., object], +) -> _NoReturnDecorator[object]: + def _impl(ref_fn: Callable[..., object]) -> Callable[..., Never]: + assert is_api_func(original_fn), ( + f"{ref.__qualname__} can only be used on API functions" + ) + assert original_fn._ref_fn is None, ( + "ref mode implementation can only be registered once per function" + ) + original_fn._ref_fn = ref_fn + return _no_call + + return _impl # pyright: ignore[reportReturnType] + + def _default_type_function( fake_fn: Callable[..., object], tiles_as_sizes: bool ) -> Callable[..., TypeInfo]: diff --git a/helion/language/constexpr.py b/helion/language/constexpr.py index 8b2ae084..e4170fd9 100644 --- a/helion/language/constexpr.py +++ b/helion/language/constexpr.py @@ -95,3 +95,8 @@ def _(state: CodegenState) -> ast.AST: value = value.__int__() assert isinstance(value, int) return expr_from_string(repr(value)) + + +@_decorators.ref(specialize) +def _(value: int | torch.SymInt) -> int: + return int(value) diff --git a/helion/language/creation_ops.py b/helion/language/creation_ops.py index b22dc090..eeb022d3 100644 --- a/helion/language/creation_ops.py +++ b/helion/language/creation_ops.py @@ -144,6 +144,16 @@ def _( return None +@_decorators.ref(full) +def _( + shape: list[int | slice], + value: float, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + processed_shape = [s.stop - s.start if isinstance(s, slice) else s for s in shape] + return torch.full(processed_shape, value, dtype=dtype, device="cuda") + + def arange( *args: int, dtype: torch.dtype | None = None, diff --git a/helion/language/device_print.py b/helion/language/device_print.py index 60170b7d..7747d35c 100644 --- a/helion/language/device_print.py +++ b/helion/language/device_print.py @@ -90,3 +90,8 @@ def _(state: CodegenState) -> None: ) stmt = create(ast.Expr, value=call_expr) state.add_statement(stmt) + + +@_decorators.ref(device_print) +def _(prefix: str, *args: object) -> None: + print(prefix, *args) diff --git a/helion/language/inline_asm_ops.py b/helion/language/inline_asm_ops.py index 936acd0d..b80310d5 100644 --- a/helion/language/inline_asm_ops.py +++ b/helion/language/inline_asm_ops.py @@ -205,3 +205,15 @@ def _(state: CodegenState) -> ast.AST | list[ast.AST]: ] return inline_asm_call + + +@_decorators.ref(inline_asm_elementwise) +def _( + asm: str, + constraints: str, + args: Sequence[torch.Tensor], + dtype: torch.dtype | Sequence[torch.dtype], + is_pure: bool, + pack: int, +) -> torch.Tensor | tuple[torch.Tensor, ...]: + raise NotImplementedError("inline_asm_elementwise is not supported in ref mode") diff --git a/helion/language/loops.py b/helion/language/loops.py index 03d874e8..20387bd9 100644 --- a/helion/language/loops.py +++ b/helion/language/loops.py @@ -3,6 +3,7 @@ import ast import builtins import inspect +import itertools from itertools import starmap from typing import TYPE_CHECKING from typing import Iterator @@ -449,6 +450,93 @@ def _(state: CodegenState) -> ast.AST: return _codegen_loop_helper(state) +@_decorators.ref(tile) +def _( + begin_or_end: int | torch.Tensor | list[int | torch.Tensor], + end_or_none: int | torch.Tensor | list[int | torch.Tensor] | None = None, + block_size: int | torch.Tensor | list[int | torch.Tensor] | None = None, +) -> Iterator[slice | tuple[slice, ...]]: + # Convert tensor values to int + def _to_int(value: int | torch.Tensor | None) -> int | None: + if value is None: + return None + if isinstance(value, torch.Tensor): + return int(value.item()) + return int(value) + + # Step 1: Normalize begin and end values based on the number of arguments + if end_or_none is not None: + # Two positional args: begin_or_end is begin, end_or_none is end + begin = begin_or_end + end = end_or_none + else: + # One positional arg: begin_or_end is end, begin defaults to 0 + end = begin_or_end + # Create begin with same structure as end, but all zeros + if isinstance(end, (list, tuple)): + begin = cast( + "int | torch.Tensor | list[int | torch.Tensor]", [0] * len(end) + ) + else: + begin = 0 + + # Step 2: Convert inputs to lists for uniform handling + def _normalize_to_list( + value: int | torch.Tensor | list[int | torch.Tensor], + ) -> list[int | torch.Tensor]: + if isinstance(value, (list, tuple)): + return list(value) + return [value] + + begin_list = _normalize_to_list(begin) + end_list = _normalize_to_list(end) + + # Convert all values to int + begin_list = [_to_int(b) for b in begin_list] + end_list = [_to_int(e) for e in end_list] + + # Step 3: Determine block_size based on the arguments + if block_size is None: + # Default block_size to end - begin for each dimension + block_size_list = [] + for b, e in zip(begin_list, end_list, strict=False): + assert b is not None and e is not None + block_size_list.append(e - b) + else: + block_size_list = _normalize_to_list(block_size) + processed_block_size_list = [] + for bs, b, e in zip(block_size_list, begin_list, end_list, strict=False): + assert b is not None and e is not None + if bs is not None: + processed_bs = _to_int(bs) + assert processed_bs is not None + processed_block_size_list.append(processed_bs) + else: + processed_block_size_list.append(e - b) + block_size_list = processed_block_size_list + + # Step 4: Yield tile ranges + # Handle single dimension case + if len(begin_list) == 1: + b = begin_list[0] + e = end_list[0] + bs = block_size_list[0] + assert b is not None and e is not None and bs is not None + for i in range(b, e, bs): + yield slice(i, min(i + bs, e)) + else: + # Handle multi-dimensional case + ranges = [] + for b, e, bs in zip(begin_list, end_list, block_size_list, strict=False): + dim_ranges = [] + assert b is not None and e is not None and bs is not None + for i in range(b, e, bs): + dim_ranges.append(slice(i, min(i + bs, e))) + ranges.append(dim_ranges) + + yield from itertools.product(*ranges) + + def _codegen_loop_helper( state: CodegenState, ) -> ast.AST: @@ -637,6 +725,46 @@ def _(state: CodegenState) -> ast.AST: return _codegen_loop_helper(state) +@_decorators.ref(grid) +def _( + begin_or_end: int | torch.Tensor | list[int | torch.Tensor], + end_or_none: int | torch.Tensor | list[int | torch.Tensor] | None = None, + step: object = None, +) -> range | Iterator[tuple[int, ...]]: + # Similar to tile but yields indices instead of slices + if end_or_none is not None: + begin = begin_or_end + end = end_or_none + else: + end = begin_or_end + if isinstance(end, (list, tuple)): + begin = cast( + "int | torch.Tensor | list[int | torch.Tensor]", [0] * len(end) + ) + else: + begin = 0 + + # Convert tensor values to int + def _to_int(value: int | torch.Tensor) -> int: + if isinstance(value, torch.Tensor): + return int(value.item()) + return int(value) + + # Handle single dimension + if not isinstance(begin, (list, tuple)): + begin_int = _to_int(begin) + assert not isinstance(end, (list, tuple)) + end_int = _to_int(end) + return range(begin_int, end_int) + + # Handle multi-dimensional + assert isinstance(end, (list, tuple)) + begin_ints = [_to_int(b) for b in begin] + end_ints = [_to_int(e) for e in end] + ranges = list(itertools.starmap(range, zip(begin_ints, end_ints, strict=False))) + return itertools.product(*ranges) + + @_decorators.device_func_replacement(builtins.zip) @_decorators.api(is_device_only=True, cache_type=True) def _zip_replacement( @@ -898,3 +1026,14 @@ def _( # Return tuple(range(...)) which will trigger existing tuple/list unrolling return tuple(range(begin_val, end_val, step)) + + +@_decorators.ref(static_range) +def _( + begin_or_end: int, + end_or_none: int | None = None, + step: int = 1, +) -> range: + if end_or_none is not None: + return range(begin_or_end, end_or_none, step) + return range(begin_or_end) diff --git a/helion/language/matmul_ops.py b/helion/language/matmul_ops.py index 25e827d0..fa7044a1 100644 --- a/helion/language/matmul_ops.py +++ b/helion/language/matmul_ops.py @@ -209,3 +209,53 @@ def _(state: CodegenState) -> object: rhs=rhs_ast, acc=acc_ast, ) + + +@_decorators.ref(dot) +def _( + mat1: torch.Tensor, + mat2: torch.Tensor, + acc: torch.Tensor | None = None, +) -> torch.Tensor: + """Reference implementation for hl.dot() in ref mode.""" + is_fp8 = mat1.dtype in (torch.float8_e4m3fn, torch.float8_e5m2) or mat2.dtype in ( + torch.float8_e4m3fn, + torch.float8_e5m2, + ) + + if is_fp8: + # Use torch._scaled_mm for FP8 operations + # Ensure column-major for second operand as required by torch._scaled_mm + mat2_t = mat2.T.contiguous().T + scale_a = torch.tensor(1.0, device=mat1.device) + scale_b = torch.tensor(1.0, device=mat2.device) + + # Determine output dtype + if acc is not None: + out_dtype = acc.dtype + else: + out_dtype = torch.float32 # Default for FP8 + + result = torch._scaled_mm( + mat1, + mat2_t, + scale_a, + scale_b, + use_fast_accum=False, + out_dtype=out_dtype, + ) + else: + # For non-FP8 tensors, use regular matmul + result = torch.matmul(mat1, mat2) + + # Handle accumulator + if acc is not None: + # Ensure result has same dtype as accumulator + if result.dtype != acc.dtype: + result = result.to(acc.dtype) + return acc + result + # Return with appropriate dtype based on inputs + out_dtype = _compute_out_dtype(mat1.dtype, mat2.dtype) + if result.dtype != out_dtype: + result = result.to(out_dtype) + return result diff --git a/helion/language/memory_ops.py b/helion/language/memory_ops.py index 3bdb529b..be92c353 100644 --- a/helion/language/memory_ops.py +++ b/helion/language/memory_ops.py @@ -18,6 +18,118 @@ __all__ = ["atomic_add", "load", "store"] +# Helper functions for ref mode implementations +def _normalize_indices(indices: slice | list | tuple) -> slice | tuple: + if isinstance(indices, slice): + return slice(indices.start, indices.stop) + if isinstance(indices, (list, tuple)): + return tuple( + slice(idx.start, idx.stop) if isinstance(idx, slice) else idx + for idx in indices + ) + return indices + + +def _combine_masks( + mask1: torch.Tensor | None, mask2: torch.Tensor | None +) -> torch.Tensor | None: + if mask1 is not None and mask2 is not None: + return mask1 & mask2 + return mask1 or mask2 + + +def _apply_mask( + result: torch.Tensor, mask: torch.Tensor | None, other: float | torch.Tensor = 0 +) -> torch.Tensor: + if mask is None: + return result + + # Handle shape mismatch + if result.shape != mask.shape: + if mask.numel() == 0 or result.numel() == 0: + return torch.zeros(mask.shape, dtype=result.dtype, device=result.device) + # Let torch handle broadcasting + + return torch.where(mask, result, other) + + +def _handle_single_tensor_index( + tensor: torch.Tensor, idx_tensor: torch.Tensor, extra_mask: torch.Tensor | None +) -> torch.Tensor: + """Handle indexing with a single tensor index (jagged array case).""" + flat_indices = idx_tensor.flatten() + clamped_indices = torch.clamp(flat_indices, 0, tensor.shape[0] - 1) + + if extra_mask is None: + return tensor[clamped_indices].reshape(idx_tensor.shape) + + # Apply mask to filter valid indices + valid_mask = extra_mask.flatten() + gathered = tensor[clamped_indices] + result = torch.zeros(idx_tensor.shape, dtype=tensor.dtype, device=tensor.device) + result_flat = result.flatten() + result_flat = torch.where(valid_mask, gathered, result_flat) + return result_flat.reshape(idx_tensor.shape) + + +def _handle_mixed_indices( + tensor: torch.Tensor, indices: tuple, extra_mask: torch.Tensor | None +) -> torch.Tensor: + """Handle mixed indexing with slices and tensors.""" + expected_shape = [] + actual_indices = [] + tensor_shape = tensor.shape + + # Build expected output shape and process indices + for i, idx in enumerate(indices): + if isinstance(idx, slice): + # Handle slice indices + shape_size = idx.stop - idx.start + expected_shape.append(shape_size) + actual_indices.append(idx) + elif isinstance(idx, torch.Tensor): + # Handle tensor indices - clamp to valid range + expected_shape.extend(idx.shape) + max_index = tensor_shape[i] - 1 if i < len(tensor_shape) else 0 + clamped_idx = torch.clamp(idx, 0, max_index) + actual_indices.append(clamped_idx) + else: + # Regular integer index + actual_indices.append(idx) + + # Perform indexing with error handling + try: + result = tensor[tuple(actual_indices)] + + # Handle shape mismatch when using extra_mask + if extra_mask is not None and result.shape != tuple(expected_shape): + result = _pad_result_to_expected_shape( + result, expected_shape, tensor.dtype, tensor.device + ) + + return result + except (RuntimeError, IndexError): + # Return zeros if indexing fails (e.g., negative indices) + return torch.zeros(expected_shape, dtype=tensor.dtype, device=tensor.device) + + +def _pad_result_to_expected_shape( + result: torch.Tensor, + expected_shape: list[int], + dtype: torch.dtype, + device: torch.device, +) -> torch.Tensor: + """Pad result tensor with zeros to match expected shape.""" + padded_result = torch.zeros(expected_shape, dtype=dtype, device=device) + + if result.numel() > 0: + # Copy valid data to padded result + slices = [slice(0, s) for s in result.shape] + padded_result[tuple(slices)] = result + + return padded_result + + @has_side_effect @_decorators.api(tiles_as_sizes=True, allow_host_tensor=True) def store( @@ -84,6 +196,22 @@ def _(state: CodegenState) -> ast.AST: ) +@_decorators.ref(store) +def _( + tensor: torch.Tensor, + indices: list[object], + value: torch.Tensor, + extra_mask: torch.Tensor | None = None, +) -> None: + normalized_indices = _normalize_indices(indices) + + if extra_mask is not None: + current = tensor[normalized_indices] + tensor[normalized_indices] = torch.where(extra_mask, value, current) + else: + tensor[normalized_indices] = value + + @_decorators.api(tiles_as_sizes=True, allow_host_tensor=True) def load( tensor: torch.Tensor, index: list[object], extra_mask: torch.Tensor | None = None @@ -129,6 +257,35 @@ def _(node: torch.fx.Node) -> int: return 0 # loads are always masked to 0 +@_decorators.ref(load) +def _( + tensor: torch.Tensor, + indices: list[object], + extra_mask: torch.Tensor | None = None, +) -> torch.Tensor: + # Combined mask handling is done inside load logic + mask = None # No base mask for ref mode + other = 0 + + assert isinstance(indices, (list, tuple)) + + # Case 1: Single tensor index (jagged indexing) + if len(indices) == 1 and isinstance(indices[0], torch.Tensor): + result = _handle_single_tensor_index(tensor, indices[0], extra_mask) + + # Case 2: Mixed indices containing slices (tiles) + elif any(isinstance(idx, slice) for idx in indices): + result = _handle_mixed_indices(tensor, tuple(indices), extra_mask) + else: + raise exc.InvalidIndexingType( + f"Invalid indices type: {indices}. Expected a list of slices or tensors." + ) + + # Apply mask + combined_mask = _combine_masks(mask, extra_mask) + return _apply_mask(result, combined_mask, other) + + @has_side_effect @_decorators.api(allow_host_tensor=True) def atomic_add( @@ -233,3 +390,30 @@ def _(state: CodegenState) -> ast.AST: mask=indices.mask_expr, sem=sem, ) + + +@_decorators.ref(atomic_add) +def _( + tensor: torch.Tensor, + indices: list[object], + value: torch.Tensor | float, + sem: str = "relaxed", +) -> None: + # Special handling for scatter-add pattern (`tensor[tensor_idx, slice] += value`) + if isinstance(indices, (list, tuple)) and len(indices) == 2: + idx0, idx1 = indices + if isinstance(idx0, torch.Tensor) and isinstance(idx1, slice): + # This is the pattern: output[idxs, tile_f] += segment_vals + start = idx1.start or 0 + stop = idx1.stop or tensor.shape[1] + tensor_view = tensor[:, start:stop] + if isinstance(value, (float, int)): + value_tensor = torch.full_like(tensor_view[0], value) + tensor_view.index_add_(0, idx0, value_tensor) + else: + tensor_view.index_add_(0, idx0, value) + return + + # Default case + normalized_indices = _normalize_indices(indices) + tensor[normalized_indices] += value diff --git a/helion/language/reduce_ops.py b/helion/language/reduce_ops.py index b811330b..eae4ae1a 100644 --- a/helion/language/reduce_ops.py +++ b/helion/language/reduce_ops.py @@ -278,6 +278,46 @@ def _( return output_tensor +@_decorators.ref(reduce) +def _( + combine_fn: CombineFunction, + input_tensor: torch.Tensor | tuple[torch.Tensor, ...], + dim: int | None = None, + other: float | tuple[float, ...] = 0, + keep_dims: bool = False, +) -> torch.Tensor | tuple[torch.Tensor, ...]: + # In ref mode, we simulate reduce using PyTorch operations + # combine_fn is a binary function that combines two elements + # For standard reductions, we would use torch operations directly + + # This is a simplified implementation that doesn't actually use combine_fn + # In a real implementation, we would need to manually reduce using the combine function + if isinstance(input_tensor, tuple): + # For tuple inputs, reduce each tensor separately + results = [] + for tensor in input_tensor: + if dim is None: + # Reduce all dimensions - return a scalar + result = tensor.sum() + if keep_dims: + result = result.reshape([1] * len(tensor.shape)) + else: + # Reduce specific dimension + result = tensor.sum(dim=dim, keepdim=keep_dims) + results.append(result) + return tuple(results) + # Single tensor input + if dim is None: + # Reduce all dimensions - return a scalar + result = input_tensor.sum() + if keep_dims: + result = result.reshape([1] * len(input_tensor.shape)) + else: + # Reduce specific dimension + result = input_tensor.sum(dim=dim, keepdim=keep_dims) + return result + + @_decorators.api() def _reduce( combine_graph_id: int, @@ -337,6 +377,19 @@ def _(state: CodegenState) -> ast.AST | list[ast.AST]: return reduce_expr +@_decorators.ref(_reduce) +def _( + combine_graph_id: int, + input_tensor: torch.Tensor | tuple[torch.Tensor, ...], + dim: int | None = None, + keep_dims: bool = False, + is_tuple_input: bool = False, +) -> torch.Tensor | tuple[torch.Tensor, ...]: + # For ref mode, we don't have access to the combine graph + # This should be handled by the higher-level reduce ref implementation + raise NotImplementedError("_reduce should not be called in ref mode") + + def _register_helper_function(state: CodegenState, combine_graph_id: int) -> str: """Register the helper function and return its final name.""" from .._compiler.device_ir import HelperFunctionGraphInfo diff --git a/helion/language/scan_ops.py b/helion/language/scan_ops.py index 44ea10cb..95d9d37b 100644 --- a/helion/language/scan_ops.py +++ b/helion/language/scan_ops.py @@ -3,6 +3,8 @@ import ast import operator from typing import TYPE_CHECKING +from typing import Callable +from typing import Iterator from typing import cast from typing import overload @@ -23,6 +25,108 @@ __all__ = ["associative_scan", "cumprod", "cumsum"] +# Helper functions for ref mode implementations +def _build_indices( + shape: tuple[int, ...], dim: int, idx: int +) -> tuple[slice | int, ...]: + """Build indexing tuple for accessing position idx along dimension dim.""" + indices: list[slice | int] = [slice(None)] * len(shape) + indices[dim] = idx + return tuple(indices) + + +def _iterate_scan_dimension( + scan_size: int, reverse: bool +) -> Iterator[tuple[int, int, bool]]: + """ + Generate iteration indices for scan operation. + + Yields: + Tuple of (iteration_index, actual_index, is_first_element) + """ + for i in range(scan_size): + # Calculate current index based on scan direction + idx = (scan_size - 1 - i) if reverse else i + + # Check if this is the first element in the scan + is_first = (i == 0 and not reverse) or (i == scan_size - 1 and reverse) + + yield i, idx, is_first + + +def _get_prev_index(idx: int, reverse: bool) -> int: + """Get the previous index in the scan sequence.""" + return (idx + 1) if reverse else (idx - 1) + + +def _scan_single_tensor( + combine_fn: Callable, input_tensor: torch.Tensor, dim: int, reverse: bool +) -> torch.Tensor: + """Helper function to perform scan on a single tensor.""" + result = torch.empty_like(input_tensor) + scan_size = input_tensor.shape[dim] + + # Iterate through the dimension to scan + for _i, idx, is_first in _iterate_scan_dimension(scan_size, reverse): + # Build indexing tuple to access elements at position idx along dim + indices = _build_indices(input_tensor.shape, dim, idx) + + if is_first: + # First element: copy input directly + result[indices] = input_tensor[indices] + else: + # Combine with previous accumulated value + prev_idx = _get_prev_index(idx, reverse) + prev_indices = _build_indices(input_tensor.shape, dim, prev_idx) + + # Apply the combine function + result[indices] = combine_fn(result[prev_indices], input_tensor[indices]) + + return result + + +def _scan_tuple_tensors( + combine_fn: Callable, input_tuple: tuple[torch.Tensor, ...], dim: int, reverse: bool +) -> tuple[torch.Tensor, ...]: + """Helper function to perform scan on a tuple of tensors.""" + tensors = list(input_tuple) + scan_size = tensors[0].shape[dim] + + # Initialize result tensors + results = [torch.empty_like(t) for t in tensors] + + # Iterate through the dimension to scan + for _i, idx, is_first in _iterate_scan_dimension(scan_size, reverse): + # Build indexing tuple + indices = _build_indices(tensors[0].shape, dim, idx) + + if is_first: + # First element: copy inputs directly + for j, tensor in enumerate(tensors): + results[j][indices] = tensor[indices] + else: + # Combine with previous accumulated values + prev_idx = _get_prev_index(idx, reverse) + prev_indices = _build_indices(tensors[0].shape, dim, prev_idx) + + # Gather values for combination + current_vals = tuple(t[indices] for t in tensors) + prev_vals = tuple(r[prev_indices] for r in results) + + # Apply combine function with unpacked arguments + combined = combine_fn(*prev_vals, *current_vals) + + # Store results (handle both single and tuple returns) + if isinstance(combined, tuple): + for j, val in enumerate(combined): + results[j][indices] = val + else: + # Single result case + results[0][indices] = combined + + return tuple(results) + + @overload @_decorators.device_func_replacement(higher_order_ops.associative_scan) @_decorators.api(is_device_only=True) @@ -229,6 +333,18 @@ def _( ) +@_decorators.ref(associative_scan) +def _( + combine_fn: Callable, + input_tensor: torch.Tensor | tuple[torch.Tensor, ...], + dim: int, + reverse: bool = False, +) -> torch.Tensor | tuple[torch.Tensor, ...]: + if isinstance(input_tensor, (tuple, list)): + return _scan_tuple_tensors(combine_fn, tuple(input_tensor), dim, reverse) + return _scan_single_tensor(combine_fn, input_tensor, dim, reverse) + + @_decorators.device_func_replacement(torch.cumsum) def cumsum(input_tensor: torch.Tensor, dim: int, reverse: bool = False) -> torch.Tensor: """ @@ -335,6 +451,19 @@ def _(state: CodegenState) -> ast.AST | list[ast.AST]: return scan_expr +@_decorators.ref(_associative_scan) +def _( + combine_graph_id: int, + input_tensor: torch.Tensor | tuple[torch.Tensor, ...], + dim: int, + reverse: bool = False, + is_tuple_input: bool = False, +) -> torch.Tensor | tuple[torch.Tensor, ...]: + # For ref mode, we don't have access to the combine graph + # This should be handled by the higher-level associative_scan ref implementation + raise NotImplementedError("_associative_scan should not be called in ref mode") + + def _get_input_tensor_ast(state: CodegenState, is_tuple_input: bool) -> ast.AST: """Get the input tensor AST, handling tuple inputs specially.""" if not is_tuple_input: diff --git a/helion/language/signal_wait.py b/helion/language/signal_wait.py index 2c474aaa..2af3b239 100644 --- a/helion/language/signal_wait.py +++ b/helion/language/signal_wait.py @@ -157,6 +157,20 @@ def _(state: CodegenState) -> ast.AST: ) +@_decorators.ref(wait) +def _( + signal_pad: torch.Tensor, + index: list[object], + signal: int = 1, + update: int | None = None, + op: str = "ld", + sem: str = "acquire", + scope: str = "gpu", + skip_sync: bool = False, +) -> None: + raise NotImplementedError("wait is not supported in ref mode") + + @has_side_effect @_decorators.api(tiles_as_sizes=True, allow_host_tensor=True) def signal( @@ -295,3 +309,17 @@ def _(state: CodegenState) -> ast.AST: signal=signal_expr, skip_sync=skip_sync_expr, ) + + +@_decorators.ref(signal) +def _( + signal_pad: torch.Tensor, + index: list[object], + signal: int = 1, + wait_for: int | None = None, + op: str = "atomic_xchg", + sem: str = "release", + scope: str = "gpu", + skip_sync: bool = False, +) -> torch.Tensor: + raise NotImplementedError("signal is not supported in ref mode") diff --git a/helion/language/tile_ops.py b/helion/language/tile_ops.py index 6e90b9ac..ba853c8c 100644 --- a/helion/language/tile_ops.py +++ b/helion/language/tile_ops.py @@ -48,6 +48,12 @@ def _(state: CodegenState) -> ast.AST: return expr_from_string(state.codegen.index_var(index)) +@_decorators.ref(tile_index) +def _(tile: slice) -> torch.Tensor: + # Handle different tile representations in ref mode + return torch.arange(tile.start, tile.stop, dtype=torch.int64, device="cuda") + + @_decorators.api(tiles_as_sizes=True) def tile_begin(tile: Tile) -> int: """ @@ -82,6 +88,16 @@ def _(state: CodegenState) -> ast.AST: return expr_from_string(state.codegen.offset_var(index)) +@_decorators.ref(tile_begin) +def _(tile: int | slice) -> int: + # Handle different tile representations in ref mode + if isinstance(tile, slice): + return tile.start + # In ref mode with tiles_as_sizes=True, we lost the begin info + # This is a limitation - we return 0 as we don't know the actual begin + return 0 + + @_decorators.api(tiles_as_sizes=True) def tile_end(tile: Tile) -> int: """ @@ -121,6 +137,16 @@ def _(state: CodegenState) -> ast.AST: return expr_from_string(naive_exp) +@_decorators.ref(tile_end) +def _(tile: int | slice) -> int: + # Handle different tile representations in ref mode + if isinstance(tile, slice): + return tile.stop + # In ref mode with tiles_as_sizes=True, we get the size + # We lost the begin info, so we assume end = size + return tile + + @_decorators.api(tiles_as_sizes=True) def tile_block_size(tile: Tile) -> int: """ @@ -139,6 +165,15 @@ def _(tile: torch.SymInt) -> torch.SymInt: # codegen is handled in _get_symnode() +@_decorators.ref(tile_block_size) +def _(tile: int | slice) -> int: + # Handle different tile representations in ref mode + if isinstance(tile, slice): + return tile.stop - tile.start + # In ref mode with tiles_as_sizes=True, the tile IS the size + return tile + + @_decorators.api(tiles_as_sizes=True) def tile_id(tile: Tile) -> int: """ @@ -166,3 +201,10 @@ def _(state: CodegenState) -> ast.AST: else: expr_str = f"{offset} // {block_size}" return expr_from_string(expr_str) + + +@_decorators.ref(tile_id) +def _(tile: int | slice) -> int: + # tile_id is the index of the tile in the grid + # For ref mode we don't have the original block_size, so we return 0 + return 0 diff --git a/helion/language/tunable_ops.py b/helion/language/tunable_ops.py index e6086ae2..0c33f170 100644 --- a/helion/language/tunable_ops.py +++ b/helion/language/tunable_ops.py @@ -104,6 +104,13 @@ def _(state: CodegenState) -> ast.AST: return expr_from_string(constant_repr(block_size)) +@_decorators.ref(register_block_size) +def _(min_or_max: int, max_or_none: int | None = None) -> int: + # In ref mode, block_size represents the full dimension size + # Return a very large value to simulate this behavior + return 2**31 - 1 # Max value for a 32-bit signed integer + + @_decorators.api(is_device_only=False, cache_type=True, tiles_as_sizes=True) def register_reduction_dim( size: int, @@ -158,6 +165,11 @@ def _(state: CodegenState) -> ast.AST: ] +@_decorators.ref(register_reduction_dim) +def _(size: int) -> int: + return size + + @_decorators.api(is_device_only=False) def register_tunable(name: str, fragment: ConfigSpecFragment) -> int: """ @@ -220,3 +232,10 @@ def _register_tunable_codegen(state: CodegenState) -> ast.AST: config_value = state.config[name] assert isinstance(config_value, (int, float, bool)) return expr_from_string(constant_repr(config_value)) + + +@_decorators.ref(register_tunable) +def _(name: str, fragment: ConfigSpecFragment) -> int: + default_value = fragment.default() + assert isinstance(default_value, int) + return default_value diff --git a/helion/language/view_ops.py b/helion/language/view_ops.py index cdbb7cd8..2e0d7d9d 100644 --- a/helion/language/view_ops.py +++ b/helion/language/view_ops.py @@ -2,6 +2,7 @@ import collections from typing import TYPE_CHECKING +from typing import Any import torch @@ -102,6 +103,19 @@ def _(state: CodegenState) -> ast.AST: ) +@_decorators.ref(subscript) +def _(tensor: torch.Tensor, indices: list[object]) -> torch.Tensor: + # Convert indices to proper types for tensor indexing + typed_indices: list[Any] = [] + for idx in indices: + if isinstance(idx, (int, slice, torch.Tensor)) or idx is ...: + typed_indices.append(idx) + else: + # Fallback for other types, try to convert to int + typed_indices.append(int(idx)) # type: ignore[arg-type] + return tensor[tuple(typed_indices)] + + @_decorators.get_masked_value(subscript) def _(node: torch.fx.Node) -> float | bool | None: from .._compiler.node_masking import cached_masked_value diff --git a/helion/runtime/kernel.py b/helion/runtime/kernel.py index 52247cd4..4f666490 100644 --- a/helion/runtime/kernel.py +++ b/helion/runtime/kernel.py @@ -36,6 +36,8 @@ from .._logging import LazyString from ..language.constexpr import ConstExpr from .config import Config +from .ref_mode import RefModeContext +from .settings import RefMode from .settings import Settings if TYPE_CHECKING: @@ -279,7 +281,11 @@ def reset(self) -> None: class BoundKernel(Generic[_R]): - def __init__(self, kernel: Kernel[_R], args: tuple[object, ...]) -> None: + def __init__( + self, + kernel: Kernel[_R], + args: tuple[object, ...], + ) -> None: """ Initialize a BoundKernel object. @@ -295,8 +301,17 @@ def __init__(self, kernel: Kernel[_R], args: tuple[object, ...]) -> None: self._run: Callable[..., _R] | None = None self._config: Config | None = None self._compile_cache: dict[Config, CompiledConfig] = {} + self._ref_func: Callable[..., _R] | None = None + + # If in ref mode, skip all compilation infrastructure + if self.kernel.settings.ref_mode != RefMode.OFF: + self.env = None + self.fake_args = [] # type: ignore[assignment] + self.host_function = None # type: ignore[assignment] + return + self.env = CompileEnvironment(_find_device(args), self.kernel.settings) - with self.env: + with self.env: # pyright: ignore[reportOptionalContextManager] assert len(args) == len(self.kernel.signature.parameters) self.fake_args: list[object] = [] constexpr_args = {} @@ -346,7 +361,7 @@ def config_spec(self) -> ConfigSpec: Returns: ConfigSpec: The configuration specification. """ - return self.env.config_spec + return self.env.config_spec # pyright: ignore[reportOptionalMemberAccess] @property def configs(self) -> list[Config]: @@ -370,10 +385,10 @@ def to_triton_code(self, config: ConfigLike | None = None) -> str: """ if config is None: config = self._require_implicit_config() - with self.env: + with self.env: # pyright: ignore[reportOptionalContextManager] if not isinstance(config, Config): config = Config(**config) # pyright: ignore[reportArgumentType] - self.env.config_spec.normalize(config) + self.env.config_spec.normalize(config) # pyright: ignore[reportOptionalMemberAccess] root = generate_ast(self.host_function, config) return get_needed_imports(root) + unparse(root) @@ -416,7 +431,7 @@ def _debug_str(self) -> str: Returns: str: A string containing debug information about the kernel. """ - with self.env: + with self.env: # pyright: ignore[reportOptionalContextManager] return self.host_function.debug_str() def autotune( @@ -491,6 +506,10 @@ def _specialize_extra(self) -> list[Callable[[Sequence[object]], Hashable]]: Returns: list[Callable[[Sequence[object]], Hashable]]: A list of functions that generate extra specialization keys. """ + if self.kernel.settings.ref_mode != RefMode.OFF: + return [] # No specialization in ref mode + + assert self.env is not None # Should be set in non-ref mode if not self.env.specialized_vars: return [] @@ -511,6 +530,7 @@ def make_extractor(v: Source) -> Callable[[Sequence[object]], Hashable]: n: i for i, n in enumerate(self.kernel.signature.parameters.keys()) } extractors = [] + assert self.env is not None # Should be set in non-ref mode for v in sorted(self.env.specialized_vars, key=lambda v: v.name): source = self.env.shape_env.var_to_sources[v][0] extractors.append(make_extractor(source)) @@ -520,6 +540,8 @@ def _implicit_config(self) -> Config | None: """ Returns a single config that is implicitly used by this kernel, if any. """ + if self.kernel.settings.ref_mode != RefMode.OFF: + return None # No config needed in ref mode configs = self.kernel.configs if self._config is not None: return self._config @@ -537,6 +559,23 @@ def _require_implicit_config(self) -> Config: raise RuntimeError("no config provided and no implicit config available") return config + def run_ref(self, *args: object) -> _R: + if self._ref_func is None: + # Use the original function directly without AST transformation + fn = self.kernel.fn + + def ref_wrapper(*args: object) -> _R: # pyright: ignore[reportReturnType] + from .ref_mode import HelionTorchFunctionMode + + with RefModeContext(), HelionTorchFunctionMode(): + # EAGER mode only + return fn(*args) # pyright: ignore[reportReturnType] + + self._ref_func = ref_wrapper + + assert self._ref_func is not None + return self._ref_func(*args) + def __call__(self, *args: object) -> _R: """ Execute the kernel with the given arguments. @@ -547,6 +586,9 @@ def __call__(self, *args: object) -> _R: Returns: _R: The result of the kernel execution. """ + if self.kernel.settings.ref_mode != RefMode.OFF: + return self.run_ref(*args) + if self._run is None: if (config := self._implicit_config()) is not None: self.set_config(config) @@ -622,8 +664,16 @@ def kernel( settings_obj = Settings(**settings) if fn is None: - return functools.partial(kernel, configs=configs, settings=settings_obj) - return Kernel(fn, configs=configs, settings=settings_obj) + return functools.partial( + kernel, + configs=configs, + settings=settings_obj, + ) + return Kernel( + fn, + configs=configs, + settings=settings_obj, + ) def _tensor_key(fn: Kernel, obj: torch.Tensor) -> Hashable: diff --git a/helion/runtime/ref_mode.py b/helion/runtime/ref_mode.py new file mode 100644 index 00000000..6d5ac01f --- /dev/null +++ b/helion/runtime/ref_mode.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +import threading +from typing import TYPE_CHECKING + +import torch +from torch.overrides import BaseTorchFunctionMode + +if TYPE_CHECKING: + from typing_extensions import Self + +_thread_local = threading.local() + + +def is_ref_mode_enabled() -> bool: + """Check if ref mode is currently active.""" + return getattr(_thread_local, "ref_mode_enabled", False) + + +class RefModeContext: + """Context manager to enable ref mode execution.""" + + def __enter__(self) -> Self: + self._old_value = getattr(_thread_local, "ref_mode_enabled", False) + _thread_local.ref_mode_enabled = True + return self + + def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> bool: + _thread_local.ref_mode_enabled = self._old_value + return False + + +class HelionTorchFunctionMode(BaseTorchFunctionMode): + """Torch function mode for Helion ref mode operations.""" + + def __torch_function__( + self, + func: object, + types: list[type[object]], + args: tuple[object, ...] = (), + kwargs: dict[str, object] | None = None, + ) -> object: + if kwargs is None: + kwargs = {} + + # Replace torch.addmm with _helion_mixed_addmm + if func == torch.addmm: + # Cast args to expected types + assert len(args) >= 3, "addmm requires at least 3 arguments" + return _helion_mixed_addmm( + args[0], # type: ignore[arg-type] + args[1], # type: ignore[arg-type] + args[2], # type: ignore[arg-type] + *args[3:], # type: ignore[arg-type] + **kwargs, # type: ignore[arg-type] + ) + + # Replace torch.baddbmm with _helion_mixed_baddbmm + if func == torch.baddbmm: + # Cast args to expected types + assert len(args) >= 3, "baddbmm requires at least 3 arguments" + return _helion_mixed_baddbmm( + args[0], # type: ignore[arg-type] + args[1], # type: ignore[arg-type] + args[2], # type: ignore[arg-type] + *args[3:], # type: ignore[arg-type] + **kwargs, # type: ignore[arg-type] + ) + + return super().__torch_function__(func, types, args, kwargs) + + +def _helion_mixed_addmm( + bias: torch.Tensor, + mat1: torch.Tensor, + mat2: torch.Tensor, + *, + beta: float = 1, + alpha: float = 1, +) -> torch.Tensor: + """Mixed precision addmm that handles dtype mismatches.""" + # Ensure both matrices have the same dtype + if mat1.dtype != mat2.dtype: + raise ValueError( + f"Matrix dtypes must match for torch.addmm: mat1.dtype={mat1.dtype}, mat2.dtype={mat2.dtype}" + ) + + # Use torch.mm with out_dtype to perform mixed precision computation + # out_dtype must be the same as bias dtype or fp32 for fp16/bf16 inputs + if ( + mat1.dtype in (torch.float16, torch.bfloat16) and bias.dtype == torch.float32 + ) or mat1.dtype == bias.dtype: + result = torch.mm(mat1, mat2, out_dtype=bias.dtype) + else: + raise ValueError( + f"Unsupported dtype combination for torch.addmm: bias.dtype={bias.dtype}, " + f"mat1.dtype={mat1.dtype}. out_dtype must be the same as bias dtype or " + f"fp32 for fp16/bf16 inputs." + ) + + # Scale the result + if alpha != 1: + result = result * alpha + + # Add the bias term, converting result to bias's dtype if needed + if result.dtype != bias.dtype: + result = result.to(bias.dtype) + + if beta == 0: + return result + return result + (beta * bias) + + +def _helion_mixed_baddbmm( + bias: torch.Tensor, + batch1: torch.Tensor, + batch2: torch.Tensor, + *, + beta: float = 1, + alpha: float = 1, +) -> torch.Tensor: + """Mixed precision baddbmm that handles dtype mismatches.""" + # Ensure both batch matrices have the same dtype + if batch1.dtype != batch2.dtype: + raise ValueError( + f"Batch matrix dtypes must match for torch.baddbmm: batch1.dtype={batch1.dtype}, batch2.dtype={batch2.dtype}" + ) + + # Use torch.bmm with out_dtype to perform mixed precision computation + # out_dtype must be the same as bias dtype or fp32 for fp16/bf16 inputs + if ( + batch1.dtype in (torch.float16, torch.bfloat16) and bias.dtype == torch.float32 + ) or batch1.dtype == bias.dtype: + result = torch.bmm(batch1, batch2, out_dtype=bias.dtype) + else: + raise ValueError( + f"Unsupported dtype combination for torch.baddbmm: bias.dtype={bias.dtype}, " + f"batch1.dtype={batch1.dtype}. out_dtype must be the same as bias dtype or " + f"fp32 for fp16/bf16 inputs." + ) + + # Scale the result + if alpha != 1: + result = result * alpha + + # Add the bias term, converting result to bias's dtype if needed + if result.dtype != bias.dtype: + result = result.to(bias.dtype) + + if beta == 0: + return result + return result + (beta * bias) diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index 339f8cf2..6273b011 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -1,6 +1,7 @@ from __future__ import annotations import dataclasses +import enum import logging import os import sys @@ -25,6 +26,23 @@ class _TLS(Protocol): _tls: _TLS = cast("_TLS", threading.local()) +class RefMode(enum.Enum): + """Reference mode for kernel execution.""" + + OFF = "off" + EAGER = "eager" + + +def _get_ref_mode_from_env() -> RefMode: + """Get reference mode from environment variables.""" + # Check for environment variables + ref_eager = os.environ.get("HELION_REF_EAGER", "").lower() in ("1", "true", "yes") + + if ref_eager: + return RefMode.EAGER + return RefMode.OFF + + def set_default_settings(settings: Settings) -> AbstractContextManager[None, None]: """ Set the default settings for the current thread and return a context manager @@ -72,6 +90,9 @@ class _Settings: allow_warp_specialize: bool = ( os.environ.get("HELION_ALLOW_WARP_SPECIALIZE", "1") == "1" ) + ref_mode: RefMode = dataclasses.field( + default_factory=lambda: _get_ref_mode_from_env() + ) class Settings(_Settings): @@ -92,6 +113,7 @@ class Settings(_Settings): "print_output_code": "If True, print the output code of the kernel to stderr.", "force_autotune": "If True, force autotuning even if a config is provided.", "allow_warp_specialize": "If True, allow warp specialization for tl.range calls on CUDA devices.", + "ref_mode": "Reference mode for kernel execution. Can be RefMode.OFF or RefMode.EAGER.", } assert __slots__.keys() == {field.name for field in dataclasses.fields(_Settings)} diff --git a/test/ref_utils.py b/test/ref_utils.py new file mode 100644 index 00000000..589e51af --- /dev/null +++ b/test/ref_utils.py @@ -0,0 +1,36 @@ +"""Helper utilities for reference mode tests.""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +import helion +from helion._testing import EXAMPLES_DIR +from helion._testing import import_path + +if TYPE_CHECKING: + from helion.runtime.settings import RefMode + + +def clear_kernel_caches_and_set_ref_mode(ref_mode: RefMode) -> None: + """Clear kernel caches and set ref_mode on all kernels in examples.""" + # Get all Python files in the examples directory + example_files = Path(EXAMPLES_DIR).glob("*.py") + + for example_file in example_files: + try: + # Import the module + mod = import_path(example_file) + + # Find all Helion kernels in the module and update their settings + for attr_name in dir(mod): + attr = getattr(mod, attr_name) + if isinstance(attr, helion.Kernel): + # Reset the kernel to clear any cached bound kernels + attr.reset() + # Update the kernel's ref_mode setting + attr.settings.ref_mode = ref_mode + except Exception: + # Skip files that can't be imported or have issues + pass diff --git a/test/test_ref_eager.py b/test/test_ref_eager.py new file mode 100644 index 00000000..45ff3e0d --- /dev/null +++ b/test/test_ref_eager.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +import contextlib +import io +import math +import unittest +from unittest.mock import patch + +import pytest +import torch + +from . import test_examples +from .ref_utils import clear_kernel_caches_and_set_ref_mode +import helion +from helion._testing import TestCase +import helion.language as hl + + +class TestExamplesRefEager(test_examples.TestExamples): + """Run all TestExamples tests in reference eager mode.""" + + # NOTE: All tests in TestExamples are run in ref eager mode by default in this test file. + + def assertExpectedJournal(self, value: str) -> None: + """Skip journal assertions in ref mode since we don't generate Triton code.""" + + def setUp(self): + """Set up test environment.""" + super().setUp() + # Clear kernel caches and set ref mode to EAGER + clear_kernel_caches_and_set_ref_mode(helion.RefMode.EAGER) + + def tearDown(self): + """Restore original environment.""" + super().tearDown() + # Clear kernel caches and reset to OFF mode + clear_kernel_caches_and_set_ref_mode(helion.RefMode.OFF) + + def test_add(self): + # Mock the critical functions that should NOT be called in ref eager mode + # These functions are only used in normal Helion mode to generate and compile Triton code + with ( + patch("helion.runtime.kernel.BoundKernel.to_triton_code") as mock_to_triton, + patch("helion.runtime.kernel.BoundKernel.compile_config") as mock_compile, + ): + # Set up the mocks to fail if called + mock_to_triton.side_effect = AssertionError( + "to_triton_code should NOT be called in reference eager mode!" + ) + mock_compile.side_effect = AssertionError( + "compile_config should NOT be called in reference eager mode!" + ) + + # Run the original test + super().test_add() + + # Assert that neither function was called + mock_to_triton.assert_not_called() + mock_compile.assert_not_called() + + @pytest.mark.skip(reason="tile.* API is not supported yet") + def test_concat(self): + super().test_concat() + + @pytest.mark.skip(reason="tile.* API is not supported yet") + def test_concat_block_ptr(self): + super().test_concat_block_ptr() + + @pytest.mark.skip(reason="tile.* API is not supported yet") + def test_cross_entropy(self): + super().test_cross_entropy() + + @pytest.mark.skip(reason="tile.* API is not supported yet") + def test_jagged_dense_add(self): + super().test_jagged_dense_add() + + @pytest.mark.skip(reason="tile.* API is not supported yet") + def test_jagged_mean(self): + super().test_jagged_mean() + + @pytest.mark.skip(reason="tile.* API is not supported yet") + def test_matmul_split_k(self): + super().test_matmul_split_k() + + @pytest.mark.skip(reason="tile.* API is not supported yet") + def test_moe_matmul_ogs(self): + super().test_moe_matmul_ogs() + + @pytest.mark.skip(reason="tile.* API is not supported yet") + def test_segment_reduction(self): + super().test_segment_reduction() + + +class TestRefEagerMisc(TestCase): + def test_print_intermediate_tensor(self): + @helion.kernel(ref_mode=helion.RefMode.EAGER) + def print_intermediate_tensor_kernel( + x: torch.Tensor, y: torch.Tensor + ) -> torch.Tensor: + out = torch.empty_like(x) + m, n = x.shape + for tile_m, tile_n in hl.tile([m, n]): + x_val = x[tile_m, tile_n] + y_val = y[tile_m, tile_n] + sum_val = x_val + y_val + print("x: ", x_val) + print("y: ", y_val) + print("sum: ", sum_val) + out[tile_m, tile_n] = sum_val + return out + + x = torch.ones([2, 2], device="cuda", dtype=torch.float32) * 10.0 + y = torch.ones([2, 2], device="cuda", dtype=torch.float32) * 5.0 + expected = x + y + + # Capture stdout to check print output + captured_output = io.StringIO() + with contextlib.redirect_stdout(captured_output): + result = print_intermediate_tensor_kernel(x, y) + + torch.testing.assert_close(result, expected, atol=1e-6, rtol=1e-6) + + # Check that the print statements produced output + output = captured_output.getvalue() + self.assertIn("x: ", output) + self.assertIn("y: ", output) + self.assertIn("sum: ", output) + self.assertIn("[[10., 10.]", output) # x values + self.assertIn("[[5., 5.]", output) # y values + self.assertIn("[[15., 15.]", output) # sum values + + def test_print_in_invalid_helion_kernel(self): + """Test that print works even in invalid Helion kernels in ref eager mode.""" + + @helion.kernel(ref_mode=helion.RefMode.EAGER) + def incorrect_kernel(x: torch.Tensor) -> torch.Tensor: + m, n = x.shape + for tile_m, tile_n in hl.tile([m, n]): + val = x[tile_m, tile_n] + print("processing tile: ", val) + # `pass` below causes this kernel to be invalid. + # But we show that in ref-eager mode, the `print` statement above still works, + # which is useful for debugging. + pass # noqa: PIE790 + return x + + x = torch.ones([2, 2], device="cuda", dtype=torch.float32) * math.pi + + # Capture stdout to check print output + captured_output = io.StringIO() + with contextlib.redirect_stdout(captured_output): + _ = incorrect_kernel(x) + + # Check that the print statement produced output + output = captured_output.getvalue() + self.assertIn("processing tile: ", output) + self.assertIn("[[3.14", output) # The value printed + + def test_ref_eager_kernel_config(self): + @helion.kernel(ref_mode=helion.RefMode.EAGER) + def kernel(x: torch.Tensor) -> torch.Tensor: + return x + x * 2.0 + + x = torch.randn(128, device="cuda") + result = kernel(x) + expected = x + x * 2.0 + torch.testing.assert_close(result, expected) + + +if __name__ == "__main__": + unittest.main() From 9ba77c19b53e40d9eaaf73be6a6d595aff479c76 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Thu, 24 Jul 2025 22:12:05 -0700 Subject: [PATCH 3/3] [Ref Mode] Make Tile apis work in ref eager mode stack-info: PR: https://github.com/pytorch-labs/helion/pull/378, branch: yf225/stack/40 --- helion/language/creation_ops.py | 12 ++- helion/language/loops.py | 7 +- helion/language/memory_ops.py | 41 ++++++++- helion/language/tile_ops.py | 27 +++++- helion/language/tile_proxy.py | 144 ++++++++++++++++++++++++++++++++ helion/language/view_ops.py | 8 +- helion/runtime/ref_mode.py | 23 ++--- test/ref_utils.py | 36 -------- test/test_ref_eager.py | 64 +++++++------- 9 files changed, 270 insertions(+), 92 deletions(-) delete mode 100644 test/ref_utils.py diff --git a/helion/language/creation_ops.py b/helion/language/creation_ops.py index eeb022d3..df5e2ff7 100644 --- a/helion/language/creation_ops.py +++ b/helion/language/creation_ops.py @@ -150,7 +150,17 @@ def _( value: float, dtype: torch.dtype = torch.float32, ) -> torch.Tensor: - processed_shape = [s.stop - s.start if isinstance(s, slice) else s for s in shape] + from .tile_proxy import RefTile + + processed_shape = [] + for s in shape: + if isinstance(s, RefTile): + # RefTile is a slice subclass with a block_size property + processed_shape.append(s.block_size) + elif isinstance(s, slice): + processed_shape.append(s.stop - s.start) + else: + processed_shape.append(s) return torch.full(processed_shape, value, dtype=dtype, device="cuda") diff --git a/helion/language/loops.py b/helion/language/loops.py index 20387bd9..772423a3 100644 --- a/helion/language/loops.py +++ b/helion/language/loops.py @@ -43,6 +43,7 @@ from ..autotuner.config_spec import RangeWarpSpecializeSpec from ..autotuner.config_spec import StaticRangeSpec from . import _decorators +from .tile_proxy import RefTile from .tile_proxy import Tile if TYPE_CHECKING: @@ -455,7 +456,7 @@ def _( begin_or_end: int | torch.Tensor | list[int | torch.Tensor], end_or_none: int | torch.Tensor | list[int | torch.Tensor] | None = None, block_size: int | torch.Tensor | list[int | torch.Tensor] | None = None, -) -> Iterator[slice | tuple[slice, ...]]: +) -> Iterator[RefTile | tuple[RefTile, ...]]: # Convert tensor values to int def _to_int(value: int | torch.Tensor | None) -> int | None: if value is None: @@ -523,7 +524,7 @@ def _normalize_to_list( bs = block_size_list[0] assert b is not None and e is not None and bs is not None for i in range(b, e, bs): - yield slice(i, min(i + bs, e)) + yield RefTile(i, min(i + bs, e)) else: # Handle multi-dimensional case ranges = [] @@ -531,7 +532,7 @@ def _normalize_to_list( dim_ranges = [] assert b is not None and e is not None and bs is not None for i in range(b, e, bs): - dim_ranges.append(slice(i, min(i + bs, e))) + dim_ranges.append(RefTile(i, min(i + bs, e))) ranges.append(dim_ranges) yield from itertools.product(*ranges) diff --git a/helion/language/memory_ops.py b/helion/language/memory_ops.py index be92c353..5aa95299 100644 --- a/helion/language/memory_ops.py +++ b/helion/language/memory_ops.py @@ -84,7 +84,13 @@ def _handle_mixed_indices( for i, idx in enumerate(indices): if isinstance(idx, slice): # Handle slice indices - shape_size = idx.stop - idx.start + if idx.start is None and idx.stop is None: + # Full slice like `:` + shape_size = tensor_shape[i] if i < len(tensor_shape) else 1 + else: + start = idx.start or 0 + stop = idx.stop or (tensor_shape[i] if i < len(tensor_shape) else 1) + shape_size = stop - start expected_shape.append(shape_size) actual_indices.append(idx) elif isinstance(idx, torch.Tensor): @@ -203,6 +209,17 @@ def _( value: torch.Tensor, extra_mask: torch.Tensor | None = None, ) -> None: + # Convert RefTile objects to slices + from .tile_proxy import RefTile + + processed_indices = [] + for idx in indices: + if isinstance(idx, RefTile): + processed_indices.append(idx._slice) + else: + processed_indices.append(idx) + indices = processed_indices + normalized_indices = _normalize_indices(indices) if extra_mask is not None: @@ -269,6 +286,17 @@ def _( assert isinstance(indices, (list, tuple)) + # Convert RefTile objects to slices + from .tile_proxy import RefTile + + processed_indices = [] + for idx in indices: + if isinstance(idx, RefTile): + processed_indices.append(idx._slice) + else: + processed_indices.append(idx) + indices = processed_indices + # Case 1: Single tensor index (jagged indexing) if len(indices) == 1 and isinstance(indices[0], torch.Tensor): result = _handle_single_tensor_index(tensor, indices[0], extra_mask) @@ -399,6 +427,17 @@ def _( value: torch.Tensor | float, sem: str = "relaxed", ) -> None: + # Convert RefTile objects to slices + from .tile_proxy import RefTile + + processed_indices = [] + for idx in indices: + if isinstance(idx, RefTile): + processed_indices.append(idx._slice) + else: + processed_indices.append(idx) + indices = processed_indices + # Special handling for scatter-add pattern (`tensor[tensor_idx, slice] += value`) if isinstance(indices, (list, tuple)) and len(indices) == 2: idx0, idx1 = indices diff --git a/helion/language/tile_ops.py b/helion/language/tile_ops.py index ba853c8c..f8e36cf9 100644 --- a/helion/language/tile_ops.py +++ b/helion/language/tile_ops.py @@ -49,9 +49,16 @@ def _(state: CodegenState) -> ast.AST: @_decorators.ref(tile_index) -def _(tile: slice) -> torch.Tensor: +def _(tile: slice | int) -> torch.Tensor: # Handle different tile representations in ref mode - return torch.arange(tile.start, tile.stop, dtype=torch.int64, device="cuda") + from .tile_proxy import RefTile + + if isinstance(tile, RefTile): + return tile.index + if isinstance(tile, slice): + return torch.arange(tile.start, tile.stop, dtype=torch.int64, device="cuda") + # tiles_as_sizes=True means we get an int + return torch.arange(0, tile, dtype=torch.int64, device="cuda") @_decorators.api(tiles_as_sizes=True) @@ -91,6 +98,10 @@ def _(state: CodegenState) -> ast.AST: @_decorators.ref(tile_begin) def _(tile: int | slice) -> int: # Handle different tile representations in ref mode + from .tile_proxy import RefTile + + if isinstance(tile, RefTile): + return tile.begin if isinstance(tile, slice): return tile.start # In ref mode with tiles_as_sizes=True, we lost the begin info @@ -140,6 +151,10 @@ def _(state: CodegenState) -> ast.AST: @_decorators.ref(tile_end) def _(tile: int | slice) -> int: # Handle different tile representations in ref mode + from .tile_proxy import RefTile + + if isinstance(tile, RefTile): + return tile.end if isinstance(tile, slice): return tile.stop # In ref mode with tiles_as_sizes=True, we get the size @@ -168,6 +183,10 @@ def _(tile: torch.SymInt) -> torch.SymInt: @_decorators.ref(tile_block_size) def _(tile: int | slice) -> int: # Handle different tile representations in ref mode + from .tile_proxy import RefTile + + if isinstance(tile, RefTile): + return tile.block_size if isinstance(tile, slice): return tile.stop - tile.start # In ref mode with tiles_as_sizes=True, the tile IS the size @@ -206,5 +225,9 @@ def _(state: CodegenState) -> ast.AST: @_decorators.ref(tile_id) def _(tile: int | slice) -> int: # tile_id is the index of the tile in the grid + from .tile_proxy import RefTile + + if isinstance(tile, RefTile): + return tile.id # For ref mode we don't have the original block_size, so we return 0 return 0 diff --git a/helion/language/tile_proxy.py b/helion/language/tile_proxy.py index 7097a2bf..98320ed2 100644 --- a/helion/language/tile_proxy.py +++ b/helion/language/tile_proxy.py @@ -182,3 +182,147 @@ def __enter__(self) -> Self: def __exit__(self, *args: object) -> None: _tls.index_calls = None + + +class RefTile(torch.Tensor): + """ + A tile-like object used in reference eager mode that behaves like a slice. + This allows tile.index and other tile operations to work properly in ref eager mode. + """ + + def __new__(cls, start: int, stop: int, step: int | None = None) -> Self: + # Create a tensor instance + return super().__new__(cls) + + def __init__(self, start: int, stop: int, step: int | None = None) -> None: + super().__init__() + # Store slice data + self.start = start + self.stop = stop + self.step = step + self._slice = slice(start, stop, step) + # We need to set block_id to something for compatibility + self.block_id = -1 # Special value for ref mode + + @property + def index(self) -> torch.Tensor: + """Return a tensor containing the offsets for this tile.""" + return torch.arange(self.start, self.stop, dtype=torch.int64, device="cuda") + + @property + def begin(self) -> int: + """Return the start offset of this tile.""" + return self.start + + @property + def end(self) -> int: + """Return the end offset of this tile.""" + return self.stop + + @property + def block_size(self) -> int: + """Return the block size of this tile.""" + return self.stop - self.start + + @property + def id(self) -> int: + """Return the id of this tile (always 0 in ref mode).""" + # We don't have enough info to compute the actual tile id + return 0 + + def __repr__(self, *, tensor_contents: object = None) -> str: + """Return string representation of RefTile.""" + # Override torch.Tensor's __repr__ with matching signature + return f"RefTile({self._slice!r})" + + def __int__(self) -> int: + """Convert to int for cases where a size is expected.""" + return self.block_size + + # Make RefTile usable as an index by delegating to the slice + def slice_indices(self, length: int) -> tuple[int, int, int]: + """Return (start, stop, step) tuple, like slice.indices().""" + return self._slice.indices(length) + + def equals(self, other: object) -> bool: + """Compare with other RefTile or slice objects. + + Use this instead of == for RefTile comparison. + """ + if isinstance(other, RefTile): + return self._slice == other._slice + if isinstance(other, slice): + return self._slice == other + return False + + def __hash__(self) -> int: + """Hash based on the slice.""" + return hash(self._slice) + + def __index__(self) -> int: + """Convert to int for use in tensor indexing. + + This is called when RefTile is used in advanced indexing contexts. + We return the start value which works for single-element tiles. + """ + # For single-element access (when block_size=1), return the index + if self.block_size == 1: + return self.start + # For larger tiles, we can't meaningfully convert to a single index + # This might happen in user lambdas trying to do advanced indexing + raise TypeError( + f"Cannot convert RefTile with block_size={self.block_size} to index" + ) + + @classmethod + def __torch_function__( + cls, + func: Callable[..., object], + types: object, + args: tuple[object, ...] = (), + kwargs: dict[str, object] | None = None, + ) -> object: + from ..language.memory_ops import load + from ..language.memory_ops import store + + if func is torch.Tensor.__getitem__: + if len(args) != 2 or kwargs: + raise exc.IncorrectTileUsage(func) + tensor, index = args + assert isinstance(tensor, torch.Tensor) + + # If a single RefTile is used as index, we want to use it as a slice + # e.g., tensor[ref_tile] should behave like tensor[ref_tile._slice] + if isinstance(index, RefTile): + return tensor[index._slice] + + # For multi-dimensional indexing (including lists) + return load(tensor, cls._prepare_index(index)) + + if func is torch.Tensor.__setitem__: + if len(args) != 3 or kwargs: + raise exc.IncorrectTileUsage(func) + tensor, index, value = args + assert isinstance(tensor, torch.Tensor) + assert isinstance(value, torch.Tensor) + + # Similar handling for setitem + if isinstance(index, RefTile): + tensor[index._slice] = value + return None + + return store(tensor, cls._prepare_index(index), value) + + if func is torch.Tensor.__format__: + return repr(args[0]) + raise exc.IncorrectTileUsage(func) + + @staticmethod + def _prepare_index(index: object) -> list[object]: + if isinstance(index, (list, tuple)): + # When indexing with a list of RefTiles like bias[[tile_m, tile_n]], + # we want it to be interpreted as bias[tile_m, tile_n] + # So we return the list as-is for multi-dimensional indexing + return [*index] + assert isinstance(index, RefTile) + return [index] diff --git a/helion/language/view_ops.py b/helion/language/view_ops.py index 2e0d7d9d..ed1dc2d0 100644 --- a/helion/language/view_ops.py +++ b/helion/language/view_ops.py @@ -3,6 +3,7 @@ import collections from typing import TYPE_CHECKING from typing import Any +from typing import cast import torch @@ -112,7 +113,12 @@ def _(tensor: torch.Tensor, indices: list[object]) -> torch.Tensor: typed_indices.append(idx) else: # Fallback for other types, try to convert to int - typed_indices.append(int(idx)) # type: ignore[arg-type] + try: + typed_indices.append(int(cast("Any", idx))) + except (TypeError, ValueError): + raise exc.InvalidIndexingType( + f"Cannot convert {idx!r} to index" + ) from None return tensor[tuple(typed_indices)] diff --git a/helion/runtime/ref_mode.py b/helion/runtime/ref_mode.py index 6d5ac01f..ba5720c8 100644 --- a/helion/runtime/ref_mode.py +++ b/helion/runtime/ref_mode.py @@ -2,6 +2,7 @@ import threading from typing import TYPE_CHECKING +from typing import cast import torch from torch.overrides import BaseTorchFunctionMode @@ -47,25 +48,19 @@ def __torch_function__( if func == torch.addmm: # Cast args to expected types assert len(args) >= 3, "addmm requires at least 3 arguments" - return _helion_mixed_addmm( - args[0], # type: ignore[arg-type] - args[1], # type: ignore[arg-type] - args[2], # type: ignore[arg-type] - *args[3:], # type: ignore[arg-type] - **kwargs, # type: ignore[arg-type] - ) + bias = cast("torch.Tensor", args[0]) + mat1 = cast("torch.Tensor", args[1]) + mat2 = cast("torch.Tensor", args[2]) + return _helion_mixed_addmm(bias, mat1, mat2, *args[3:], **kwargs) # Replace torch.baddbmm with _helion_mixed_baddbmm if func == torch.baddbmm: # Cast args to expected types assert len(args) >= 3, "baddbmm requires at least 3 arguments" - return _helion_mixed_baddbmm( - args[0], # type: ignore[arg-type] - args[1], # type: ignore[arg-type] - args[2], # type: ignore[arg-type] - *args[3:], # type: ignore[arg-type] - **kwargs, # type: ignore[arg-type] - ) + bias = cast("torch.Tensor", args[0]) + batch1 = cast("torch.Tensor", args[1]) + batch2 = cast("torch.Tensor", args[2]) + return _helion_mixed_baddbmm(bias, batch1, batch2, *args[3:], **kwargs) return super().__torch_function__(func, types, args, kwargs) diff --git a/test/ref_utils.py b/test/ref_utils.py deleted file mode 100644 index 589e51af..00000000 --- a/test/ref_utils.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Helper utilities for reference mode tests.""" - -from __future__ import annotations - -from pathlib import Path -from typing import TYPE_CHECKING - -import helion -from helion._testing import EXAMPLES_DIR -from helion._testing import import_path - -if TYPE_CHECKING: - from helion.runtime.settings import RefMode - - -def clear_kernel_caches_and_set_ref_mode(ref_mode: RefMode) -> None: - """Clear kernel caches and set ref_mode on all kernels in examples.""" - # Get all Python files in the examples directory - example_files = Path(EXAMPLES_DIR).glob("*.py") - - for example_file in example_files: - try: - # Import the module - mod = import_path(example_file) - - # Find all Helion kernels in the module and update their settings - for attr_name in dir(mod): - attr = getattr(mod, attr_name) - if isinstance(attr, helion.Kernel): - # Reset the kernel to clear any cached bound kernels - attr.reset() - # Update the kernel's ref_mode setting - attr.settings.ref_mode = ref_mode - except Exception: - # Skip files that can't be imported or have issues - pass diff --git a/test/test_ref_eager.py b/test/test_ref_eager.py index 45ff3e0d..fc338897 100644 --- a/test/test_ref_eager.py +++ b/test/test_ref_eager.py @@ -3,18 +3,46 @@ import contextlib import io import math +from pathlib import Path +from typing import TYPE_CHECKING import unittest from unittest.mock import patch -import pytest import torch from . import test_examples -from .ref_utils import clear_kernel_caches_and_set_ref_mode import helion +from helion._testing import EXAMPLES_DIR from helion._testing import TestCase +from helion._testing import import_path import helion.language as hl +if TYPE_CHECKING: + from helion.runtime.settings import RefMode + + +def clear_kernel_caches_and_set_ref_mode(ref_mode: RefMode) -> None: + """Clear kernel caches and set ref_mode on all kernels in examples.""" + # Get all Python files in the examples directory + example_files = Path(EXAMPLES_DIR).glob("*.py") + + for example_file in example_files: + try: + # Import the module + mod = import_path(example_file) + + # Find all Helion kernels in the module and update their settings + for attr_name in dir(mod): + attr = getattr(mod, attr_name) + if isinstance(attr, helion.Kernel): + # Reset the kernel to clear any cached bound kernels + attr.reset() + # Update the kernel's ref_mode setting + attr.settings.ref_mode = ref_mode + except Exception: + # Skip files that can't be imported or have issues + pass + class TestExamplesRefEager(test_examples.TestExamples): """Run all TestExamples tests in reference eager mode.""" @@ -58,38 +86,6 @@ def test_add(self): mock_to_triton.assert_not_called() mock_compile.assert_not_called() - @pytest.mark.skip(reason="tile.* API is not supported yet") - def test_concat(self): - super().test_concat() - - @pytest.mark.skip(reason="tile.* API is not supported yet") - def test_concat_block_ptr(self): - super().test_concat_block_ptr() - - @pytest.mark.skip(reason="tile.* API is not supported yet") - def test_cross_entropy(self): - super().test_cross_entropy() - - @pytest.mark.skip(reason="tile.* API is not supported yet") - def test_jagged_dense_add(self): - super().test_jagged_dense_add() - - @pytest.mark.skip(reason="tile.* API is not supported yet") - def test_jagged_mean(self): - super().test_jagged_mean() - - @pytest.mark.skip(reason="tile.* API is not supported yet") - def test_matmul_split_k(self): - super().test_matmul_split_k() - - @pytest.mark.skip(reason="tile.* API is not supported yet") - def test_moe_matmul_ogs(self): - super().test_moe_matmul_ogs() - - @pytest.mark.skip(reason="tile.* API is not supported yet") - def test_segment_reduction(self): - super().test_segment_reduction() - class TestRefEagerMisc(TestCase): def test_print_intermediate_tensor(self):