Skip to content

Commit 9f0d154

Browse files
gmarkallisVoidKyleFromNVIDIA
authored
Fix linking of external code from callees (#137)
The original linking implementation for linkable code in device declarations did not consider calls inside callees; this change recurses through the typing to find all calls requiring linkable code. --------- Co-authored-by: isVoid <[email protected]> Co-authored-by: Kyle Edwards <[email protected]>
1 parent 332eb7c commit 9f0d154

File tree

2 files changed

+178
-10
lines changed

2 files changed

+178
-10
lines changed

numba_cuda/numba/cuda/dispatcher.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
import sys
55
import ctypes
66
import functools
7+
from collections import defaultdict
78

8-
from numba.core import config, serialize, sigutils, types, typing, utils
9+
from numba.core import config, ir, serialize, sigutils, types, typing, utils
910
from numba.core.caching import Cache, CacheImpl
1011
from numba.core.compiler_lock import global_compiler_lock
1112
from numba.core.dispatcher import Dispatcher
@@ -42,6 +43,55 @@
4243
reshape_funcs = ['nocopy_empty_reshape', 'numba_attempt_nocopy_reshape']
4344

4445

46+
def get_cres_link_objects(cres):
47+
"""Given a compile result, return a set of all linkable code objects that
48+
are required for it to be fully linked."""
49+
50+
link_objects = set()
51+
52+
# List of calls into declared device functions
53+
device_func_calls = [
54+
(name, v) for name, v in cres.fndesc.typemap.items() if (
55+
isinstance(v, cuda_types.CUDADispatcher)
56+
)
57+
]
58+
59+
# List of tuples with SSA name of calls and corresponding signature
60+
call_signatures = [
61+
(call.func.name, sig)
62+
for call, sig in cres.fndesc.calltypes.items() if (
63+
isinstance(call, ir.Expr) and call.op == 'call'
64+
)
65+
]
66+
67+
# Map SSA names to all invoked signatures
68+
call_signature_d = defaultdict(list)
69+
for name, sig in call_signatures:
70+
call_signature_d[name].append(sig)
71+
72+
# Add the link objects from the current function's callees
73+
for name, v in device_func_calls:
74+
for sig in call_signature_d.get(name, []):
75+
called_cres = v.dispatcher.overloads[sig.args]
76+
called_link_objects = get_cres_link_objects(called_cres)
77+
link_objects.update(called_link_objects)
78+
79+
# From this point onwards, we are only interested in ExternFunction
80+
# declarations - these are the calls made directly in this function to
81+
# them.
82+
for name, v in cres.fndesc.typemap.items():
83+
if not isinstance(v, Function):
84+
continue
85+
86+
if not isinstance(v.typing_key, ExternFunction):
87+
continue
88+
89+
for obj in v.typing_key.link:
90+
link_objects.add(obj)
91+
92+
return link_objects
93+
94+
4595
class _Kernel(serialize.ReduceMixin):
4696
'''
4797
CUDA Kernel specialized for a given set of argument types. When called, this
@@ -159,15 +209,8 @@ def link_to_library_functions(library_functions, library_path,
159209

160210
self.maybe_link_nrt(link, tgt_ctx, asm)
161211

162-
for k, v in cres.fndesc.typemap.items():
163-
if not isinstance(v, Function):
164-
continue
165-
166-
if not isinstance(v.typing_key, ExternFunction):
167-
continue
168-
169-
for obj in v.typing_key.link:
170-
lib.add_linking_file(obj)
212+
for obj in get_cres_link_objects(cres):
213+
lib.add_linking_file(obj)
171214

172215
for filepath in link:
173216
lib.add_linking_file(filepath)

numba_cuda/numba/cuda/tests/cudapy/test_device_func.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,14 @@ def rgba_caller(x, channels):
205205
}
206206
""")
207207

208+
times3_cu = cuda.CUSource("""
209+
extern "C" __device__
210+
int times3(int *out, int a)
211+
{
212+
*out = a * 3;
213+
return 0;
214+
}
215+
""")
208216

209217
times4_cu = cuda.CUSource("""
210218
extern "C" __device__
@@ -351,6 +359,123 @@ def kernel(x, seed):
351359
kernel[1, 1](x, 1)
352360
np.testing.assert_equal(x[0], 323845807)
353361

362+
def test_declared_in_called_function(self):
363+
times2 = cuda.declare_device('times2', 'int32(int32)', link=times2_cu)
364+
365+
@cuda.jit
366+
def device_func(x):
367+
return times2(x)
368+
369+
@cuda.jit
370+
def kernel(r, x):
371+
i = cuda.grid(1)
372+
if i < len(r):
373+
r[i] = device_func(x[i])
374+
375+
x = np.arange(10, dtype=np.int32)
376+
r = np.empty_like(x)
377+
378+
kernel[1, 32](r, x)
379+
380+
np.testing.assert_equal(r, x * 2)
381+
382+
def test_declared_in_called_function_twice(self):
383+
times2 = cuda.declare_device('times2', 'int32(int32)', link=times2_cu)
384+
385+
@cuda.jit
386+
def device_func_1(x):
387+
return times2(x)
388+
389+
@cuda.jit
390+
def device_func_2(x):
391+
return device_func_1(x)
392+
393+
@cuda.jit
394+
def kernel(r, x):
395+
i = cuda.grid(1)
396+
if i < len(r):
397+
r[i] = device_func_2(x[i])
398+
399+
x = np.arange(10, dtype=np.int32)
400+
r = np.empty_like(x)
401+
402+
kernel[1, 32](r, x)
403+
404+
np.testing.assert_equal(r, x * 2)
405+
406+
def test_declared_in_called_function_two_calls(self):
407+
times2 = cuda.declare_device('times2', 'int32(int32)', link=times2_cu)
408+
409+
@cuda.jit
410+
def device_func(x):
411+
return times2(x)
412+
413+
@cuda.jit
414+
def kernel(r, x):
415+
i = cuda.grid(1)
416+
if i < len(r):
417+
r[i] = device_func(x[i]) + device_func(x[i] + i)
418+
419+
x = np.arange(10, dtype=np.int32)
420+
r = np.empty_like(x)
421+
422+
kernel[1, 32](r, x)
423+
424+
np.testing.assert_equal(r, x * 6)
425+
426+
def test_call_declared_function_twice(self):
427+
times2 = cuda.declare_device('times2', 'int32(int32)', link=times2_cu)
428+
429+
@cuda.jit
430+
def kernel(r, x):
431+
i = cuda.grid(1)
432+
if i < len(r):
433+
r[i] = times2(x[i]) + times2(x[i] + i)
434+
435+
x = np.arange(10, dtype=np.int32)
436+
r = np.empty_like(x)
437+
438+
kernel[1, 32](r, x)
439+
440+
np.testing.assert_equal(r, x * 6)
441+
442+
def test_declared_in_called_function_and_parent(self):
443+
times2 = cuda.declare_device('times2', 'int32(int32)', link=times2_cu)
444+
445+
@cuda.jit
446+
def device_func(x):
447+
return times2(x)
448+
449+
@cuda.jit
450+
def kernel(r, x):
451+
i = cuda.grid(1)
452+
if i < len(r):
453+
r[i] = device_func(x[i]) + times2(x[i])
454+
455+
x = np.arange(10, dtype=np.int32)
456+
r = np.empty_like(x)
457+
458+
kernel[1, 32](r, x)
459+
460+
np.testing.assert_equal(r, x * 4)
461+
462+
def test_call_two_different_declared_functions(self):
463+
times2 = cuda.declare_device('times2', 'int32(int32)', link=times2_cu)
464+
times3 = cuda.declare_device('times3', 'int32(int32)', link=times3_cu)
465+
466+
@cuda.jit
467+
def kernel(r, x):
468+
i = cuda.grid(1)
469+
if i < len(r):
470+
r[i] = times2(x[i]) + times3(x[i])
471+
472+
x = np.arange(10, dtype=np.int32)
473+
r = np.empty_like(x)
474+
475+
kernel[1, 32](r, x)
476+
477+
np.testing.assert_equal(r, x * 5)
478+
354479

355480
if __name__ == '__main__':
356481
unittest.main()

0 commit comments

Comments
 (0)