Skip to content

Commit e7125bf

Browse files
committed
Elide kernel wrappers by altering device function IR
1 parent 793d238 commit e7125bf

File tree

6 files changed

+198
-33
lines changed

6 files changed

+198
-33
lines changed

numba_cuda/numba/cuda/compiler.py

Lines changed: 180 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from llvmlite import ir
22
from numba.core.typing.templates import ConcreteTemplate
3-
from numba.core import types, typing, funcdesc, config, compiler, sigutils
3+
from numba.core import (cgutils, types, typing, funcdesc, config, compiler,
4+
sigutils, utils)
45
from numba.core.compiler import (sanitize_compile_result_entries, CompilerBase,
56
DefaultPassBuilder, Flags, Option,
67
CompileResult)
@@ -11,7 +12,10 @@
1112
from numba.core.typed_passes import (IRLegalization, NativeLowering,
1213
AnnotateTypes)
1314
from warnings import warn
15+
from numba.cuda import nvvmutils
1416
from numba.cuda.api import get_current_device
17+
from numba.cuda.cudadrv import nvvm
18+
from numba.cuda.descriptor import cuda_target
1519
from numba.cuda.target import CUDACABICallConv
1620

1721

@@ -24,6 +28,15 @@ def _nvvm_options_type(x):
2428
return x
2529

2630

31+
def _optional_int_type(x):
32+
if x is None:
33+
return None
34+
35+
else:
36+
assert isinstance(x, int)
37+
return x
38+
39+
2740
class CUDAFlags(Flags):
2841
nvvm_options = Option(
2942
type=_nvvm_options_type,
@@ -35,6 +48,16 @@ class CUDAFlags(Flags):
3548
default=None,
3649
doc="Compute Capability",
3750
)
51+
max_registers = Option(
52+
type=_optional_int_type,
53+
default=None,
54+
doc="Max registers"
55+
)
56+
lto = Option(
57+
type=bool,
58+
default=False,
59+
doc="Enable Link-time Optimization"
60+
)
3861

3962

4063
# The CUDACompileResult (CCR) has a specially-defined entry point equal to its
@@ -109,7 +132,11 @@ def run_pass(self, state):
109132
codegen = state.targetctx.codegen()
110133
name = state.func_id.func_qualname
111134
nvvm_options = state.flags.nvvm_options
112-
state.library = codegen.create_library(name, nvvm_options=nvvm_options)
135+
max_registers = state.flags.max_registers
136+
lto = state.flags.lto
137+
state.library = codegen.create_library(name, nvvm_options=nvvm_options,
138+
max_registers=max_registers,
139+
lto=lto)
113140
# Enable object caching upfront so that the library can be serialized.
114141
state.library.enable_object_caching()
115142

@@ -152,7 +179,7 @@ def define_cuda_lowering_pipeline(self, state):
152179
@global_compiler_lock
153180
def compile_cuda(pyfunc, return_type, args, debug=False, lineinfo=False,
154181
inline=False, fastmath=False, nvvm_options=None,
155-
cc=None):
182+
cc=None, max_registers=None, lto=False):
156183
if cc is None:
157184
raise ValueError('Compute Capability must be supplied')
158185

