Skip to content

Commit 91f0772

Browse files
wdeconincksimondsmart
authored andcommitted
Add support in eckit::mpi::Comm for std::vector arguments with non-standard allocator
1 parent c92b034 commit 91f0772

File tree

3 files changed

+171
-126
lines changed

3 files changed

+171
-126
lines changed

src/eckit/mpi/Buffer.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,16 @@ namespace eckit::mpi {
2020

2121
/// Buffer handles colleciton of vector pieces into a larger vector
2222

23-
template <typename DATA_TYPE>
23+
template <typename DATA_TYPE, typename Allocator = std::allocator<DATA_TYPE>>
2424
struct Buffer {
2525
typedef DATA_TYPE value_type;
26-
typedef typename std::vector<DATA_TYPE>::iterator iterator;
26+
typedef typename std::vector<DATA_TYPE, Allocator>::iterator iterator;
2727

2828
int cnt;
2929

3030
std::vector<int> counts;
3131
std::vector<int> displs;
32-
std::vector<DATA_TYPE> buffer;
32+
std::vector<DATA_TYPE, Allocator> buffer;
3333

3434
Buffer(size_t size) {
3535
counts.resize(size);

src/eckit/mpi/Comm.h

Lines changed: 70 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <iterator>
1616
#include <string>
1717
#include <string_view>
18+
#include <type_traits>
1819
#include <vector>
1920

2021
#include "eckit/filesystem/PathName.h"
@@ -71,6 +72,12 @@ namespace detail {
7172
/// Assertions for eckit::mpi code
7273
/// Don't use directly in client code
7374
void Assert(int code, const char* msg, const char* file, int line, const char* func);
75+
76+
template <typename Type>
77+
struct is_std_vector : std::false_type {};
78+
79+
template <typename T, typename A>
80+
struct is_std_vector<std::vector<T, A> > : std::true_type {};
7481
} // namespace detail
7582

7683
//----------------------------------------------------------------------------------------------------------------------
@@ -164,8 +171,8 @@ class Comm : private eckit::NonCopyable {
164171
template <typename T>
165172
void broadcast(T buffer[], size_t count, size_t root) const;
166173

167-
template <typename T>
168-
void broadcast(typename std::vector<T>& v, size_t root) const;
174+
template <typename T, typename A>
175+
void broadcast(typename std::vector<T, A>& v, size_t root) const;
169176

170177
template <class Iter>
171178
void broadcast(Iter first, Iter last, size_t root) const;
@@ -177,11 +184,11 @@ class Comm : private eckit::NonCopyable {
177184
template <class CIter, class Iter>
178185
void gather(CIter first, CIter last, Iter rfirst, Iter rlast, size_t root) const;
179186

180-
template <typename T>
181-
void gather(const T send, std::vector<T>& recv, size_t root) const;
187+
template <typename T, typename A>
188+
void gather(const T send, std::vector<T, A>& recv, size_t root) const;
182189

183-
template <typename T>
184-
void gather(const std::vector<T>& send, std::vector<T>& recv, size_t root) const;
190+
template <typename T, typename A1, typename A2>
191+
void gather(const std::vector<T, A1>& send, std::vector<T, A2>& recv, size_t root) const;
185192

186193
///
187194
/// Gather methods to one root, variable data sizes per rank
@@ -195,23 +202,23 @@ class Comm : private eckit::NonCopyable {
195202
void gatherv(CIter first, CIter last, Iter rfirst, Iter rlast, const int recvcounts[], const int displs[],
196203
size_t root) const;
197204

198-
template <class CIter, class Iter>
199-
void gatherv(CIter first, CIter last, Iter rfirst, Iter rlast, const std::vector<int>& recvcounts,
200-
const std::vector<int>& displs, size_t root) const;
205+
template <class CIter, class Iter, typename A1, typename A2>
206+
void gatherv(CIter first, CIter last, Iter rfirst, Iter rlast, const std::vector<int, A1>& recvcounts,
207+
const std::vector<int, A2>& displs, size_t root) const;
201208

202-
template <typename T>
203-
void gatherv(const std::vector<T>& send, std::vector<T>& recv, const std::vector<int>& recvcounts,
204-
const std::vector<int>& displs, size_t root) const;
209+
template <typename T, typename A1, typename A2, typename A3, typename A4>
210+
void gatherv(const std::vector<T, A1>& send, std::vector<T, A2>& recv, const std::vector<int, A3>& recvcounts,
211+
const std::vector<int, A4>& displs, size_t root) const;
205212

206213
///
207214
/// Scatter methods from one root
208215
///
209216

210-
template <typename T>
211-
void scatter(const std::vector<T>& send, T& recv, size_t root) const;
217+
template <typename T, typename A>
218+
void scatter(const std::vector<T, A>& send, T& recv, size_t root) const;
212219

213-
template <typename T>
214-
void scatter(const std::vector<T>& send, std::vector<T>& recv, size_t root) const;
220+
template <typename T, typename A1, typename A2>
221+
void scatter(const std::vector<T, A1>& send, std::vector<T, A2>& recv, size_t root) const;
215222

216223
///
217224
/// Scatter methods from one root, variable data sizes per rank, pointer to data (also covers
@@ -226,8 +233,8 @@ class Comm : private eckit::NonCopyable {
226233
void scatterv(CIter first, CIter last, const int sendcounts[], const int displs[], Iter rfirst, Iter rlast,
227234
size_t root) const;
228235

229-
template <class CIter, class Iter>
230-
void scatterv(CIter first, CIter last, const std::vector<int>& sendcounts, const std::vector<int>& displs,
236+
template <class CIter, class Iter, typename A1, typename A2>
237+
void scatterv(CIter first, CIter last, const std::vector<int, A1>& sendcounts, const std::vector<int, A2>& displs,
231238
Iter rfirst, Iter rlast, size_t root) const;
232239

233240
///
@@ -240,8 +247,8 @@ class Comm : private eckit::NonCopyable {
240247
template <typename T>
241248
void reduce(const T* send, T* recv, size_t count, Operation::Code op, size_t root) const;
242249

243-
template <typename T>
244-
void reduce(const std::vector<T>& send, std::vector<T>& recv, Operation::Code op, size_t root) const;
250+
template <typename T, typename A1, typename A2>
251+
void reduce(const std::vector<T, A1>& send, std::vector<T, A2>& recv, Operation::Code op, size_t root) const;
245252

246253
///
247254
/// Reduce operations, in place buffer
@@ -260,14 +267,14 @@ class Comm : private eckit::NonCopyable {
260267
/// All reduce operations, separate buffers
261268
///
262269

263-
template <typename T>
270+
template <typename T, std::enable_if_t<!eckit::mpi::detail::is_std_vector<T>::value>* = nullptr>
264271
void allReduce(const T send, T& recv, Operation::Code op) const;
265272

266273
template <typename T>
267274
void allReduce(const T* send, T* recv, size_t count, Operation::Code op) const;
268275

269-
template <typename T>
270-
void allReduce(const std::vector<T>& send, std::vector<T>& recv, Operation::Code op) const;
276+
template <typename T, typename A1, typename A2>
277+
void allReduce(const std::vector<T, A1>& send, std::vector<T, A2>& recv, Operation::Code op) const;
271278

272279
///
273280
/// All reduce operations, in place buffer
@@ -296,15 +303,15 @@ class Comm : private eckit::NonCopyable {
296303
template <typename CIter, typename Iter>
297304
void allGatherv(CIter first, CIter last, Iter recvbuf, const int recvcounts[], const int displs[]) const;
298305

299-
template <typename T, typename CIter>
300-
void allGatherv(CIter first, CIter last, mpi::Buffer<T>& recv) const;
306+
template <typename T, typename A, typename CIter>
307+
void allGatherv(CIter first, CIter last, mpi::Buffer<T, A>& recv) const;
301308

302309
///
303310
/// All to all methods, fixed data size
304311
///
305312

306-
template <typename T>
307-
void allToAll(const std::vector<T>& send, std::vector<T>& recv) const;
313+
template <typename T, typename A1, typename A2, std::enable_if_t<!detail::is_std_vector<T>::value>* = nullptr>
314+
void allToAll(const std::vector<T, A1>& send, std::vector<T, A2>& recv) const;
308315

309316
///
310317
/// All to All, variable data size
@@ -379,8 +386,9 @@ class Comm : private eckit::NonCopyable {
379386
/// All to all of vector< vector<> >
380387
///
381388

382-
template <typename T>
383-
void allToAll(const std::vector<std::vector<T> >& sendvec, std::vector<std::vector<T> >& recvvec) const;
389+
template <typename T, typename A1, typename A2, typename A3, typename A4>
390+
void allToAll(const std::vector<std::vector<T, A1>, A3>& sendvec,
391+
std::vector<std::vector<T, A2>, A4>& recvvec) const;
384392

385393
///
386394
/// Read file on one rank, and broadcast
@@ -555,8 +563,8 @@ void eckit::mpi::Comm::broadcast(T buffer[], size_t count, size_t root) const {
555563
broadcast(buffer, count, Data::Type<T>::code(), root);
556564
}
557565

558-
template <typename T>
559-
void eckit::mpi::Comm::broadcast(typename std::vector<T>& v, size_t root) const {
566+
template <typename T, typename A>
567+
void eckit::mpi::Comm::broadcast(typename std::vector<T, A>& v, size_t root) const {
560568
size_t commsize = size();
561569
ECKIT_MPI_ASSERT(root < commsize);
562570

@@ -603,8 +611,8 @@ void eckit::mpi::Comm::gather(CIter first, CIter last, Iter rfirst, Iter rlast,
603611
gather(sendbuf, sendcount, recvbuf, recvcount, type, root);
604612
}
605613

606-
template <typename T>
607-
void eckit::mpi::Comm::gather(const T send, std::vector<T>& recv, size_t root) const {
614+
template <typename T, typename A>
615+
void eckit::mpi::Comm::gather(const T send, std::vector<T, A>& recv, size_t root) const {
608616
size_t commsize = size();
609617
ECKIT_MPI_ASSERT(commsize > 0);
610618
ECKIT_MPI_ASSERT(root < commsize);
@@ -616,8 +624,8 @@ void eckit::mpi::Comm::gather(const T send, std::vector<T>& recv, size_t root) c
616624
gather(&send, sendcount, recv.data(), recvcount, Data::Type<T>::code(), root);
617625
}
618626

619-
template <typename T>
620-
void eckit::mpi::Comm::gather(const std::vector<T>& send, std::vector<T>& recv, size_t root) const {
627+
template <typename T, typename A1, typename A2>
628+
void eckit::mpi::Comm::gather(const std::vector<T, A1>& send, std::vector<T, A2>& recv, size_t root) const {
621629
size_t commsize = size();
622630
ECKIT_MPI_ASSERT(commsize > 0);
623631
ECKIT_MPI_ASSERT(root < commsize);
@@ -657,9 +665,9 @@ void eckit::mpi::Comm::gatherv(CIter first, CIter last, Iter rfirst, Iter rlast,
657665
gatherv(sendbuf, sendcount, recvbuf, recvcounts, displs, type, root);
658666
}
659667

660-
template <class CIter, class Iter>
661-
void eckit::mpi::Comm::gatherv(CIter first, CIter last, Iter rfirst, Iter rlast, const std::vector<int>& recvcounts,
662-
const std::vector<int>& displs, size_t root) const {
668+
template <class CIter, class Iter, typename A1, typename A2>
669+
void eckit::mpi::Comm::gatherv(CIter first, CIter last, Iter rfirst, Iter rlast, const std::vector<int, A1>& recvcounts,
670+
const std::vector<int, A2>& displs, size_t root) const {
663671
size_t commsize = size();
664672
ECKIT_MPI_ASSERT(root < commsize);
665673
ECKIT_MPI_ASSERT(recvcounts.size() == commsize);
@@ -668,9 +676,10 @@ void eckit::mpi::Comm::gatherv(CIter first, CIter last, Iter rfirst, Iter rlast,
668676
gatherv(first, last, rfirst, rlast, recvcounts.data(), displs.data(), root);
669677
}
670678

671-
template <typename T>
672-
void eckit::mpi::Comm::gatherv(const std::vector<T>& send, std::vector<T>& recv, const std::vector<int>& recvcounts,
673-
const std::vector<int>& displs, size_t root) const {
679+
template <typename T, typename A1, typename A2, typename A3, typename A4>
680+
void eckit::mpi::Comm::gatherv(const std::vector<T, A1>& send, std::vector<T, A2>& recv,
681+
const std::vector<int, A3>& recvcounts, const std::vector<int, A4>& displs,
682+
size_t root) const {
674683
size_t commsize = size();
675684
ECKIT_MPI_ASSERT(root < commsize);
676685
if (rank() == root) {
@@ -686,8 +695,8 @@ void eckit::mpi::Comm::gatherv(const std::vector<T>& send, std::vector<T>& recv,
686695
/// Scatter methods from one root
687696
///
688697

689-
template <typename T>
690-
void eckit::mpi::Comm::scatter(const std::vector<T>& send, T& recv, size_t root) const {
698+
template <typename T, typename A>
699+
void eckit::mpi::Comm::scatter(const std::vector<T, A>& send, T& recv, size_t root) const {
691700
size_t commsize = size();
692701
ECKIT_MPI_ASSERT(commsize > 0);
693702
ECKIT_MPI_ASSERT(root < commsize);
@@ -699,8 +708,8 @@ void eckit::mpi::Comm::scatter(const std::vector<T>& send, T& recv, size_t root)
699708
scatter(send.data(), sendcount, &recv, recvcount, Data::Type<T>::code(), root);
700709
}
701710

702-
template <typename T>
703-
void eckit::mpi::Comm::scatter(const std::vector<T>& send, std::vector<T>& recv, size_t root) const {
711+
template <typename T, typename A1, typename A2>
712+
void eckit::mpi::Comm::scatter(const std::vector<T, A1>& send, std::vector<T, A2>& recv, size_t root) const {
704713
size_t commsize = size();
705714
ECKIT_MPI_ASSERT(commsize > 0);
706715
ECKIT_MPI_ASSERT(root < commsize);
@@ -745,9 +754,9 @@ void eckit::mpi::Comm::scatterv(CIter first, CIter last, const int sendcounts[],
745754
scatterv(sendbuf, sendcounts, displs, recvbuf, recvcounts, type, root);
746755
}
747756

748-
template <class CIter, class Iter>
749-
void eckit::mpi::Comm::scatterv(CIter first, CIter last, const std::vector<int>& sendcounts,
750-
const std::vector<int>& displs, Iter rfirst, Iter rlast, size_t root) const {
757+
template <class CIter, class Iter, typename A1, typename A2>
758+
void eckit::mpi::Comm::scatterv(CIter first, CIter last, const std::vector<int, A1>& sendcounts,
759+
const std::vector<int, A2>& displs, Iter rfirst, Iter rlast, size_t root) const {
751760
size_t commsize = size();
752761
ECKIT_MPI_ASSERT(root < commsize);
753762
ECKIT_MPI_ASSERT(sendcounts.size() == commsize);
@@ -770,8 +779,9 @@ void eckit::mpi::Comm::reduce(const T* send, T* recv, size_t count, Operation::C
770779
reduce(send, recv, count, Data::Type<T>::code(), op, root);
771780
}
772781

773-
template <typename T>
774-
void eckit::mpi::Comm::reduce(const std::vector<T>& send, std::vector<T>& recv, Operation::Code op, size_t root) const {
782+
template <typename T, typename A1, typename A2>
783+
void eckit::mpi::Comm::reduce(const std::vector<T, A1>& send, std::vector<T, A2>& recv, Operation::Code op,
784+
size_t root) const {
775785
ECKIT_MPI_ASSERT(send.size() == recv.size());
776786
reduce(send.data(), recv.data(), send.size(), Data::Type<T>::code(), op, root);
777787
}
@@ -801,7 +811,7 @@ void eckit::mpi::Comm::reduceInPlace(Iter first, Iter last, Operation::Code op,
801811
/// All reduce operations, separate buffers
802812
///
803813

804-
template <typename T>
814+
template <typename T, std::enable_if_t<!eckit::mpi::detail::is_std_vector<T>::value>*>
805815
void eckit::mpi::Comm::allReduce(const T send, T& recv, Operation::Code op) const {
806816
allReduce(&send, &recv, 1, Data::Type<T>::code(), op);
807817
}
@@ -811,8 +821,8 @@ void eckit::mpi::Comm::allReduce(const T* send, T* recv, size_t count, Operation
811821
allReduce(send, recv, count, Data::Type<T>::code(), op);
812822
}
813823

814-
template <typename T>
815-
void eckit::mpi::Comm::allReduce(const std::vector<T>& send, std::vector<T>& recv, Operation::Code op) const {
824+
template <typename T, typename A1, typename A2>
825+
void eckit::mpi::Comm::allReduce(const std::vector<T, A1>& send, std::vector<T, A2>& recv, Operation::Code op) const {
816826
ECKIT_MPI_ASSERT(send.size() == recv.size());
817827
allReduce(send.data(), recv.data(), send.size(), Data::Type<T>::code(), op);
818828
}
@@ -883,8 +893,8 @@ void eckit::mpi::Comm::allGatherv(CIter first, CIter last, Iter rfirst, const in
883893
/// All to all methods, fixed data size
884894
///
885895

886-
template <typename T>
887-
void eckit::mpi::Comm::allToAll(const std::vector<T>& send, std::vector<T>& recv) const {
896+
template <typename T, typename A1, typename A2, std::enable_if_t<!eckit::mpi::detail::is_std_vector<T>::value>*>
897+
void eckit::mpi::Comm::allToAll(const std::vector<T, A1>& send, std::vector<T, A2>& recv) const {
888898
size_t commsize = size();
889899
ECKIT_MPI_ASSERT(commsize > 0);
890900
ECKIT_MPI_ASSERT(send.size() % commsize == 0);
@@ -988,8 +998,8 @@ eckit::mpi::Request eckit::mpi::Comm::iSend(const T& sendbuf, int dest, int tag)
988998
return iSend(&sendbuf, 1, Data::Type<T>::code(), dest, tag);
989999
}
9901000

991-
template <typename T, typename CIter>
992-
void eckit::mpi::Comm::allGatherv(CIter first, CIter last, mpi::Buffer<T>& recv) const {
1001+
template <typename T, typename A, typename CIter>
1002+
void eckit::mpi::Comm::allGatherv(CIter first, CIter last, mpi::Buffer<T, A>& recv) const {
9931003
int sendcnt = int(std::distance(first, last));
9941004

9951005
allGather(sendcnt, recv.counts.begin(), recv.counts.end());
@@ -1007,9 +1017,9 @@ void eckit::mpi::Comm::allGatherv(CIter first, CIter last, mpi::Buffer<T>& recv)
10071017
allGatherv(first, last, recv.buffer.data(), recv.counts.data(), recv.displs.data());
10081018
}
10091019

1010-
template <typename T>
1011-
void eckit::mpi::Comm::allToAll(const std::vector<std::vector<T> >& sendvec,
1012-
std::vector<std::vector<T> >& recvvec) const {
1020+
template <typename T, typename A1, typename A2, typename A3, typename A4>
1021+
void eckit::mpi::Comm::allToAll(const std::vector<std::vector<T, A1>, A3>& sendvec,
1022+
std::vector<std::vector<T, A2>, A4>& recvvec) const {
10131023
size_t commsize = size();
10141024
ECKIT_MPI_ASSERT(sendvec.size() == commsize);
10151025
ECKIT_MPI_ASSERT(recvvec.size() == commsize);

0 commit comments

Comments
 (0)