Skip to content

Commit c643a92

Browse files
committed
Validate indices
1 parent cb35c60 commit c643a92

File tree

2 files changed

+64
-10
lines changed

2 files changed

+64
-10
lines changed

numba_cuda/numba/cuda/cache_hints.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,35 @@ def stwt(array, i, value):
6161
}
6262

6363

64+
def _validate_arguments(instruction, array, index):
65+
if not isinstance(array, types.Array):
66+
msg = f"{instruction} operates on arrays. Got type {array}"
67+
raise NumbaTypeError(msg)
68+
69+
valid_index = False
70+
71+
if isinstance(index, types.Integer):
72+
if array.ndim != 1:
73+
msg = f"Expected {array.ndim} indices, got a scalar"
74+
raise NumbaTypeError(msg)
75+
valid_index = True
76+
77+
if isinstance(index, types.UniTuple):
78+
if index.count != array.ndim:
79+
msg = f"Expected {array.ndim} indices, got {index.count}"
80+
raise NumbaTypeError(msg)
81+
82+
if all([isinstance(t, types.Integer) for t in index.dtype]):
83+
valid_index = True
84+
85+
if not valid_index:
86+
raise NumbaTypeError(f"{index} is not a valid index")
87+
88+
6489
def ld_cache_operator(operator):
6590
@intrinsic
6691
def impl(typingctx, array, index):
67-
if not isinstance(array, types.Array):
68-
msg = f"ldcs operates on arrays. Got type {array}"
69-
raise NumbaTypeError(msg)
92+
_validate_arguments(f"ld{operator}", array, index)
7093

7194
# Need to validate bitwidth
7295

@@ -111,9 +134,7 @@ def codegen(context, builder, sig, args):
111134
def st_cache_operator(operator):
112135
@intrinsic
113136
def impl(typingctx, array, index, value):
114-
if not isinstance(array, types.Array):
115-
msg = f"ldcs operates on arrays. Got type {array}"
116-
raise NumbaTypeError(msg)
137+
_validate_arguments(f"st{operator}", array, index)
117138

118139
# Need to validate bitwidth
119140

numba_cuda/numba/cuda/tests/cudapy/test_cache_hints.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from numba import cuda, typeof
1+
from numba import cuda, errors, typeof, types
22
from numba.cuda.testing import unittest, CUDATestCase
33
import numpy as np
44

5-
types = (
5+
tested_types = (
66
np.int8, np.int16, np.int32, np.int64,
77
np.uint8, np.uint16, np.uint32, np.uint64,
88
np.float16, np.float32, np.float64,
@@ -35,7 +35,7 @@ def f(r, x):
3535
for i in range(len(r)):
3636
r[i] = operator(x, i)
3737

38-
for ty in types:
38+
for ty in tested_types:
3939
with self.subTest(operator=operator, ty=ty):
4040
x = np.arange(5).astype(ty)
4141
r = np.zeros_like(x)
@@ -58,7 +58,7 @@ def f(r, x):
5858
for i in range(len(r)):
5959
operator(r, i, x[i])
6060

61-
for ty in types:
61+
for ty in tested_types:
6262
with self.subTest(operator=operator, ty=ty):
6363
x = np.arange(5).astype(ty)
6464
r = np.zeros_like(x)
@@ -74,6 +74,39 @@ def f(r, x):
7474

7575
self.assertIn(f"st.global.{modifier}.b{bitwidth}", ptx)
7676

77+
def test_bad_indices(self):
78+
def float_indices(x):
79+
cuda.ldcs(x, 1.0)
80+
81+
sig_1d = (types.float32[::1],)
82+
83+
msg = "float64 is not a valid index"
84+
with self.assertRaisesRegex(errors.TypingError, msg):
85+
cuda.compile_ptx(float_indices, sig_1d)
86+
87+
def too_long_indices(x):
88+
cuda.ldcs(x, (1, 2))
89+
90+
msg = "Expected 1 indices, got 2"
91+
with self.assertRaisesRegex(errors.TypingError, msg):
92+
cuda.compile_ptx(too_long_indices, sig_1d)
93+
94+
def too_short_indices_scalar(x):
95+
cuda.ldcs(x, 1)
96+
97+
def too_short_indices_tuple(x):
98+
cuda.ldcs(x, (1,))
99+
100+
sig_2d = (types.float32[:,::1],)
101+
102+
msg = "Expected 2 indices, got a scalar"
103+
with self.assertRaisesRegex(errors.TypingError, msg):
104+
cuda.compile_ptx(too_short_indices_scalar, sig_2d)
105+
106+
msg = "Expected 2 indices, got 1"
107+
with self.assertRaisesRegex(errors.TypingError, msg):
108+
cuda.compile_ptx(too_short_indices_tuple, sig_2d)
109+
77110

78111
if __name__ == '__main__':
79112
unittest.main()

0 commit comments

Comments
 (0)