Skip to content

Commit b8c03c4

Browse files
committed
initial
1 parent 9479123 commit b8c03c4

File tree

2 files changed

+77
-17
lines changed

2 files changed

+77
-17
lines changed

numba_cuda/numba/cuda/cudadrv/driver.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@
4242
from .linkable_code import LinkableCode, LTOIR, Fatbin, Object
4343
from numba.cuda.cudadrv import enums, drvapi, nvrtc
4444

45+
try:
46+
from pynvjitlink.api import NvJitLinker, NvJitLinkError
47+
except ImportError:
48+
NvJitLinker, NvJitLinkError = None, None
49+
4550
USE_NV_BINDING = config.CUDA_USE_NVIDIA_BINDING
4651

4752
if USE_NV_BINDING:
@@ -92,20 +97,6 @@ def _readenv(name, ctor, default):
9297
if not hasattr(config, "CUDA_ENABLE_PYNVJITLINK"):
9398
config.CUDA_ENABLE_PYNVJITLINK = ENABLE_PYNVJITLINK
9499

95-
if ENABLE_PYNVJITLINK:
96-
try:
97-
from pynvjitlink.api import NvJitLinker, NvJitLinkError
98-
except ImportError:
99-
raise ImportError(
100-
"Using pynvjitlink requires the pynvjitlink package to be available"
101-
)
102-
103-
if config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY:
104-
raise ValueError(
105-
"Can't set CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY and "
106-
"CUDA_ENABLE_PYNVJITLINK at the same time"
107-
)
108-
109100

110101
def make_logger():
111102
logger = logging.getLogger(__name__)
@@ -3061,6 +3052,17 @@ def __init__(
30613052
lto=False,
30623053
additional_flags=None,
30633054
):
3055+
if NvJitLinker is None:
3056+
raise ImportError(
3057+
"Using pynvjitlink requires the pynvjitlink package to be "
3058+
"available"
3059+
)
3060+
3061+
if config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY:
3062+
raise ValueError(
3063+
"Can't set CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY and "
3064+
"CUDA_ENABLE_PYNVJITLINK at the same time"
3065+
)
30643066

30653067
if cc is None:
30663068
raise RuntimeError("PyNvJitLinker requires CC to be specified")

numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,18 @@
22
from numba.cuda.testing import skip_on_cudasim
33
from numba.cuda.testing import CUDATestCase
44
from numba.cuda.cudadrv.driver import PyNvJitLinker
5+
from numba.cuda import get_current_device
6+
7+
from numba import cuda
8+
from numba import config
9+
from numba.tests.support import run_in_subprocess, override_config
510

611
import itertools
712
import os
813
import io
914
import contextlib
1015
import warnings
1116

12-
from numba.cuda import get_current_device
13-
from numba import cuda
14-
from numba import config
1517

1618
TEST_BIN_DIR = os.getenv("NUMBA_CUDA_TEST_BIN_DIR")
1719
if TEST_BIN_DIR:
@@ -251,5 +253,61 @@ def kernel():
251253
pass
252254

253255

256+
class TestLinkerUsage(CUDATestCase):
257+
"""Test that whether pynvjitlink can be enabled by both environment variable
258+
and modification of config at runtime.
259+
"""
260+
def test_linker_enabled_envvar(self):
261+
# Linkable code is only supported via pynvjitlink
262+
src = """if 1:
263+
import os
264+
from numba import cuda
265+
266+
TEST_BIN_DIR = os.getenv("NUMBA_CUDA_TEST_BIN_DIR")
267+
if TEST_BIN_DIR:
268+
test_device_functions_cubin = os.path.join(
269+
TEST_BIN_DIR, "test_device_functions.cubin"
270+
)
271+
print(TEST_BIN_DIR)
272+
files = (
273+
test_device_functions_cubin,
274+
)
275+
for lto in [True, False]:
276+
for file in files:
277+
sig = "uint32(uint32, uint32)"
278+
add_from_numba = cuda.declare_device("add_from_numba", sig)
279+
280+
@cuda.jit(link=[file], lto=lto)
281+
def kernel(result):
282+
result[0] = add_from_numba(1, 2)
283+
284+
result = cuda.device_array(1)
285+
kernel[1, 1](result)
286+
assert result[0] == 3
287+
"""
288+
env = os.environ.copy()
289+
env['NUMBA_CUDA_ENABLE_PYNVJITLINK'] = "1"
290+
print(env['NUMBA_CUDA_TEST_BIN_DIR'])
291+
run_in_subprocess(src, env=env)
292+
293+
def test_linker_enabled_config(self):
294+
with override_config("CUDA_ENABLE_PYNVJITLINK", True):
295+
files = (
296+
test_device_functions_cubin,
297+
)
298+
for lto in [True, False]:
299+
for file in files:
300+
sig = "uint32(uint32, uint32)"
301+
add_from_numba = cuda.declare_device("add_from_numba", sig)
302+
303+
@cuda.jit(link=[file], lto=lto)
304+
def kernel(result):
305+
result[0] = add_from_numba(1, 2)
306+
307+
result = cuda.device_array(1)
308+
kernel[1, 1](result)
309+
assert result[0] == 3
310+
311+
254312
if __name__ == "__main__":
255313
unittest.main()

0 commit comments

Comments
 (0)