Skip to content

Commit 721f88c

Browse files
committed
Implement alignment support for local and shared arrays.
1 parent 7b44bd3 commit 721f88c

File tree

3 files changed

+183
-12
lines changed

3 files changed

+183
-12
lines changed

numba_cuda/numba/cuda/cudadecl.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
class Cuda_array_decl(CallableTemplate):
2727
def generic(self):
28-
def typer(shape, dtype):
28+
def typer(shape, dtype, alignment=None):
2929

3030
# Only integer literals and tuples of integer literals are valid
3131
# shapes
@@ -39,10 +39,13 @@ def typer(shape, dtype):
3939
else:
4040
return None
4141

42+
# N.B. alignment validation happens in types.Array().
43+
4244
ndim = parse_shape(shape)
4345
nb_dtype = parse_dtype(dtype)
4446
if nb_dtype is not None and ndim is not None:
45-
return types.Array(dtype=nb_dtype, ndim=ndim, layout='C')
47+
return types.Array(dtype=nb_dtype, ndim=ndim, layout='C',
48+
alignment=alignment)
4649

4750
return typer
4851

numba_cuda/numba/cuda/cudaimpl.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,14 @@ def dim3_z(context, builder, sig, args):
7272
# -----------------------------------------------------------------------------
7373

7474
@lower(cuda.const.array_like, types.Array)
75+
@lower(cuda.const.array_like, types.Array, types.IntegerLiteral)
76+
@lower(cuda.const.array_like, types.Array, types.NoneType)
7577
def cuda_const_array_like(context, builder, sig, args):
7678
# This is a no-op because CUDATargetContext.make_constant_array already
7779
# created the constant array.
80+
if len(sig.args) > 1:
81+
# XXX-140: How do we handle alignment here?
82+
pass
7883
return args[0]
7984

8085

@@ -92,34 +97,62 @@ def _get_unique_smem_id(name):
9297

9398

9499
@lower(cuda.shared.array, types.IntegerLiteral, types.Any)
100+
@lower(cuda.shared.array, types.IntegerLiteral, types.Any, types.IntegerLiteral)
101+
@lower(cuda.shared.array, types.IntegerLiteral, types.Any, types.NoneType)
95102
def cuda_shared_array_integer(context, builder, sig, args):
96103
length = sig.args[0].literal_value
97104
dtype = parse_dtype(sig.args[1])
105+
alignment = None
106+
if len(sig.args) == 3:
107+
try:
108+
alignment = sig.args[2].literal_value
109+
except AttributeError:
110+
pass
98111
return _generic_array(context, builder, shape=(length,), dtype=dtype,
99112
symbol_name=_get_unique_smem_id('_cudapy_smem'),
100113
addrspace=nvvm.ADDRSPACE_SHARED,
101-
can_dynsized=True)
114+
can_dynsized=True, alignment=alignment)
102115

103116

117+
# XXX-140: Should I just use types.Any for the last alignment arg?
118+
104119
@lower(cuda.shared.array, types.Tuple, types.Any)
105120
@lower(cuda.shared.array, types.UniTuple, types.Any)
121+
@lower(cuda.shared.array, types.Tuple, types.Any, types.IntegerLiteral)
122+
@lower(cuda.shared.array, types.UniTuple, types.Any, types.IntegerLiteral)
123+
@lower(cuda.shared.array, types.Tuple, types.Any, types.NoneType)
124+
@lower(cuda.shared.array, types.UniTuple, types.Any, types.NoneType)
106125
def cuda_shared_array_tuple(context, builder, sig, args):
107126
shape = [ s.literal_value for s in sig.args[0] ]
108127
dtype = parse_dtype(sig.args[1])
128+
alignment = None
129+
if len(sig.args) == 3:
130+
try:
131+
alignment = sig.args[2].literal_value
132+
except AttributeError:
133+
pass
109134
return _generic_array(context, builder, shape=shape, dtype=dtype,
110135
symbol_name=_get_unique_smem_id('_cudapy_smem'),
111136
addrspace=nvvm.ADDRSPACE_SHARED,
112-
can_dynsized=True)
137+
can_dynsized=True, alignment=alignment)
113138

