11from llvmlite import ir
22from numba .core .typing .templates import ConcreteTemplate
3- from numba .core import types , typing , funcdesc , config , compiler , sigutils
3+ from numba .core import (cgutils , types , typing , funcdesc , config , compiler ,
4+ sigutils , utils )
45from numba .core .compiler import (sanitize_compile_result_entries , CompilerBase ,
56 DefaultPassBuilder , Flags , Option ,
67 CompileResult )
1112from numba .core .typed_passes import (IRLegalization , NativeLowering ,
1213 AnnotateTypes )
1314from warnings import warn
15+ from numba .cuda import nvvmutils
1416from numba .cuda .api import get_current_device
17+ from numba .cuda .cudadrv import nvvm
18+ from numba .cuda .descriptor import cuda_target
1519from numba .cuda .target import CUDACABICallConv
1620
1721
@@ -24,6 +28,15 @@ def _nvvm_options_type(x):
2428 return x
2529
2630
31+ def _optional_int_type (x ):
32+ if x is None :
33+ return None
34+
35+ else :
36+ assert isinstance (x , int )
37+ return x
38+
39+
2740class CUDAFlags (Flags ):
2841 nvvm_options = Option (
2942 type = _nvvm_options_type ,
@@ -35,6 +48,16 @@ class CUDAFlags(Flags):
3548 default = None ,
3649 doc = "Compute Capability" ,
3750 )
51+ max_registers = Option (
52+ type = _optional_int_type ,
53+ default = None ,
54+ doc = "Max registers"
55+ )
56+ lto = Option (
57+ type = bool ,
58+ default = False ,
59+ doc = "Enable Link-time Optimization"
60+ )
3861
3962
4063# The CUDACompileResult (CCR) has a specially-defined entry point equal to its
@@ -109,7 +132,11 @@ def run_pass(self, state):
109132 codegen = state .targetctx .codegen ()
110133 name = state .func_id .func_qualname
111134 nvvm_options = state .flags .nvvm_options
112- state .library = codegen .create_library (name , nvvm_options = nvvm_options )
135+ max_registers = state .flags .max_registers
136+ lto = state .flags .lto
137+ state .library = codegen .create_library (name , nvvm_options = nvvm_options ,
138+ max_registers = max_registers ,
139+ lto = lto )
113140 # Enable object caching upfront so that the library can be serialized.
114141 state .library .enable_object_caching ()
115142
@@ -152,7 +179,7 @@ def define_cuda_lowering_pipeline(self, state):
152179@global_compiler_lock
153180def compile_cuda (pyfunc , return_type , args , debug = False , lineinfo = False ,
154181 inline = False , fastmath = False , nvvm_options = None ,
155- cc = None ):
182+ cc = None , max_registers = None , lto = False ):
156183 if cc is None :
157184 raise ValueError ('Compute Capability must be supplied' )
158185
@@ -189,6 +216,8 @@ def compile_cuda(pyfunc, return_type, args, debug=False, lineinfo=False,
189216 if nvvm_options :
190217 flags .nvvm_options = nvvm_options
191218 flags .compute_capability = cc
219+ flags .max_registers = max_registers
220+ flags .lto = lto
192221
193222 # Run compilation pipeline
194223 from numba .core .target_extension import target_override
@@ -247,11 +276,155 @@ def cabi_wrap_function(context, lib, fndesc, wrapper_function_name,
247276 builder , func , restype , argtypes , callargs )
248277 builder .ret (return_value )
249278
279+ if config .DUMP_LLVM :
280+ utils .dump_llvm (fndesc , wrapper_module )
281+
250282 library .add_ir_module (wrapper_module )
251283 library .finalize ()
252284 return library
253285
254286
287+ def kernel_fixup (kernel , debug ):
288+ if debug :
289+ exc_helper = add_exception_store_helper (kernel )
290+
291+ # Pass 1 - replace:
292+ #
293+ # ret <value>
294+ #
295+ # with:
296+ #
297+ # exc_helper(<value>)
298+ # ret void
299+
300+ for block in kernel .blocks :
301+ for i , inst in enumerate (block .instructions ):
302+ if isinstance (inst , ir .Ret ):
303+ old_ret = block .instructions .pop ()
304+ block .terminator = None
305+
306+ # The original return's metadata will be set on the new
307+ # instructions in order to preserve debug info
308+ metadata = old_ret .metadata
309+
310+ builder = ir .IRBuilder (block )
311+ if debug :
312+ status_code = old_ret .operands [0 ]
313+ exc_helper_call = builder .call (exc_helper , (status_code ,))
314+ exc_helper_call .metadata = metadata
315+
316+ new_ret = builder .ret_void ()
317+ new_ret .metadata = old_ret .metadata
318+
319+ # Need to break out so we don't carry on modifying what we are
320+ # iterating over. There can only be one return in a block
321+ # anyway.
322+ break
323+
324+ # Pass 2: remove stores of null pointer to return value argument pointer
325+
326+ return_value = kernel .args [0 ]
327+
328+ for block in kernel .blocks :
329+ remove_list = []
330+
331+ # Find all stores first
332+ for inst in block .instructions :
333+ if (isinstance (inst , ir .StoreInstr )
334+ and inst .operands [1 ] == return_value ):
335+ remove_list .append (inst )
336+
337+ # Remove all stores
338+ for to_remove in remove_list :
339+ block .instructions .remove (to_remove )
340+
341+ # Replace non-void return type with void return type and remove return
342+ # value
343+
344+ if isinstance (kernel .type , ir .PointerType ):
345+ new_type = ir .PointerType (ir .FunctionType (ir .VoidType (),
346+ kernel .type .pointee .args [1 :]))
347+ else :
348+ new_type = ir .FunctionType (ir .VoidType (), kernel .type .args [1 :])
349+
350+ kernel .type = new_type
351+ kernel .return_value = ir .ReturnValue (kernel , ir .VoidType ())
352+ kernel .args = kernel .args [1 :]
353+
354+ # Mark as a kernel for NVVM
355+
356+ nvvm .set_cuda_kernel (kernel )
357+
358+ if config .DUMP_LLVM :
359+ print (f"LLVM DUMP: Post kernel fixup { kernel .name } " .center (80 , '-' ))
360+ print (kernel .module )
361+ print ('=' * 80 )
362+
363+
364+ def add_exception_store_helper (kernel ):
365+
366+ # Create global variables for exception state
367+
368+ def define_error_gv (postfix ):
369+ name = kernel .name + postfix
370+ gv = cgutils .add_global_variable (kernel .module , ir .IntType (32 ),
371+ name )
372+ gv .initializer = ir .Constant (gv .type .pointee , None )
373+ return gv
374+
375+ gv_exc = define_error_gv ("__errcode__" )
376+ gv_tid = []
377+ gv_ctaid = []
378+ for i in 'xyz' :
379+ gv_tid .append (define_error_gv ("__tid%s__" % i ))
380+ gv_ctaid .append (define_error_gv ("__ctaid%s__" % i ))
381+
382+ # Create exception store helper function
383+
384+ helper_name = kernel .name + "__exc_helper__"
385+ helper_type = ir .FunctionType (ir .VoidType (), (ir .IntType (32 ),))
386+ helper_func = ir .Function (kernel .module , helper_type , helper_name )
387+
388+ block = helper_func .append_basic_block (name = "entry" )
389+ builder = ir .IRBuilder (block )
390+
391+ # Implement status check / exception store logic
392+
393+ status_code = helper_func .args [0 ]
394+ call_conv = cuda_target .target_context .call_conv
395+ status = call_conv ._get_return_status (builder , status_code )
396+
397+ # Check error status
398+ with cgutils .if_likely (builder , status .is_ok ):
399+ builder .ret_void ()
400+
401+ with builder .if_then (builder .not_ (status .is_python_exc )):
402+ # User exception raised
403+ old = ir .Constant (gv_exc .type .pointee , None )
404+
405+ # Use atomic cmpxchg to prevent rewriting the error status
406+ # Only the first error is recorded
407+
408+ xchg = builder .cmpxchg (gv_exc , old , status .code ,
409+ 'monotonic' , 'monotonic' )
410+ changed = builder .extract_value (xchg , 1 )
411+
412+ # If the xchange is successful, save the thread ID.
413+ sreg = nvvmutils .SRegBuilder (builder )
414+ with builder .if_then (changed ):
415+ for dim , ptr , in zip ("xyz" , gv_tid ):
416+ val = sreg .tid (dim )
417+ builder .store (val , ptr )
418+
419+ for dim , ptr , in zip ("xyz" , gv_ctaid ):
420+ val = sreg .ctaid (dim )
421+ builder .store (val , ptr )
422+
423+ builder .ret_void ()
424+
425+ return helper_func
426+
427+
255428@global_compiler_lock
256429def compile (pyfunc , sig , debug = None , lineinfo = False , device = True ,
257430 fastmath = False , cc = None , opt = None , abi = "c" , abi_info = None ,
@@ -347,13 +520,10 @@ def compile(pyfunc, sig, debug=None, lineinfo=False, device=True,
347520 lib = cabi_wrap_function (tgt , lib , cres .fndesc , wrapper_name ,
348521 nvvm_options )
349522 else :
350- code = pyfunc .__code__
351- filename = code .co_filename
352- linenum = code .co_firstlineno
353-
354- lib , kernel = tgt .prepare_cuda_kernel (cres .library , cres .fndesc , debug ,
355- lineinfo , nvvm_options , filename ,
356- linenum )
523+ lib = cres .library
524+ kernel = lib .get_function (cres .fndesc .llvm_func_name )
525+ lib ._entry_name = cres .fndesc .llvm_func_name
526+ kernel_fixup (kernel , debug )
357527
358528 if lto :
359529 code = lib .get_ltoir (cc = cc )
0 commit comments