Skip to content

Commit c2e1b8e

Browse files
committed
Add support for printing bools
1 parent 5e03f15 commit c2e1b8e

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

numba_cuda/numba/cuda/printimpl.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,17 @@ def dim3_print_impl(ty, context, builder, val):
6363
return rawfmt, [x, y, z]
6464

6565

66+
@print_item.register(types.Boolean)
67+
def bool_print_impl(ty, context, builder, val):
68+
true_string = context.insert_string_const_addrspace(builder, "True")
69+
false_string = context.insert_string_const_addrspace(builder, "False")
70+
res_ptr = cgutils.alloca_once_value(builder, false_string)
71+
with builder.if_then(val):
72+
builder.store(true_string, res_ptr)
73+
rawfmt = "%s"
74+
return rawfmt, [builder.load(res_ptr)]
75+
76+
6677
@lower(print, types.VarArg(types.Any))
6778
def print_varargs(context, builder, sig, args):
6879
"""This function is a generic 'print' wrapper for arbitrary types.

numba_cuda/numba/cuda/tests/cudapy/test_print.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,21 @@ def printfloat():
3232
"""
3333

3434

35+
printbool_usecase = """\
36+
from numba import cuda
37+
38+
@cuda.jit
39+
def printbool(x):
40+
print(True)
41+
print(False)
42+
print(x == 0)
43+
44+
printbool[1, 1](0)
45+
printbool[1, 1](1)
46+
cuda.synchronize()
47+
"""
48+
49+
3550
printstring_usecase = """\
3651
from numba import cuda
3752
@@ -109,6 +124,11 @@ def test_printfloat(self):
109124
expected_cases = ["0 23 34.750000 321", "0 23 34.75 321"]
110125
self.assertIn(output.strip(), expected_cases)
111126

127+
def test_bool(self):
128+
output, _ = self.run_code(printbool_usecase)
129+
expected = "True\nFalse\nTrue\nTrue\nFalse\nFalse"
130+
self.assertEqual(output.strip(), expected)
131+
112132
def test_printempty(self):
113133
output, _ = self.run_code(printempty_usecase)
114134
self.assertEqual(output.strip(), "")

0 commit comments

Comments
 (0)