Skip to content

Commit 6b66886

Browse files
begin replacing pynvjitlinker
1 parent d4eb970 commit 6b66886

File tree

1 file changed

+141
-1
lines changed

1 file changed

+141
-1
lines changed

numba_cuda/numba/cuda/cudadrv/driver.py

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@
4747
except ImportError:
4848
NvJitLinker, NvJitLinkError = None, None
4949

50+
from cuda.bindings.nvjitlink import nvJitLinkError
51+
from cuda.bindings import nvjitlink as _nvjitlink
52+
5053
USE_NV_BINDING = config.CUDA_USE_NVIDIA_BINDING
5154

5255
if USE_NV_BINDING:
@@ -2594,7 +2597,7 @@ def new(cls,
25942597
"Enabling pynvjitlink requires CUDA 12."
25952598
)
25962599
if config.CUDA_ENABLE_PYNVJITLINK:
2597-
linker = PyNvJitLinker
2600+
linker = PyNvJitLinkerNew
25982601

25992602
elif config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY:
26002603
linker = MVCLinker
@@ -3024,6 +3027,143 @@ def complete(self):
30243027
return bytes(np.ctypeslib.as_array(cubin_ptr, shape=(size,)))
30253028

30263029

3030+
class PyNvJitLinkerNew(Linker):
3031+
def __init__(
3032+
self,
3033+
max_registers=None,
3034+
lineinfo=False,
3035+
cc=None,
3036+
lto=False,
3037+
additional_flags=None,
3038+
):
3039+
3040+
if cc is None:
3041+
raise RuntimeError("NvJitLink requires CC to be specified")
3042+
if not any(isinstance(cc, t) for t in [list, tuple]):
3043+
raise TypeError("`cc` must be a list or tuple of length 2")
3044+
3045+
sm_ver = f"{cc[0] * 10 + cc[1]}"
3046+
arch = f"-arch=sm_{sm_ver}"
3047+
options = [arch]
3048+
if max_registers:
3049+
options.append(f"-maxrregcount={max_registers}")
3050+
if lineinfo:
3051+
options.append("-lineinfo")
3052+
if lto:
3053+
options.append("-lto")
3054+
if additional_flags is not None:
3055+
options.extend(additional_flags)
3056+
3057+
self.handle = _nvjitlink.create(len(options), options)
3058+
self.lto = lto
3059+
self.options = options
3060+
3061+
self._info_log = None
3062+
self._error_log = None
3063+
self._complete = False
3064+
3065+
@property
3066+
def info_log(self):
3067+
return self._info_log
3068+
3069+
@property
3070+
def error_log(self):
3071+
return self._error_log
3072+
3073+
def add_data(self, input_type, data, name):
3074+
if self._complete:
3075+
raise nvJitLinkError("Cannot add data to already-completeted link")
3076+
3077+
try:
3078+
_nvjitlink.add_data(
3079+
self.handle,
3080+
input_type.value,
3081+
data,
3082+
len(data),
3083+
name
3084+
)
3085+
except RuntimeError as e:
3086+
log_size = _nvjitlink.get_error_log_size(self.handle)
3087+
log = bytearray(log_size)
3088+
self._info_log = _nvjitlink.get_error_log(self.handle, log)
3089+
3090+
log_size = _nvjitlink.get_info_log_size(self.handle)
3091+
log = bytearray(log_size)
3092+
self._error_log = _nvjitlink.get_info_log(self.handle, log)
3093+
raise nvJitLinkError(f"{e}\n{self.error_log}")
3094+
3095+
def add_cubin(self, cubin, name=None):
3096+
name = name or "unnamed-cubin"
3097+
self.add_data(_nvjitlink.InputType.CUBIN, cubin, name)
3098+
3099+
def add_ptx(self, ptx, name=None):
3100+
name = name or "unnamed-ptx"
3101+
self.add_data(_nvjitlink.InputType.PTX, ptx, name)
3102+
3103+
def add_ltoir(self, ltoir, name=None):
3104+
name = name or "unnamed-ltoir"
3105+
self.add_data(_nvjitlink.InputType.LTOIR, ltoir, name)
3106+
3107+
def add_object(self, object_, name=None):
3108+
name = name or "unnamed-object"
3109+
self.add_data(_nvjitlink.InputType.OBJECT, object_, name)
3110+
3111+
def add_fatbin(self, fatbin, name=None):
3112+
name = name or "unnamed-fatbin"
3113+
self.add_data(_nvjitlink.InputType.FATBIN, fatbin, name)
3114+
3115+
def add_library(self, library, name=None):
3116+
self.add_data(_nvjitlink.InputType.LIBRARY, library, name)
3117+
3118+
def add_file(self, path, kind):
3119+
try:
3120+
with open(path, "rb") as f:
3121+
data = f.read()
3122+
except FileNotFoundError:
3123+
raise LinkerError(f"{path} not found")
3124+
3125+
name = pathlib.Path(path).name
3126+
self.add_data(data, kind, name)
3127+
3128+
def get_linked_cubin(self):
3129+
try:
3130+
_nvjitlink.complete(self.handle)
3131+
self._complete = True
3132+
size = _nvjitlink.get_linked_cubin_size(self.handle)
3133+
cubin = bytearray(size)
3134+
3135+
_nvjitlink.get_linked_cubin(self.handle, cubin)
3136+
return cubin
3137+
3138+
except nvJitLinkError as e:
3139+
size = _nvjitlink.get_error_log_size(self.handle)
3140+
log = bytearray(size)
3141+
self._error_log = _nvjitlink.get_error_log(self.handle, log)
3142+
raise nvJitLinkError(f"{e}\n{self.error_log}")
3143+
finally:
3144+
size = _nvjitlink.get_info_log_size(self.handle)
3145+
log = bytearray(size)
3146+
self._info_log = _nvjitlink.get_info_log(self.handle, log)
3147+
3148+
def get_linked_ptx(self):
3149+
try:
3150+
_nvjitlink.complete(self.handle)
3151+
self._complete = True
3152+
size = _nvjitlink.get_linked_ptx_size(self.handle)
3153+
return _nvjitlink.get_linked_ptx(self.handle, size)
3154+
except RuntimeError as e:
3155+
self._error_log = _nvjitlink.get_error_log(self.handle)
3156+
raise NvJitLinkError(f"{e}\n{self.error_log}")
3157+
finally:
3158+
self._info_log = _nvjitlink.get_info_log(self.handle)
3159+
3160+
def complete(self):
3161+
try:
3162+
return self.get_linked_cubin()
3163+
except NvJitLinkError as e:
3164+
raise LinkerError from e
3165+
3166+
30273167
class PyNvJitLinker(Linker):
30283168
def __init__(
30293169
self,

0 commit comments

Comments
 (0)