55from textwrap import dedent
66
77from numba import cuda
8- from numba .cuda import uint32 , uint64 , float32 , float64
8+ from numba .cuda import uint32 , uint64 , float32 , float64 , int32
99from numba .cuda .testing import unittest , CUDATestCase , cc_X_or_above
1010from numba .cuda .core import config
1111
@@ -239,19 +239,19 @@ def atomic_add_double_3(ary):
239239
240240def atomic_sub (ary ):
241241 atomic_binary_1dim_shared (
242- ary , ary , 1 , uint32 , 32 , cuda .atomic .sub , atomic_cast_none , 0 , False
242+ ary , ary , 1 , int32 , 32 , cuda .atomic .sub , atomic_cast_none , 0 , False
243243 )
244244
245245
246246def atomic_sub2 (ary ):
247247 atomic_binary_2dim_shared (
248- ary , 1 , uint32 , (4 , 8 ), cuda .atomic .sub , atomic_cast_none , False
248+ ary , 1 , int32 , (4 , 8 ), cuda .atomic .sub , atomic_cast_none , False
249249 )
250250
251251
252252def atomic_sub3 (ary ):
253253 atomic_binary_2dim_shared (
254- ary , 1 , uint32 , (4 , 8 ), cuda .atomic .sub , atomic_cast_to_uint64 , False
254+ ary , 1 , int32 , (4 , 8 ), cuda .atomic .sub , atomic_cast_to_uint64 , False
255255 )
256256
257257
@@ -789,7 +789,7 @@ def test_atomic_add_double_global_3(self):
789789 self .assertCorrectFloat64Atomics (cuda_func , shared = False )
790790
791791 def test_atomic_sub (self ):
792- ary = np .random .randint (0 , 32 , size = 32 ). astype ( np .int32 )
792+ ary = np .random .randint (0 , 32 , size = 32 , dtype = np .int32 )
793793 orig = ary .copy ()
794794 cuda_atomic_sub = cuda .jit ("void(int32[:])" )(atomic_sub )
795795 cuda_atomic_sub [1 , 32 ](ary )
@@ -801,16 +801,16 @@ def test_atomic_sub(self):
801801 self .assertTrue (np .all (ary == gold ))
802802
803803 def test_atomic_sub2 (self ):
804- ary = np .random .randint (0 , 32 , size = 32 ). astype ( np . uint32 ). reshape ( 4 , 8 )
804+ ary = np .random .randint (0 , 32 , size = ( 4 , 8 ), dtype = np . int32 )
805805 orig = ary .copy ()
806- cuda_atomic_sub2 = cuda .jit ("void(uint32 [:,:])" )(atomic_sub2 )
806+ cuda_atomic_sub2 = cuda .jit ("void(int32 [:,:])" )(atomic_sub2 )
807807 cuda_atomic_sub2 [1 , (4 , 8 )](ary )
808808 self .assertTrue (np .all (ary == orig - 1 ))
809809
810810 def test_atomic_sub3 (self ):
811- ary = np .random .randint (0 , 32 , size = 32 ). astype ( np . uint32 ). reshape ( 4 , 8 )
811+ ary = np .random .randint (0 , 32 , size = ( 4 , 8 ), dtype = np . uint32 )
812812 orig = ary .copy ()
813- cuda_atomic_sub3 = cuda .jit ("void(uint32 [:,:])" )(atomic_sub3 )
813+ cuda_atomic_sub3 = cuda .jit ("void(int32 [:,:])" )(atomic_sub3 )
814814 cuda_atomic_sub3 [1 , (4 , 8 )](ary )
815815 self .assertTrue (np .all (ary == orig - 1 ))
816816
0 commit comments