Skip to content

Commit c479585

Browse files
committed
Start of tests
1 parent 29935f6 commit c479585

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from numba import cuda, typeof
2+
from numba.cuda.testing import unittest, CUDATestCase
3+
import numpy as np
4+
5+
types = (
6+
np.int8, np.int16, np.int32, np.int64,
7+
np.uint8, np.uint16, np.uint32, np.uint64,
8+
np.float16, np.float32, np.float64,
9+
)
10+
11+
complex_types = (
12+
np.complex64, np.complex128
13+
)
14+
15+
load_operators = (
16+
(cuda.ldca, 'ca'),
17+
(cuda.ldcg, 'cg'),
18+
(cuda.ldcs, 'cs'),
19+
(cuda.ldlu, 'lu'),
20+
(cuda.ldcv, 'cv')
21+
)
22+
store_operators = (
23+
(cuda.stcg, 'cg'),
24+
(cuda.stcs, 'cs'),
25+
(cuda.stwb, 'wb'),
26+
(cuda.stwt, 'wt')
27+
)
28+
29+
30+
class TestCacheHints(CUDATestCase):
31+
def test_loads(self):
32+
for operator, modifier in load_operators:
33+
@cuda.jit
34+
def f(r, x):
35+
for i in range(len(r)):
36+
r[i] = operator(x, i)
37+
38+
for ty in types:
39+
with self.subTest(operator=operator, ty=ty):
40+
x = np.arange(5).astype(ty)
41+
r = np.zeros_like(x)
42+
43+
f[1, 1](r, x)
44+
np.testing.assert_equal(r, x)
45+
46+
# Check PTX contains a cache-policy load instruction
47+
numba_type = typeof(x)
48+
bitwidth = numba_type.dtype.bitwidth
49+
sig = (numba_type, numba_type)
50+
ptx, _ = cuda.compile_ptx(f, sig)
51+
52+
self.assertIn(f"ld.global.{modifier}.b{bitwidth}", ptx)
53+
54+
def test_stores(self):
55+
for operator, modifier in store_operators:
56+
@cuda.jit
57+
def f(r, x):
58+
for i in range(len(r)):
59+
operator(r, i, x[i])
60+
61+
for ty in types:
62+
with self.subTest(operator=operator, ty=ty):
63+
x = np.arange(5).astype(ty)
64+
r = np.zeros_like(x)
65+
66+
f[1, 1](r, x)
67+
np.testing.assert_equal(r, x)
68+
69+
# Check PTX contains a cache-policy store instruction
70+
numba_type = typeof(x)
71+
bitwidth = numba_type.dtype.bitwidth
72+
sig = (numba_type, numba_type)
73+
ptx, _ = cuda.compile_ptx(f, sig)
74+
75+
self.assertIn(f"st.global.{modifier}.b{bitwidth}", ptx)
76+
77+
78+
if __name__ == '__main__':
79+
unittest.main()

0 commit comments

Comments
 (0)