Skip to content

Commit 4968438

Browse files
committed
Fix Issue #588
One test still fails, because the C ABI wrapper generator generates no debug info, and the separate compilation seems to lead NVVM to not generate a debug section for it. This should probably be addressed by generating debug info for the C ABI wrapper.
1 parent cb1978f commit 4968438

File tree

2 files changed

+29
-12
lines changed

2 files changed

+29
-12
lines changed

numba_cuda/numba/cuda/codegen.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,9 @@ def _ensure_cc(self, cc):
206206
return device.compute_capability
207207

208208
def get_asm_str(self, cc=None):
209+
return "\n".join(self.get_asm_strs(cc=cc))
210+
211+
def get_asm_strs(self, cc=None):
209212
cc = self._ensure_cc(cc)
210213

211214
ptxes = self._ptx_cache.get(cc, None)
@@ -218,21 +221,25 @@ def get_asm_str(self, cc=None):
218221

219222
irs = self.llvm_strs
220223

221-
ptx = nvvm.compile_ir(irs, **options)
224+
if "g" in options:
225+
ptxes = [nvvm.compile_ir(ir) for ir in irs]
226+
else:
227+
ptxes = [nvvm.compile_ir(irs, **options)]
222228

223229
# Sometimes the result from NVVM contains trailing whitespace and
224230
# nulls, which we strip so that the assembly dump looks a little
225231
# tidier.
226-
ptx = ptx.decode().strip("\x00").strip()
232+
ptxes = [ptx.decode().strip("\x00").strip() for ptx in ptxes]
227233

228234
if config.DUMP_ASSEMBLY:
229235
print(("ASSEMBLY %s" % self._name).center(80, "-"))
230-
print(ptx)
236+
for ptx in ptxes:
237+
print(ptx)
231238
print("=" * 80)
232239

233-
self._ptx_cache[cc] = ptx
240+
self._ptx_cache[cc] = ptxes
234241

235-
return ptx
242+
return ptxes
236243

237244
def get_lto_ptx(self, cc=None):
238245
"""
@@ -284,8 +291,9 @@ def _link_all(self, linker, cc, ignore_nonlto=False):
284291
ltoir = self.get_ltoir(cc=cc)
285292
linker.add_ltoir(ltoir)
286293
else:
287-
ptx = self.get_asm_str(cc=cc)
288-
linker.add_ptx(ptx.encode())
294+
ptxes = self.get_asm_strs(cc=cc)
295+
for ptx in ptxes:
296+
linker.add_ptx(ptx.encode())
289297

290298
for path in self._linking_files:
291299
linker.add_file_guess_ext(path, ignore_nonlto)
@@ -432,7 +440,10 @@ def finalize(self):
432440
for mod in library.modules:
433441
for fn in mod.functions:
434442
if not fn.is_declaration:
435-
fn.linkage = "linkonce_odr"
443+
if "g" in self._nvvm_options:
444+
fn.linkage = "weak_odr"
445+
else:
446+
fn.linkage = "linkonce_odr"
436447

437448
self._finalized = True
438449

numba_cuda/numba/cuda/compiler.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,10 +1023,9 @@ def compile_all(
10231023
)
10241024

10251025
if lto:
1026-
code = lib.get_ltoir(cc=cc)
1026+
codes = [lib.get_ltoir(cc=cc)]
10271027
else:
1028-
code = lib.get_asm_str(cc=cc)
1029-
codes = [code]
1028+
codes = lib.get_asm_strs(cc=cc)
10301029

10311030
# linking_files
10321031
is_ltoir = output == "ltoir"
@@ -1241,7 +1240,14 @@ def compile(
12411240
if lto:
12421241
code = lib.get_ltoir(cc=cc)
12431242
else:
1244-
code = lib.get_asm_str(cc=cc)
1243+
codes = lib.get_asm_strs(cc=cc)
1244+
if len(codes) == 1:
1245+
code = codes[0]
1246+
else:
1247+
raise RuntimeError(
1248+
"Compiling this function results in multiple"
1249+
"PTX files. Use compile_all() instead"
1250+
)
12451251
return code, resty
12461252

12471253

0 commit comments

Comments
 (0)