Skip to content

Commit 63f5b0e

Browse files
Enable NV bindings by default and add them as a dependency (#284)
* use nv binding by default and add dep * docs * address review * move code around * reset file * generic cuda-core conda dep * fix * fix logic * update * update deps * updates * Update numba_cuda/numba/cuda/cudadrv/driver.py Co-authored-by: Graham Markall <[email protected]> * use minimal bindings dep --------- Co-authored-by: Graham Markall <[email protected]>
1 parent 5625de8 commit 63f5b0e

File tree

9 files changed

+48
-37
lines changed

9 files changed

+48
-37
lines changed

conda/recipes/numba-cuda/meta.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ requirements:
2828
run:
2929
- python
3030
- numba >=0.59.1
31+
- cuda-bindings
3132

3233
about:
3334
home: {{ project_urls["Homepage"] }}

docs/source/reference/envvars.rst

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,8 @@ target.
108108

109109
When set to 1, Numba will attempt to use the `NVIDIA CUDA Python binding
110110
<https://nvidia.github.io/cuda-python/>`_ to make calls to the driver API
111-
instead of using its own ctypes binding. This defaults to 0 (off), as the
112-
NVIDIA binding is currently missing support for Per-Thread Default
113-
Streams and the profiler APIs.
111+
instead of using its own ctypes binding. This defaults to 1 (on). Set to
112+
0 to use the ctypes bindings.
114113

115114
.. envvar:: NUMBA_CUDA_INCLUDE_PATH
116115

numba_cuda/numba/cuda/__init__.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,36 @@
1+
import importlib
12
from numba import runtests
23
from numba.core import config
4+
from .utils import _readenv
5+
6+
# Enable pynvjitlink if the environment variables NUMBA_CUDA_ENABLE_PYNVJITLINK
7+
# or CUDA_ENABLE_PYNVJITLINK are set, or if the pynvjitlink module is found. If
8+
# explicitly disabled, do not use pynvjitlink, even if present in the env.
9+
_pynvjitlink_enabled_in_env = _readenv(
10+
"NUMBA_CUDA_ENABLE_PYNVJITLINK", bool, None
11+
)
12+
_pynvjitlink_enabled_in_cfg = getattr(config, "CUDA_ENABLE_PYNVJITLINK", None)
13+
14+
if _pynvjitlink_enabled_in_env is not None:
15+
ENABLE_PYNVJITLINK = _pynvjitlink_enabled_in_env
16+
elif _pynvjitlink_enabled_in_cfg is not None:
17+
ENABLE_PYNVJITLINK = _pynvjitlink_enabled_in_cfg
18+
else:
19+
ENABLE_PYNVJITLINK = importlib.util.find_spec("pynvjitlink") is not None
20+
21+
if not hasattr(config, "CUDA_ENABLE_PYNVJITLINK"):
22+
config.CUDA_ENABLE_PYNVJITLINK = ENABLE_PYNVJITLINK
23+
24+
# Upstream numba sets CUDA_USE_NVIDIA_BINDING to 0 by default, so it always
25+
# exists. Override, but not if explicitly set to 0 in the envioronment.
26+
_nvidia_binding_enabled_in_env = _readenv(
27+
"NUMBA_CUDA_USE_NVIDIA_BINDING", bool, None
28+
)
29+
if _nvidia_binding_enabled_in_env is False:
30+
USE_NV_BINDING = False
31+
else:
32+
USE_NV_BINDING = True
33+
config.CUDA_USE_NVIDIA_BINDING = USE_NV_BINDING
334

435
if config.ENABLE_CUDASIM:
536
from .simulator_init import *

numba_cuda/numba/cuda/cudadrv/driver.py

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -49,23 +49,14 @@
4949
from .drvapi import cu_occupancy_b2d_size, cu_stream_callback_pyobj, cu_uuid
5050
from .mappings import FILE_EXTENSION_MAP
5151
from .linkable_code import LinkableCode, LTOIR, Fatbin, Object
52-
from numba.cuda.utils import _readenv, cached_file_read
52+
from numba.cuda.utils import cached_file_read
5353
from numba.cuda.cudadrv import enums, drvapi, nvrtc
5454

5555
try:
5656
from pynvjitlink.api import NvJitLinker, NvJitLinkError
5757
except ImportError:
5858
NvJitLinker, NvJitLinkError = None, None
5959

60-
USE_NV_BINDING = config.CUDA_USE_NVIDIA_BINDING
61-
62-
if USE_NV_BINDING:
63-
from cuda import cuda as binding
64-
65-
# There is no definition of the default stream in the Nvidia bindings (nor
66-
# is there at the C/C++ level), so we define it here so we don't need to
67-
# use a magic number 0 in places where we want the default stream.
68-
CU_STREAM_DEFAULT = 0
6960

7061
MIN_REQUIRED_CC = (3, 5)
7162
SUPPORTS_IPC = sys.platform.startswith("linux")
@@ -82,23 +73,15 @@
8273
"to be available"
8374
)
8475

