Skip to content

Commit b329831

Browse files
committed
removing kernel finalizers
1 parent d05f76b commit b329831

File tree

2 files changed

+17
-39
lines changed

2 files changed

+17
-39
lines changed

numba_cuda/numba/cuda/dispatcher.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
import weakref
32
import os
43
import re
54
import sys
@@ -1025,7 +1024,6 @@ def compile(self, sig):
10251024
raise RuntimeError("Compilation disabled")
10261025

10271026
kernel = _Kernel(self.py_func, argtypes, **self.targetoptions)
1028-
weakref.finalize(kernel, _kernel_finalize_callback, kernel)
10291027
# We call bind to force codegen, so that there is a cubin to cache
10301028
kernel.bind()
10311029
self._cache.save_overload(sig, kernel)
@@ -1151,18 +1149,3 @@ def _reduce_states(self):
11511149
"""
11521150
return dict(py_func=self.py_func,
11531151
targetoptions=self.targetoptions)
1154-
1155-
1156-
def _kernel_finalize_callback(kernel):
1157-
module = kernel.library.get_cufunc().module
1158-
try:
1159-
if driver.USE_NV_BINDING:
1160-
key = module.handle
1161-
else:
1162-
key = module.handle.value
1163-
except ReferenceError:
1164-
return
1165-
1166-
ctx = cuda.current_context()
1167-
if key in ctx.modules:
1168-
del ctx.modules[key]

numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
1-
import gc
1+
import unittest
22

33
import numpy as np
44

5-
from numba import cuda
5+
from numba import cuda, config
66
from numba.cuda.cudadrv.linkable_code import CUSource
77
from numba.cuda.testing import CUDATestCase
88

99
from cuda.bindings.driver import cuModuleGetGlobal, cuMemcpyHtoD
1010

1111

12+
def wipe_all_modules_in_context():
13+
ctx = cuda.current_context()
14+
ctx.modules.clear()
15+
16+
17+
@unittest.skipIf(
18+
config.CUDA_USE_NVIDIA_BINDING,
19+
"NV binding support superceded by cuda.bindings."
20+
)
1221
class TestModuleCallbacksBasic(CUDATestCase):
1322

1423
def test_basic(self):
@@ -33,13 +42,9 @@ def kernel():
3342
self.assertEqual(counter, 1)
3443
kernel[1, 1]() # cached
3544
self.assertEqual(counter, 1)
36-
breakpoint()
37-
del kernel
38-
gc.collect()
39-
cuda.current_context().deallocations.clear()
45+
46+
wipe_all_modules_in_context()
4047
self.assertEqual(counter, 0)
41-
# We don't have a way to explicitly evict kernel and its modules at
42-
# the moment.
4348

4449
def test_different_argtypes(self):
4550
counter = 0
@@ -66,11 +71,8 @@ def kernel(arg):
6671
kernel[1, 1](3.14) # (float64)->() : module 2
6772
self.assertEqual(counter, 2)
6873

69-
# del kernel
70-
# gc.collect()
71-
# cuda.current_context().deallocations.clear()
72-
# self.assertEqual(counter, 0) # We don't have a way to explicitly
73-
# evict kernel and its modules at the moment.
74+
wipe_all_modules_in_context()
75+
self.assertEqual(counter, 0)
7476

7577
def test_two_kernels(self):
7678
counter = 0
@@ -98,11 +100,8 @@ def kernel2():
98100
kernel2[1, 1]()
99101
self.assertEqual(counter, 2)
100102

101-
# del kernel
102-
# gc.collect()
103-
# cuda.current_context().deallocations.clear()
104-
# self.assertEqual(counter, 0) # We don't have a way to explicitly
105-
# evict kernel and its modules at the moment.
103+
wipe_all_modules_in_context()
104+
self.assertEqual(counter, 0)
106105

107106

108107
class TestModuleCallbacks(CUDATestCase):
@@ -137,10 +136,6 @@ def teardown(mod, stream):
137136
self.lib = CUSource(
138137
module, setup_callback=set_forty_two, teardown_callback=teardown)
139138

140-
def tearDown(self):
141-
super().tearDown()
142-
del self.lib
143-
144139
def test_decldevice_arg(self):
145140
get_num = cuda.declare_device("get_num", "int32()", link=[self.lib])
146141

0 commit comments

Comments
 (0)