Skip to content

Commit 9df5748

Browse files
Add printing support for Dim3s. (#14)
--------- Co-authored-by: Graham Markall <[email protected]>
1 parent 173e000 commit 9df5748

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

numba_cuda/numba/cuda/printimpl.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from numba.core.errors import NumbaWarning
55
from numba.core.imputils import Registry
66
from numba.cuda import nvvmutils
7+
from numba.cuda.types import Dim3
78
from warnings import warn
89

910
registry = Registry()
@@ -53,6 +54,15 @@ def const_print_impl(ty, context, builder, sigval):
5354
return rawfmt, [val]
5455

5556

57+
@print_item.register(Dim3)
58+
def dim3_print_impl(ty, context, builder, val):
59+
rawfmt = "(%d, %d, %d)"
60+
x = builder.extract_value(val, 0)
61+
y = builder.extract_value(val, 1)
62+
z = builder.extract_value(val, 2)
63+
return rawfmt, [x, y, z]
64+
65+
5666
@lower(print, types.VarArg(types.Any))
5767
def print_varargs(context, builder, sig, args):
5868
"""This function is a generic 'print' wrapper for arbitrary types.

numba_cuda/numba/cuda/tests/cudapy/test_print.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from numba.cuda.testing import CUDATestCase, skip_on_cudasim
2+
import numpy as np
23
import subprocess
34
import sys
45
import unittest
@@ -43,6 +44,19 @@ def printstring():
4344
cuda.synchronize()
4445
"""
4546

47+
48+
printdim3_usecase = """\
49+
from numba import cuda
50+
51+
@cuda.jit
52+
def printdim3():
53+
print(cuda.threadIdx)
54+
55+
printdim3[1, (2, 2, 2)]()
56+
cuda.synchronize()
57+
"""
58+
59+
4660
printempty_usecase = """\
4761
from numba import cuda
4862
@@ -105,6 +119,12 @@ def test_string(self):
105119
expected = ['%d hop! 999' % i for i in range(3)]
106120
self.assertEqual(sorted(lines), expected)
107121

122+
def test_dim3(self):
123+
output, _ = self.run_code(printdim3_usecase)
124+
lines = [line.strip() for line in output.splitlines(True)]
125+
expected = [str(i) for i in np.ndindex(2, 2, 2)]
126+
self.assertEqual(sorted(lines), expected)
127+
108128
@skip_on_cudasim('cudasim can print unlimited output')
109129
def test_too_many_args(self):
110130
# Tests that we emit the format string and warn when there are more

0 commit comments

Comments
 (0)