Skip to content

Commit 5796f06

Browse files
committed
Implement alignment support for local and shared arrays.
1 parent 3e9e705 commit 5796f06

File tree

4 files changed

+234
-21
lines changed

4 files changed

+234
-21
lines changed

numba_cuda/numba/cuda/cudadecl.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
class Cuda_array_decl(CallableTemplate):
2727
def generic(self):
28-
def typer(shape, dtype):
28+
def typer(shape, dtype, alignment=None):
2929

3030
# Only integer literals and tuples of integer literals are valid
3131
# shapes
@@ -39,6 +39,14 @@ def typer(shape, dtype):
3939
else:
4040
return None
4141

42+
# N.B. We don't do anything with alignment in this routine; it's
43+
# not part of the underlying types.Array interface, so we
44+
# don't need to pass it down the stack. The value supplied
45+
# to the array declaration will be handled in the lowering.
46+
#
47+
# E.g. `cuda.local.array(..., alignment=256)` will be handled
48+
# by `cudaimpl.cuda_local_array_integer()`.
49+
4250
ndim = parse_shape(shape)
4351
nb_dtype = parse_dtype(dtype)
4452
if nb_dtype is not None and ndim is not None:

numba_cuda/numba/cuda/cudaimpl.py

Lines changed: 87 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from functools import reduce
22
import operator
33
import math
4+
import struct
45

56
from llvmlite import ir
67
import llvmlite.binding as ll
@@ -72,9 +73,14 @@ def dim3_z(context, builder, sig, args):
7273
# -----------------------------------------------------------------------------
7374

7475
@lower(cuda.const.array_like, types.Array)
76+
@lower(cuda.const.array_like, types.Array, types.IntegerLiteral)
77+
@lower(cuda.const.array_like, types.Array, types.NoneType)
7578
def cuda_const_array_like(context, builder, sig, args):
7679
# This is a no-op because CUDATargetContext.make_constant_array already
7780
# created the constant array.
81+
if len(sig.args) > 1:
82+
# XXX-140: How do we handle alignment here?
83+
pass
7884
return args[0]
7985

8086

@@ -91,35 +97,85 @@ def _get_unique_smem_id(name):
9197
return "{0}_{1}".format(name, _unique_smem_id)
9298

9399

100+
def _validate_alignment(alignment: int):
101+
"""
102+
Ensures that *alignment*, if not None, is a) greater than zero, b) a power
103+
of two, and c) a multiple of the size of a pointer. If any of these
104+
conditions are not met, a NumbaValueError is raised. Otherwise, this
105+
function returns None, indicating that the alignment is valid.
106+
"""
107+
if alignment is None:
108+
return
109+
if not isinstance(alignment, int):
110+
raise ValueError("Alignment must be an integer")
111+
if alignment <= 0:
112+
raise ValueError("Alignment must be positive")
113+
if (alignment & (alignment - 1)) != 0:
114+
raise ValueError("Alignment must be a power of 2")
115+
pointer_size = struct.calcsize("P")
116+
if (alignment % pointer_size) != 0:
117+
msg = f"Alignment must be a multiple of {pointer_size}"
118+
raise ValueError(msg)
119+
120+
94121
@lower(cuda.shared.array, types.IntegerLiteral, types.Any)
122+
@lower(cuda.shared.array, types.IntegerLiteral, types.Any, types.IntegerLiteral)
123+
@lower(cuda.shared.array, types.IntegerLiteral, types.Any, types.NoneType)
95124
def cuda_shared_array_integer(context, builder, sig, args):
96125
length = sig.args[0].literal_value
97126
dtype = parse_dtype(sig.args[1])
127+
alignment = None
128+
if len(sig.args) == 3:
129+
try:
130+
alignment = sig.args[2].literal_value
131+
_validate_alignment(alignment)
132+
except (AttributeError, ValueError):
133+
pass
98134
return _generic_array(context, builder, shape=(length,), dtype=dtype,
99135
symbol_name=_get_unique_smem_id('_cudapy_smem'),
100136
addrspace=nvvm.ADDRSPACE_SHARED,
101-
can_dynsized=True)
137+
can_dynsized=True, alignment=alignment)
102138

