Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions numba_cuda/numba/cuda/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
missing_launch_config_msg,
normalize_kernel_dimensions,
)
from numba.cuda.launchconfig import launch_config_ctx
from numba.cuda.typing.templates import fold_arguments
from numba.cuda.cudadrv.linkable_code import LinkableCode
from numba.cuda.cudadrv.devices import get_context
Expand Down Expand Up @@ -1016,12 +1017,23 @@ def call(self, args, griddim, blockdim, stream, sharedmem):
"""
Compile if necessary and invoke this kernel with *args*.
"""
if self.specialized:
kernel = next(iter(self.overloads.values()))
else:
kernel = _dispatcher.Dispatcher._cuda_call(self, *args)
with launch_config_ctx(
dispatcher=self,
args=args,
griddim=griddim,
blockdim=blockdim,
stream=stream,
sharedmem=sharedmem,
) as launch_config:
if self.specialized:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specialized kernels cannot be recompiled, so a new launch configuration would not be able to affect the compilation of a new version - so this check could be kept outside the context manager.

kernel = next(iter(self.overloads.values()))
else:
kernel = _dispatcher.Dispatcher._cuda_call(self, *args)

for callback in launch_config.pre_launch_callbacks:
callback(kernel, launch_config)

kernel.launch(args, griddim, blockdim, stream, sharedmem)
kernel.launch(args, griddim, blockdim, stream, sharedmem)

def _compile_for_args(self, *args, **kws):
# Based on _DispatcherBase._compile_for_args.
Expand Down
106 changes: 106 additions & 0 deletions numba_cuda/numba/cuda/launchconfig.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from __future__ import annotations

from dataclasses import dataclass
from contextvars import ContextVar
from contextlib import contextmanager
from typing import (
Any,
Callable,
List,
Tuple,
Optional,
TYPE_CHECKING,
)

if TYPE_CHECKING:
from numba.cuda.dispatcher import CUDADispatcher, _Kernel


@dataclass(frozen=True, slots=True)
class LaunchConfig:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to be quite some overlap with dispatcher._LaunchConfiguration in this class (an observation at this point - I don't know whether it makes sense to combine them)

"""
Helper class used to encapsulate kernel launch configuration for storing
and retrieving from a thread-local ContextVar.
"""

dispatcher: "CUDADispatcher"
args: Tuple[Any, ...]
griddim: Tuple[int, int, int]
blockdim: Tuple[int, int, int]
stream: Any
sharedmem: int
pre_launch_callbacks: List[Callable[["_Kernel", "LaunchConfig"], None]]
"""
List of functions to call before launching a kernel. The functions are
called with the kernel and the launch config as arguments. This enables
just-in-time modifications to the kernel's configuration prior to launch,
such as appending extensions for dynamic types that were created after the
@cuda.jit decorator appeared (i.e. as part of rewriting).
"""

def __str__(self) -> str:
a = ", ".join(map(str, self.args))
g = "×".join(map(str, self.griddim))
b = "×".join(map(str, self.blockdim))
cb = ", ".join(map(str, self.pre_launch_callbacks))
return (
f"<LaunchConfig args=[{a}], grid={g}, block={b}, "
f"stream={self.stream}, smem={self.sharedmem}B, "
f"pre_launch_callbacks=[{cb}]>"
)


_launch_config_var: ContextVar[Optional[LaunchConfig]] = ContextVar(
"_launch_config_var",
default=None,
)


def current_launch_config() -> Optional[LaunchConfig]:
"""
Read the launch config visible in *this* thread/asyncio task.
Returns None if no launch config is set.
"""
return _launch_config_var.get()


def ensure_current_launch_config() -> LaunchConfig:
"""
Ensure that a launch config is set for *this* thread/asyncio task.
Returns the launch config. Raises RuntimeError if no launch config is set.
"""
launch_config = current_launch_config()
if launch_config is None:
raise RuntimeError("No launch config set for this thread/asyncio task")
return launch_config


@contextmanager
def launch_config_ctx(
*,
dispatcher: "CUDADispatcher",
args: Tuple[Any, ...],
griddim: Tuple[int, int, int],
blockdim: Tuple[int, int, int],
stream: Any,
sharedmem: int,
):
"""
Install a LaunchConfig for the dynamic extent of the with-block.
The previous value (if any) is restored automatically.
"""
pre_launch_callbacks = []
launch_config = LaunchConfig(
dispatcher,
args,
griddim,
blockdim,
stream,
sharedmem,
pre_launch_callbacks,
)
token = _launch_config_var.set(launch_config)
try:
yield launch_config
finally:
_launch_config_var.reset(token)