Skip to content

Commit df8f261

Browse files
authored
refactor: decouple Context from Stream and Event objects (#579)
Remove some unnecessary weak reference usage by decoupling `Context` objects from `Stream` and `Event`.
1 parent 5389798 commit df8f261

File tree

2 files changed

+17
-31
lines changed

2 files changed

+17
-31
lines changed

ci/test_conda.sh

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,6 @@
44

55
set -euo pipefail
66

7-
DISTRO=`cat /etc/os-release | grep "^ID=" | awk 'BEGIN {FS="="} { print $2 }'`
8-
9-
if [ "$DISTRO" = "ubuntu" ]; then
10-
apt-get update
11-
apt remove --purge `dpkg --get-selections | grep cuda-nvvm | awk '{print $1}'` -y
12-
apt remove --purge `dpkg --get-selections | grep cuda-nvrtc | awk '{print $1}'` -y
13-
fi
14-
157
# Constrain oldest supported dependencies for testing
168
if [ "${NUMBA_VERSION:-*}" != "*" ]; then
179
# add to the default environment's dependencies

numba_cuda/numba/cuda/cudadrv/driver.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1305,19 +1305,19 @@ def unload_module(self, module):
13051305

13061306
def get_default_stream(self):
13071307
handle = drvapi.cu_stream(int(binding.CUstream(CU_STREAM_DEFAULT)))
1308-
return Stream(weakref.proxy(self), handle, None)
1308+
return Stream(handle)
13091309

13101310
def get_legacy_default_stream(self):
13111311
handle = drvapi.cu_stream(
13121312
int(binding.CUstream(binding.CU_STREAM_LEGACY))
13131313
)
1314-
return Stream(weakref.proxy(self), handle, None)
1314+
return Stream(handle)
13151315

13161316
def get_per_thread_default_stream(self):
13171317
handle = drvapi.cu_stream(
13181318
int(binding.CUstream(binding.CU_STREAM_PER_THREAD))
13191319
)
1320-
return Stream(weakref.proxy(self), handle, None)
1320+
return Stream(handle)
13211321

13221322
def create_stream(self):
13231323
# The default stream creation flag, specifying that the created
@@ -1327,26 +1327,22 @@ def create_stream(self):
13271327
flags = binding.CUstream_flags.CU_STREAM_DEFAULT.value
13281328
handle = drvapi.cu_stream(int(driver.cuStreamCreate(flags)))
13291329
return Stream(
1330-
weakref.proxy(self),
1331-
handle,
1332-
_stream_finalizer(self.deallocations, handle),
1330+
handle, finalizer=_stream_finalizer(self.deallocations, handle)
13331331
)
13341332

13351333
def create_external_stream(self, ptr):
13361334
if not isinstance(ptr, int):
13371335
raise TypeError("ptr for external stream must be an int")
13381336
handle = drvapi.cu_stream(int(binding.CUstream(ptr)))
1339-
return Stream(weakref.proxy(self), handle, None, external=True)
1337+
return Stream(handle, external=True)
13401338

13411339
def create_event(self, timing=True):
13421340
flags = 0
13431341
if not timing:
13441342
flags |= enums.CU_EVENT_DISABLE_TIMING
13451343
handle = drvapi.cu_event(int(driver.cuEventCreate(flags)))
13461344
return Event(
1347-
weakref.proxy(self),
1348-
handle,
1349-
finalizer=_event_finalizer(self.deallocations, handle),
1345+
handle, finalizer=_event_finalizer(self.deallocations, handle)
13501346
)
13511347

13521348
def synchronize(self):
@@ -1359,7 +1355,7 @@ def defer_cleanup(self):
13591355
yield
13601356

13611357
def __repr__(self):
1362-
return "<CUDA context %s of device %d>" % (self.handle, self.device.id)
1358+
return f"<CUDA context {self.handle} of device {self.device.id:d}>"
13631359

13641360
def __eq__(self, other):
13651361
if isinstance(other, Context):
@@ -2034,9 +2030,8 @@ class ManagedOwnedPointer(OwnedPointer, mviewbuf.MemAlloc):
20342030
pass
20352031

20362032

2037-
class Stream(object):
2038-
def __init__(self, context, handle, finalizer, external=False):
2039-
self.context = context
2033+
class Stream:
2034+
def __init__(self, handle, finalizer=None, external=False):
20402035
self.handle = handle
20412036
self.external = external
20422037
if finalizer is not None:
@@ -2053,18 +2048,18 @@ def __cuda_stream__(self):
20532048

20542049
def __repr__(self):
20552050
default_streams = {
2056-
drvapi.CU_STREAM_DEFAULT: "<Default CUDA stream on %s>",
2057-
drvapi.CU_STREAM_LEGACY: "<Legacy default CUDA stream on %s>",
2058-
drvapi.CU_STREAM_PER_THREAD: "<Per-thread default CUDA stream on %s>",
2051+
drvapi.CU_STREAM_DEFAULT: "<Default CUDA stream>",
2052+
drvapi.CU_STREAM_LEGACY: "<Legacy default CUDA stream>",
2053+
drvapi.CU_STREAM_PER_THREAD: "<Per-thread default CUDA stream>",
20592054
}
20602055
ptr = self.handle.value or drvapi.CU_STREAM_DEFAULT
20612056

20622057
if ptr in default_streams:
2063-
return default_streams[ptr] % self.context
2058+
return default_streams[ptr]
20642059
elif self.external:
2065-
return "<External CUDA stream %d on %s>" % (ptr, self.context)
2060+
return f"<External CUDA stream {ptr:d}>"
20662061
else:
2067-
return "<CUDA stream %d on %s>" % (ptr, self.context)
2062+
return f"<CUDA stream {ptr:d}>"
20682063

20692064
def synchronize(self):
20702065
"""
@@ -2166,9 +2161,8 @@ def callback(stream, status, future):
21662161
return future
21672162

21682163

2169-
class Event(object):
2170-
def __init__(self, context, handle, finalizer=None):
2171-
self.context = context
2164+
class Event:
2165+
def __init__(self, handle, finalizer=None):
21722166
self.handle = handle
21732167
if finalizer is not None:
21742168
weakref.finalize(self, finalizer)

0 commit comments

Comments
 (0)