Skip to content

Commit 1996fac

Browse files
committed
Fix registration with Numba, vendor MakeFunctionToJITFunction tests
PR #555 removed dependencies on the target extension, but it also removed registration in Numba's target registry. This prevents handling of closures, because closure handling uses a process that is equivalent to wrapping closure functions in a `@jit` decorator, where the appropriate `@jit` decorator for the target is looked up in the jit registry. This commit restores the registration when Numba is present, and vendors / ports `test_make_function_to_jit_function`. This ensures that the target registration continues to work, and also exercises the `MakeFunctionToJITFunction` pass, which previously had no effect on any code in the test suite. In order to preserve the form of the original tests, the original test code has the `@njit` decorators replaced with a wrapper function that generates an appropriate kernel, temporary storage, and launch code. Some changes had to be made to avoid refcounting, where arrays would have been used. These have been replaced by (in different cases): - Summation of all the values that would have been array elements, or - Carefully constructing tuples to return and then packing / unpacking them at function boundaries.
1 parent 5aeb63c commit 1996fac

File tree

2 files changed

+375
-0
lines changed

2 files changed

+375
-0
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,21 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
3+
from importlib.util import find_spec
34

45

56
def initialize_all():
67
# Import models to register them with the data model manager
78
import numba.cuda.models # noqa: F401
9+
10+
if find_spec("numba"):
11+
from numba.cuda.decorators import jit
12+
from numba.cuda.dispatcher import CUDADispatcher
13+
from numba.core.target_extension import (
14+
target_registry,
15+
dispatcher_registry,
16+
jit_registry,
17+
)
18+
19+
cuda_target = target_registry["cuda"]
20+
jit_registry[cuda_target] = jit
21+
dispatcher_registry[cuda_target] = CUDADispatcher
Lines changed: 361 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: BSD-2-Clause
3+
4+
from numba import cuda
5+
from numba.core import errors
6+
from numba.cuda.extending import overload
7+
import numpy as np
8+
9+
import unittest
10+
11+
12+
@cuda.jit
13+
def consumer(func, *args):
14+
return func(*args)
15+
16+
17+
@cuda.jit
18+
def consumer2arg(func1, func2):
19+
return func2(func1)
20+
21+
22+
def wrap_with_kernel_noarg(func):
23+
jitted_func = cuda.jit(func)
24+
25+
@cuda.jit
26+
def kernel(out):
27+
out[0] = jitted_func()
28+
29+
def runner():
30+
out = np.zeros(1, dtype=np.int64)
31+
kernel[1, 1](out)
32+
return out[0]
33+
34+
return runner
35+
36+
37+
def wrap_with_kernel_one_arg(func):
38+
jitted_func = cuda.jit(func)
39+
40+
@cuda.jit
41+
def kernel(out, in1):
42+
out[0] = jitted_func(in1)
43+
44+
def runner(in1):
45+
out = np.zeros(1, dtype=np.int64)
46+
kernel[1, 1](out, in1)
47+
return out[0]
48+
49+
return runner
50+
51+
52+
def wrap_with_kernel_two_args(func):
53+
jitted_func = cuda.jit(func)
54+
55+
@cuda.jit
56+
def kernel(out, in1, in2):
57+
out[0] = jitted_func(in1, in2)
58+
59+
def runner(in1, in2):
60+
out = np.zeros(1, dtype=np.int64)
61+
kernel[1, 1](out, in1, in2)
62+
return out[0]
63+
64+
return runner
65+
66+
67+
def wrap_with_kernel_noarg_tuple_return(func):
68+
jitted_func = cuda.jit(func)
69+
70+
@cuda.jit
71+
def kernel(out):
72+
out[0], out[1], out[2], out[3] = jitted_func()
73+
74+
def runner():
75+
out = np.zeros(4, dtype=np.int64)
76+
kernel[1, 1](out)
77+
return out[0], out[1], out[2], out[3]
78+
79+
return runner
80+
81+
82+
_global = 123
83+
84+
85+
class TestMakeFunctionToJITFunction(unittest.TestCase):
86+
"""
87+
This tests the pass that converts ir.Expr.op == make_function (i.e. closure)
88+
into a JIT function.
89+
"""
90+
91+
# NOTE: testing this is a bit tricky. The function receiving a JIT'd closure
92+
# must also be under JIT control so as to handle the JIT'd closure
93+
# correctly, however, in the case of running the test implementations in the
94+
# interpreter, the receiving function cannot be JIT'd else it will receive
95+
# the Python closure and then complain about pyobjects as arguments.
96+
# The way around this is to use a factory function to close over either the
97+
# jitted or standard python function as the consumer depending on context.
98+
99+
def test_escape(self):
100+
def impl_factory(consumer_func):
101+
def impl():
102+
def inner():
103+
return 10
104+
105+
return consumer_func(inner)
106+
107+
return impl
108+
109+
cfunc = wrap_with_kernel_noarg(impl_factory(consumer))
110+
impl = impl_factory(consumer.py_func)
111+
112+
self.assertEqual(impl(), cfunc())
113+
114+
def test_nested_escape(self):
115+
def impl_factory(consumer_func):
116+
def impl():
117+
def inner():
118+
return 10
119+
120+
def innerinner(x):
121+
return x()
122+
123+
return consumer_func(inner, innerinner)
124+
125+
return impl
126+
127+
cfunc = wrap_with_kernel_noarg(impl_factory(consumer2arg))
128+
impl = impl_factory(consumer2arg.py_func)
129+
130+
self.assertEqual(impl(), cfunc())
131+
132+
def test_closure_in_escaper(self):
133+
def impl_factory(consumer_func):
134+
def impl():
135+
def callinner():
136+
def inner():
137+
return 10
138+
139+
return inner()
140+
141+
return consumer_func(callinner)
142+
143+
return impl
144+
145+
cfunc = wrap_with_kernel_noarg(impl_factory(consumer))
146+
impl = impl_factory(consumer.py_func)
147+
148+
self.assertEqual(impl(), cfunc())
149+
150+
def test_close_over_consts(self):
151+
def impl_factory(consumer_func):
152+
def impl():
153+
y = 10
154+
155+
def callinner(z):
156+
return y + z + _global
157+
158+
return consumer_func(callinner, 6)
159+
160+
return impl
161+
162+
cfunc = wrap_with_kernel_noarg(impl_factory(consumer))
163+
impl = impl_factory(consumer.py_func)
164+
165+
self.assertEqual(impl(), cfunc())
166+
167+
def test_close_over_consts_w_args(self):
168+
def impl_factory(consumer_func):
169+
def impl(x):
170+
y = 10
171+
172+
def callinner(z):
173+
return y + z + _global
174+
175+
return consumer_func(callinner, x)
176+
177+
return impl
178+
179+
cfunc = wrap_with_kernel_one_arg(impl_factory(consumer))
180+
impl = impl_factory(consumer.py_func)
181+
182+
a = 5
183+
self.assertEqual(impl(a), cfunc(a))
184+
185+
def test_with_overload(self):
186+
def foo(func, *args):
187+
nargs = len(args)
188+
if nargs == 1:
189+
return func(*args)
190+
elif nargs == 2:
191+
return func(func(*args))
192+
193+
@overload(foo)
194+
def foo_ol(func, *args):
195+
# specialise on the number of args, as per `foo`
196+
nargs = len(args)
197+
if nargs == 1:
198+
199+
def impl(func, *args):
200+
return func(*args)
201+
202+
return impl
203+
elif nargs == 2:
204+
205+
def impl(func, *args):
206+
return func(func(*args))
207+
208+
return impl
209+
210+
def impl_factory(consumer_func):
211+
def impl(x):
212+
y = 10
213+
214+
def callinner(*z):
215+
if len(z) == 1:
216+
tmp = z[0]
217+
elif len(z) == 2:
218+
tmp = z[0] + z[1]
219+
return y + tmp + _global
220+
221+
# run both specialisations, 1 arg, and 2 arg.
222+
return foo(callinner, x) + foo(callinner, x, x)
223+
224+
return impl
225+
226+
cfunc = wrap_with_kernel_one_arg(impl_factory(consumer))
227+
impl = impl_factory(consumer.py_func)
228+
229+
a = 5
230+
self.assertEqual(impl(a), cfunc(a))
231+
232+
def test_basic_apply_like_case(self):
233+
def apply(arg, func):
234+
return func(arg)
235+
236+
@overload(apply)
237+
def ov_apply(arg, func):
238+
return lambda arg, func: func(arg)
239+
240+
def impl(arg):
241+
def mul10(x):
242+
return x * 10
243+
244+
return apply(arg, mul10)
245+
246+
cfunc = wrap_with_kernel_one_arg(impl)
247+
248+
a = 10
249+
np.testing.assert_allclose(impl(a), cfunc(a))
250+
251+
# this needs true SSA to be able to work correctly, check error for now
252+
def test_multiply_defined_freevar(self):
253+
def impl(c):
254+
if c:
255+
x = 3
256+
257+
def inner(y):
258+
return y + x
259+
260+
r = consumer(inner, 1)
261+
else:
262+
x = 6
263+
264+
def inner(y):
265+
return y + x
266+
267+
r = consumer(inner, 2)
268+
return r
269+
270+
with self.assertRaises(errors.TypingError) as e:
271+
cuda.jit("void(int64)")(impl)
272+
273+
self.assertIn(
274+
"Cannot capture a constant value for variable", str(e.exception)
275+
)
276+
277+
def test_non_const_in_escapee(self):
278+
def impl(x):
279+
z = np.arange(x)
280+
281+
def inner(val):
282+
return 1 + z + val # z is non-const freevar
283+
284+
return consumer(inner, x)
285+
286+
with self.assertRaises(errors.TypingError) as e:
287+
cuda.jit("void(int64)")(impl)
288+
289+
self.assertIn(
290+
"Cannot capture the non-constant value associated", str(e.exception)
291+
)
292+
293+
def test_escape_with_kwargs(self):
294+
def impl_factory(consumer_func):
295+
def impl():
296+
t = 12
297+
298+
def inner(a, b, c, mydefault1=123, mydefault2=456):
299+
z = 4
300+
return mydefault1 + mydefault2 + z + t + a + b + c
301+
302+
# this is awkward, top and tail closure inlining with a escapees
303+
# in the middle that do/don't have defaults.
304+
return (
305+
inner(1, 2, 5, 91, 53),
306+
consumer_func(inner, 1, 2, 3, 73),
307+
consumer_func(
308+
inner,
309+
1,
310+
2,
311+
3,
312+
),
313+
inner(1, 2, 4),
314+
)
315+
316+
return impl
317+
318+
cfunc = wrap_with_kernel_noarg_tuple_return(impl_factory(consumer))
319+
impl = impl_factory(consumer.py_func)
320+
321+
np.testing.assert_allclose(impl(), cfunc())
322+
323+
def test_escape_with_kwargs_override_kwargs(self):
324+
@cuda.jit
325+
def specialised_consumer(func, *args):
326+
x, y, z = args # unpack to avoid `CALL_FUNCTION_EX`
327+
a = func(x, y, z, mydefault1=1000)
328+
b = func(x, y, z, mydefault2=1000)
329+
c = func(x, y, z, mydefault1=1000, mydefault2=1000)
330+
return a + b + c
331+
332+
def impl_factory(consumer_func):
333+
def impl():
334+
t = 12
335+
336+
def inner(a, b, c, mydefault1=123, mydefault2=456):
337+
z = 4
338+
return mydefault1 + mydefault2 + z + t + a + b + c
339+
340+
# this is awkward, top and tail closure inlining with a escapees
341+
# in the middle that get defaults specified in the consumer
342+
return (
343+
inner(1, 2, 5, 91, 53),
344+
consumer_func(inner, 1, 2, 11),
345+
consumer_func(
346+
inner,
347+
1,
348+
2,
349+
3,
350+
),
351+
inner(1, 2, 4),
352+
)
353+
354+
return impl
355+
356+
cfunc = wrap_with_kernel_noarg_tuple_return(
357+
impl_factory(specialised_consumer)
358+
)
359+
impl = impl_factory(specialised_consumer.py_func)
360+
361+
np.testing.assert_allclose(impl(), cfunc())

0 commit comments

Comments
 (0)