Skip to content

Commit 4fd1bae

Browse files
committed
Clean up overload tests
1 parent ddfda6d commit 4fd1bae

File tree

2 files changed

+23
-55
lines changed

2 files changed

+23
-55
lines changed

numba_cuda/numba/cuda/tests/cudapy/test_nested_calls.py

Lines changed: 0 additions & 55 deletions
This file was deleted.

numba_cuda/numba/cuda/tests/cudapy/test_overload.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ def target_overloaded_calls_target_overloaded():
5656
pass
5757

5858

59+
def default_values_and_kwargs():
60+
pass
61+
62+
5963
# To recognise which functions are resolved for a call, we identify each with a
6064
# prime number. Each function called multiplies a value by its prime (starting
6165
# with the value 1), and we can check that the result is as expected based on
@@ -185,6 +189,13 @@ def impl(x):
185189
return impl
186190

187191

192+
@overload(default_values_and_kwargs)
193+
def ol_default_values_and_kwargs(out, x, y=5, z=6):
194+
def impl(out, x, y=5, z=6):
195+
out[0], out[1] = x + y, z
196+
return impl
197+
198+
188199
@skip_on_cudasim('Overloading not supported in cudasim')
189200
class TestOverload(CUDATestCase):
190201
def check_overload(self, kernel, expected):
@@ -330,6 +341,18 @@ def illegal_target_attr_use(x):
330341
def cuda_target_attr_use(res, dummy):
331342
res[0] = dummy.cuda_only
332343

344+
def test_default_values_and_kwargs(self):
345+
"""
346+
Test default values and kwargs.
347+
"""
348+
@cuda.jit()
349+
def kernel(a, b, out):
350+
default_values_and_kwargs(out, a, z=b)
351+
352+
out = np.empty(2, dtype=np.int64)
353+
kernel[1,1](1, 2, out)
354+
self.assertEqual(tuple(out), (6, 2))
355+
333356

334357
if __name__ == '__main__':
335358
unittest.main()

0 commit comments

Comments
 (0)