Skip to content

Commit 52006ea

Browse files
committed
Elide kernel wrappers by altering device function IR
1 parent cea934b commit 52006ea

File tree

6 files changed

+189
-33
lines changed

6 files changed

+189
-33
lines changed

numba_cuda/numba/cuda/compiler.py

Lines changed: 172 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from llvmlite import ir
22
from numba.core.typing.templates import ConcreteTemplate
3-
from numba.core import types, typing, funcdesc, config, compiler, sigutils
3+
from numba.core import (cgutils, types, typing, funcdesc, config, compiler,
4+
sigutils, utils)
45
from numba.core.compiler import (sanitize_compile_result_entries, CompilerBase,
56
DefaultPassBuilder, Flags, Option,
67
CompileResult)
@@ -11,7 +12,10 @@
1112
from numba.core.typed_passes import (IRLegalization, NativeLowering,
1213
AnnotateTypes)
1314
from warnings import warn
15+
from numba.cuda import nvvmutils
1416
from numba.cuda.api import get_current_device
17+
from numba.cuda.cudadrv import nvvm
18+
from numba.cuda.descriptor import cuda_target
1519
from numba.cuda.target import CUDACABICallConv
1620

1721

@@ -24,6 +28,15 @@ def _nvvm_options_type(x):
2428
return x
2529

2630

31+
def _optional_int_type(x):
32+
if x is None:
33+
return None
34+
35+
else:
36+
assert isinstance(x, int)
37+
return x
38+
39+
2740
class CUDAFlags(Flags):
2841
nvvm_options = Option(
2942
type=_nvvm_options_type,
@@ -35,6 +48,11 @@ class CUDAFlags(Flags):
3548
default=None,
3649
doc="Compute Capability",
3750
)
51+
max_registers = Option(
52+
type=_optional_int_type,
53+
default=None,
54+
doc="Max registers"
55+
)
3856

3957

4058
# The CUDACompileResult (CCR) has a specially-defined entry point equal to its
@@ -109,7 +127,9 @@ def run_pass(self, state):
109127
codegen = state.targetctx.codegen()
110128
name = state.func_id.func_qualname
111129
nvvm_options = state.flags.nvvm_options
112-
state.library = codegen.create_library(name, nvvm_options=nvvm_options)
130+
max_registers = state.flags.max_registers
131+
state.library = codegen.create_library(name, nvvm_options=nvvm_options,
132+
max_registers=max_registers)
113133
# Enable object caching upfront so that the library can be serialized.
114134
state.library.enable_object_caching()
115135

@@ -152,7 +172,7 @@ def define_cuda_lowering_pipeline(self, state):
152172
@global_compiler_lock
153173
def compile_cuda(pyfunc, return_type, args, debug=False, lineinfo=False,
154174
inline=False, fastmath=False, nvvm_options=None,
155-
cc=None):
175+
cc=None, max_registers=None):
156176
if cc is None:
157177
raise ValueError('Compute Capability must be supplied')
158178

