-
Notifications
You must be signed in to change notification settings - Fork 121
Expand file tree
/
Copy pathtest_stream.py
More file actions
57 lines (40 loc) · 1.42 KB
/
test_stream.py
File metadata and controls
57 lines (40 loc) · 1.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
import torch
from torch.testing import make_tensor
import cupy
import cuda.tile as ct
@ct.kernel
def array_copy_1d(x, y, TILE: ct.Constant[int]):
bid = ct.bid(0)
tx = ct.load(x, index=(bid,), shape=(TILE,))
ct.store(y, index=(bid,), tile=tx)
def _test_stream(stream, sync):
x = make_tensor(4096, dtype=torch.float32, device="cuda")
y = torch.zeros_like(x)
torch.cuda.synchronize()
ct.launch(stream, (1,), array_copy_1d, (x, y, 4096))
sync()
torch.testing.assert_close(x, y)
# -- Test PyTorch Stream --
def test_torch_pass_stream():
stream = torch.cuda.Stream()
_test_stream(stream, stream.synchronize)
def test_torch_pass_stream_ptr():
stream = torch.cuda.Stream()
_test_stream(stream.cuda_stream, stream.synchronize)
# -- Test CuPy Stream --
def test_cupy_pass_stream():
stream = cupy.cuda.Stream()
_test_stream(stream, stream.synchronize)
def test_cupy_pass_stream_ptr():
stream = cupy.cuda.Stream()
_test_stream(stream.ptr, stream.synchronize)
# -- Test Numba Stream --
def test_numba_pass_stream(numba_cuda):
stream = numba_cuda.stream()
_test_stream(stream, stream.synchronize)
def test_numba_pass_stream_ptr(numba_cuda):
stream = numba_cuda.stream()
_test_stream(stream.handle.value, stream.synchronize)