Skip to content

Commit c534d9e

Browse files
committed
test
1 parent 078e708 commit c534d9e

File tree

2 files changed

+600
-0
lines changed

2 files changed

+600
-0
lines changed

test/test_indexing.expected

Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,349 @@ def broadcast_add_3d(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor,
285285
# src[test_indexing.py:N]: return out
286286
return out
287287

288+
--- assertExpectedJournal(TestIndexing.test_indirect_indexing_2d_direct_gather)
289+
from __future__ import annotations
290+
291+
import torch
292+
import triton
293+
import triton.language as tl
294+
from helion.runtime import default_launcher as _default_launcher
295+
296+
@triton.jit
297+
def _helion_test(col, B, val, C, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
298+
# src[test_indexing.py:N]: for tile_m, tile_n in hl.tile([M, N]):
299+
num_blocks_0 = tl.cdiv(32, _BLOCK_SIZE_0)
300+
pid_0 = tl.program_id(0) % num_blocks_0
301+
pid_1 = tl.program_id(0) // num_blocks_0
302+
offset_0 = pid_0 * _BLOCK_SIZE_0
303+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
304+
offset_1 = pid_1 * _BLOCK_SIZE_1
305+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
306+
# src[test_indexing.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
307+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
308+
# src[test_indexing.py:N]: for tile_k in hl.tile(K):
309+
# src[test_indexing.py:N]: cols_2d = col[tile_m, tile_k]
310+
# src[test_indexing.py:N]: B_slice = B[cols_2d[:, :, None], tile_n.index[None, None, :]]
311+
# src[test_indexing.py:N-N]: ...
312+
for offset_3 in tl.range(0, 16, _BLOCK_SIZE_2):
313+
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
314+
acc_copy = acc
315+
acc_copy_0 = acc_copy
316+
# src[test_indexing.py:N]: cols_2d = col[tile_m, tile_k]
317+
cols_2d = tl.load(col + (indices_0[:, None] * 16 + indices_3[None, :] * 1), None)
318+
# src[test_indexing.py:N]: B_slice = B[cols_2d[:, :, None], tile_n.index[None, None, :]]
319+
subscript = cols_2d[:, :, None]
320+
load_1 = indices_1[None, None, :]
321+
B_slice = tl.load(B + (subscript * 24 + load_1 * 1), None)
322+
# src[test_indexing.py:N]: vals_2d = val[tile_m, tile_k]
323+
vals_2d = tl.load(val + (indices_0[:, None] * 16 + indices_3[None, :] * 1), None)
324+
# src[test_indexing.py:N]: contrib = vals_2d[:, :, None] * B_slice
325+
subscript_1 = vals_2d[:, :, None]
326+
v_0 = subscript_1 * B_slice
327+
# src[test_indexing.py:N]: contrib = contrib.sum(dim=1)
328+
contrib_1 = tl.cast(tl.sum(v_0, 1), tl.float32)
329+
# src[test_indexing.py:N]: acc = acc + contrib
330+
acc = acc_copy_0 + contrib_1
331+
# src[test_indexing.py:N]: C[tile_m, tile_n] = acc.to(out_dtype)
332+
tl.store(C + (indices_0[:, None] * 24 + indices_1[None, :] * 1), acc, None)
333+
334+
def test(col: torch.Tensor, val: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher):
335+
# src[test_indexing.py:N]: M, K = col.shape
336+
M, K = col.shape
337+
# src[test_indexing.py:N]: _, N = B.shape
338+
_, N = B.shape
339+
# src[test_indexing.py:N]: out_dtype = torch.promote_types(val.dtype, B.dtype)
340+
out_dtype = torch.promote_types(val.dtype, B.dtype)
341+
# src[test_indexing.py:N]: C = torch.empty((M, N), dtype=out_dtype, device=B.device)
342+
C = torch.empty((M, N), dtype=out_dtype, device=B.device)
343+
# src[test_indexing.py:N]: for tile_m, tile_n in hl.tile([M, N]):
344+
_BLOCK_SIZE_0 = 8
345+
_BLOCK_SIZE_1 = 8
346+
# src[test_indexing.py:N]: for tile_k in hl.tile(K):
347+
# src[test_indexing.py:N]: cols_2d = col[tile_m, tile_k]
348+
# src[test_indexing.py:N]: B_slice = B[cols_2d[:, :, None], tile_n.index[None, None, :]]
349+
# src[test_indexing.py:N-N]: ...
350+
_BLOCK_SIZE_2 = 4
351+
# src[test_indexing.py:N]: for tile_m, tile_n in hl.tile([M, N]):
352+
# src[test_indexing.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
353+
# src[test_indexing.py:N-N]: ...
354+
_RDIM_SIZE_3 = triton.next_power_of_2(_BLOCK_SIZE_1)
355+
_launcher(_helion_test, (triton.cdiv(32, _BLOCK_SIZE_0) * triton.cdiv(24, _BLOCK_SIZE_1),), col, B, val, C, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1)
356+
# src[test_indexing.py:N]: return C
357+
return C
358+
359+
--- assertExpectedJournal(TestIndexing.test_indirect_indexing_2d_flat_load)
360+
from __future__ import annotations
361+
362+
import torch
363+
import triton
364+
import triton.language as tl
365+
from helion.runtime import default_launcher as _default_launcher
366+
367+
@triton.jit
368+
def _helion_test(col, B_flat, val, C, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
369+
# src[test_indexing.py:N]: for tile_m, tile_n in hl.tile([M, N]):
370+
num_blocks_0 = tl.cdiv(32, _BLOCK_SIZE_0)
371+
pid_0 = tl.program_id(0) % num_blocks_0
372+
pid_1 = tl.program_id(0) // num_blocks_0
373+
offset_0 = pid_0 * _BLOCK_SIZE_0
374+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
375+
offset_1 = pid_1 * _BLOCK_SIZE_1
376+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
377+
# src[test_indexing.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
378+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
379+
# src[test_indexing.py:N]: for tile_k in hl.tile(K):
380+
# src[test_indexing.py:N]: cols_2d = col[tile_m, tile_k]
381+
# src[test_indexing.py:N]: B_indices = (cols_2d * N)[:, :, None] + tile_n.index[None, None, :]
382+
# src[test_indexing.py:N-N]: ...
383+
for offset_3 in tl.range(0, 16, _BLOCK_SIZE_2):
384+
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
385+
acc_copy = acc
386+
acc_copy_0 = acc_copy
387+
# src[test_indexing.py:N]: cols_2d = col[tile_m, tile_k]
388+
cols_2d = tl.load(col + (indices_0[:, None] * 16 + indices_3[None, :] * 1), None)
389+
# src[test_indexing.py:N]: B_indices = (cols_2d * N)[:, :, None] + tile_n.index[None, None, :]
390+
v_0 = tl.full([], 24, tl.int64)
391+
v_1 = tl.cast(cols_2d * v_0, tl.int64)
392+
subscript = v_1[:, :, None]
393+
load_1 = indices_1[None, None, :]
394+
v_2 = tl.cast(load_1, tl.int64)
395+
v_3 = subscript + v_2
396+
# src[test_indexing.py:N]: B_slice = hl.load(B_flat, [B_indices])
397+
B_slice = tl.load(B_flat + v_3 * 1, None)
398+
# src[test_indexing.py:N]: vals_2d = val[tile_m, tile_k]
399+
vals_2d = tl.load(val + (indices_0[:, None] * 16 + indices_3[None, :] * 1), None)
400+
# src[test_indexing.py:N]: contrib = vals_2d[:, :, None] * B_slice
401+
subscript_1 = vals_2d[:, :, None]
402+
v_4 = subscript_1 * B_slice
403+
# src[test_indexing.py:N]: contrib = contrib.sum(dim=1)
404+
contrib_1 = tl.cast(tl.sum(v_4, 1), tl.float32)
405+
# src[test_indexing.py:N]: acc = acc + contrib
406+
acc = acc_copy_0 + contrib_1
407+
# src[test_indexing.py:N]: C[tile_m, tile_n] = acc.to(out_dtype)
408+
tl.store(C + (indices_0[:, None] * 24 + indices_1[None, :] * 1), acc, None)
409+
410+
def test(col: torch.Tensor, val: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher):
411+
# src[test_indexing.py:N]: M, K = col.shape
412+
M, K = col.shape
413+
# src[test_indexing.py:N]: _, N = B.shape
414+
_, N = B.shape
415+
# src[test_indexing.py:N]: out_dtype = torch.promote_types(val.dtype, B.dtype)
416+
out_dtype = torch.promote_types(val.dtype, B.dtype)
417+
# src[test_indexing.py:N]: C = torch.empty((M, N), dtype=out_dtype, device=B.device)
418+
C = torch.empty((M, N), dtype=out_dtype, device=B.device)
419+
# src[test_indexing.py:N]: B_flat = B.reshape(-1) # [K*N]
420+
B_flat = B.reshape(-1)
421+
# src[test_indexing.py:N]: for tile_m, tile_n in hl.tile([M, N]):
422+
_BLOCK_SIZE_0 = 8
423+
_BLOCK_SIZE_1 = 8
424+
# src[test_indexing.py:N]: for tile_k in hl.tile(K):
425+
# src[test_indexing.py:N]: cols_2d = col[tile_m, tile_k]
426+
# src[test_indexing.py:N]: B_indices = (cols_2d * N)[:, :, None] + tile_n.index[None, None, :]
427+
# src[test_indexing.py:N-N]: ...
428+
_BLOCK_SIZE_2 = 4
429+
# src[test_indexing.py:N]: for tile_m, tile_n in hl.tile([M, N]):
430+
# src[test_indexing.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
431+
# src[test_indexing.py:N-N]: ...
432+
_RDIM_SIZE_3 = triton.next_power_of_2(_BLOCK_SIZE_1)
433+
_launcher(_helion_test, (triton.cdiv(32, _BLOCK_SIZE_0) * triton.cdiv(24, _BLOCK_SIZE_1),), col, B_flat, val, C, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1)
434+
# src[test_indexing.py:N]: return C
435+
return C
436+
437+
--- assertExpectedJournal(TestIndexing.test_indirect_indexing_3d_direct_gather)
438+
from __future__ import annotations
439+
440+
import torch
441+
import triton
442+
import triton.language as tl
443+
from helion.runtime import default_launcher as _default_launcher
444+
445+
@triton.jit
446+
def _helion_test(col, B, val, C, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE_4: tl.constexpr):
447+
# src[test_indexing.py:N]: for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]):
448+
num_blocks_0 = tl.cdiv(16, _BLOCK_SIZE_0)
449+
num_blocks_1 = tl.cdiv(12, _BLOCK_SIZE_1)
450+
num_blocks_2 = tl.cdiv(10, _BLOCK_SIZE_2)
451+
pid_0 = tl.program_id(0) % num_blocks_0
452+
pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
453+
pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1) % num_blocks_2
454+
pid_3 = tl.program_id(0) // (num_blocks_0 * num_blocks_1 * num_blocks_2)
455+
offset_0 = pid_0 * _BLOCK_SIZE_0
456+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
457+
offset_1 = pid_1 * _BLOCK_SIZE_1
458+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
459+
offset_2 = pid_2 * _BLOCK_SIZE_2
460+
indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
461+
mask_2 = indices_2 < 10
462+
offset_3 = pid_3 * _BLOCK_SIZE_3
463+
indices_3 = (offset_3 + tl.arange(0, _BLOCK_SIZE_3)).to(tl.int32)
464+
mask_3 = indices_3 < 14
465+
# src[test_indexing.py:N]: acc = hl.zeros([tile_m, tile_n, tile_p, tile_q], dtype=torch.float32)
466+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3], 0.0, tl.float32)
467+
# src[test_indexing.py:N]: for tile_k in hl.tile(K):
468+
# src[test_indexing.py:N]: cols_3d = col[tile_m, tile_n, tile_k]
469+
# src[test_indexing.py:N]: B_slice = B[
470+
# src[test_indexing.py:N-N]: ...
471+
for offset_5 in tl.range(0, 8, _BLOCK_SIZE_4):
472+
indices_5 = offset_5 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
473+
acc_copy = acc
474+
acc_copy_0 = acc_copy
475+
# src[test_indexing.py:N]: cols_3d = col[tile_m, tile_n, tile_k]
476+
cols_3d = tl.load(col + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 8 + indices_5[None, None, :] * 1), None)
477+
# src[test_indexing.py:N]: cols_3d[:, :, :, None, None],
478+
subscript = cols_3d[:, :, :, None, None]
479+
# src[test_indexing.py:N]: tile_p.index[None, None, :, None],
480+
load_1 = indices_2[None, None, :, None]
481+
# src[test_indexing.py:N]: tile_q.index[None, None, None, :],
482+
load_2 = indices_3[None, None, None, :]
483+
# src[test_indexing.py:N]: B_slice = B[
484+
# src[test_indexing.py:N]: cols_3d[:, :, :, None, None],
485+
# src[test_indexing.py:N]: tile_p.index[None, None, :, None],
486+
# src[test_indexing.py:N-N]: ...
487+
B_slice = tl.load(B + (subscript * 140 + load_1 * 14 + load_2 * 1), mask_2[None, None, None, :, None] & mask_3[None, None, None, None, :], other=0)
488+
# src[test_indexing.py:N]: vals_3d = val[tile_m, tile_n, tile_k]
489+
vals_3d = tl.load(val + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 8 + indices_5[None, None, :] * 1), None)
490+
# src[test_indexing.py:N]: contrib = vals_3d[:, :, :, None, None] * B_slice
491+
subscript_1 = vals_3d[:, :, :, None, None]
492+
v_0 = subscript_1 * B_slice
493+
# src[test_indexing.py:N]: contrib = contrib.sum(dim=2)
494+
contrib_1 = tl.cast(tl.sum(v_0, 2), tl.float32)
495+
# src[test_indexing.py:N]: acc = acc + contrib
496+
acc = acc_copy_0 + contrib_1
497+
# src[test_indexing.py:N]: C[tile_m, tile_n, tile_p, tile_q] = acc.to(out_dtype)
498+
tl.store(C + (indices_0[:, None, None, None] * 1680 + indices_1[None, :, None, None] * 140 + indices_2[None, None, :, None] * 14 + indices_3[None, None, None, :] * 1), acc, mask_2[None, None, :, None] & mask_3[None, None, None, :])
499+
500+
def test(col: torch.Tensor, val: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher):
501+
# src[test_indexing.py:N]: M, N, K = col.shape
502+
M, N, K = col.shape
503+
# src[test_indexing.py:N]: _, P, Q = B.shape
504+
_, P, Q = B.shape
505+
# src[test_indexing.py:N]: out_dtype = torch.promote_types(val.dtype, B.dtype)
506+
out_dtype = torch.promote_types(val.dtype, B.dtype)
507+
# src[test_indexing.py:N]: C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device)
508+
C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device)
509+
# src[test_indexing.py:N]: for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]):
510+
_BLOCK_SIZE_0 = 4
511+
_BLOCK_SIZE_1 = 4
512+
_BLOCK_SIZE_2 = 4
513+
_BLOCK_SIZE_3 = 4
514+
# src[test_indexing.py:N]: for tile_k in hl.tile(K):
515+
# src[test_indexing.py:N]: cols_3d = col[tile_m, tile_n, tile_k]
516+
# src[test_indexing.py:N]: B_slice = B[
517+
# src[test_indexing.py:N-N]: ...
518+
_BLOCK_SIZE_4 = 4
519+
# src[test_indexing.py:N]: for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]):
520+
# src[test_indexing.py:N]: acc = hl.zeros([tile_m, tile_n, tile_p, tile_q], dtype=torch.float32)
521+
# src[test_indexing.py:N-N]: ...
522+
_RDIM_SIZE_5 = triton.next_power_of_2(_BLOCK_SIZE_2)
523+
_launcher(_helion_test, (triton.cdiv(16, _BLOCK_SIZE_0) * triton.cdiv(12, _BLOCK_SIZE_1) * triton.cdiv(10, _BLOCK_SIZE_2) * triton.cdiv(14, _BLOCK_SIZE_3),), col, B, val, C, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, num_warps=4, num_stages=1)
524+
# src[test_indexing.py:N]: return C
525+
return C
526+
527+
--- assertExpectedJournal(TestIndexing.test_indirect_indexing_3d_flat_load)
528+
from __future__ import annotations
529+
530+
import torch
531+
import triton
532+
import triton.language as tl
533+
from helion.runtime import default_launcher as _default_launcher
534+
535+
@triton.jit
536+
def _helion_test(col, B_flat, val, C, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE_4: tl.constexpr):
537+
# src[test_indexing.py:N]: for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]):
538+
num_blocks_0 = tl.cdiv(16, _BLOCK_SIZE_0)
539+
num_blocks_1 = tl.cdiv(12, _BLOCK_SIZE_1)
540+
num_blocks_2 = tl.cdiv(10, _BLOCK_SIZE_2)
541+
pid_0 = tl.program_id(0) % num_blocks_0
542+
pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
543+
pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1) % num_blocks_2
544+
pid_3 = tl.program_id(0) // (num_blocks_0 * num_blocks_1 * num_blocks_2)
545+
offset_0 = pid_0 * _BLOCK_SIZE_0
546+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
547+
offset_1 = pid_1 * _BLOCK_SIZE_1
548+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
549+
offset_2 = pid_2 * _BLOCK_SIZE_2
550+
indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
551+
mask_2 = indices_2 < 10
552+
offset_3 = pid_3 * _BLOCK_SIZE_3
553+
indices_3 = (offset_3 + tl.arange(0, _BLOCK_SIZE_3)).to(tl.int32)
554+
mask_3 = indices_3 < 14
555+
# src[test_indexing.py:N]: acc = hl.zeros([tile_m, tile_n, tile_p, tile_q], dtype=torch.float32)
556+
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3], 0.0, tl.float32)
557+
# src[test_indexing.py:N]: for tile_k in hl.tile(K):
558+
# src[test_indexing.py:N]: cols_3d = col[tile_m, tile_n, tile_k]
559+
# src[test_indexing.py:N]: B_indices = (
560+
# src[test_indexing.py:N-N]: ...
561+
for offset_5 in tl.range(0, 8, _BLOCK_SIZE_4):
562+
indices_5 = offset_5 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
563+
acc_copy = acc
564+
acc_copy_0 = acc_copy
565+
# src[test_indexing.py:N]: cols_3d = col[tile_m, tile_n, tile_k]
566+
cols_3d = tl.load(col + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 8 + indices_5[None, None, :] * 1), None)
567+
# src[test_indexing.py:N]: cols_3d[:, :, :, None, None] * (P * Q)
568+
subscript = cols_3d[:, :, :, None, None]
569+
v_0 = tl.full([], 140, tl.int64)
570+
v_1 = tl.cast(subscript * v_0, tl.int64)
571+
# src[test_indexing.py:N]: + tile_p.index[None, None, :, None] * Q
572+
load_1 = indices_2[None, None, :, None]
573+
v_2 = tl.full([], 14, tl.int32)
574+
v_3 = tl.cast(load_1 * v_2, tl.int32)
575+
# src[test_indexing.py:N]: cols_3d[:, :, :, None, None] * (P * Q)
576+
# src[test_indexing.py:N]: + tile_p.index[None, None, :, None] * Q
577+
v_4 = v_3[None, :, :, :, :]
578+
v_5 = tl.cast(v_4, tl.int64)
579+
v_6 = v_1 + v_5
580+
# src[test_indexing.py:N]: + tile_q.index[None, None, None, :]
581+
load_2 = indices_3[None, None, None, :]
582+
# src[test_indexing.py:N]: cols_3d[:, :, :, None, None] * (P * Q)
583+
# src[test_indexing.py:N]: + tile_p.index[None, None, :, None] * Q
584+
# src[test_indexing.py:N]: + tile_q.index[None, None, None, :]
585+
v_7 = load_2[None, :, :, :, :]
586+
v_8 = tl.cast(v_7, tl.int64)
587+
v_9 = v_6 + v_8
588+
# src[test_indexing.py:N]: B_slice = hl.load(B_flat, [B_indices])
589+
B_slice = tl.load(B_flat + v_9 * 1, mask_2[None, None, None, :, None] & mask_3[None, None, None, None, :], other=0)
590+
# src[test_indexing.py:N]: vals_3d = val[tile_m, tile_n, tile_k]
591+
vals_3d = tl.load(val + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 8 + indices_5[None, None, :] * 1), None)
592+
# src[test_indexing.py:N]: contrib = vals_3d[:, :, :, None, None] * B_slice
593+
subscript_1 = vals_3d[:, :, :, None, None]
594+
v_10 = subscript_1 * B_slice
595+
# src[test_indexing.py:N]: contrib = contrib.sum(dim=2)
596+
contrib_1 = tl.cast(tl.sum(v_10, 2), tl.float32)
597+
# src[test_indexing.py:N]: acc = acc + contrib
598+
acc = acc_copy_0 + contrib_1
599+
# src[test_indexing.py:N]: C[tile_m, tile_n, tile_p, tile_q] = acc.to(out_dtype)
600+
tl.store(C + (indices_0[:, None, None, None] * 1680 + indices_1[None, :, None, None] * 140 + indices_2[None, None, :, None] * 14 + indices_3[None, None, None, :] * 1), acc, mask_2[None, None, :, None] & mask_3[None, None, None, :])
601+
602+
def test(col: torch.Tensor, val: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher):
603+
# src[test_indexing.py:N]: M, N, K = col.shape
604+
M, N, K = col.shape
605+
# src[test_indexing.py:N]: _, P, Q = B.shape
606+
_, P, Q = B.shape
607+
# src[test_indexing.py:N]: out_dtype = torch.promote_types(val.dtype, B.dtype)
608+
out_dtype = torch.promote_types(val.dtype, B.dtype)
609+
# src[test_indexing.py:N]: C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device)
610+
C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device)
611+
# src[test_indexing.py:N]: B_flat = B.reshape(-1) # [K*P*Q]
612+
B_flat = B.reshape(-1)
613+
# src[test_indexing.py:N]: for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]):
614+
_BLOCK_SIZE_0 = 4
615+
_BLOCK_SIZE_1 = 4
616+
_BLOCK_SIZE_2 = 4
617+
_BLOCK_SIZE_3 = 4
618+
# src[test_indexing.py:N]: for tile_k in hl.tile(K):
619+
# src[test_indexing.py:N]: cols_3d = col[tile_m, tile_n, tile_k]
620+
# src[test_indexing.py:N]: B_indices = (
621+
# src[test_indexing.py:N-N]: ...
622+
_BLOCK_SIZE_4 = 4
623+
# src[test_indexing.py:N]: for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]):
624+
# src[test_indexing.py:N]: acc = hl.zeros([tile_m, tile_n, tile_p, tile_q], dtype=torch.float32)
625+
# src[test_indexing.py:N-N]: ...
626+
_RDIM_SIZE_5 = triton.next_power_of_2(_BLOCK_SIZE_2)
627+
_launcher(_helion_test, (triton.cdiv(16, _BLOCK_SIZE_0) * triton.cdiv(12, _BLOCK_SIZE_1) * triton.cdiv(10, _BLOCK_SIZE_2) * triton.cdiv(14, _BLOCK_SIZE_3),), col, B_flat, val, C, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, num_warps=4, num_stages=1)
628+
# src[test_indexing.py:N]: return C
629+
return C
630+
288631
--- assertExpectedJournal(TestIndexing.test_mask_load)
289632
from __future__ import annotations
290633

0 commit comments

Comments
 (0)