@@ -47,7 +47,8 @@ cdef class DeviceBuffer:
4747 def __cinit__ (self , *,
4848 uintptr_t ptr = 0 ,
4949 size_t size = 0 ,
50- Stream stream = DEFAULT_STREAM):
50+ Stream stream = DEFAULT_STREAM,
51+ DeviceMemoryResource mr = None ):
5152 """ Construct a ``DeviceBuffer`` with optional size and data pointer
5253
5354 Parameters
@@ -64,6 +65,9 @@ cdef class DeviceBuffer:
6465 scope while the DeviceBuffer is in use. Destroying the
6566 underlying stream while the DeviceBuffer is in use will
6667 result in undefined behavior.
68+ mr : optional
69+ Memory resource to use to allocate memory for the underlying
70+ ``device_buffer``.
6771
6872 Note
6973 ----
@@ -77,22 +81,31 @@ cdef class DeviceBuffer:
7781 >>> db = rmm.DeviceBuffer(size=5)
7882 """
7983 cdef const void * c_ptr
84+ cdef device_memory_resource* c_mr
85+
86+ # Use default memory resource if none is specified.
87+ # Also get C++ representation to call constructor below.
88+ if mr is None :
89+ mr = get_current_device_resource()
90+ c_mr = mr.get_mr()
8091
8192 with nogil:
8293 c_ptr = < const void * > ptr
8394
8495 if size == 0 :
85- self .c_obj.reset(new device_buffer())
96+ self .c_obj.reset(new device_buffer(c_mr ))
8697 elif c_ptr == NULL :
87- self .c_obj.reset(new device_buffer(size, stream.view()))
98+ self .c_obj.reset(new device_buffer(size, stream.view(), c_mr ))
8899 else :
89- self .c_obj.reset(new device_buffer(c_ptr, size, stream.view()))
100+ self .c_obj.reset(
101+ new device_buffer(c_ptr, size, stream.view(), c_mr)
102+ )
90103
91104 if stream.c_is_default():
92105 stream.c_synchronize()
93106
94107 # Save a reference to the MR and stream used for allocation
95- self .mr = get_current_device_resource()
108+ self .mr = mr
96109 self .stream = stream
97110
98111 def __len__ (self ):
0 commit comments