Skip to content

Commit bb1cf0f

Browse files
committed
comments and clean up
1 parent 5273e4a commit bb1cf0f

File tree

1 file changed

+75
-13
lines changed
  • numba_cuda/numba/cuda/runtime

1 file changed

+75
-13
lines changed

numba_cuda/numba/cuda/runtime/nrt.py

Lines changed: 75 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,16 @@
1010
from numba.cuda.utils import _readenv
1111

1212

13+
# Check environment variable or config for NRT statistics enablement
1314
NRT_STATS = (
1415
_readenv("NUMBA_CUDA_NRT_STATS", bool, False) or
1516
getattr(config, "NUMBA_CUDA_NRT_STATS", False)
1617
)
1718
if not hasattr(config, "NUMBA_CUDA_NRT_STATS"):
1819
config.CUDA_NRT_STATS = NRT_STATS
1920

21+
22+
# Check environment variable or config for NRT enablement
2023
ENABLE_NRT = (
2124
_readenv("NUMBA_CUDA_ENABLE_NRT", bool, False) or
2225
getattr(config, "NUMBA_CUDA_ENABLE_NRT", False)
@@ -25,7 +28,11 @@
2528
config.CUDA_ENABLE_NRT = ENABLE_NRT
2629

2730

31+
# Protect method to ensure NRT memory allocation and initialization
2832
def _alloc_init_guard(method):
33+
"""
34+
Ensure NRT memory allocation and initialization before running the method
35+
"""
2936
@wraps(method)
3037
def wrapper(self, *args, **kwargs):
3138
self.ensure_allocated()
@@ -35,6 +42,7 @@ def wrapper(self, *args, **kwargs):
3542

3643

3744
class _Runtime:
45+
"""Singleton class for Numba CUDA runtime"""
3846
_instance = None
3947

4048
def __new__(cls, *args, **kwargs):
@@ -43,47 +51,64 @@ def __new__(cls, *args, **kwargs):
4351
return cls._instance
4452

4553
def __init__(self):
54+
"""Initialize memsys module and variable"""
4655
self._memsys_module = None
4756
self._memsys = None
48-
4957
self._initialized = False
5058

5159
def _compile_memsys_module(self):
60+
"""
61+
Compile memsys.cu and create a module from it in the current context
62+
"""
63+
# Define the path for memsys.cu
5264
memsys_mod = os.path.join(
5365
os.path.dirname(os.path.abspath(__file__)),
5466
"memsys.cu"
5567
)
5668
cc = get_current_device().compute_capability
5769

70+
# Create a new linker instance and add the cu file
5871
linker = Linker.new(cc=cc)
5972
linker.add_cu_file(memsys_mod)
60-
cubin = linker.complete()
6173

74+
# Complete the linker and create a module from it
75+
cubin = linker.complete()
6276
ctx = devices.get_context()
6377
module = ctx.create_module_image(cubin)
6478

79+
# Set the memsys module
6580
self._memsys_module = module
6681

6782
def ensure_allocated(self, stream=None):
83+
"""
84+
If memsys is not allocated, allocate it; otherwise, perform a no-op
85+
"""
6886
if self._memsys is not None:
6987
return
7088

89+
# Allocate the memsys
7190
self.allocate(stream)
7291

7392
def allocate(self, stream=None):
93+
"""
94+
Allocate memsys on global memory
95+
"""
7496
from numba.cuda import device_array
7597

98+
# Check if memsys module is defined
7699
if self._memsys_module is None:
100+
# Compile the memsys module if not defined
77101
self._compile_memsys_module()
78102

79103
# Allocate space for NRT_MemSys
80104
# TODO: determine the size of NRT_MemSys at runtime
81105
self._memsys = device_array((40,), dtype="i1", stream=stream)
82-
# TODO: Memsys module needs a stream that's consistent with the
83-
# system's stream.
84106
self.set_memsys_to_module(self._memsys_module, stream=stream)
85107

86108
def _single_thread_launch(self, module, stream, name, params=()):
109+
"""
110+
Launch the specified kernel with only 1 thread
111+
"""
87112
if stream is None:
88113
stream = cuda.default_stream()
89114

@@ -99,33 +124,45 @@ def _single_thread_launch(self, module, stream, name, params=()):
99124
)
100125

101126
def ensure_initialized(self, stream=None):
127+
"""
128+
If memsys is not initialized, initialize memsys
129+
"""
102130
if self._initialized:
103131
return
104132

133+
# Initialize the memsys
105134
self.initialize(stream)
106135

107136
def initialize(self, stream=None):
108-
self.ensure_allocated(stream)
109-
137+
"""
138+
Launch memsys initialization kernel
139+
"""
110140
self._single_thread_launch(
111141
self._memsys_module, stream, "NRT_MemSys_init")
112142
self._initialized = True
113143

