33 * SPDX-License-Identifier: Apache-2.0
44 */
55
6+ #include < algorithm>
7+ #include < ranges>
68#include < sstream>
79
810#include < rapidsmpf/communicator/communicator.hpp>
11+ #include < rapidsmpf/nvtx.hpp>
912#include < rapidsmpf/shuffler/chunk.hpp>
1013#include < rapidsmpf/shuffler/postbox.hpp>
1114#include < rapidsmpf/utils.hpp>
@@ -32,10 +35,7 @@ void PostBox<KeyType>::insert(Chunk&& chunk) {
3235 " PostBox.insert(): n_non_empty_keys_ is already at the maximum"
3336 );
3437 }
35- RAPIDSMPF_EXPECTS (
36- map_value.chunks .emplace (std::move (chunk)).second ,
37- " PostBox.insert(): chunk already exist"
38- );
38+ map_value.chunks .push_back (std::move (chunk));
3939}
4040
4141template <typename KeyType>
@@ -53,15 +53,11 @@ std::vector<Chunk> PostBox<KeyType>::extract(PartID pid) {
5353template <typename KeyType>
5454std::vector<Chunk> PostBox<KeyType>::extract_by_key(KeyType key) {
5555 auto & map_value = pigeonhole_.at (key);
56- std::vector<Chunk> ret;
5756 std::lock_guard lock (map_value.mutex );
5857 RAPIDSMPF_EXPECTS (!map_value.chunks .empty (), " PostBox.extract(): partition is empty" );
59- ret.reserve (map_value.chunks .size ());
6058
61- for (auto it = map_value.chunks .begin (); it != map_value.chunks .end ();) {
62- auto node = map_value.chunks .extract (it++);
63- ret.emplace_back (std::move (node.value ()));
64- }
59+ std::vector<Chunk> ret = std::move (map_value.chunks );
60+ map_value.chunks .clear ();
6561
6662 RAPIDSMPF_EXPECTS (
6763 n_non_empty_keys_.fetch_sub (1 , std::memory_order_relaxed) > 0 ,
@@ -77,24 +73,33 @@ std::vector<Chunk> PostBox<KeyType>::extract_all_ready() {
7773 // Iterate through the outer map
7874 for (auto & [key, map_value] : pigeonhole_) {
7975 std::lock_guard lock (map_value.mutex );
80- bool chunks_available = !map_value.chunks .empty ();
81- auto chunk_it = map_value.chunks .begin ();
82- while (chunk_it != map_value.chunks .end ()) {
83- if (chunk_it->is_ready ()) {
84- auto node = map_value.chunks .extract (chunk_it++);
85- ret.emplace_back (std::move (node.value ()));
86- } else {
87- ++chunk_it;
88- }
89- }
9076
91- // if the chunks were available and are now empty, its fully extracted
92- if (chunks_available && map_value.chunks .empty ()) {
77+ // Partition: non-ready chunks first, ready chunks at the end
78+ auto partition_point =
79+ std::ranges::partition (map_value.chunks , [](const Chunk& c) {
80+ return !c.is_ready ();
81+ }).begin ();
82+
83+ // if the chunks are available and all are ready, then all chunks will be
84+ // extracted
85+ if (map_value.chunks .begin () == partition_point
86+ && partition_point != map_value.chunks .end ())
87+ {
9388 RAPIDSMPF_EXPECTS (
9489 n_non_empty_keys_.fetch_sub (1 , std::memory_order_relaxed) > 0 ,
9590 " PostBox.extract_all_ready(): n_non_empty_keys_ is already 0"
9691 );
9792 }
93+
94+ // Move ready chunks to result
95+ ret.insert (
96+ ret.end (),
97+ std::make_move_iterator (partition_point),
98+ std::make_move_iterator (map_value.chunks .end ())
99+ );
100+
101+ // Remove ready chunks from the vector
102+ map_value.chunks .erase (partition_point, map_value.chunks .end ());
98103 }
99104 return ret;
100105}
@@ -105,9 +110,44 @@ bool PostBox<KeyType>::empty() const {
105110}
106111
107112template <typename KeyType>
108- size_t PostBox<KeyType>::spill(BufferResource* /* br */ , size_t /* amount */ ) {
109- // TODO: implement spill
110- return 0 ;
113+ size_t PostBox<KeyType>::spill(
114+ BufferResource* br, Communicator::Logger& log, size_t amount
115+ ) {
116+ RAPIDSMPF_NVTX_FUNC_RANGE ();
117+
118+ // individually lock each key and spill the chunks in it. If we are unable to lock the
119+ // key, then it will be skipped.
120+ size_t total_spilled = 0 ;
121+ for (auto & [key, map_value] : pigeonhole_) {
122+ std::unique_lock lock (map_value.mutex , std::try_to_lock);
123+ if (lock) { // now all chunks in this key are locked
124+ for (auto & chunk : map_value.chunks ) {
125+ if (chunk.is_data_buffer_set ()
126+ && chunk.data_memory_type () == MemoryType::DEVICE)
127+ {
128+ size_t size = chunk.concat_data_size ();
129+ auto [host_reservation, host_overbooking] =
130+ br->reserve (MemoryType::HOST, size, true );
131+ if (host_overbooking > 0 ) {
132+ log.warn (
133+ " Cannot spill to host because of host memory overbooking: " ,
134+ format_nbytes (host_overbooking)
135+ );
136+ continue ;
137+ }
138+ chunk.set_data_buffer (
139+ br->move (chunk.release_data_buffer (), host_reservation)
140+ );
141+ total_spilled += size;
142+ if (total_spilled >= amount) {
143+ break ;
144+ }
145+ }
146+ }
147+ }
148+ }
149+
150+ return total_spilled;
111151}
112152
113153template <typename KeyType>
0 commit comments