-
Notifications
You must be signed in to change notification settings - Fork 192
[Device] WarpSpeed feature enablement #2073
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 46 commits
8568514
8ef18e4
bf97215
3cd8227
ff1d576
67b95c7
e8546e6
3545ce9
e6e0b25
385729d
c2f80c3
1915071
773077b
dc05efc
312c75b
e4087e4
b1266f3
6ecb5b7
a1f32bb
7fa1926
378d54c
e15b9e0
d9e8cb4
99de243
6394818
0cda162
64c4549
3190e86
a77f6c8
a13c782
a47eddb
f77efa0
78259af
0c0eaaf
50d914b
9fefa8b
d717584
4cdf269
fc53c54
9b695e3
c87510b
ad8d80e
fa7d972
dff8220
c3539b3
e0ad5f8
9f18a01
6b8cf5c
e6905e9
8b5a1f5
0410c4d
ed81f5d
30367a2
75e327d
83751fc
1a32798
168e017
bf76570
4581393
15677ff
8f5dbe0
d808362
fd2b87b
d6e8b79
fe78376
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -269,7 +269,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload, Metadata, Pi | |
| i4.data2 = (val >> 32); | ||
| i4.flag2 = flag; | ||
| *((u64_gptr) dst->v) = *((u64_gptr) i4.v); | ||
| *((u64_gptr) dst->v+1) = *((u64_gptr) i4.v+1); | ||
| *((u64_gptr) dst->v+1) = *((u64_gptr) i4.v+1); | ||
| #if defined(__gfx950__) && ROCM_VERSION < 70002 | ||
| __builtin_amdgcn_fence(__ATOMIC_RELEASE, ""); // flush cache on gfx950 if ROCr fix for hipHostMallocUncached is not available (ROCm version < 7.0.2) | ||
| #endif | ||
|
|
@@ -507,7 +507,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload, Metadata, Pi | |
| nelem -= eltPerTrip; | ||
| offset += nthreads; | ||
| } | ||
| #ifdef __gfx950__ | ||
| #ifdef __gfx950__ | ||
| if constexpr (isMsccl(Metadata) && DST){ | ||
| // Wait for pending vector loads and stores | ||
| __builtin_amdgcn_s_waitcnt((15 << 8) | (7 << 4)); // s_waitcnt vmcnt(0) | ||
|
|
@@ -652,9 +652,10 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload, Metadata, Pi | |
| bool ipcReg = false, bool netReg = false, int stepSize_ = 0 | ||
| ): | ||
| redOp(redOpArg), | ||
| tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), group(group), threadsPerBlock(blockDim.x), | ||
| tid(tid), nthreads(nthreads), wid(threadIdx.x%WARP_SIZE), group(group), threadsPerBlock(blockDim.x), | ||
|
||
| stepLines(ncclShmem.comm.buffSizes[NCCL_PROTO_LL]/NCCL_STEPS/sizeof(ncclLLFifoLine)) { | ||
| auto *channel = &ncclShmem.channel; | ||
| int warp = threadIdx.x / WARP_SIZE; | ||
| auto *channel = &ncclShmem.warpChannel[warp]; | ||
| barriers = &ncclShmem.groups[group].barrier; | ||
| // If we are going to support oneshot collNet + LL, then we would need to add connector index here | ||
| int nrecv=0, nsend=0; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -81,18 +81,14 @@ class Primitives< | |
|
|
||
| // Don't use barrier 0 as it's used by the final sync | ||
| inline __device__ void barrier() { | ||
| if (nthreads == WARP_SIZE) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we may need something here to prevent code motion in the general case like an __atomic_signal_fence. That's effectively what __syncwarp() is doing for us here I believe since our warps can't diverge at the hardware level. Same with below.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As discussed, I went with __syncwarp(), as this evaluation is out of scope for the PR |
||
| __syncwarp(); | ||
| else | ||
| if(nthreads != WARP_SIZE) | ||
| #if defined(__gfx942__) || defined(__gfx950__) | ||
| barrier_generic(__threadfence_block(), nworkers, barrier_next, barriers); | ||
| #else | ||
| barrier_generic(__threadfence(), nworkers, barrier_next, barriers); | ||
| #endif | ||
| } | ||
| inline __device__ void subBarrier() { | ||
| if (nworkers == WARP_SIZE) __syncwarp(); | ||
| else | ||
| barrier(); | ||
| } | ||
|
|
||
|
|
@@ -490,14 +486,14 @@ class Primitives< | |
|
|
||
| public: | ||
| static inline __device__ void sendPeerNotify(int peer, int connIndex, int steps) { | ||
| ncclDevChannelPeer* peerPtr = ncclShmem.channel.peers[peer]; | ||
| ncclDevChannelPeer* peerPtr = ncclShmem.warpChannel[threadIdx.x/WARP_SIZE].peers[peer]; | ||
| peerPtr->send[connIndex].step += steps; | ||
| st_relaxed_sys_global(peerPtr->send[connIndex].tail, peerPtr->send[connIndex].step); | ||
| } | ||
|
|
||
| static inline __device__ void recvPeerNotify(int peer, int connIndex, int steps) { | ||
| int spins = 0; | ||
| ncclDevChannelPeer* peerPtr = ncclShmem.channel.peers[peer]; | ||
| ncclDevChannelPeer* peerPtr = ncclShmem.warpChannel[threadIdx.x/WARP_SIZE].peers[peer]; | ||
| peerPtr->recv[connIndex].step += steps; | ||
| st_relaxed_sys_global(peerPtr->recv[connIndex].head, peerPtr->recv[connIndex].step); | ||
| while (ld_volatile_global(peerPtr->recv[connIndex].tail) < peerPtr->recv[connIndex].step) { | ||
|
|
@@ -758,7 +754,7 @@ class Primitives< | |
| struct ncclDevWorkP2p* p2pWork = nullptr, int stepSize_ = 0, int mode = primsModeDefault | ||
| ): | ||
| tid(tid), tidInBlock(threadIdx.x), nthreads(nthreads), /*compiler warnings*/ | ||
| stepSize(stepSize_ == 0 ? ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T) : stepSize_), group(group), threadsPerBlock(blockDim.x){ | ||
| stepSize(stepSize_ == 0 ? ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T) : stepSize_), group(ncclShmem.warpComm? tidInBlock / WARP_SIZE : group), threadsPerBlock(blockDim.x){ | ||
|
|
||
| barriers = &ncclShmem.groups[group].barrier; | ||
| // PAT uses the same barrier for each group | ||
|
|
@@ -819,9 +815,9 @@ class Primitives< | |
| } | ||
|
|
||
| // coverity[overrun-call] => Coverity think prims.index can be greater than 1 | ||
| if (flags & (RoleWaitRecv|RolePostRecv)) loadRecvConn(ncclShmem.channel.peers[peer], connIndexRecv, collWork ? collWork->direct : 0, recvIpcReg, recvNetReg); | ||
| if (flags & (RoleWaitRecv|RolePostRecv)) loadRecvConn(ncclShmem.warpChannel[tidInBlock/WARP_SIZE].peers[peer], connIndexRecv, collWork ? collWork->direct : 0, recvIpcReg, recvNetReg); | ||
| // coverity[overrun-call] => Coverity think prims.index can be greater than 1 | ||
| if (flags & (RoleWaitSend|RolePostSend)) loadSendConn(ncclShmem.channel.peers[peer], connIndexSend, collWork ? collWork->direct : 0, sendIpcReg, sendNetReg); | ||
| if (flags & (RoleWaitSend|RolePostSend)) loadSendConn(ncclShmem.warpChannel[tidInBlock/WARP_SIZE].peers[peer], connIndexSend, collWork ? collWork->direct : 0, sendIpcReg, sendNetReg); | ||
|
|
||
| // if (barrierAny(flags & NetDeviceUnpack)) { | ||
| // flags |= AnyNetDeviceUnpack; | ||
|
|
@@ -849,7 +845,7 @@ class Primitives< | |
| // Load recv peer | ||
| int recvPeer = mode == primsModePatRs ? (rank - delta + nranks) % nranks : (rank + delta) % nranks; | ||
| struct ncclPatPeer* peer = ((struct ncclPatPeer*)recvPeers)+tid; | ||
| struct ncclConnInfo* conn = peer->conn = ncclShmem.channel.peers[recvPeer]->recv+connIndexRecv; | ||
| struct ncclConnInfo* conn = peer->conn = ncclShmem.warpChannel[tidInBlock/WARP_SIZE].peers[recvPeer]->recv+connIndexRecv; | ||
| peer->step = conn->step; | ||
| peer->buff = conn->buffs[NCCL_PROTO_SIMPLE]; | ||
| peer->stepCache = loadStepValue(peer->tailPtr = conn->tail); | ||
|
|
@@ -859,7 +855,7 @@ class Primitives< | |
| // Load send peer | ||
| int sendPeer = mode == primsModePatAg ? (rank - delta + nranks) % nranks : (rank + delta) % nranks; | ||
| peer = ((struct ncclPatPeer*)sendPeers)+tid; | ||
| conn = peer->conn = ncclShmem.channel.peers[sendPeer]->send+connIndexSend; | ||
| conn = peer->conn = ncclShmem.warpChannel[tidInBlock/WARP_SIZE].peers[sendPeer]->send+connIndexSend; | ||
| peer->step = conn->step; | ||
| peer->connFifo = conn->connFifo; | ||
| peer->buff = conn->buffs[NCCL_PROTO_SIMPLE]; | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.