Skip to content

Commit 7806707

Browse files
committed
mocking object type using cuda.core objects
1 parent 2827d0d commit 7806707

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

numba_cuda/numba/cuda/codegen.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from numba.cuda.cudadrv.libs import get_cudalib
77
from numba.cuda.cudadrv.linkable_code import LinkableCode
88

9+
from cuda.core.experimental import ObjectCode
10+
911
import os
1012
import subprocess
1113
import tempfile
@@ -250,14 +252,15 @@ def get_cufunc(self):
250252
return cufunc
251253

252254
cubin = self.get_cubin(cc=device.compute_capability)
253-
module = ctx.create_module_image(cubin)
255+
256+
# just a mock, https://github.com/NVIDIA/numba-cuda/pull/133 will
257+
# formalize the object code interface
258+
obj_code = ObjectCode.from_cubin(cubin)
259+
cufunc = obj_code.get_kernel(self._entry_name)
254260

255261
# Init
256262
for init_fn in self._init_functions:
257-
init_fn(module)
258-
259-
# Load
260-
cufunc = module.get_function(self._entry_name)
263+
init_fn(obj_code)
261264

262265
# Populate caches
263266
self._cufunc_cache[device.id] = cufunc

numba_cuda/numba/cuda/tests/cudadrv/test_module_init.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
from numba.cuda.cudadrv.linkable_code import CUSource
55
from numba.cuda.testing import CUDATestCase
66

7-
from cuda.bindings.driver import cuModuleGetGlobal, cuMemcpyHtoD
7+
from cuda.bindings.driver import (
8+
cuModuleGetGlobal,
9+
cuMemcpyHtoD,
10+
cuLibraryGetModule
11+
)
812

913

1014
class TestModuleInitCallback(CUDATestCase):
@@ -21,11 +25,13 @@ def setUp(self):
2125
}
2226
"""
2327

24-
def set_fourty_two(mod):
28+
def set_fourty_two(obj):
2529
# Initialize 42 to global variable `num`
26-
res, dptr, size = cuModuleGetGlobal(
27-
mod.handle.value, "num".encode()
28-
)
30+
culib = obj._handle
31+
res, mod = cuLibraryGetModule(culib)
32+
self.assertEqual(res, 0)
33+
34+
res, dptr, size = cuModuleGetGlobal(mod, "num".encode())
2935
self.assertEqual(res, 0)
3036
self.assertEqual(size, 4)
3137

0 commit comments

Comments
 (0)