Skip to content

Commit cc70f1d

Browse files
drop to ctypes ptr in nrt when using nv binding
1 parent 0c8171a commit cc70f1d

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

numba_cuda/numba/cuda/cudadrv/driver.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2553,7 +2553,6 @@ def launch_kernel(cufunc_handle,
25532553
hstream,
25542554
args,
25552555
cooperative=False):
2556-
25572556
param_ptrs = [addressof(arg) for arg in args]
25582557
params = (c_void_p * len(param_ptrs))(*param_ptrs)
25592558

numba_cuda/numba/cuda/runtime/nrt.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from numba.cuda.utils import _readenv
1212

1313

14+
1415
# Check environment variable or config for NRT statistics enablement
1516
NRT_STATS = (
1617
_readenv("NUMBA_CUDA_NRT_STATS", bool, False) or
@@ -117,6 +118,13 @@ def _single_thread_launch(self, module, stream, name, params=()):
117118
if stream is None:
118119
stream = cuda.default_stream()
119120

121+
if config.CUDA_USE_NVIDIA_BINDING:
122+
from numba.cuda.cudadrv.drvapi import cu_device_ptr
123+
from cuda.cuda import CUdeviceptr
124+
params = tuple(
125+
cu_device_ptr.from_address(ptr.getPtr()) if isinstance(ptr, CUdeviceptr) else ptr for ptr in params
126+
)
127+
120128
func = module.get_function(name)
121129
launch_kernel(
122130
func.handle,
@@ -294,6 +302,7 @@ def set_memsys_to_module(self, module, stream=None):
294302
raise RuntimeError(
295303
"Please allocate NRT Memsys first before setting to module.")
296304

305+
297306
self._single_thread_launch(
298307
module,
299308
stream,

0 commit comments

Comments
 (0)