From b229ca7d5db0a5c0bed795e58d3c8dc062d72ee2 Mon Sep 17 00:00:00 2001 From: Trent Nelson Date: Tue, 10 Jun 2025 15:43:07 -0700 Subject: [PATCH 1/4] Implement a thread-local means to access kernel launch config. This allows downstream passes, such as rewriting, to access information about the kernel launch for which they have been enlisted to participate. --- numba_cuda/numba/cuda/dispatcher.py | 15 +++++-- numba_cuda/numba/cuda/launchconfig.py | 62 +++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 4 deletions(-) create mode 100644 numba_cuda/numba/cuda/launchconfig.py diff --git a/numba_cuda/numba/cuda/dispatcher.py b/numba_cuda/numba/cuda/dispatcher.py index d129495a4..ace9d7222 100644 --- a/numba_cuda/numba/cuda/dispatcher.py +++ b/numba_cuda/numba/cuda/dispatcher.py @@ -41,6 +41,7 @@ missing_launch_config_msg, normalize_kernel_dimensions, ) +from numba.cuda.launchconfig import launch_config_ctx from numba.cuda.typing.templates import fold_arguments from numba.cuda.cudadrv.linkable_code import LinkableCode from numba.cuda.cudadrv.devices import get_context @@ -1016,10 +1017,16 @@ def call(self, args, griddim, blockdim, stream, sharedmem): """ Compile if necessary and invoke this kernel with *args*. """ - if self.specialized: - kernel = next(iter(self.overloads.values())) - else: - kernel = _dispatcher.Dispatcher._cuda_call(self, *args) + with launch_config_ctx( + griddim=griddim, + blockdim=blockdim, + stream=stream, + sharedmem=sharedmem, + ): + if self.specialized: + kernel = next(iter(self.overloads.values())) + else: + kernel = _dispatcher.Dispatcher._cuda_call(self, *args) kernel.launch(args, griddim, blockdim, stream, sharedmem) diff --git a/numba_cuda/numba/cuda/launchconfig.py b/numba_cuda/numba/cuda/launchconfig.py new file mode 100644 index 000000000..568277edb --- /dev/null +++ b/numba_cuda/numba/cuda/launchconfig.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from dataclasses import dataclass +from contextvars import ContextVar +from contextlib import contextmanager +from typing import Any, Tuple, Optional + + +@dataclass(frozen=True, slots=True) +class LaunchConfig: + """ + Helper class used to encapsulate kernel launch configuration for storing + and retrieving from a thread-local ContextVar. + """ + + griddim: Tuple[int, int, int] + blockdim: Tuple[int, int, int] + stream: Any + sharedmem: int + + def __str__(self) -> str: + g = "×".join(map(str, self.griddim)) + b = "×".join(map(str, self.blockdim)) + return ( + f"" + ) + + +_launch_config_var: ContextVar[Optional[LaunchConfig]] = ContextVar( + "_launch_config_var", + default=None, +) + + +def current_launch_config() -> Optional[LaunchConfig]: + """ + Read the launch config visible in *this* thread/asyncio task. + Returns None if no launch config is set. + """ + return _launch_config_var.get() + + +@contextmanager +def launch_config_ctx( + *, + griddim: Tuple[int, int, int], + blockdim: Tuple[int, int, int], + stream: Any, + sharedmem: int, +): + """ + Install a LaunchConfig for the dynamic extent of the with-block. + The previous value (if any) is restored automatically. + """ + token = _launch_config_var.set( + LaunchConfig(griddim, blockdim, stream, sharedmem) + ) + try: + yield + finally: + _launch_config_var.reset(token) From 2faf9cd1fe222d780ce133385382f06db140deb5 Mon Sep 17 00:00:00 2001 From: Trent Nelson Date: Tue, 10 Jun 2025 23:56:46 -0700 Subject: [PATCH 2/4] Add args and dispatcher to LaunchConfig. --- numba_cuda/numba/cuda/dispatcher.py | 4 +++- numba_cuda/numba/cuda/launchconfig.py | 14 +++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/numba_cuda/numba/cuda/dispatcher.py b/numba_cuda/numba/cuda/dispatcher.py index ace9d7222..b68c1078a 100644 --- a/numba_cuda/numba/cuda/dispatcher.py +++ b/numba_cuda/numba/cuda/dispatcher.py @@ -1018,6 +1018,8 @@ def call(self, args, griddim, blockdim, stream, sharedmem): Compile if necessary and invoke this kernel with *args*. """ with launch_config_ctx( + dispatcher=self, + args=args, griddim=griddim, blockdim=blockdim, stream=stream, @@ -1028,7 +1030,7 @@ def call(self, args, griddim, blockdim, stream, sharedmem): else: kernel = _dispatcher.Dispatcher._cuda_call(self, *args) - kernel.launch(args, griddim, blockdim, stream, sharedmem) + kernel.launch(args, griddim, blockdim, stream, sharedmem) def _compile_for_args(self, *args, **kws): # Based on _DispatcherBase._compile_for_args. diff --git a/numba_cuda/numba/cuda/launchconfig.py b/numba_cuda/numba/cuda/launchconfig.py index 568277edb..365f9f234 100644 --- a/numba_cuda/numba/cuda/launchconfig.py +++ b/numba_cuda/numba/cuda/launchconfig.py @@ -3,7 +3,10 @@ from dataclasses import dataclass from contextvars import ContextVar from contextlib import contextmanager -from typing import Any, Tuple, Optional +from typing import Any, Tuple, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from numba.cuda.dispatcher import CUDADispatcher @dataclass(frozen=True, slots=True) @@ -13,16 +16,19 @@ class LaunchConfig: and retrieving from a thread-local ContextVar. """ + dispatcher: "CUDADispatcher" + args: Tuple[Any, ...] griddim: Tuple[int, int, int] blockdim: Tuple[int, int, int] stream: Any sharedmem: int def __str__(self) -> str: + a = ", ".join(map(str, self.args)) g = "×".join(map(str, self.griddim)) b = "×".join(map(str, self.blockdim)) return ( - f"" ) @@ -44,6 +50,8 @@ def current_launch_config() -> Optional[LaunchConfig]: @contextmanager def launch_config_ctx( *, + dispatcher: "CUDADispatcher", + args: Tuple[Any, ...], griddim: Tuple[int, int, int], blockdim: Tuple[int, int, int], stream: Any, @@ -54,7 +62,7 @@ def launch_config_ctx( The previous value (if any) is restored automatically. """ token = _launch_config_var.set( - LaunchConfig(griddim, blockdim, stream, sharedmem) + LaunchConfig(dispatcher, args, griddim, blockdim, stream, sharedmem) ) try: yield From 6a88b46655aed206e55beadf261c93e128082adf Mon Sep 17 00:00:00 2001 From: Trent Nelson Date: Sun, 13 Jul 2025 19:21:43 -0700 Subject: [PATCH 3/4] Implement ensure_current_launch_config(). This routine raises an error if no launch config is set, which is inevitably going to be the preferred way of obtaining the current launch config. --- numba_cuda/numba/cuda/launchconfig.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/numba_cuda/numba/cuda/launchconfig.py b/numba_cuda/numba/cuda/launchconfig.py index 365f9f234..0497fc951 100644 --- a/numba_cuda/numba/cuda/launchconfig.py +++ b/numba_cuda/numba/cuda/launchconfig.py @@ -47,6 +47,17 @@ def current_launch_config() -> Optional[LaunchConfig]: return _launch_config_var.get() +def ensure_current_launch_config() -> LaunchConfig: + """ + Ensure that a launch config is set for *this* thread/asyncio task. + Returns the launch config. Raises RuntimeError if no launch config is set. + """ + launch_config = current_launch_config() + if launch_config is None: + raise RuntimeError("No launch config set for this thread/asyncio task") + return launch_config + + @contextmanager def launch_config_ctx( *, From fea9f79f93020e60986517a001ca2895076af3e9 Mon Sep 17 00:00:00 2001 From: Trent Nelson Date: Sun, 13 Jul 2025 19:32:48 -0700 Subject: [PATCH 4/4] Add support for pre-kernel-launch callbacks to launch config. This is required by cuda.coop in order to pass two-phase primitive instances as kernel parameters without having to call the @cuda.jit decorator with extensions=[...] up-front. --- numba_cuda/numba/cuda/dispatcher.py | 5 +++- numba_cuda/numba/cuda/launchconfig.py | 37 ++++++++++++++++++++++----- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/numba_cuda/numba/cuda/dispatcher.py b/numba_cuda/numba/cuda/dispatcher.py index b68c1078a..1d9a2ff6e 100644 --- a/numba_cuda/numba/cuda/dispatcher.py +++ b/numba_cuda/numba/cuda/dispatcher.py @@ -1024,12 +1024,15 @@ def call(self, args, griddim, blockdim, stream, sharedmem): blockdim=blockdim, stream=stream, sharedmem=sharedmem, - ): + ) as launch_config: if self.specialized: kernel = next(iter(self.overloads.values())) else: kernel = _dispatcher.Dispatcher._cuda_call(self, *args) + for callback in launch_config.pre_launch_callbacks: + callback(kernel, launch_config) + kernel.launch(args, griddim, blockdim, stream, sharedmem) def _compile_for_args(self, *args, **kws): diff --git a/numba_cuda/numba/cuda/launchconfig.py b/numba_cuda/numba/cuda/launchconfig.py index 0497fc951..8ef355f1d 100644 --- a/numba_cuda/numba/cuda/launchconfig.py +++ b/numba_cuda/numba/cuda/launchconfig.py @@ -3,10 +3,17 @@ from dataclasses import dataclass from contextvars import ContextVar from contextlib import contextmanager -from typing import Any, Tuple, Optional, TYPE_CHECKING +from typing import ( + Any, + Callable, + List, + Tuple, + Optional, + TYPE_CHECKING, +) if TYPE_CHECKING: - from numba.cuda.dispatcher import CUDADispatcher + from numba.cuda.dispatcher import CUDADispatcher, _Kernel @dataclass(frozen=True, slots=True) @@ -22,14 +29,24 @@ class LaunchConfig: blockdim: Tuple[int, int, int] stream: Any sharedmem: int + pre_launch_callbacks: List[Callable[["_Kernel", "LaunchConfig"], None]] + """ + List of functions to call before launching a kernel. The functions are + called with the kernel and the launch config as arguments. This enables + just-in-time modifications to the kernel's configuration prior to launch, + such as appending extensions for dynamic types that were created after the + @cuda.jit decorator appeared (i.e. as part of rewriting). + """ def __str__(self) -> str: a = ", ".join(map(str, self.args)) g = "×".join(map(str, self.griddim)) b = "×".join(map(str, self.blockdim)) + cb = ", ".join(map(str, self.pre_launch_callbacks)) return ( f"" + f"stream={self.stream}, smem={self.sharedmem}B, " + f"pre_launch_callbacks=[{cb}]>" ) @@ -72,10 +89,18 @@ def launch_config_ctx( Install a LaunchConfig for the dynamic extent of the with-block. The previous value (if any) is restored automatically. """ - token = _launch_config_var.set( - LaunchConfig(dispatcher, args, griddim, blockdim, stream, sharedmem) + pre_launch_callbacks = [] + launch_config = LaunchConfig( + dispatcher, + args, + griddim, + blockdim, + stream, + sharedmem, + pre_launch_callbacks, ) + token = _launch_config_var.set(launch_config) try: - yield + yield launch_config finally: _launch_config_var.reset(token)