Skip to content

Commit 0c8171a

Browse files
workaround passing lto=False to cuda-python
1 parent 4334b8b commit 0c8171a

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

numba_cuda/numba/cuda/cudadrv/driver.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2617,12 +2617,7 @@ def new(cls,
26172617
else:
26182618
linker = CtypesLinker
26192619

2620-
if linker is PyNvJitLinker:
2621-
return linker(max_registers, lineinfo, cc, lto, additional_flags)
2622-
elif additional_flags or lto:
2623-
raise ValueError("LTO and additional flags require PyNvJitLinker")
2624-
else:
2625-
return linker(max_registers, lineinfo, cc)
2620+
return linker(max_registers, lineinfo, cc, lto, additional_flags)
26262621

26272622
@abstractmethod
26282623
def __init__(self, max_registers, lineinfo, cc):
@@ -2762,19 +2757,31 @@ def complete(self):
27622757

27632758

27642759
class CUDALinker(Linker):
2765-
def __init__(self, max_registers=None, lineinfo=False, cc=None):
2760+
def __init__(
2761+
self,
2762+
max_registers=None,
2763+
lineinfo=False,
2764+
cc=None,
2765+
lto=None,
2766+
additional_flags=None
2767+
):
27662768
arch = f"sm_{cc[0] * 10 + cc[1]}"
2769+
# TODO: cuda-python/xyz
2770+
if lto is False:
2771+
lto = None
27672772
self.options = _CUDALinkerOptions(
27682773
max_register_count=max_registers,
27692774
lineinfo=lineinfo,
2770-
arch=arch
2775+
arch=arch,
2776+
link_time_optimization=lto,
27712777
)
27722778

27732779
self.max_registers = max_registers
27742780
self.lineinfo = lineinfo
27752781
self.cc = cc
27762782
self.arch = arch
2777-
self.lto = False
2783+
self.lto = lto
2784+
self.additional_flags = additional_flags
27782785

27792786
self._complete = False
27802787
self._object_codes = []
@@ -2833,6 +2840,10 @@ def add_cubin(self, cubin, name='<cudapy-cubin>'):
28332840
obj = ObjectCode.from_cubin(cubin)
28342841
self._object_codes.append(obj)
28352842

2843+
def add_ltoir(self, ltoir, name='<cudapy-ltoir>'):
2844+
obj = ObjectCode._init(ltoir, 'ltoir')
2845+
self._object_codes.append(obj)
2846+
28362847
def add_file(self, path, kind):
28372848
try:
28382849
with open(path, 'rb') as f:

0 commit comments

Comments
 (0)