1- import gc
1+ import unittest
22
33import numpy as np
44
5- from numba import cuda
5+ from numba import cuda , config
66from numba .cuda .cudadrv .linkable_code import CUSource
77from numba .cuda .testing import CUDATestCase
88
99from cuda .bindings .driver import cuModuleGetGlobal , cuMemcpyHtoD
1010
1111
12+ def wipe_all_modules_in_context ():
13+ ctx = cuda .current_context ()
14+ ctx .modules .clear ()
15+
16+
17+ @unittest .skipIf (
18+ config .CUDA_USE_NVIDIA_BINDING ,
19+ "NV binding support superceded by cuda.bindings."
20+ )
1221class TestModuleCallbacksBasic (CUDATestCase ):
1322
1423 def test_basic (self ):
@@ -33,13 +42,9 @@ def kernel():
3342 self .assertEqual (counter , 1 )
3443 kernel [1 , 1 ]() # cached
3544 self .assertEqual (counter , 1 )
36- breakpoint ()
37- del kernel
38- gc .collect ()
39- cuda .current_context ().deallocations .clear ()
45+
46+ wipe_all_modules_in_context ()
4047 self .assertEqual (counter , 0 )
41- # We don't have a way to explicitly evict kernel and its modules at
42- # the moment.
4348
4449 def test_different_argtypes (self ):
4550 counter = 0
@@ -66,11 +71,8 @@ def kernel(arg):
6671 kernel [1 , 1 ](3.14 ) # (float64)->() : module 2
6772 self .assertEqual (counter , 2 )
6873
69- # del kernel
70- # gc.collect()
71- # cuda.current_context().deallocations.clear()
72- # self.assertEqual(counter, 0) # We don't have a way to explicitly
73- # evict kernel and its modules at the moment.
74+ wipe_all_modules_in_context ()
75+ self .assertEqual (counter , 0 )
7476
7577 def test_two_kernels (self ):
7678 counter = 0
@@ -98,11 +100,8 @@ def kernel2():
98100 kernel2 [1 , 1 ]()
99101 self .assertEqual (counter , 2 )
100102
101- # del kernel
102- # gc.collect()
103- # cuda.current_context().deallocations.clear()
104- # self.assertEqual(counter, 0) # We don't have a way to explicitly
105- # evict kernel and its modules at the moment.
103+ wipe_all_modules_in_context ()
104+ self .assertEqual (counter , 0 )
106105
107106
108107class TestModuleCallbacks (CUDATestCase ):
@@ -137,10 +136,6 @@ def teardown(mod, stream):
137136 self .lib = CUSource (
138137 module , setup_callback = set_forty_two , teardown_callback = teardown )
139138
140- def tearDown (self ):
141- super ().tearDown ()
142- del self .lib
143-
144139 def test_decldevice_arg (self ):
145140 get_num = cuda .declare_device ("get_num" , "int32()" , link = [self .lib ])
146141
0 commit comments