Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 51 additions & 35 deletions cpp/include/rapidsmpf/shuffler/postbox.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@
*/
#pragma once

#include <atomic>
#include <functional>
#include <list>
#include <mutex>
#include <random>
#include <ranges>
#include <string>
#include <unordered_map>
#include <vector>

#include <rapidsmpf/error.hpp>
#include <rapidsmpf/memory/buffer_resource.hpp>
#include <rapidsmpf/shuffler/chunk.hpp>

namespace rapidsmpf::shuffler::detail {
Expand All @@ -31,17 +36,24 @@ class PostBox {
*
* @tparam Fn The type of the function that maps a partition ID to a key.
* @param key_map_fn A function that maps a partition ID to a key.
* @param num_keys_hint The number of keys to reserve space for.
* @param keys The keys expected to be used in the PostBox.
*
* @note The `key_map_fn` must be convertible to a function that takes a `PartID` and
* returns a `KeyType`.
*/
template <typename Fn>
PostBox(Fn&& key_map_fn, size_t num_keys_hint = 0)
: key_map_fn_(std::move(key_map_fn)) {
if (num_keys_hint > 0) {
pigeonhole_.reserve(num_keys_hint);
template <typename Fn, std::ranges::input_range Range>
requires std::convertible_to<std::ranges::range_value_t<Range>, KeyType>
PostBox(Fn&& key_map_fn, Range&& keys) : key_map_fn_(std::move(key_map_fn)) {
pigeonhole_.reserve(std::ranges::size(keys));
for (const auto& key : keys) {
pigeonhole_.emplace(
std::piecewise_construct,
std::forward_as_tuple(key),
std::forward_as_tuple()
);
}
rng_ = std::mt19937(std::random_device{}());
dist_ = std::uniform_int_distribution<size_t>(0, keys.size() - 1);
}

/**
Expand All @@ -60,18 +72,7 @@ class PostBox {
* @note The result reflects a snapshot at the time of the call and may change
* immediately afterward.
*/
bool is_empty(PartID pid) const;

/**
* @brief Extracts a specific chunk from the PostBox.
*
* @param pid The ID of the partition containing the chunk.
* @param cid The ID of the chunk to be accessed.
* @return The extracted chunk.
*
* @throws std::out_of_range If the chunk is not found.
*/
[[nodiscard]] Chunk extract(PartID pid, ChunkID cid);
[[nodiscard]] bool is_empty(PartID pid) const;

/**
* @brief Extracts all chunks associated with a specific partition.
Expand All @@ -81,7 +82,7 @@ class PostBox {
*
* @throws std::out_of_range If the partition is not found.
*/
std::unordered_map<ChunkID, Chunk> extract(PartID pid);
std::vector<Chunk> extract(PartID pid);

/**
* @brief Extracts all chunks associated with a specific key.
Expand All @@ -91,7 +92,7 @@ class PostBox {
*
* @throws std::out_of_range If the key is not found.
*/
std::unordered_map<ChunkID, Chunk> extract_by_key(KeyType key);
std::vector<Chunk> extract_by_key(KeyType key);

/**
* @brief Extracts all ready chunks from the PostBox.
Expand All @@ -107,30 +108,45 @@ class PostBox {
*/
[[nodiscard]] bool empty() const;

/**
* @brief Searches for chunks of the specified memory type.
*
* @param mem_type The type of memory to search within.
* @return A vector of tuples, where each tuple contains: PartID, ChunkID, and the
* size of the chunk.
*/
[[nodiscard]] std::vector<std::tuple<key_type, ChunkID, std::size_t>> search(
MemoryType mem_type
) const;

/**
* @brief Returns a description of this instance.
* @return The description.
*/
[[nodiscard]] std::string str() const;

/**
* @brief Spills the specified amount of data from the PostBox.
*
* @param br Buffer resource to use for spilling.
* @param log Logger to use for logging.
* @param amount The amount of data to spill.
* @return The amount of data spilled.
*/
size_t spill(BufferResource* br, Communicator::Logger& log, size_t amount);

private:
// TODO: more fine-grained locking e.g. by locking each partition individually.
mutable std::mutex mutex_;
/**
* @brief Map value for the PostBox.
*/
struct MapValue {
mutable std::mutex mutex; ///< Mutex to protect each key
std::list<Chunk> ready_chunks; ///< Vector of chunks for the key
size_t n_spilling_chunks{0}; ///< Number of chunks that are being spilled

[[nodiscard]] bool is_empty_unsafe() const noexcept {
return ready_chunks.empty() && n_spilling_chunks == 0;
}
};

std::function<key_type(PartID)>
key_map_fn_; ///< Function to map partition IDs to keys.
std::unordered_map<key_type, std::unordered_map<ChunkID, Chunk>>
pigeonhole_; ///< Storage for chunks, organized by a key and chunk ID.
std::unordered_map<key_type, MapValue> pigeonhole_; ///< Storage for chunks
std::atomic<size_t> n_chunks{0
}; ///< Number of chunks in the PostBox. Since the pigenhole map is not extracted,
///< this count will be used to check the emptiness
std::mt19937 rng_; ///< Random number generator
std::uniform_int_distribution<size_t>
dist_; ///< Distribution for selecting a random key
};

/**
Expand Down
9 changes: 3 additions & 6 deletions cpp/include/rapidsmpf/shuffler/shuffler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,29 +341,26 @@ class Shuffler {

private:
BufferResource* br_;
std::shared_ptr<Communicator> comm_;
std::atomic<bool> active_{true};
std::vector<PartID> const local_partitions_;

detail::PostBox<Rank> outgoing_postbox_; ///< Postbox for outgoing chunks, that are
///< ready to be sent to other ranks.
detail::PostBox<PartID> ready_postbox_; ///< Postbox for received chunks, that are
///< ready to be extracted by the user.

std::shared_ptr<Communicator> comm_;
std::shared_ptr<ProgressThread> progress_thread_;
ProgressThread::FunctionID progress_thread_function_id_;
OpID const op_id_;

SpillManager::SpillFunctionID spill_function_id_;

std::vector<PartID> const local_partitions_;

detail::FinishCounter finish_counter_;
std::unordered_map<PartID, detail::ChunkID> outbound_chunk_counter_;
mutable std::mutex outbound_chunk_counter_mutex_;

// We protect ready_postbox extraction to avoid returning a chunk that is in the
// process of being spilled by `Shuffler::spill`.
mutable std::mutex ready_postbox_spilling_mutex_;

std::atomic<detail::ChunkID> chunk_id_counter_{0};

std::shared_ptr<Statistics> statistics_;
Expand Down
14 changes: 11 additions & 3 deletions cpp/src/memory/spill_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,22 @@ std::size_t SpillManager::spill(std::size_t amount) {
std::size_t spilled{0};
std::unique_lock<std::mutex> lock(mutex_);
auto const t0_elapsed = Clock::now();
for (auto const [_, fid] : spill_function_priorities_) {
// for (auto const [_, fid] : spill_function_priorities_) {
// if (spilled >= amount) {
// break;
// }
// spilled += spill_functions_.at(fid)(amount - spilled);
// }
auto spill_functions_cp = spill_functions_;
lock.unlock();

for (auto& [id, fn] : spill_functions_cp) {
if (spilled >= amount) {
break;
}
spilled += spill_functions_.at(fid)(amount - spilled);
spilled += fn(amount - spilled);
}
auto const t1_elapsed = Clock::now();
lock.unlock();
auto& stats = *br_->statistics();
stats.add_duration_stat("spill-time-device-to-host", t1_elapsed - t0_elapsed);
stats.add_bytes_stat("spill-bytes-device-to-host", spilled);
Expand Down
Loading