11from llvmlite import ir
22from numba .core .typing .templates import ConcreteTemplate
3- from numba .core import types , typing , funcdesc , config , compiler , sigutils
3+ from numba .core import (cgutils , types , typing , funcdesc , config , compiler ,
4+ sigutils , utils )
45from numba .core .compiler import (sanitize_compile_result_entries , CompilerBase ,
56 DefaultPassBuilder , Flags , Option ,
67 CompileResult )
1112from numba .core .typed_passes import (IRLegalization , NativeLowering ,
1213 AnnotateTypes )
1314from warnings import warn
15+ from numba .cuda import nvvmutils
1416from numba .cuda .api import get_current_device
17+ from numba .cuda .cudadrv import nvvm
18+ from numba .cuda .descriptor import cuda_target
1519from numba .cuda .target import CUDACABICallConv
1620
1721
@@ -24,6 +28,15 @@ def _nvvm_options_type(x):
2428 return x
2529
2630
31+ def _optional_int_type (x ):
32+ if x is None :
33+ return None
34+
35+ else :
36+ assert isinstance (x , int )
37+ return x
38+
39+
2740class CUDAFlags (Flags ):
2841 nvvm_options = Option (
2942 type = _nvvm_options_type ,
@@ -35,6 +48,11 @@ class CUDAFlags(Flags):
3548 default = None ,
3649 doc = "Compute Capability" ,
3750 )
51+ max_registers = Option (
52+ type = _optional_int_type ,
53+ default = None ,
54+ doc = "Max registers"
55+ )
3856
3957
4058# The CUDACompileResult (CCR) has a specially-defined entry point equal to its
@@ -109,7 +127,9 @@ def run_pass(self, state):
109127 codegen = state .targetctx .codegen ()
110128 name = state .func_id .func_qualname
111129 nvvm_options = state .flags .nvvm_options
112- state .library = codegen .create_library (name , nvvm_options = nvvm_options )
130+ max_registers = state .flags .max_registers
131+ state .library = codegen .create_library (name , nvvm_options = nvvm_options ,
132+ max_registers = max_registers )
113133 # Enable object caching upfront so that the library can be serialized.
114134 state .library .enable_object_caching ()
115135
@@ -152,7 +172,7 @@ def define_cuda_lowering_pipeline(self, state):
152172@global_compiler_lock
153173def compile_cuda (pyfunc , return_type , args , debug = False , lineinfo = False ,
154174 inline = False , fastmath = False , nvvm_options = None ,
155- cc = None ):
175+ cc = None , max_registers = None ):
156176 if cc is None :
157177 raise ValueError ('Compute Capability must be supplied' )
158178
@@ -189,6 +209,7 @@ def compile_cuda(pyfunc, return_type, args, debug=False, lineinfo=False,
189209 if nvvm_options :
190210 flags .nvvm_options = nvvm_options
191211 flags .compute_capability = cc
212+ flags .max_registers = max_registers
192213
193214 # Run compilation pipeline
194215 from numba .core .target_extension import target_override
@@ -247,11 +268,155 @@ def cabi_wrap_function(context, lib, fndesc, wrapper_function_name,
247268 builder , func , restype , argtypes , callargs )
248269 builder .ret (return_value )
249270
271+ if config .DUMP_LLVM :
272+ utils .dump_llvm (fndesc , wrapper_module )
273+
250274 library .add_ir_module (wrapper_module )
251275 library .finalize ()
252276 return library
253277
254278
279+ def kernel_fixup (kernel , debug ):
280+ if debug :
281+ exc_helper = add_exception_store_helper (kernel )
282+
283+ # Pass 1 - replace:
284+ #
285+ # ret <value>
286+ #
287+ # with:
288+ #
289+ # exc_helper(<value>)
290+ # ret void
291+
292+ for block in kernel .blocks :
293+ for i , inst in enumerate (block .instructions ):
294+ if isinstance (inst , ir .Ret ):
295+ old_ret = block .instructions .pop ()
296+ block .terminator = None
297+
298+ # The original return's metadata will be set on the new
299+ # instructions in order to preserve debug info
300+ metadata = old_ret .metadata
301+
302+ builder = ir .IRBuilder (block )
303+ if debug :
304+ status_code = old_ret .operands [0 ]
305+ exc_helper_call = builder .call (exc_helper , (status_code ,))
306+ exc_helper_call .metadata = metadata
307+
308+ new_ret = builder .ret_void ()
309+ new_ret .metadata = old_ret .metadata
310+
311+ # Need to break out so we don't carry on modifying what we are
312+ # iterating over. There can only be one return in a block
313+ # anyway.
314+ break
315+
316+ # Pass 2: remove stores of null pointer to return value argument pointer
317+
318+ return_value = kernel .args [0 ]
319+
320+ for block in kernel .blocks :
321+ remove_list = []
322+
323+ # Find all stores first
324+ for inst in block .instructions :
325+ if (isinstance (inst , ir .StoreInstr )
326+ and inst .operands [1 ] == return_value ):
327+ remove_list .append (inst )
328+
329+ # Remove all stores
330+ for to_remove in remove_list :
331+ block .instructions .remove (to_remove )
332+
333+ # Replace non-void return type with void return type and remove return
334+ # value
335+
336+ if isinstance (kernel .type , ir .PointerType ):
337+ new_type = ir .PointerType (ir .FunctionType (ir .VoidType (),
338+ kernel .type .pointee .args [1 :]))
339+ else :
340+ new_type = ir .FunctionType (ir .VoidType (), kernel .type .args [1 :])
341+
342+ kernel .type = new_type
343+ kernel .return_value = ir .ReturnValue (kernel , ir .VoidType ())
344+ kernel .args = kernel .args [1 :]
345+
346+ # Mark as a kernel for NVVM
347+
348+ nvvm .set_cuda_kernel (kernel )
349+
350+ if config .DUMP_LLVM :
351+ print (f"LLVM DUMP: Post kernel fixup { kernel .name } " .center (80 , '-' ))
352+ print (kernel .module )
353+ print ('=' * 80 )
354+
355+
356+ def add_exception_store_helper (kernel ):
357+
358+ # Create global variables for exception state
359+
360+ def define_error_gv (postfix ):
361+ name = kernel .name + postfix
362+ gv = cgutils .add_global_variable (kernel .module , ir .IntType (32 ),
363+ name )
364+ gv .initializer = ir .Constant (gv .type .pointee , None )
365+ return gv
366+
367+ gv_exc = define_error_gv ("__errcode__" )
368+ gv_tid = []
369+ gv_ctaid = []
370+ for i in 'xyz' :
371+ gv_tid .append (define_error_gv ("__tid%s__" % i ))
372+ gv_ctaid .append (define_error_gv ("__ctaid%s__" % i ))
373+
374+ # Create exception store helper function
375+
376+ helper_name = kernel .name + "__exc_helper__"
377+ helper_type = ir .FunctionType (ir .VoidType (), (ir .IntType (32 ),))
378+ helper_func = ir .Function (kernel .module , helper_type , helper_name )
379+
380+ block = helper_func .append_basic_block (name = "entry" )
381+ builder = ir .IRBuilder (block )
382+
383+ # Implement status check / exception store logic
384+
385+ status_code = helper_func .args [0 ]
386+ call_conv = cuda_target .target_context .call_conv
387+ status = call_conv ._get_return_status (builder , status_code )
388+
389+ # Check error status
390+ with cgutils .if_likely (builder , status .is_ok ):
391+ builder .ret_void ()
392+
393+ with builder .if_then (builder .not_ (status .is_python_exc )):
394+ # User exception raised
395+ old = ir .Constant (gv_exc .type .pointee , None )
396+
397+ # Use atomic cmpxchg to prevent rewriting the error status
398+ # Only the first error is recorded
399+
400+ xchg = builder .cmpxchg (gv_exc , old , status .code ,
401+ 'monotonic' , 'monotonic' )
402+ changed = builder .extract_value (xchg , 1 )
403+
404+ # If the xchange is successful, save the thread ID.
405+ sreg = nvvmutils .SRegBuilder (builder )
406+ with builder .if_then (changed ):
407+ for dim , ptr , in zip ("xyz" , gv_tid ):
408+ val = sreg .tid (dim )
409+ builder .store (val , ptr )
410+
411+ for dim , ptr , in zip ("xyz" , gv_ctaid ):
412+ val = sreg .ctaid (dim )
413+ builder .store (val , ptr )
414+
415+ builder .ret_void ()
416+
417+ return helper_func
418+
419+
255420@global_compiler_lock
256421def compile (pyfunc , sig , debug = False , lineinfo = False , device = True ,
257422 fastmath = False , cc = None , opt = True , abi = "c" , abi_info = None ,
@@ -344,13 +509,10 @@ def compile(pyfunc, sig, debug=False, lineinfo=False, device=True,
344509 lib = cabi_wrap_function (tgt , lib , cres .fndesc , wrapper_name ,
345510 nvvm_options )
346511 else :
347- code = pyfunc .__code__
348- filename = code .co_filename
349- linenum = code .co_firstlineno
350-
351- lib , kernel = tgt .prepare_cuda_kernel (cres .library , cres .fndesc , debug ,
352- lineinfo , nvvm_options , filename ,
353- linenum )
512+ lib = cres .library
513+ kernel = lib .get_function (cres .fndesc .llvm_func_name )
514+ lib ._entry_name = cres .fndesc .llvm_func_name
515+ kernel_fixup (kernel , debug )
354516
355517 if lto :
356518 code = lib .get_ltoir (cc = cc )
0 commit comments