Skip to content

Commit a58f3fc

Browse files
authored
Support linking code for device functions in declaration (#124)
Allows specifying files to link in the `cuda.declare_device()` declaration, so that it's no longer required for the user to know which files to link. Changes consist of: - Adding the `link` kwarg to the `declare_device` function, and automatically linking in any linkable items when the declared function is used. - Updating the documentation to describe this mechanism, and reflect that it's the recommended way to specify what to link. - Documents the `LinkableCode` classes, which were previously undocumented. - Removes some obsolete notices about needing the NVIDIA bindings for linking C/C++ code. - Adds cffi to the test environment, as it's used by one of the new tests (it should have already been present, really). I decided to not tackle #67 in its entirety, which also requests that a callback function can be used to generate the implementation, for a couple of reasons: - I think the existing implementation is of immediate value for Numbast, and all other FFI-calling implementations. - There is some thought needed about how to handle typing when a callback function is used - for example, whether it's necessary to generalize the typing beyond just the single signature that `declare_device()` presently accepts.
1 parent e7c2d66 commit a58f3fc

File tree

13 files changed

+281
-57
lines changed

13 files changed

+281
-57
lines changed

ci/test_conda.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ set -euo pipefail
88
if [ "${CUDA_VER%.*.*}" = "11" ]; then
99
CTK_PACKAGES="cudatoolkit"
1010
else
11-
CTK_PACKAGES="cuda-cccl cuda-nvcc-impl cuda-nvrtc"
11+
CTK_PACKAGES="cuda-cccl cuda-nvcc-impl cuda-nvrtc libcurand-dev"
1212
fi
1313

1414
rapids-logger "Install testing dependencies"
@@ -22,6 +22,7 @@ rapids-mamba-retry create -n test \
2222
make \
2323
psutil \
2424
pytest \
25+
cffi \
2526
python=${RAPIDS_PY_VERSION}
2627

2728
# Temporarily allow unbound variables for conda activation.

ci/test_conda_pynvjitlink.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ set -euo pipefail
88
if [ "${CUDA_VER%.*.*}" = "11" ]; then
99
CTK_PACKAGES="cudatoolkit"
1010
else
11-
CTK_PACKAGES="cuda-nvcc-impl cuda-nvrtc cuda-cuobjdump"
11+
CTK_PACKAGES="cuda-nvcc-impl cuda-nvrtc cuda-cuobjdump libcurand-dev"
1212
fi
1313

1414
rapids-logger "Install testing dependencies"
@@ -22,6 +22,7 @@ rapids-mamba-retry create -n test \
2222
make \
2323
psutil \
2424
pytest \
25+
cffi \
2526
python=${RAPIDS_PY_VERSION}
2627

2728
# Temporarily allow unbound variables for conda activation.

ci/test_wheel.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ rapids-logger "Install testing dependencies"
77
# TODO: Replace with rapids-dependency-file-generator
88
python -m pip install \
99
psutil \
10+
cffi \
1011
cuda-python \
1112
nvidia-cuda-cccl-cu12 \
13+
nvidia-curand-cu12 \
1214
pytest
1315

1416
rapids-logger "Install wheel"

ci/test_wheel_pynvjitlink.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ rapids-logger "Install testing dependencies"
77
# TODO: Replace with rapids-dependency-file-generator
88
python -m pip install \
99
psutil \
10+
cffi \
1011
cuda-python \
12+
nvidia-curand-cu12 \
1113
pytest
1214

1315
rapids-logger "Install pynvjitlink"

docs/source/user/cuda_ffi.rst

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ of a Python kernel call to a foreign device function are:
1111

1212
- The device function implementation in a foreign language (e.g. CUDA C).
1313
- A declaration of the device function in Python.
14-
- A kernel that links with and calls the foreign function.
14+
- A kernel that calls the foreign function.
1515

1616
.. _device-function-abi:
1717

@@ -83,7 +83,7 @@ For example, when:
8383

8484
.. code::
8585
86-
mul = cuda.declare_device('mul_f32_f32', 'float32(float32, float32)')
86+
mul = cuda.declare_device('mul_f32_f32', 'float32(float32, float32)' , link="functions.cu")
8787
8888
is declared, calling ``mul(a, b)`` inside a kernel will translate into a call to
8989
``mul_f32_f32(a, b)`` in the compiled code.
@@ -134,15 +134,63 @@ where ``result`` and ``array`` are both arrays of ``float32`` data.
134134
Linking and Calling functions
135135
-----------------------------
136136

137-
The ``link`` keyword argument of the :func:`@cuda.jit <numba.cuda.jit>`
138-
decorator accepts a list of file names specified by absolute path or a path
139-
relative to the current working directory. Files whose name ends in ``.cu``
140-
will be compiled with the `NVIDIA Runtime Compiler (NVRTC)
141-
<https://docs.nvidia.com/cuda/nvrtc/index.html>`_ and linked into the kernel as
142-
PTX; other files will be passed directly to the CUDA Linker.
137+
The ``link`` keyword argument to the :func:`declare_device
138+
<numba.cuda.declare_device>` function accepts *Linkable Code* items. Either a
139+
single Linkable Code item can be passed, or multiple items in a list, tuple, or
140+
set.
141+
142+
A Linkable Code item is either:
143+
144+
* A string indicating the location of a file in the filesystem, or
145+
* A :class:`LinkableCode <numba.cuda.LinkableCode>` object, for linking code
146+
that exists in memory.
147+
148+
Suported code formats that can be linked are:
149+
150+
* PTX source code (``*.ptx``)
151+
* CUDA C/C++ source code (``*.cu``)
152+
* CUDA ELF Fat Binaries (``*.fatbin``)
153+
* CUDA ELF Cubins (``*.cubin``)
154+
* CUDA ELF archives (``*.a``)
155+
* CUDA Object files (``*.o``)
156+
* CUDA LTOIR files (``*.ltoir``)
157+
158+
CUDA C/C++ source code will be compiled with the `NVIDIA Runtime Compiler
159+
(NVRTC) <https://docs.nvidia.com/cuda/nvrtc/index.html>`_ and linked into the
160+
kernel as either PTX or LTOIR, depending on whether LTO is enabled. Other files
161+
will be passed directly to the CUDA Linker.
162+
163+
:class:`LinkableCode <numba.cuda.LinkableCode>` objects are initialized using
164+
the parameters of their base class:
143165

144-
For example, the following kernel calls the ``mul()`` function declared above
145-
with the implementation ``mul_f32_f32()`` in a file called ``functions.cu``:
166+
.. autoclass:: numba.cuda.LinkableCode
167+
168+
However, one should instantiate an instance of the class that represents the
169+
type of item being linked:
170+
171+
.. autoclass:: numba.cuda.PTXSource
172+
.. autoclass:: numba.cuda.CUSource
173+
.. autoclass:: numba.cuda.Fatbin
174+
.. autoclass:: numba.cuda.Cubin
175+
.. autoclass:: numba.cuda.Archive
176+
.. autoclass:: numba.cuda.Object
177+
.. autoclass:: numba.cuda.LTOIR
178+
179+
Legacy ``@cuda.jit`` decorator ``link`` support
180+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
181+
182+
The ``link`` keyword argument of the :func:`@cuda.jit <numba.cuda.jit>`
183+
decorator also accepts a list of Linkable Code items, which will then be linked
184+
into the kernel. This facility is provided for backwards compatibility; it is
185+
recommended that Linkable Code items are always specified in the
186+
:func:`declare_device <numba.cuda.declare_device>` call, so that the user of the
187+
declared API is not burdened with specifying the items to link themselves when
188+
writing a kernel.
189+
190+
As an example of how this legacy mechanism looked at the point of use: the
191+
following kernel calls the ``mul()`` function declared above with the
192+
implementation ``mul_f32_f32()`` as if it were in a file called ``functions.cu``
193+
that had not been declared as part of the ``link`` argument in the declaration:
146194

147195
.. code::
148196
@@ -153,17 +201,13 @@ with the implementation ``mul_f32_f32()`` in a file called ``functions.cu``:
153201
if i < len(r):
154202
r[i] = mul(x[i], y[i])
155203
156-
157204
C/C++ Support
158205
-------------
159206

160207
Support for compiling and linking of CUDA C/C++ code is provided through the use
161208
of NVRTC subject to the following considerations:
162209

163-
- It is only available when using the NVIDIA Bindings. See
164-
:envvar:`NUMBA_CUDA_USE_NVIDIA_BINDING`.
165-
- A suitable version of the NVRTC library for the installed version of the
166-
NVIDIA CUDA Bindings must be available.
210+
- A suitable version of the NVRTC library must be available.
167211
- The CUDA include path is assumed by default to be ``/usr/local/cuda/include``
168212
on Linux and ``$env:CUDA_PATH\include`` on Windows. It can be modified using
169213
the environment variable :envvar:`NUMBA_CUDA_INCLUDE_PATH`.

numba_cuda/numba/cuda/compiler.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -570,16 +570,16 @@ def compile_ptx_for_current_device(pyfunc, sig, debug=None, lineinfo=False,
570570
abi=abi, abi_info=abi_info)
571571

572572

573-
def declare_device_function(name, restype, argtypes):
574-
return declare_device_function_template(name, restype, argtypes).key
573+
def declare_device_function(name, restype, argtypes, link):
574+
return declare_device_function_template(name, restype, argtypes, link).key
575575

576576

577-
def declare_device_function_template(name, restype, argtypes):
577+
def declare_device_function_template(name, restype, argtypes, link):
578578
from .descriptor import cuda_target
579579
typingctx = cuda_target.typing_context
580580
targetctx = cuda_target.target_context
581581
sig = typing.signature(restype, *argtypes)
582-
extfn = ExternFunction(name, sig)
582+
extfn = ExternFunction(name, sig, link)
583583

584584
class device_function_template(ConcreteTemplate):
585585
key = extfn
@@ -593,7 +593,8 @@ class device_function_template(ConcreteTemplate):
593593
return device_function_template
594594

595595

596-
class ExternFunction(object):
597-
def __init__(self, name, sig):
596+
class ExternFunction:
597+
def __init__(self, name, sig, link):
598598
self.name = name
599599
self.sig = sig
600+
self.link = link

numba_cuda/numba/cuda/cudadecl.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,16 +403,20 @@ def _genfp16_binary_operator(op):
403403

404404

405405
def _resolve_wrapped_unary(fname):
406+
link = tuple()
406407
decl = declare_device_function_template(f'__numba_wrapper_{fname}',
407408
types.float16,
408-
(types.float16,))
409+
(types.float16,),
410+
link)
409411
return types.Function(decl)
410412

