Skip to content

Commit 0035b79

Browse files
committed
Fix the cuda.is_supported_version() API
PR #313 removed the `runtime.is_supported_version()` API, but it is used by the `cuda.is_supported_version()` public API. This commit restores the `cuda.is_supported_version()` API by checking whether the CUDA runtime major version is 12 or 13. The version number check will need bumping as appropriate when future toolkit major versions are added and existing toolkit major version are dropped. This situation will be caught by the test that is added to exercise this API.
1 parent 96e10e6 commit 0035b79

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

numba_cuda/numba/cuda/device_init.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def is_supported_version():
133133
- Generating an error or otherwise preventing the use of CUDA.
134134
"""
135135

136-
return runtime.is_supported_version()
136+
return runtime.get_version()[0] in (12, 13)
137137

138138

139139
def cuda_error():

numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,16 @@ def test_visible_devices_set_after_import(self):
4343
visible_gpu_count = future.result()
4444
assert visible_gpu_count == 1
4545

46+
def test_is_supported_version(self):
47+
# Exercise the `cuda.is_supported_version()` API.
48+
#
49+
# Assume for the purpose of the test that we're running on a supported
50+
# toolkit version; if not, there's not much point in running the test
51+
# suite.
52+
from numba import cuda
53+
54+
self.assertTrue(cuda.is_supported_version())
55+
4656

4757
if __name__ == "__main__":
4858
unittest.main()

0 commit comments

Comments
 (0)