-
Notifications
You must be signed in to change notification settings - Fork 121
Expand file tree
/
Copy pathtest_cudagraph.py
More file actions
75 lines (57 loc) · 1.9 KB
/
test_cudagraph.py
File metadata and controls
75 lines (57 loc) · 1.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
import cuda.tile as ct
import torch
import pytest
@ct.kernel
def add_one(x):
xi = ct.load(x, 0, ())
xi = xi + 1
ct.store(x, 0, xi)
def test_simple():
x = torch.zeros(1, device='cuda')
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
stream = torch.cuda.current_stream()
ct.launch(stream, (1,), add_one, (x,))
assert x.item() == 0
for _ in range(10):
graph.replay()
assert x.item() == 10
@ct.kernel
def matmul_accumulate(x, y, z):
acc = ct.load(z, (0, 0), (16, 16))
for k in range(4):
tx = ct.load(x, (0, k), (16, 4))
ty = ct.load(y, (k, 0), (4, 16))
acc = ct.mma(tx, ty, acc)
ct.store(z, (0, 0), acc)
def test_matmul():
x = torch.ones((16, 16), dtype=torch.float16, device='cuda')
y = torch.ones((16, 16), dtype=torch.float16, device='cuda')
z = torch.zeros_like(x)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
stream = torch.cuda.current_stream()
ct.launch(stream, (1,), matmul_accumulate, (x, y, z))
N = 4
for _ in range(N):
graph.replay()
ref = torch.full((16, 16), 16 * N, dtype=torch.float16, device='cuda')
assert torch.all(z == ref)
@ct.kernel
def list_copy(As, Bs, N: ct.Constant[int]):
for i in range(len(As)):
tx = ct.load(As[i], (0,), (N,))
ct.store(Bs[i], (0,), tile=tx)
def test_list_of_array():
N = 8
x = torch.rand((N, ), device='cuda')
y = torch.empty_like(x)
graph = torch.cuda.CUDAGraph()
with pytest.raises(RuntimeError, match=r'List argument in CUDAGraph isn\'t supported yet'):
with torch.cuda.graph(graph):
ct.launch(torch.cuda.current_stream(),
(1,),
list_copy, ([x,], [y,], N))