33from dataclasses import dataclass
44from contextvars import ContextVar
55from contextlib import contextmanager
6- from typing import Any , Tuple , Optional , TYPE_CHECKING
6+ from typing import (
7+ Any ,
8+ Callable ,
9+ List ,
10+ Tuple ,
11+ Optional ,
12+ TYPE_CHECKING ,
13+ )
714
815if TYPE_CHECKING :
9- from numba .cuda .dispatcher import CUDADispatcher
16+ from numba .cuda .dispatcher import CUDADispatcher , _Kernel
1017
1118
1219@dataclass (frozen = True , slots = True )
@@ -22,14 +29,24 @@ class LaunchConfig:
2229 blockdim : Tuple [int , int , int ]
2330 stream : Any
2431 sharedmem : int
32+ pre_launch_callbacks : List [Callable [["_Kernel" , "LaunchConfig" ], None ]]
33+ """
34+ List of functions to call before launching a kernel. The functions are
35+ called with the kernel and the launch config as arguments. This enables
36+ just-in-time modifications to the kernel's configuration prior to launch,
37+ such as appending extensions for dynamic types that were created after the
38+ @cuda.jit decorator appeared (i.e. as part of rewriting).
39+ """
2540
2641 def __str__ (self ) -> str :
2742 a = ", " .join (map (str , self .args ))
2843 g = "×" .join (map (str , self .griddim ))
2944 b = "×" .join (map (str , self .blockdim ))
45+ cb = ", " .join (map (str , self .pre_launch_callbacks ))
3046 return (
3147 f"<LaunchConfig args=[{ a } ], grid={ g } , block={ b } , "
32- f"stream={ self .stream } , smem={ self .sharedmem } B>"
48+ f"stream={ self .stream } , smem={ self .sharedmem } B, "
49+ f"pre_launch_callbacks=[{ cb } ]>"
3350 )
3451
3552
@@ -72,10 +89,18 @@ def launch_config_ctx(
7289 Install a LaunchConfig for the dynamic extent of the with-block.
7390 The previous value (if any) is restored automatically.
7491 """
75- token = _launch_config_var .set (
76- LaunchConfig (dispatcher , args , griddim , blockdim , stream , sharedmem )
92+ pre_launch_callbacks = []
93+ launch_config = LaunchConfig (
94+ dispatcher ,
95+ args ,
96+ griddim ,
97+ blockdim ,
98+ stream ,
99+ sharedmem ,
100+ pre_launch_callbacks ,
77101 )
102+ token = _launch_config_var .set (launch_config )
78103 try :
79- yield
104+ yield launch_config
80105 finally :
81106 _launch_config_var .reset (token )
0 commit comments