Skip to content

Commit f0a7672

Browse files
test
1 parent 1c5d05a commit f0a7672

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

numba_cuda/numba/cuda/dispatcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ def _prepare_args(self, ty, val, stream, retr, kernelargs):
542542
ty, val = extension.prepare_args(
543543
ty, val, stream=stream, retr=retr
544544
)
545-
elif handler := _arg_handlers.get(ty):
545+
elif handler := _arg_handlers.get(type(val)):
546546
ty, val = handler.prepare_args(ty, val, stream=stream, retr=retr)
547547

548548
if isinstance(ty, types.Array):

numba_cuda/numba/cuda/tests/cudapy/test_extending.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,5 +302,35 @@ def foo(r, x):
302302
np.testing.assert_equal(r, x * 2)
303303

304304

305+
@skip_on_cudasim("Extensions not supported in the simulator")
306+
class TestArgHandlerRegistration(CUDATestCase):
307+
def test_register_arg_handler(self):
308+
from numba.cuda.dispatcher import register_arg_handler, ArgHandlerBase
309+
310+
class NumpyArrayWrapper:
311+
def __init__(self, arr):
312+
self.arr = arr
313+
314+
class NumpyArrayWrapperArgHandler(ArgHandlerBase):
315+
def prepare_args(self, ty, val, **kwargs):
316+
return types.int32[::1], val.arr
317+
318+
register_arg_handler(
319+
NumpyArrayWrapperArgHandler(), (NumpyArrayWrapper,)
320+
)
321+
322+
@cuda.jit("void(int32[::1])")
323+
def kernel(arr):
324+
i = cuda.grid(1)
325+
if i < arr.size:
326+
arr[i] += 1
327+
328+
arr = np.zeros(10, dtype=np.int32)
329+
wrapped_arr = NumpyArrayWrapper(arr)
330+
331+
kernel.forall(len(arr))(wrapped_arr)
332+
np.testing.assert_equal(arr, np.ones(10, dtype=np.int32))
333+
334+
305335
if __name__ == "__main__":
306336
unittest.main()

0 commit comments

Comments
 (0)