114139

115140
@lower(cuda.local.array, types.IntegerLiteral, types.Any)
141+
@lower(cuda.local.array, types.IntegerLiteral, types.Any, types.IntegerLiteral)
142+
@lower(cuda.local.array, types.IntegerLiteral, types.Any, types.NoneType)
116143
def cuda_local_array_integer(context, builder, sig, args):
117144
length = sig.args[0].literal_value
118145
dtype = parse_dtype(sig.args[1])
146+
alignment = None
147+
if len(sig.args) == 3:
148+
try:
149+
alignment = sig.args[2].literal_value
150+
except AttributeError:
151+
pass
119152
return _generic_array(context, builder, shape=(length,), dtype=dtype,
120153
symbol_name='_cudapy_lmem',
121154
addrspace=nvvm.ADDRSPACE_LOCAL,
122-
can_dynsized=False)
155+
can_dynsized=False, alignment=alignment)
123156

124157

125158
@lower(cuda.local.array, types.Tuple, types.Any)
@@ -954,7 +987,7 @@ def ptx_nanosleep(context, builder, sig, args):
954987

955988

956989
def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
957-
can_dynsized=False):
990+
can_dynsized=False, alignment=None):
958991
elemcount = reduce(operator.mul, shape, 1)
959992

960993
# Check for valid shape for this type of allocation.
@@ -981,17 +1014,37 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
9811014
# NVVM is smart enough to only use local memory if no register is
9821015
# available
9831016
dataptr = cgutils.alloca_once(builder, laryty, name=symbol_name)
1017+
1018+
# If the user has specified a custom alignment, just set the align
1019+
# attribute on the alloca IR directly. We don't do any additional
1020+
# hand-holding here like checking the underlying data type's alignment
1021+
# or rounding up to the next power of 2--those checks will have already
1022+
# been done by the time we see the alignment value.
1023+
if alignment is not None:
1024+
dataptr.align = alignment
9841025
else:
9851026
lmod = builder.module
9861027

9871028
# Create global variable in the requested address space
9881029
gvmem = cgutils.add_global_variable(lmod, laryty, symbol_name,
9891030
addrspace)
990-
# Specify alignment to avoid misalignment bug
991-
align = context.get_abi_sizeof(lldtype)
992-
# Alignment is required to be a power of 2 for shared memory. If it is
993-
# not a power of 2 (e.g. for a Record array) then round up accordingly.
994-
gvmem.align = 1 << (align - 1 ).bit_length()
1031+
1032+
# If the caller hasn't specified a custom alignment, obtain the
1033+
# underlying dtype alignment from the ABI and then round it up to
1034+
# a power of two. Otherwise, just use the caller's alignment.
1035+
#
1036+
# N.B. The caller *could* provide a valid-but-smaller-than-natural
1037+
# alignment here; we'll assume the caller knows what they're
1038+
# doing and let that through without error.
1039+
1040+
if alignment is None:
1041+
abi_alignment = context.get_abi_alignment(lldtype)
1042+
# Ensure a power of two alignment.
1043+
actual_alignment = 1 << (abi_alignment - 1).bit_length()
1044+
else:
1045+
actual_alignment = alignment
1046+
1047+
gvmem.align = actual_alignment
9951048

9961049
if dynamic_smem:
9971050
gvmem.linkage = 'external'
@@ -1041,7 +1094,8 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
10411094

10421095
# Create array object
10431096
ndim = len(shape)
1044-
aryty = types.Array(dtype=dtype, ndim=ndim, layout='C')
1097+
aryty = types.Array(dtype=dtype, ndim=ndim, layout='C',
1098+
alignment=alignment)
10451099
ary = context.make_array(aryty)(context, builder)
10461100

