Skip to content

Commit f5f81fa

Browse files
authored
Updates for recent API changes (#313)
* Use NVRTC for supported CCs Changes in the compute capability support matrix in nvvm.py will continue to be needed with new CUDA versions if we maintain a list of explicitly-supported compute capabilities. NVRTC supports retrieving the supported list programmatically, so we switch to using it instead. This does assume that the user's environment has a consistent set of components (NVVM, NVRTC, etc.) - this is generally expected to be the case with recent developments in package management, and there's little we can do about an inconsistent environment anyway. Changes outside of nvvm.py / nvrtc.py are to accommodate the movement of this functionality. A major side effect is that we no longer need to initialize the list of supported CCs prior to forking, because we don't need to use the CUDA runtime to populate the supported CC list. * Use NVRTC to get runtime version We only used the CUDA runtime library to get the runtime version so that we could populate the list of supported compute capabilities in nvvm.py. Now that we don't do this, and that NVRTC provides the CUDA toolkit version, there is no need to use the CUDA runtime API at all. The Numba API for the runtime version is not deleted in case it was used by external code - instead, it uses NVRTC to obtain the toolkit version. Because NVRTC used the runtime version to determine what prototypes to bind, we need to stop doing that to avoid a circular dependency / deadlock - instead of checking the runtime version and creating the list of prototypes, we try to add all known prototypes, and ignore errors in those related to LTOIR, which can occur with CUDA 11 where they were not present. The `runtime.is_supported_version()` API and its test is removed - it would always have been `False` on CUDA 12 (incorrectly) and this has never been reported as an issue, so it seems very unlikely that anyone was using it. * Update for new cccl search location Recent toolkits move the CCCL headers into their own subdirectory, so we need to add this subdirectory to the include path so that headers such as `cuda/atomic` etc. can be located successfully in all cases. * Handle variants of cuCtxCreate() The most recent `cuCtxCreate()` API in the CUDA bindings will require an additional optional parameter. We don't have to supply a value for it (other than `None`), but we do need to provide the argument on binding versions where it is required. * Delete docs on runtime binding * Default to at least compute capability 7.5 The change to use NVRTC for the supported compute capabilities also had the implicit effect of making the default compute capability the lowest supported by the installed NVRTC version. We need it to default to at least 7.5 (unless specified higher by the user) to maintain the behaviour of the compute capability logic from nvvm.py that was replaced. * Locate NVRTC DLL by searching possible paths We use NVRTC to get the CUDA version, so we can't use the CUDA version to determine the NVRTC DLL / SO anymore. Instead, check for the presence of each version, preferring the highest.
1 parent 2fa741c commit f5f81fa

File tree

21 files changed

+176
-531
lines changed

21 files changed

+176
-531
lines changed

docs/source/reference/host.rst

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -229,20 +229,3 @@ stream, and the stream must remain valid whilst the Numba ``Stream`` object is
229229
in use.
230230

231231
.. autofunction:: numba.cuda.external_stream
232-
233-
234-
Runtime
235-
-------
236-
237-
Numba generally uses the Driver API, but it provides a simple wrapper to the
238-
Runtime API so that the version of the runtime in use can be queried. This is
239-
accessed through ``cuda.runtime``, which is an instance of the
240-
:class:`numba.cuda.cudadrv.runtime.Runtime` class:
241-
242-
.. autoclass:: numba.cuda.cudadrv.runtime.Runtime
243-
:members: get_version, is_supported_version, supported_versions
244-
245-
Whether the current runtime is officially supported and tested with the current
246-
version of Numba can also be queried:
247-
248-
.. autofunction:: numba.cuda.is_supported_version

numba_cuda/numba/cuda/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,19 @@
8383
implementation = "NVIDIA"
8484

8585

86+
# The default compute capability as set by the upstream Numba implementation.
87+
config_default_cc = config.CUDA_DEFAULT_PTX_CC
88+
89+
# The default compute capability for Numba-CUDA. This will usually override the
90+
# upstream Numba built-in default of 5.0, unless the user has set it even
91+
# higher, in which case we should use the user-specified value. This default is
92+
# aligned with recent toolkit versions.
93+
numba_cuda_default_ptx_cc = (7, 5)
94+
95+
if numba_cuda_default_ptx_cc > config_default_cc:
96+
config.CUDA_DEFAULT_PTX_CC = numba_cuda_default_ptx_cc
97+
98+
8699
def test(*args, **kwargs):
87100
if not is_available():
88101
raise cuda_error()

numba_cuda/numba/cuda/codegen.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from numba.core import config, serialize
44
from numba.core.codegen import Codegen, CodeLibrary
5-
from .cudadrv import devices, driver, nvvm, runtime
5+
from .cudadrv import devices, driver, nvrtc, nvvm, runtime
66
from numba.cuda.cudadrv.libs import get_cudalib
77
from numba.cuda.cudadrv.linkable_code import LinkableCode
88
from numba.cuda.memory_management.nrt import NRT_LIBRARY
@@ -211,7 +211,7 @@ def get_asm_str(self, cc=None):
211211
if ptxes:
212212
return ptxes
213213

214-
arch = nvvm.get_arch_option(*cc)
214+
arch = nvrtc.get_arch_option(*cc)
215215
options = self._nvvm_options.copy()
216216
options["arch"] = arch
217217

@@ -240,7 +240,7 @@ def get_ltoir(self, cc=None):
240240
if ltoir is not None:
241241
return ltoir
242242

243-
arch = nvvm.get_arch_option(*cc)
243+
arch = nvrtc.get_arch_option(*cc)
244244
options = self._nvvm_options.copy()
245245
options["arch"] = arch
246246
options["gen-lto"] = None

numba_cuda/numba/cuda/compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from numba.cuda import nvvmutils
3636
from numba.cuda.api import get_current_device
3737
from numba.cuda.codegen import ExternalCodeLibrary
38-
from numba.cuda.cudadrv import nvvm
38+
from numba.cuda.cudadrv import nvvm, nvrtc
3939
from numba.cuda.descriptor import cuda_target
4040
from numba.cuda.flags import CUDAFlags
4141
from numba.cuda.target import CUDACABICallConv
@@ -640,7 +640,7 @@ def compile(
640640
# If the user has used the config variable to specify a non-default that is
641641
# greater than the lowest non-deprecated one, then we should default to
642642
# their specified CC instead of the lowest non-deprecated one.
643-
MIN_CC = max(config.CUDA_DEFAULT_PTX_CC, nvvm.LOWEST_CURRENT_CC)
643+
MIN_CC = max(config.CUDA_DEFAULT_PTX_CC, nvrtc.get_lowest_supported_cc())
644644
cc = cc or MIN_CC
645645

646646
cres = compile_cuda(

numba_cuda/numba/cuda/cuda_paths.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -132,16 +132,9 @@ def _get_nvvm_wheel():
132132
return None
133133

134134

135-
def get_major_cuda_version():
136-
# TODO: remove once cuda-python is
137-
# a hard dependency
138-
from numba.cuda.cudadrv.runtime import get_version
139-
140-
return get_version()[0]
141-
142-
143135
def get_nvrtc_dso_path():
144136
site_paths = [site.getusersitepackages()] + site.getsitepackages()
137+
145138
for sp in site_paths:
146139
lib_dir = os.path.join(
147140
sp,
@@ -150,23 +143,28 @@ def get_nvrtc_dso_path():
150143
("bin" if IS_WIN32 else "lib") if sp else None,
151144
)
152145
if lib_dir and os.path.exists(lib_dir):
153-
try:
154-
major = get_major_cuda_version()
155-
if major == 11:
156-
cu_ver = "112" if IS_WIN32 else "11.2"
157-
elif major == 12:
158-
cu_ver = "120" if IS_WIN32 else "12"
159-
else:
160-
raise NotImplementedError(f"CUDA {major} is not supported")
161-
162-
return os.path.join(
146+
chosen_path = None
147+
148+
# Check for each version of the NVRTC DLL, preferring the most
149+
# recent.
150+
versions = (
151+
"112" if IS_WIN32 else "11.2",
152+
"120" if IS_WIN32 else "12",
153+
"130" if IS_WIN32 else "13",
154+
)
155+
156+
for version in versions:
157+
dso_path = os.path.join(
163158
lib_dir,
164-
f"nvrtc64_{cu_ver}_0.dll"
159+
f"nvrtc64_{version}_0.dll"
165160
if IS_WIN32
166-
else f"libnvrtc.so.{cu_ver}",
161+
else f"libnvrtc.so.{version}",
167162
)
168-
except RuntimeError:
169-
continue
163+
164+
if os.path.exists(dso_path) and os.path.isfile(dso_path):
165+
chosen_path = dso_path
166+
167+
return chosen_path
170168

171169

172170
def _get_nvrtc_wheel():

numba_cuda/numba/cuda/cudadrv/error.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,7 @@ class NvrtcBuiltinOperationFailure(NvrtcError):
3838

3939
class NvrtcSupportError(ImportError):
4040
pass
41+
42+
43+
class CCSupportError(RuntimeError):
44+
pass

numba_cuda/numba/cuda/cudadrv/libs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def test():
154154
print(f"\t\t{location}")
155155

156156
# Checks for dynamic libraries
157-
libs = "nvvm nvrtc cudart".split()
157+
libs = "nvvm nvrtc".split()
158158
for lib in libs:
159159
path = get_cudalib(lib)
160160
print("Finding {} from {}".format(lib, _get_source_variable(lib)))

numba_cuda/numba/cuda/cudadrv/nvrtc.py

Lines changed: 80 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from ctypes import byref, c_char, c_char_p, c_int, c_size_t, c_void_p, POINTER
22
from enum import IntEnum
33
from numba.cuda.cudadrv.error import (
4+
CCSupportError,
45
NvrtcError,
56
NvrtcBuiltinOperationFailure,
67
NvrtcCompilationError,
@@ -79,20 +80,6 @@ class NVRTC:
7980
(for Numba) open_cudalib function to load the NVRTC library.
8081
"""
8182

82-
_CU11_2ONLY_PROTOTYPES = {
83-
# nvrtcResult nvrtcGetNumSupportedArchs(int *numArchs);
84-
"nvrtcGetNumSupportedArchs": (nvrtc_result, POINTER(c_int)),
85-
# nvrtcResult nvrtcGetSupportedArchs(int *supportedArchs);
86-
"nvrtcGetSupportedArchs": (nvrtc_result, POINTER(c_int)),
87-
}
88-
89-
_CU12ONLY_PROTOTYPES = {
90-
# nvrtcResult nvrtcGetLTOIRSize(nvrtcProgram prog, size_t *ltoSizeRet);
91-
"nvrtcGetLTOIRSize": (nvrtc_result, nvrtc_program, POINTER(c_size_t)),
92-
# nvrtcResult nvrtcGetLTOIR(nvrtcProgram prog, char *lto);
93-
"nvrtcGetLTOIR": (nvrtc_result, nvrtc_program, c_char_p),
94-
}
95-
9683
_PROTOTYPES = {
9784
# nvrtcResult nvrtcVersion(int *major, int *minor)
9885
"nvrtcVersion": (nvrtc_result, POINTER(c_int), POINTER(c_int)),
@@ -140,6 +127,14 @@ class NVRTC:
140127
),
141128
# nvrtcResult nvrtcGetProgramLog(nvrtcProgram prog, char *log);
142129
"nvrtcGetProgramLog": (nvrtc_result, nvrtc_program, c_char_p),
130+
# nvrtcResult nvrtcGetNumSupportedArchs(int *numArchs);
131+
"nvrtcGetNumSupportedArchs": (nvrtc_result, POINTER(c_int)),
132+
# nvrtcResult nvrtcGetSupportedArchs(int *supportedArchs);
133+
"nvrtcGetSupportedArchs": (nvrtc_result, POINTER(c_int)),
134+
# nvrtcResult nvrtcGetLTOIRSize(nvrtcProgram prog, size_t *ltoSizeRet);
135+
"nvrtcGetLTOIRSize": (nvrtc_result, nvrtc_program, POINTER(c_size_t)),
136+
# nvrtcResult nvrtcGetLTOIR(nvrtcProgram prog, char *lto);
137+
"nvrtcGetLTOIR": (nvrtc_result, nvrtc_program, c_char_p),
143138
}
144139

145140
# Singleton reference
@@ -157,18 +152,18 @@ def __new__(cls):
157152
cls.__INSTANCE = None
158153
raise NvrtcSupportError("NVRTC cannot be loaded") from e
159154

160-
from numba.cuda.cudadrv.runtime import get_version
161-
162-
if get_version() >= (11, 2):
163-
inst._PROTOTYPES |= inst._CU11_2ONLY_PROTOTYPES
164-
if get_version() >= (12, 0):
165-
inst._PROTOTYPES |= inst._CU12ONLY_PROTOTYPES
166-
167155
# Find & populate functions
168156
for name, proto in inst._PROTOTYPES.items():
169-
func = getattr(lib, name)
170-
func.restype = proto[0]
171-
func.argtypes = proto[1:]
157+
try:
158+
func = getattr(lib, name)
159+
func.restype = proto[0]
160+
func.argtypes = proto[1:]
161+
except AttributeError:
162+
if "LTOIR" in name:
163+
# CUDA 11 does not have LTOIR functions; ignore
164+
continue
165+
else:
166+
raise
172167

173168
@functools.wraps(func)
174169
def checked_call(*args, func=func, name=name):
@@ -195,52 +190,16 @@ def checked_call(*args, func=func, name=name):
195190

196191
return cls.__INSTANCE
197192

193+
@functools.cache
198194
def get_supported_archs(self):
199195
"""
200196
Get Supported Architectures by NVRTC as list of arch tuples.
201197
"""
202-
ver = self.get_version()
203-
if ver < (11, 0):
204-
raise RuntimeError(
205-
"Unsupported CUDA version. CUDA 11.0 or higher is required."
206-
)
207-
elif ver == (11, 0):
208-
return [
209-
(3, 0),
210-
(3, 2),
211-
(3, 5),
212-
(3, 7),
213-
(5, 0),
214-
(5, 2),
215-
(5, 3),
216-
(6, 0),
217-
(6, 1),
218-
(6, 2),
219-
(7, 0),
220-
(7, 2),
221-
(7, 5),
222-
]
223-
elif ver == (11, 1):
224-
return [
225-
(3, 5),
226-
(3, 7),
227-
(5, 0),
228-
(5, 2),
229-
(5, 3),
230-
(6, 0),
231-
(6, 1),
232-
(6, 2),
233-
(7, 0),
234-
(7, 2),
235-
(7, 5),
236-
(8, 0),
237-
]
238-
else:
239-
num = c_int()
240-
self.nvrtcGetNumSupportedArchs(byref(num))
241-
archs = (c_int * num.value)()
242-
self.nvrtcGetSupportedArchs(archs)
243-
return [(archs[i] // 10, archs[i] % 10) for i in range(num.value)]
198+
num = c_int()
199+
self.nvrtcGetNumSupportedArchs(byref(num))
200+
archs = (c_int * num.value)()
201+
self.nvrtcGetSupportedArchs(archs)
202+
return [(archs[i] // 10, archs[i] % 10) for i in range(num.value)]
244203

245204
def get_version(self):
246205
"""
@@ -349,9 +308,9 @@ def compile(src, name, cc, ltoir=False):
349308

350309
version = nvrtc.get_version()
351310
ver_str = lambda v: ".".join(v)
352-
if version < (11, 0):
311+
if version < (11, 2):
353312
raise RuntimeError(
354-
"Unsupported CUDA version. CUDA 11.0 or higher is required."
313+
"Unsupported CUDA version. CUDA 11.2 or higher is required."
355314
)
356315
else:
357316
supported_arch = nvrtc.get_supported_archs()
@@ -383,8 +342,10 @@ def compile(src, name, cc, ltoir=False):
383342
else:
384343
arch = f"--gpu-architecture=compute_{major}{minor}"
385344

386-
cuda_include = [
387-
f"{get_cuda_paths()['include_dir'].info}",
345+
cuda_include_dir = get_cuda_paths()["include_dir"].info
346+
cuda_includes = [
347+
f"{cuda_include_dir}",
348+
f"{os.path.join(cuda_include_dir, 'cccl')}",
388349
]
389350

390351
nvrtc_version = nvrtc.get_version()
@@ -405,7 +366,7 @@ def compile(src, name, cc, ltoir=False):
405366

406367
nrt_include = os.path.join(numba_cuda_path, "memory_management")
407368

408-
includes = [numba_include, *cuda_include, nrt_include, *extra_includes]
369+
includes = [numba_include, *cuda_includes, nrt_include, *extra_includes]
409370

410371
if config.CUDA_USE_NVIDIA_BINDING:
411372
options = ProgramOptions(
@@ -474,3 +435,51 @@ def write(self, msg):
474435
else:
475436
ptx = nvrtc.get_ptx(program)
476437
return ptx, log
438+
439+
440+
def find_closest_arch(mycc):
441+
"""
442+
Given a compute capability, return the closest compute capability supported
443+
by the CUDA toolkit.
444+
445+
:param mycc: Compute capability as a tuple ``(MAJOR, MINOR)``
446+
:return: Closest supported CC as a tuple ``(MAJOR, MINOR)``
447+
"""
448+
supported_ccs = get_supported_ccs()
449+
450+
for i, cc in enumerate(supported_ccs):
451+
if cc == mycc:
452+
# Matches
453+
return cc
454+
elif cc > mycc:
455+
# Exceeded
456+
if i == 0:
457+
# CC lower than supported
458+
msg = (
459+
"GPU compute capability %d.%d is not supported"
460+
"(requires >=%d.%d)" % (mycc + cc)
461+
)
462+
raise CCSupportError(msg)
463+
else:
464+
# return the previous CC
465+
return supported_ccs[i - 1]
466+
467+
# CC higher than supported
468+
return supported_ccs[-1] # Choose the highest
469+
470+
471+
def get_arch_option(major, minor):
472+
"""Matches with the closest architecture option"""
473+
if config.FORCE_CUDA_CC:
474+
arch = config.FORCE_CUDA_CC
475+
else:
476+
arch = find_closest_arch((major, minor))
477+
return "compute_%d%d" % arch
478+
479+
480+
def get_lowest_supported_cc():
481+
return min(get_supported_ccs())
482+
483+
484+
def get_supported_ccs():
485+
return NVRTC().get_supported_archs()

0 commit comments

Comments
 (0)