411413

412414
def _resolve_wrapped_binary(fname):
415+
link = tuple()
413416
decl = declare_device_function_template(f'__numba_wrapper_{fname}',
414417
types.float16,
415-
(types.float16, types.float16,))
418+
(types.float16, types.float16,),
419+
link)
416420
return types.Function(decl)
417421

418422

numba_cuda/numba/cuda/cudadrv/linkable_code.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22

33

44
class LinkableCode:
5-
"""An object that can be passed in the `link` list argument to `@cuda.jit`
6-
kernels to supply code to be linked from memory."""
5+
"""An object that holds code to be linked from memory.
6+
7+
:param data: A buffer containing the data to link.
8+
:param name: The name of the file to be referenced in any compilation or
9+
linking errors that may be produced.
10+
"""
711

812
def __init__(self, data, name=None):
913
self.data = data
@@ -15,49 +19,49 @@ def name(self):
1519

1620

1721
class PTXSource(LinkableCode):
18-
"""PTX Source code in memory"""
22+
"""PTX source code in memory."""
1923

2024
kind = FILE_EXTENSION_MAP["ptx"]
2125
default_name = "<unnamed-ptx>"
2226

2327

2428
class CUSource(LinkableCode):
25-
"""CUDA C/C++ Source code in memory"""
29+
"""CUDA C/C++ source code in memory."""
2630

