Skip to content

Commit afcce87

Browse files
committed
Conditionally add LTO-able objects for PTX prints
1 parent b42c67d commit afcce87

File tree

3 files changed

+39
-25
lines changed

3 files changed

+39
-25
lines changed

numba_cuda/numba/cuda/codegen.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import os
99
import subprocess
1010
import tempfile
11-
from warnings import warn
1211

1312
CUDA_TRIPLE = 'nvptx64-nvidia-cuda'
1413

@@ -179,7 +178,7 @@ def get_ltoir(self, cc=None):
179178

180179
return ltoir
181180

182-
def _link_all(self, linker, cc):
181+
def _link_all(self, linker, cc, ignore_nonlto=False):
183182
if linker.lto:
184183
ltoir = self.get_ltoir(cc=cc)
185184
linker.add_ltoir(ltoir)
@@ -188,9 +187,11 @@ def _link_all(self, linker, cc):
188187
linker.add_ptx(ptx.encode())
189188

190189
for path in self._linking_files:
191-
linker.add_file_guess_ext(path)
190+
linker.add_file_guess_ext(path, ignore_nonlto)
192191
if self.needs_cudadevrt:
193-
linker.add_file_guess_ext(get_cudalib('cudadevrt', static=True))
192+
linker.add_file_guess_ext(
193+
get_cudalib('cudadevrt', static=True), ignore_nonlto
194+
)
194195

195196
def get_cubin(self, cc=None):
196197
cc = self._ensure_cc(cc)
@@ -206,30 +207,22 @@ def get_cubin(self, cc=None):
206207
additional_flags=["-ptx"],
207208
lto=self._lto
208209
)
209-
self._link_all(linker, cc)
210-
211-
try:
212-
ptx = linker.get_linked_ptx().decode('utf-8')
213-
214-
print(("ASSEMBLY (AFTER LTO) %s" % self._name).center(80, '-'))
215-
print(ptx)
216-
print('=' * 80)
217-
except driver.LinkerError as e:
218-
if linkererr_cause := getattr(e, "__cause__", None):
219-
if "-ptx requires that all inputs have LTOIR" in str(
220-
linkererr_cause
221-
):
222-
warn(
223-
"Linker input contains non-LTOIR objects, nvjitlink"
224-
" cannot generate LTO-ed PTX."
225-
)
210+
# `-ptx` flag is meant to view the optimized PTX for LTO objects.
211+
# Non-LTO objects are not passed to linker.
212+
self._link_all(linker, cc, ignore_nonlto=True)
213+
214+
ptx = linker.get_linked_ptx().decode('utf-8')
215+
216+
print(("ASSEMBLY (AFTER LTO) %s" % self._name).center(80, '-'))
217+
print(ptx)
218+
print('=' * 80)
226219

227220
linker = driver.Linker.new(
228221
max_registers=self._max_registers,
229222
cc=cc,
230223
lto=self._lto
231224
)
232-
self._link_all(linker, cc)
225+
self._link_all(linker, cc, ignore_nonlto=False)
233226
cubin = linker.complete()
234227

235228
self._cubin_cache[cc] = cubin

numba_cuda/numba/cuda/cudadrv/driver.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from .drvapi import API_PROTOTYPES
3737
from .drvapi import cu_occupancy_b2d_size, cu_stream_callback_pyobj, cu_uuid
3838
from .mappings import FILE_EXTENSION_MAP
39-
from .linkable_code import LinkableCode
39+
from .linkable_code import LinkableCode, LTOIR
4040
from numba.cuda.cudadrv import enums, drvapi, nvrtc
4141

4242
USE_NV_BINDING = config.CUDA_USE_NVIDIA_BINDING
@@ -2683,12 +2683,18 @@ def add_cu_file(self, path):
26832683
cu = f.read()
26842684
self.add_cu(cu, os.path.basename(path))
26852685

2686-
def add_file_guess_ext(self, path_or_code):
2686+
def add_file_guess_ext(self, path_or_code, ignore_nonlto=False):
26872687
"""
26882688
Add a file or LinkableCode object to the link. If a file is
26892689
passed, the type will be inferred from the extension. A LinkableCode
26902690
object represents a file already in memory.
2691+
2692+
When `ignore_nonlto` is set to true, do not add code that are will not
2693+
be LTO-ed in the linking process. This is useful in inspecting the
2694+
LTO-ed portion of the PTX when linker is added with objects that can be
2695+
both LTO-ed and not LTO-ed.
26912696
"""
2697+
26922698
if isinstance(path_or_code, str):
26932699
ext = pathlib.Path(path_or_code).suffix
26942700
if ext == '':
@@ -2704,6 +2710,13 @@ def add_file_guess_ext(self, path_or_code):
27042710
"Don't know how to link file with extension "
27052711
f"{ext}"
27062712
)
2713+
if ignore_nonlto and kind != FILE_EXTENSION_MAP["ltoir"]:
2714+
warnings.warn(
2715+
f"Not adding {path_or_code} as it is not optimizable "
2716+
"at link time, and `ignore_nonlto == True`."
2717+
)
2718+
return
2719+
27072720
self.add_file(path_or_code, kind)
27082721
return
27092722
else:
@@ -2716,6 +2729,13 @@ def add_file_guess_ext(self, path_or_code):
27162729
if path_or_code.kind == "cu":
27172730
self.add_cu(path_or_code.data, path_or_code.name)
27182731
else:
2732+
if ignore_nonlto and not isinstance(path_or_code.kind, LTOIR):
2733+
warnings.warn(
2734+
f"Not adding {path_or_code.name} as it is not "
2735+
"optimizable at link time, and `ignore_nonlto == True`."
2736+
)
2737+
return
2738+
27192739
self.add_data(
27202740
path_or_code.data, path_or_code.kind, path_or_code.name
27212741
)

numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,8 @@ def kernel(result):
230230
assert result[0] == 3
231231

232232
assert len(w) == 1
233-
self.assertIn("cannot generate LTO-ed PTX", str(w[0].message))
233+
self.assertIn("it is not optimizable at link time, and "
234+
"`ignore_nonlto == True`", str(w[0].message))
234235

235236
config.DUMP_ASSEMBLY = False
236237

0 commit comments

Comments
 (0)