@@ -24,13 +24,11 @@ AllReduce::AllReduce(
2424 ReduceKernel reduce_kernel,
2525 std::function<void (void )> finished_callback
2626)
27- : comm_{std::move (comm)},
28- progress_thread_{std::move (progress_thread)},
29- br_{br},
30- statistics_{std::move (statistics)},
31- reduce_kernel_{std::move (reduce_kernel)},
27+ : reduce_kernel_{std::move (reduce_kernel)},
3228 finished_callback_{std::move (finished_callback)},
33- gatherer_{comm_, progress_thread_, op_id, br_, statistics_} {
29+ gatherer_{
30+ std::move (comm), std::move (progress_thread), op_id, br, std::move (statistics)
31+ } {
3432 RAPIDSMPF_EXPECTS (
3533 static_cast <bool >(reduce_kernel_),
3634 " AllReduce requires a valid ReduceKernel at construction time"
@@ -55,42 +53,31 @@ bool AllReduce::finished() const noexcept {
5553std::vector<PackedData> AllReduce::wait_and_extract (std::chrono::milliseconds timeout) {
5654 // Block until the underlying allgather completes, then perform the reduction locally
5755 // (exactly once).
58- if (!reduced_computed_.load (std::memory_order_acquire)) {
59- auto gathered =
60- gatherer_.wait_and_extract (AllGather::Ordered::YES, std::move (timeout));
61- reduced_results_ = reduce_all (std::move (gathered));
62- reduced_computed_.store (true , std::memory_order_release);
63- if (finished_callback_) {
64- finished_callback_ ();
65- }
66- }
67- return std::move (reduced_results_);
56+ auto gathered =
57+ gatherer_.wait_and_extract (AllGather::Ordered::YES, std::move (timeout));
58+ return reduce_all (std::move (gathered));
6859}
6960
7061bool AllReduce::is_ready () const noexcept {
71- return reduced_computed_. load (std::memory_order_acquire) || gatherer_.finished ();
62+ return gatherer_.finished ();
7263}
7364
7465std::vector<PackedData> AllReduce::reduce_all (std::vector<PackedData>&& gathered) {
75- auto const nranks = static_cast <std::size_t >(comm_->nranks ());
7666 auto const total = gathered.size ();
7767
7868 if (total == 0 ) {
7969 return {};
8070 }
8171
8272 RAPIDSMPF_EXPECTS (
83- nranks > 0 , " AllReduce requires a positive number of ranks" , std::runtime_error
84- );
85- RAPIDSMPF_EXPECTS (
86- total % nranks == 0 ,
73+ total % nranks_ == 0 ,
8774 " AllReduce expects each rank to contribute the same number of messages" ,
8875 std::runtime_error
8976 );
9077
9178 auto const n_local =
9279 static_cast <std::size_t >(nlocal_insertions_.load (std::memory_order_acquire));
93- auto const n_per_rank = total / nranks ;
80+ auto const n_per_rank = total / nranks_ ;
9481
9582 // We allow non-uniform insertion counts across ranks but require that the local
9683 // insertion count matches the per-rank contribution implied by the gather.
@@ -109,7 +96,7 @@ std::vector<PackedData> AllReduce::reduce_all(std::vector<PackedData>&& gathered
10996 for (std::size_t k = 0 ; k < n_per_rank; ++k) {
11097 // Start from rank 0's contribution for this logical insertion.
11198 auto accum = std::move (gathered[k]);
112- for (std::size_t r = 1 ; r < nranks ; ++r) {
99+ for (std::size_t r = 1 ; r < nranks_ ; ++r) {
113100 auto idx = r * n_per_rank + k;
114101 reduce_kernel_ (accum, std::move (gathered[idx]));
115102 }
0 commit comments