-
Notifications
You must be signed in to change notification settings - Fork 120
Expand file tree
/
Copy pathtest_generate.py
More file actions
30 lines (24 loc) · 950 Bytes
/
test_generate.py
File metadata and controls
30 lines (24 loc) · 950 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from math import ceil
import cuda.tile as ct
from util import assert_equal
from conftest import int_dtypes, float_dtypes, dtype_id
@ct.kernel
def arange(x, TILE: ct.Constant[int]):
bid = ct.bid(0)
start = ct.astype(bid * TILE, x.dtype)
tx = start + ct.arange(TILE, dtype=x.dtype)
ct.store(x, index=(bid,), tile=tx)
@pytest.mark.parametrize("shape", [(128,)])
@pytest.mark.parametrize("tile", [64])
@pytest.mark.parametrize("dtype", int_dtypes + float_dtypes, ids=dtype_id)
def test_arange(shape, dtype, tile):
x = torch.zeros(shape, dtype=dtype, device='cuda')
grid = (ceil(shape[0] / tile), 1, 1)
ct.launch(torch.cuda.current_stream(), grid, arange, (x, tile))
ref = torch.arange(len(x), dtype=dtype, device=x.device)
assert_equal(x, ref)