Skip to content

Commit 670068b

Browse files
committed
add spilling
Signed-off-by: niranda perera <[email protected]>
1 parent e0b7467 commit 670068b

File tree

4 files changed

+69
-72
lines changed

4 files changed

+69
-72
lines changed

cpp/include/rapidsmpf/shuffler/chunk.hpp

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -469,44 +469,3 @@ inline std::ostream& operator<<(std::ostream& os, ReadyForDataMessage const& obj
469469

470470
} // namespace detail
471471
} // namespace rapidsmpf::shuffler
472-
473-
// Custom hash function for Chunk that uses chunk ID
474-
namespace std {
475-
template <>
476-
/**
477-
* @brief Hash function for Chunk.
478-
*/
479-
struct hash<rapidsmpf::shuffler::detail::Chunk> {
480-
/**
481-
* @brief Hash function for Chunk that uses chunk ID.
482-
*
483-
* @param chunk The chunk to hash.
484-
* @return The hash of the chunk.
485-
*/
486-
std::size_t operator()(
487-
rapidsmpf::shuffler::detail::Chunk const& chunk
488-
) const noexcept {
489-
return std::hash<rapidsmpf::shuffler::detail::ChunkID>{}(chunk.chunk_id());
490-
}
491-
};
492-
493-
template <>
494-
/**
495-
* @brief Equality operator for Chunk.
496-
*/
497-
struct equal_to<rapidsmpf::shuffler::detail::Chunk> {
498-
/**
499-
* @brief Equality operator for Chunk that uses chunk ID.
500-
*
501-
* @param lhs The left chunk.
502-
* @param rhs The right chunk.
503-
* @return True if the chunks are equal, false otherwise.
504-
*/
505-
bool operator()(
506-
rapidsmpf::shuffler::detail::Chunk const& lhs,
507-
rapidsmpf::shuffler::detail::Chunk const& rhs
508-
) const noexcept {
509-
return lhs.chunk_id() == rhs.chunk_id();
510-
}
511-
};
512-
} // namespace std

cpp/include/rapidsmpf/shuffler/postbox.hpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include <ranges>
1111
#include <string>
1212
#include <unordered_map>
13-
#include <unordered_set>
1413
#include <vector>
1514

1615
#include <rapidsmpf/error.hpp>
@@ -115,20 +114,19 @@ class PostBox {
115114
* @brief Spills the specified amount of data from the PostBox.
116115
*
117116
* @param br Buffer resource to use for spilling.
117+
* @param log Logger to use for logging.
118118
* @param amount The amount of data to spill.
119119
* @return The amount of data spilled.
120120
*/
121-
size_t spill(BufferResource* br, size_t amount);
121+
size_t spill(BufferResource* br, Communicator::Logger& log, size_t amount);
122122

123123
private:
124124
/**
125125
* @brief Map value for the PostBox.
126-
*
127-
* @note The mutex is used to protect the chunks set.
128126
*/
129127
struct MapValue {
130128
mutable std::mutex mutex; ///< Mutex to protect each key
131-
std::unordered_set<Chunk> chunks; ///< Set of chunks for the key
129+
std::vector<Chunk> chunks; ///< Vector of chunks for the key
132130
};
133131

134132
std::function<key_type(PartID)>

cpp/src/shuffler/postbox.cpp

Lines changed: 65 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6+
#include <algorithm>
7+
#include <ranges>
68
#include <sstream>
79

810
#include <rapidsmpf/communicator/communicator.hpp>
11+
#include <rapidsmpf/nvtx.hpp>
912
#include <rapidsmpf/shuffler/chunk.hpp>
1013
#include <rapidsmpf/shuffler/postbox.hpp>
1114
#include <rapidsmpf/utils.hpp>
@@ -32,10 +35,7 @@ void PostBox<KeyType>::insert(Chunk&& chunk) {
3235
"PostBox.insert(): n_non_empty_keys_ is already at the maximum"
3336
);
3437
}
35-
RAPIDSMPF_EXPECTS(
36-
map_value.chunks.emplace(std::move(chunk)).second,
37-
"PostBox.insert(): chunk already exist"
38-
);
38+
map_value.chunks.push_back(std::move(chunk));
3939
}
4040

