Skip to content

Commit 02df550

Browse files
isVoidgmarkall
andauthored
Fix Invalid NVVM IR emitted when lowering shfl_sync APIs (#231)
As per the NVVM documentation, shuffle APIs take an IR constant for the `mode` parameter. In the current Numba implementation, it is a variable. This could crash NVVM because the constant folds are not applied to the IR until optimization passes are run. This PR fixes #228. --------- Co-authored-by: Graham Markall <[email protected]>
1 parent a6a6374 commit 02df550

File tree

7 files changed

+226
-166
lines changed

7 files changed

+226
-166
lines changed

numba_cuda/numba/cuda/cudadecl.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -100,45 +100,6 @@ class Cuda_syncwarp(ConcreteTemplate):
100100
cases = [signature(types.none), signature(types.none, types.i4)]
101101

102102

103-
@register
104-
class Cuda_shfl_sync_intrinsic(ConcreteTemplate):
105-
key = cuda.shfl_sync_intrinsic
106-
cases = [
107-
signature(
108-
types.Tuple((types.i4, types.b1)),
109-
types.i4,
110-
types.i4,
111-
types.i4,
112-
types.i4,
113-
types.i4,
114-
),
115-
signature(
116-
types.Tuple((types.i8, types.b1)),
117-
types.i4,
118-
types.i4,
119-
types.i8,
120-
types.i4,
121-
types.i4,
122-
),
123-
signature(
124-
types.Tuple((types.f4, types.b1)),
125-
types.i4,
126-
types.i4,
127-
types.f4,
128-
types.i4,
129-
types.i4,
130-
),
131-
signature(
132-
types.Tuple((types.f8, types.b1)),
133-
types.i4,
134-
types.i4,
135-
types.f8,
136-
types.i4,
137-
types.i4,
138-
),
139-
]
140-
141-
142103
@register
143104
class Cuda_vote_sync_intrinsic(ConcreteTemplate):
144105
key = cuda.vote_sync_intrinsic
@@ -815,9 +776,6 @@ def resolve_threadfence_system(self, mod):
815776
def resolve_syncwarp(self, mod):
816777
return types.Function(Cuda_syncwarp)
817778

818-
def resolve_shfl_sync_intrinsic(self, mod):
819-
return types.Function(Cuda_shfl_sync_intrinsic)
820-
821779
def resolve_vote_sync_intrinsic(self, mod):
822780
return types.Function(Cuda_vote_sync_intrinsic)
823781

numba_cuda/numba/cuda/cudaimpl.py

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -204,69 +204,6 @@ def ptx_syncwarp_mask(context, builder, sig, args):
204204
return context.get_dummy_value()
205205

206206

207-
@lower(
208-
stubs.shfl_sync_intrinsic, types.i4, types.i4, types.i4, types.i4, types.i4
209-
)
210-
@lower(
211-
stubs.shfl_sync_intrinsic, types.i4, types.i4, types.i8, types.i4, types.i4
212-
)
213-
@lower(
214-
stubs.shfl_sync_intrinsic, types.i4, types.i4, types.f4, types.i4, types.i4
215-
)
216-
@lower(
217-
stubs.shfl_sync_intrinsic, types.i4, types.i4, types.f8, types.i4, types.i4
218-
)
219-
def ptx_shfl_sync_i32(context, builder, sig, args):
220-
"""
221-
The NVVM intrinsic for shfl only supports i32, but the cuda intrinsic
222-
function supports both 32 and 64 bit ints and floats, so for feature parity,
223-
i64, f32, and f64 are implemented. Floats by way of bitcasting the float to
224-
an int, then shuffling, then bitcasting back. And 64-bit values by packing
225-
them into 2 32bit values, shuffling thoose, and then packing back together.
226-
"""
227-
mask, mode, value, index, clamp = args
228-
value_type = sig.args[2]
229-
if value_type in types.real_domain:
230-
value = builder.bitcast(value, ir.IntType(value_type.bitwidth))
231-
fname = "llvm.nvvm.shfl.sync.i32"
232-
lmod = builder.module
233-
fnty = ir.FunctionType(
234-
ir.LiteralStructType((ir.IntType(32), ir.IntType(1))),
235-
(
236-
ir.IntType(32),
237-
ir.IntType(32),
238-
ir.IntType(32),
239-
ir.IntType(32),
240-
ir.IntType(32),
241-
),
242-
)
243-
func = cgutils.get_or_insert_function(lmod, fnty, fname)
244-
if value_type.bitwidth == 32:
245-
ret = builder.call(func, (mask, mode, value, index, clamp))
246-
if value_type == types.float32:
247-
rv = builder.extract_value(ret, 0)
248-
pred = builder.extract_value(ret, 1)
249-
fv = builder.bitcast(rv, ir.FloatType())
250-
ret = cgutils.make_anonymous_struct(builder, (fv, pred))
251-
else:
252-
value1 = builder.trunc(value, ir.IntType(32))
253-
value_lshr = builder.lshr(value, context.get_constant(types.i8, 32))
254-
value2 = builder.trunc(value_lshr, ir.IntType(32))
255-
ret1 = builder.call(func, (mask, mode, value1, index, clamp))
256-
ret2 = builder.call(func, (mask, mode, value2, index, clamp))
257-
rv1 = builder.extract_value(ret1, 0)
258-
rv2 = builder.extract_value(ret2, 0)
259-
pred = builder.extract_value(ret1, 1)
260-
rv1_64 = builder.zext(rv1, ir.IntType(64))
261-
rv2_64 = builder.zext(rv2, ir.IntType(64))
262-
rv_shl = builder.shl(rv2_64, context.get_constant(types.i8, 32))
263-
rv = builder.or_(rv_shl, rv1_64)
264-
if value_type == types.float64:
265-
rv = builder.bitcast(rv, ir.DoubleType())
266-
ret = cgutils.make_anonymous_struct(builder, (rv, pred))
267-
return ret
268-
269-
270207
@lower(stubs.vote_sync_intrinsic, types.i4, types.i4, types.boolean)
271208
def ptx_vote_sync(context, builder, sig, args):
272209
fname = "llvm.nvvm.vote.sync"

numba_cuda/numba/cuda/device_init.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
local,
1414
const,
1515
atomic,
16-
shfl_sync_intrinsic,
1716
vote_sync_intrinsic,
1817
match_any_sync,
1918
match_all_sync,
@@ -40,6 +39,10 @@
4039
syncthreads_and,
4140
syncthreads_count,
4241
syncthreads_or,
42+
shfl_sync,
43+
shfl_up_sync,
44+
shfl_down_sync,
45+
shfl_xor_sync,
4346
)
4447
from .cudadrv.error import CudaSupportError
4548
from numba.cuda.cudadrv.driver import (
@@ -68,10 +71,6 @@
6871
any_sync,
6972
eq_sync,
7073
ballot_sync,
71-
shfl_sync,
72-
shfl_up_sync,
73-
shfl_down_sync,
74-
shfl_xor_sync,
7574
)
7675

