Skip to content

Commit 06dc5c7

Browse files
committed
Minor fixups following #23 / #56
- Update the codegen class docstring for LTO. - Simplify / correct some logic in `_readenv()` (`value.lower()` could never be `"True"`, only `"true"`. - Simplify additional flags and linker checks. - Setting `self._linker.complete` in `complete()` is unnecessary, as calling `get_linked_cubin()` sets the link as complete already.
1 parent 7e01ab0 commit 06dc5c7

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

numba_cuda/numba/cuda/codegen.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,16 @@ def __init__(
7070
):
7171
"""
7272
codegen:
73-
Codegen object.
73+
Codegen object.
7474
name:
7575
Name of the function in the source.
7676
entry_name:
7777
Name of the kernel function in the binary, if this is a global
7878
kernel and not a device function.
7979
max_registers:
8080
The maximum register usage to aim for when linking.
81+
lto:
82+
Whether to enable link-time optimization.
8183
nvvm_options:
8284
Dict of options to pass to NVVM.
8385
"""

numba_cuda/numba/cuda/cudadrv/driver.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _readenv(name, ctor, default):
6464
return default() if callable(default) else default
6565
try:
6666
if ctor is bool:
67-
return bool(value.lower() in {'1', "True"})
67+
return value.lower() in {'1', "true"}
6868
return ctor(value)
6969
except Exception:
7070
warnings.warn(
@@ -2631,7 +2631,7 @@ def new(cls,
26312631

26322632
if linker is PyNvJitLinker:
26332633
return linker(max_registers, lineinfo, cc, lto, additional_flags)
2634-
elif additional_flags is not None or lto is True:
2634+
elif additional_flags or lto:
26352635
raise ValueError("LTO and additional flags require PyNvJitLinker")
26362636
else:
26372637
return linker(max_registers, lineinfo, cc)
@@ -3088,9 +3088,7 @@ def add_data(self, data, kind, name):
30883088

30893089
def complete(self):
30903090
try:
3091-
cubin = self._linker.get_linked_cubin()
3092-
self._linker._complete = True
3093-
return cubin
3091+
return self._linker.get_linked_cubin()
30943092
except NvJitLinkError as e:
30953093
raise LinkerError from e
30963094

0 commit comments

Comments
 (0)