11from functools import reduce
22import operator
33import math
4+ import struct
45
56from llvmlite import ir
67import llvmlite .binding as ll
@@ -91,35 +92,85 @@ def _get_unique_smem_id(name):
9192 return "{0}_{1}" .format (name , _unique_smem_id )
9293
9394
95+ def _validate_alignment (alignment : int ):
96+ """
97+ Ensures that *alignment*, if not None, is a) greater than zero, b) a power
98+ of two, and c) a multiple of the size of a pointer. If any of these
99+ conditions are not met, a NumbaValueError is raised. Otherwise, this
100+ function returns None, indicating that the alignment is valid.
101+ """
102+ if alignment is None :
103+ return
104+ if not isinstance (alignment , int ):
105+ raise ValueError ("Alignment must be an integer" )
106+ if alignment <= 0 :
107+ raise ValueError ("Alignment must be positive" )
108+ if (alignment & (alignment - 1 )) != 0 :
109+ raise ValueError ("Alignment must be a power of 2" )
110+ pointer_size = struct .calcsize ("P" )
111+ if (alignment % pointer_size ) != 0 :
112+ msg = f"Alignment must be a multiple of { pointer_size } "
113+ raise ValueError (msg )
114+
115+
94116@lower (cuda .shared .array , types .IntegerLiteral , types .Any )
117+ @lower (cuda .shared .array , types .IntegerLiteral , types .Any , types .IntegerLiteral )
118+ @lower (cuda .shared .array , types .IntegerLiteral , types .Any , types .NoneType )
95119def cuda_shared_array_integer (context , builder , sig , args ):
96120 length = sig .args [0 ].literal_value
97121 dtype = parse_dtype (sig .args [1 ])
122+ alignment = None
123+ if len (sig .args ) == 3 :
124+ try :
125+ alignment = sig .args [2 ].literal_value
126+ _validate_alignment (alignment )
127+ except AttributeError :
128+ pass
98129 return _generic_array (context , builder , shape = (length ,), dtype = dtype ,
99130 symbol_name = _get_unique_smem_id ('_cudapy_smem' ),
100131 addrspace = nvvm .ADDRSPACE_SHARED ,
101- can_dynsized = True )
132+ can_dynsized = True , alignment = alignment )
102133
103134
104135@lower (cuda .shared .array , types .Tuple , types .Any )
105136@lower (cuda .shared .array , types .UniTuple , types .Any )
137+ @lower (cuda .shared .array , types .Tuple , types .Any , types .IntegerLiteral )
138+ @lower (cuda .shared .array , types .UniTuple , types .Any , types .IntegerLiteral )
139+ @lower (cuda .shared .array , types .Tuple , types .Any , types .NoneType )
140+ @lower (cuda .shared .array , types .UniTuple , types .Any , types .NoneType )
106141def cuda_shared_array_tuple (context , builder , sig , args ):
107142 shape = [ s .literal_value for s in sig .args [0 ] ]
108143 dtype = parse_dtype (sig .args [1 ])
144+ alignment = None
145+ if len (sig .args ) == 3 :
146+ try :
147+ alignment = sig .args [2 ].literal_value
148+ _validate_alignment (alignment )
149+ except AttributeError :
150+ pass
109151 return _generic_array (context , builder , shape = shape , dtype = dtype ,
110152 symbol_name = _get_unique_smem_id ('_cudapy_smem' ),
111153 addrspace = nvvm .ADDRSPACE_SHARED ,
112- can_dynsized = True )
154+ can_dynsized = True , alignment = alignment )
113155
114156
115157@lower (cuda .local .array , types .IntegerLiteral , types .Any )
158+ @lower (cuda .local .array , types .IntegerLiteral , types .Any , types .IntegerLiteral )
159+ @lower (cuda .local .array , types .IntegerLiteral , types .Any , types .NoneType )
116160def cuda_local_array_integer (context , builder , sig , args ):
117161 length = sig .args [0 ].literal_value
118162 dtype = parse_dtype (sig .args [1 ])
163+ alignment = None
164+ if len (sig .args ) == 3 :
165+ try :
166+ alignment = sig .args [2 ].literal_value
167+ _validate_alignment (alignment )
168+ except AttributeError :
169+ pass
119170 return _generic_array (context , builder , shape = (length ,), dtype = dtype ,
120171 symbol_name = '_cudapy_lmem' ,
121172 addrspace = nvvm .ADDRSPACE_LOCAL ,
122- can_dynsized = False )
173+ can_dynsized = False , alignment = alignment )
123174
124175
125176@lower (cuda .local .array , types .Tuple , types .Any )
@@ -954,7 +1005,7 @@ def ptx_nanosleep(context, builder, sig, args):
9541005
9551006
9561007def _generic_array (context , builder , shape , dtype , symbol_name , addrspace ,
957- can_dynsized = False ):
1008+ can_dynsized = False , alignment = None ):
9581009 elemcount = reduce (operator .mul , shape , 1 )
9591010
9601011 # Check for valid shape for this type of allocation.
@@ -981,17 +1032,37 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
9811032 # NVVM is smart enough to only use local memory if no register is
9821033 # available
9831034 dataptr = cgutils .alloca_once (builder , laryty , name = symbol_name )
1035+
1036+ # If the caller has specified a custom alignment, just set the align
1037+ # attribute on the alloca IR directly. We don't do any additional
1038+ # hand-holding here like checking the underlying data type's alignment
1039+ # or rounding up to the next power of 2--those checks will have already
1040+ # been done by the time we see the alignment value.
1041+ if alignment is not None :
1042+ dataptr .align = alignment
9841043 else :
9851044 lmod = builder .module
9861045
9871046 # Create global variable in the requested address space
9881047 gvmem = cgutils .add_global_variable (lmod , laryty , symbol_name ,
9891048 addrspace )
990- # Specify alignment to avoid misalignment bug
991- align = context .get_abi_sizeof (lldtype )
992- # Alignment is required to be a power of 2 for shared memory. If it is
993- # not a power of 2 (e.g. for a Record array) then round up accordingly.
994- gvmem .align = 1 << (align - 1 ).bit_length ()
1049+
1050+ # If the caller hasn't specified a custom alignment, obtain the
1051+ # underlying dtype alignment from the ABI and then round it up to
1052+ # a power of two. Otherwise, just use the caller's alignment.
1053+ #
1054+ # N.B. The caller *could* provide a valid-but-smaller-than-natural
1055+ # alignment here; we'll assume the caller knows what they're
1056+ # doing and let that through without error.
1057+
1058+ if alignment is None :
1059+ abi_alignment = context .get_abi_alignment (lldtype )
1060+ # Ensure a power of two alignment.
1061+ actual_alignment = 1 << (abi_alignment - 1 ).bit_length ()
1062+ else :
1063+ actual_alignment = alignment
1064+
1065+ gvmem .align = actual_alignment
9951066
9961067 if dynamic_smem :
9971068 gvmem .linkage = 'external'
@@ -1041,7 +1112,8 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
10411112
10421113 # Create array object
10431114 ndim = len (shape )
1044- aryty = types .Array (dtype = dtype , ndim = ndim , layout = 'C' )
1115+ aryty = types .Array (dtype = dtype , ndim = ndim , layout = 'C' ,
1116+ alignment = alignment )
10451117 ary = context .make_array (aryty )(context , builder )
10461118
10471119 context .populate_array (ary ,
0 commit comments