Skip to content

Commit b06e183

Browse files
authored
fix: warp vote operations must use a constant int for the mode parameter (#606)
Fixes #592. Followed similar pattern as #231.
1 parent 9212523 commit b06e183

File tree

7 files changed

+270
-86
lines changed

7 files changed

+270
-86
lines changed

numba_cuda/numba/cuda/cudadecl.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -103,16 +103,6 @@ class Cuda_syncwarp(ConcreteTemplate):
103103
cases = [signature(types.none), signature(types.none, types.i4)]
104104

105105

106-
@register
107-
class Cuda_vote_sync_intrinsic(ConcreteTemplate):
108-
key = cuda.vote_sync_intrinsic
109-
cases = [
110-
signature(
111-
types.Tuple((types.i4, types.b1)), types.i4, types.i4, types.b1
112-
)
113-
]
114-
115-
116106
@register
117107
class Cuda_match_any_sync(ConcreteTemplate):
118108
key = cuda.match_any_sync
@@ -522,9 +512,6 @@ def resolve_threadfence_system(self, mod):
522512
def resolve_syncwarp(self, mod):
523513
return types.Function(Cuda_syncwarp)
524514

525-
def resolve_vote_sync_intrinsic(self, mod):
526-
return types.Function(Cuda_vote_sync_intrinsic)
527-
528515
def resolve_match_any_sync(self, mod):
529516
return types.Function(Cuda_match_any_sync)
530517

numba_cuda/numba/cuda/cudaimpl.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -280,18 +280,6 @@ def ptx_syncwarp_mask(context, builder, sig, args):
280280
return context.get_dummy_value()
281281

282282

283-
@lower(stubs.vote_sync_intrinsic, types.i4, types.i4, types.boolean)
284-
def ptx_vote_sync(context, builder, sig, args):
285-
fname = "llvm.nvvm.vote.sync"
286-
lmod = builder.module
287-
fnty = ir.FunctionType(
288-
ir.LiteralStructType((ir.IntType(32), ir.IntType(1))),
289-
(ir.IntType(32), ir.IntType(32), ir.IntType(1)),
290-
)
291-
func = cgutils.get_or_insert_function(lmod, fnty, fname)
292-
return builder.call(func, args)
293-
294-
295283
@lower(stubs.match_any_sync, types.i4, types.i4)
296284
@lower(stubs.match_any_sync, types.i4, types.i8)
297285
@lower(stubs.match_any_sync, types.i4, types.f4)

numba_cuda/numba/cuda/device_init.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
local,
2828
const,
2929
atomic,
30-
vote_sync_intrinsic,
3130
match_any_sync,
3231
match_all_sync,
3332
threadfence_block,
@@ -56,6 +55,10 @@
5655
shfl_up_sync,
5756
shfl_down_sync,
5857
shfl_xor_sync,
58+
all_sync,
59+
any_sync,
60+
eq_sync,
61+
ballot_sync,
5962
)
6063
from .cudadrv.error import CudaSupportError
6164
from numba.cuda.cudadrv.driver import (
@@ -79,12 +82,6 @@
7982
from .api import _auto_device
8083
from .args import In, Out, InOut
8184

82-
from .intrinsic_wrapper import (
83-
all_sync,
84-
any_sync,
85-
eq_sync,
86-
ballot_sync,
87-
)
8885

8986
from .kernels import reduction
9087
from numba.cuda.cudadrv.linkable_code import (

numba_cuda/numba/cuda/intrinsic_wrapper.py

Lines changed: 0 additions & 41 deletions
This file was deleted.

numba_cuda/numba/cuda/intrinsics.py

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
from numba import cuda
77
from numba.cuda import types
88
from numba.cuda import cgutils
9-
from numba.cuda.core.errors import RequireLiteralValue, TypingError
9+
from numba.cuda.core.errors import (
10+
RequireLiteralValue,
11+
TypingError,
12+
NumbaTypeError,
13+
)
1014
from numba.cuda.typing import signature
1115
from numba.cuda.extending import overload_attribute, overload_method
1216
from numba.cuda import nvvmutils
@@ -380,3 +384,148 @@ def codegen(context, builder, sig, args):
380384
sig = signature(a_type, membermask_type, a_type, b_type)
381385

382386
return sig, codegen
387+
388+
389+
# -------------------------------------------------------------------------------
390+
# Warp vote functions
391+
#
392+
# References:
393+
#
394+
# - https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#warp-vote-functions
395+
# - https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html?highlight=data%2520movement#vote
396+
#
397+
# Notes:
398+
#
399+
# - The NVVM IR specification requires some of the mode parameter to be
400+
# constants. It's therefore essential that we pass in mode values to the
401+
# vote_sync_intrinsic.
402+
403+
404+
@intrinsic
405+
def all_sync(typingctx, mask_type, predicate_type):
406+
"""
407+
If for all threads in the masked warp the predicate is true, then
408+
a non-zero value is returned, otherwise 0 is returned.
409+
"""
410+
mode_value = 0
411+
sig, codegen_inner = vote_sync_intrinsic(
412+
typingctx, mask_type, mode_value, predicate_type
413+
)
414+
415+
def codegen(context, builder, sig_outer, args):
416+
# Call vote_sync_intrinsic and extract the boolean result (index 1)
417+
result_tuple = codegen_inner(context, builder, sig, args)
418+
return builder.extract_value(result_tuple, 1)
419+
420+
sig_outer = signature(types.b1, mask_type, predicate_type)
421+
return sig_outer, codegen
422+
423+
424+
@intrinsic
425+
def any_sync(typingctx, mask_type, predicate_type):
426+
"""
427+
If for any thread in the masked warp the predicate is true, then
428+
a non-zero value is returned, otherwise 0 is returned.
429+
"""
430+
mode_value = 1
431+
sig, codegen_inner = vote_sync_intrinsic(
432+
typingctx, mask_type, mode_value, predicate_type
433+
)
434+
435+
def codegen(context, builder, sig_outer, args):
436+
result_tuple = codegen_inner(context, builder, sig, args)
437+
return builder.extract_value(result_tuple, 1)
438+
439+
sig_outer = signature(types.b1, mask_type, predicate_type)
440+
return sig_outer, codegen
441+
442+
443+
@intrinsic
444+
def eq_sync(typingctx, mask_type, predicate_type):
445+
"""
446+
If for all threads in the masked warp the boolean predicate is the same,
447+
then a non-zero value is returned, otherwise 0 is returned.
448+
"""
449+
mode_value = 2
450+
sig, codegen_inner = vote_sync_intrinsic(
451+
typingctx, mask_type, mode_value, predicate_type
452+
)
453+
454+
def codegen(context, builder, sig_outer, args):
455+
result_tuple = codegen_inner(context, builder, sig, args)
456+
return builder.extract_value(result_tuple, 1)
457+
458+
sig_outer = signature(types.b1, mask_type, predicate_type)
459+
return sig_outer, codegen
460+
461+
462+
@intrinsic
463+
def ballot_sync(typingctx, mask_type, predicate_type):
464+
"""
465+
Returns a mask of all threads in the warp whose predicate is true,
466+
and are within the given mask.
467+
"""
468+
mode_value = 3
469+
sig, codegen_inner = vote_sync_intrinsic(
470+
typingctx, mask_type, mode_value, predicate_type
471+
)
472+
473+
def codegen(context, builder, sig_outer, args):
474+
result_tuple = codegen_inner(context, builder, sig, args)
475+
return builder.extract_value(
476+
result_tuple, 0
477+
) # Extract ballot result (index 0)
478+
479+
sig_outer = signature(types.i4, mask_type, predicate_type)
480+
return sig_outer, codegen
481+
482+
483+
def vote_sync_intrinsic(typingctx, mask_type, mode_value, predicate_type):
484+
# Validate mode value
485+
if mode_value not in (0, 1, 2, 3):
486+
raise ValueError("Mode must be 0 (all), 1 (any), 2 (eq), or 3 (ballot)")
487+
488+
if types.unliteral(mask_type) not in types.integer_domain:
489+
raise NumbaTypeError(f"Mask type must be an integer. Got {mask_type}")
490+
predicate_types = types.integer_domain | {types.boolean}
491+
492+
if types.unliteral(predicate_type) not in predicate_types:
493+
raise NumbaTypeError(
494+
f"Predicate must be an integer or boolean. Got {predicate_type}"
495+
)
496+
497+
def codegen(context, builder, sig, args):
498+
mask, predicate = args
499+
500+
# Types
501+
i1 = ir.IntType(1)
502+
i32 = ir.IntType(32)
503+
504+
# NVVM intrinsic definition
505+
arg_types = (i32, i32, i1)
506+
vote_return_type = ir.LiteralStructType((i32, i1))
507+
fnty = ir.FunctionType(vote_return_type, arg_types)
508+
509+
fname = "llvm.nvvm.vote.sync"
510+
lmod = builder.module
511+
vote_sync = cgutils.get_or_insert_function(lmod, fnty, fname)
512+
513+
# Intrinsic arguments
514+
mode = ir.Constant(i32, mode_value)
515+
mask_i32 = builder.trunc(mask, i32)
516+
517+
# Convert predicate to i1
518+
if predicate.type != ir.IntType(1):
519+
predicate_bool = builder.icmp_signed(
520+
"!=", predicate, ir.Constant(predicate.type, 0)
521+
)
522+
else:
523+
predicate_bool = predicate
524+
525+
return builder.call(vote_sync, [mask_i32, mode, predicate_bool])
526+
527+
sig = signature(
528+
types.Tuple((types.i4, types.b1)), mask_type, predicate_type
529+
)
530+
531+
return sig, codegen

numba_cuda/numba/cuda/stubs.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -200,17 +200,6 @@ class syncwarp(Stub):
200200
_description_ = "<warp_sync()>"
201201

202202

203-
class vote_sync_intrinsic(Stub):
204-
"""
205-
vote_sync_intrinsic(mask, mode, predictate)
206-
207-
Nvvm intrinsic for performing a reduce and broadcast across a warp
208-
docs.nvidia.com/cuda/nvvm-ir-spec/index.html#nvvm-intrin-warp-level-vote
209-
"""
210-
211-
_description_ = "<vote_sync()>"
212-
213-
214203
class match_any_sync(Stub):
215204
"""
216205
match_any_sync(mask, value)

0 commit comments

Comments
 (0)