Skip to content

Commit d80c369

Browse files
Added inline ctors to LL and LL128 Primitives that match the Simple signature used in all_gather, and explicitly forward params to the proper ctor. This removes the warning and any ambiguity around the implicit overloading of p2pWork and stepSize_.
1 parent 9efc459 commit d80c369

File tree

3 files changed

+17
-1
lines changed

3 files changed

+17
-1
lines changed

src/device/all_gather.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ namespace {
7070
// FanSymmetric<1>, only the first element is ever accessed, so it's fine.
7171
// coverity[callee_ptr_arith:FALSE]
7272
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto, 0, isNetOffload> prims
73-
(tid, workNthreads, &ring->prev, &ring->next, inputBuf, outputBuf, work->redOpArg, 0, work->connIndex, work->connIndex, work, NULL, isNetOffload ? NCCL_MAX_NET_SIZE : 0);
73+
(tid, workNthreads, &ring->prev, &ring->next, inputBuf, outputBuf, work->redOpArg, 0, work->connIndex, work->connIndex, work, nullptr, isNetOffload ? NCCL_MAX_NET_SIZE : 0);
7474

7575
#if defined(ENABLE_NPKIT)
7676
if (tid == 0) {

src/device/prims_ll.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,14 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload, Metadata, Pi
680680
setDataPtrs(inputBuf, outputBuf, e != nullptr ? e->acc : nullptr);
681681
}
682682

683+
__forceinline__ __device__ Primitives(
684+
int tid, int nthreads, int const *recvPeers, int const *sendPeers,
685+
void const *inputBuf, void *outputBuf, uint64_t redOpArg, uint8_t group,
686+
uint8_t connIndexRecv, uint8_t connIndexSend, struct ncclDevWorkColl* collWork,
687+
struct ncclDevWorkP2p* p2pWork, int stepSize_ = 0, int mode = primsModeDefault
688+
): Primitives(tid, nthreads, recvPeers, sendPeers, inputBuf, outputBuf, redOpArg, group,
689+
connIndexRecv, connIndexSend, collWork) {}
690+
683691
__device__ ~Primitives() {
684692
// Save steps for the next operation
685693
if (tid >= nthreads-WARP_SIZE && wid < fan.nrecv())

src/device/prims_ll128.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,14 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p, isNetOffload, Metadata,
600600
setDataPtrs(inputBuf, outputBuf, e != nullptr ? e->acc : nullptr);
601601
}
602602

603+
__forceinline__ __device__ Primitives(
604+
int tid, int nthreads, int const *recvPeers, int const *sendPeers,
605+
void const *inputBuf, void *outputBuf, uint64_t redOpArg, uint8_t group,
606+
uint8_t connIndexRecv, uint8_t connIndexSend, struct ncclDevWorkColl* collWork,
607+
struct ncclDevWorkP2p* p2pWork, int stepSize_ = 0, int mode = primsModeDefault
608+
): Primitives(tid, nthreads, recvPeers, sendPeers, inputBuf, outputBuf, redOpArg, group,
609+
connIndexRecv, connIndexSend, collWork) {}
610+
603611
__device__ ~Primitives() {
604612
// Save steps for the next operation
605613
if (tid >= nthreads-WARP_SIZE && wid < fan.nrecv())

0 commit comments

Comments
 (0)