Skip to content

Commit 5aa6b2e

Browse files
authored
Fix registration with Numba, vendor MakeFunctionToJITFunction tests (#566)
PR #555 removed dependencies on the target extension, but it also removed registration in Numba's target registry. This prevents handling of closures, because closure handling uses a process that is equivalent to wrapping closure functions in a `@jit` decorator, where the appropriate `@jit` decorator for the target is looked up in the jit registry. This commit restores the registration when Numba is present, and vendors / ports `test_make_function_to_jit_function`. This ensures that the target registration continues to work, and also exercises the `MakeFunctionToJITFunction` pass, which previously had no effect on any code in the test suite. In order to preserve the form of the original tests, the original test code has the `@njit` decorators replaced with a wrapper function that generates an appropriate kernel, temporary storage, and launch code. Some changes had to be made to avoid refcounting, where arrays would have been used. These have been replaced by (in different cases): - Summation of all the values that would have been array elements, or - Carefully constructing tuples to return and then packing / unpacking them at function boundaries.
1 parent eeac024 commit 5aa6b2e

File tree

2 files changed

+381
-0
lines changed

2 files changed

+381
-0
lines changed

numba_cuda/numba/cuda/initialize.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,20 @@
55
def initialize_all():
66
# Import models to register them with the data model manager
77
import numba.cuda.models # noqa: F401
8+
9+
from numba.cuda import HAS_NUMBA
10+
11+
if not HAS_NUMBA:
12+
return
13+
14+
from numba.cuda.decorators import jit
15+
from numba.cuda.dispatcher import CUDADispatcher
16+
from numba.core.target_extension import (
17+
target_registry,
18+
dispatcher_registry,
19+
jit_registry,
20+
)
21+
22+
cuda_target = target_registry["cuda"]
23+
jit_registry[cuda_target] = jit
24+
dispatcher_registry[cuda_target] = CUDADispatcher
Lines changed: 364 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,364 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: BSD-2-Clause
3+
4+
from numba import cuda
5+
from numba.cuda.core import errors
6+
from numba.cuda.extending import overload
7+
from numba.cuda.testing import skip_on_cudasim
8+
import numpy as np
9+
10+
import unittest
11+
12+
13+
@cuda.jit
14+
def consumer(func, *args):
15+
return func(*args)
16+
17+
18+
@cuda.jit
19+
def consumer2arg(func1, func2):
20+
return func2(func1)
21+
22+
23+
def wrap_with_kernel_noarg(func):
24+
jitted_func = cuda.jit(func)
25+
26+
@cuda.jit
27+
def kernel(out):
28+
out[0] = jitted_func()
29+
30+
def runner():
31+
out = np.zeros(1, dtype=np.int64)
32+
kernel[1, 1](out)
33+
return out[0]
34+
35+
return runner
36+
37+
38+
def wrap_with_kernel_one_arg(func):
39+
jitted_func = cuda.jit(func)
40+
41+
@cuda.jit
42+
def kernel(out, in1):
43+
out[0] = jitted_func(in1)
44+
45+
def runner(in1):
46+
out = np.zeros(1, dtype=np.int64)
47+
kernel[1, 1](out, in1)
48+
return out[0]
49+
50+
return runner
51+
52+
53+
def wrap_with_kernel_two_args(func):
54+
jitted_func = cuda.jit(func)
55+
56+
@cuda.jit
57+
def kernel(out, in1, in2):
58+
out[0] = jitted_func(in1, in2)
59+
60+
def runner(in1, in2):
61+
out = np.zeros(1, dtype=np.int64)
62+
kernel[1, 1](out, in1, in2)
63+
return out[0]
64+
65+
return runner
66+
67+
68+
def wrap_with_kernel_noarg_tuple_return(func):
69+
jitted_func = cuda.jit(func)
70+
71+
@cuda.jit
72+
def kernel(out):
73+
out[0], out[1], out[2], out[3] = jitted_func()
74+
75+
def runner():
76+
out = np.zeros(4, dtype=np.int64)
77+
kernel[1, 1](out)
78+
return out[0], out[1], out[2], out[3]
79+
80+
return runner
81+
82+
83+
_global = 123
84+
85+
86+
class TestMakeFunctionToJITFunction(unittest.TestCase):
87+
"""
88+
This tests the pass that converts ir.Expr.op == make_function (i.e. closure)
89+
into a JIT function.
90+
"""
91+
92+
# NOTE: testing this is a bit tricky. The function receiving a JIT'd closure
93+
# must also be under JIT control so as to handle the JIT'd closure
94+
# correctly, however, in the case of running the test implementations in the
95+
# interpreter, the receiving function cannot be JIT'd else it will receive
96+
# the Python closure and then complain about pyobjects as arguments.
97+
# The way around this is to use a factory function to close over either the
98+
# jitted or standard python function as the consumer depending on context.
99+
100+
def test_escape(self):
101+
def impl_factory(consumer_func):
102+
def impl():
103+
def inner():
104+
return 10
105+
106+
return consumer_func(inner)
107+
108+
return impl
109+
110+
cfunc = wrap_with_kernel_noarg(impl_factory(consumer))
111+
impl = impl_factory(consumer.py_func)
112+
113+
self.assertEqual(impl(), cfunc())
114+
115+
def test_nested_escape(self):
116+
def impl_factory(consumer_func):
117+
def impl():
118+
def inner():
119+
return 10
120+
121+
def innerinner(x):
122+
return x()
123+
124+
return consumer_func(inner, innerinner)
125+
126+
return impl
127+
128+
cfunc = wrap_with_kernel_noarg(impl_factory(consumer2arg))
129+
impl = impl_factory(consumer2arg.py_func)
130+
131+
self.assertEqual(impl(), cfunc())
132+
133+
def test_closure_in_escaper(self):
134+
def impl_factory(consumer_func):
135+
def impl():
136+
def callinner():
137+
def inner():
138+
return 10
139+
140+
return inner()
141+
142+
return consumer_func(callinner)
143+
144+
return impl
145+
146+
cfunc = wrap_with_kernel_noarg(impl_factory(consumer))
147+
impl = impl_factory(consumer.py_func)
148+
149+
self.assertEqual(impl(), cfunc())
150+
151+
def test_close_over_consts(self):
152+
def impl_factory(consumer_func):
153+
def impl():
154+
y = 10
155+
156+
def callinner(z):
157+
return y + z + _global
158+
159+
return consumer_func(callinner, 6)
160+
161+
return impl
162+
163+
cfunc = wrap_with_kernel_noarg(impl_factory(consumer))
164+
impl = impl_factory(consumer.py_func)
165+
166+
self.assertEqual(impl(), cfunc())
167+
168+
def test_close_over_consts_w_args(self):
169+
def impl_factory(consumer_func):
170+
def impl(x):
171+
y = 10
172+
173+
def callinner(z):
174+
return y + z + _global
175+
176+
return consumer_func(callinner, x)
177+
178+
return impl
179+
180+
cfunc = wrap_with_kernel_one_arg(impl_factory(consumer))
181+
impl = impl_factory(consumer.py_func)
182+
183+
a = 5
184+
self.assertEqual(impl(a), cfunc(a))
185+
186+
def test_with_overload(self):
187+
def foo(func, *args):
188+
nargs = len(args)
189+
if nargs == 1:
190+
return func(*args)
191+
elif nargs == 2:
192+
return func(func(*args))
193+
194+
@overload(foo)
195+
def foo_ol(func, *args):
196+
# specialise on the number of args, as per `foo`
197+
nargs = len(args)
198+
if nargs == 1:
199+
200+
def impl(func, *args):
201+
return func(*args)
202+
203+
return impl
204+
elif nargs == 2:
205+
206+
def impl(func, *args):
207+
return func(func(*args))
208+
209+
return impl
210+
211+
def impl_factory(consumer_func):
212+
def impl(x):
213+
y = 10
214+
215+
def callinner(*z):
216+
if len(z) == 1:
217+
tmp = z[0]
218+
elif len(z) == 2:
219+
tmp = z[0] + z[1]
220+
return y + tmp + _global
221+
222+
# run both specialisations, 1 arg, and 2 arg.
223+
return foo(callinner, x) + foo(callinner, x, x)
224+
225+
return impl
226+
227+
cfunc = wrap_with_kernel_one_arg(impl_factory(consumer))
228+
impl = impl_factory(consumer.py_func)
229+
230+
a = 5
231+
self.assertEqual(impl(a), cfunc(a))
232+
233+
def test_basic_apply_like_case(self):
234+
def apply(arg, func):
235+
return func(arg)
236+
237+
@overload(apply)
238+
def ov_apply(arg, func):
239+
return lambda arg, func: func(arg)
240+
241+
def impl(arg):
242+
def mul10(x):
243+
return x * 10
244+
245+
return apply(arg, mul10)
246+
247+
cfunc = wrap_with_kernel_one_arg(impl)
248+
249+
a = 10
250+
np.testing.assert_allclose(impl(a), cfunc(a))
251+
252+
# this needs true SSA to be able to work correctly, check error for now
253+
@skip_on_cudasim("Simulator will not raise a typing error")
254+
def test_multiply_defined_freevar(self):
255+
def impl(c):
256+
if c:
257+
x = 3
258+
259+
def inner(y):
260+
return y + x
261+
262+
r = consumer(inner, 1)
263+
else:
264+
x = 6
265+
266+
def inner(y):
267+
return y + x
268+
269+
r = consumer(inner, 2)
270+
return r
271+
272+
with self.assertRaises(errors.TypingError) as e:
273+
cuda.jit("void(int64)")(impl)
274+
275+
self.assertIn(
276+
"Cannot capture a constant value for variable", str(e.exception)
277+
)
278+
279+
@skip_on_cudasim("Simulator will not raise a typing error")
280+
def test_non_const_in_escapee(self):
281+
def impl(x):
282+
z = np.arange(x)
283+
284+
def inner(val):
285+
return 1 + z + val # z is non-const freevar
286+
287+
return consumer(inner, x)
288+
289+
with self.assertRaises(errors.TypingError) as e:
290+
cuda.jit("void(int64)")(impl)
291+
292+
self.assertIn(
293+
"Cannot capture the non-constant value associated", str(e.exception)
294+
)
295+
296+
def test_escape_with_kwargs(self):
297+
def impl_factory(consumer_func):
298+
def impl():
299+
t = 12
300+
301+
def inner(a, b, c, mydefault1=123, mydefault2=456):
302+
z = 4
303+
return mydefault1 + mydefault2 + z + t + a + b + c
304+
305+
# this is awkward, top and tail closure inlining with a escapees
306+
# in the middle that do/don't have defaults.
307+
return (
308+
inner(1, 2, 5, 91, 53),
309+
consumer_func(inner, 1, 2, 3, 73),
310+
consumer_func(
311+
inner,
312+
1,
313+
2,
314+
3,
315+
),
316+
inner(1, 2, 4),
317+
)
318+
319+
return impl
320+
321+
cfunc = wrap_with_kernel_noarg_tuple_return(impl_factory(consumer))
322+
impl = impl_factory(consumer.py_func)
323+
324+
np.testing.assert_allclose(impl(), cfunc())
325+
326+
def test_escape_with_kwargs_override_kwargs(self):
327+
@cuda.jit
328+
def specialised_consumer(func, *args):
329+
x, y, z = args # unpack to avoid `CALL_FUNCTION_EX`
330+
a = func(x, y, z, mydefault1=1000)
331+
b = func(x, y, z, mydefault2=1000)
332+
c = func(x, y, z, mydefault1=1000, mydefault2=1000)
333+
return a + b + c
334+
335+
def impl_factory(consumer_func):
336+
def impl():
337+
t = 12
338+
339+
def inner(a, b, c, mydefault1=123, mydefault2=456):
340+
z = 4
341+
return mydefault1 + mydefault2 + z + t + a + b + c
342+
343+
# this is awkward, top and tail closure inlining with a escapees
344+
# in the middle that get defaults specified in the consumer
345+
return (
346+
inner(1, 2, 5, 91, 53),
347+
consumer_func(inner, 1, 2, 11),
348+
consumer_func(
349+
inner,
350+
1,
351+
2,
352+
3,
353+
),
354+
inner(1, 2, 4),
355+
)
356+
357+
return impl
358+
359+
cfunc = wrap_with_kernel_noarg_tuple_return(
360+
impl_factory(specialised_consumer)
361+
)
362+
impl = impl_factory(specialised_consumer.py_func)
363+
364+
np.testing.assert_allclose(impl(), cfunc())

0 commit comments

Comments
 (0)