Skip to content

Commit 548b353

Browse files
committed
Use memory resource in DeviceBuffer construtor
Update the Python constructor to take and handle a `DeviceMemoryResource` argument. Also pass this through to `device_buffer` constructors.
1 parent fb75182 commit 548b353

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

python/rmm/_lib/device_buffer.pyx

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ cdef class DeviceBuffer:
4747
def __cinit__(self, *,
4848
uintptr_t ptr=0,
4949
size_t size=0,
50-
Stream stream=DEFAULT_STREAM):
50+
Stream stream=DEFAULT_STREAM,
51+
DeviceMemoryResource mr=None):
5152
"""Construct a ``DeviceBuffer`` with optional size and data pointer
5253
5354
Parameters
@@ -64,6 +65,9 @@ cdef class DeviceBuffer:
6465
scope while the DeviceBuffer is in use. Destroying the
6566
underlying stream while the DeviceBuffer is in use will
6667
result in undefined behavior.
68+
mr : optional
69+
Memory resource to use to allocate memory for the underlying
70+
``device_buffer``.
6771
6872
Note
6973
----
@@ -77,22 +81,31 @@ cdef class DeviceBuffer:
7781
>>> db = rmm.DeviceBuffer(size=5)
7882
"""
7983
cdef const void* c_ptr
84+
cdef device_memory_resource* c_mr
85+
86+
# Use default memory resource if none is specified.
87+
# Also get C++ representation to call constructor below.
88+
if mr is None:
89+
mr = get_current_device_resource()
90+
c_mr = mr.get_mr()
8091

8192
with nogil:
8293
c_ptr = <const void*>ptr
8394

8495
if size == 0:
85-
self.c_obj.reset(new device_buffer())
96+
self.c_obj.reset(new device_buffer(c_mr))
8697
elif c_ptr == NULL:
87-
self.c_obj.reset(new device_buffer(size, stream.view()))
98+
self.c_obj.reset(new device_buffer(size, stream.view(), c_mr))
8899
else:
89-
self.c_obj.reset(new device_buffer(c_ptr, size, stream.view()))
100+
self.c_obj.reset(
101+
new device_buffer(c_ptr, size, stream.view(), c_mr)
102+
)
90103

91104
if stream.c_is_default():
92105
stream.c_synchronize()
93106

94107
# Save a reference to the MR and stream used for allocation
95-
self.mr = get_current_device_resource()
108+
self.mr = mr
96109
self.stream = stream
97110

98111
def __len__(self):

0 commit comments

Comments
 (0)