Skip to content

Commit e0b7467

Browse files
committed
finer locking without spilling
Signed-off-by: niranda perera <[email protected]>
1 parent c7e7ea2 commit e0b7467

File tree

6 files changed

+156
-171
lines changed

6 files changed

+156
-171
lines changed

cpp/include/rapidsmpf/shuffler/chunk.hpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,3 +469,44 @@ inline std::ostream& operator<<(std::ostream& os, ReadyForDataMessage const& obj
469469

470470
} // namespace detail
471471
} // namespace rapidsmpf::shuffler
472+
473+
// Custom hash function for Chunk that uses chunk ID
474+
namespace std {
475+
template <>
476+
/**
477+
* @brief Hash function for Chunk.
478+
*/
479+
struct hash<rapidsmpf::shuffler::detail::Chunk> {
480+
/**
481+
* @brief Hash function for Chunk that uses chunk ID.
482+
*
483+
* @param chunk The chunk to hash.
484+
* @return The hash of the chunk.
485+
*/
486+
std::size_t operator()(
487+
rapidsmpf::shuffler::detail::Chunk const& chunk
488+
) const noexcept {
489+
return std::hash<rapidsmpf::shuffler::detail::ChunkID>{}(chunk.chunk_id());
490+
}
491+
};
492+
493+
template <>
494+
/**
495+
* @brief Equality operator for Chunk.
496+
*/
497+
struct equal_to<rapidsmpf::shuffler::detail::Chunk> {
498+
/**
499+
* @brief Equality operator for Chunk that uses chunk ID.
500+
*
501+
* @param lhs The left chunk.
502+
* @param rhs The right chunk.
503+
* @return True if the chunks are equal, false otherwise.
504+
*/
505+
bool operator()(
506+
rapidsmpf::shuffler::detail::Chunk const& lhs,
507+
rapidsmpf::shuffler::detail::Chunk const& rhs
508+
) const noexcept {
509+
return lhs.chunk_id() == rhs.chunk_id();
510+
}
511+
};
512+
} // namespace std

cpp/include/rapidsmpf/shuffler/postbox.hpp

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,17 @@
44
*/
55
#pragma once
66

7+
#include <atomic>
78
#include <functional>
89
#include <mutex>
10+
#include <ranges>
911
#include <string>
1012
#include <unordered_map>
13+
#include <unordered_set>
1114
#include <vector>
1215

1316
#include <rapidsmpf/error.hpp>
17+
#include <rapidsmpf/memory/buffer_resource.hpp>
1418
#include <rapidsmpf/shuffler/chunk.hpp>
1519

