Skip to content

Commit 0eba9ce

Browse files
committed
Add AllReduce class
1 parent e8e6899 commit 0eba9ce

File tree

8 files changed

+1073
-0
lines changed

8 files changed

+1073
-0
lines changed

cpp/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ target_link_options(maybe_asan INTERFACE "$<$<BOOL:${RAPIDSMPF_ASAN}>:-fsanitize
157157
add_library(
158158
rapidsmpf
159159
src/allgather/allgather.cpp
160+
src/allreduce/allreduce.cpp
161+
src/allreduce/device_kernels.cu
160162
src/bootstrap/bootstrap.cpp
161163
src/bootstrap/file_backend.cpp
162164
src/communicator/communicator.cpp
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
/**
2+
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
#pragma once
6+
7+
#include <algorithm>
8+
#include <atomic>
9+
#include <chrono>
10+
#include <cstdint>
11+
#include <cstring>
12+
#include <functional>
13+
#include <memory>
14+
#include <type_traits>
15+
#include <utility>
16+
#include <vector>
17+
18+
#include <rapidsmpf/allgather/allgather.hpp>
19+
#include <rapidsmpf/communicator/communicator.hpp>
20+
#include <rapidsmpf/error.hpp>
21+
#include <rapidsmpf/memory/buffer.hpp>
22+
#include <rapidsmpf/memory/buffer_resource.hpp>
23+
#include <rapidsmpf/memory/packed_data.hpp>
24+
#include <rapidsmpf/progress_thread.hpp>
25+
#include <rapidsmpf/statistics.hpp>
26+
27+
namespace rapidsmpf::allreduce {
28+
29+
/**
30+
* @brief Reduction operators supported by `AllReduce`.
31+
*
32+
* These closely mirror the reduction operators from MPI.
33+
*/
34+
enum class ReduceOp : std::uint8_t {
35+
SUM, ///< Sum / addition
36+
PROD, ///< Product / multiplication
37+
MIN, ///< Minimum
38+
MAX, ///< Maximum
39+
};
40+
41+
/**
42+
* @brief Type-erased reduction kernel used by `AllReduce`.
43+
*
44+
* The kernel must implement an associative binary operation over the contents of
45+
* two `PackedData` objects and accumulate the result into @p accum.
46+
*
47+
* Implementations must:
48+
* - Treat @p accum as the running partial result.
49+
* - Combine @p incoming into @p accum in-place.
50+
* - Leave @p incoming in a valid but unspecified state after the call.
51+
*
52+
* The kernel is responsible for interpreting `PackedData::metadata` and
53+
* `PackedData::data` consistently across all ranks.
54+
*/
55+
using ReduceKernel = std::function<void(PackedData& accum, PackedData&& incoming)>;
56+
57+
/**
58+
* @brief AllReduce collective.
59+
*
60+
* The current implementation is built using `allgather::AllGather` and performs
61+
* the reduction locally after allgather completes. Considering `R` is the number of
62+
* ranks, and `N` is the number of bytes of data, per rank this incurs `O(R * N)` bytes of
63+
* memory consumption and `O(R)` communication operations.
64+
*
65+
* Semantics:
66+
* - Each rank may call `insert` any number of times with a local sequence number.
67+
* - Conceptually, the *k*-th insertion on each rank participates in a single
68+
* global reduction. That is, insertions are paired across ranks by their
69+
* local insertion order, not by sequence number values.
70+
* - Once all ranks call `insert_finished`, `wait_and_extract` returns one
71+
* globally-reduced `PackedData` per local insertion on this rank.
72+
*
73+
* The actual reduction is implemented via a type-erased `ReduceKernel` that is
74+
* supplied at construction time. Helper factories such as
75+
* `detail::make_reduce_kernel` can be used to build element-wise
76+
* reductions over contiguous arrays in device memory.
77+
*/
78+
class AllReduce {
79+
public:
80+
/**
81+
* @brief Construct a new AllReduce operation.
82+
*
83+
* @param comm The communicator for communication.
84+
* @param progress_thread The progress thread used by the underlying AllGather.
85+
* @param op_id Unique operation identifier for this allreduce.
86+
* @param br Buffer resource for memory allocation.
87+
* @param statistics Statistics collection instance (disabled by default).
88+
* @param reduce_kernel Type-erased reduction kernel to use.
89+
* @param finished_callback Optional callback run once locally when the allreduce
90+
* is finished and results are ready for extraction.
91+
*
92+
* @note This constructor internally creates an `allgather::AllGather` instance
93+
* that uses the same communicator, progress thread, and buffer resource.
94+
*/
95+
AllReduce(
96+
std::shared_ptr<Communicator> comm,
97+
std::shared_ptr<ProgressThread> progress_thread,
98+
OpID op_id,
99+
BufferResource* br,
100+
std::shared_ptr<Statistics> statistics = Statistics::disabled(),
101+
ReduceKernel reduce_kernel = {},
102+
std::function<void(void)> finished_callback = nullptr
103+
);
104+
105+
AllReduce(AllReduce const&) = delete;
106+
AllReduce& operator=(AllReduce const&) = delete;
107+
AllReduce(AllReduce&&) = delete;
108+
AllReduce& operator=(AllReduce&&) = delete;
109+
110+
/**
111+
* @brief Destructor.
112+
*
113+
* @note This operation is logically collective. If an `AllReduce` is locally
114+
* destructed before `wait_and_extract` is called, there is no guarantee
115+
* that in-flight communication will be completed.
116+
*/
117+
~AllReduce();
118+
119+
/**
120+
* @brief Insert packed data into the allreduce operation.
121+
*
122+
* @param sequence_number Local ordered sequence number of the data.
123+
* @param packed_data The data to contribute to the allreduce.
124+
*
125+
* The caller promises that:
126+
* - `sequence_number`s are non-decreasing on each rank.
127+
* - The *k*-th call to `insert` on each rank corresponds to the same logical
128+
* reduction across all ranks (i.e., same element type and shape).
129+
*/
130+
void insert(std::uint64_t sequence_number, PackedData&& packed_data);
131+
132+
/**
133+
* @brief Mark that this rank has finished contributing data.
134+
*/
135+
void insert_finished();
136+
137+
/**
138+
* @brief Check if the allreduce operation has completed.
139+
*
140+
* @return True if all data and finish messages from all ranks have been
141+
* received and locally reduced.
142+
*/
143+
[[nodiscard]] bool finished() const noexcept;
144+
145+
/**
146+
* @brief Wait for completion and extract all reduced data.
147+
*
148+
* Blocks until the allreduce operation completes and returns all locally
149+
* reduced results, ordered by local insertion order.
150+
*
151+
* @param timeout Optional maximum duration to wait. Negative values mean
152+
* no timeout.
153+
*
154+
* @return A vector containing reduced packed data, one entry per local
155+
* insertion on this rank.
156+
* @throws std::runtime_error If the timeout is reached.
157+
*/
158+
[[nodiscard]] std::vector<PackedData> wait_and_extract(
159+
std::chrono::milliseconds timeout = std::chrono::milliseconds{-1}
160+
);
161+
162+
/**
163+
* @brief Check if reduced results are ready for extraction.
164+
*
165+
* This returns true once the underlying allgather has completed and, if
166+
* `wait_and_extract` has not yet been called, indicates that calling it
167+
* would not block.
168+
*
169+
* @return True if the allreduce operation has completed and results are ready for
170+
* extraction, false otherwise.
171+
*/
172+
[[nodiscard]] bool is_ready() const noexcept;
173+
174+
private:
175+
/// @brief Perform the reduction across all ranks for all gathered contributions.
176+
[[nodiscard]] std::vector<PackedData> reduce_all(std::vector<PackedData>&& gathered);
177+
178+
std::shared_ptr<Communicator> comm_; ///< Communicator
179+
std::shared_ptr<ProgressThread>
180+
progress_thread_; ///< Progress thread (unused directly).
181+
BufferResource* br_; ///< Buffer resource
182+
std::shared_ptr<Statistics> statistics_; ///< Statistics collection instance
183+
ReduceKernel reduce_kernel_; ///< Type-erased reduction kernel
184+
std::function<void(void)> finished_callback_; ///< Optional finished callback
185+
186+
allgather::AllGather gatherer_; ///< Underlying allgather primitive
187+
188+
std::atomic<std::uint32_t> nlocal_insertions_{0}; ///< Number of local inserts
189+
std::atomic<bool> reduced_computed_{
190+
false
191+
}; ///< Whether the reduction has been computed
192+
std::vector<PackedData> reduced_results_; ///< Cached reduced results
193+
};
194+
195+
namespace detail {
196+
/**
197+
* @brief Create a device-based element-wise reduction kernel for a given (T, Op).
198+
*
199+
* This kernel expects both `PackedData::data` buffers to reside in device memory.
200+
* Implementations are provided in `device_kernels.cu` for a subset of (T, Op)
201+
* combinations.
202+
*/
203+
template <typename T, ReduceOp Op>
204+
ReduceKernel make_reduce_kernel();
205+
} // namespace detail
206+
207+
} // namespace rapidsmpf::allreduce

cpp/src/allreduce/allreduce.cpp

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/**
2+
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
#include <algorithm>
7+
#include <chrono>
8+
#include <cstddef>
9+
#include <cstdint>
10+
#include <utility>
11+
#include <vector>
12+
13+
#include <rapidsmpf/allreduce/allreduce.hpp>
14+
#include <rapidsmpf/error.hpp>
15+
16+
namespace rapidsmpf::allreduce {
17+
18+
AllReduce::AllReduce(
19+
std::shared_ptr<Communicator> comm,
20+
std::shared_ptr<ProgressThread> progress_thread,
21+
OpID op_id,
22+
BufferResource* br,
23+
std::shared_ptr<Statistics> statistics,
24+
ReduceKernel reduce_kernel,
25+
std::function<void(void)> finished_callback
26+
)
27+
: comm_{std::move(comm)},
28+
progress_thread_{std::move(progress_thread)},
29+
br_{br},
30+
statistics_{std::move(statistics)},
31+
reduce_kernel_{std::move(reduce_kernel)},
32+
finished_callback_{std::move(finished_callback)},
33+
gatherer_{comm_, progress_thread_, op_id, br_, statistics_} {
34+
RAPIDSMPF_EXPECTS(
35+
static_cast<bool>(reduce_kernel_),
36+
"AllReduce requires a valid ReduceKernel at construction time"
37+
);
38+
}
39+
40+
AllReduce::~AllReduce() = default;
41+
42+
void AllReduce::insert(std::uint64_t sequence_number, PackedData&& packed_data) {
43+
nlocal_insertions_.fetch_add(1, std::memory_order_relaxed);
44+
gatherer_.insert(sequence_number, std::move(packed_data));
45+
}
46+
47+
void AllReduce::insert_finished() {
48+
gatherer_.insert_finished();
49+
}
50+
51+
bool AllReduce::finished() const noexcept {
52+
return gatherer_.finished();
53+
}
54+
55+
std::vector<PackedData> AllReduce::wait_and_extract(std::chrono::milliseconds timeout) {
56+
// Block until the underlying allgather completes, then perform the reduction locally
57+
// (exactly once).
58+
if (!reduced_computed_.load(std::memory_order_acquire)) {
59+
auto gathered = gatherer_.wait_and_extract(
60+
allgather::AllGather::Ordered::YES, std::move(timeout)
61+
);
62+
reduced_results_ = reduce_all(std::move(gathered));
63+
reduced_computed_.store(true, std::memory_order_release);
64+
if (finished_callback_) {
65+
finished_callback_();
66+
}
67+
}
68+
return std::move(reduced_results_);
69+
}
70+
71+
bool AllReduce::is_ready() const noexcept {
72+
return reduced_computed_.load(std::memory_order_acquire) || gatherer_.finished();
73+
}
74+
75+
std::vector<PackedData> AllReduce::reduce_all(std::vector<PackedData>&& gathered) {
76+
auto const nranks = static_cast<std::size_t>(comm_->nranks());
77+
auto const total = gathered.size();
78+
79+
if (total == 0) {
80+
return {};
81+
}
82+
83+
RAPIDSMPF_EXPECTS(
84+
nranks > 0, "AllReduce requires a positive number of ranks", std::runtime_error
85+
);
86+
RAPIDSMPF_EXPECTS(
87+
total % nranks == 0,
88+
"AllReduce expects each rank to contribute the same number of messages",
89+
std::runtime_error
90+
);
91+
92+
auto const n_local =
93+
static_cast<std::size_t>(nlocal_insertions_.load(std::memory_order_acquire));
94+
auto const n_per_rank = total / nranks;
95+
96+
// We allow non-uniform insertion counts across ranks but require that the local
97+
// insertion count matches the per-rank contribution implied by the gather.
98+
RAPIDSMPF_EXPECTS(
99+
n_local == 0 || n_local == n_per_rank,
100+
"AllReduce local insertion count does not match gathered contributions per rank",
101+
std::runtime_error
102+
);
103+
104+
std::vector<PackedData> results;
105+
results.reserve(n_per_rank);
106+
107+
// Conceptually, the k-th insertion on each rank participates in a single
108+
// reduction. With ordered allgather results, entries are laid out as:
109+
// [rank0:0..n_per_rank-1][rank1:0..n_per_rank-1]...[rankP-1:0..n_per_rank-1]
110+
for (std::size_t k = 0; k < n_per_rank; ++k) {
111+
// Start from rank 0's contribution for this logical insertion.
112+
auto accum = std::move(gathered[k]);
113+
for (std::size_t r = 1; r < nranks; ++r) {
114+
auto idx = r * n_per_rank + k;
115+
reduce_kernel_(accum, std::move(gathered[idx]));
116+
}
117+
results.emplace_back(std::move(accum));
118+
}
119+
120+
return results;
121+
}
122+
123+
} // namespace rapidsmpf::allreduce

0 commit comments

Comments
 (0)