Skip to content

Commit b353aa3

Browse files
committed
Remove support for multiple PackedData inserts
1 parent 70145ee commit b353aa3

File tree

3 files changed

+102
-165
lines changed

3 files changed

+102
-165
lines changed

cpp/include/rapidsmpf/coll/allreduce.hpp

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,9 @@ using ReduceKernel = std::function<void(PackedData& accum, PackedData&& incoming
6363
* memory consumption and `O(R)` communication operations.
6464
*
6565
* Semantics:
66-
* - Each rank may call `insert` any number of times with a local sequence number.
67-
* - Conceptually, the *k*-th insertion on each rank participates in a single
68-
* global reduction. That is, insertions are paired across ranks by their
69-
* local insertion order, not by sequence number values.
70-
* - Once all ranks call `insert_finished`, `wait_and_extract` returns one
71-
* globally-reduced `PackedData` per local insertion on this rank.
66+
* - Each rank calls `insert` exactly once to contribute data to the reduction.
67+
* - Once all ranks call `insert_finished`, `wait_and_extract` returns the
68+
* globally-reduced `PackedData`.
7269
*
7370
* The actual reduction is implemented via a type-erased `ReduceKernel` that is
7471
* supplied at construction time. Helper factories such as
@@ -119,15 +116,11 @@ class AllReduce {
119116
/**
120117
* @brief Insert packed data into the allreduce operation.
121118
*
122-
* @param sequence_number Local ordered sequence number of the data.
123119
* @param packed_data The data to contribute to the allreduce.
124120
*
125-
* The caller promises that:
126-
* - `sequence_number`s are non-decreasing on each rank.
127-
* - The *k*-th call to `insert` on each rank corresponds to the same logical
128-
* reduction across all ranks (i.e., same element type and shape).
121+
* @throws std::runtime_error If insert has already been called on this instance.
129122
*/
130-
void insert(std::uint64_t sequence_number, PackedData&& packed_data);
123+
void insert(PackedData&& packed_data);
131124

132125
/**
133126
* @brief Mark that this rank has finished contributing data.
@@ -143,19 +136,18 @@ class AllReduce {
143136
[[nodiscard]] bool finished() const noexcept;
144137

145138
/**
146-
* @brief Wait for completion and extract all reduced data.
139+
* @brief Wait for completion and extract the reduced data.
147140
*
148-
* Blocks until the allreduce operation completes and returns all locally
149-
* reduced results, ordered by local insertion order.
141+
* Blocks until the allreduce operation completes and returns the
142+
* globally reduced result.
150143
*
151144
* @param timeout Optional maximum duration to wait. Negative values mean
152145
* no timeout.
153146
*
154-
* @return A vector containing reduced packed data, one entry per local
155-
* insertion on this rank.
147+
* @return The reduced packed data.
156148
* @throws std::runtime_error If the timeout is reached.
157149
*/
158-
[[nodiscard]] std::vector<PackedData> wait_and_extract(
150+
[[nodiscard]] PackedData wait_and_extract(
159151
std::chrono::milliseconds timeout = std::chrono::milliseconds{-1}
160152
);
161153

@@ -172,16 +164,16 @@ class AllReduce {
172164
[[nodiscard]] bool is_ready() const noexcept;
173165

174166
private:
175-
/// @brief Perform the reduction across all ranks for all gathered contributions.
176-
[[nodiscard]] std::vector<PackedData> reduce_all(std::vector<PackedData>&& gathered);
167+
/// @brief Perform the reduction across all ranks for the gathered contributions.
168+
[[nodiscard]] PackedData reduce_all(std::vector<PackedData>&& gathered);
177169

178170
ReduceKernel reduce_kernel_; ///< Type-erased reduction kernel
179171
std::function<void(void)> finished_callback_; ///< Optional finished callback
180172

181173
Rank nranks_; ///< Number of ranks in the communicator
182174
AllGather gatherer_; ///< Underlying allgather primitive
183175

184-
std::atomic<std::uint32_t> nlocal_insertions_{0}; ///< Number of local inserts
176+
bool inserted_{false}; ///< Whether insert has been called
185177
};
186178

187179
namespace detail {

cpp/src/coll/allreduce.cpp

Lines changed: 19 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ AllReduce::AllReduce(
2626
)
2727
: reduce_kernel_{std::move(reduce_kernel)},
2828
finished_callback_{std::move(finished_callback)},
29+
nranks_{comm->nranks()},
2930
gatherer_{
3031
std::move(comm), std::move(progress_thread), op_id, br, std::move(statistics)
3132
} {
@@ -37,9 +38,14 @@ AllReduce::AllReduce(
3738

3839
AllReduce::~AllReduce() = default;
3940

40-
void AllReduce::insert(std::uint64_t sequence_number, PackedData&& packed_data) {
41-
nlocal_insertions_.fetch_add(1, std::memory_order_relaxed);
42-
gatherer_.insert(sequence_number, std::move(packed_data));
41+
void AllReduce::insert(PackedData&& packed_data) {
42+
RAPIDSMPF_EXPECTS(
43+
!inserted_,
44+
"AllReduce::insert can only be called once per instance",
45+
std::runtime_error
46+
);
47+
inserted_ = true;
48+
gatherer_.insert(0, std::move(packed_data));
4349
}
4450

4551
void AllReduce::insert_finished() {
@@ -50,7 +56,7 @@ bool AllReduce::finished() const noexcept {
5056
return gatherer_.finished();
5157
}
5258

53-
std::vector<PackedData> AllReduce::wait_and_extract(std::chrono::milliseconds timeout) {
59+
PackedData AllReduce::wait_and_extract(std::chrono::milliseconds timeout) {
5460
// Block until the underlying allgather completes, then perform the reduction locally
5561
// (exactly once).
5662
auto gathered =
@@ -62,48 +68,24 @@ bool AllReduce::is_ready() const noexcept {
6268
return gatherer_.finished();
6369
}
6470

65-
std::vector<PackedData> AllReduce::reduce_all(std::vector<PackedData>&& gathered) {
71+
PackedData AllReduce::reduce_all(std::vector<PackedData>&& gathered) {
6672
auto const total = gathered.size();
6773

68-
if (total == 0) {
69-
return {};
70-
}
71-
7274
RAPIDSMPF_EXPECTS(
73-
total % nranks_ == 0,
74-
"AllReduce expects each rank to contribute the same number of messages",
75+
total == static_cast<std::size_t>(nranks_),
76+
"AllReduce expects exactly one contribution from each rank",
7577
std::runtime_error
7678
);
7779

78-
auto const n_local =
79-
static_cast<std::size_t>(nlocal_insertions_.load(std::memory_order_acquire));
80-
auto const n_per_rank = total / nranks_;
81-
82-
// We allow non-uniform insertion counts across ranks but require that the local
83-
// insertion count matches the per-rank contribution implied by the gather.
84-
RAPIDSMPF_EXPECTS(
85-
n_local == 0 || n_local == n_per_rank,
86-
"AllReduce local insertion count does not match gathered contributions per rank",
87-
std::runtime_error
88-
);
80+
// Start with rank 0's contribution as the accumulator
81+
auto accum = std::move(gathered[0]);
8982

90-
std::vector<PackedData> results;
91-
results.reserve(n_per_rank);
92-
93-
// Conceptually, the k-th insertion on each rank participates in a single
94-
// reduction. With ordered allgather results, entries are laid out as:
95-
// [rank0:0..n_per_rank-1][rank1:0..n_per_rank-1]...[rankP-1:0..n_per_rank-1]
96-
for (std::size_t k = 0; k < n_per_rank; ++k) {
97-
// Start from rank 0's contribution for this logical insertion.
98-
auto accum = std::move(gathered[k]);
99-
for (std::size_t r = 1; r < nranks_; ++r) {
100-
auto idx = r * n_per_rank + k;
101-
reduce_kernel_(accum, std::move(gathered[idx]));
102-
}
103-
results.emplace_back(std::move(accum));
83+
// Reduce contributions from all other ranks into the accumulator
84+
for (std::size_t r = 1; r < static_cast<std::size_t>(nranks_); ++r) {
85+
reduce_kernel_(accum, std::move(gathered[r]));
10486
}
10587

106-
return results;
88+
return accum;
10789
}
10890

10991
} // namespace rapidsmpf::coll

0 commit comments

Comments
 (0)