-
Notifications
You must be signed in to change notification settings - Fork 121
Expand file tree
/
Copy pathtest_tuple.py
More file actions
27 lines (23 loc) · 841 Bytes
/
test_tuple.py
File metadata and controls
27 lines (23 loc) · 841 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
import cuda.tile as ct
import torch
from util import assert_equal
def test_tuple_concatenation():
@ct.kernel
def kernel(x, y, z):
a = ct.load(x, (0,), (16,))
b = ct.load(x, (1,), (16,))
c = ct.load(x, (2,), (16,))
t = (a,) + (b, c)
ct.store(y, (0,), t[0])
ct.store(y, (1,), t[1])
ct.store(y, (2,), t[2])
ct.scatter(z, (), len(t))
x = torch.arange(48, dtype=torch.int32, device="cuda")
y = torch.zeros((48,), dtype=torch.int32, device="cuda")
z = torch.zeros((), dtype=torch.int32, device="cuda")
ct.launch(torch.cuda.current_stream(), (1,), kernel, (x, y, z))
assert_equal(y, x)
assert z.item() == 3