Skip to content

Commit e17b6fc

Browse files
committed
Support linking code for device functions in declaration
1 parent bf487d7 commit e17b6fc

File tree

5 files changed

+144
-23
lines changed

5 files changed

+144
-23
lines changed

numba_cuda/numba/cuda/compiler.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -570,16 +570,16 @@ def compile_ptx_for_current_device(pyfunc, sig, debug=None, lineinfo=False,
570570
abi=abi, abi_info=abi_info)
571571

572572

573-
def declare_device_function(name, restype, argtypes):
574-
return declare_device_function_template(name, restype, argtypes).key
573+
def declare_device_function(name, restype, argtypes, link):
574+
return declare_device_function_template(name, restype, argtypes, link).key
575575

576576

577-
def declare_device_function_template(name, restype, argtypes):
577+
def declare_device_function_template(name, restype, argtypes, link):
578578
from .descriptor import cuda_target
579579
typingctx = cuda_target.typing_context
580580
targetctx = cuda_target.target_context
581581
sig = typing.signature(restype, *argtypes)
582-
extfn = ExternFunction(name, sig)
582+
extfn = ExternFunction(name, sig, link)
583583

584584
class device_function_template(ConcreteTemplate):
585585
key = extfn
@@ -593,7 +593,8 @@ class device_function_template(ConcreteTemplate):
593593
return device_function_template
594594

595595

596-
class ExternFunction(object):
597-
def __init__(self, name, sig):
596+
class ExternFunction:
597+
def __init__(self, name, sig, link):
598598
self.name = name
599599
self.sig = sig
600+
self.link = link

numba_cuda/numba/cuda/cudadecl.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,16 +403,20 @@ def _genfp16_binary_operator(op):
403403

404404

405405
def _resolve_wrapped_unary(fname):
406+
link = tuple()
406407
decl = declare_device_function_template(f'__numba_wrapper_{fname}',
407408
types.float16,
408-
(types.float16,))
409+
(types.float16,),
410+
link)
409411
return types.Function(decl)
410412

411413

412414
def _resolve_wrapped_binary(fname):
415+
link = tuple()
413416
decl = declare_device_function_template(f'__numba_wrapper_{fname}',
414417
types.float16,
415-
(types.float16, types.float16,))
418+
(types.float16, types.float16,),
419+
link)
416420
return types.Function(decl)
417421

418422

numba_cuda/numba/cuda/decorators.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def autojitwrapper(func):
173173
return disp
174174

175175

176-
def declare_device(name, sig):
176+
def declare_device(name, sig, link=None):
177177
"""
178178
Declare the signature of a foreign function. Returns a descriptor that can
179179
be used to call the function from a Python kernel.
@@ -182,9 +182,15 @@ def declare_device(name, sig):
182182
:type name: str
183183
:param sig: The Numba signature of the function.
184184
"""
185+
if link is None:
186+
link = tuple()
187+
else:
188+
if not isinstance(link, (list, tuple, set)):
189+
link = (link,)
190+
185191
argtypes, restype = sigutils.normalize_signature(sig)
186192
if restype is None:
187193
msg = 'Return type must be provided for device declarations'
188194
raise TypeError(msg)
189195

190-
return declare_device_function(name, restype, argtypes)
196+
return declare_device_function(name, restype, argtypes, link)