10471101
context.populate_array(ary,
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import itertools
2+
import numpy as np
3+
from numba import cuda
4+
from numba.cuda.testing import CUDATestCase
5+
from numba.core.errors import TypingError
6+
import unittest
7+
8+
9+
# Set to true if you want to see dots printed for each subtest.
10+
NOISY = True
11+
12+
13+
# N.B. We name the test class TestArrayAddressAlignment to avoid name conflict
14+
# with the test_alignment.TestArrayAlignment class.
15+
16+
17+
class TestArrayAddressAlignment(CUDATestCase):
18+
"""
19+
Test cuda.local.array and cuda.shared.array support for an alignment
20+
keyword argument.
21+
"""
22+
23+
def test_array_alignment(self):
24+
shapes = (1, 3, 4, 8, 9, 50)
25+
dtypes = (np.uint8, np.uint16, np.uint32, np.uint64)
26+
alignments = (None, 8, 16, 32, 64, 128, 256)
27+
array_types = [(0, 'local'), (1, 'shared')]
28+
29+
items = itertools.product(array_types, shapes, dtypes, alignments)
30+
31+
for (which, array_type), shape, dtype, alignment in items:
32+
with self.subTest(array_type=array_type, shape=shape,
33+
dtype=dtype, alignment=alignment):
34+
@cuda.jit
35+
def f(loc, shrd, which):
36+
i = cuda.grid(1)
37+
if which == 0:
38+
local_array = cuda.local.array(
39+
shape=shape,
40+
dtype=dtype,
41+
alignment=alignment,
42+
)
43+
if i == 0:
44+
loc[0] = local_array.ctypes.data
45+
else:
46+
shared_array = cuda.shared.array(
47+
shape=shape,
48+
dtype=dtype,
49+
alignment=alignment,
50+
)
51+
if i == 0:
52+
shrd[0] = shared_array.ctypes.data
53+
54+
loc = np.zeros(1, dtype=np.uint64)
55+
shrd = np.zeros(1, dtype=np.uint64)
56+
f[1, 1](loc, shrd, which)
57+
58+
if alignment is not None:
59+
address = loc[0] if which == 0 else shrd[0]
60+
alignment_mod = int(address % alignment)
61+
self.assertEqual(alignment_mod, 0)
62+
63+
if NOISY:
64+
print('.', end='', flush=True)
65+
66+
def test_invalid_aligments(self):
67+
shapes = (1, 3, 4, 8, 9, 50)
68+
dtypes = (np.uint8, np.uint16, np.uint32, np.uint64)
69+
alignments = (-1, 0, 3, 5, 7, 9, 15, 17, 31, 33, 63, 65)
70+
array_types = [(0, 'local'), (1, 'shared')]
71+
72+
items = itertools.product(array_types, shapes, dtypes, alignments)
73+
74+
for (which, array_type), shape, dtype, alignment in items:
75+
with self.subTest(array_type=array_type, shape=shape,
76+
dtype=dtype, alignment=alignment):
77+
@cuda.jit
78+
def f(local_array, shared_array, which):
79+
i = cuda.grid(1)
80+
if which == 0:
81+
local_array = cuda.local.array(
82+
shape=shape,
83+
dtype=dtype,
84+
alignment=alignment,
85+
)
86+
if i == 0:
87+
local_array[0] = local_array.ctypes.data
88+
else:
89+
shared_array = cuda.shared.array(
90+
shape=shape,
91+
dtype=dtype,
92+
alignment=alignment,
93+
)
94+
if i == 0:
95+
shared_array[0] = shared_array.ctypes.data
96+
97+
loc = np.zeros(1, dtype=np.uint64)
98+
shrd = np.zeros(1, dtype=np.uint64)
99+
100+
with self.assertRaises(TypingError) as raises:
101+
f[1, 1](loc, shrd, which)
102+
exc = str(raises.exception)
103+
self.assertIn("Alignment must be", exc)
104+
105+
if NOISY:
106+
print('.', end='', flush=True)
107+
108+
def test_array_like(self):
109+
# XXX-140: TODO; need to flush out the array_like stuff more.
110+
pass
111+
112+
113+
if __name__ == '__main__':
114+
unittest.main()

0 commit comments

Comments
 (0)