Skip to content

Commit fea9f79

Browse files
committed
Add support for pre-kernel-launch callbacks to launch config.
This is required by cuda.coop in order to pass two-phase primitive instances as kernel parameters without having to call the @cuda.jit decorator with extensions=[...] up-front.
1 parent 6a88b46 commit fea9f79

File tree

2 files changed

+35
-7
lines changed

2 files changed

+35
-7
lines changed

numba_cuda/numba/cuda/dispatcher.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1024,12 +1024,15 @@ def call(self, args, griddim, blockdim, stream, sharedmem):
10241024
blockdim=blockdim,
10251025
stream=stream,
10261026
sharedmem=sharedmem,
1027-
):
1027+
) as launch_config:
10281028
if self.specialized:
10291029
kernel = next(iter(self.overloads.values()))
10301030
else:
10311031
kernel = _dispatcher.Dispatcher._cuda_call(self, *args)
10321032

1033+
for callback in launch_config.pre_launch_callbacks:
1034+
callback(kernel, launch_config)
1035+
10331036
kernel.launch(args, griddim, blockdim, stream, sharedmem)
10341037

10351038
def _compile_for_args(self, *args, **kws):

numba_cuda/numba/cuda/launchconfig.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,17 @@
33
from dataclasses import dataclass
44
from contextvars import ContextVar
55
from contextlib import contextmanager
6-
from typing import Any, Tuple, Optional, TYPE_CHECKING
6+
from typing import (
7+
Any,
8+
Callable,
9+
List,
10+
Tuple,
11+
Optional,
12+
TYPE_CHECKING,
13+
)
714

815
if TYPE_CHECKING:
9-
from numba.cuda.dispatcher import CUDADispatcher
16+
from numba.cuda.dispatcher import CUDADispatcher, _Kernel
1017

1118

1219
@dataclass(frozen=True, slots=True)
@@ -22,14 +29,24 @@ class LaunchConfig:
2229
blockdim: Tuple[int, int, int]
2330
stream: Any
2431
sharedmem: int
32+
pre_launch_callbacks: List[Callable[["_Kernel", "LaunchConfig"], None]]
33+
"""
34+
List of functions to call before launching a kernel. The functions are
35+
called with the kernel and the launch config as arguments. This enables
36+
just-in-time modifications to the kernel's configuration prior to launch,
37+
such as appending extensions for dynamic types that were created after the
38+
@cuda.jit decorator appeared (i.e. as part of rewriting).
39+
"""
2540

2641
def __str__(self) -> str:
2742
a = ", ".join(map(str, self.args))
2843
g = "×".join(map(str, self.griddim))
2944
b = "×".join(map(str, self.blockdim))
45+
cb = ", ".join(map(str, self.pre_launch_callbacks))
3046
return (
3147
f"<LaunchConfig args=[{a}], grid={g}, block={b}, "
32-
f"stream={self.stream}, smem={self.sharedmem}B>"
48+
f"stream={self.stream}, smem={self.sharedmem}B, "
49+
f"pre_launch_callbacks=[{cb}]>"
3350
)
3451

3552

@@ -72,10 +89,18 @@ def launch_config_ctx(
7289
Install a LaunchConfig for the dynamic extent of the with-block.
7390
The previous value (if any) is restored automatically.
7491
"""
75-
token = _launch_config_var.set(
76-
LaunchConfig(dispatcher, args, griddim, blockdim, stream, sharedmem)
92+
pre_launch_callbacks = []
93+
launch_config = LaunchConfig(
94+
dispatcher,
95+
args,
96+
griddim,
97+
blockdim,
98+
stream,
99+
sharedmem,
100+
pre_launch_callbacks,
77101
)
102+
token = _launch_config_var.set(launch_config)
78103
try:
79-
yield
104+
yield launch_config
80105
finally:
81106
_launch_config_var.reset(token)

0 commit comments

Comments
 (0)