Skip to content

Commit 2379935

Browse files
committed
PyTorch reference mode (both eager and torch.compile)
Fixes #77. stack-info: PR: #339, branch: yf225/stack/34
1 parent 41fe6e9 commit 2379935

16 files changed

+1392
-29
lines changed

examples/concatenate.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,17 @@ def concat2d_dim1(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
1515
)
1616
for tile0, tile1 in hl.tile(out.size()):
1717
# Most masking is automatic in helion, but tile1 spans both x and y we need to do some manual masking
18+
tile1_indices = hl.tile_index(tile1)
1819
x_part = hl.load(
19-
x, [tile0, tile1], extra_mask=(tile1.index < x.size(1))[None, :]
20+
x, [tile0, tile1], extra_mask=(tile1_indices < x.size(1))[None, :]
2021
)
2122
y_part = hl.load(
2223
y,
23-
[tile0, tile1.index - x.size(1)],
24-
extra_mask=(tile1.index >= x.size(1))[None, :],
24+
[tile0, tile1_indices - x.size(1)],
25+
extra_mask=(tile1_indices >= x.size(1))[None, :],
2526
)
2627
out[tile0, tile1] = torch.where(
27-
(tile1.index < x.size(1))[None, :], x_part, y_part
28+
(tile1_indices < x.size(1))[None, :], x_part, y_part
2829
)
2930
return out
3031

