|
4 | 4 | import sys |
5 | 5 | import ctypes |
6 | 6 | import functools |
| 7 | +from collections import defaultdict |
7 | 8 |
|
8 | | -from numba.core import config, serialize, sigutils, types, typing, utils |
| 9 | +from numba.core import config, ir, serialize, sigutils, types, typing, utils |
9 | 10 | from numba.core.caching import Cache, CacheImpl |
10 | 11 | from numba.core.compiler_lock import global_compiler_lock |
11 | 12 | from numba.core.dispatcher import Dispatcher |
|
42 | 43 | reshape_funcs = ['nocopy_empty_reshape', 'numba_attempt_nocopy_reshape'] |
43 | 44 |
|
44 | 45 |
|
| 46 | +def get_cres_link_objects(cres): |
| 47 | + """Given a compile result, return a set of all linkable code objects that |
| 48 | + are required for it to be fully linked.""" |
| 49 | + |
| 50 | + link_objects = set() |
| 51 | + |
| 52 | + # List of calls into declared device functions |
| 53 | + device_func_calls = [ |
| 54 | + (name, v) for name, v in cres.fndesc.typemap.items() if ( |
| 55 | + isinstance(v, cuda_types.CUDADispatcher) |
| 56 | + ) |
| 57 | + ] |
| 58 | + |
| 59 | + # List of tuples with SSA name of calls and corresponding signature |
| 60 | + call_signatures = [ |
| 61 | + (call.func.name, sig) |
| 62 | + for call, sig in cres.fndesc.calltypes.items() if ( |
| 63 | + isinstance(call, ir.Expr) and call.op == 'call' |
| 64 | + ) |
| 65 | + ] |
| 66 | + |
| 67 | + # Map SSA names to all invoked signatures |
| 68 | + call_signature_d = defaultdict(list) |
| 69 | + for name, sig in call_signatures: |
| 70 | + call_signature_d[name].append(sig) |
| 71 | + |
| 72 | + # Add the link objects from the current function's callees |
| 73 | + for name, v in device_func_calls: |
| 74 | + for sig in call_signature_d.get(name, []): |
| 75 | + called_cres = v.dispatcher.overloads[sig.args] |
| 76 | + called_link_objects = get_cres_link_objects(called_cres) |
| 77 | + link_objects.update(called_link_objects) |
| 78 | + |
| 79 | + # From this point onwards, we are only interested in ExternFunction |
| 80 | + # declarations - these are the calls made directly in this function to |
| 81 | + # them. |
| 82 | + for name, v in cres.fndesc.typemap.items(): |
| 83 | + if not isinstance(v, Function): |
| 84 | + continue |
| 85 | + |
| 86 | + if not isinstance(v.typing_key, ExternFunction): |
| 87 | + continue |
| 88 | + |
| 89 | + for obj in v.typing_key.link: |
| 90 | + link_objects.add(obj) |
| 91 | + |
| 92 | + return link_objects |
| 93 | + |
| 94 | + |
45 | 95 | class _Kernel(serialize.ReduceMixin): |
46 | 96 | ''' |
47 | 97 | CUDA Kernel specialized for a given set of argument types. When called, this |
@@ -159,15 +209,8 @@ def link_to_library_functions(library_functions, library_path, |
159 | 209 |
|
160 | 210 | self.maybe_link_nrt(link, tgt_ctx, asm) |
161 | 211 |
|
162 | | - for k, v in cres.fndesc.typemap.items(): |
163 | | - if not isinstance(v, Function): |
164 | | - continue |
165 | | - |
166 | | - if not isinstance(v.typing_key, ExternFunction): |
167 | | - continue |
168 | | - |
169 | | - for obj in v.typing_key.link: |
170 | | - lib.add_linking_file(obj) |
| 212 | + for obj in get_cres_link_objects(cres): |
| 213 | + lib.add_linking_file(obj) |
171 | 214 |
|
172 | 215 | for filepath in link: |
173 | 216 | lib.add_linking_file(filepath) |
|
0 commit comments