1- from numba import cuda , typeof
1+ from numba import cuda , errors , typeof , types
22from numba .cuda .testing import unittest , CUDATestCase
33import numpy as np
44
5- types = (
5+ tested_types = (
66 np .int8 , np .int16 , np .int32 , np .int64 ,
77 np .uint8 , np .uint16 , np .uint32 , np .uint64 ,
88 np .float16 , np .float32 , np .float64 ,
@@ -35,7 +35,7 @@ def f(r, x):
3535 for i in range (len (r )):
3636 r [i ] = operator (x , i )
3737
38- for ty in types :
38+ for ty in tested_types :
3939 with self .subTest (operator = operator , ty = ty ):
4040 x = np .arange (5 ).astype (ty )
4141 r = np .zeros_like (x )
@@ -58,7 +58,7 @@ def f(r, x):
5858 for i in range (len (r )):
5959 operator (r , i , x [i ])
6060
61- for ty in types :
61+ for ty in tested_types :
6262 with self .subTest (operator = operator , ty = ty ):
6363 x = np .arange (5 ).astype (ty )
6464 r = np .zeros_like (x )
@@ -74,6 +74,39 @@ def f(r, x):
7474
7575 self .assertIn (f"st.global.{ modifier } .b{ bitwidth } " , ptx )
7676
77+ def test_bad_indices (self ):
78+ def float_indices (x ):
79+ cuda .ldcs (x , 1.0 )
80+
81+ sig_1d = (types .float32 [::1 ],)
82+
83+ msg = "float64 is not a valid index"
84+ with self .assertRaisesRegex (errors .TypingError , msg ):
85+ cuda .compile_ptx (float_indices , sig_1d )
86+
87+ def too_long_indices (x ):
88+ cuda .ldcs (x , (1 , 2 ))
89+
90+ msg = "Expected 1 indices, got 2"
91+ with self .assertRaisesRegex (errors .TypingError , msg ):
92+ cuda .compile_ptx (too_long_indices , sig_1d )
93+
94+ def too_short_indices_scalar (x ):
95+ cuda .ldcs (x , 1 )
96+
97+ def too_short_indices_tuple (x ):
98+ cuda .ldcs (x , (1 ,))
99+
100+ sig_2d = (types .float32 [:,::1 ],)
101+
102+ msg = "Expected 2 indices, got a scalar"
103+ with self .assertRaisesRegex (errors .TypingError , msg ):
104+ cuda .compile_ptx (too_short_indices_scalar , sig_2d )
105+
106+ msg = "Expected 2 indices, got 1"
107+ with self .assertRaisesRegex (errors .TypingError , msg ):
108+ cuda .compile_ptx (too_short_indices_tuple , sig_2d )
109+
77110
78111if __name__ == '__main__' :
79112 unittest .main ()
0 commit comments