-
Notifications
You must be signed in to change notification settings - Fork 122
Expand file tree
/
Copy pathtest_loop_split.py
More file actions
32 lines (24 loc) · 969 Bytes
/
test_loop_split.py
File metadata and controls
32 lines (24 loc) · 969 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
import torch
import cuda.tile as ct
from cuda.tile._ir.ops import Loop
from cuda.tile._compiler_options import CompilerOptions
from cuda.tile._compile import compile_tile
from util import assert_equal
@ct.kernel
def split_ge_kernel(x):
for i in range(x.shape[0]):
val = i
if i >= 3:
val *= 10
ct.store(x, i, val)
def test_split_ge():
x = torch.zeros(10, dtype=torch.int32, device="cuda")
root_block = compile_tile(split_ge_kernel._pyfunc, (x,), CompilerOptions()).final_ir
loop_ops = [op for op in root_block.traverse() if isinstance(op, Loop)]
assert len(loop_ops) == 2
ct.launch(torch.cuda.current_stream(), (1,), split_ge_kernel, (x,))
ref = torch.tensor([0, 1, 2, 30, 40, 50, 60, 70, 80, 90], dtype=torch.int32, device="cuda")
assert_equal(x, ref)