-
Notifications
You must be signed in to change notification settings - Fork 120
Expand file tree
/
Copy pathtest_num_tiles.py
More file actions
43 lines (34 loc) · 1.14 KB
/
test_num_tiles.py
File metadata and controls
43 lines (34 loc) · 1.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
import cuda.tile as ct
import torch
import math
import pytest
@ct.kernel
def check_dim_0d(x):
n = ct.num_tiles(x, axis=0, shape=())
ct.store(x, 0, tile=n)
@ct.kernel
def check_dim_1d(x, M: ct.Constant[int]):
n = ct.num_tiles(x, axis=0, shape=M)
ct.store(x, 0, tile=n)
@ct.kernel
def check_dim_2d(x, M: ct.Constant[int]):
n = ct.num_tiles(x, axis=0, shape=(M, M))
ct.store(x, (0, 0), tile=n)
@pytest.mark.parametrize("shape", [(5,), (10, 10)])
@pytest.mark.parametrize("tile_size", [1, 2])
def test_num_tiles(shape, tile_size):
x = torch.zeros(shape, dtype=torch.int32, device='cuda')
stream = torch.cuda.current_stream()
if len(shape) == 1:
if tile_size == 1:
ct.launch(stream, (1,), check_dim_0d, (x,))
else:
ct.launch(stream, (1,), check_dim_1d, (x, tile_size))
res = x[0].item()
else:
ct.launch(stream, (1,), check_dim_2d, (x, tile_size))
res = x[0][0].item()
assert res == math.ceil(shape[0] / tile_size)