103139

104140
@lower(cuda.shared.array, types.Tuple, types.Any)
105141
@lower(cuda.shared.array, types.UniTuple, types.Any)
142+
@lower(cuda.shared.array, types.Tuple, types.Any, types.IntegerLiteral)
143+
@lower(cuda.shared.array, types.UniTuple, types.Any, types.IntegerLiteral)
144+
@lower(cuda.shared.array, types.Tuple, types.Any, types.NoneType)
145+
@lower(cuda.shared.array, types.UniTuple, types.Any, types.NoneType)
106146
def cuda_shared_array_tuple(context, builder, sig, args):
107147
shape = [ s.literal_value for s in sig.args[0] ]
108148
dtype = parse_dtype(sig.args[1])
149+
alignment = None
150+
if len(sig.args) == 3:
151+
try:
152+
alignment = sig.args[2].literal_value
153+
_validate_alignment(alignment)
154+
except (AttributeError, ValueError):
155+
pass
109156
return _generic_array(context, builder, shape=shape, dtype=dtype,
110157
symbol_name=_get_unique_smem_id('_cudapy_smem'),
111158
addrspace=nvvm.ADDRSPACE_SHARED,
112-
can_dynsized=True)
159+
can_dynsized=True, alignment=alignment)
113160

114161

115162
@lower(cuda.local.array, types.IntegerLiteral, types.Any)
163+
@lower(cuda.local.array, types.IntegerLiteral, types.Any, types.IntegerLiteral)
164+
@lower(cuda.local.array, types.IntegerLiteral, types.Any, types.NoneType)
116165
def cuda_local_array_integer(context, builder, sig, args):
117166
length = sig.args[0].literal_value
118167
dtype = parse_dtype(sig.args[1])
168+
alignment = None
169+
if len(sig.args) == 3:
170+
try:
171+
alignment = sig.args[2].literal_value
172+
_validate_alignment(alignment)
173+
except (AttributeError, ValueError):
174+
pass
119175
return _generic_array(context, builder, shape=(length,), dtype=dtype,
120176
symbol_name='_cudapy_lmem',
121177
addrspace=nvvm.ADDRSPACE_LOCAL,
122-
can_dynsized=False)
178+
can_dynsized=False, alignment=alignment)
123179

124180

125181
@lower(cuda.local.array, types.Tuple, types.Any)
@@ -954,7 +1010,7 @@ def ptx_nanosleep(context, builder, sig, args):
9541010

9551011

9561012
def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
957-
can_dynsized=False):
1013+
can_dynsized=False, alignment=None):
9581014
elemcount = reduce(operator.mul, shape, 1)
9591015

9601016
# Check for valid shape for this type of allocation.
@@ -981,17 +1037,37 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
9811037
# NVVM is smart enough to only use local memory if no register is
9821038
# available
9831039
dataptr = cgutils.alloca_once(builder, laryty, name=symbol_name)
1040+
1041+
# If the caller has specified a custom alignment, just set the align
1042+
# attribute on the alloca IR directly. We don't do any additional
1043+
# hand-holding here like checking the underlying data type's alignment
1044+
# or rounding up to the next power of 2--those checks will have already
1045+
# been done by the time we see the alignment value.
1046+
if alignment is not None:
1047+
dataptr.align = alignment
9841048
else:
9851049
lmod = builder.module
9861050

