Skip to content

Commit 69394c7

Browse files
authored
Fix the cuda.is_supported_version() API (#571)
1 parent bb850ff commit 69394c7

File tree

3 files changed

+22
-0
lines changed

3 files changed

+22
-0
lines changed

numba_cuda/numba/cuda/cudadrv/runtime.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,19 @@
1010

1111
from numba.cuda.cudadrv.nvrtc import _get_nvrtc_version
1212

13+
SUPPORTED_TOOLKIT_MAJOR_VERSIONS = (12, 13)
14+
1315

1416
class Runtime:
1517
def get_version(self):
1618
return _get_nvrtc_version()
1719

20+
def is_supported_version(self):
21+
"""
22+
Returns True if the CUDA Runtime is a supported version.
23+
"""
24+
return self.get_version()[0] in SUPPORTED_TOOLKIT_MAJOR_VERSIONS
25+
1826

1927
runtime = Runtime()
2028

numba_cuda/numba/cuda/simulator/api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,7 @@ def jitwrapper(fn):
160160
def defer_cleanup():
161161
# No effect for simulator
162162
yield
163+
164+
165+
def is_supported_version():
166+
return True

numba_cuda/numba/cuda/tests/cudadrv/test_detect.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@ def test_cuda_detect(self):
2525
self.assertIn("CUDA devices", output)
2626

2727

28+
class TestSupportedVersion(CUDATestCase):
29+
def test_is_supported_version(self):
30+
# Exercise the `cuda.is_supported_version()` API.
31+
#
32+
# Assume for the purpose of the test that we're running on a supported
33+
# toolkit version; if not, there's not much point in running the test
34+
# suite.
35+
self.assertTrue(cuda.is_supported_version())
36+
37+
2838
@skip_under_cuda_memcheck("Hangs cuda-memcheck")
2939
class TestCUDAFindLibs(CUDATestCase):
3040
def run_cmd(self, cmdline, env):

0 commit comments

Comments
 (0)