@@ -26,6 +26,7 @@ AllReduce::AllReduce(
2626)
2727 : reduce_kernel_{std::move (reduce_kernel)},
2828 finished_callback_{std::move (finished_callback)},
29+ nranks_{comm->nranks ()},
2930 gatherer_{
3031 std::move (comm), std::move (progress_thread), op_id, br, std::move (statistics)
3132 } {
@@ -37,9 +38,14 @@ AllReduce::AllReduce(
3738
3839AllReduce::~AllReduce () = default ;
3940
40- void AllReduce::insert (std::uint64_t sequence_number, PackedData&& packed_data) {
41- nlocal_insertions_.fetch_add (1 , std::memory_order_relaxed);
42- gatherer_.insert (sequence_number, std::move (packed_data));
41+ void AllReduce::insert (PackedData&& packed_data) {
42+ RAPIDSMPF_EXPECTS (
43+ !inserted_,
44+ " AllReduce::insert can only be called once per instance" ,
45+ std::runtime_error
46+ );
47+ inserted_ = true ;
48+ gatherer_.insert (0 , std::move (packed_data));
4349}
4450
4551void AllReduce::insert_finished () {
@@ -50,7 +56,7 @@ bool AllReduce::finished() const noexcept {
5056 return gatherer_.finished ();
5157}
5258
53- std::vector< PackedData> AllReduce::wait_and_extract (std::chrono::milliseconds timeout) {
59+ PackedData AllReduce::wait_and_extract (std::chrono::milliseconds timeout) {
5460 // Block until the underlying allgather completes, then perform the reduction locally
5561 // (exactly once).
5662 auto gathered =
@@ -62,48 +68,24 @@ bool AllReduce::is_ready() const noexcept {
6268 return gatherer_.finished ();
6369}
6470
65- std::vector< PackedData> AllReduce::reduce_all (std::vector<PackedData>&& gathered) {
71+ PackedData AllReduce::reduce_all (std::vector<PackedData>&& gathered) {
6672 auto const total = gathered.size ();
6773
68- if (total == 0 ) {
69- return {};
70- }
71-
7274 RAPIDSMPF_EXPECTS (
73- total % nranks_ == 0 ,
74- " AllReduce expects each rank to contribute the same number of messages " ,
75+ total == static_cast <std:: size_t >(nranks_) ,
76+ " AllReduce expects exactly one contribution from each rank " ,
7577 std::runtime_error
7678 );
7779
78- auto const n_local =
79- static_cast <std::size_t >(nlocal_insertions_.load (std::memory_order_acquire));
80- auto const n_per_rank = total / nranks_;
81-
82- // We allow non-uniform insertion counts across ranks but require that the local
83- // insertion count matches the per-rank contribution implied by the gather.
84- RAPIDSMPF_EXPECTS (
85- n_local == 0 || n_local == n_per_rank,
86- " AllReduce local insertion count does not match gathered contributions per rank" ,
87- std::runtime_error
88- );
80+ // Start with rank 0's contribution as the accumulator
81+ auto accum = std::move (gathered[0 ]);
8982
90- std::vector<PackedData> results;
91- results.reserve (n_per_rank);
92-
93- // Conceptually, the k-th insertion on each rank participates in a single
94- // reduction. With ordered allgather results, entries are laid out as:
95- // [rank0:0..n_per_rank-1][rank1:0..n_per_rank-1]...[rankP-1:0..n_per_rank-1]
96- for (std::size_t k = 0 ; k < n_per_rank; ++k) {
97- // Start from rank 0's contribution for this logical insertion.
98- auto accum = std::move (gathered[k]);
99- for (std::size_t r = 1 ; r < nranks_; ++r) {
100- auto idx = r * n_per_rank + k;
101- reduce_kernel_ (accum, std::move (gathered[idx]));
102- }
103- results.emplace_back (std::move (accum));
83+ // Reduce contributions from all other ranks into the accumulator
84+ for (std::size_t r = 1 ; r < static_cast <std::size_t >(nranks_); ++r) {
85+ reduce_kernel_ (accum, std::move (gathered[r]));
10486 }
10587
106- return results ;
88+ return accum ;
10789}
10890
10991} // namespace rapidsmpf::coll
0 commit comments