Skip to content

Commit 3d31059

Browse files
committed
fix: remove unused unload API from Module and decouple from Context to fix weakref lifetime issue
1 parent 2309b7d commit 3d31059

File tree

3 files changed

+43
-11
lines changed

3 files changed

+43
-11
lines changed

numba_cuda/numba/cuda/cudadrv/driver.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1321,7 +1321,7 @@ def create_module_image(
13211321
)
13221322
key = module.handle
13231323
self.modules[key] = module
1324-
return weakref.proxy(module)
1324+
return module
13251325

13261326
def unload_module(self, module):
13271327
key = module.handle
@@ -1432,7 +1432,6 @@ def load_module_image_ctypes(
14321432
info_log = jitinfo.value
14331433

14341434
return CtypesModule(
1435-
weakref.proxy(context),
14361435
handle,
14371436
info_log,
14381437
_module_finalizer(context, handle),
@@ -1476,7 +1475,6 @@ def load_module_image_cuda_python(
14761475
info_log = jitinfo.decode("utf-8")
14771476

14781477
return CudaPythonModule(
1479-
weakref.proxy(context),
14801478
handle,
14811479
info_log,
14821480
_module_finalizer(context, handle),
@@ -2251,14 +2249,12 @@ class Module(metaclass=ABCMeta):
22512249

22522250
def __init__(
22532251
self,
2254-
context,
22552252
handle,
22562253
info_log,
22572254
finalizer=None,
22582255
setup_callbacks=None,
22592256
teardown_callbacks=None,
22602257
):
2261-
self.context = context
22622258
self.handle = handle
22632259
self.info_log = info_log
22642260
if finalizer is not None:
@@ -2270,10 +2266,6 @@ def __init__(
22702266

22712267
self._set_finalizers()
22722268

2273-
def unload(self):
2274-
"""Unload this module from the context"""
2275-
self.context.unload_module(self)
2276-
22772269
@abstractmethod
22782270
def get_function(self, name):
22792271
"""Returns a Function object encapsulating the named function"""

numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,6 @@ def test_cuda_driver_basic(self):
119119
for i, v in enumerate(array):
120120
self.assertEqual(i, v)
121121

122-
module.unload()
123-
124122
def test_cuda_driver_stream_operations(self):
125123
module = self.context.create_module_ptx(self.ptx)
126124
function = module.get_function("_Z10helloworldPi")
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: BSD-2-Clause
3+
4+
import pytest
5+
6+
from numba.cuda.cudadrv import devices
7+
from numba.cuda.memory_management.nrt import rtsys
8+
9+
10+
@pytest.fixture
11+
def alloc_init():
12+
rtsys.ensure_allocated()
13+
rtsys.ensure_initialized()
14+
15+
16+
@pytest.fixture
17+
def ctx(alloc_init):
18+
ctx = devices.get_context()
19+
yield ctx
20+
ctx.reset()
21+
22+
# this fails with a ReferenceError
23+
str(rtsys._memsys_module)
24+
25+
26+
def test_nothing(ctx):
27+
"""
28+
This test _must_ remain in a separate module, otherwise it is not reproducible.
29+
30+
The sequence that previously caused the failure is:
31+
32+
1. the alloc_init fixture creates a weakref proxy to a CudaPythonModule
33+
(rtsys._memsys_module) from devices.get_context()
34+
2. ctx.reset() calls dict.clear() on the context's module dict
35+
3. since there's only a single strong ref to `rtsys._memsys_module` it gets collected
36+
during the clear
37+
4. All ops on the module (a weakref proxy) fail, because it is now
38+
referencing a dead object
39+
40+
The solution was to decouple `Context` from `Module`, allowing a strong
41+
reference to `Module` to exist, cleaned up with the usual mechanisms.
42+
"""

0 commit comments

Comments
 (0)