Skip to content

Commit 849cc33

Browse files
brycelelbachgmarkall
authored andcommitted
Add device-side support for int.bit_count (which just lowers to cuda.popc).
Expand tests for cuda.popc to include smaller integer types.
1 parent 9ed01c5 commit 849cc33

File tree

2 files changed

+50
-5
lines changed

2 files changed

+50
-5
lines changed

numba_cuda/numba/cuda/intrinsics.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from numba.core import cgutils
55
from numba.core.errors import RequireLiteralValue
66
from numba.core.typing import signature
7-
from numba.core.extending import overload_attribute
7+
from numba.core.extending import overload_attribute, overload_method
88
from numba.cuda import nvvmutils
99
from numba.cuda.extending import intrinsic
1010

@@ -196,3 +196,8 @@ def syncthreads_or(typingctx, predicate):
196196
'''
197197
fname = 'llvm.nvvm.barrier0.or'
198198
return _syncthreads_predicate(typingctx, predicate, fname)
199+
200+
201+
@overload_method(types.Integer, 'bit_count', target='cuda')
202+
def integer_bit_count(i):
203+
return lambda i: cuda.popc(i)

numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def simple_popc(ary, c):
6868
ary[0] = cuda.popc(c)
6969

7070

71+
def simple_bit_count(ary, c):
72+
ary[0] = c.bit_count()
73+
74+
7175
def simple_fma(ary, a, b, c):
7276
ary[0] = cuda.fma(a, b, c)
7377

@@ -550,17 +554,53 @@ def foo(out):
550554

551555
self.assertTrue(np.all(arr))
552556

557+
def test_popc_u1(self):
558+
compiled = cuda.jit("void(int32[:], uint8)")(simple_popc)
559+
ary = np.zeros(1, dtype=np.int8)
560+
compiled[1, 1](ary, np.uint8(0xFF))
561+
self.assertEqual(ary[0], 8)
562+
563+
def test_popc_u2(self):
564+
compiled = cuda.jit("void(int32[:], uint16)")(simple_popc)
565+
ary = np.zeros(1, dtype=np.int16)
566+
compiled[1, 1](ary, np.uint16(0xFFFF))
567+
self.assertEqual(ary[0], 16)
568+
553569
def test_popc_u4(self):
554570
compiled = cuda.jit("void(int32[:], uint32)")(simple_popc)
555571
ary = np.zeros(1, dtype=np.int32)
556-
compiled[1, 1](ary, 0xF0)
557-
self.assertEqual(ary[0], 4)
572+
compiled[1, 1](ary, np.uint32(0xFFFFFFFF))
573+
self.assertEqual(ary[0], 32)
558574

559575
def test_popc_u8(self):
560576
compiled = cuda.jit("void(int32[:], uint64)")(simple_popc)
561577
ary = np.zeros(1, dtype=np.int32)
562-
compiled[1, 1](ary, 0xF00000000000)
563-
self.assertEqual(ary[0], 4)
578+
compiled[1, 1](ary, np.uint64(0xFFFFFFFFFFFFFFFF))
579+
self.assertEqual(ary[0], 64)
580+
581+
def test_bit_count_u1(self):
582+
compiled = cuda.jit("void(int32[:], uint8)")(simple_bit_count)
583+
ary = np.zeros(1, dtype=np.int8)
584+
compiled[1, 1](ary, np.uint8(0xFF))
585+
self.assertEqual(ary[0], 8)
586+
587+
def test_bit_count_u2(self):
588+
compiled = cuda.jit("void(int32[:], uint16)")(simple_bit_count)
589+
ary = np.zeros(1, dtype=np.int16)
590+
compiled[1, 1](ary, np.uint16(0xFFFF))
591+
self.assertEqual(ary[0], 16)
592+
593+
def test_bit_count_u4(self):
594+
compiled = cuda.jit("void(int32[:], uint32)")(simple_bit_count)
595+
ary = np.zeros(1, dtype=np.int32)
596+
compiled[1, 1](ary, np.uint32(0xFFFFFFFF))
597+
self.assertEqual(ary[0], 32)
598+
599+
def test_bit_count_u8(self):
600+
compiled = cuda.jit("void(int32[:], uint64)")(simple_bit_count)
601+
ary = np.zeros(1, dtype=np.int32)
602+
compiled[1, 1](ary, np.uint64(0xFFFFFFFFFFFFFFFF))
603+
self.assertEqual(ary[0], 64)
564604

565605
def test_fma_f4(self):
566606
compiled = cuda.jit("void(f4[:], f4, f4, f4)")(simple_fma)

0 commit comments

Comments
 (0)