|
| 1 | +/** |
| 2 | + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. |
| 3 | + * SPDX-License-Identifier: Apache-2.0 |
| 4 | + */ |
| 5 | +#pragma once |
| 6 | + |
| 7 | +#include <algorithm> |
| 8 | +#include <atomic> |
| 9 | +#include <chrono> |
| 10 | +#include <cstdint> |
| 11 | +#include <cstring> |
| 12 | +#include <functional> |
| 13 | +#include <memory> |
| 14 | +#include <type_traits> |
| 15 | +#include <utility> |
| 16 | +#include <vector> |
| 17 | + |
| 18 | +#include <rapidsmpf/allgather/allgather.hpp> |
| 19 | +#include <rapidsmpf/communicator/communicator.hpp> |
| 20 | +#include <rapidsmpf/error.hpp> |
| 21 | +#include <rapidsmpf/memory/buffer.hpp> |
| 22 | +#include <rapidsmpf/memory/buffer_resource.hpp> |
| 23 | +#include <rapidsmpf/memory/packed_data.hpp> |
| 24 | +#include <rapidsmpf/progress_thread.hpp> |
| 25 | +#include <rapidsmpf/statistics.hpp> |
| 26 | + |
| 27 | +namespace rapidsmpf::allreduce { |
| 28 | + |
| 29 | +/** |
| 30 | + * @brief Reduction operators supported by `AllReduce`. |
| 31 | + * |
| 32 | + * These closely mirror the reduction operators from MPI. |
| 33 | + */ |
| 34 | +enum class ReduceOp : std::uint8_t { |
| 35 | + SUM, ///< Sum / addition |
| 36 | + PROD, ///< Product / multiplication |
| 37 | + MIN, ///< Minimum |
| 38 | + MAX, ///< Maximum |
| 39 | +}; |
| 40 | + |
| 41 | +/** |
| 42 | + * @brief Type-erased reduction kernel used by `AllReduce`. |
| 43 | + * |
| 44 | + * The kernel must implement an associative binary operation over the contents of |
| 45 | + * two `PackedData` objects and accumulate the result into @p accum. |
| 46 | + * |
| 47 | + * Implementations must: |
| 48 | + * - Treat @p accum as the running partial result. |
| 49 | + * - Combine @p incoming into @p accum in-place. |
| 50 | + * - Leave @p incoming in a valid but unspecified state after the call. |
| 51 | + * |
| 52 | + * The kernel is responsible for interpreting `PackedData::metadata` and |
| 53 | + * `PackedData::data` consistently across all ranks. |
| 54 | + */ |
| 55 | +using ReduceKernel = std::function<void(PackedData& accum, PackedData&& incoming)>; |
| 56 | + |
| 57 | +/** |
| 58 | + * @brief AllReduce collective. |
| 59 | + * |
| 60 | + * The current implementation is built using `allgather::AllGather` and performs |
| 61 | + * the reduction locally after allgather completes. Considering `R` is the number of |
| 62 | + * ranks, and `N` is the number of bytes of data, per rank this incurs `O(R * N)` bytes of |
| 63 | + * memory consumption and `O(R)` communication operations. |
| 64 | + * |
| 65 | + * Semantics: |
| 66 | + * - Each rank may call `insert` any number of times with a local sequence number. |
| 67 | + * - Conceptually, the *k*-th insertion on each rank participates in a single |
| 68 | + * global reduction. That is, insertions are paired across ranks by their |
| 69 | + * local insertion order, not by sequence number values. |
| 70 | + * - Once all ranks call `insert_finished`, `wait_and_extract` returns one |
| 71 | + * globally-reduced `PackedData` per local insertion on this rank. |
| 72 | + * |
| 73 | + * The actual reduction is implemented via a type-erased `ReduceKernel` that is |
| 74 | + * supplied at construction time. Helper factories such as |
| 75 | + * `detail::make_reduce_kernel` can be used to build element-wise |
| 76 | + * reductions over contiguous arrays in device memory. |
| 77 | + */ |
| 78 | +class AllReduce { |
| 79 | + public: |
| 80 | + /** |
| 81 | + * @brief Construct a new AllReduce operation. |
| 82 | + * |
| 83 | + * @param comm The communicator for communication. |
| 84 | + * @param progress_thread The progress thread used by the underlying AllGather. |
| 85 | + * @param op_id Unique operation identifier for this allreduce. |
| 86 | + * @param br Buffer resource for memory allocation. |
| 87 | + * @param statistics Statistics collection instance (disabled by default). |
| 88 | + * @param reduce_kernel Type-erased reduction kernel to use. |
| 89 | + * @param finished_callback Optional callback run once locally when the allreduce |
| 90 | + * is finished and results are ready for extraction. |
| 91 | + * |
| 92 | + * @note This constructor internally creates an `allgather::AllGather` instance |
| 93 | + * that uses the same communicator, progress thread, and buffer resource. |
| 94 | + */ |
| 95 | + AllReduce( |
| 96 | + std::shared_ptr<Communicator> comm, |
| 97 | + std::shared_ptr<ProgressThread> progress_thread, |
| 98 | + OpID op_id, |
| 99 | + BufferResource* br, |
| 100 | + std::shared_ptr<Statistics> statistics = Statistics::disabled(), |
| 101 | + ReduceKernel reduce_kernel = {}, |
| 102 | + std::function<void(void)> finished_callback = nullptr |
| 103 | + ); |
| 104 | + |
| 105 | + AllReduce(AllReduce const&) = delete; |
| 106 | + AllReduce& operator=(AllReduce const&) = delete; |
| 107 | + AllReduce(AllReduce&&) = delete; |
| 108 | + AllReduce& operator=(AllReduce&&) = delete; |
| 109 | + |
| 110 | + /** |
| 111 | + * @brief Destructor. |
| 112 | + * |
| 113 | + * @note This operation is logically collective. If an `AllReduce` is locally |
| 114 | + * destructed before `wait_and_extract` is called, there is no guarantee |
| 115 | + * that in-flight communication will be completed. |
| 116 | + */ |
| 117 | + ~AllReduce(); |
| 118 | + |
| 119 | + /** |
| 120 | + * @brief Insert packed data into the allreduce operation. |
| 121 | + * |
| 122 | + * @param sequence_number Local ordered sequence number of the data. |
| 123 | + * @param packed_data The data to contribute to the allreduce. |
| 124 | + * |
| 125 | + * The caller promises that: |
| 126 | + * - `sequence_number`s are non-decreasing on each rank. |
| 127 | + * - The *k*-th call to `insert` on each rank corresponds to the same logical |
| 128 | + * reduction across all ranks (i.e., same element type and shape). |
| 129 | + */ |
| 130 | + void insert(std::uint64_t sequence_number, PackedData&& packed_data); |
| 131 | + |
| 132 | + /** |
| 133 | + * @brief Mark that this rank has finished contributing data. |
| 134 | + */ |
| 135 | + void insert_finished(); |
| 136 | + |
| 137 | + /** |
| 138 | + * @brief Check if the allreduce operation has completed. |
| 139 | + * |
| 140 | + * @return True if all data and finish messages from all ranks have been |
| 141 | + * received and locally reduced. |
| 142 | + */ |
| 143 | + [[nodiscard]] bool finished() const noexcept; |
| 144 | + |
| 145 | + /** |
| 146 | + * @brief Wait for completion and extract all reduced data. |
| 147 | + * |
| 148 | + * Blocks until the allreduce operation completes and returns all locally |
| 149 | + * reduced results, ordered by local insertion order. |
| 150 | + * |
| 151 | + * @param timeout Optional maximum duration to wait. Negative values mean |
| 152 | + * no timeout. |
| 153 | + * |
| 154 | + * @return A vector containing reduced packed data, one entry per local |
| 155 | + * insertion on this rank. |
| 156 | + * @throws std::runtime_error If the timeout is reached. |
| 157 | + */ |
| 158 | + [[nodiscard]] std::vector<PackedData> wait_and_extract( |
| 159 | + std::chrono::milliseconds timeout = std::chrono::milliseconds{-1} |
| 160 | + ); |
| 161 | + |
| 162 | + /** |
| 163 | + * @brief Check if reduced results are ready for extraction. |
| 164 | + * |
| 165 | + * This returns true once the underlying allgather has completed and, if |
| 166 | + * `wait_and_extract` has not yet been called, indicates that calling it |
| 167 | + * would not block. |
| 168 | + * |
| 169 | + * @return True if the allreduce operation has completed and results are ready for |
| 170 | + * extraction, false otherwise. |
| 171 | + */ |
| 172 | + [[nodiscard]] bool is_ready() const noexcept; |
| 173 | + |
| 174 | + private: |
| 175 | + /// @brief Perform the reduction across all ranks for all gathered contributions. |
| 176 | + [[nodiscard]] std::vector<PackedData> reduce_all(std::vector<PackedData>&& gathered); |
| 177 | + |
| 178 | + std::shared_ptr<Communicator> comm_; ///< Communicator |
| 179 | + std::shared_ptr<ProgressThread> |
| 180 | + progress_thread_; ///< Progress thread (unused directly). |
| 181 | + BufferResource* br_; ///< Buffer resource |
| 182 | + std::shared_ptr<Statistics> statistics_; ///< Statistics collection instance |
| 183 | + ReduceKernel reduce_kernel_; ///< Type-erased reduction kernel |
| 184 | + std::function<void(void)> finished_callback_; ///< Optional finished callback |
| 185 | + |
| 186 | + allgather::AllGather gatherer_; ///< Underlying allgather primitive |
| 187 | + |
| 188 | + std::atomic<std::uint32_t> nlocal_insertions_{0}; ///< Number of local inserts |
| 189 | + std::atomic<bool> reduced_computed_{ |
| 190 | + false |
| 191 | + }; ///< Whether the reduction has been computed |
| 192 | + std::vector<PackedData> reduced_results_; ///< Cached reduced results |
| 193 | +}; |
| 194 | + |
| 195 | +namespace detail { |
| 196 | +/** |
| 197 | + * @brief Create a device-based element-wise reduction kernel for a given (T, Op). |
| 198 | + * |
| 199 | + * This kernel expects both `PackedData::data` buffers to reside in device memory. |
| 200 | + * Implementations are provided in `device_kernels.cu` for a subset of (T, Op) |
| 201 | + * combinations. |
| 202 | + */ |
| 203 | +template <typename T, ReduceOp Op> |
| 204 | +ReduceKernel make_reduce_kernel(); |
| 205 | +} // namespace detail |
| 206 | + |
| 207 | +} // namespace rapidsmpf::allreduce |
0 commit comments