From 9a3d6c68046818e638860098ec0cd4ec677dc327 Mon Sep 17 00:00:00 2001 From: joydddd Date: Thu, 10 Jul 2025 11:50:42 -0700 Subject: [PATCH] [BC breaking] Add MulticastTensor support to hl.signal & hl.wait (as_ptrs) stack-info: PR: https://github.com/pytorch-labs/helion/pull/261, branch: joydddd/stack/13 --- examples/all_gather_matmul.py | 4 - helion/language/multicast_tensor.py | 7 + helion/language/signal_wait.py | 293 ++++++++++++++++------------ helion/runtime/triton_helpers.py | 8 +- test/test_signal_wait.expected | 81 +++++++- test/test_signal_wait.py | 70 ++++++- 6 files changed, 323 insertions(+), 140 deletions(-) diff --git a/examples/all_gather_matmul.py b/examples/all_gather_matmul.py index e93f28bd..0372fac6 100644 --- a/examples/all_gather_matmul.py +++ b/examples/all_gather_matmul.py @@ -96,10 +96,6 @@ def helion_matmul_w_progress( tile_m.begin // (M_per_rank // SPLITS_PER_RANK), ], signal=1, - update=None, - op="ld", - scope="gpu", - sem="acquire", ) for tile_k in hl.tile(K): # TODO(joydddd): use a_shared and skip barrier when data is available on local rank. diff --git a/helion/language/multicast_tensor.py b/helion/language/multicast_tensor.py index 09101637..d60ada4e 100644 --- a/helion/language/multicast_tensor.py +++ b/helion/language/multicast_tensor.py @@ -9,6 +9,8 @@ from . import _decorators if TYPE_CHECKING: + from typing import Sequence + from .._compiler.type_propagation import TypeInfo from .._compiler.variable_origin import Origin @@ -62,6 +64,11 @@ def __setitem__( # pyright ignore[reportIncompatibleMethodOverride] ) -> None: raise exc.NotInsideKernel + def new_empty( + self, *args: Sequence[int | torch.SymInt], **kwargs: dict + ) -> torch.Tensor: + return self.tensor_like.new_empty(*args, **kwargs) # pyright: ignore[reportCallIssue] + def multicast_like( tensor_like: torch.Tensor, diff --git a/helion/language/signal_wait.py b/helion/language/signal_wait.py index 2c474aaa..926a44e6 100644 --- a/helion/language/signal_wait.py +++ b/helion/language/signal_wait.py @@ -3,11 +3,13 @@ from typing import TYPE_CHECKING import torch +from torch._inductor.utils import triton_type from torch.fx import has_side_effect from .. import exc from .._compiler.indexing_strategy import SubscriptIndexing from . import _decorators +from helion.language.multicast_tensor import MulticastTensor if TYPE_CHECKING: import ast @@ -20,25 +22,21 @@ @has_side_effect @_decorators.api(tiles_as_sizes=True, allow_host_tensor=True) def wait( - signal_pad: torch.Tensor, - index: list[object], + signal_pad: torch.Tensor | MulticastTensor, + index: list[object] | None = None, signal: int = 1, update: int | None = None, - op: str = "ld", - sem: str = "acquire", scope: str = "gpu", - skip_sync: bool = False, + hasSubsequentMemAccess: bool = True, ) -> None: """Wait until all entries of the signal_pad slice are equal to the signal value. Args: - signal_pad: The signal pad tensor to wait on + signal_pad: The signal pad tensor / multicast tensor to wait on index: Indices to index into the signal_pad tensor signal: the value to wait for update: Atomically update the signal_pad tensor with this value once the signal is observed. (default: None) - op: The memory op for acquiring the lock (default: 'ld') - sem: The memory semantic for acquiring the lock (default: 'acquire') scope: The scope of the lock (default: 'gpu') - skip_sync: Skip the syncthreads after the wait (default: False) + hasSubsequentMemAccess: Whether the wait is followed by a subsequence memory access (default: True) Returns: None @@ -48,42 +46,23 @@ def wait( @_decorators.prepare_args(wait) def _( - signal_pad: torch.Tensor, + signal_pad: torch.Tensor | MulticastTensor, index: list[object], signal: int = 1, update: int | None = None, - op: str = "ld", - sem: str = "acquire", scope: str = "gpu", - skip_sync: bool = False, -) -> tuple[torch.Tensor, object, int, int | None, str, str, str, bool]: + hasSubsequentMemAccess: bool = True, +) -> tuple[torch.Tensor | tuple[object, ...], object, int, int | None, str, bool]: from .tile_proxy import Tile - valid_ops = {"ld", "atomic_cas"} - valid_sems = {"relaxed", "acquire", "acq_rel"} - valid_scopes = {"sys", "gpu"} - - if op not in valid_ops: - raise ValueError(f"Invalid Wait op '{op}'. Must be one of {valid_ops}. ") - - if sem == "release": - raise ValueError( - f"Do not use '{sem}' for wait patterns. Wait sem must be one of {valid_sems}." - ) - - if sem not in valid_sems: - raise ValueError( - f"Invalid memory semantic '{sem}'. Must be one of {valid_sems}." - ) + assert isinstance(signal_pad, (torch.Tensor, MulticastTensor)) - if op == "atomic_cas" and update is None: - raise ValueError( - f"{op} without an update value. Do you want to use 'ld' instead? " + if signal_pad.dtype not in (torch.int32, torch.uint32): + raise NotImplementedError( + f"Unsupported signal pad dtype: {signal_pad.dtype}. Must be of torch.int32 or torch.uint32." ) - if op == "ld": - assert update is None - update = 0 + valid_scopes = {"sys", "gpu"} if scope not in valid_scopes: raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.") @@ -91,19 +70,20 @@ def _( index = Tile._prepare_index(index) index = Tile._tiles_to_sizes(index) - return (signal_pad, index, signal, update, op, sem, scope, skip_sync) + if isinstance(signal_pad, MulticastTensor): + return (tuple(signal_pad), index, signal, update, scope, hasSubsequentMemAccess) + return (signal_pad, index, signal, update, scope, hasSubsequentMemAccess) @_decorators.register_fake(wait) def _( - signal_pad: torch.Tensor, + signal_pad: torch.Tensor | tuple[object, ...], index: list[object], signal: int = 1, update: int | None = None, - op: str = "ld", - sem: str = "acquire", - scope: str = "sys", - skip_sync: bool = False, + scope: str = "gpu", + hasSubsequentMemAccess: bool = True, + as_ptrs: bool = False, ) -> None: return None @@ -119,39 +99,73 @@ def _(state: CodegenState) -> ast.AST: index = state.proxy_arg(1) signal = state.proxy_arg(2) update = state.proxy_arg(3) - op = state.proxy_arg(4) - sem = state.proxy_arg(5) - scope = state.proxy_arg(6) - skip_sync = state.proxy_arg(7) + scope = state.proxy_arg(4) + has_subsequent_load = state.proxy_arg(5) + + if isinstance(signal_pad, tuple): + signal_pad = MulticastTensor(*signal_pad) - assert isinstance(signal_pad, torch.Tensor) + assert isinstance(signal_pad, (torch.Tensor, MulticastTensor)) assert isinstance(index, (list)) - indices = SubscriptIndexing.create(state, signal_pad, index) - signal_pad_name = state.device_function.tensor_arg(signal_pad).name + assert type(scope) is str - signal_expr = ast.Constant(value=signal) # pyright: ignore[reportArgumentType] - update_expr = ast.Constant(value=update) # pyright: ignore[reportArgumentType] + assert type(has_subsequent_load) is bool - assert type(op) is str - assert type(sem) is str - assert type(scope) is str + sem = "acquire" if has_subsequent_load else "relaxed" + op = "atomic_cas" if update is not None else "ld" + update = 0 if update is None else update + skip_sync = not has_subsequent_load - bar_tensor_shape = SubscriptIndexing.compute_shape(signal_pad, index) - is_scalar = len(bar_tensor_shape) == 0 + if isinstance(signal_pad, torch.Tensor): + indices = SubscriptIndexing.create(state, signal_pad, index) + shape = SubscriptIndexing.compute_shape(signal_pad, index) + signal_pad_name = state.device_function.tensor_arg(signal_pad).name - if is_scalar: - call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={signal_pad_name} + offset, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})" - else: - if signal_pad.dtype not in (torch.int32, torch.uint32): - raise NotImplementedError( - f"Unsupported signal pad dtype: {signal_pad.dtype}. Must be of torch.int32 or torch.uint32." + bar_addrs_expr = expr_from_string( + f"{signal_pad_name} + offset", offset=indices.index_expr + ) + elif isinstance(signal_pad, MulticastTensor): + from .._compiler.indexing_strategy import MulticastIndexingStrategy + + subscript_shape = SubscriptIndexing.compute_shape(signal_pad.tensor_like, index) + multicast_shape = signal_pad.dev_ptrs.shape + shape = subscript_shape + list(multicast_shape) + + multicast_broadcast, tensor_broadcast = ( + MulticastIndexingStrategy.get_broadcast_str( + multicast_shape, subscript_shape ) - call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={signal_pad_name} + offset, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})" + ) + + tensor_like_indices = SubscriptIndexing.create( + state, signal_pad.tensor_like, index + ) + + dtype = triton_type(signal_pad.dtype) + + ast_tensors = state.ast_args[0] + assert isinstance(ast_tensors, tuple) + assert len(ast_tensors) == 2 + tensor_like_ast, dev_ptrs_ast = ast_tensors + bar_addrs_expr = expr_from_string( + f"base.to(tl.pointer_type({dtype})){multicast_broadcast} + offset{tensor_broadcast}", + base=dev_ptrs_ast, + offset=tensor_like_indices.index_expr, + ) + else: + raise NotImplementedError(f"Unsupported signal pad type: {type(signal_pad)}") + + signal_expr = ast.Constant(value=signal) # pyright: ignore[reportArgumentType] + update_expr = ast.Constant(value=update) # pyright: ignore[reportArgumentType] + + is_scalar = len(shape) == 0 + + call_triton_wait_signal = f"helion.runtime.triton_wait_{'' if is_scalar else 'multiple_'}signal(addr=bar_addrs, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})" return expr_from_string( call_triton_wait_signal, - offset=indices.index_expr, + bar_addrs=bar_addrs_expr, signal=signal_expr, update=update_expr, ) @@ -160,62 +174,46 @@ def _(state: CodegenState) -> ast.AST: @has_side_effect @_decorators.api(tiles_as_sizes=True, allow_host_tensor=True) def signal( - signal_pad: torch.Tensor, - index: list[object], + signal_pad: torch.Tensor | MulticastTensor, + index: list[object] | None = None, signal: int = 1, wait_for: int | None = None, - op: str = "atomic_xchg", - sem: str = "release", scope: str = "gpu", - skip_sync: bool = False, + hasPreviousMemAccess: bool = True, ) -> torch.Tensor: """Set the signal_pad slice to the signal value. Args: - signal_pad: The signal pad to signal + signal_pad: The signal pad tensor / multicast tensor to signal index: Indices to index into the signal_pad tensor signal: the value to send - wait_for: The value to wait for before sending the signal. Only valid for op = 'atomic_cas'. - op: The memory op for acquiring the lock (default: 'atomic_xchg') - sem: The memory semantic for acquiring the lock (default: 'release') + wait_for: The value to wait for before sending the signal. scope: The scope of the lock (default: 'gpu') - skip_sync: Skip the syncthreads before sending signal (default: False) + hasPreviousMemAccess: Whether the signal is preceded by a memory access (default: True) + Returns: + The old value of the signal_pad slice before the update. """ raise exc.NotInsideKernel @_decorators.prepare_args(signal) def _( - signal_pad: torch.Tensor, + signal_pad: torch.Tensor | MulticastTensor, index: list[object], signal: int = 1, wait_for: int | None = None, - op: str = "atomic_xchg", - sem: str = "release", scope: str = "gpu", - skip_sync: bool = False, -) -> tuple[torch.Tensor, object, int, int | None, str, str, str, bool]: + hasPreviousMemAccess: bool = True, +) -> tuple[torch.Tensor | tuple, object, int, int | None, str, bool]: from .tile_proxy import Tile - valid_ops = {"atomic_add", "atomic_xchg", "atomic_cas"} - valid_sems = {"relaxed", "release", "acq_rel"} - valid_scopes = {"sys", "gpu"} - - if op not in valid_ops: - raise ValueError(f"Invalid signal op '{op}'. Must be one of {valid_ops}. ") + assert isinstance(signal_pad, (torch.Tensor, MulticastTensor)) - if op == "atomic_cas" and wait_for is None: - raise ValueError( - f"{op} without a wait_for value. Do you want to use 'atomic_add' or 'atomic_xchg' instead? " - ) - if op in {"atomic_add", "atomic_xchg"} and wait_for is not None: - raise ValueError( - f"{op} with a wait_for value. Do you want to use 'atomic_cas' instead? " + if signal_pad.dtype not in (torch.int32, torch.uint32): + raise NotImplementedError( + f"Unsupported signal pad dtype: {signal_pad.dtype}. Must be of torch.int32 or torch.uint32." ) - if sem not in valid_sems: - raise ValueError( - f"Invalid memory semantic '{sem}'. Must be one of {valid_sems}." - ) + valid_scopes = {"sys", "gpu"} if scope not in valid_scopes: raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.") @@ -223,21 +221,31 @@ def _( index = Tile._prepare_index(index) index = Tile._tiles_to_sizes(index) - return (signal_pad, index, signal, wait_for, op, sem, scope, skip_sync) + if isinstance(signal_pad, MulticastTensor): + return (tuple(signal_pad), index, signal, wait_for, scope, hasPreviousMemAccess) + return (signal_pad, index, signal, wait_for, scope, hasPreviousMemAccess) @_decorators.register_fake(signal) def _( - signal_pad: torch.Tensor, + signal_pad: torch.Tensor | tuple, index: list[object], signal: int = 1, wait_for: int | None = None, - op: str = "atomic_xchg", - sem: str = "release", scope: str = "gpu", - skip_sync: bool = False, + hasPreviousMemAccess: bool = True, ) -> torch.Tensor: - return signal_pad.new_empty(SubscriptIndexing.compute_shape(signal_pad, index)) + if isinstance(signal_pad, tuple): + signal_pad = MulticastTensor(*signal_pad) + multicast_shape = signal_pad.dev_ptrs.shape + subscript_shape = SubscriptIndexing.compute_shape(signal_pad.tensor_like, index) + shape = list(multicast_shape) + subscript_shape + elif isinstance(signal_pad, torch.Tensor): + shape = SubscriptIndexing.compute_shape(signal_pad, index) + else: + raise NotImplementedError(f"Unsupported signal pad type: {type(signal_pad)}") + + return signal_pad.new_empty(shape) @_decorators.codegen(signal) @@ -251,16 +259,62 @@ def _(state: CodegenState) -> ast.AST: index = state.proxy_arg(1) signal = state.proxy_arg(2) wait_for = state.proxy_arg(3) - op = state.proxy_arg(4) - sem = state.proxy_arg(5) - scope = state.proxy_arg(6) - skip_sync = state.proxy_arg(7) + scope = state.proxy_arg(4) + hasPreviousMemAccess = state.proxy_arg(5) - assert isinstance(signal_pad, torch.Tensor) + if isinstance(signal_pad, tuple): + signal_pad = MulticastTensor(*signal_pad) + assert isinstance(signal_pad, (torch.Tensor, MulticastTensor)) assert isinstance(index, list) - indices = SubscriptIndexing.create(state, signal_pad, index) - signal_pad_name = state.device_function.tensor_arg(signal_pad).name + assert type(scope) is str + + assert type(hasPreviousMemAccess) is bool + + sem = "release" if hasPreviousMemAccess else "relaxed" + op = "atomic_xchg" if wait_for is None else "atomic_cas" + skip_sync = not hasPreviousMemAccess + + if isinstance(signal_pad, torch.Tensor): + indices = SubscriptIndexing.create(state, signal_pad, index) + shape = SubscriptIndexing.compute_shape(signal_pad, index) + signal_pad_name = state.device_function.tensor_arg(signal_pad).name + + bar_addrs_expr = expr_from_string( + f"{signal_pad_name} + offset", offset=indices.index_expr + ) + elif isinstance(signal_pad, MulticastTensor): + from .._compiler.indexing_strategy import MulticastIndexingStrategy + + subscript_shape = SubscriptIndexing.compute_shape(signal_pad.tensor_like, index) + multicast_shape = signal_pad.dev_ptrs.shape + shape = subscript_shape + list(multicast_shape) + + multicast_broadcast, tensor_broadcast = ( + MulticastIndexingStrategy.get_broadcast_str( + multicast_shape, subscript_shape + ) + ) + + tensor_like_indices = SubscriptIndexing.create( + state, signal_pad.tensor_like, index + ) + + dtype = triton_type(signal_pad.dtype) + + ast_tensors = state.ast_args[0] + assert isinstance(ast_tensors, tuple) + assert len(ast_tensors) == 2 + tensor_like_ast, dev_ptrs_ast = ast_tensors + bar_addrs_expr = expr_from_string( + f"base.to(tl.pointer_type({dtype})){multicast_broadcast} + offset{tensor_broadcast}", + base=dev_ptrs_ast, + offset=tensor_like_indices.index_expr, + ) + else: + raise NotImplementedError(f"Unsupported signal pad type: {type(signal_pad)}") + + is_scalar = len(shape) == 0 signal_expr = ast.Constant(value=signal) # pyright: ignore[reportArgumentType] if wait_for is not None: @@ -268,30 +322,19 @@ def _(state: CodegenState) -> ast.AST: else: wait_for_expr = ast.Constant(value=0) skip_sync_expr = ast.Constant(value=skip_sync) # pyright: ignore[reportArgumentType] - assert type(op) is str - assert type(sem) is str - assert type(scope) is str - - if op == "atomic_cas": - bar_tensor_shape = SubscriptIndexing.compute_shape(signal_pad, index) - is_scalar = len(bar_tensor_shape) == 0 - if is_scalar: - call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={signal_pad_name} + offset, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync))" - else: - call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={signal_pad_name} + offset, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync))" + if wait_for is not None: + call_triton_wait_signal = f"helion.runtime.triton_wait_{'' if is_scalar else 'multiple_'}signal(addr=bar_addrs, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync))" return expr_from_string( call_triton_wait_signal, - offset=indices.index_expr, + bar_addrs=bar_addrs_expr, wait_for=wait_for_expr, signal=signal_expr, skip_sync=skip_sync_expr, ) - call_triton_send_signal = f"helion.runtime.triton_send_signal(addr={signal_pad_name} + offset, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=skip_sync)" - return expr_from_string( - call_triton_send_signal, - offset=indices.index_expr, + f"helion.runtime.triton_send_signal(addr=bar_addrs, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=skip_sync)", + bar_addrs=bar_addrs_expr, signal=signal_expr, skip_sync=skip_sync_expr, ) diff --git a/helion/runtime/triton_helpers.py b/helion/runtime/triton_helpers.py index 1e93d446..87354d45 100644 --- a/helion/runtime/triton_helpers.py +++ b/helion/runtime/triton_helpers.py @@ -109,9 +109,6 @@ def triton_wait_signal( # Triton generates smem broadcasting of tl.atomic_add return value in ptx, # but it is optimized away by ptxas in SASS, hence no performance overhead. if op == "ld": - tl.static_assert( - update == 0, "ld wait on gmem_barriers cannot update the lock. " - ) while tl.atomic_add(addr, 0, sem=sem, scope=scope) != expect: pass elif op == "atomic_cas": @@ -171,6 +168,11 @@ def triton_wait_multiple_signal( "Invalid barrier value type. Only supports int32 for multi barrier signal. ", ) + if sync_before: + tl.inline_asm_elementwise( + "bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1 + ) + addr = tl.ravel(addr) tl.static_assert(len(addr.shape) == 1, "addr must be a 1D tensor. ") diff --git a/test/test_signal_wait.expected b/test/test_signal_wait.expected index fff0ebce..d0006a62 100644 --- a/test/test_signal_wait.expected +++ b/test/test_signal_wait.expected @@ -16,7 +16,7 @@ def _gmem_multi_bar_sync_kernel_kernel(signal_pad, signal_pad_stride_0, signal_p offset_0 = pid_0 for offset_1 in tl.range(0, N.to(tl.int32), _BLOCK_SIZE_1): indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) - helion.runtime.triton_send_signal(addr=signal_pad + (indices_1 * signal_pad_stride_0 + offset_0 * signal_pad_stride_1), update=1, sem='release', scope='gpu', op='atomic_xchg', skip_sync=True) + helion.runtime.triton_send_signal(addr=signal_pad + (indices_1 * signal_pad_stride_0 + offset_0 * signal_pad_stride_1), update=1, sem='relaxed', scope='gpu', op='atomic_xchg', skip_sync=True) helion.runtime.triton_wait_multiple_signal(addr=signal_pad + (offset_0 * signal_pad_stride_0 + indices_1 * signal_pad_stride_1), expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False) def gmem_multi_bar_sync_kernel(signal_pad: torch.Tensor, *, _launcher=_default_launcher): @@ -66,6 +66,29 @@ def gmem_signal_cas_kernel(signal_pad: torch.Tensor, *, _launcher=_default_launc _launcher(_gmem_signal_cas_kernel_kernel, (n,), signal_pad, signal_pad.stride(0), num_warps=4, num_stages=3) return signal_pad +--- assertExpectedJournal(TestWait.test_signal_multicast_signalpad) +from __future__ import annotations + +import torch +import helion +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _gmem_signal_pointers_kernel_kernel(signal_pad_ptrs, signal_pad_ptrs_size_0, example_stride_0, signal_pad_ptrs_stride_0, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) + mask_1 = indices_1 < signal_pad_ptrs_size_0 + ptr_tile = tl.load(signal_pad_ptrs + indices_1 * signal_pad_ptrs_stride_0, mask_1, other=0) + helion.runtime.triton_send_signal(addr=ptr_tile.to(tl.pointer_type(tl.int32))[:] + (offset_0 * example_stride_0)[None], update=1, sem='release', scope='gpu', op='atomic_xchg', skip_sync=False) + +def gmem_signal_pointers_kernel(signal_pad_ptrs: torch.Tensor, example: torch.Tensor, *, _launcher=_default_launcher): + _RDIM_SIZE_1 = triton.next_power_of_2(signal_pad_ptrs.size(0)) + _launcher(_gmem_signal_pointers_kernel_kernel, (example.size(0),), signal_pad_ptrs, signal_pad_ptrs.size(0), example.stride(0), signal_pad_ptrs.stride(0), _RDIM_SIZE_1, num_warps=4, num_stages=3) + return signal_pad_ptrs + --- assertExpectedJournal(TestWait.test_signal_multiple) from __future__ import annotations @@ -215,3 +238,59 @@ def gmem_wait_multi_bar_kernel_cas(signal_pad: torch.Tensor, *, _launcher=_defau _BLOCK_SIZE_0 = 4 _launcher(_gmem_wait_multi_bar_kernel_cas_kernel, (triton.cdiv(N, _BLOCK_SIZE_0),), signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) return signal_pad + +--- assertExpectedJournal(TestWait.test_wait_multicast_signalpad) +from __future__ import annotations + +import torch +import helion +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _gmem_wait_pointers_kernel_kernel(signal_pad_ptrs, out, signal_pad_ptrs_size_0, example_stride_0, out_stride_0, signal_pad_ptrs_stride_0, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) + mask_1 = indices_1 < signal_pad_ptrs_size_0 + dev_tile = tl.load(signal_pad_ptrs + indices_1 * signal_pad_ptrs_stride_0, mask_1, other=0) + helion.runtime.triton_wait_multiple_signal(addr=dev_tile.to(tl.pointer_type(tl.int32))[:] + (offset_0 * example_stride_0)[None], expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False) + tl.store(out + offset_0 * out_stride_0, offset_0, None) + +def gmem_wait_pointers_kernel(signal_pad_ptrs: torch.Tensor, example: torch.Tensor, *, _launcher=_default_launcher): + out = torch.empty_like(example) + _RDIM_SIZE_1 = triton.next_power_of_2(signal_pad_ptrs.size(0)) + _launcher(_gmem_wait_pointers_kernel_kernel, (example.size(0),), signal_pad_ptrs, out, signal_pad_ptrs.size(0), example.stride(0), out.stride(0), signal_pad_ptrs.stride(0), _RDIM_SIZE_1, num_warps=4, num_stages=3) + return out + +--- assertExpectedJournal(TestWait.test_wait_pointers) +from __future__ import annotations + +import torch +import helion +import helion.language as hl +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _gmem_wait_pointers_kernel_kernel(signal_pad_ptrs, out, out_stride_0, signal_pad_ptrs_stride_0, N, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + for offset_1 in tl.range(0, N.to(tl.int32), _BLOCK_SIZE_1): + indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_1 < N + load = tl.load(signal_pad_ptrs + indices_1 * signal_pad_ptrs_stride_0, mask_1, other=0) + symnode_0 = 4 * offset_0 + v_0 = symnode_0.to(tl.uint64) + v_1 = load + v_0 + helion.runtime.triton_wait_multiple_signal(addr=v_1.to(tl.pointer_type(tl.int32)), expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False) + tl.store(out + offset_0 * out_stride_0, offset_0, None) + +def gmem_wait_pointers_kernel(signal_pad_ptrs: torch.Tensor, pad_shape: hl.constexpr, *, _launcher=_default_launcher): + out = torch.empty(4, device=signal_pad_ptrs.device, dtype=torch.int32) + N = signal_pad_ptrs.size(0) + _BLOCK_SIZE_1 = N + _launcher(_gmem_wait_pointers_kernel_kernel, (4,), signal_pad_ptrs, out, out.stride(0), signal_pad_ptrs.stride(0), N, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return out diff --git a/test/test_signal_wait.py b/test/test_signal_wait.py index dbff6046..b44596d7 100644 --- a/test/test_signal_wait.py +++ b/test/test_signal_wait.py @@ -84,7 +84,7 @@ def gmem_wait_multi_bar_kernel_cas(signal_pad: torch.Tensor) -> torch.Tensor: n = hl.register_block_size(N) for tile in hl.tile(N, block_size=n): - hl.wait(signal_pad, [tile], signal=1, update=2, op="atomic_cas") + hl.wait(signal_pad, [tile], signal=1, update=2) return signal_pad @@ -118,7 +118,7 @@ def test_signal_cas(self): def gmem_signal_cas_kernel(signal_pad: torch.Tensor) -> torch.Tensor: (n,) = signal_pad.shape for i in hl.grid(n): - hl.signal(signal_pad, [i], signal=1, wait_for=0, op="atomic_cas") + hl.signal(signal_pad, [i], signal=1, wait_for=0) return signal_pad signal_pad = torch.zeros(4, device=DEVICE, dtype=torch.int32) @@ -152,7 +152,7 @@ def test_signal_multiple_cas(self): def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor: (n,) = signal_pad.shape for tile in hl.tile(n): - hl.signal(signal_pad, [tile], wait_for=0, signal=1, op="atomic_cas") + hl.signal(signal_pad, [tile], wait_for=0, signal=1) return signal_pad signal_pad = torch.zeros(16, device=DEVICE, dtype=torch.int32) @@ -192,7 +192,9 @@ def gmem_multi_bar_sync_kernel(signal_pad: torch.Tensor) -> torch.Tensor: assert M == N for i in hl.grid(N): for tile in hl.tile(N, block_size=N): - hl.signal(signal_pad, [tile, i], signal=1, skip_sync=True) + hl.signal( + signal_pad, [tile, i], signal=1, hasPreviousMemAccess=False + ) hl.wait(signal_pad, [i, tile], signal=1) return signal_pad @@ -216,10 +218,9 @@ def gmem_multi_bar_sync_kernel(signal_pad: torch.Tensor) -> torch.Tensor: [tile, i], signal=1, wait_for=0, - skip_sync=True, - op="atomic_cas", + hasPreviousMemAccess=False, ) - hl.wait(signal_pad, [i, tile], signal=1, update=2, op="atomic_cas") + hl.wait(signal_pad, [i, tile], signal=1, update=2) return signal_pad signal_pad = torch.zeros(4, 4, device=DEVICE, dtype=torch.int32) @@ -230,6 +231,61 @@ def gmem_multi_bar_sync_kernel(signal_pad: torch.Tensor) -> torch.Tensor: ) self.assertIn("atomic_cas", code) + def test_wait_multicast_signalpad(self): + @helion.kernel + def gmem_wait_pointers_kernel( + signal_pad_ptrs: torch.Tensor, example: torch.Tensor + ) -> torch.Tensor: + out = torch.empty_like(example) + for i in hl.grid(example.size(0)): + dev_tile = signal_pad_ptrs[:] + multicast_tensor = hl.multicast_like(example, dev_tile) + hl.wait(multicast_tensor, [i], signal=1) + out[i] = i + return out + + signal_pad_list = [ + torch.ones(4, device=DEVICE, dtype=torch.int32) for _ in range(4) + ] + signal_pad_ptrs = torch.as_tensor( + [p.data_ptr() for p in signal_pad_list], device=DEVICE, dtype=torch.uint64 + ) + code, result = code_and_output( + gmem_wait_pointers_kernel, (signal_pad_ptrs, signal_pad_list[0]) + ) + torch.testing.assert_close( + result, torch.arange(4, device=DEVICE, dtype=torch.int32) + ) + self.assertExpectedJournal(code) + + def test_signal_multicast_signalpad(self): + @helion.kernel + def gmem_signal_pointers_kernel( + signal_pad_ptrs: torch.Tensor, + example: torch.Tensor, + ) -> torch.Tensor: + for i in hl.grid(example.size(0)): + ptr_tile = signal_pad_ptrs[:] + multicast_signal_pad = hl.multicast_like(example, ptr_tile) + hl.signal(multicast_signal_pad, [i], signal=1) + return signal_pad_ptrs + + signal_pad_list = [ + torch.zeros(4, device=DEVICE, dtype=torch.int32) for _ in range(4) + ] + signal_pad_ptrs = torch.as_tensor( + [p.data_ptr() for p in signal_pad_list], device=DEVICE, dtype=torch.uint64 + ) + code, result = code_and_output( + gmem_signal_pointers_kernel, (signal_pad_ptrs, signal_pad_list[0]) + ) + + for tensor in signal_pad_list: + torch.testing.assert_close( + tensor, torch.ones(4, device=DEVICE, dtype=torch.int32) + ) + self.assertExpectedJournal(code) + if __name__ == "__main__": unittest.main()