|
1 | 1 | from llvmlite import ir |
2 | 2 | from numba.core.typing.templates import ConcreteTemplate |
| 3 | +from numba.core import ir as numba_ir |
3 | 4 | from numba.core import (cgutils, types, typing, funcdesc, config, compiler, |
4 | 5 | sigutils, utils) |
5 | 6 | from numba.core.compiler import (sanitize_compile_result_entries, CompilerBase, |
6 | 7 | DefaultPassBuilder, Flags, Option, |
7 | 8 | CompileResult) |
8 | 9 | from numba.core.compiler_lock import global_compiler_lock |
9 | | -from numba.core.compiler_machinery import (LoweringPass, |
| 10 | +from numba.core.compiler_machinery import (FunctionPass, LoweringPass, |
10 | 11 | PassManager, register_pass) |
| 12 | +from numba.core.interpreter import Interpreter |
11 | 13 | from numba.core.errors import NumbaInvalidConfigWarning |
| 14 | +from numba.core.untyped_passes import TranslateByteCode |
12 | 15 | from numba.core.typed_passes import (IRLegalization, NativeLowering, |
13 | 16 | AnnotateTypes) |
14 | 17 | from warnings import warn |
@@ -143,13 +146,74 @@ def run_pass(self, state): |
143 | 146 | return True |
144 | 147 |
|
145 | 148 |
|
| 149 | +class CUDABytecodeInterpreter(Interpreter): |
| 150 | + # Based on the superclass implementation, but names the resulting variable |
| 151 | + # "$bool<N>" instead of "bool<N>" - see Numba PR #9888: |
| 152 | + # https://github.com/numba/numba/pull/9888 |
| 153 | + # |
| 154 | + # This can be removed once that PR is available in an upstream Numba |
| 155 | + # release. |
| 156 | + def _op_JUMP_IF(self, inst, pred, iftrue): |
| 157 | + brs = { |
| 158 | + True: inst.get_jump_target(), |
| 159 | + False: inst.next, |
| 160 | + } |
| 161 | + truebr = brs[iftrue] |
| 162 | + falsebr = brs[not iftrue] |
| 163 | + |
| 164 | + name = "$bool%s" % (inst.offset) |
| 165 | + gv_fn = numba_ir.Global("bool", bool, loc=self.loc) |
| 166 | + self.store(value=gv_fn, name=name) |
| 167 | + |
| 168 | + callres = numba_ir.Expr.call(self.get(name), (self.get(pred),), (), |
| 169 | + loc=self.loc) |
| 170 | + |
| 171 | + pname = "$%spred" % (inst.offset) |
| 172 | + predicate = self.store(value=callres, name=pname) |
| 173 | + bra = numba_ir.Branch(cond=predicate, truebr=truebr, falsebr=falsebr, |
| 174 | + loc=self.loc) |
| 175 | + self.current_block.append(bra) |
| 176 | + |
| 177 | + |
| 178 | +@register_pass(mutates_CFG=True, analysis_only=False) |
| 179 | +class CUDATranslateBytecode(FunctionPass): |
| 180 | + _name = "cuda_translate_bytecode" |
| 181 | + |
| 182 | + def __init__(self): |
| 183 | + FunctionPass.__init__(self) |
| 184 | + |
| 185 | + def run_pass(self, state): |
| 186 | + func_id = state['func_id'] |
| 187 | + bc = state['bc'] |
| 188 | + interp = CUDABytecodeInterpreter(func_id) |
| 189 | + func_ir = interp.interpret(bc) |
| 190 | + state['func_ir'] = func_ir |
| 191 | + return True |
| 192 | + |
| 193 | + |
146 | 194 | class CUDACompiler(CompilerBase): |
147 | 195 | def define_pipelines(self): |
148 | 196 | dpb = DefaultPassBuilder |
149 | 197 | pm = PassManager('cuda') |
150 | 198 |
|
151 | 199 | untyped_passes = dpb.define_untyped_pipeline(self.state) |
152 | | - pm.passes.extend(untyped_passes.passes) |
| 200 | + |
| 201 | + # Rather than replicating the whole untyped passes definition in |
| 202 | + # numba-cuda, it seems cleaner to take the pass list and replace the |
| 203 | + # TranslateBytecode pass with our own. |
| 204 | + |
| 205 | + def replace_translate_pass(implementation, description): |
| 206 | + if implementation is TranslateByteCode: |
| 207 | + return (CUDATranslateBytecode, description) |
| 208 | + else: |
| 209 | + return (implementation, description) |
| 210 | + |
| 211 | + cuda_untyped_passes = [ |
| 212 | + replace_translate_pass(implementation, description) |
| 213 | + for implementation, description in untyped_passes.passes |
| 214 | + ] |
| 215 | + |
| 216 | + pm.passes.extend(cuda_untyped_passes) |
153 | 217 |
|
154 | 218 | typed_passes = dpb.define_typed_pipeline(self.state) |
155 | 219 | pm.passes.extend(typed_passes.passes) |
|
0 commit comments