Skip to content

Commit ddfda6d

Browse files
committed
Add support for default overload values
1 parent df6d56a commit ddfda6d

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

numba_cuda/numba/cuda/dispatcher.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -968,6 +968,10 @@ def get_call_template(self, args, kws):
968968
969969
A (template, pysig, args, kws) tuple is returned.
970970
"""
971+
# Fold keyword arguments and resolve default values
972+
pysig, args = self._compiler.fold_argument_types(args, kws)
973+
kws = {}
974+
971975
# Ensure an exactly-matching overload is available if we can
972976
# compile. We proceed with the typing even if we can't compile
973977
# because we may be able to force a cast on the caller side.
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""
2+
Test problems in nested calls.
3+
Usually due to invalid type conversion between function boundaries.
4+
"""
5+
6+
7+
from numba import cuda
8+
from numba.core import types
9+
from numba.cuda.testing import CUDATestCase
10+
from numba.extending import overload
11+
import unittest
12+
import numpy as np
13+
14+
15+
def generated_inner(out, x, y=5, z=6):
16+
# Provide implementation for the simulation.
17+
if isinstance(x, complex):
18+
out[0], out[1] = x + y, z
19+
else:
20+
out[0], out[1] = x - y, z
21+
22+
23+
@overload(generated_inner)
24+
def ol_generated_inner(out, x, y=5, z=6):
25+
if isinstance(x, types.Complex):
26+
def impl(out, x, y=5, z=6):
27+
out[0], out[1] = x + y, z
28+
else:
29+
def impl(out, x, y=5, z=6):
30+
out[0], out[1] = x - y, z
31+
return impl
32+
33+
34+
def call_generated(a, b, out):
35+
generated_inner(out, a, z=b)
36+
37+
38+
class TestNestedCall(CUDATestCase):
39+
def test_call_generated(self):
40+
"""
41+
Test a nested function call to a generated jit function.
42+
"""
43+
cfunc = cuda.jit(call_generated)
44+
45+
out = np.empty(2, dtype=np.int64)
46+
cfunc[1,1](1, 2, out)
47+
self.assertPreciseEqual(tuple(out), (-4, 2))
48+
49+
out = np.empty(2, dtype=np.complex64)
50+
cfunc[1,1](1j, 2, out)
51+
self.assertPreciseEqual(tuple(map(complex,out)), (5 + 1j, 2 + 0j))
52+
53+
54+
if __name__ == '__main__':
55+
unittest.main()

0 commit comments

Comments
 (0)