2731
kind = "cu"
2832
default_name = "<unnamed-cu>"
2933

3034

3135
class Fatbin(LinkableCode):
32-
"""A fatbin ELF in memory"""
36+
"""An ELF Fatbin in memory."""
3337

3438
kind = FILE_EXTENSION_MAP["fatbin"]
3539
default_name = "<unnamed-fatbin>"
3640

3741

3842
class Cubin(LinkableCode):
39-
"""A cubin ELF in memory"""
43+
"""An ELF Cubin in memory."""
4044

4145
kind = FILE_EXTENSION_MAP["cubin"]
4246
default_name = "<unnamed-cubin>"
4347

4448

4549
class Archive(LinkableCode):
46-
"""An archive of objects in memory"""
50+
"""An archive of objects in memory."""
4751

4852
kind = FILE_EXTENSION_MAP["a"]
4953
default_name = "<unnamed-archive>"
5054

5155

5256
class Object(LinkableCode):
53-
"""An object file in memory"""
57+
"""An object file in memory."""
5458

5559
kind = FILE_EXTENSION_MAP["o"]
5660
default_name = "<unnamed-object>"
5761

5862

5963
class LTOIR(LinkableCode):
60-
"""An LTOIR file in memory"""
64+
"""An LTOIR file in memory."""
6165

