@@ -72,9 +72,14 @@ def dim3_z(context, builder, sig, args):
7272# -----------------------------------------------------------------------------
7373
7474@lower (cuda .const .array_like , types .Array )
75+ @lower (cuda .const .array_like , types .Array , types .IntegerLiteral )
76+ @lower (cuda .const .array_like , types .Array , types .NoneType )
7577def cuda_const_array_like (context , builder , sig , args ):
7678 # This is a no-op because CUDATargetContext.make_constant_array already
7779 # created the constant array.
80+ if len (sig .args ) > 1 :
81+ # XXX-140: How do we handle alignment here?
82+ pass
7883 return args [0 ]
7984
8085
@@ -92,34 +97,62 @@ def _get_unique_smem_id(name):
9297
9398
9499@lower (cuda .shared .array , types .IntegerLiteral , types .Any )
100+ @lower (cuda .shared .array , types .IntegerLiteral , types .Any , types .IntegerLiteral )
101+ @lower (cuda .shared .array , types .IntegerLiteral , types .Any , types .NoneType )
95102def cuda_shared_array_integer (context , builder , sig , args ):
96103 length = sig .args [0 ].literal_value
97104 dtype = parse_dtype (sig .args [1 ])
105+ alignment = None
106+ if len (sig .args ) == 3 :
107+ try :
108+ alignment = sig .args [2 ].literal_value
109+ except AttributeError :
110+ pass
98111 return _generic_array (context , builder , shape = (length ,), dtype = dtype ,
99112 symbol_name = _get_unique_smem_id ('_cudapy_smem' ),
100113 addrspace = nvvm .ADDRSPACE_SHARED ,
101- can_dynsized = True )
114+ can_dynsized = True , alignment = alignment )
102115
103116
117+ # XXX-140: Should I just use types.Any for the last alignment arg?
118+
104119@lower (cuda .shared .array , types .Tuple , types .Any )
105120@lower (cuda .shared .array , types .UniTuple , types .Any )
121+ @lower (cuda .shared .array , types .Tuple , types .Any , types .IntegerLiteral )
122+ @lower (cuda .shared .array , types .UniTuple , types .Any , types .IntegerLiteral )
123+ @lower (cuda .shared .array , types .Tuple , types .Any , types .NoneType )
124+ @lower (cuda .shared .array , types .UniTuple , types .Any , types .NoneType )
106125def cuda_shared_array_tuple (context , builder , sig , args ):
107126 shape = [ s .literal_value for s in sig .args [0 ] ]
108127 dtype = parse_dtype (sig .args [1 ])
128+ alignment = None
129+ if len (sig .args ) == 3 :
130+ try :
131+ alignment = sig .args [2 ].literal_value
132+ except AttributeError :
133+ pass
109134 return _generic_array (context , builder , shape = shape , dtype = dtype ,
110135 symbol_name = _get_unique_smem_id ('_cudapy_smem' ),
111136 addrspace = nvvm .ADDRSPACE_SHARED ,
112- can_dynsized = True )
137+ can_dynsized = True , alignment = alignment )
113138
114139
115140@lower (cuda .local .array , types .IntegerLiteral , types .Any )
141+ @lower (cuda .local .array , types .IntegerLiteral , types .Any , types .IntegerLiteral )
142+ @lower (cuda .local .array , types .IntegerLiteral , types .Any , types .NoneType )
116143def cuda_local_array_integer (context , builder , sig , args ):
117144 length = sig .args [0 ].literal_value
118145 dtype = parse_dtype (sig .args [1 ])
146+ alignment = None
147+ if len (sig .args ) == 3 :
148+ try :
149+ alignment = sig .args [2 ].literal_value
150+ except AttributeError :
151+ pass
119152 return _generic_array (context , builder , shape = (length ,), dtype = dtype ,
120153 symbol_name = '_cudapy_lmem' ,
121154 addrspace = nvvm .ADDRSPACE_LOCAL ,
122- can_dynsized = False )
155+ can_dynsized = False , alignment = alignment )
123156
124157
125158@lower (cuda .local .array , types .Tuple , types .Any )
@@ -954,7 +987,7 @@ def ptx_nanosleep(context, builder, sig, args):
954987
955988
956989def _generic_array (context , builder , shape , dtype , symbol_name , addrspace ,
957- can_dynsized = False ):
990+ can_dynsized = False , alignment = None ):
958991 elemcount = reduce (operator .mul , shape , 1 )
959992
960993 # Check for valid shape for this type of allocation.
@@ -981,17 +1014,37 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
9811014 # NVVM is smart enough to only use local memory if no register is
9821015 # available
9831016 dataptr = cgutils .alloca_once (builder , laryty , name = symbol_name )
1017+
1018+ # If the user has specified a custom alignment, just set the align
1019+ # attribute on the alloca IR directly. We don't do any additional
1020+ # hand-holding here like checking the underlying data type's alignment
1021+ # or rounding up to the next power of 2--those checks will have already
1022+ # been done by the time we see the alignment value.
1023+ if alignment is not None :
1024+ dataptr .align = alignment
9841025 else :
9851026 lmod = builder .module
9861027
9871028 # Create global variable in the requested address space
9881029 gvmem = cgutils .add_global_variable (lmod , laryty , symbol_name ,
9891030 addrspace )
990- # Specify alignment to avoid misalignment bug
991- align = context .get_abi_sizeof (lldtype )
992- # Alignment is required to be a power of 2 for shared memory. If it is
993- # not a power of 2 (e.g. for a Record array) then round up accordingly.
994- gvmem .align = 1 << (align - 1 ).bit_length ()
1031+
1032+ # If the caller hasn't specified a custom alignment, obtain the
1033+ # underlying dtype alignment from the ABI and then round it up to
1034+ # a power of two. Otherwise, just use the caller's alignment.
1035+ #
1036+ # N.B. The caller *could* provide a valid-but-smaller-than-natural
1037+ # alignment here; we'll assume the caller knows what they're
1038+ # doing and let that through without error.
1039+
1040+ if alignment is None :
1041+ abi_alignment = context .get_abi_alignment (lldtype )
1042+ # Ensure a power of two alignment.
1043+ actual_alignment = 1 << (abi_alignment - 1 ).bit_length ()
1044+ else :
1045+ actual_alignment = alignment
1046+
1047+ gvmem .align = actual_alignment
9951048
9961049 if dynamic_smem :
9971050 gvmem .linkage = 'external'
@@ -1041,7 +1094,8 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
10411094
10421095 # Create array object
10431096 ndim = len (shape )
1044- aryty = types .Array (dtype = dtype , ndim = ndim , layout = 'C' )
1097+ aryty = types .Array (dtype = dtype , ndim = ndim , layout = 'C' ,
1098+ alignment = alignment )
10451099 ary = context .make_array (aryty )(context , builder )
10461100
10471101 context .populate_array (ary ,
0 commit comments