33from numba .core import config
44from numba .cuda .cudadrv .error import (NvrtcError , NvrtcCompilationError ,
55 NvrtcSupportError )
6+ from cuda .cuda .cudadrv .driver import get_version
67
78import functools
89import os
@@ -62,6 +63,14 @@ class NVRTC:
6263 NVVM interface. Initialization is protected by a lock and uses the standard
6364 (for Numba) open_cudalib function to load the NVRTC library.
6465 """
66+
67+ _CU12ONLY_PROTOTYPES = {
68+ # nvrtcResult nvrtcGetLTOIRSize(nvrtcProgram prog, size_t *ltoSizeRet);
69+ "nvrtcGetLTOIRSize" : (nvrtc_result , nvrtc_program , POINTER (c_size_t )),
70+ # nvrtcResult nvrtcGetLTOIR(nvrtcProgram prog, char *lto);
71+ "nvrtcGetLTOIR" : (nvrtc_result , nvrtc_program , c_char_p )
72+ }
73+
6574 _PROTOTYPES = {
6675 # nvrtcResult nvrtcVersion(int *major, int *minor)
6776 'nvrtcVersion' : (nvrtc_result , POINTER (c_int ), POINTER (c_int )),
@@ -84,10 +93,6 @@ class NVRTC:
8493 'nvrtcGetPTXSize' : (nvrtc_result , nvrtc_program , POINTER (c_size_t )),
8594 # nvrtcResult nvrtcGetPTX(nvrtcProgram prog, char *ptx);
8695 'nvrtcGetPTX' : (nvrtc_result , nvrtc_program , c_char_p ),
87- # nvrtcResult nvrtcGetLTOIRSize(nvrtcProgram prog, size_t *ltoSizeRet);
88- "nvrtcGetLTOIRSize" : (nvrtc_result , nvrtc_program , POINTER (c_size_t )),
89- # nvrtcResult nvrtcGetLTOIR(nvrtcProgram prog, char *lto);
90- "nvrtcGetLTOIR" : (nvrtc_result , nvrtc_program , c_char_p ),
9196 # nvrtcResult nvrtcGetCUBINSize(nvrtcProgram prog,
9297 # size_t *cubinSizeRet);
9398 'nvrtcGetCUBINSize' : (nvrtc_result , nvrtc_program , POINTER (c_size_t )),
@@ -101,6 +106,9 @@ class NVRTC:
101106 'nvrtcGetProgramLog' : (nvrtc_result , nvrtc_program , c_char_p ),
102107 }
103108
109+ if get_version () >= (12 , 0 ):
110+ _PROTOTYPES |= _CU12ONLY_PROTOTYPES
111+
104112 # Singleton reference
105113 __INSTANCE = None
106114
0 commit comments