6266
kind = "ltoir"
6367
default_name = "<unnamed-ltoir>"

numba_cuda/numba/cuda/decorators.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,18 +173,25 @@ def autojitwrapper(func):
173173
return disp
174174

175175

176-
def declare_device(name, sig):
176+
def declare_device(name, sig, link=None):
177177
"""
178178
Declare the signature of a foreign function. Returns a descriptor that can
179179
be used to call the function from a Python kernel.
180180
181181
:param name: The name of the foreign function.
182182
:type name: str
183183
:param sig: The Numba signature of the function.
184+
:param link: External code to link when calling the function.
184185
"""
186+
if link is None:
187+
link = tuple()
188+
else:
189+
if not isinstance(link, (list, tuple, set)):
190+
link = (link,)
191+
185192
argtypes, restype = sigutils.normalize_signature(sig)
186193
if restype is None:
187194
msg = 'Return type must be provided for device declarations'
188195
raise TypeError(msg)
189196

190-
return declare_device_function(name, restype, argtypes)
197+
return declare_device_function(name, restype, argtypes, link)

numba_cuda/numba/cuda/dispatcher.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
from numba.core.dispatcher import Dispatcher
1212
from numba.core.errors import NumbaPerformanceWarning
1313
from numba.core.typing.typeof import Purpose, typeof
14-
14+
from numba.core.types.functions import Function
1515
from numba.cuda.api import get_current_device
1616
from numba.cuda.args import wrap_arg
17-
from numba.cuda.compiler import compile_cuda, CUDACompiler, kernel_fixup
17+
from numba.cuda.compiler import (compile_cuda, CUDACompiler, kernel_fixup,
18+
ExternFunction)
1819
from numba.cuda.cudadrv import driver
1920
from numba.cuda.cudadrv.devices import get_context
2021
from numba.cuda.descriptor import cuda_target
@@ -158,6 +159,16 @@ def link_to_library_functions(library_functions, library_path,
158159

159160
self.maybe_link_nrt(link, tgt_ctx, asm)
160161

162+
for k, v in cres.fndesc.typemap.items():
163+
if not isinstance(v, Function):
164+
continue
165+
166+
if not isinstance(v.typing_key, ExternFunction):
167+
continue
168+
169+
for obj in v.typing_key.link:
170+
lib.add_linking_file(obj)
171+
161172
for filepath in link:
162173
lib.add_linking_file(filepath)
163174

0 commit comments

Comments
 (0)