Skip to content

Commit 27b304d

Browse files
committed
enable LTO nvrtc function only for 11.*
1 parent 9ea2fd4 commit 27b304d

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

numba_cuda/numba/cuda/cudadrv/nvrtc.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from numba.core import config
44
from numba.cuda.cudadrv.error import (NvrtcError, NvrtcCompilationError,
55
NvrtcSupportError)
6+
from cuda.cuda.cudadrv.driver import get_version
67

78
import functools
89
import os
@@ -62,6 +63,14 @@ class NVRTC:
6263
NVVM interface. Initialization is protected by a lock and uses the standard
6364
(for Numba) open_cudalib function to load the NVRTC library.
6465
"""
66+
67+
_CU12ONLY_PROTOTYPES = {
68+
# nvrtcResult nvrtcGetLTOIRSize(nvrtcProgram prog, size_t *ltoSizeRet);
69+
"nvrtcGetLTOIRSize": (nvrtc_result, nvrtc_program, POINTER(c_size_t)),
70+
# nvrtcResult nvrtcGetLTOIR(nvrtcProgram prog, char *lto);
71+
"nvrtcGetLTOIR": (nvrtc_result, nvrtc_program, c_char_p)
72+
}
73+
6574
_PROTOTYPES = {
6675
# nvrtcResult nvrtcVersion(int *major, int *minor)
6776
'nvrtcVersion': (nvrtc_result, POINTER(c_int), POINTER(c_int)),
@@ -84,10 +93,6 @@ class NVRTC:
8493
'nvrtcGetPTXSize': (nvrtc_result, nvrtc_program, POINTER(c_size_t)),
8594
# nvrtcResult nvrtcGetPTX(nvrtcProgram prog, char *ptx);
8695
'nvrtcGetPTX': (nvrtc_result, nvrtc_program, c_char_p),
87-
# nvrtcResult nvrtcGetLTOIRSize(nvrtcProgram prog, size_t *ltoSizeRet);
88-
"nvrtcGetLTOIRSize": (nvrtc_result, nvrtc_program, POINTER(c_size_t)),
89-
# nvrtcResult nvrtcGetLTOIR(nvrtcProgram prog, char *lto);
90-
"nvrtcGetLTOIR": (nvrtc_result, nvrtc_program, c_char_p),
9196
# nvrtcResult nvrtcGetCUBINSize(nvrtcProgram prog,
9297
# size_t *cubinSizeRet);
9398
'nvrtcGetCUBINSize': (nvrtc_result, nvrtc_program, POINTER(c_size_t)),
@@ -101,6 +106,9 @@ class NVRTC:
101106
'nvrtcGetProgramLog': (nvrtc_result, nvrtc_program, c_char_p),
102107
}
103108

109+
if get_version() >= (12, 0):
110+
_PROTOTYPES |= _CU12ONLY_PROTOTYPES
111+
104112
# Singleton reference
105113
__INSTANCE = None
106114

0 commit comments

Comments
 (0)