7776
from .kernels import reduction

numba_cuda/numba/cuda/intrinsic_wrapper.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -36,42 +36,3 @@ def ballot_sync(mask, predicate):
3636
and are within the given mask.
3737
"""
3838
return numba.cuda.vote_sync_intrinsic(mask, 3, predicate)[0]
39-
40-
41-
@jit(device=True)
42-
def shfl_sync(mask, value, src_lane):
43-
"""
44-
Shuffles value across the masked warp and returns the value
45-
from src_lane. If this is outside the warp, then the
46-
given value is returned.
47-
"""
48-
return numba.cuda.shfl_sync_intrinsic(mask, 0, value, src_lane, 0x1F)[0]
49-
50-
51-
@jit(device=True)
52-
def shfl_up_sync(mask, value, delta):
53-
"""
54-
Shuffles value across the masked warp and returns the value
55-
from (laneid - delta). If this is outside the warp, then the
56-
given value is returned.
57-
"""
58-
return numba.cuda.shfl_sync_intrinsic(mask, 1, value, delta, 0)[0]
59-
60-
61-
@jit(device=True)
62-
def shfl_down_sync(mask, value, delta):
63-
"""
64-
Shuffles value across the masked warp and returns the value
65-
from (laneid + delta). If this is outside the warp, then the
66-
given value is returned.
67-
"""
68-
return numba.cuda.shfl_sync_intrinsic(mask, 2, value, delta, 0x1F)[0]
69-
70-
71-
@jit(device=True)
72-
def shfl_xor_sync(mask, value, lane_mask):
73-
"""
74-
Shuffles value across the masked warp and returns the value
75-
from (laneid ^ lane_mask).
76-
"""
77-
return numba.cuda.shfl_sync_intrinsic(mask, 3, value, lane_mask, 0x1F)[0]

numba_cuda/numba/cuda/intrinsics.py

Lines changed: 172 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from numba import cuda, types
44
from numba.core import cgutils
5-
from numba.core.errors import RequireLiteralValue
5+
from numba.core.errors import RequireLiteralValue, TypingError
66
from numba.core.typing import signature
77
from numba.core.extending import overload_attribute, overload_method
88
from numba.cuda import nvvmutils
@@ -205,3 +205,174 @@ def syncthreads_or(typingctx, predicate):
205205
@overload_method(types.Integer, "bit_count", target="cuda")
206206
def integer_bit_count(i):
207207
return lambda i: cuda.popc(i)
208+
209+
210+
# -------------------------------------------------------------------------------
211+
# Warp shuffle functions
212+
#
213+
# References:
214+
#
215+
# - https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#warp-shuffle-functions
216+
# - https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#data-movement
217+
#
218+
# Notes:
219+
#
220+
# - The public CUDA C/C++ and Numba Python APIs for these intrinsics use
221+
# different names for parameters to the NVVM IR specification. So that we
222+
# can correlate the implementation with the documentation, the @intrinsic
223+
# API functions map the public API arguments to the NVVM intrinsic
224+
# arguments.
225+
# - The NVVM IR specification requires some of the parameters (e.g. mode) to be
226+
# constants. It's therefore essential that we pass in some values to the
227+
# shfl_sync_intrinsic function (e.g. the mode and c values).
228+
# - Normally parameters for intrinsic functions in Numba would be given the
229+
# same name as used in the API, and would contain a type. However, because we
230+
# have to pass in some values and some times (and there is divergence between
231+
# the names in the intrinsic documentation and the public APIs) we instead
232+
# follow the convention of naming shfl_sync_intrinsic parameters with a
233+
# suffix of _type or _value depending on whether they contain a type or a
234+
# value.
235+
236+
237+
@intrinsic
238+
def shfl_sync(typingctx, mask, value, src_lane):
239+
"""
240+
Shuffles ``value`` across the masked warp and returns the value from
241+
``src_lane``. If this is outside the warp, then the given value is
242+
returned.
243+
"""
244+
membermask_type = mask
245+
mode_value = 0
246+
a_type = value
247+
b_type = src_lane
248+
c_value = 0x1F
249+
return shfl_sync_intrinsic(
250+
typingctx, membermask_type, mode_value, a_type, b_type, c_value
251+
)
252+
253+
254+
@intrinsic
255+
def shfl_up_sync(typingctx, mask, value, delta):
256+
"""
257+
Shuffles ``value`` across the masked warp and returns the value from
258+
``(laneid - delta)``. If this is outside the warp, then the given value is
259+
returned.
260+
"""
261+
membermask_type = mask
262+
mode_value = 1
263+
a_type = value
264+
b_type = delta
265+
c_value = 0
266+
return shfl_sync_intrinsic(
267+
typingctx, membermask_type, mode_value, a_type, b_type, c_value
268+
)
269+
270+
271+
@intrinsic
272+
def shfl_down_sync(typingctx, mask, value, delta):
273+
"""
274+
Shuffles ``value`` across the masked warp and returns the value from
275+
``(laneid + delta)``. If this is outside the warp, then the given value is
276+
returned.
277+
"""
278+
membermask_type = mask
279+
mode_value = 2
280+
a_type = value
281+
b_type = delta
282+
c_value = 0x1F
283+
return shfl_sync_intrinsic(
284+
typingctx, membermask_type, mode_value, a_type, b_type, c_value
285+
)
286+
287+
288+
@intrinsic
289+
def shfl_xor_sync(typingctx, mask, value, lane_mask):
290+
"""
291+
Shuffles ``value`` across the masked warp and returns the value from
292+
``(laneid ^ lane_mask)``.
293+
"""
294+
membermask_type = mask
295+
mode_value = 3
296+
a_type = value
297+
b_type = lane_mask
298+
c_value = 0x1F
299+
return shfl_sync_intrinsic(
300+
typingctx, membermask_type, mode_value, a_type, b_type, c_value
301+
)
302+
303+
304+
def shfl_sync_intrinsic(
305+
typingctx,
306+
membermask_type,
307+
mode_value,
308+
a_type,
309+
b_type,
310+
c_value,
311+
):
312+
if a_type not in (types.i4, types.i8, types.f4, types.f8):
313+
raise TypingError(
314+
"shfl_sync only supports 32- and 64-bit ints and floats"
315+
)
316+
317+
def codegen(context, builder, sig, args):
318+
"""
319+
The NVVM shfl_sync intrinsic only supports i32, but the CUDA C/C++
320+
intrinsic supports both 32- and 64-bit ints and floats, so for feature
321+
parity, i32, i64, f32, and f64 are implemented. Floats by way of
322+
bitcasting the float to an int, then shuffling, then bitcasting
323+
back."""
324+
membermask, a, b = args
325+
326+
# Types
327+
a_type = sig.args[1]
328+
return_type = context.get_value_type(sig.return_type)
329+
i32 = ir.IntType(32)
330+
i64 = ir.IntType(64)
331+
332+
if a_type in types.real_domain:
333+
a = builder.bitcast(a, ir.IntType(a_type.bitwidth))
334+
335+
# NVVM intrinsic definition
336+
arg_types = (i32, i32, i32, i32, i32)
337+
shfl_return_type = ir.LiteralStructType((i32, ir.IntType(1)))
338+
fnty = ir.FunctionType(shfl_return_type, arg_types)
339+
340+
fname = "llvm.nvvm.shfl.sync.i32"
341+
shfl_sync = cgutils.get_or_insert_function(builder.module, fnty, fname)
342+
343+
# Intrinsic arguments
344+
mode = ir.Constant(i32, mode_value)
345+
c = ir.Constant(i32, c_value)
346+
membermask = builder.trunc(membermask, i32)
347+
b = builder.trunc(b, i32)
348+
349+
if a_type.bitwidth == 32:
350+
a = builder.trunc(a, i32)
351+
ret = builder.call(shfl_sync, (membermask, mode, a, b, c))
352+
d = builder.extract_value(ret, 0)
353+
else:
354+
# Handle 64-bit values by shuffling as two 32-bit values and
355+
# packing the result into 64 bits.
356+
357+
# Extract high and low parts
358+
lo = builder.trunc(a, i32)
359+
a_lshr = builder.lshr(a, ir.Constant(i64, 32))
360+
hi = builder.trunc(a_lshr, i32)
361+
362+
# Shuffle individual parts
363+
ret_lo = builder.call(shfl_sync, (membermask, mode, lo, b, c))
364+
ret_hi = builder.call(shfl_sync, (membermask, mode, hi, b, c))
365+
366+
# Combine individual result parts into a 64-bit result
367+
d_lo = builder.extract_value(ret_lo, 0)
368+
d_hi = builder.extract_value(ret_hi, 0)
369+
d_lo_64 = builder.zext(d_lo, i64)
370+
d_hi_64 = builder.zext(d_hi, i64)
371+
d_shl = builder.shl(d_hi_64, ir.Constant(i64, 32))
372+
d = builder.or_(d_shl, d_lo_64)
373+
374+
return builder.bitcast(d, return_type)
375+
376+
sig = signature(a_type, membermask_type, a_type, b_type)
377+
378+
return sig, codegen

0 commit comments

Comments
 (0)