@@ -189,6 +209,7 @@ def compile_cuda(pyfunc, return_type, args, debug=False, lineinfo=False,
189209
if nvvm_options:
190210
flags.nvvm_options = nvvm_options
191211
flags.compute_capability = cc
212+
flags.max_registers = max_registers
192213

193214
# Run compilation pipeline
194215
from numba.core.target_extension import target_override
@@ -247,11 +268,155 @@ def cabi_wrap_function(context, lib, fndesc, wrapper_function_name,
247268
builder, func, restype, argtypes, callargs)
248269
builder.ret(return_value)
249270

271+
if config.DUMP_LLVM:
272+
utils.dump_llvm(fndesc, wrapper_module)
273+
250274
library.add_ir_module(wrapper_module)
251275
library.finalize()
252276
return library
253277

254278

279+
def kernel_fixup(kernel, debug):
280+
if debug:
281+
exc_helper = add_exception_store_helper(kernel)
282+
283+
# Pass 1 - replace:
284+
#
285+
# ret <value>
286+
#
287+
# with:
288+
#
289+
# exc_helper(<value>)
290+
# ret void
291+
292+
for block in kernel.blocks:
293+
for i, inst in enumerate(block.instructions):
294+
if isinstance(inst, ir.Ret):
295+
old_ret = block.instructions.pop()
296+
block.terminator = None
297+
298+
# The original return's metadata will be set on the new
299+
# instructions in order to preserve debug info
300+
metadata = old_ret.metadata
301+
302+
builder = ir.IRBuilder(block)
303+
if debug:
304+
status_code = old_ret.operands[0]
305+
exc_helper_call = builder.call(exc_helper, (status_code,))
306+
exc_helper_call.metadata = metadata
307+
308+
new_ret = builder.ret_void()
309+
new_ret.metadata = old_ret.metadata
310+
311+
# Need to break out so we don't carry on modifying what we are
312+
# iterating over. There can only be one return in a block
313+
# anyway.
314+
break
315+
316+
# Pass 2: remove stores of null pointer to return value argument pointer
317+
318+
return_value = kernel.args[0]
319+
320+
for block in kernel.blocks:
321+
remove_list = []
322+
323+
# Find all stores first
324+
for inst in block.instructions:
325+
if (isinstance(inst, ir.StoreInstr)
326+
and inst.operands[1] == return_value):
327+
remove_list.append(inst)
328+
329+
# Remove all stores
330+
for to_remove in remove_list:
331+
block.instructions.remove(to_remove)
332+
333+
# Replace non-void return type with void return type and remove return
334+
# value
335+
336+
if isinstance(kernel.type, ir.PointerType):
337+
new_type = ir.PointerType(ir.FunctionType(ir.VoidType(),
338+
kernel.type.pointee.args[1:]))
339+
else:
340+
new_type = ir.FunctionType(ir.VoidType(), kernel.type.args[1:])
341+
342+
kernel.type = new_type
343+
kernel.return_value = ir.ReturnValue(kernel, ir.VoidType())
344+
kernel.args = kernel.args[1:]
345+
346+
# Mark as a kernel for NVVM
347+
348+
nvvm.set_cuda_kernel(kernel)
349+
350+
if config.DUMP_LLVM:
351+
print(f"LLVM DUMP: Post kernel fixup {kernel.name}".center(80, '-'))
352+
print(kernel.module)
353+
print('=' * 80)
354+
355+
356+
def add_exception_store_helper(kernel):
357+
358+
# Create global variables for exception state
359+
360+
def define_error_gv(postfix):
361+
name = kernel.name + postfix
362+
gv = cgutils.add_global_variable(kernel.module, ir.IntType(32),
363+
name)
364+
gv.initializer = ir.Constant(gv.type.pointee, None)
365+
return gv
366+
367+
gv_exc = define_error_gv("__errcode__")
368+
gv_tid = []
369+
gv_ctaid = []
370+
for i in 'xyz':
371+
gv_tid.append(define_error_gv("__tid%s__" % i))
372+
gv_ctaid.append(define_error_gv("__ctaid%s__" % i))
373+
374+
# Create exception store helper function
375+
376+
helper_name = kernel.name + "__exc_helper__"
377+
helper_type = ir.FunctionType(ir.VoidType(), (ir.IntType(32),))
378+
helper_func = ir.Function(kernel.module, helper_type, helper_name)
379+
380+
block = helper_func.append_basic_block(name="entry")
381+
builder = ir.IRBuilder(block)
382+
383+
# Implement status check / exception store logic
384+
385+
status_code = helper_func.args[0]
386+
call_conv = cuda_target.target_context.call_conv
387+
status = call_conv._get_return_status(builder, status_code)
388+
389+
# Check error status
390+
with cgutils.if_likely(builder, status.is_ok):
391+
builder.ret_void()
392+
393+
with builder.if_then(builder.not_(status.is_python_exc)):
394+
# User exception raised
395+
old = ir.Constant(gv_exc.type.pointee, None)
396+
397+
# Use atomic cmpxchg to prevent rewriting the error status
398+
# Only the first error is recorded
399+
400+
xchg = builder.cmpxchg(gv_exc, old, status.code,
401+
'monotonic', 'monotonic')
402+
changed = builder.extract_value(xchg, 1)
403+
404+
# If the xchange is successful, save the thread ID.
405+
sreg = nvvmutils.SRegBuilder(builder)
406+
with builder.if_then(changed):
407+
for dim, ptr, in zip("xyz", gv_tid):
408+
val = sreg.tid(dim)
409+
builder.store(val, ptr)
410+
411+
for dim, ptr, in zip("xyz", gv_ctaid):
412+
val = sreg.ctaid(dim)
413+
builder.store(val, ptr)
414+
415+
builder.ret_void()
416+
417+
return helper_func
418+
419+
255420
@global_compiler_lock
256421
def compile(pyfunc, sig, debug=False, lineinfo=False, device=True,
257422
fastmath=False, cc=None, opt=True, abi="c", abi_info=None,
@@ -344,13 +509,10 @@ def compile(pyfunc, sig, debug=False, lineinfo=False, device=True,
344509
lib = cabi_wrap_function(tgt, lib, cres.fndesc, wrapper_name,
345510
nvvm_options)
346511
else:
347-
code = pyfunc.__code__
348-
filename = code.co_filename
349-
linenum = code.co_firstlineno
350-
351-
lib, kernel = tgt.prepare_cuda_kernel(cres.library, cres.fndesc, debug,
352-
lineinfo, nvvm_options, filename,
353-
linenum)
512+
lib = cres.library
513+
kernel = lib.get_function(cres.fndesc.llvm_func_name)
514+
lib._entry_name = cres.fndesc.llvm_func_name
515+
kernel_fixup(kernel, debug)
354516

355517
if lto:
356518
code = lib.get_ltoir(cc=cc)

numba_cuda/numba/cuda/dispatcher.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from numba.cuda.api import get_current_device
1515
from numba.cuda.args import wrap_arg
16-
from numba.cuda.compiler import compile_cuda, CUDACompiler
16+
from numba.cuda.compiler import compile_cuda, CUDACompiler, kernel_fixup
1717
from numba.cuda.cudadrv import driver
1818
from numba.cuda.cudadrv.devices import get_context
1919
from numba.cuda.descriptor import cuda_target
@@ -86,15 +86,13 @@ def __init__(self, py_func, argtypes, link=None, debug=False,
8686
inline=inline,
8787
fastmath=fastmath,
8888
nvvm_options=nvvm_options,
89-
cc=cc)
89+
cc=cc,
90+
max_registers=max_registers)
9091
tgt_ctx = cres.target_context
91-
code = self.py_func.__code__
92-
filename = code.co_filename
93-
linenum = code.co_firstlineno
94-
lib, kernel = tgt_ctx.prepare_cuda_kernel(cres.library, cres.fndesc,
95-
debug, lineinfo, nvvm_options,
96-
filename, linenum,
97-
max_registers)
92+
lib = cres.library
93+
kernel = lib.get_function(cres.fndesc.llvm_func_name)
94+
lib._entry_name = cres.fndesc.llvm_func_name
95+
kernel_fixup(kernel, self.debug)
9896

9997
if not link:
10098
link = []

numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def test_issue_5835(self):
7272
def f(x):
7373
x[0] = 0
7474

75+
@unittest.skip("Wrappers no longer exist")
7576
def test_wrapper_has_debuginfo(self):
7677
sig = (types.int32[::1],)
7778

numba_cuda/numba/cuda/tests/cudapy/test_inspect.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,7 @@ def foo(x, y):
3333
self.assertIn("foo", llvm)
3434

3535
# Kernel in LLVM
36-
self.assertIn('cuda.kernel.wrapper', llvm)
37-
38-
# Wrapped device function body in LLVM
39-
self.assertIn("define linkonce_odr i32", llvm)
36+
self.assertIn("define void @", llvm)
4037

4138
asm = foo.inspect_asm(sig)
4239

@@ -72,12 +69,8 @@ def foo(x, y):
7269
self.assertIn("foo", llvmirs[float64, float64])
7370

7471
# Kernels in LLVM
75-
self.assertIn('cuda.kernel.wrapper', llvmirs[intp, intp])
76-
self.assertIn('cuda.kernel.wrapper', llvmirs[float64, float64])
77-
78-
# Wrapped device function bodies in LLVM
79-
self.assertIn("define linkonce_odr i32", llvmirs[intp, intp])
80-
self.assertIn("define linkonce_odr i32", llvmirs[float64, float64])
72+
self.assertIn("define void @", llvmirs[intp, intp])
73+
self.assertIn("define void @", llvmirs[float64, float64])
8174

8275
asmdict = foo.inspect_asm()
8376

numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,9 @@ def caller(x):
170170
subprograms += 1
171171

172172
# One DISubprogram for each of:
173-
# - The kernel wrapper
174173
# - The caller
175174
# - The callee
176-
expected_subprograms = 3
175+
expected_subprograms = 2
177176

178177
self.assertEqual(subprograms, expected_subprograms,
179178
f'"Expected {expected_subprograms} DISubprograms; '

numba_cuda/numba/cuda/tests/cudapy/test_optimization.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@ def device_func(x, y, z):
1414

1515

1616
# Fragments of code that are removed from kernel_func's PTX when optimization
17-
# is on
18-
removed_by_opt = ( '__local_depot0', 'call.uni', 'st.param.b64')
17+
# is on. Previously this list was longer when kernel wrappers were used - if
18+
# the test function were more complex it may be possible to isolate additional
19+
# fragments of PTX we could check for the absence / presence of, but removal of
20+
# the use of local memory is a good indicator that optimization was applied.
21+
removed_by_opt = ( '__local_depot0',)
1922

2023

2124
@skip_on_cudasim('Simulator does not optimize code')

0 commit comments

Comments
 (0)