Skip to content

Commit 8ee39ad

Browse files
authored
Allow specifying mr in DeviceBuffer construction, and document ownership requirements in Python/C++ interfacing (#1552)
On the C++ side, device_buffers store raw pointers for the memory resource that was used in their allocation. Consequently, it is unsafe to take ownership of a device_buffer in Python unless we controlled the provenance of the memory resource that was used in its allocation. The only way to do that is to pass the memory resource from Python into C++ and then use it when constructing the DeviceBuffer. Add discussion of this with some examples and a section on pitfalls if only relying on get_current_device_resource and set_current_device_resource. To allow Python users of `DeviceBuffer` objects to follow best practices, introduce explicit (defaulting to the current device resource) `mr` arguments in both `c_from_unique_ptr` and the `DeviceBuffer` constructor. - Closes #1492 Authors: - Lawrence Mitchell (https://github.com/wence-) Approvers: - Mark Harris (https://github.com/harrism) - Vyas Ramasubramani (https://github.com/vyasr) URL: #1552
1 parent 32cd537 commit 8ee39ad

File tree

4 files changed

+134
-9
lines changed

4 files changed

+134
-9
lines changed

README.md

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ Python requirements:
7373
* `cuda-python`
7474
* `cython`
7575

76-
For more details, see [pyproject.toml](python/pyproject.toml)
76+
For more details, see [pyproject.toml](python/rmm/pyproject.toml)
7777

7878

7979
### Script to build RMM from source
@@ -855,3 +855,94 @@ Out[6]:
855855
'total_bytes': 16,
856856
'total_count': 1}
857857
```
858+
859+
## Taking ownership of C++ objects from Python.
860+
861+
When interacting with a C++ library that uses RMM from Python, one
862+
must be careful when taking ownership of `rmm::device_buffer` objects
863+
on the Python side. The `rmm::device_buffer` does not contain an
864+
owning reference to the memory resource used for its allocation (only
865+
a `device_async_resource_ref`), and the allocating user is expected to
866+
keep this memory resource alive for at least the lifetime of the
867+
buffer. When taking ownership of such a buffer in Python, we have no
868+
way (in the general case) of ensuring that the memory resource will
869+
outlive the buffer we are now holding.
870+
871+
To avoid any issues, we need two things:
872+
873+
1. The C++ library we are interfacing with should accept a memory
874+
resource that is used for allocations that are returned to the
875+
user.
876+
2. When calling into the library from python, we should provide a
877+
memory resource whose lifetime we control. This memory resource
878+
should then be provided when we take ownership of any allocated
879+
`rmm::device_buffer`s.
880+
881+
For example, suppose we have a C++ function that allocates
882+
`device_buffer`s, which has a utility overload that defaults the
883+
memory resource to the current device resource:
884+
885+
```c++
886+
std::unique_ptr<rmm::device_buffer> allocate(
887+
std::size_t size,
888+
rmm::mr::device_async_resource_ref mr = get_current_device_resource())
889+
{
890+
return std::make_unique<rmm::device_buffer>(size, rmm::cuda_stream_default, mr);
891+
}
892+
```
893+
894+
The Python `DeviceBuffer` class has a convenience Cython function,
895+
`c_from_unique_ptr` to construct a `DeviceBuffer` from a
896+
`unique_ptr<rmm::device_buffer>`, taking ownership of it. To do this
897+
safely, we must ensure that the allocation that was done on the C++
898+
side uses a memory resource we control. So:
899+
900+
```cython
901+
# Bad, doesn't control lifetime
902+
buffer_bad = DeviceBuffer.c_from_unique_ptr(allocate(10))
903+
904+
# Good, allocation happens with a memory resource we control
905+
# mr is a DeviceMemoryResource
906+
buffer_good = DeviceBuffer.c_from_unique_ptr(
907+
allocate(10, mr.get_mr()),
908+
mr=mr,
909+
)
910+
```
911+
912+
Note two differences between the bad and good cases:
913+
914+
1. In the good case we pass the memory resource to the allocation
915+
function.
916+
2. In the good case, we pass _the same_ memory resource to the
917+
`DeviceBuffer` constructor so that its lifetime is tied to the
918+
lifetime of the buffer.
919+
920+
### Potential pitfalls of relying on `get_current_device_resource`
921+
922+
Functions in both the C++ and Python APIs that perform allocation
923+
typically default the memory resource argument to the value of
924+
`get_current_device_resource`. This is to simplify the interface for
925+
callers. When using a C++ library from Python, this defaulting is
926+
safe, _as long as_ it is only the Python process that ever calls
927+
`set_current_device_resource`.
928+
929+
This is because the current device resource on the C++ side has a
930+
lifetime which is expected to be managed by the user. The resources
931+
set by `rmm::mr::set_current_device_resource` are stored in a static
932+
`std::map` whose keys are device ids and values are raw pointers to
933+
the memory resources. Consequently,
934+
`rmm::mr::get_current_device_resource` returns an object with no
935+
lifetime provenance. This is, for the reasons discussed above, not
936+
usable from Python. To handle this on the Python side, the
937+
Python-level `set_current_device_resource` sets the C++ resource _and_
938+
stores the Python object in a static global dictionary. The Python
939+
`get_current_device_resource` then _does not use_
940+
`rmm::mr::get_current_device_resource` and instead looks up the
941+
current device resource in this global dictionary.
942+
943+
Hence, if the C++ library we are interfacing with calls
944+
`rmm::mr::set_current_device_resource`, the C++ and Python sides of
945+
the program can disagree on what `get_current_device_resource`
946+
returns. The only safe thing to do if using the simplified interfaces
947+
is therefore to ensure that `set_current_device_resource` is only ever
948+
called on the Python side.

python/rmm/rmm/_lib/device_buffer.pxd

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2019-2020, NVIDIA CORPORATION.
1+
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -65,7 +65,8 @@ cdef class DeviceBuffer:
6565
@staticmethod
6666
cdef DeviceBuffer c_from_unique_ptr(
6767
unique_ptr[device_buffer] ptr,
68-
Stream stream=*
68+
Stream stream=*,
69+
DeviceMemoryResource mr=*,
6970
)
7071

7172
@staticmethod

python/rmm/rmm/_lib/device_buffer.pyx

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2019-2020, NVIDIA CORPORATION.
1+
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -33,6 +33,7 @@ from cuda.ccudart cimport (
3333
)
3434

3535
from rmm._lib.memory_resource cimport (
36+
DeviceMemoryResource,
3637
device_memory_resource,
3738
get_current_device_resource,
3839
)
@@ -48,7 +49,8 @@ cdef class DeviceBuffer:
4849
def __cinit__(self, *,
4950
uintptr_t ptr=0,
5051
size_t size=0,
51-
Stream stream=DEFAULT_STREAM):
52+
Stream stream=DEFAULT_STREAM,
53+
DeviceMemoryResource mr=None):
5254
"""Construct a ``DeviceBuffer`` with optional size and data pointer
5355
5456
Parameters
@@ -65,6 +67,9 @@ cdef class DeviceBuffer:
6567
scope while the DeviceBuffer is in use. Destroying the
6668
underlying stream while the DeviceBuffer is in use will
6769
result in undefined behavior.
70+
mr : optional
71+
DeviceMemoryResource for the allocation, if not provided
72+
defaults to the current device resource.
6873
6974
Note
7075
----
@@ -80,7 +85,7 @@ cdef class DeviceBuffer:
8085
cdef const void* c_ptr
8186
cdef device_memory_resource * mr_ptr
8287
# Save a reference to the MR and stream used for allocation
83-
self.mr = get_current_device_resource()
88+
self.mr = get_current_device_resource() if mr is None else mr
8489
self.stream = stream
8590

8691
mr_ptr = self.mr.get_mr()
@@ -162,13 +167,14 @@ cdef class DeviceBuffer:
162167
@staticmethod
163168
cdef DeviceBuffer c_from_unique_ptr(
164169
unique_ptr[device_buffer] ptr,
165-
Stream stream=DEFAULT_STREAM
170+
Stream stream=DEFAULT_STREAM,
171+
DeviceMemoryResource mr=None,
166172
):
167173
cdef DeviceBuffer buf = DeviceBuffer.__new__(DeviceBuffer)
168174
if stream.c_is_default():
169175
stream.c_synchronize()
170176
buf.c_obj = move(ptr)
171-
buf.mr = get_current_device_resource()
177+
buf.mr = get_current_device_resource() if mr is None else mr
172178
buf.stream = stream
173179
return buf
174180

python/rmm/rmm/tests/test_rmm.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
1+
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import copy
16+
import functools
1617
import gc
1718
import os
1819
import pickle
@@ -498,6 +499,32 @@ def test_mr_devicebuffer_lifetime():
498499
del a
499500

500501

502+
def test_device_buffer_with_mr():
503+
allocations = []
504+
base = rmm.mr.CudaMemoryResource()
505+
rmm.mr.set_current_device_resource(base)
506+
507+
def alloc_cb(size, stream, *, base):
508+
allocations.append(size)
509+
return base.allocate(size, stream)
510+
511+
def dealloc_cb(ptr, size, stream, *, base):
512+
return base.deallocate(ptr, size, stream)
513+
514+
cb_mr = rmm.mr.CallbackMemoryResource(
515+
functools.partial(alloc_cb, base=base),
516+
functools.partial(dealloc_cb, base=base),
517+
)
518+
rmm.DeviceBuffer(size=10)
519+
assert len(allocations) == 0
520+
buf = rmm.DeviceBuffer(size=256, mr=cb_mr)
521+
assert len(allocations) == 1
522+
assert allocations[0] == 256
523+
del cb_mr
524+
gc.collect()
525+
del buf
526+
527+
501528
def test_mr_upstream_lifetime():
502529
# Simple test to ensure upstream MRs are deallocated before downstream MR
503530
cuda_mr = rmm.mr.CudaMemoryResource()

0 commit comments

Comments
 (0)