@@ -199,12 +199,52 @@ def check_error(self, error, msg, exit=False):
199199
200200
201201class CompilationUnit (object ):
202- def __init__ (self ):
202+ """
203+ A CompilationUnit is a set of LLVM modules that are compiled to PTX or
204+ LTO-IR with NVVM.
205+
206+ Compilation options are accepted as a dict mapping option names to values,
207+ with the following considerations:
208+
209+ - Underscores (`_`) in option names are converted to dashes (`-`), to match
210+ NVVM's option name format.
211+ - Options that take a value will be emitted in the form "-<name>=<value>".
212+ - Booleans passed as option values will be converted to integers.
213+ - Options which take no value (such as `-gen-lto`) should have a value of
214+ `None` and will be emitted in the form "-<name>".
215+
216+ For documentation on NVVM compilation options, see the CUDA Toolkit
217+ Documentation:
218+
219+ https://docs.nvidia.com/cuda/libnvvm-api/index.html#_CPPv418nvvmCompileProgram11nvvmProgramiPPKc
220+ """
221+
222+ def __init__ (self , options ):
203223 self .driver = NVVM ()
204224 self ._handle = nvvm_program ()
205225 err = self .driver .nvvmCreateProgram (byref (self ._handle ))
206226 self .driver .check_error (err , 'Failed to create CU' )
207227
228+ def stringify_option (k , v ):
229+ k = k .replace ('_' , '-' )
230+
231+ if v is None :
232+ return f'-{ k } ' .encode ('utf-8' )
233+
234+ if isinstance (v , bool ):
235+ v = int (v )
236+
237+ return f'-{ k } ={ v } ' .encode ('utf-8' )
238+
239+ options = [stringify_option (k , v ) for k , v in options .items ()]
240+ option_ptrs = (c_char_p * len (options ))(* [c_char_p (x ) for x in options ])
241+
242+ # We keep both the options and the pointers to them so that options are
243+ # not destroyed before we've used their values
244+ self .options = options
245+ self .option_ptrs = option_ptrs
246+ self .n_options = len (options )
247+
208248 def __del__ (self ):
209249 driver = NVVM ()
210250 err = driver .nvvmDestroyProgram (byref (self ._handle ))
@@ -230,60 +270,35 @@ def lazy_add_module(self, buffer):
230270 len (buffer ), None )
231271 self .driver .check_error (err , 'Failed to add module' )
232272
233- def compile (self , ** options ):
234- """Perform Compilation.
235-
236- Compilation options are accepted as keyword arguments, with the
237- following considerations:
238-
239- - Underscores (`_`) in option names are converted to dashes (`-`), to
240- match NVVM's option name format.
241- - Options that take a value will be emitted in the form
242- "-<name>=<value>".
243- - Booleans passed as option values will be converted to integers.
244- - Options which take no value (such as `-gen-lto`) should have a value
245- of `None` passed in and will be emitted in the form "-<name>".
246-
247- For documentation on NVVM compilation options, see the CUDA Toolkit
248- Documentation:
249-
250- https://docs.nvidia.com/cuda/libnvvm-api/index.html#_CPPv418nvvmCompileProgram11nvvmProgramiPPKc
273+ def verify (self ):
251274 """
252-
253- def stringify_option (k , v ):
254- k = k .replace ('_' , '-' )
255-
256- if v is None :
257- return f'-{ k } '
258-
259- if isinstance (v , bool ):
260- v = int (v )
261-
262- return f'-{ k } ={ v } '
263-
264- options = [stringify_option (k , v ) for k , v in options .items ()]
265-
266- c_opts = (c_char_p * len (options ))(* [c_char_p (x .encode ('utf8' ))
267- for x in options ])
268- # verify
269- err = self .driver .nvvmVerifyProgram (self ._handle , len (options ), c_opts )
275+ Run the NVVM verifier on all code added to the compilation unit.
276+ """
277+ err = self .driver .nvvmVerifyProgram (self ._handle , self .n_options ,
278+ self .option_ptrs )
270279 self ._try_error (err , 'Failed to verify\n ' )
271280
272- # compile
273- err = self .driver .nvvmCompileProgram (self ._handle , len (options ), c_opts )
281+ def compile (self ):
282+ """
283+ Compile all modules added to the compilation unit and return the
284+ resulting PTX or LTO-IR (depending on the options).
285+ """
286+ err = self .driver .nvvmCompileProgram (self ._handle , self .n_options ,
287+ self .option_ptrs )
274288 self ._try_error (err , 'Failed to compile\n ' )
275289
276- # get result
277- reslen = c_size_t ()
278- err = self .driver .nvvmGetCompiledResultSize (self ._handle , byref (reslen ))
290+ # Get result
291+ result_size = c_size_t ()
292+ err = self .driver .nvvmGetCompiledResultSize (self ._handle ,
293+ byref (result_size ))
279294
280295 self ._try_error (err , 'Failed to get size of compiled result.' )
281296
282- output_buffer = (c_char * reslen .value )()
297+ output_buffer = (c_char * result_size .value )()
283298 err = self .driver .nvvmGetCompiledResult (self ._handle , output_buffer )
284299 self ._try_error (err , 'Failed to get compiled result.' )
285300
286- # get log
301+ # Get log
287302 self .log = self .get_log ()
288303 if self .log :
289304 warnings .warn (self .log , category = NvvmWarning )
@@ -620,27 +635,31 @@ def llvm_replace(llvmir):
620635 return llvmir
621636
622637
623- def compile_ir (llvmir , ** opts ):
638+ def compile_ir (llvmir , ** options ):
624639 if isinstance (llvmir , str ):
625640 llvmir = [llvmir ]
626641
627- if opts .pop ('fastmath' , False ):
628- opts .update ({
642+ if options .pop ('fastmath' , False ):
643+ options .update ({
629644 'ftz' : True ,
630645 'fma' : True ,
631646 'prec_div' : False ,
632647 'prec_sqrt' : False ,
633648 })
634649
635- cu = CompilationUnit ()
636- libdevice = LibDevice ()
650+ cu = CompilationUnit (options )
637651
638652 for mod in llvmir :
639653 mod = llvm_replace (mod )
640654 cu .add_module (mod .encode ('utf8' ))
655+ cu .verify ()
656+
657+ # We add libdevice following verification so that it is not subject to the
658+ # verifier's requirements
659+ libdevice = LibDevice ()
641660 cu .lazy_add_module (libdevice .get ())
642661
643- return cu .compile (** opts )
662+ return cu .compile ()
644663
645664
646665re_attributes_def = re .compile (r"^attributes #\d+ = \{ ([\w\s]+)\ }" )
0 commit comments