numba_cuda/numba/cuda/dispatcher.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
from numba.core.dispatcher import Dispatcher
1212
from numba.core.errors import NumbaPerformanceWarning
1313
from numba.core.typing.typeof import Purpose, typeof
14-
14+
from numba.core.types.functions import Function
1515
from numba.cuda.api import get_current_device
1616
from numba.cuda.args import wrap_arg
17-
from numba.cuda.compiler import compile_cuda, CUDACompiler, kernel_fixup
17+
from numba.cuda.compiler import (compile_cuda, CUDACompiler, kernel_fixup,
18+
ExternFunction)
1819
from numba.cuda.cudadrv import driver
1920
from numba.cuda.cudadrv.devices import get_context
2021
from numba.cuda.descriptor import cuda_target
@@ -158,6 +159,16 @@ def link_to_library_functions(library_functions, library_path,
158159

159160
self.maybe_link_nrt(link, tgt_ctx, asm)
160161

162+
for k, v in cres.fndesc.typemap.items():
163+
if not isinstance(v, Function):
164+
continue
165+
166+
if not isinstance(v.typing_key, ExternFunction):
167+
continue
168+
169+
for obj in v.typing_key.link:
170+
lib.add_linking_file(obj)
171+
161172
for filepath in link:
162173
lib.add_linking_file(filepath)
163174

numba_cuda/numba/cuda/tests/cudapy/test_device_func.py

Lines changed: 110 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import re
2-
import types
2+
import cffi
33

44
import numpy as np
55

6-
from numba.cuda.testing import unittest, skip_on_cudasim, CUDATestCase
7-
from numba import cuda, jit, float32, int32
6+
from numba.cuda.testing import (skip_on_cudasim, test_data_dir, unittest,
7+
CUDATestCase)
8+
from numba import cuda, jit, float32, int32, types
89
from numba.core.errors import TypingError
10+
from types import ModuleType
911

1012

1113
class TestDeviceFunc(CUDATestCase):
@@ -92,7 +94,7 @@ def test_cpu_dispatcher_other_module(self):
9294
def add(a, b):
9395
return a + b
9496

95-
mymod = types.ModuleType(name='mymod')
97+
mymod = ModuleType(name='mymod')
9698
mymod.add = add
9799
del add
98100

@@ -192,31 +194,128 @@ def rgba_caller(x, channels):
192194

193195
self.assertEqual(0x04010203, x[0])
194196

195-
def _test_declare_device(self, decl):
197+
198+
times2_cu = cuda.CUSource("""
199+
extern "C" __device__
200+
int times2(int *out, int a)
201+
{
202+
*out = a * 2;
203+
return 0;
204+
}
205+
""")
206+
207+
208+
times4_cu = cuda.CUSource("""
209+
extern "C" __device__
210+
int times2(int *out, int a);
211+
212+
extern "C" __device__
213+
int times4(int *out, int a)
214+
{
215+
int tmp;
216+
times2(&tmp, a);
217+
*out = tmp * 2;
218+
return 0;
219+
}
220+
""")
221+
222+
jitlink_user_cu = cuda.CUSource("""
223+
extern "C" __device__
224+
int array_mutator(void *out, int *a);
225+
226+
extern "C" __device__
227+
int use_array_mutator(void *out, int *a) {
228+
array_mutator(out, a);
229+
return 0;
230+
}
231+
""")
232+
233+
234+
@skip_on_cudasim('External functions unsupported in the simulator')
235+
class TestDeclareDevice(CUDATestCase):
236+
237+
def check_api(self, decl):
196238
self.assertEqual(decl.name, 'f1')
197239
self.assertEqual(decl.sig.args, (float32[:],))
198240
self.assertEqual(decl.sig.return_type, int32)
199241

200-
@skip_on_cudasim('cudasim does not check signatures')
201242
def test_declare_device_signature(self):
202243
f1 = cuda.declare_device('f1', int32(float32[:]))
203-
self._test_declare_device(f1)
244+
self.check_api(f1)
204245

205-
@skip_on_cudasim('cudasim does not check signatures')
206246
def test_declare_device_string(self):
207247
f1 = cuda.declare_device('f1', 'int32(float32[:])')
208-
self._test_declare_device(f1)
248+
self.check_api(f1)
209249

210-
@skip_on_cudasim('cudasim does not check signatures')
211250
def test_bad_declare_device_tuple(self):
212251
with self.assertRaisesRegex(TypeError, 'Return type'):
213252
cuda.declare_device('f1', (float32[:],))
214253

215-
@skip_on_cudasim('cudasim does not check signatures')
216254
def test_bad_declare_device_string(self):
217255
with self.assertRaisesRegex(TypeError, 'Return type'):
218256
cuda.declare_device('f1', '(float32[:],)')
219257

258+
def test_link_cu_source(self):
259+
times2 = cuda.declare_device('times2', 'int32(int32)', link=times2_cu)
260+
261+
@cuda.jit
262+
def kernel(r, x):
263+
i = cuda.grid(1)
264+
if i < len(r):
265+
r[i] = times2(x[i])
266+
267+
x = np.arange(10, dtype=np.int32)
268+
r = np.empty_like(x)
269+
270+
kernel[1, 32](r, x)
271+
272+
np.testing.assert_equal(r, x * 2)
273+
274+
def _test_link_multiple_sources(self, link_type):
275+
link = link_type([times2_cu, times4_cu])
276+
times4 = cuda.declare_device('times4', 'int32(int32)', link=link)
277+
278+
@cuda.jit
279+
def kernel(r, x):
280+
i = cuda.grid(1)
281+
if i < len(r):
282+
r[i] = times4(x[i])
283+
284+
x = np.arange(10, dtype=np.int32)
285+
r = np.empty_like(x)
286+
287+
kernel[1, 32](r, x)
288+
289+
np.testing.assert_equal(r, x * 4)
290+
291+
def test_link_multiple_sources_set(self):
292+
self._test_link_multiple_sources(set)
293+
294+
def test_link_multiple_sources_tuple(self):
295+
self._test_link_multiple_sources(tuple)
296+
297+
def test_link_multiple_sources_list(self):
298+
self._test_link_multiple_sources(list)
299+
300+
def test_link_sources_in_memory_and_on_disk(self):
301+
jitlink_cu = str(test_data_dir / "jitlink.cu")
302+
link = [jitlink_cu, jitlink_user_cu]
303+
sig = types.void(types.CPointer(types.int32))
304+
ext_fn = cuda.declare_device("use_array_mutator", sig, link=link)
305+
306+
ffi = cffi.FFI()
307+
308+
@cuda.jit
309+
def kernel(x):
310+
ptr = ffi.from_buffer(x)
311+
ext_fn(ptr)
312+
313+
x = np.arange(2, dtype=np.int32)
314+
kernel[1, 1](x)
315+
316+
expected = np.ones(2, dtype=np.int32)
317+
np.testing.assert_equal(x, expected)
318+
220319

221320
if __name__ == '__main__':
222321
unittest.main()

0 commit comments

Comments
 (0)