@@ -285,6 +285,349 @@ def broadcast_add_3d(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor,
285285 # src[test_indexing.py:N]: return out
286286 return out
287287
288+ --- assertExpectedJournal(TestIndexing.test_indirect_indexing_2d_direct_gather)
289+ from __future__ import annotations
290+
291+ import torch
292+ import triton
293+ import triton.language as tl
294+ from helion.runtime import default_launcher as _default_launcher
295+
296+ @triton.jit
297+ def _helion_test(col, B, val, C, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
298+ # src[test_indexing.py:N]: for tile_m, tile_n in hl.tile([M, N]):
299+ num_blocks_0 = tl.cdiv(32, _BLOCK_SIZE_0)
300+ pid_0 = tl.program_id(0) % num_blocks_0
301+ pid_1 = tl.program_id(0) // num_blocks_0
302+ offset_0 = pid_0 * _BLOCK_SIZE_0
303+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
304+ offset_1 = pid_1 * _BLOCK_SIZE_1
305+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
306+ # src[test_indexing.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
307+ acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
308+ # src[test_indexing.py:N]: for tile_k in hl.tile(K):
309+ # src[test_indexing.py:N]: cols_2d = col[tile_m, tile_k]
310+ # src[test_indexing.py:N]: B_slice = B[cols_2d[:, :, None], tile_n.index[None, None, :]]
311+ # src[test_indexing.py:N-N]: ...
312+ for offset_3 in tl.range(0, 16, _BLOCK_SIZE_2):
313+ indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
314+ acc_copy = acc
315+ acc_copy_0 = acc_copy
316+ # src[test_indexing.py:N]: cols_2d = col[tile_m, tile_k]
317+ cols_2d = tl.load(col + (indices_0[:, None] * 16 + indices_3[None, :] * 1), None)
318+ # src[test_indexing.py:N]: B_slice = B[cols_2d[:, :, None], tile_n.index[None, None, :]]
319+ subscript = cols_2d[:, :, None]
320+ load_1 = indices_1[None, None, :]
321+ B_slice = tl.load(B + (subscript * 24 + load_1 * 1), None)
322+ # src[test_indexing.py:N]: vals_2d = val[tile_m, tile_k]
323+ vals_2d = tl.load(val + (indices_0[:, None] * 16 + indices_3[None, :] * 1), None)
324+ # src[test_indexing.py:N]: contrib = vals_2d[:, :, None] * B_slice
325+ subscript_1 = vals_2d[:, :, None]
326+ v_0 = subscript_1 * B_slice
327+ # src[test_indexing.py:N]: contrib = contrib.sum(dim=1)
328+ contrib_1 = tl.cast(tl.sum(v_0, 1), tl.float32)
329+ # src[test_indexing.py:N]: acc = acc + contrib
330+ acc = acc_copy_0 + contrib_1
331+ # src[test_indexing.py:N]: C[tile_m, tile_n] = acc.to(out_dtype)
332+ tl.store(C + (indices_0[:, None] * 24 + indices_1[None, :] * 1), acc, None)
333+
334+ def test(col: torch.Tensor, val: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher):
335+ # src[test_indexing.py:N]: M, K = col.shape
336+ M, K = col.shape
337+ # src[test_indexing.py:N]: _, N = B.shape
338+ _, N = B.shape
339+ # src[test_indexing.py:N]: out_dtype = torch.promote_types(val.dtype, B.dtype)
340+ out_dtype = torch.promote_types(val.dtype, B.dtype)
341+ # src[test_indexing.py:N]: C = torch.empty((M, N), dtype=out_dtype, device=B.device)
342+ C = torch.empty((M, N), dtype=out_dtype, device=B.device)
343+ # src[test_indexing.py:N]: for tile_m, tile_n in hl.tile([M, N]):
344+ _BLOCK_SIZE_0 = 8
345+ _BLOCK_SIZE_1 = 8
346+ # src[test_indexing.py:N]: for tile_k in hl.tile(K):
347+ # src[test_indexing.py:N]: cols_2d = col[tile_m, tile_k]
348+ # src[test_indexing.py:N]: B_slice = B[cols_2d[:, :, None], tile_n.index[None, None, :]]
349+ # src[test_indexing.py:N-N]: ...
350+ _BLOCK_SIZE_2 = 4
351+ # src[test_indexing.py:N]: for tile_m, tile_n in hl.tile([M, N]):
352+ # src[test_indexing.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
353+ # src[test_indexing.py:N-N]: ...
354+ _RDIM_SIZE_3 = triton.next_power_of_2(_BLOCK_SIZE_1)
355+ _launcher(_helion_test, (triton.cdiv(32, _BLOCK_SIZE_0) * triton.cdiv(24, _BLOCK_SIZE_1),), col, B, val, C, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1)
356+ # src[test_indexing.py:N]: return C
357+ return C
358+
359+ --- assertExpectedJournal(TestIndexing.test_indirect_indexing_2d_flat_load)
360+ from __future__ import annotations
361+
362+ import torch
363+ import triton
364+ import triton.language as tl
365+ from helion.runtime import default_launcher as _default_launcher
366+
367+ @triton.jit
368+ def _helion_test(col, B_flat, val, C, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
369+ # src[test_indexing.py:N]: for tile_m, tile_n in hl.tile([M, N]):
370+ num_blocks_0 = tl.cdiv(32, _BLOCK_SIZE_0)
371+ pid_0 = tl.program_id(0) % num_blocks_0
372+ pid_1 = tl.program_id(0) // num_blocks_0
373+ offset_0 = pid_0 * _BLOCK_SIZE_0
374+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
375+ offset_1 = pid_1 * _BLOCK_SIZE_1
376+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
377+ # src[test_indexing.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
378+ acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
379+ # src[test_indexing.py:N]: for tile_k in hl.tile(K):
380+ # src[test_indexing.py:N]: cols_2d = col[tile_m, tile_k]
381+ # src[test_indexing.py:N]: B_indices = (cols_2d * N)[:, :, None] + tile_n.index[None, None, :]
382+ # src[test_indexing.py:N-N]: ...
383+ for offset_3 in tl.range(0, 16, _BLOCK_SIZE_2):
384+ indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
385+ acc_copy = acc
386+ acc_copy_0 = acc_copy
387+ # src[test_indexing.py:N]: cols_2d = col[tile_m, tile_k]
388+ cols_2d = tl.load(col + (indices_0[:, None] * 16 + indices_3[None, :] * 1), None)
389+ # src[test_indexing.py:N]: B_indices = (cols_2d * N)[:, :, None] + tile_n.index[None, None, :]
390+ v_0 = tl.full([], 24, tl.int64)
391+ v_1 = tl.cast(cols_2d * v_0, tl.int64)
392+ subscript = v_1[:, :, None]
393+ load_1 = indices_1[None, None, :]
394+ v_2 = tl.cast(load_1, tl.int64)
395+ v_3 = subscript + v_2
396+ # src[test_indexing.py:N]: B_slice = hl.load(B_flat, [B_indices])
397+ B_slice = tl.load(B_flat + v_3 * 1, None)
398+ # src[test_indexing.py:N]: vals_2d = val[tile_m, tile_k]
399+ vals_2d = tl.load(val + (indices_0[:, None] * 16 + indices_3[None, :] * 1), None)
400+ # src[test_indexing.py:N]: contrib = vals_2d[:, :, None] * B_slice
401+ subscript_1 = vals_2d[:, :, None]
402+ v_4 = subscript_1 * B_slice
403+ # src[test_indexing.py:N]: contrib = contrib.sum(dim=1)
404+ contrib_1 = tl.cast(tl.sum(v_4, 1), tl.float32)
405+ # src[test_indexing.py:N]: acc = acc + contrib
406+ acc = acc_copy_0 + contrib_1
407+ # src[test_indexing.py:N]: C[tile_m, tile_n] = acc.to(out_dtype)
408+ tl.store(C + (indices_0[:, None] * 24 + indices_1[None, :] * 1), acc, None)
409+
410+ def test(col: torch.Tensor, val: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher):
411+ # src[test_indexing.py:N]: M, K = col.shape
412+ M, K = col.shape
413+ # src[test_indexing.py:N]: _, N = B.shape
414+ _, N = B.shape
415+ # src[test_indexing.py:N]: out_dtype = torch.promote_types(val.dtype, B.dtype)
416+ out_dtype = torch.promote_types(val.dtype, B.dtype)
417+ # src[test_indexing.py:N]: C = torch.empty((M, N), dtype=out_dtype, device=B.device)
418+ C = torch.empty((M, N), dtype=out_dtype, device=B.device)
419+ # src[test_indexing.py:N]: B_flat = B.reshape(-1) # [K*N]
420+ B_flat = B.reshape(-1)
421+ # src[test_indexing.py:N]: for tile_m, tile_n in hl.tile([M, N]):
422+ _BLOCK_SIZE_0 = 8
423+ _BLOCK_SIZE_1 = 8
424+ # src[test_indexing.py:N]: for tile_k in hl.tile(K):
425+ # src[test_indexing.py:N]: cols_2d = col[tile_m, tile_k]
426+ # src[test_indexing.py:N]: B_indices = (cols_2d * N)[:, :, None] + tile_n.index[None, None, :]
427+ # src[test_indexing.py:N-N]: ...
428+ _BLOCK_SIZE_2 = 4
429+ # src[test_indexing.py:N]: for tile_m, tile_n in hl.tile([M, N]):
430+ # src[test_indexing.py:N]: acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
431+ # src[test_indexing.py:N-N]: ...
432+ _RDIM_SIZE_3 = triton.next_power_of_2(_BLOCK_SIZE_1)
433+ _launcher(_helion_test, (triton.cdiv(32, _BLOCK_SIZE_0) * triton.cdiv(24, _BLOCK_SIZE_1),), col, B_flat, val, C, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=1)
434+ # src[test_indexing.py:N]: return C
435+ return C
436+
437+ --- assertExpectedJournal(TestIndexing.test_indirect_indexing_3d_direct_gather)
438+ from __future__ import annotations
439+
440+ import torch
441+ import triton
442+ import triton.language as tl
443+ from helion.runtime import default_launcher as _default_launcher
444+
445+ @triton.jit
446+ def _helion_test(col, B, val, C, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE_4: tl.constexpr):
447+ # src[test_indexing.py:N]: for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]):
448+ num_blocks_0 = tl.cdiv(16, _BLOCK_SIZE_0)
449+ num_blocks_1 = tl.cdiv(12, _BLOCK_SIZE_1)
450+ num_blocks_2 = tl.cdiv(10, _BLOCK_SIZE_2)
451+ pid_0 = tl.program_id(0) % num_blocks_0
452+ pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
453+ pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1) % num_blocks_2
454+ pid_3 = tl.program_id(0) // (num_blocks_0 * num_blocks_1 * num_blocks_2)
455+ offset_0 = pid_0 * _BLOCK_SIZE_0
456+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
457+ offset_1 = pid_1 * _BLOCK_SIZE_1
458+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
459+ offset_2 = pid_2 * _BLOCK_SIZE_2
460+ indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
461+ mask_2 = indices_2 < 10
462+ offset_3 = pid_3 * _BLOCK_SIZE_3
463+ indices_3 = (offset_3 + tl.arange(0, _BLOCK_SIZE_3)).to(tl.int32)
464+ mask_3 = indices_3 < 14
465+ # src[test_indexing.py:N]: acc = hl.zeros([tile_m, tile_n, tile_p, tile_q], dtype=torch.float32)
466+ acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3], 0.0, tl.float32)
467+ # src[test_indexing.py:N]: for tile_k in hl.tile(K):
468+ # src[test_indexing.py:N]: cols_3d = col[tile_m, tile_n, tile_k]
469+ # src[test_indexing.py:N]: B_slice = B[
470+ # src[test_indexing.py:N-N]: ...
471+ for offset_5 in tl.range(0, 8, _BLOCK_SIZE_4):
472+ indices_5 = offset_5 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
473+ acc_copy = acc
474+ acc_copy_0 = acc_copy
475+ # src[test_indexing.py:N]: cols_3d = col[tile_m, tile_n, tile_k]
476+ cols_3d = tl.load(col + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 8 + indices_5[None, None, :] * 1), None)
477+ # src[test_indexing.py:N]: cols_3d[:, :, :, None, None],
478+ subscript = cols_3d[:, :, :, None, None]
479+ # src[test_indexing.py:N]: tile_p.index[None, None, :, None],
480+ load_1 = indices_2[None, None, :, None]
481+ # src[test_indexing.py:N]: tile_q.index[None, None, None, :],
482+ load_2 = indices_3[None, None, None, :]
483+ # src[test_indexing.py:N]: B_slice = B[
484+ # src[test_indexing.py:N]: cols_3d[:, :, :, None, None],
485+ # src[test_indexing.py:N]: tile_p.index[None, None, :, None],
486+ # src[test_indexing.py:N-N]: ...
487+ B_slice = tl.load(B + (subscript * 140 + load_1 * 14 + load_2 * 1), mask_2[None, None, None, :, None] & mask_3[None, None, None, None, :], other=0)
488+ # src[test_indexing.py:N]: vals_3d = val[tile_m, tile_n, tile_k]
489+ vals_3d = tl.load(val + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 8 + indices_5[None, None, :] * 1), None)
490+ # src[test_indexing.py:N]: contrib = vals_3d[:, :, :, None, None] * B_slice
491+ subscript_1 = vals_3d[:, :, :, None, None]
492+ v_0 = subscript_1 * B_slice
493+ # src[test_indexing.py:N]: contrib = contrib.sum(dim=2)
494+ contrib_1 = tl.cast(tl.sum(v_0, 2), tl.float32)
495+ # src[test_indexing.py:N]: acc = acc + contrib
496+ acc = acc_copy_0 + contrib_1
497+ # src[test_indexing.py:N]: C[tile_m, tile_n, tile_p, tile_q] = acc.to(out_dtype)
498+ tl.store(C + (indices_0[:, None, None, None] * 1680 + indices_1[None, :, None, None] * 140 + indices_2[None, None, :, None] * 14 + indices_3[None, None, None, :] * 1), acc, mask_2[None, None, :, None] & mask_3[None, None, None, :])
499+
500+ def test(col: torch.Tensor, val: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher):
501+ # src[test_indexing.py:N]: M, N, K = col.shape
502+ M, N, K = col.shape
503+ # src[test_indexing.py:N]: _, P, Q = B.shape
504+ _, P, Q = B.shape
505+ # src[test_indexing.py:N]: out_dtype = torch.promote_types(val.dtype, B.dtype)
506+ out_dtype = torch.promote_types(val.dtype, B.dtype)
507+ # src[test_indexing.py:N]: C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device)
508+ C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device)
509+ # src[test_indexing.py:N]: for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]):
510+ _BLOCK_SIZE_0 = 4
511+ _BLOCK_SIZE_1 = 4
512+ _BLOCK_SIZE_2 = 4
513+ _BLOCK_SIZE_3 = 4
514+ # src[test_indexing.py:N]: for tile_k in hl.tile(K):
515+ # src[test_indexing.py:N]: cols_3d = col[tile_m, tile_n, tile_k]
516+ # src[test_indexing.py:N]: B_slice = B[
517+ # src[test_indexing.py:N-N]: ...
518+ _BLOCK_SIZE_4 = 4
519+ # src[test_indexing.py:N]: for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]):
520+ # src[test_indexing.py:N]: acc = hl.zeros([tile_m, tile_n, tile_p, tile_q], dtype=torch.float32)
521+ # src[test_indexing.py:N-N]: ...
522+ _RDIM_SIZE_5 = triton.next_power_of_2(_BLOCK_SIZE_2)
523+ _launcher(_helion_test, (triton.cdiv(16, _BLOCK_SIZE_0) * triton.cdiv(12, _BLOCK_SIZE_1) * triton.cdiv(10, _BLOCK_SIZE_2) * triton.cdiv(14, _BLOCK_SIZE_3),), col, B, val, C, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, num_warps=4, num_stages=1)
524+ # src[test_indexing.py:N]: return C
525+ return C
526+
527+ --- assertExpectedJournal(TestIndexing.test_indirect_indexing_3d_flat_load)
528+ from __future__ import annotations
529+
530+ import torch
531+ import triton
532+ import triton.language as tl
533+ from helion.runtime import default_launcher as _default_launcher
534+
535+ @triton.jit
536+ def _helion_test(col, B_flat, val, C, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE_4: tl.constexpr):
537+ # src[test_indexing.py:N]: for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]):
538+ num_blocks_0 = tl.cdiv(16, _BLOCK_SIZE_0)
539+ num_blocks_1 = tl.cdiv(12, _BLOCK_SIZE_1)
540+ num_blocks_2 = tl.cdiv(10, _BLOCK_SIZE_2)
541+ pid_0 = tl.program_id(0) % num_blocks_0
542+ pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
543+ pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1) % num_blocks_2
544+ pid_3 = tl.program_id(0) // (num_blocks_0 * num_blocks_1 * num_blocks_2)
545+ offset_0 = pid_0 * _BLOCK_SIZE_0
546+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
547+ offset_1 = pid_1 * _BLOCK_SIZE_1
548+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
549+ offset_2 = pid_2 * _BLOCK_SIZE_2
550+ indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
551+ mask_2 = indices_2 < 10
552+ offset_3 = pid_3 * _BLOCK_SIZE_3
553+ indices_3 = (offset_3 + tl.arange(0, _BLOCK_SIZE_3)).to(tl.int32)
554+ mask_3 = indices_3 < 14
555+ # src[test_indexing.py:N]: acc = hl.zeros([tile_m, tile_n, tile_p, tile_q], dtype=torch.float32)
556+ acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3], 0.0, tl.float32)
557+ # src[test_indexing.py:N]: for tile_k in hl.tile(K):
558+ # src[test_indexing.py:N]: cols_3d = col[tile_m, tile_n, tile_k]
559+ # src[test_indexing.py:N]: B_indices = (
560+ # src[test_indexing.py:N-N]: ...
561+ for offset_5 in tl.range(0, 8, _BLOCK_SIZE_4):
562+ indices_5 = offset_5 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
563+ acc_copy = acc
564+ acc_copy_0 = acc_copy
565+ # src[test_indexing.py:N]: cols_3d = col[tile_m, tile_n, tile_k]
566+ cols_3d = tl.load(col + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 8 + indices_5[None, None, :] * 1), None)
567+ # src[test_indexing.py:N]: cols_3d[:, :, :, None, None] * (P * Q)
568+ subscript = cols_3d[:, :, :, None, None]
569+ v_0 = tl.full([], 140, tl.int64)
570+ v_1 = tl.cast(subscript * v_0, tl.int64)
571+ # src[test_indexing.py:N]: + tile_p.index[None, None, :, None] * Q
572+ load_1 = indices_2[None, None, :, None]
573+ v_2 = tl.full([], 14, tl.int32)
574+ v_3 = tl.cast(load_1 * v_2, tl.int32)
575+ # src[test_indexing.py:N]: cols_3d[:, :, :, None, None] * (P * Q)
576+ # src[test_indexing.py:N]: + tile_p.index[None, None, :, None] * Q
577+ v_4 = v_3[None, :, :, :, :]
578+ v_5 = tl.cast(v_4, tl.int64)
579+ v_6 = v_1 + v_5
580+ # src[test_indexing.py:N]: + tile_q.index[None, None, None, :]
581+ load_2 = indices_3[None, None, None, :]
582+ # src[test_indexing.py:N]: cols_3d[:, :, :, None, None] * (P * Q)
583+ # src[test_indexing.py:N]: + tile_p.index[None, None, :, None] * Q
584+ # src[test_indexing.py:N]: + tile_q.index[None, None, None, :]
585+ v_7 = load_2[None, :, :, :, :]
586+ v_8 = tl.cast(v_7, tl.int64)
587+ v_9 = v_6 + v_8
588+ # src[test_indexing.py:N]: B_slice = hl.load(B_flat, [B_indices])
589+ B_slice = tl.load(B_flat + v_9 * 1, mask_2[None, None, None, :, None] & mask_3[None, None, None, None, :], other=0)
590+ # src[test_indexing.py:N]: vals_3d = val[tile_m, tile_n, tile_k]
591+ vals_3d = tl.load(val + (indices_0[:, None, None] * 96 + indices_1[None, :, None] * 8 + indices_5[None, None, :] * 1), None)
592+ # src[test_indexing.py:N]: contrib = vals_3d[:, :, :, None, None] * B_slice
593+ subscript_1 = vals_3d[:, :, :, None, None]
594+ v_10 = subscript_1 * B_slice
595+ # src[test_indexing.py:N]: contrib = contrib.sum(dim=2)
596+ contrib_1 = tl.cast(tl.sum(v_10, 2), tl.float32)
597+ # src[test_indexing.py:N]: acc = acc + contrib
598+ acc = acc_copy_0 + contrib_1
599+ # src[test_indexing.py:N]: C[tile_m, tile_n, tile_p, tile_q] = acc.to(out_dtype)
600+ tl.store(C + (indices_0[:, None, None, None] * 1680 + indices_1[None, :, None, None] * 140 + indices_2[None, None, :, None] * 14 + indices_3[None, None, None, :] * 1), acc, mask_2[None, None, :, None] & mask_3[None, None, None, :])
601+
602+ def test(col: torch.Tensor, val: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher):
603+ # src[test_indexing.py:N]: M, N, K = col.shape
604+ M, N, K = col.shape
605+ # src[test_indexing.py:N]: _, P, Q = B.shape
606+ _, P, Q = B.shape
607+ # src[test_indexing.py:N]: out_dtype = torch.promote_types(val.dtype, B.dtype)
608+ out_dtype = torch.promote_types(val.dtype, B.dtype)
609+ # src[test_indexing.py:N]: C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device)
610+ C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device)
611+ # src[test_indexing.py:N]: B_flat = B.reshape(-1) # [K*P*Q]
612+ B_flat = B.reshape(-1)
613+ # src[test_indexing.py:N]: for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]):
614+ _BLOCK_SIZE_0 = 4
615+ _BLOCK_SIZE_1 = 4
616+ _BLOCK_SIZE_2 = 4
617+ _BLOCK_SIZE_3 = 4
618+ # src[test_indexing.py:N]: for tile_k in hl.tile(K):
619+ # src[test_indexing.py:N]: cols_3d = col[tile_m, tile_n, tile_k]
620+ # src[test_indexing.py:N]: B_indices = (
621+ # src[test_indexing.py:N-N]: ...
622+ _BLOCK_SIZE_4 = 4
623+ # src[test_indexing.py:N]: for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]):
624+ # src[test_indexing.py:N]: acc = hl.zeros([tile_m, tile_n, tile_p, tile_q], dtype=torch.float32)
625+ # src[test_indexing.py:N-N]: ...
626+ _RDIM_SIZE_5 = triton.next_power_of_2(_BLOCK_SIZE_2)
627+ _launcher(_helion_test, (triton.cdiv(16, _BLOCK_SIZE_0) * triton.cdiv(12, _BLOCK_SIZE_1) * triton.cdiv(10, _BLOCK_SIZE_2) * triton.cdiv(14, _BLOCK_SIZE_3),), col, B_flat, val, C, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, num_warps=4, num_stages=1)
628+ # src[test_indexing.py:N]: return C
629+ return C
630+
288631--- assertExpectedJournal(TestIndexing.test_mask_load)
289632from __future__ import annotations
290633
0 commit comments