@@ -189,6 +216,8 @@ def compile_cuda(pyfunc, return_type, args, debug=False, lineinfo=False,
189216
if nvvm_options:
190217
flags.nvvm_options = nvvm_options
191218
flags.compute_capability = cc
219+
flags.max_registers = max_registers
220+
flags.lto = lto
192221

193222
# Run compilation pipeline
194223
from numba.core.target_extension import target_override
@@ -247,11 +276,155 @@ def cabi_wrap_function(context, lib, fndesc, wrapper_function_name,
247276
builder, func, restype, argtypes, callargs)
248277
builder.ret(return_value)
249278

279+
if config.DUMP_LLVM:
280+
utils.dump_llvm(fndesc, wrapper_module)
281+
250282
library.add_ir_module(wrapper_module)
251283
library.finalize()
252284
return library
253285

254286

287+
def kernel_fixup(kernel, debug):
288+
if debug:
289+
exc_helper = add_exception_store_helper(kernel)
290+
291+
# Pass 1 - replace:
292+
#
293+
# ret <value>
294+
#
295+
# with:
296+
#
297+
# exc_helper(<value>)
298+
# ret void
299+
300+
for block in kernel.blocks:
301+
for i, inst in enumerate(block.instructions):
302+
if isinstance(inst, ir.Ret):
303+
old_ret = block.instructions.pop()
304+
block.terminator = None
305+
306+
# The original return's metadata will be set on the new
307+
# instructions in order to preserve debug info
308+
metadata = old_ret.metadata
309+
310+
builder = ir.IRBuilder(block)
311+
if debug:
312+
status_code = old_ret.operands[0]
313+
exc_helper_call = builder.call(exc_helper, (status_code,))
314+
exc_helper_call.metadata = metadata
315+
316+
new_ret = builder.ret_void()
317+
new_ret.metadata = old_ret.metadata
318+
319+
# Need to break out so we don't carry on modifying what we are
320+
# iterating over. There can only be one return in a block
321+
# anyway.
322+
break
323+
324+
# Pass 2: remove stores of null pointer to return value argument pointer
325+
326+
return_value = kernel.args[0]
327+
328+
for block in kernel.blocks:
329+
remove_list = []
330+
331+
# Find all stores first
332+
for inst in block.instructions:
333+
if (isinstance(inst, ir.StoreInstr)
334+
and inst.operands[1] == return_value):
335+
remove_list.append(inst)
336+
337+
# Remove all stores
338+
for to_remove in remove_list:
339+
block.instructions.remove(to_remove)
340+
341+
# Replace non-void return type with void return type and remove return
342+
# value
343+
344+
if isinstance(kernel.type, ir.PointerType):
345+
new_type = ir.PointerType(ir.FunctionType(ir.VoidType(),
346+
kernel.type.pointee.args[1:]))
347+
else:
348+
new_type = ir.FunctionType(ir.VoidType(), kernel.type.args[1:])
349+
350+
kernel.type = new_type
351+
kernel.return_value = ir.ReturnValue(kernel, ir.VoidType())
352+
kernel.args = kernel.args[1:]
353+
354+
# Mark as a kernel for NVVM
355+
356+
nvvm.set_cuda_kernel(kernel)
357+
358+
if config.DUMP_LLVM:
359+
print(f"LLVM DUMP: Post kernel fixup {kernel.name}".center(80, '-'))
360+
print(kernel.module)
361+
print('=' * 80)
362+
363+
364+
def add_exception_store_helper(kernel):
365+
366+
# Create global variables for exception state
367+
368+
def define_error_gv(postfix):
369+
name = kernel.name + postfix
370+
gv = cgutils.add_global_variable(kernel.module, ir.IntType(32),
371+
name)
372+
gv.initializer = ir.Constant(gv.type.pointee, None)
373+
return gv
374+
375+
gv_exc = define_error_gv("__errcode__")
376+
gv_tid = []
377+
gv_ctaid = []
378+
for i in 'xyz':
379+
gv_tid.append(define_error_gv("__tid%s__" % i))
380+
gv_ctaid.append(define_error_gv("__ctaid%s__" % i))
381+
382+
# Create exception store helper function
383+
384+
helper_name = kernel.name + "__exc_helper__"
385+
helper_type = ir.FunctionType(ir.VoidType(), (ir.IntType(32),))
386+
helper_func = ir.Function(kernel.module, helper_type, helper_name)
387+
388+
block = helper_func.append_basic_block(name="entry")
389+
builder = ir.IRBuilder(block)
390+
391+
# Implement status check / exception store logic
392+
393+
status_code = helper_func.args[0]
394+
call_conv = cuda_target.target_context.call_conv
395+
status = call_conv._get_return_status(builder, status_code)
396+
397+
# Check error status
398+
with cgutils.if_likely(builder, status.is_ok):
399+
builder.ret_void()
400+
401+
with builder.if_then(builder.not_(status.is_python_exc)):
402+
# User exception raised
403+
old = ir.Constant(gv_exc.type.pointee, None)
404+
405+
# Use atomic cmpxchg to prevent rewriting the error status
406+
# Only the first error is recorded
407+
408+
xchg = builder.cmpxchg(gv_exc, old, status.code,
409+
'monotonic', 'monotonic')
410+
changed = builder.extract_value(xchg, 1)
411+
412+
# If the xchange is successful, save the thread ID.
413+
sreg = nvvmutils.SRegBuilder(builder)
414+
with builder.if_then(changed):
415+
for dim, ptr, in zip("xyz", gv_tid):
416+
val = sreg.tid(dim)
417+
builder.store(val, ptr)
418+
419+
for dim, ptr, in zip("xyz", gv_ctaid):
420+
val = sreg.ctaid(dim)
421+
builder.store(val, ptr)
422+
423+
builder.ret_void()
424+
425+
return helper_func
426+
427+
255428
@global_compiler_lock
256429
def compile(pyfunc, sig, debug=None, lineinfo=False, device=True,
257430
fastmath=False, cc=None, opt=None, abi="c", abi_info=None,
@@ -347,13 +520,10 @@ def compile(pyfunc, sig, debug=None, lineinfo=False, device=True,
347520
lib = cabi_wrap_function(tgt, lib, cres.fndesc, wrapper_name,
348521
nvvm_options)
349522
else:
350-
code = pyfunc.__code__
351-
filename = code.co_filename
352-
linenum = code.co_firstlineno
353-
354-
lib, kernel = tgt.prepare_cuda_kernel(cres.library, cres.fndesc, debug,
355-
lineinfo, nvvm_options, filename,
356-
linenum)
523+
lib = cres.library
524+
kernel = lib.get_function(cres.fndesc.llvm_func_name)
525+
lib._entry_name = cres.fndesc.llvm_func_name
526+
kernel_fixup(kernel, debug)
357527

358528
if lto:
359529
code = lib.get_ltoir(cc=cc)

numba_cuda/numba/cuda/dispatcher.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from numba.cuda.api import get_current_device
1616
from numba.cuda.args import wrap_arg
17-
from numba.cuda.compiler import compile_cuda, CUDACompiler
17+
from numba.cuda.compiler import compile_cuda, CUDACompiler, kernel_fixup
1818
from numba.cuda.cudadrv import driver
1919
from numba.cuda.cudadrv.devices import get_context
2020
from numba.cuda.descriptor import cuda_target
@@ -102,15 +102,14 @@ def __init__(self, py_func, argtypes, link=None, debug=False,
102102
inline=inline,
103103
fastmath=fastmath,
104104
nvvm_options=nvvm_options,
105-
cc=cc)
105+
cc=cc,
106+
max_registers=max_registers,
107+
lto=lto)
106108
tgt_ctx = cres.target_context
107-
code = self.py_func.__code__
108-
filename = code.co_filename
109-
linenum = code.co_firstlineno
110-
lib, kernel = tgt_ctx.prepare_cuda_kernel(cres.library, cres.fndesc,
111-
debug, lineinfo, nvvm_options,
112-
filename, linenum,
113-
max_registers, lto)
109+
lib = cres.library
110+
kernel = lib.get_function(cres.fndesc.llvm_func_name)
111+
lib._entry_name = cres.fndesc.llvm_func_name
112+
kernel_fixup(kernel, self.debug)
114113

115114
if not link:
116115
link = []

numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def test_issue_5835(self):
7272
def f(x):
7373
x[0] = 0
7474

75+
@unittest.skip("Wrappers no longer exist")
7576
def test_wrapper_has_debuginfo(self):
7677
sig = (types.int32[::1],)
7778

numba_cuda/numba/cuda/tests/cudapy/test_inspect.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,7 @@ def foo(x, y):
3333
self.assertIn("foo", llvm)
3434

3535
# Kernel in LLVM
36-
self.assertIn('cuda.kernel.wrapper', llvm)
37-
38-
# Wrapped device function body in LLVM
39-
self.assertIn("define linkonce_odr i32", llvm)
36+
self.assertIn("define void @", llvm)
4037

4138
asm = foo.inspect_asm(sig)
4239

@@ -72,12 +69,8 @@ def foo(x, y):
7269
self.assertIn("foo", llvmirs[float64, float64])
7370

7471
# Kernels in LLVM
75-
self.assertIn('cuda.kernel.wrapper', llvmirs[intp, intp])
76-
self.assertIn('cuda.kernel.wrapper', llvmirs[float64, float64])
77-
78-
# Wrapped device function bodies in LLVM
79-
self.assertIn("define linkonce_odr i32", llvmirs[intp, intp])
80-
self.assertIn("define linkonce_odr i32", llvmirs[float64, float64])
72+
self.assertIn("define void @", llvmirs[intp, intp])
73+
self.assertIn("define void @", llvmirs[float64, float64])
8174

8275
asmdict = foo.inspect_asm()
8376

numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,9 @@ def caller(x):
170170
subprograms += 1
171171

172172
# One DISubprogram for each of:
173-
# - The kernel wrapper
174173
# - The caller
175174
# - The callee
176-
expected_subprograms = 3
175+
expected_subprograms = 2
177176

178177
self.assertEqual(subprograms, expected_subprograms,
179178
f'"Expected {expected_subprograms} DISubprograms; '

numba_cuda/numba/cuda/tests/cudapy/test_optimization.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@ def device_func(x, y, z):
1414

1515

1616
# Fragments of code that are removed from kernel_func's PTX when optimization
17-
# is on
18-
removed_by_opt = ( '__local_depot0', 'call.uni', 'st.param.b64')
17+
# is on. Previously this list was longer when kernel wrappers were used - if
18+
# the test function were more complex it may be possible to isolate additional
19+
# fragments of PTX we could check for the absence / presence of, but removal of
20+
# the use of local memory is a good indicator that optimization was applied.
21+
removed_by_opt = ( '__local_depot0',)
1922

2023

2124
@skip_on_cudasim('Simulator does not optimize code')

0 commit comments

Comments
 (0)