examples/cross_entropy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ def cross_entropy(
2828
for tile_n in hl.tile(n):
2929
# Get data for this tile
3030
labels_tile = labels[tile_n] # [tile_size]
31-
base_indices_tile = tile_n.index * v # [tile_size]
31+
tile_n_indices = hl.tile_index(tile_n)
32+
base_indices_tile = tile_n_indices * v # [tile_size]
3233

3334
# Compute the actual flat indices by adding the label offset
3435
flat_indices = base_indices_tile + labels_tile

examples/fp8_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def fp8_attention_kernel(
2222
head_dim = q.size(2)
2323

2424
# Output tensor with 4D shape in FP8 format
25-
out = torch.empty(
25+
out = torch.zeros(
2626
[batch, heads, seq_len, head_dim], dtype=torch.float8_e5m2, device=q.device
2727
)
2828

examples/jagged_dense_add.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,19 @@ def jagged_dense_add_2d(
4444
out = torch.zeros_like(y)
4545
for tile0 in hl.tile(num_rows):
4646
starts = x_offsets[tile0]
47-
ends = x_offsets[tile0.index + 1]
47+
tile0_indices = hl.tile_index(tile0)
48+
ends = x_offsets[tile0_indices + 1]
4849
nnz = ends - starts
4950
max_nnz = nnz.amax()
5051
# Note, the dynamic loop bounds aren't strictly necessary for this example, since
5152
# the output is dense, and we iterate over the rest in the next loop. However,
5253
# it is useful to illustrate how more complex jagged+jagged ops can be handled.
5354
for tile1 in hl.tile(0, max_nnz):
55+
tile1_indices = hl.tile_index(tile1)
5456
x_slice = hl.load(
5557
x_data,
56-
[starts[:, None] + tile1.index[None, :]],
57-
extra_mask=tile1.index[None, :] < nnz[:, None],
58+
[starts[:, None] + tile1_indices[None, :]],
59+
extra_mask=tile1_indices[None, :] < nnz[:, None],
5860
)
5961
out[tile0, tile1] = y[tile0, tile1] + x_slice
6062
for tile1 in hl.tile(max_nnz, out.size(1)):

examples/jagged_mean.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ def jagged_mean_kernel(
4848
# Process rows in tiles
4949
for tile_b in hl.tile(num_rows):
5050
starts = x_offsets[tile_b]
51-
ends = x_offsets[tile_b.index + 1]
51+
tile_b_indices = hl.tile_index(tile_b)
52+
ends = x_offsets[tile_b_indices + 1]
5253
nnz = ends - starts
5354
max_nnz = nnz.amax()
5455

@@ -58,21 +59,23 @@ def jagged_mean_kernel(
5859
# Process features in tiles
5960
for tile_m in hl.tile(max_M):
6061
# Create mask for valid features
61-
feature_valid = tile_m.index < feature_counts[:, None]
62+
tile_m_indices = hl.tile_index(tile_m)
63+
feature_valid = tile_m_indices < feature_counts[:, None]
6264

6365
# Initialize accumulator
6466
row_sums = hl.zeros([tile_b, tile_m], dtype=x_data.dtype)
6567

6668
# Process elements within each row
6769
for tile_k in hl.tile(0, max_nnz):
6870
# Compute flattened indices
69-
base_indices = starts[:, None] + tile_k.index[None, :]
71+
tile_k_indices = hl.tile_index(tile_k)
72+
base_indices = starts[:, None] + tile_k_indices[None, :]
7073
flat_indices = (
71-
base_indices[:, :, None] * max_M + tile_m.index[None, None, :]
74+
base_indices[:, :, None] * max_M + tile_m_indices[None, None, :]
7275
)
7376

7477
# Combined mask: valid row element AND valid feature
75-
row_mask = tile_k.index[None, :] < nnz[:, None]
78+
row_mask = tile_k_indices[None, :] < nnz[:, None]
7679
combined_mask = row_mask[:, :, None] & feature_valid[:, None, :]
7780

7881
x_slice = hl.load(

examples/matmul_split_k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
2121
k_block = helion.next_power_of_2(helion.cdiv(k, split_k))
2222
for tile_m, tile_n, outer_k in hl.tile([m, n, k], block_size=[None, None, k_block]):
2323
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
24-
for inner_k in hl.tile(outer_k.begin, outer_k.end):
24+
for inner_k in hl.tile(hl.tile_begin(outer_k), hl.tile_end(outer_k)):
2525
acc = torch.addmm(acc, x[tile_m, inner_k], y[inner_k, tile_n])
2626
hl.atomic_add(out, [tile_m, tile_n], acc)
2727
return out

examples/moe_matmul_ogs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def moe_matmul_ogs(
4646
for tile_t, tile_n in hl.tile([max_T_per_expert, N]):
4747
# Get local token offsets for this tile
4848
# (i.e. the tile's corresponding chunk in [0 .. max_T_per_expert-1] token range)
49-
local_token_offsets = tile_t.index # [BLOCK_T]
49+
local_token_offsets = hl.tile_index(tile_t) # [BLOCK_T]
5050

5151
# Create mask for valid tokens (some tiles may be partially filled)
5252
token_valid = local_token_offsets < num_tokens # bool[BLOCK_T]

examples/segment_reduction.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@ def segmented_reduction_helion(
3434
for tile_e, tile_f in hl.tile([num_elements, num_features]):
3535
vals = input_data[tile_e, tile_f]
3636
idxs = indices[tile_e]
37+
tile_e_indices = hl.tile_index(tile_e)
3738
idxs_next = hl.load(
38-
indices, [tile_e.index + 1], extra_mask=tile_e.index < num_elements - 1
39+
indices, [tile_e_indices + 1], extra_mask=tile_e_indices < num_elements - 1
3940
)
4041
tuple_in = (vals, idxs.float().unsqueeze(1).expand_as(vals))
4142
out_vals, _ = hl.associative_scan(combine_fn_helion, tuple_in, dim=0)
42-
mask = (idxs != idxs_next) | (
43-
tile_e.index % tile_e.block_size == tile_e.block_size - 1
44-
)
43+
block_size = hl.tile_block_size(tile_e)
44+
mask = (idxs != idxs_next) | (tile_e_indices % block_size == block_size - 1)
4545
segment_vals = torch.where(mask.unsqueeze(1), out_vals, 0.0)
4646
hl.atomic_add(output, [idxs, tile_f], segment_vals)
4747
return output

helion/_testing.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ def code_and_output(
4545
args: tuple[object, ...],
4646
**kwargs: object,
4747
) -> tuple[str, object]:
48+
bound = fn.bind(args)
49+
if bound.ref_eager or bound.ref_compile:
50+
result = fn(*args)
51+
return "", result
52+
4853
if kwargs:
4954
config = Config(
5055
**kwargs # pyright: ignore[reportArgumentType]
@@ -306,6 +311,13 @@ def assertExpectedJournal(self, value: str) -> None:
306311
Note:
307312
Use EXPECTTEST_ACCEPT=1 environment variable to update expected outputs.
308313
"""
314+
# Skip expected code checks in ref modes since they use the exact same code as original Helion kernel.
315+
if (
316+
os.environ.get("HELION_REF_EAGER") == "1"
317+
or os.environ.get("HELION_REF_COMPILE") == "1"
318+
):
319+
return
320+
309321
value, expected = self._expected_journal.lookup(self.id(), value)
310322
self.assertMultiLineEqual(
311323
value,

helion/ref/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from __future__ import annotations
2+
3+
from . import hl_patch
4+
from . import torch_patch
5+
6+
__all__ = ["hl_patch", "torch_patch"]

0 commit comments

Comments
 (0)