1010from numba .cuda .utils import _readenv
1111
1212
13+ # Check environment variable or config for NRT statistics enablement
1314NRT_STATS = (
1415 _readenv ("NUMBA_CUDA_NRT_STATS" , bool , False ) or
1516 getattr (config , "NUMBA_CUDA_NRT_STATS" , False )
1617)
1718if not hasattr (config , "NUMBA_CUDA_NRT_STATS" ):
1819 config .CUDA_NRT_STATS = NRT_STATS
1920
21+
22+ # Check environment variable or config for NRT enablement
2023ENABLE_NRT = (
2124 _readenv ("NUMBA_CUDA_ENABLE_NRT" , bool , False ) or
2225 getattr (config , "NUMBA_CUDA_ENABLE_NRT" , False )
2528 config .CUDA_ENABLE_NRT = ENABLE_NRT
2629
2730
31+ # Protect method to ensure NRT memory allocation and initialization
2832def _alloc_init_guard (method ):
33+ """
34+ Ensure NRT memory allocation and initialization before running the method
35+ """
2936 @wraps (method )
3037 def wrapper (self , * args , ** kwargs ):
3138 self .ensure_allocated ()
@@ -35,6 +42,7 @@ def wrapper(self, *args, **kwargs):
3542
3643
3744class _Runtime :
45+ """Singleton class for Numba CUDA runtime"""
3846 _instance = None
3947
4048 def __new__ (cls , * args , ** kwargs ):
@@ -43,47 +51,64 @@ def __new__(cls, *args, **kwargs):
4351 return cls ._instance
4452
4553 def __init__ (self ):
54+ """Initialize memsys module and variable"""
4655 self ._memsys_module = None
4756 self ._memsys = None
48-
4957 self ._initialized = False
5058
5159 def _compile_memsys_module (self ):
60+ """
61+ Compile memsys.cu and create a module from it in the current context
62+ """
63+ # Define the path for memsys.cu
5264 memsys_mod = os .path .join (
5365 os .path .dirname (os .path .abspath (__file__ )),
5466 "memsys.cu"
5567 )
5668 cc = get_current_device ().compute_capability
5769
70+ # Create a new linker instance and add the cu file
5871 linker = Linker .new (cc = cc )
5972 linker .add_cu_file (memsys_mod )
60- cubin = linker .complete ()
6173
74+ # Complete the linker and create a module from it
75+ cubin = linker .complete ()
6276 ctx = devices .get_context ()
6377 module = ctx .create_module_image (cubin )
6478
79+ # Set the memsys module
6580 self ._memsys_module = module
6681
6782 def ensure_allocated (self , stream = None ):
83+ """
84+ If memsys is not allocated, allocate it; otherwise, perform a no-op
85+ """
6886 if self ._memsys is not None :
6987 return
7088
89+ # Allocate the memsys
7190 self .allocate (stream )
7291
7392 def allocate (self , stream = None ):
93+ """
94+ Allocate memsys on global memory
95+ """
7496 from numba .cuda import device_array
7597
98+ # Check if memsys module is defined
7699 if self ._memsys_module is None :
100+ # Compile the memsys module if not defined
77101 self ._compile_memsys_module ()
78102
79103 # Allocate space for NRT_MemSys
80104 # TODO: determine the size of NRT_MemSys at runtime
81105 self ._memsys = device_array ((40 ,), dtype = "i1" , stream = stream )
82- # TODO: Memsys module needs a stream that's consistent with the
83- # system's stream.
84106 self .set_memsys_to_module (self ._memsys_module , stream = stream )
85107
86108 def _single_thread_launch (self , module , stream , name , params = ()):
109+ """
110+ Launch the specified kernel with only 1 thread
111+ """
87112 if stream is None :
88113 stream = cuda .default_stream ()
89114
@@ -99,33 +124,45 @@ def _single_thread_launch(self, module, stream, name, params=()):
99124 )
100125
101126 def ensure_initialized (self , stream = None ):
127+ """
128+ If memsys is not initialized, initialize memsys
129+ """
102130 if self ._initialized :
103131 return
104132
133+ # Initialize the memsys
105134 self .initialize (stream )
106135
107136 def initialize (self , stream = None ):
108- self .ensure_allocated (stream )
109-
137+ """
138+ Launch memsys initialization kernel
139+ """
110140 self ._single_thread_launch (
111141 self ._memsys_module , stream , "NRT_MemSys_init" )
112142 self ._initialized = True
113143
114- if NRT_STATS :
115- self .memsys_enable_stats (stream )
116-
117144 @_alloc_init_guard
118145 def memsys_enable_stats (self , stream = None ):
146+ """
147+ Enable memsys statistics
148+ """
119149 self ._single_thread_launch (
120150 self ._memsys_module , stream , "NRT_MemSys_enable_stats" )
121151
122152 @_alloc_init_guard
123153 def memsys_disable_stats (self , stream = None ):
154+ """
155+ Disable memsys statistics
156+ """
124157 self ._single_thread_launch (
125158 self ._memsys_module , stream , "NRT_MemSys_disable_stats" )
126159
127160 @_alloc_init_guard
128161 def memsys_stats_enabled (self , stream = None ):
162+ """
163+ Return a boolean indicating whether memsys is enabled. Synchronizes
164+ context
165+ """
129166 enabled_ar = cuda .managed_array (1 , np .uint8 )
130167
131168 self ._single_thread_launch (
@@ -140,8 +177,9 @@ def memsys_stats_enabled(self, stream=None):
140177
141178 @_alloc_init_guard
142179 def _copy_memsys_to_host (self , stream ):
143-
144- # Q: What stream should we execute this on?
180+ """
181+ Copy all statistics of memsys to the host
182+ """
145183 dt = np .dtype ([
146184 ('alloc' , np .uint64 ),
147185 ('free' , np .uint64 ),
@@ -163,6 +201,9 @@ def _copy_memsys_to_host(self, stream):
163201
164202 @_alloc_init_guard
165203 def get_allocation_stats (self , stream = None ):
204+ """
205+ Get the allocation statistics
206+ """
166207 enabled = self .memsys_stats_enabled (stream )
167208 if not enabled :
168209 raise RuntimeError ("NRT stats are disabled." )
@@ -176,6 +217,9 @@ def get_allocation_stats(self, stream=None):
176217
177218 @_alloc_init_guard
178219 def _get_single_stat (self , stat , stream = None ):
220+ """
221+ Get a single stat from the memsys
222+ """
179223 got = cuda .managed_array (1 , np .uint64 )
180224 self ._single_thread_launch (
181225 self ._memsys_module ,
@@ -189,6 +233,9 @@ def _get_single_stat(self, stat, stream=None):
189233
190234 @_alloc_init_guard
191235 def memsys_get_stats_alloc (self , stream = None ):
236+ """
237+ Get the allocation statistic
238+ """
192239 enabled = self .memsys_stats_enabled (stream )
193240 if not enabled :
194241 raise RuntimeError ("NRT stats are disabled." )
@@ -197,6 +244,9 @@ def memsys_get_stats_alloc(self, stream=None):
197244
198245 @_alloc_init_guard
199246 def memsys_get_stats_free (self , stream = None ):
247+ """
248+ Get the free statistic
249+ """
200250 enabled = self .memsys_stats_enabled (stream )
201251 if not enabled :
202252 raise RuntimeError ("NRT stats are disabled." )
@@ -205,6 +255,9 @@ def memsys_get_stats_free(self, stream=None):
205255
206256 @_alloc_init_guard
207257 def memsys_get_stats_mi_alloc (self , stream = None ):
258+ """
259+ Get the mi alloc statistic
260+ """
208261 enabled = self .memsys_stats_enabled (stream )
209262 if not enabled :
210263 raise RuntimeError ("NRT stats are disabled." )
@@ -213,18 +266,24 @@ def memsys_get_stats_mi_alloc(self, stream=None):
213266
214267 @_alloc_init_guard
215268 def memsys_get_stats_mi_free (self , stream = None ):
269+ """
270+ Get the mi free statistic
271+ """
216272 enabled = self .memsys_stats_enabled (stream )
217273 if not enabled :
218274 raise RuntimeError ("NRT stats are disabled." )
219275
220276 return self ._get_single_stat ("mi_free" )
221277
222278 def set_memsys_to_module (self , module , stream = None ):
279+ """
280+ Set the memsys module. The module must contain `NRT_MemSys_set` kernel,
281+ and declare a pointer to NRT_MemSys structure.
282+ """
223283 if self ._memsys is None :
224284 raise RuntimeError (
225285 "Please allocate NRT Memsys first before initializing." )
226286
227- print (f"Setting { self ._memsys .device_ctypes_pointer } to { module } " )
228287 self ._single_thread_launch (
229288 module ,
230289 stream ,
@@ -234,7 +293,9 @@ def set_memsys_to_module(self, module, stream=None):
234293
235294 @_alloc_init_guard
236295 def print_memsys (self , stream = None ):
237- """Print the current statistics of memsys, for debugging purpose."""
296+ """
297+ Print the current statistics of memsys, for debugging purposes
298+ """
238299 cuda .synchronize ()
239300 self ._single_thread_launch (
240301 self ._memsys_module ,
@@ -243,4 +304,5 @@ def print_memsys(self, stream=None):
243304 )
244305
245306
307+ # Create an instance of the runtime
246308rtsys = _Runtime ()
0 commit comments