Skip to content

Commit 21216f1

Browse files
committed
raise warning for inputs that has non-LTOIR objects
1 parent 1ce439b commit 21216f1

File tree

2 files changed

+47
-5
lines changed

2 files changed

+47
-5
lines changed

numba_cuda/numba/cuda/codegen.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import os
99
import subprocess
1010
import tempfile
11-
11+
from warnings import warn
1212

1313
CUDA_TRIPLE = 'nvptx64-nvidia-cuda'
1414

@@ -207,11 +207,20 @@ def get_cubin(self, cc=None):
207207
lto=self._lto
208208
)
209209
self._link_all(linker, cc)
210-
ptx = linker.get_linked_ptx().decode('utf-8')
211210

212-
print(("ASSEMBLY (AFTER LTO) %s" % self._name).center(80, '-'))
213-
print(ptx)
214-
print('=' * 80)
211+
try:
212+
ptx = linker.get_linked_ptx().decode('utf-8')
213+
214+
print(("ASSEMBLY (AFTER LTO) %s" % self._name).center(80, '-'))
215+
print(ptx)
216+
print('=' * 80)
217+
except driver.LinkerError as e:
218+
if linkererr_cause := getattr(e, "__cause__", None):
219+
if "-ptx requires that all inputs have LTOIR" in str(linkererr_cause):
220+
warn(
221+
"Linker input contains non-LTOIR objects, nvjitlink "
222+
"cannot generate LTO-ed PTX."
223+
)
215224

216225
linker = driver.Linker.new(
217226
max_registers=self._max_registers,

numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import io
99
import contextlib
10+
import warnings
1011

1112
from numba.cuda import get_current_device
1213
from numba import cuda
@@ -201,6 +202,38 @@ def kernel(result):
201202
config.DUMP_ASSEMBLY = False
202203

203204

205+
def test_nvjitlink_jit_with_linkable_code_lto_dump_assembly_warn(self):
206+
files = [
207+
test_device_functions_a,
208+
test_device_functions_cubin,
209+
test_device_functions_fatbin,
210+
test_device_functions_o,
211+
test_device_functions_ptx,
212+
]
213+
214+
config.DUMP_ASSEMBLY = True
215+
216+
for file in files:
217+
with self.subTest(file=file):
218+
with warnings.catch_warnings(record=True) as w:
219+
with contextlib.redirect_stdout(None): # suppress other PTX
220+
sig = "uint32(uint32, uint32)"
221+
add_from_numba = cuda.declare_device("add_from_numba", sig)
222+
223+
@cuda.jit(link=[file], lto=True)
224+
def kernel(result):
225+
result[0] = add_from_numba(1, 2)
226+
227+
result = cuda.device_array(1)
228+
kernel[1, 1](result)
229+
assert result[0] == 3
230+
231+
assert len(w) == 1
232+
self.assertIn("cannot generate LTO-ed PTX", str(w[0].message))
233+
234+
235+
config.DUMP_ASSEMBLY = False
236+
204237
def test_nvjitlink_jit_with_invalid_linkable_code(self):
205238
with open(test_device_functions_cubin, "rb") as f:
206239
content = f.read()

0 commit comments

Comments
 (0)