85-
# Enable pynvjitlink if the environment variables NUMBA_CUDA_ENABLE_PYNVJITLINK
86-
# or CUDA_ENABLE_PYNVJITLINK are set, or if the pynvjitlink module is found. If
87-
# explicitly disabled, do not use pynvjitlink, even if present in the env.
88-
_pynvjitlink_enabled_in_env = _readenv(
89-
"NUMBA_CUDA_ENABLE_PYNVJITLINK", bool, None
90-
)
91-
_pynvjitlink_enabled_in_cfg = getattr(config, "CUDA_ENABLE_PYNVJITLINK", None)
76+
USE_NV_BINDING = config.CUDA_USE_NVIDIA_BINDING
9277

93-
if _pynvjitlink_enabled_in_env is not None:
94-
ENABLE_PYNVJITLINK = _pynvjitlink_enabled_in_env
95-
elif _pynvjitlink_enabled_in_cfg is not None:
96-
ENABLE_PYNVJITLINK = _pynvjitlink_enabled_in_cfg
97-
else:
98-
ENABLE_PYNVJITLINK = importlib.util.find_spec("pynvjitlink") is not None
78+
if USE_NV_BINDING:
79+
from cuda.bindings import driver as binding
9980

100-
if not hasattr(config, "CUDA_ENABLE_PYNVJITLINK"):
101-
config.CUDA_ENABLE_PYNVJITLINK = ENABLE_PYNVJITLINK
81+
# There is no definition of the default stream in the Nvidia bindings (nor
82+
# is there at the C/C++ level), so we define it here so we don't need to
83+
# use a magic number 0 in places where we want the default stream.
84+
CU_STREAM_DEFAULT = 0
10285

10386

10487
def make_logger():
@@ -3192,7 +3175,6 @@ def __init__(self, max_registers=0, lineinfo=False, cc=None):
31923175

31933176
raw_keys = list(options.keys())
31943177
raw_values = list(options.values())
3195-
31963178
self.handle = driver.cuLinkCreate(len(raw_keys), raw_keys, raw_values)
31973179

31983180
weakref.finalize(self, driver.cuLinkDestroy, self.handle)

numba_cuda/numba/cuda/cudadrv/mappings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
from . import enums
33

44
if config.CUDA_USE_NVIDIA_BINDING:
5-
from cuda import cuda
5+
from cuda.bindings import driver
66

7-
jitty = cuda.CUjitInputType
7+
jitty = driver.CUjitInputType
88
FILE_EXTENSION_MAP = {
99
"o": jitty.CU_JIT_INPUT_OBJECT,
1010
"ptx": jitty.CU_JIT_INPUT_PTX,

numba_cuda/numba/cuda/tests/cudadrv/test_linker.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,6 @@ def simple_lmem(A, B, dty):
104104

105105
@skip_on_cudasim("Linking unsupported in the simulator")
106106
class TestLinker(CUDATestCase):
107-
_NUMBA_NVIDIA_BINDING_0_ENV = {"NUMBA_CUDA_USE_NVIDIA_BINDING": "0"}
108-
109107
@require_context
110108
def test_linker_basic(self):
111109
"""Simply go through the constructor and destructor"""

numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from cuda.bindings.driver import cuModuleGetGlobal, cuMemcpyHtoD
1616

1717
if config.CUDA_USE_NVIDIA_BINDING:
18-
from cuda.cuda import CUmodule as cu_module_type
18+
from cuda.bindings.driver import CUmodule as cu_module_type
1919
else:
2020
from numba.cuda.cudadrv.drvapi import cu_module as cu_module_type
2121

numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@
5757
)
5858
@skip_on_cudasim("Linking unsupported in the simulator")
5959
class TestLinker(CUDATestCase):
60-
_NUMBA_NVIDIA_BINDING_0_ENV = {"NUMBA_CUDA_USE_NVIDIA_BINDING": "0"}
61-
6260
def test_nvjitlink_create(self):
6361
patched_linker = PyNvJitLinker(cc=(7, 5))
6462
assert "-arch=sm_75" in patched_linker.options

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@ dependencies = ["numba>=0.59.1"]
2222

2323
[project.optional-dependencies]
2424
cu11 = [
25+
"cuda-bindings==11.8.*",
2526
"cuda-python==11.8.*", # supports all CTK 11.x
2627
"nvidia-cuda-nvcc-cu11", # for libNVVM
2728
"nvidia-cuda-runtime-cu11",
2829
"nvidia-cuda-nvrtc-cu11",
2930
]
3031
cu12 = [
32+
"cuda-bindings==12.9.*",
3133
"cuda-python==12.9.*", # supports all CTK 12.x
3234
"nvidia-cuda-nvcc-cu12", # for libNVVM
3335
"nvidia-cuda-runtime-cu12",
@@ -110,7 +112,7 @@ exclude = [
110112
# Slightly long line in the standard version file
111113
"numba_cuda/_version.py" = ["E501"]
112114
# "Unused" imports / potentially undefined names in init files
113-
"numba_cuda/numba/cuda/__init__.py" = ["F401", "F403", "F405"]
115+
"numba_cuda/numba/cuda/__init__.py" = ["F401", "F403", "F405", "E402"]
114116
"numba_cuda/numba/cuda/simulator/__init__.py" = ["F401", "F403"]
115117
"numba_cuda/numba/cuda/simulator/cudadrv/__init__.py" = ["F401"]
116118
# Ignore star imports", " unused imports", " and "may be defined by star imports"

0 commit comments

Comments
 (0)