114-
if NRT_STATS:
115-
self.memsys_enable_stats(stream)
116-
117144
@_alloc_init_guard
118145
def memsys_enable_stats(self, stream=None):
146+
"""
147+
Enable memsys statistics
148+
"""
119149
self._single_thread_launch(
120150
self._memsys_module, stream, "NRT_MemSys_enable_stats")
121151

122152
@_alloc_init_guard
123153
def memsys_disable_stats(self, stream=None):
154+
"""
155+
Disable memsys statistics
156+
"""
124157
self._single_thread_launch(
125158
self._memsys_module, stream, "NRT_MemSys_disable_stats")
126159

127160
@_alloc_init_guard
128161
def memsys_stats_enabled(self, stream=None):
162+
"""
163+
Return a boolean indicating whether memsys is enabled. Synchronizes
164+
context
165+
"""
129166
enabled_ar = cuda.managed_array(1, np.uint8)
130167

131168
self._single_thread_launch(
@@ -140,8 +177,9 @@ def memsys_stats_enabled(self, stream=None):
140177

141178
@_alloc_init_guard
142179
def _copy_memsys_to_host(self, stream):
143-
144-
# Q: What stream should we execute this on?
180+
"""
181+
Copy all statistics of memsys to the host
182+
"""
145183
dt = np.dtype([
146184
('alloc', np.uint64),
147185
('free', np.uint64),
@@ -163,6 +201,9 @@ def _copy_memsys_to_host(self, stream):
163201

164202
@_alloc_init_guard
165203
def get_allocation_stats(self, stream=None):
204+
"""
205+
Get the allocation statistics
206+
"""
166207
enabled = self.memsys_stats_enabled(stream)
167208
if not enabled:
168209
raise RuntimeError("NRT stats are disabled.")
@@ -176,6 +217,9 @@ def get_allocation_stats(self, stream=None):
176217

177218
@_alloc_init_guard
178219
def _get_single_stat(self, stat, stream=None):
220+
"""
221+
Get a single stat from the memsys
222+
"""
179223
got = cuda.managed_array(1, np.uint64)
180224
self._single_thread_launch(
181225
self._memsys_module,
@@ -189,6 +233,9 @@ def _get_single_stat(self, stat, stream=None):
189233

190234
@_alloc_init_guard
191235
def memsys_get_stats_alloc(self, stream=None):
236+
"""
237+
Get the allocation statistic
238+
"""
192239
enabled = self.memsys_stats_enabled(stream)
193240
if not enabled:
194241
raise RuntimeError("NRT stats are disabled.")
@@ -197,6 +244,9 @@ def memsys_get_stats_alloc(self, stream=None):
197244

198245
@_alloc_init_guard
199246
def memsys_get_stats_free(self, stream=None):
247+
"""
248+
Get the free statistic
249+
"""
200250
enabled = self.memsys_stats_enabled(stream)
201251
if not enabled:
202252
raise RuntimeError("NRT stats are disabled.")
@@ -205,6 +255,9 @@ def memsys_get_stats_free(self, stream=None):
205255

206256
@_alloc_init_guard
207257
def memsys_get_stats_mi_alloc(self, stream=None):
258+
"""
259+
Get the mi alloc statistic
260+
"""
208261
enabled = self.memsys_stats_enabled(stream)
209262
if not enabled:
210263
raise RuntimeError("NRT stats are disabled.")
@@ -213,18 +266,24 @@ def memsys_get_stats_mi_alloc(self, stream=None):
213266

214267
@_alloc_init_guard
215268
def memsys_get_stats_mi_free(self, stream=None):
269+
"""
270+
Get the mi free statistic
271+
"""
216272
enabled = self.memsys_stats_enabled(stream)
217273
if not enabled:
218274
raise RuntimeError("NRT stats are disabled.")
219275

220276
return self._get_single_stat("mi_free")
221277

222278
def set_memsys_to_module(self, module, stream=None):
279+
"""
280+
Set the memsys module. The module must contain `NRT_MemSys_set` kernel,
281+
and declare a pointer to NRT_MemSys structure.
282+
"""
223283
if self._memsys is None:
224284
raise RuntimeError(
225285
"Please allocate NRT Memsys first before initializing.")
226286

227-
print(f"Setting {self._memsys.device_ctypes_pointer} to {module}")
228287
self._single_thread_launch(
229288
module,
230289
stream,
@@ -234,7 +293,9 @@ def set_memsys_to_module(self, module, stream=None):
234293

235294
@_alloc_init_guard
236295
def print_memsys(self, stream=None):
237-
"""Print the current statistics of memsys, for debugging purpose."""
296+
"""
297+
Print the current statistics of memsys, for debugging purposes
298+
"""
238299
cuda.synchronize()
239300
self._single_thread_launch(
240301
self._memsys_module,
@@ -243,4 +304,5 @@ def print_memsys(self, stream=None):
243304
)
244305

245306

307+
# Create an instance of the runtime
246308
rtsys = _Runtime()

0 commit comments

Comments
 (0)