9871051
# Create global variable in the requested address space
9881052
gvmem = cgutils.add_global_variable(lmod, laryty, symbol_name,
9891053
addrspace)
990-
# Specify alignment to avoid misalignment bug
991-
align = context.get_abi_sizeof(lldtype)
992-
# Alignment is required to be a power of 2 for shared memory. If it is
993-
# not a power of 2 (e.g. for a Record array) then round up accordingly.
994-
gvmem.align = 1 << (align - 1 ).bit_length()
1054+
1055+
# If the caller hasn't specified a custom alignment, obtain the
1056+
# underlying dtype alignment from the ABI and then round it up to
1057+
# a power of two. Otherwise, just use the caller's alignment.
1058+
#
1059+
# N.B. The caller *could* provide a valid-but-smaller-than-natural
1060+
# alignment here; we'll assume the caller knows what they're
1061+
# doing and let that through without error.
1062+
1063+
if alignment is None:
1064+
abi_alignment = context.get_abi_alignment(lldtype)
1065+
# Ensure a power of two alignment.
1066+
actual_alignment = 1 << (abi_alignment - 1).bit_length()
1067+
else:
1068+
actual_alignment = alignment
1069+
1070+
gvmem.align = actual_alignment
9951071

9961072
if dynamic_smem:
9971073
gvmem.linkage = 'external'
@@ -1041,7 +1117,8 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
10411117

10421118
# Create array object
10431119
ndim = len(shape)
1044-
aryty = types.Array(dtype=dtype, ndim=ndim, layout='C')
1120+
aryty = types.Array(dtype=dtype, ndim=ndim, layout='C',
1121+
alignment=alignment)
10451122
ary = context.make_array(aryty)(context, builder)
10461123