4141
template <typename KeyType>
@@ -53,15 +53,11 @@ std::vector<Chunk> PostBox<KeyType>::extract(PartID pid) {
5353
template <typename KeyType>
5454
std::vector<Chunk> PostBox<KeyType>::extract_by_key(KeyType key) {
5555
auto& map_value = pigeonhole_.at(key);
56-
std::vector<Chunk> ret;
5756
std::lock_guard lock(map_value.mutex);
5857
RAPIDSMPF_EXPECTS(!map_value.chunks.empty(), "PostBox.extract(): partition is empty");
59-
ret.reserve(map_value.chunks.size());
6058

61-
for (auto it = map_value.chunks.begin(); it != map_value.chunks.end();) {
62-
auto node = map_value.chunks.extract(it++);
63-
ret.emplace_back(std::move(node.value()));
64-
}
59+
std::vector<Chunk> ret = std::move(map_value.chunks);
60+
map_value.chunks.clear();
6561

6662
RAPIDSMPF_EXPECTS(
6763
n_non_empty_keys_.fetch_sub(1, std::memory_order_relaxed) > 0,
@@ -77,24 +73,33 @@ std::vector<Chunk> PostBox<KeyType>::extract_all_ready() {
7773
// Iterate through the outer map
7874
for (auto& [key, map_value] : pigeonhole_) {
7975
std::lock_guard lock(map_value.mutex);
80-
bool chunks_available = !map_value.chunks.empty();
81-
auto chunk_it = map_value.chunks.begin();
82-
while (chunk_it != map_value.chunks.end()) {
83-
if (chunk_it->is_ready()) {
84-
auto node = map_value.chunks.extract(chunk_it++);
85-
ret.emplace_back(std::move(node.value()));
86-
} else {
87-
++chunk_it;
88-
}
89-
}
9076

91-
// if the chunks were available and are now empty, its fully extracted
92-
if (chunks_available && map_value.chunks.empty()) {
77+
// Partition: non-ready chunks first, ready chunks at the end
78+
auto partition_point =
79+
std::ranges::partition(map_value.chunks, [](const Chunk& c) {
80+
return !c.is_ready();
81+
}).begin();
82+
83+
// if the chunks are available and all are ready, then all chunks will be
84+
// extracted
85+
if (map_value.chunks.begin() == partition_point
86+
&& partition_point != map_value.chunks.end())
87+
{
9388
RAPIDSMPF_EXPECTS(
9489
n_non_empty_keys_.fetch_sub(1, std::memory_order_relaxed) > 0,
9590
"PostBox.extract_all_ready(): n_non_empty_keys_ is already 0"
9691
);
9792
}
93+
94+
// Move ready chunks to result
95+
ret.insert(
96+
ret.end(),
97+
std::make_move_iterator(partition_point),
98+
std::make_move_iterator(map_value.chunks.end())
99+
);
100+
101+
// Remove ready chunks from the vector
102+
map_value.chunks.erase(partition_point, map_value.chunks.end());
98103
}
99104
return ret;
100105
}
@@ -105,9 +110,44 @@ bool PostBox<KeyType>::empty() const {
105110
}
106111

107112
template <typename KeyType>
108-
size_t PostBox<KeyType>::spill(BufferResource* /* br */, size_t /* amount */) {
109-
// TODO: implement spill
110-
return 0;
113+
size_t PostBox<KeyType>::spill(
114+
BufferResource* br, Communicator::Logger& log, size_t amount
115+
) {
116+
RAPIDSMPF_NVTX_FUNC_RANGE();
117+
118+
// individually lock each key and spill the chunks in it. If we are unable to lock the
119+
// key, then it will be skipped.
120+
size_t total_spilled = 0;
121+
for (auto& [key, map_value] : pigeonhole_) {
122+
std::unique_lock lock(map_value.mutex, std::try_to_lock);
123+
if (lock) { // now all chunks in this key are locked
124+
for (auto& chunk : map_value.chunks) {
125+
if (chunk.is_data_buffer_set()
126+
&& chunk.data_memory_type() == MemoryType::DEVICE)
127+
{
128+
size_t size = chunk.concat_data_size();
129+
auto [host_reservation, host_overbooking] =
130+
br->reserve(MemoryType::HOST, size, true);
131+
if (host_overbooking > 0) {
132+
log.warn(
133+
"Cannot spill to host because of host memory overbooking: ",
134+
format_nbytes(host_overbooking)
135+
);
136+
continue;
137+
}
138+
chunk.set_data_buffer(
139+
br->move(chunk.release_data_buffer(), host_reservation)
140+
);
141+
total_spilled += size;
142+
if (total_spilled >= amount) {
143+
break;
144+
}
145+
}
146+
}
147+
}
148+
}
149+
150+
return total_spilled;
111151
}
112152

113153
template <typename KeyType>

cpp/src/shuffler/shuffler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ std::size_t Shuffler::spill(std::optional<std::size_t> amount) {
753753
}
754754
std::size_t spilled{0};
755755
if (spill_need > 0) {
756-
spilled = ready_postbox_.spill(br_, spill_need);
756+
spilled = ready_postbox_.spill(br_, comm_->logger(), spill_need);
757757
}
758758
return spilled;
759759
}

0 commit comments

Comments
 (0)