11from functools import reduce
22import operator
33import math
4+ import struct
45
56from llvmlite import ir
67import llvmlite .binding as ll
@@ -72,9 +73,14 @@ def dim3_z(context, builder, sig, args):
7273# -----------------------------------------------------------------------------
7374
7475@lower (cuda .const .array_like , types .Array )
76+ @lower (cuda .const .array_like , types .Array , types .IntegerLiteral )
77+ @lower (cuda .const .array_like , types .Array , types .NoneType )
7578def cuda_const_array_like (context , builder , sig , args ):
7679 # This is a no-op because CUDATargetContext.make_constant_array already
7780 # created the constant array.
81+ if len (sig .args ) > 1 :
82+ # XXX-140: How do we handle alignment here?
83+ pass
7884 return args [0 ]
7985
8086
@@ -91,35 +97,85 @@ def _get_unique_smem_id(name):
9197 return "{0}_{1}" .format (name , _unique_smem_id )
9298
9399
100+ def _validate_alignment (alignment : int ):
101+ """
102+ Ensures that *alignment*, if not None, is a) greater than zero, b) a power
103+ of two, and c) a multiple of the size of a pointer. If any of these
104+ conditions are not met, a NumbaValueError is raised. Otherwise, this
105+ function returns None, indicating that the alignment is valid.
106+ """
107+ if alignment is None :
108+ return
109+ if not isinstance (alignment , int ):
110+ raise ValueError ("Alignment must be an integer" )
111+ if alignment <= 0 :
112+ raise ValueError ("Alignment must be positive" )
113+ if (alignment & (alignment - 1 )) != 0 :
114+ raise ValueError ("Alignment must be a power of 2" )
115+ pointer_size = struct .calcsize ("P" )
116+ if (alignment % pointer_size ) != 0 :
117+ msg = f"Alignment must be a multiple of { pointer_size } "
118+ raise ValueError (msg )
119+
120+
94121@lower (cuda .shared .array , types .IntegerLiteral , types .Any )
122+ @lower (cuda .shared .array , types .IntegerLiteral , types .Any , types .IntegerLiteral )
123+ @lower (cuda .shared .array , types .IntegerLiteral , types .Any , types .NoneType )
95124def cuda_shared_array_integer (context , builder , sig , args ):
96125 length = sig .args [0 ].literal_value
97126 dtype = parse_dtype (sig .args [1 ])
127+ alignment = None
128+ if len (sig .args ) == 3 :
129+ try :
130+ alignment = sig .args [2 ].literal_value
131+ _validate_alignment (alignment )
132+ except (AttributeError , ValueError ):
133+ pass
98134 return _generic_array (context , builder , shape = (length ,), dtype = dtype ,
99135 symbol_name = _get_unique_smem_id ('_cudapy_smem' ),
100136 addrspace = nvvm .ADDRSPACE_SHARED ,
101- can_dynsized = True )
137+ can_dynsized = True , alignment = alignment )
102138
103139
104140@lower (cuda .shared .array , types .Tuple , types .Any )
105141@lower (cuda .shared .array , types .UniTuple , types .Any )
142+ @lower (cuda .shared .array , types .Tuple , types .Any , types .IntegerLiteral )
143+ @lower (cuda .shared .array , types .UniTuple , types .Any , types .IntegerLiteral )
144+ @lower (cuda .shared .array , types .Tuple , types .Any , types .NoneType )
145+ @lower (cuda .shared .array , types .UniTuple , types .Any , types .NoneType )
106146def cuda_shared_array_tuple (context , builder , sig , args ):
107147 shape = [ s .literal_value for s in sig .args [0 ] ]
108148 dtype = parse_dtype (sig .args [1 ])
149+ alignment = None
150+ if len (sig .args ) == 3 :
151+ try :
152+ alignment = sig .args [2 ].literal_value
153+ _validate_alignment (alignment )
154+ except (AttributeError , ValueError ):
155+ pass
109156 return _generic_array (context , builder , shape = shape , dtype = dtype ,
110157 symbol_name = _get_unique_smem_id ('_cudapy_smem' ),
111158 addrspace = nvvm .ADDRSPACE_SHARED ,
112- can_dynsized = True )
159+ can_dynsized = True , alignment = alignment )
113160
114161
115162@lower (cuda .local .array , types .IntegerLiteral , types .Any )
163+ @lower (cuda .local .array , types .IntegerLiteral , types .Any , types .IntegerLiteral )
164+ @lower (cuda .local .array , types .IntegerLiteral , types .Any , types .NoneType )
116165def cuda_local_array_integer (context , builder , sig , args ):
117166 length = sig .args [0 ].literal_value
118167 dtype = parse_dtype (sig .args [1 ])
168+ alignment = None
169+ if len (sig .args ) == 3 :
170+ try :
171+ alignment = sig .args [2 ].literal_value
172+ _validate_alignment (alignment )
173+ except (AttributeError , ValueError ):
174+ pass
119175 return _generic_array (context , builder , shape = (length ,), dtype = dtype ,
120176 symbol_name = '_cudapy_lmem' ,
121177 addrspace = nvvm .ADDRSPACE_LOCAL ,
122- can_dynsized = False )
178+ can_dynsized = False , alignment = alignment )
123179
124180
125181@lower (cuda .local .array , types .Tuple , types .Any )
@@ -954,7 +1010,7 @@ def ptx_nanosleep(context, builder, sig, args):
9541010
9551011
9561012def _generic_array (context , builder , shape , dtype , symbol_name , addrspace ,
957- can_dynsized = False ):
1013+ can_dynsized = False , alignment = None ):
9581014 elemcount = reduce (operator .mul , shape , 1 )
9591015
9601016 # Check for valid shape for this type of allocation.
@@ -981,17 +1037,37 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
9811037 # NVVM is smart enough to only use local memory if no register is
9821038 # available
9831039 dataptr = cgutils .alloca_once (builder , laryty , name = symbol_name )
1040+
1041+ # If the caller has specified a custom alignment, just set the align
1042+ # attribute on the alloca IR directly. We don't do any additional
1043+ # hand-holding here like checking the underlying data type's alignment
1044+ # or rounding up to the next power of 2--those checks will have already
1045+ # been done by the time we see the alignment value.
1046+ if alignment is not None :
1047+ dataptr .align = alignment
9841048 else :
9851049 lmod = builder .module
9861050
9871051 # Create global variable in the requested address space
9881052 gvmem = cgutils .add_global_variable (lmod , laryty , symbol_name ,
9891053 addrspace )
990- # Specify alignment to avoid misalignment bug
991- align = context .get_abi_sizeof (lldtype )
992- # Alignment is required to be a power of 2 for shared memory. If it is
993- # not a power of 2 (e.g. for a Record array) then round up accordingly.
994- gvmem .align = 1 << (align - 1 ).bit_length ()
1054+
1055+ # If the caller hasn't specified a custom alignment, obtain the
1056+ # underlying dtype alignment from the ABI and then round it up to
1057+ # a power of two. Otherwise, just use the caller's alignment.
1058+ #
1059+ # N.B. The caller *could* provide a valid-but-smaller-than-natural
1060+ # alignment here; we'll assume the caller knows what they're
1061+ # doing and let that through without error.
1062+
1063+ if alignment is None :
1064+ abi_alignment = context .get_abi_alignment (lldtype )
1065+ # Ensure a power of two alignment.
1066+ actual_alignment = 1 << (abi_alignment - 1 ).bit_length ()
1067+ else :
1068+ actual_alignment = alignment
1069+
1070+ gvmem .align = actual_alignment
9951071
9961072 if dynamic_smem :
9971073 gvmem .linkage = 'external'
@@ -1041,7 +1117,8 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
10411117
10421118 # Create array object
10431119 ndim = len (shape )
1044- aryty = types .Array (dtype = dtype , ndim = ndim , layout = 'C' )
1120+ aryty = types .Array (dtype = dtype , ndim = ndim , layout = 'C' ,
1121+ alignment = alignment )
10451122 ary = context .make_array (aryty )(context , builder )
10461123
10471124 context .populate_array (ary ,
0 commit comments