Skip to content

Commit d05f76b

Browse files
committed
add kernel finalizer
1 parent 51dd293 commit d05f76b

File tree

3 files changed

+28
-10
lines changed

3 files changed

+28
-10
lines changed

numba_cuda/numba/cuda/cudadrv/managed_module.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
import weakref
22

3-
from numba import config
43
from . import devices
5-
from .driver import CtypesModule
4+
from .driver import CtypesModule, USE_NV_BINDING
65

7-
USE_NV_BINDING = config.CUDA_USE_NVIDIA_BINDING
86

9-
10-
class CuFuncProxy:
7+
class _CuFuncProxy:
118
def __init__(self, module, cufunc):
129
self._module = module
1310
self._cufunc = cufunc
@@ -73,7 +70,7 @@ def lazy_callback(callbacks, module, stream):
7370

7471
def get_function(self, name):
7572
ctypesfunc = self._module.get_function(name)
76-
return CuFuncProxy(self, ctypesfunc)
73+
return _CuFuncProxy(self, ctypesfunc)
7774

7875
def __getattr__(self, name):
7976
if name == "get_function":

numba_cuda/numba/cuda/dispatcher.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import weakref
23
import os
34
import re
45
import sys
@@ -247,6 +248,7 @@ def _rebuild(cls, cooperative, name, signature, codelibrary,
247248
instance.lineinfo = lineinfo
248249
instance.call_helper = call_helper
249250
instance.extensions = extensions
251+
instance.initialized = False
250252
return instance
251253

252254
def _reduce_states(self):
@@ -1023,6 +1025,7 @@ def compile(self, sig):
10231025
raise RuntimeError("Compilation disabled")
10241026

10251027
kernel = _Kernel(self.py_func, argtypes, **self.targetoptions)
1028+
weakref.finalize(kernel, _kernel_finalize_callback, kernel)
10261029
# We call bind to force codegen, so that there is a cubin to cache
10271030
kernel.bind()
10281031
self._cache.save_overload(sig, kernel)
@@ -1148,3 +1151,18 @@ def _reduce_states(self):
11481151
"""
11491152
return dict(py_func=self.py_func,
11501153
targetoptions=self.targetoptions)
1154+
1155+
1156+
def _kernel_finalize_callback(kernel):
1157+
module = kernel.library.get_cufunc().module
1158+
try:
1159+
if driver.USE_NV_BINDING:
1160+
key = module.handle
1161+
else:
1162+
key = module.handle.value
1163+
except ReferenceError:
1164+
return
1165+
1166+
ctx = cuda.current_context()
1167+
if key in ctx.modules:
1168+
del ctx.modules[key]

numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import gc
2+
13
import numpy as np
24

35
from numba import cuda
@@ -31,10 +33,11 @@ def kernel():
3133
self.assertEqual(counter, 1)
3234
kernel[1, 1]() # cached
3335
self.assertEqual(counter, 1)
34-
# del kernel
35-
# gc.collect()
36-
# cuda.current_context().deallocations.clear()
37-
# self.assertEqual(counter, 0)
36+
breakpoint()
37+
del kernel
38+
gc.collect()
39+
cuda.current_context().deallocations.clear()
40+
self.assertEqual(counter, 0)
3841
# We don't have a way to explicitly evict kernel and its modules at
3942
# the moment.
4043

0 commit comments

Comments
 (0)