Skip to content

Commit b2f4245

Browse files
small updates
1 parent ff18c5c commit b2f4245

File tree

2 files changed

+4
-16
lines changed

2 files changed

+4
-16
lines changed

numba_cuda/numba/cuda/cudadrv/driver.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
system to freeze in some cases.
1111
1212
"""
13-
1413
import sys
1514
import os
1615
import ctypes
@@ -85,12 +84,11 @@ def _readenv(name, ctor, default):
8584

8685
ENABLE_PYNVJITLINK = (
8786
_readenv("ENABLE_PYNVJITLINK", bool, False)
88-
or getattr(config, "ENABLE_PYNVJITLINK", None)
87+
or getattr(config, "ENABLE_PYNVJITLINK", False)
8988
)
9089
if not hasattr(config, "ENABLE_PYNVJITLINK"):
9190
config.ENABLE_PYNVJITLINK = ENABLE_PYNVJITLINK
9291

93-
9492
if ENABLE_PYNVJITLINK:
9593
try:
9694
from pynvjitlink.api import NvJitLinker, NvJitLinkError

numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,8 @@
88
from numba import cuda
99
from numba import config
1010

11-
HAVE_PYNVJITLINK = False
12-
try:
13-
import pynvjitlink # noqa: F401
14-
from pynvjitlink.api import NvJitLinkError
1511

16-
HAVE_PYNVJITLINK = True
17-
except ImportError:
18-
pass
19-
20-
21-
@unittest.skipIf(not HAVE_PYNVJITLINK, "pynvjitlink not available")
12+
@unittest.skipIf(config.ENABLE_PYNVJITLINK, "pynvjitlink not enabled")
2213
@skip_on_cudasim("Linking unsupported in the simulator")
2314
class TestLinker(CUDATestCase):
2415
_NUMBA_NVIDIA_BINDING_0_ENV = {"NUMBA_CUDA_USE_NVIDIA_BINDING": "0"}
@@ -35,6 +26,8 @@ def test_nvjitlink_create_no_cc_error(self):
3526
PyNvJitLinker()
3627

3728
def test_nvjitlink_invalid_arch_error(self):
29+
from pynvjitlink.api import NvJitLinkError
30+
3831
# CC 0.0 is not a valid compute capability
3932
with self.assertRaisesRegex(
4033
NvJitLinkError, "NVJITLINK_ERROR_UNRECOGNIZED_OPTION error"
@@ -126,7 +119,6 @@ def test_nvjitlink_test_add_file_guess_ext_invalid_input(self):
126119
# because there's no way to know what kind of file to treat it as
127120
patched_linker.add_file_guess_ext(content)
128121

129-
@unittest.skipIf(not HAVE_PYNVJITLINK, "pynvjitlink not available")
130122
def test_nvjitlink_jit_with_linkable_code(self):
131123
files = (
132124
"test_device_functions.a",
@@ -138,8 +130,6 @@ def test_nvjitlink_jit_with_linkable_code(self):
138130
)
139131
for file in files:
140132
with self.subTest(file=file):
141-
# TODO: unsafe teardown if test errors
142-
config.ENABLE_PYNVJITLINK = True
143133
sig = "uint32(uint32, uint32)"
144134
add_from_numba = cuda.declare_device("add_from_numba", sig)
145135

0 commit comments

Comments
 (0)