1620
namespace rapidsmpf::shuffler::detail {
@@ -31,16 +35,21 @@ class PostBox {
3135
*
3236
* @tparam Fn The type of the function that maps a partition ID to a key.
3337
* @param key_map_fn A function that maps a partition ID to a key.
34-
* @param num_keys_hint The number of keys to reserve space for.
38+
* @param keys The keys expected to be used in the PostBox.
3539
*
3640
* @note The `key_map_fn` must be convertible to a function that takes a `PartID` and
3741
* returns a `KeyType`.
3842
*/
39-
template <typename Fn>
40-
PostBox(Fn&& key_map_fn, size_t num_keys_hint = 0)
41-
: key_map_fn_(std::move(key_map_fn)) {
42-
if (num_keys_hint > 0) {
43-
pigeonhole_.reserve(num_keys_hint);
43+
template <typename Fn, std::ranges::input_range Range>
44+
requires std::convertible_to<std::ranges::range_value_t<Range>, KeyType>
45+
PostBox(Fn&& key_map_fn, Range&& keys) : key_map_fn_(std::move(key_map_fn)) {
46+
pigeonhole_.reserve(std::ranges::size(keys));
47+
for (const auto& key : keys) {
48+
pigeonhole_.emplace(
49+
std::piecewise_construct,
50+
std::forward_as_tuple(key),
51+
std::forward_as_tuple()
52+
);
4453
}
4554
}
4655

@@ -62,17 +71,6 @@ class PostBox {
6271
*/
6372
bool is_empty(PartID pid) const;
6473

65-
/**
66-
* @brief Extracts a specific chunk from the PostBox.
67-
*
68-
* @param pid The ID of the partition containing the chunk.
69-
* @param cid The ID of the chunk to be accessed.
70-
* @return The extracted chunk.
71-
*
72-
* @throws std::out_of_range If the chunk is not found.
73-
*/
74-
[[nodiscard]] Chunk extract(PartID pid, ChunkID cid);
75-
7674
/**
7775
* @brief Extracts all chunks associated with a specific partition.
7876
*
@@ -81,7 +79,7 @@ class PostBox {
8179
*
8280
* @throws std::out_of_range If the partition is not found.
8381
*/
84-
std::unordered_map<ChunkID, Chunk> extract(PartID pid);
82+
std::vector<Chunk> extract(PartID pid);
8583

8684
/**
8785
* @brief Extracts all chunks associated with a specific key.
@@ -91,7 +89,7 @@ class PostBox {
9189
*
9290
* @throws std::out_of_range If the key is not found.
9391
*/
94-
std::unordered_map<ChunkID, Chunk> extract_by_key(KeyType key);
92+
std::vector<Chunk> extract_by_key(KeyType key);
9593

9694
/**
9795
* @brief Extracts all ready chunks from the PostBox.
@@ -107,30 +105,36 @@ class PostBox {
107105
*/
108106
[[nodiscard]] bool empty() const;
109107

110-
/**
111-
* @brief Searches for chunks of the specified memory type.
112-
*
113-
* @param mem_type The type of memory to search within.
114-
* @return A vector of tuples, where each tuple contains: PartID, ChunkID, and the
115-
* size of the chunk.
116-
*/
117-
[[nodiscard]] std::vector<std::tuple<key_type, ChunkID, std::size_t>> search(
118-
MemoryType mem_type
119-
) const;
120-
121108
/**
122109
* @brief Returns a description of this instance.
123110
* @return The description.
124111
*/
125112
[[nodiscard]] std::string str() const;
126113

114+
/**
115+
* @brief Spills the specified amount of data from the PostBox.
116+
*
117+
* @param br Buffer resource to use for spilling.
118+
* @param amount The amount of data to spill.
119+
* @return The amount of data spilled.
120+
*/
121+
size_t spill(BufferResource* br, size_t amount);
122+
127123
private:
128-
// TODO: more fine-grained locking e.g. by locking each partition individually.
129-
mutable std::mutex mutex_;
124+
/**
125+
* @brief Map value for the PostBox.
126+
*
127+
* @note The mutex is used to protect the chunks set.
128+
*/
129+
struct MapValue {
130+
mutable std::mutex mutex; ///< Mutex to protect each key
131+
std::unordered_set<Chunk> chunks; ///< Set of chunks for the key
132+
};
133+
130134
std::function<key_type(PartID)>
131135
key_map_fn_; ///< Function to map partition IDs to keys.
132-
std::unordered_map<key_type, std::unordered_map<ChunkID, Chunk>>
133-
pigeonhole_; ///< Storage for chunks, organized by a key and chunk ID.
136+
std::unordered_map<key_type, MapValue> pigeonhole_; ///< Storage for chunks
137+
std::atomic<size_t> n_non_empty_keys_{0};
134138
};
135139

136140
/**

cpp/include/rapidsmpf/shuffler/shuffler.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -341,29 +341,26 @@ class Shuffler {
341341

342342
private:
343343
BufferResource* br_;
344+
std::shared_ptr<Communicator> comm_;
344345
std::atomic<bool> active_{true};
346+
std::vector<PartID> const local_partitions_;
347+
345348
detail::PostBox<Rank> outgoing_postbox_; ///< Postbox for outgoing chunks, that are
346349
///< ready to be sent to other ranks.
347350
detail::PostBox<PartID> ready_postbox_; ///< Postbox for received chunks, that are
348351
///< ready to be extracted by the user.
349352

350-
std::shared_ptr<Communicator> comm_;
351353
std::shared_ptr<ProgressThread> progress_thread_;
352354
ProgressThread::FunctionID progress_thread_function_id_;
353355
OpID const op_id_;
354356

355357
SpillManager::SpillFunctionID spill_function_id_;
356358

357-
std::vector<PartID> const local_partitions_;
358359

359360
detail::FinishCounter finish_counter_;
360361
std::unordered_map<PartID, detail::ChunkID> outbound_chunk_counter_;
361362
mutable std::mutex outbound_chunk_counter_mutex_;
362363

363-
// We protect ready_postbox extraction to avoid returning a chunk that is in the
364-
// process of being spilled by `Shuffler::spill`.
365-
mutable std::mutex ready_postbox_spilling_mutex_;
366-
367364
std::atomic<detail::ChunkID> chunk_id_counter_{0};
368365

369366
std::shared_ptr<Statistics> statistics_;

cpp/src/shuffler/postbox.cpp

Lines changed: 54 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -22,88 +22,92 @@ void PostBox<KeyType>::insert(Chunk&& chunk) {
2222
"PostBox.insert(): all messages in the chunk must map to the same key"
2323
);
2424
}
25-
std::lock_guard const lock(mutex_);
25+
// std::lock_guard const lock(mutex_);
26+
auto& map_value = pigeonhole_.at(key);
27+
std::lock_guard lock(map_value.mutex);
28+
if (map_value.chunks.empty()) {
29+
RAPIDSMPF_EXPECTS(
30+
n_non_empty_keys_.fetch_add(1, std::memory_order_relaxed) + 1
31+
<= pigeonhole_.size(),
32+
"PostBox.insert(): n_non_empty_keys_ is already at the maximum"
33+
);
34+
}
2635
RAPIDSMPF_EXPECTS(
27-
pigeonhole_[key].emplace(chunk.chunk_id(), std::move(chunk)).second,
36+
map_value.chunks.emplace(std::move(chunk)).second,
2837
"PostBox.insert(): chunk already exist"
2938
);
3039
}
3140

3241
template <typename KeyType>
3342
bool PostBox<KeyType>::is_empty(PartID pid) const {
34-
std::lock_guard const lock(mutex_);
35-
return !pigeonhole_.contains(key_map_fn_(pid));
43+
auto& map_value = pigeonhole_.at(key_map_fn_(pid));
44+
std::lock_guard lock(map_value.mutex);
45+
return map_value.chunks.empty();
3646
}
3747

3848
template <typename KeyType>
39-
Chunk PostBox<KeyType>::extract(PartID pid, ChunkID cid) {
40-
std::lock_guard const lock(mutex_);
41-
return extract_item(pigeonhole_[key_map_fn_(pid)], cid).second;
49+
std::vector<Chunk> PostBox<KeyType>::extract(PartID pid) {
50+
return extract_by_key(key_map_fn_(pid));
4251
}
4352

4453
template <typename KeyType>
45-
std::unordered_map<ChunkID, Chunk> PostBox<KeyType>::extract(PartID pid) {
46-
std::lock_guard const lock(mutex_);
47-
return extract_value(pigeonhole_, key_map_fn_(pid));
48-
}
54+
std::vector<Chunk> PostBox<KeyType>::extract_by_key(KeyType key) {
55+
auto& map_value = pigeonhole_.at(key);
56+
std::vector<Chunk> ret;
57+
std::lock_guard lock(map_value.mutex);
58+
RAPIDSMPF_EXPECTS(!map_value.chunks.empty(), "PostBox.extract(): partition is empty");
59+
ret.reserve(map_value.chunks.size());
4960

50-
template <typename KeyType>
51-
std::unordered_map<ChunkID, Chunk> PostBox<KeyType>::extract_by_key(KeyType key) {
52-
std::lock_guard const lock(mutex_);
53-
return extract_value(pigeonhole_, key);
61+
for (auto it = map_value.chunks.begin(); it != map_value.chunks.end();) {
62+
auto node = map_value.chunks.extract(it++);
63+
ret.emplace_back(std::move(node.value()));
64+
}
65+
66+
RAPIDSMPF_EXPECTS(
67+
n_non_empty_keys_.fetch_sub(1, std::memory_order_relaxed) > 0,
68+
"PostBox.extract(): n_non_empty_keys_ is already 0"
69+
);
70+
return ret;
5471
}
5572

5673
template <typename KeyType>
5774
std::vector<Chunk> PostBox<KeyType>::extract_all_ready() {
58-
std::lock_guard const lock(mutex_);
5975
std::vector<Chunk> ret;
6076

6177
// Iterate through the outer map
62-
auto pid_it = pigeonhole_.begin();
63-
while (pid_it != pigeonhole_.end()) {
64-
// Iterate through the inner map
65-
auto& chunks = pid_it->second;
66-
auto chunk_it = chunks.begin();
67-
while (chunk_it != chunks.end()) {
68-
if (chunk_it->second.is_ready()) {
69-
ret.emplace_back(std::move(chunk_it->second));
70-
chunk_it = chunks.erase(chunk_it);
78+
for (auto& [key, map_value] : pigeonhole_) {
79+
std::lock_guard lock(map_value.mutex);
80+
bool chunks_available = !map_value.chunks.empty();
81+
auto chunk_it = map_value.chunks.begin();
82+
while (chunk_it != map_value.chunks.end()) {
83+
if (chunk_it->is_ready()) {
84+
auto node = map_value.chunks.extract(chunk_it++);
85+
ret.emplace_back(std::move(node.value()));
7186
} else {
7287
++chunk_it;
7388
}
7489
}
7590

76-
// Remove the pid entry if its chunks map is empty
77-
if (chunks.empty()) {
78-
pid_it = pigeonhole_.erase(pid_it);
79-
} else {
80-
++pid_it;
91+
// if the chunks were available and are now empty, its fully extracted
92+
if (chunks_available && map_value.chunks.empty()) {
93+
RAPIDSMPF_EXPECTS(
94+
n_non_empty_keys_.fetch_sub(1, std::memory_order_relaxed) > 0,
95+
"PostBox.extract_all_ready(): n_non_empty_keys_ is already 0"
96+
);
8197
}
8298
}
83-
8499
return ret;
85100
}
86101

87102
template <typename KeyType>
88103
bool PostBox<KeyType>::empty() const {
89-
std::lock_guard const lock(mutex_);
90-
return pigeonhole_.empty();
104+
return n_non_empty_keys_.load(std::memory_order_acquire) == 0;
91105
}
92106

93107
template <typename KeyType>
94-
std::vector<std::tuple<KeyType, ChunkID, std::size_t>> PostBox<KeyType>::search(
95-
MemoryType mem_type
96-
) const {
97-
std::lock_guard const lock(mutex_);
98-
std::vector<std::tuple<KeyType, ChunkID, std::size_t>> ret;
99-
for (auto& [key, chunks] : pigeonhole_) {
100-
for (auto& [cid, chunk] : chunks) {
101-
if (!chunk.is_control_message(0) && chunk.data_memory_type() == mem_type) {
102-
ret.emplace_back(key, cid, chunk.concat_data_size());
103-
}
104-
}
105-
}
106-
return ret;
108+
size_t PostBox<KeyType>::spill(BufferResource* /* br */, size_t /* amount */) {
109+
// TODO: implement spill
110+
return 0;
107111
}
108112

109113
template <typename KeyType>
@@ -113,14 +117,14 @@ std::string PostBox<KeyType>::str() const {
113117
}
114118
std::stringstream ss;
115119
ss << "PostBox(";
116-
for (auto const& [key, chunks] : pigeonhole_) {
120+
for (auto const& [key, map_value] : pigeonhole_) {
117121
ss << "k=" << key << ": [";
118-
for (auto const& [cid, chunk] : chunks) {
119-
assert(cid == chunk.chunk_id());
122+
for (auto const& chunk : map_value.chunks) {
123+
// assert(cid == chunk.chunk_id());
120124
if (chunk.is_control_message(0)) {
121125
ss << "EOP" << chunk.expected_num_chunks(0) << ", ";
122126
} else {
123-
ss << cid << ", ";
127+
ss << chunk.chunk_id() << ", ";
124128
}
125129
}
126130
ss << "\b\b], ";

0 commit comments

Comments
 (0)