|
6 | 6 | from numba import cuda |
7 | 7 | from numba.cuda import types |
8 | 8 | from numba.cuda import cgutils |
9 | | -from numba.cuda.core.errors import RequireLiteralValue, TypingError |
| 9 | +from numba.cuda.core.errors import ( |
| 10 | + RequireLiteralValue, |
| 11 | + TypingError, |
| 12 | + NumbaTypeError, |
| 13 | +) |
10 | 14 | from numba.cuda.typing import signature |
11 | 15 | from numba.cuda.extending import overload_attribute, overload_method |
12 | 16 | from numba.cuda import nvvmutils |
@@ -380,3 +384,148 @@ def codegen(context, builder, sig, args): |
380 | 384 | sig = signature(a_type, membermask_type, a_type, b_type) |
381 | 385 |
|
382 | 386 | return sig, codegen |
| 387 | + |
| 388 | + |
| 389 | +# ------------------------------------------------------------------------------- |
| 390 | +# Warp vote functions |
| 391 | +# |
| 392 | +# References: |
| 393 | +# |
| 394 | +# - https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#warp-vote-functions |
| 395 | +# - https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html?highlight=data%2520movement#vote |
| 396 | +# |
| 397 | +# Notes: |
| 398 | +# |
| 399 | +# - The NVVM IR specification requires some of the mode parameter to be |
| 400 | +# constants. It's therefore essential that we pass in mode values to the |
| 401 | +# vote_sync_intrinsic. |
| 402 | + |
| 403 | + |
| 404 | +@intrinsic |
| 405 | +def all_sync(typingctx, mask_type, predicate_type): |
| 406 | + """ |
| 407 | + If for all threads in the masked warp the predicate is true, then |
| 408 | + a non-zero value is returned, otherwise 0 is returned. |
| 409 | + """ |
| 410 | + mode_value = 0 |
| 411 | + sig, codegen_inner = vote_sync_intrinsic( |
| 412 | + typingctx, mask_type, mode_value, predicate_type |
| 413 | + ) |
| 414 | + |
| 415 | + def codegen(context, builder, sig_outer, args): |
| 416 | + # Call vote_sync_intrinsic and extract the boolean result (index 1) |
| 417 | + result_tuple = codegen_inner(context, builder, sig, args) |
| 418 | + return builder.extract_value(result_tuple, 1) |
| 419 | + |
| 420 | + sig_outer = signature(types.b1, mask_type, predicate_type) |
| 421 | + return sig_outer, codegen |
| 422 | + |
| 423 | + |
| 424 | +@intrinsic |
| 425 | +def any_sync(typingctx, mask_type, predicate_type): |
| 426 | + """ |
| 427 | + If for any thread in the masked warp the predicate is true, then |
| 428 | + a non-zero value is returned, otherwise 0 is returned. |
| 429 | + """ |
| 430 | + mode_value = 1 |
| 431 | + sig, codegen_inner = vote_sync_intrinsic( |
| 432 | + typingctx, mask_type, mode_value, predicate_type |
| 433 | + ) |
| 434 | + |
| 435 | + def codegen(context, builder, sig_outer, args): |
| 436 | + result_tuple = codegen_inner(context, builder, sig, args) |
| 437 | + return builder.extract_value(result_tuple, 1) |
| 438 | + |
| 439 | + sig_outer = signature(types.b1, mask_type, predicate_type) |
| 440 | + return sig_outer, codegen |
| 441 | + |
| 442 | + |
| 443 | +@intrinsic |
| 444 | +def eq_sync(typingctx, mask_type, predicate_type): |
| 445 | + """ |
| 446 | + If for all threads in the masked warp the boolean predicate is the same, |
| 447 | + then a non-zero value is returned, otherwise 0 is returned. |
| 448 | + """ |
| 449 | + mode_value = 2 |
| 450 | + sig, codegen_inner = vote_sync_intrinsic( |
| 451 | + typingctx, mask_type, mode_value, predicate_type |
| 452 | + ) |
| 453 | + |
| 454 | + def codegen(context, builder, sig_outer, args): |
| 455 | + result_tuple = codegen_inner(context, builder, sig, args) |
| 456 | + return builder.extract_value(result_tuple, 1) |
| 457 | + |
| 458 | + sig_outer = signature(types.b1, mask_type, predicate_type) |
| 459 | + return sig_outer, codegen |
| 460 | + |
| 461 | + |
| 462 | +@intrinsic |
| 463 | +def ballot_sync(typingctx, mask_type, predicate_type): |
| 464 | + """ |
| 465 | + Returns a mask of all threads in the warp whose predicate is true, |
| 466 | + and are within the given mask. |
| 467 | + """ |
| 468 | + mode_value = 3 |
| 469 | + sig, codegen_inner = vote_sync_intrinsic( |
| 470 | + typingctx, mask_type, mode_value, predicate_type |
| 471 | + ) |
| 472 | + |
| 473 | + def codegen(context, builder, sig_outer, args): |
| 474 | + result_tuple = codegen_inner(context, builder, sig, args) |
| 475 | + return builder.extract_value( |
| 476 | + result_tuple, 0 |
| 477 | + ) # Extract ballot result (index 0) |
| 478 | + |
| 479 | + sig_outer = signature(types.i4, mask_type, predicate_type) |
| 480 | + return sig_outer, codegen |
| 481 | + |
| 482 | + |
| 483 | +def vote_sync_intrinsic(typingctx, mask_type, mode_value, predicate_type): |
| 484 | + # Validate mode value |
| 485 | + if mode_value not in (0, 1, 2, 3): |
| 486 | + raise ValueError("Mode must be 0 (all), 1 (any), 2 (eq), or 3 (ballot)") |
| 487 | + |
| 488 | + if types.unliteral(mask_type) not in types.integer_domain: |
| 489 | + raise NumbaTypeError(f"Mask type must be an integer. Got {mask_type}") |
| 490 | + predicate_types = types.integer_domain | {types.boolean} |
| 491 | + |
| 492 | + if types.unliteral(predicate_type) not in predicate_types: |
| 493 | + raise NumbaTypeError( |
| 494 | + f"Predicate must be an integer or boolean. Got {predicate_type}" |
| 495 | + ) |
| 496 | + |
| 497 | + def codegen(context, builder, sig, args): |
| 498 | + mask, predicate = args |
| 499 | + |
| 500 | + # Types |
| 501 | + i1 = ir.IntType(1) |
| 502 | + i32 = ir.IntType(32) |
| 503 | + |
| 504 | + # NVVM intrinsic definition |
| 505 | + arg_types = (i32, i32, i1) |
| 506 | + vote_return_type = ir.LiteralStructType((i32, i1)) |
| 507 | + fnty = ir.FunctionType(vote_return_type, arg_types) |
| 508 | + |
| 509 | + fname = "llvm.nvvm.vote.sync" |
| 510 | + lmod = builder.module |
| 511 | + vote_sync = cgutils.get_or_insert_function(lmod, fnty, fname) |
| 512 | + |
| 513 | + # Intrinsic arguments |
| 514 | + mode = ir.Constant(i32, mode_value) |
| 515 | + mask_i32 = builder.trunc(mask, i32) |
| 516 | + |
| 517 | + # Convert predicate to i1 |
| 518 | + if predicate.type != ir.IntType(1): |
| 519 | + predicate_bool = builder.icmp_signed( |
| 520 | + "!=", predicate, ir.Constant(predicate.type, 0) |
| 521 | + ) |
| 522 | + else: |
| 523 | + predicate_bool = predicate |
| 524 | + |
| 525 | + return builder.call(vote_sync, [mask_i32, mode, predicate_bool]) |
| 526 | + |
| 527 | + sig = signature( |
| 528 | + types.Tuple((types.i4, types.b1)), mask_type, predicate_type |
| 529 | + ) |
| 530 | + |
| 531 | + return sig, codegen |
0 commit comments