Skip to content

Commit bca956d

Browse files
committed
Implement debuginfo bool name fix (numba/numba#9888) in numba-cuda
As a workaround until numba/numba#9888 is merged and available in an upstream Numba release, we interject our own modified bytecode translation pass that names bools correctly (`$bool<N>` as oppoosed to `bool<N>`). To avoid duplicating the whole untyped pipeline definition, we take the definition from upstream Numba and modify it to use our pass.
1 parent d4eb970 commit bca956d

File tree

1 file changed

+66
-2
lines changed

1 file changed

+66
-2
lines changed

numba_cuda/numba/cuda/compiler.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
from llvmlite import ir
22
from numba.core.typing.templates import ConcreteTemplate
3+
from numba.core import ir as numba_ir
34
from numba.core import (cgutils, types, typing, funcdesc, config, compiler,
45
sigutils, utils)
56
from numba.core.compiler import (sanitize_compile_result_entries, CompilerBase,
67
DefaultPassBuilder, Flags, Option,
78
CompileResult)
89
from numba.core.compiler_lock import global_compiler_lock
9-
from numba.core.compiler_machinery import (LoweringPass,
10+
from numba.core.compiler_machinery import (FunctionPass, LoweringPass,
1011
PassManager, register_pass)
12+
from numba.core.interpreter import Interpreter
1113
from numba.core.errors import NumbaInvalidConfigWarning
14+
from numba.core.untyped_passes import TranslateByteCode
1215
from numba.core.typed_passes import (IRLegalization, NativeLowering,
1316
AnnotateTypes)
1417
from warnings import warn
@@ -143,13 +146,74 @@ def run_pass(self, state):
143146
return True
144147

145148

149+
class CUDABytecodeInterpreter(Interpreter):
150+
# Based on the superclass implementation, but names the resulting variable
151+
# "$bool<N>" instead of "bool<N>" - see Numba PR #9888:
152+
# https://github.com/numba/numba/pull/9888
153+
#
154+
# This can be removed once that PR is available in an upstream Numba
155+
# release.
156+
def _op_JUMP_IF(self, inst, pred, iftrue):
157+
brs = {
158+
True: inst.get_jump_target(),
159+
False: inst.next,
160+
}
161+
truebr = brs[iftrue]
162+
falsebr = brs[not iftrue]
163+
164+
name = "$bool%s" % (inst.offset)
165+
gv_fn = numba_ir.Global("bool", bool, loc=self.loc)
166+
self.store(value=gv_fn, name=name)
167+
168+
callres = numba_ir.Expr.call(self.get(name), (self.get(pred),), (),
169+
loc=self.loc)
170+
171+
pname = "$%spred" % (inst.offset)
172+
predicate = self.store(value=callres, name=pname)
173+
bra = numba_ir.Branch(cond=predicate, truebr=truebr, falsebr=falsebr,
174+
loc=self.loc)
175+
self.current_block.append(bra)
176+
177+
178+
@register_pass(mutates_CFG=True, analysis_only=False)
179+
class CUDATranslateBytecode(FunctionPass):
180+
_name = "cuda_translate_bytecode"
181+
182+
def __init__(self):
183+
FunctionPass.__init__(self)
184+
185+
def run_pass(self, state):
186+
func_id = state['func_id']
187+
bc = state['bc']
188+
interp = CUDABytecodeInterpreter(func_id)
189+
func_ir = interp.interpret(bc)
190+
state['func_ir'] = func_ir
191+
return True
192+
193+
146194
class CUDACompiler(CompilerBase):
147195
def define_pipelines(self):
148196
dpb = DefaultPassBuilder
149197
pm = PassManager('cuda')
150198

151199
untyped_passes = dpb.define_untyped_pipeline(self.state)
152-
pm.passes.extend(untyped_passes.passes)
200+
201+
# Rather than replicating the whole untyped passes definition in
202+
# numba-cuda, it seems cleaner to take the pass list and replace the
203+
# TranslateBytecode pass with our own.
204+
205+
def replace_translate_pass(implementation, description):
206+
if implementation is TranslateByteCode:
207+
return (CUDATranslateBytecode, description)
208+
else:
209+
return (implementation, description)
210+
211+
cuda_untyped_passes = [
212+
replace_translate_pass(implementation, description)
213+
for implementation, description in untyped_passes.passes
214+
]
215+
216+
pm.passes.extend(cuda_untyped_passes)
153217

154218
typed_passes = dpb.define_typed_pipeline(self.state)
155219
pm.passes.extend(typed_passes.passes)

0 commit comments

Comments
 (0)