10471124
context.populate_array(ary,

numba_cuda/numba/cuda/stubs.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,17 @@ class shared(Stub):
116116
_description_ = '<shared>'
117117

118118
@stub_function
119-
def array(shape, dtype):
119+
def array(shape, dtype, alignment=None):
120120
'''
121-
Allocate a shared array of the given *shape* and *type*. *shape* is
122-
either an integer or a tuple of integers representing the array's
123-
dimensions. *type* is a :ref:`Numba type <numba-types>` of the
124-
elements needing to be stored in the array.
121+
Allocate a shared array of the given *shape*, *type*, and, optionally,
122+
*alignment*. *shape* is either an integer or a tuple of integers
123+
representing the array's dimensions. *type* is a :ref:`Numba type
124+
<numba-types>` of the elements needing to be stored in the array.
125+
*alignment* is an optional integer specifying the byte alignment of
126+
the array. When specified, it must be a power of two, and a multiple
127+
of the size of a pointer (4 for 32-bit, 8 for 64-bit). When not
128+
specified, the array is allocated with an alignment appropriate for
129+
the supplied *dtype*.
125130
126131
The returned array-like object can be read and written to like any
127132
normal device array (e.g. through indexing).
@@ -135,12 +140,21 @@ class local(Stub):
135140
_description_ = '<local>'
136141

137142
@stub_function
138-
def array(shape, dtype):
143+
def array(shape, dtype, alignment=None):
139144
'''
140-
Allocate a local array of the given *shape* and *type*. The array is
141-
private to the current thread, and resides in global memory. An
142-
array-like object is returned which can be read and written to like any
143-
standard array (e.g. through indexing).
145+
Allocate a local array of the given *shape*, *type*, and, optionally,
146+
*alignment*. *shape* is either an integer or a tuple of integers
147+
representing the array's dimensions. *type* is a :ref:`Numba type
148+
<numba-types>` of the elements needing to be stored in the array.
149+
*alignment* is an optional integer specifying the byte alignment of
150+
the array. When specified, it must be a power of two, and a multiple
151+
of the size of a pointer (4 for 32-bit, 8 for 64-bit). When not
152+
specified, the array is allocated with an alignment appropriate for
153+
the supplied *dtype*.
154+
155+
The array is private to the current thread, and resides in global
156+
memory. An array-like object is returned which can be read and
157+
written to like any standard array (e.g. through indexing).
144158
'''
145159

146160

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import itertools
2+
import numpy as np
3+
from numba import cuda
4+
from numba.cuda.testing import CUDATestCase
5+
from numba.core.errors import TypingError
6+
import unittest
7+
8+
9+
# Set to true if you want to see dots printed for each subtest.
10+
NOISY = True
11+
12+
13+
# N.B. We name the test class TestArrayAddressAlignment to avoid name conflict
14+
# with the test_alignment.TestArrayAlignment class.
15+
16+
17+
class TestArrayAddressAlignment(CUDATestCase):
18+
"""
19+
Test cuda.local.array and cuda.shared.array support for an alignment
20+
keyword argument.
21+
"""
22+
23+
def test_array_alignment(self):
24+
shapes = (1, 3, 4, 8, 9, 50)
25+
dtypes = (np.uint8, np.uint16, np.uint32, np.uint64)
26+
alignments = (None, 8, 16, 32, 64, 128, 256)
27+
array_types = [(0, 'local'), (1, 'shared')]
28+
29+
items = itertools.product(array_types, shapes, dtypes, alignments)
30+
31+
for (which, array_type), shape, dtype, alignment in items:
32+
with self.subTest(array_type=array_type, shape=shape,
33+
dtype=dtype, alignment=alignment):
34+
@cuda.jit
35+
def f(loc, shrd, which):
36+
i = cuda.grid(1)
37+
if which == 0:
38+
local_array = cuda.local.array(
39+
shape=shape,
40+
dtype=dtype,
41+
alignment=alignment,
42+
)
43+
if i == 0:
44+
loc[0] = local_array.ctypes.data
45+
else:
46+
shared_array = cuda.shared.array(
47+
shape=shape,
48+
dtype=dtype,
49+
alignment=alignment,
50+
)
51+
if i == 0:
52+
shrd[0] = shared_array.ctypes.data
53+
54+
loc = np.zeros(1, dtype=np.uint64)
55+
shrd = np.zeros(1, dtype=np.uint64)
56+
f[1, 1](loc, shrd, which)
57+
58+
if alignment is not None:
59+
address = loc[0] if which == 0 else shrd[0]
60+
alignment_mod = int(address % alignment)
61+
self.assertEqual(alignment_mod, 0)
62+
63+
if NOISY:
64+
print('.', end='', flush=True)
65+
66+
def test_invalid_aligments(self):
67+
shapes = (1, 3, 4, 8, 9, 50)
68+
dtypes = (np.uint8, np.uint16, np.uint32, np.uint64)
69+
alignments = (-1, 0, 3, 5, 7, 9, 15, 17, 31, 33, 63, 65)
70+
array_types = [(0, 'local'), (1, 'shared')]
71+
72+
items = itertools.product(array_types, shapes, dtypes, alignments)
73+
74+
for (which, array_type), shape, dtype, alignment in items:
75+
with self.subTest(array_type=array_type, shape=shape,
76+
dtype=dtype, alignment=alignment):
77+
@cuda.jit
78+
def f(local_array, shared_array, which):
79+
i = cuda.grid(1)
80+
if which == 0:
81+
local_array = cuda.local.array(
82+
shape=shape,
83+
dtype=dtype,
84+
alignment=alignment,
85+
)
86+
if i == 0:
87+
local_array[0] = local_array.ctypes.data
88+
else:
89+
shared_array = cuda.shared.array(
90+
shape=shape,
91+
dtype=dtype,
92+
alignment=alignment,
93+
)
94+
if i == 0:
95+
shared_array[0] = shared_array.ctypes.data
96+
97+
loc = np.zeros(1, dtype=np.uint64)
98+
shrd = np.zeros(1, dtype=np.uint64)
99+
100+
with self.assertRaises(TypingError) as raises:
101+
f[1, 1](loc, shrd, which)
102+
exc = str(raises.exception)
103+
self.assertIn("Alignment must be", exc)
104+
105+
if NOISY:
106+
print('.', end='', flush=True)
107+
108+
def test_array_like(self):
109+
# XXX-140: TODO; need to flush out the array_like stuff more.
110+
pass
111+
112+
113+
if __name__ == '__main__':
114+
unittest.main()

0 commit comments

Comments
 (0)