Skip to content

Commit 70713b4

Browse files
committed
Restore Runtime.is_supported_version()
1 parent 9daa42b commit 70713b4

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

numba_cuda/numba/cuda/cudadrv/runtime.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,19 @@
1010

1111
from numba.cuda.cudadrv.nvrtc import _get_nvrtc_version
1212

13+
SUPPORTED_TOOLKIT_MAJOR_VERSIONS = (12, 13)
14+
1315

1416
class Runtime:
1517
def get_version(self):
1618
return _get_nvrtc_version()
1719

20+
def is_supported_version(self):
21+
"""
22+
Returns True if the CUDA Runtime is a supported version.
23+
"""
24+
return self.get_version()[0] in SUPPORTED_TOOLKIT_MAJOR_VERSIONS
25+
1826

1927
runtime = Runtime()
2028

numba_cuda/numba/cuda/device_init.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def is_supported_version():
133133
- Generating an error or otherwise preventing the use of CUDA.
134134
"""
135135

136-
return runtime.get_version()[0] in (12, 13)
136+
return runtime.is_supported_version()
137137

138138

139139
def cuda_error():

0 commit comments

Comments
 (0)