Skip to content

Commit 6c3adfa

Browse files
authored
OOM protection by fallback to managed memory. (#287)
Implement out-of-memory protection by using a RMM resource `RmmFallbackResource` based on rapidsai/rmm#1665. The idea is to use managed memory when the RMM pool raises an OOM error. Authors: - Mads R. B. Kristensen (https://github.com/madsbk) Approvers: - Peter Andreas Entschev (https://github.com/pentschev) URL: #287
1 parent 55346a0 commit 6c3adfa

File tree

7 files changed

+407
-2
lines changed

7 files changed

+407
-2
lines changed
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
/**
2+
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
#pragma once
7+
8+
#include <cstddef>
9+
#include <mutex>
10+
#include <unordered_set>
11+
12+
#include <rmm/error.hpp>
13+
#include <rmm/mr/device/device_memory_resource.hpp>
14+
#include <rmm/resource_ref.hpp>
15+
16+
namespace rapidsmpf {
17+
18+
19+
/**
20+
* @brief A device memory resource that uses an alternate upstream resource when the
21+
* primary upstream resource throws `rmm::out_of_memory`.
22+
*
23+
* An instance of this resource must be constructed with two upstream resources to satisfy
24+
* allocation requests.
25+
*
26+
*/
27+
class RmmFallbackResource final : public rmm::mr::device_memory_resource {
28+
public:
29+
/**
30+
* @brief Construct a new `RmmFallbackResource` that uses `primary_upstream`
31+
* to satisfy allocation requests and if that fails with `rmm::out_of_memory`,
32+
* uses `alternate_upstream`.
33+
*
34+
* @param primary_upstream The primary resource used for allocating/deallocating
35+
* device memory
36+
* @param alternate_upstream The alternate resource used for allocating/deallocating
37+
* device memory memory
38+
*/
39+
RmmFallbackResource(
40+
rmm::device_async_resource_ref primary_upstream,
41+
rmm::device_async_resource_ref alternate_upstream
42+
)
43+
: primary_upstream_{primary_upstream}, alternate_upstream_{alternate_upstream} {}
44+
45+
RmmFallbackResource() = delete;
46+
~RmmFallbackResource() override = default;
47+
48+
/**
49+
* @brief Get a reference to the primary upstream resource.
50+
*
51+
* @return Reference to the RMM memory resource.
52+
*/
53+
[[nodiscard]] rmm::device_async_resource_ref get_upstream_resource() const noexcept {
54+
return primary_upstream_;
55+
}
56+
57+
/**
58+
* @brief Get a reference to the alternative upstream resource.
59+
*
60+
* This resource is used when primary upstream resource throws `rmm::out_of_memory`.
61+
*
62+
* @return Reference to the RMM memory resource.
63+
*/
64+
[[nodiscard]] rmm::device_async_resource_ref get_alternate_upstream_resource(
65+
) const noexcept {
66+
return alternate_upstream_;
67+
}
68+
69+
private:
70+
/**
71+
* @brief Allocates memory of size at least `bytes` using the upstream
72+
* resource.
73+
*
74+
* @throws any exceptions thrown from the upstream resources, only
75+
* `rmm::out_of_memory` thrown by the primary upstream is caught.
76+
*
77+
* @param bytes The size, in bytes, of the allocation
78+
* @param stream Stream on which to perform the allocation
79+
* @return void* Pointer to the newly allocated memory
80+
*/
81+
void* do_allocate(std::size_t bytes, rmm::cuda_stream_view stream) override {
82+
void* ret{};
83+
try {
84+
ret = primary_upstream_.allocate_async(bytes, stream);
85+
} catch (rmm::out_of_memory const& e) {
86+
ret = alternate_upstream_.allocate_async(bytes, stream);
87+
std::lock_guard<std::mutex> lock(mutex_);
88+
alternate_allocations_.insert(ret);
89+
}
90+
return ret;
91+
}
92+
93+
/**
94+
* @brief Free allocation of size `bytes` pointed to by `ptr`
95+
*
96+
* @param ptr Pointer to be deallocated
97+
* @param bytes Size of the allocation
98+
* @param stream Stream on which to perform the deallocation
99+
*/
100+
void do_deallocate(void* ptr, std::size_t bytes, rmm::cuda_stream_view stream)
101+
override {
102+
std::size_t count{0};
103+
{
104+
std::lock_guard<std::mutex> lock(mutex_);
105+
count = alternate_allocations_.erase(ptr);
106+
}
107+
if (count > 0) {
108+
alternate_upstream_.deallocate_async(ptr, bytes, stream);
109+
} else {
110+
primary_upstream_.deallocate_async(ptr, bytes, stream);
111+
}
112+
}
113+
114+
/**
115+
* @brief Compare the resource to another.
116+
*
117+
* @param other The other resource to compare to
118+
* @return true If the two resources are equivalent
119+
* @return false If the two resources are not equal
120+
*/
121+
[[nodiscard]] bool do_is_equal(rmm::mr::device_memory_resource const& other
122+
) const noexcept override {
123+
if (this == &other) {
124+
return true;
125+
}
126+
auto cast = dynamic_cast<RmmFallbackResource const*>(&other);
127+
if (cast == nullptr) {
128+
return false;
129+
}
130+
return get_upstream_resource() == cast->get_upstream_resource()
131+
&& get_alternate_upstream_resource()
132+
== cast->get_alternate_upstream_resource();
133+
}
134+
135+
std::mutex mutex_;
136+
rmm::device_async_resource_ref primary_upstream_;
137+
rmm::device_async_resource_ref alternate_upstream_;
138+
std::unordered_set<void*> alternate_allocations_;
139+
};
140+
141+
142+
} // namespace rapidsmpf
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/**
2+
* SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
7+
#include <cstddef>
8+
#include <stdexcept>
9+
#include <unordered_set>
10+
11+
#include <gmock/gmock.h>
12+
#include <gtest/gtest.h>
13+
14+
#include <rmm/cuda_stream_view.hpp>
15+
#include <rmm/detail/error.hpp>
16+
#include <rmm/device_buffer.hpp>
17+
18+
#include <rapidsmpf/buffer/rmm_fallback_resource.hpp>
19+
20+
21+
using namespace rapidsmpf;
22+
23+
template <typename ExceptionType>
24+
struct throw_at_limit_resource final : public rmm::mr::device_memory_resource {
25+
throw_at_limit_resource(std::size_t limit) : limit{limit} {}
26+
27+
void* do_allocate(std::size_t bytes, rmm::cuda_stream_view stream) override {
28+
if (bytes > limit) {
29+
throw ExceptionType{"foo"};
30+
}
31+
void* ptr{nullptr};
32+
RMM_CUDA_TRY_ALLOC(cudaMallocAsync(&ptr, bytes, stream));
33+
allocs.insert(ptr);
34+
return ptr;
35+
}
36+
37+
void do_deallocate(void* ptr, std::size_t, rmm::cuda_stream_view) override {
38+
RMM_ASSERT_CUDA_SUCCESS(cudaFree(ptr));
39+
allocs.erase(ptr);
40+
}
41+
42+
[[nodiscard]] bool do_is_equal(rmm::mr::device_memory_resource const& other
43+
) const noexcept override {
44+
return this == &other;
45+
}
46+
47+
const std::size_t limit;
48+
std::unordered_set<void*> allocs{};
49+
};
50+
51+
TEST(FailureAlternateTest, TrackBothUpstreams) {
52+
throw_at_limit_resource<rmm::out_of_memory> primary_mr{100};
53+
throw_at_limit_resource<rmm::out_of_memory> alternate_mr{1000};
54+
RmmFallbackResource mr{primary_mr, alternate_mr};
55+
56+
// Check that a small allocation goes to the primary resource.
57+
{
58+
void* a1 = mr.allocate(10);
59+
EXPECT_EQ(primary_mr.allocs, std::unordered_set<void*>{a1});
60+
EXPECT_EQ(alternate_mr.allocs, std::unordered_set<void*>{});
61+
mr.deallocate(a1, 10);
62+
EXPECT_EQ(primary_mr.allocs, std::unordered_set<void*>{});
63+
EXPECT_EQ(alternate_mr.allocs, std::unordered_set<void*>{});
64+
}
65+
66+
// Check that a large allocation goes to the alternate resource.
67+
{
68+
void* a1 = mr.allocate(200);
69+
EXPECT_EQ(primary_mr.allocs, std::unordered_set<void*>{});
70+
EXPECT_EQ(alternate_mr.allocs, std::unordered_set<void*>{a1});
71+
mr.deallocate(a1, 200);
72+
EXPECT_EQ(primary_mr.allocs, std::unordered_set<void*>{});
73+
EXPECT_EQ(alternate_mr.allocs, std::unordered_set<void*>{});
74+
}
75+
76+
// Check that we get an error when the allocation cannot fit the
77+
// primary or the alternate resource.
78+
EXPECT_THROW(mr.allocate(2000), rmm::out_of_memory);
79+
}
80+
81+
TEST(FailureAlternateTest, DifferentExceptionTypes) {
82+
throw_at_limit_resource<std::invalid_argument> primary_mr{100};
83+
throw_at_limit_resource<rmm::out_of_memory> alternate_mr{1000};
84+
RmmFallbackResource mr{primary_mr, alternate_mr};
85+
86+
// Check that `RmmFallbackResource` only catch `rmm::out_of_memory` exceptions.
87+
EXPECT_THROW(mr.allocate(200), std::invalid_argument);
88+
}

python/rapidsmpf/rapidsmpf/buffer/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
# SPDX-License-Identifier: Apache-2.0
44
# =================================================================================
55

6-
set(cython_modules buffer.pyx packed_data.pyx resource.pyx spill_manager.pyx)
6+
set(cython_modules buffer.pyx packed_data.pyx resource.pyx spill_manager.pyx
7+
rmm_fallback_resource.pyx
8+
)
79

810
rapids_cython_create_modules(
911
CXX
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from rmm.pylibrmm.memory_resource import DeviceMemoryResource
5+
from rmm.pylibrmm.stream import Stream
6+
7+
class RmmFallbackResource:
8+
def __init__(
9+
self,
10+
upstream_mr: DeviceMemoryResource,
11+
alternate_upstream_mr: DeviceMemoryResource,
12+
): ...
13+
@property
14+
def get_upstream(self) -> DeviceMemoryResource: ...
15+
@property
16+
def get_alternate_upstream(self) -> DeviceMemoryResource: ...
17+
def allocate(self, nbytes: int, stream: Stream = ...) -> int: ...
18+
def deallocate(self, ptr: int, nbytes: int, stream: Stream = ...) -> None: ...
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from rmm.pylibrmm.memory_resource cimport (DeviceMemoryResource,
5+
UpstreamResourceAdaptor,
6+
device_memory_resource)
7+
8+
9+
cdef extern from "<rapidsmpf/buffer/rmm_fallback_resource.hpp>" nogil:
10+
cdef cppclass cpp_RmmFallbackResource"rapidsmpf::RmmFallbackResource"(
11+
device_memory_resource
12+
):
13+
# Notice, `RmmFallbackResource` takes `device_async_resource_ref` as
14+
# upstream arguments but we define them here as `device_memory_resource*`
15+
# and rely on implicit type conversion.
16+
cpp_RmmFallbackResource(
17+
device_memory_resource* upstream_mr,
18+
device_memory_resource* alternate_upstream_mr,
19+
) except +
20+
21+
22+
cdef class RmmFallbackResource(UpstreamResourceAdaptor):
23+
cdef readonly DeviceMemoryResource alternate_upstream_mr
24+
25+
def __cinit__(
26+
self,
27+
DeviceMemoryResource upstream_mr,
28+
DeviceMemoryResource alternate_upstream_mr,
29+
):
30+
# Note, `upstream_mr is None` is checked by `UpstreamResourceAdaptor`.
31+
if alternate_upstream_mr is None:
32+
raise Exception("Argument `alternate_upstream_mr` must not be None")
33+
self.alternate_upstream_mr = alternate_upstream_mr
34+
35+
self.c_obj.reset(
36+
new cpp_RmmFallbackResource(
37+
upstream_mr.get_mr(),
38+
alternate_upstream_mr.get_mr(),
39+
)
40+
)
41+
42+
def __init__(
43+
self,
44+
DeviceMemoryResource upstream_mr,
45+
DeviceMemoryResource alternate_upstream_mr,
46+
):
47+
"""
48+
A memory resource that uses an alternate resource when memory allocation fails.
49+
Parameters
50+
----------
51+
upstream : DeviceMemoryResource
52+
The primary resource used for allocating/deallocating device memory
53+
alternate_upstream : DeviceMemoryResource
54+
The alternate resource used when the primary fails to allocate
55+
"""
56+
pass
57+
58+
cpdef DeviceMemoryResource get_alternate_upstream(self):
59+
return self.alternate_upstream_mr

python/rapidsmpf/rapidsmpf/integrations/dask/core.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from rapidsmpf.buffer.buffer import MemoryType
2121
from rapidsmpf.buffer.resource import BufferResource, LimitAvailableMemory
22+
from rapidsmpf.buffer.rmm_fallback_resource import RmmFallbackResource
2223
from rapidsmpf.buffer.spill_collection import SpillCollection
2324
from rapidsmpf.communicator.ucxx import barrier, get_root_ucxx_address, new_communicator
2425
from rapidsmpf.integrations.dask import _compat
@@ -187,6 +188,7 @@ def rmpf_worker_setup(
187188
*,
188189
spill_device: float,
189190
periodic_spill_check: float,
191+
oom_protection: bool,
190192
enable_statistics: bool,
191193
) -> None:
192194
"""
@@ -204,6 +206,9 @@ def rmpf_worker_setup(
204206
by the buffer resource. The value of ``periodic_spill_check`` is used as
205207
the pause between checks (in seconds). If None, no periodic spill check
206208
is performed.
209+
oom_protection
210+
Enable out-of-memory protection by using managed memory when the device
211+
memory pool raises OOM errors.
207212
enable_statistics
208213
Whether to track shuffler statistics.
209214
@@ -236,9 +241,13 @@ def rmpf_worker_setup(
236241
assert ctx.comm is not None
237242
ctx.progress_thread = ProgressThread(ctx.comm, ctx.statistics)
238243

244+
mr = rmm.mr.get_current_device_resource()
245+
if oom_protection:
246+
mr = RmmFallbackResource(mr, rmm.mr.ManagedMemoryResource())
247+
239248
# Setup a buffer_resource.
240249
# Wrap the current RMM resource in statistics adaptor.
241-
mr = rmm.mr.StatisticsResourceAdaptor(rmm.mr.get_current_device_resource())
250+
mr = rmm.mr.StatisticsResourceAdaptor(mr)
242251
rmm.mr.set_current_device_resource(mr)
243252
total_memory = rmm.mr.available_device_memory()[1]
244253
memory_available = {
@@ -307,6 +316,7 @@ def bootstrap_dask_cluster(
307316
*,
308317
spill_device: float = 0.50,
309318
periodic_spill_check: float | None = 1e-3,
319+
oom_protection: bool = False,
310320
enable_statistics: bool = True,
311321
) -> None:
312322
"""
@@ -324,6 +334,9 @@ def bootstrap_dask_cluster(
324334
by the buffer resource. The value of ``periodic_spill_check`` is used as
325335
the pause between checks (in seconds). If None, no periodic spill
326336
check is performed.
337+
oom_protection
338+
Enable out-of-memory protection by using managed memory when the device
339+
memory pool raises OOM errors.
327340
enable_statistics
328341
Whether to track shuffler statistics.
329342
@@ -383,6 +396,7 @@ def bootstrap_dask_cluster(
383396
rmpf_worker_setup,
384397
spill_device=spill_device,
385398
periodic_spill_check=periodic_spill_check,
399+
oom_protection=oom_protection,
386400
enable_statistics=enable_statistics,
387401
)
388402

0 commit comments

Comments
 (0)