-
Notifications
You must be signed in to change notification settings - Fork 121
Expand file tree
/
Copy pathtest_fma.py
More file actions
109 lines (93 loc) · 3.81 KB
/
test_fma.py
File metadata and controls
109 lines (93 loc) · 3.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from math import ceil
import cuda.tile as ct
from util import assert_close, filecheck, get_bytecode
from torch.testing import make_tensor
def mul_add_kernel(x, y, z, output,
TILE: ct.Constant[int],
DIM: ct.Constant[int]):
bidx = ct.bid(0)
tx = ct.load(x, index=(bidx, 0), shape=(TILE, DIM))
ty = ct.load(y, index=(bidx, 0), shape=(TILE, DIM))
tz = ct.load(z, index=(bidx, 0), shape=(TILE, DIM))
output_tile = tx * ty + tz
ct.store(output, index=(bidx, 0), tile=output_tile)
def mul_add_kernel_local_var(x, y, z, output,
TILE: ct.Constant[int],
DIM: ct.Constant[int]):
bidx = ct.bid(0)
tx = ct.load(x, index=(bidx, 0), shape=(TILE, DIM))
ty = ct.load(y, index=(bidx, 0), shape=(TILE, DIM))
tz = ct.load(z, index=(bidx, 0), shape=(TILE, DIM))
tmp = tx * ty
output_tile = tmp + tz
ct.store(output, index=(bidx, 0), tile=output_tile)
def mul_sub_kernel(x, y, z, output,
TILE: ct.Constant[int],
DIM: ct.Constant[int]):
bidx = ct.bid(0)
tx = ct.load(x, index=(bidx, 0), shape=(TILE, DIM))
ty = ct.load(y, index=(bidx, 0), shape=(TILE, DIM))
tz = ct.load(z, index=(bidx, 0), shape=(TILE, DIM))
output_tile = tx * ty - tz
ct.store(output, index=(bidx, 0), tile=output_tile)
def add_mul_kernel(x, y, z, output,
TILE: ct.Constant[int],
DIM: ct.Constant[int]):
bidx = ct.bid(0)
tx = ct.load(x, index=(bidx, 0), shape=(TILE, DIM))
ty = ct.load(y, index=(bidx, 0), shape=(TILE, DIM))
tz = ct.load(z, index=(bidx, 0), shape=(TILE, DIM))
output_tile = tz + tx * ty
ct.store(output, index=(bidx, 0), tile=output_tile)
@ct.kernel
def mul_add_same_operand_kernel(x, output,
TILE: ct.Constant[int],
DIM: ct.Constant[int]):
bidx = ct.bid(0)
tx = ct.load(x, index=(bidx, 0), shape=(TILE, DIM))
tmp = tx * tx
output_tile = tmp + tmp
ct.store(output, index=(bidx, 0), tile=output_tile)
def test_fma_skip_when_new_op_uses_deleted_var():
shape = (128, 32)
x = make_tensor(shape, dtype=torch.float32, device='cuda')
output = make_tensor(shape, dtype=torch.float32, device='cuda')
TILE = 32
grid = (ceil(shape[0] / TILE), 1, 1)
ct.launch(torch.cuda.current_stream(), grid, mul_add_same_operand_kernel,
(x, output, TILE, shape[1]))
assert_close(output, 2 * x * x, atol=1e-3, rtol=1e-3)
@pytest.mark.use_mlir
@pytest.mark.parametrize(
"kernel, kernel_ref",
[
pytest.param(mul_add_kernel_local_var, lambda x, y, z: x * y + z),
pytest.param(mul_add_kernel, lambda x, y, z: x * y + z),
pytest.param(mul_sub_kernel, lambda x, y, z: x * y - z),
pytest.param(add_mul_kernel, lambda x, y, z: z + x * y),
]
)
def test_fma(kernel, kernel_ref):
shape = (128, 32)
x = make_tensor(shape, dtype=torch.float32, device='cuda')
y = make_tensor(shape, dtype=torch.float32, device='cuda')
z = make_tensor(shape, dtype=torch.float32, device='cuda')
output = make_tensor(shape, dtype=torch.float32, device='cuda')
TILE = 32
grid = (ceil(shape[0] / TILE), 1, 1)
kernel = ct.kernel(kernel)
bytecode = get_bytecode(kernel, (x, y, z, output, TILE, shape[1]))
check_directive = """\
// CHECK: %[[VAL:.*]] = fma
// CHECK-NOT: mulf
// CHECK-NOT: addf
// CHECK-NOT: subf
"""
filecheck(bytecode, check_directive)
ct.launch(torch.cuda.current_stream(), grid, kernel, (x, y, z, output, TILE, shape[1]))
assert_close(output, kernel_ref(x, y, z), atol=1e-3, rtol=1e-3)