Skip to content
Open
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
8568514
Add support for 256 channel count
mustafabar Nov 7, 2025
8ef18e4
Add LL with validation issues
mustafabar Nov 9, 2025
bf97215
Fix bug and add concept simple support
mustafabar Nov 10, 2025
3cd8227
Add cleanup
mustafabar Nov 10, 2025
ff1d576
Add minor edits
mustafabar Nov 10, 2025
67b95c7
Add slightly improved version
mustafabar Nov 11, 2025
e8546e6
Add a working v1 in drafty phase
mustafabar Nov 11, 2025
3545ce9
Force channels to be multiple of 7
mustafabar Nov 11, 2025
e6e0b25
No barrier when nthreads == WARP_SIZE
mustafabar Nov 11, 2025
385729d
Add all_gather warp_level
mustafabar Nov 11, 2025
c2f80c3
Gen all unrolls for mi350
mustafabar Nov 11, 2025
1915071
Fix cases where nChannels is not multiple of 7 for single node
mustafabar Nov 11, 2025
773077b
Enable up to 512 channels
mustafabar Nov 12, 2025
dc05efc
Enable any thread block size
mustafabar Nov 12, 2025
312c75b
Add support for LL128
mustafabar Nov 12, 2025
e4087e4
Revert MinTrafficPerChannel change
mustafabar Nov 13, 2025
b1266f3
Add threads per block control
mustafabar Nov 13, 2025
6ecb5b7
Add RS support
mustafabar Nov 13, 2025
a1f32bb
Generate unroll 3 and add env var
mustafabar Nov 13, 2025
7fa1926
Fix SendRecv and Tree
mustafabar Nov 14, 2025
378d54c
Rename and simplify symbols
mustafabar Nov 14, 2025
e15b9e0
Avoid more than 64 channels for Tree
mustafabar Nov 14, 2025
d9e8cb4
Added install.sh flag to suppress warnings.
thananon Nov 13, 2025
99de243
Add feature knobs and refactor changes
mustafabar Nov 16, 2025
6394818
Merge branch 'warp_speed_v1' of github.com:mustafabar/rccl into warp_…
mustafabar Nov 16, 2025
0cda162
Add warpspeed tuning
mustafabar Nov 17, 2025
64c4549
Merge conflicts
mustafabar Nov 17, 2025
3190e86
Fix channel tuning for multinode
mustafabar Nov 17, 2025
a77f6c8
Reduce Kernel Argsize
mustafabar Nov 17, 2025
a13c782
Add broadcast support
mustafabar Nov 21, 2025
a47eddb
Add Reduce support
mustafabar Nov 21, 2025
f77efa0
Revert "Added install.sh flag to suppress warnings."
mustafabar Nov 21, 2025
78259af
Use NCCL_MAX_GROUPS for max Warps per block
mustafabar Nov 21, 2025
0c0eaaf
Add clarifying comments on the Warp's channel loading
mustafabar Nov 21, 2025
50d914b
Reuse tidInBlock in init
mustafabar Nov 21, 2025
9fefa8b
Remove comment
mustafabar Nov 21, 2025
d717584
Edit comments
mustafabar Nov 21, 2025
4cdf269
Reflect correct type name
mustafabar Nov 21, 2025
fc53c54
Return channel logic for WARP_SIZE < 64
mustafabar Nov 21, 2025
9b695e3
Return UNROLLs to original
mustafabar Nov 21, 2025
c87510b
Add unroll 4
mustafabar Nov 21, 2025
ad8d80e
Unroll back to what they were
mustafabar Nov 21, 2025
fa7d972
Go back to -O1 for debug build
mustafabar Nov 22, 2025
dff8220
Modify unroll factor treatment
mustafabar Nov 22, 2025
c3539b3
Use -O1 for debug
mustafabar Nov 22, 2025
e0ad5f8
Remove unneeded ringIx
mustafabar Nov 22, 2025
9f18a01
Fix MSCCL compatibility
mustafabar Nov 24, 2025
6b8cf5c
Guard changes by MACRO enabled for MI3xx targets only
mustafabar Nov 25, 2025
e6905e9
Better align diffs for RunWorkBatch
mustafabar Nov 25, 2025
8b5a1f5
Merge branch 'develop' into warp_speed_v1
mustafabar Nov 25, 2025
0410c4d
Fix preprocessor directive syntax in all_gather.h
mustafabar Nov 26, 2025
ed81f5d
Modify comment
mustafabar Dec 1, 2025
30367a2
Revert barrier changes to use __syncwarp()
mustafabar Dec 1, 2025
75e327d
Remove nested/unneeded macro check
mustafabar Dec 1, 2025
83751fc
Double channels instead of hard-code to 8
mustafabar Dec 1, 2025
1a32798
Modify comment on channel increase
mustafabar Dec 1, 2025
168e017
Assign wid using tid
mustafabar Dec 2, 2025
bf76570
Undo space-changes in prims_ll.h
mustafabar Dec 2, 2025
4581393
Remove unnecessary whitespace in generate.py
mustafabar Dec 2, 2025
15677ff
Remove unneeded whitespace changes in generate.py
mustafabar Dec 2, 2025
8f5dbe0
Undo whitespace changes
mustafabar Dec 2, 2025
d808362
Merge branch 'develop' into warp_speed_v1
mustafabar Dec 2, 2025
fd2b87b
Skip model detection to support lower subscriptions
mustafabar Dec 2, 2025
d6e8b79
Merge branch 'warp_speed_v1' of github.com:mustafabar/rccl into warp_…
mustafabar Dec 2, 2025
fe78376
Use all threads in the warp to copy the channel data in parallel
mustafabar Dec 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/device/all_gather.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ namespace {
const int bid = ncclShmem.channelId - work->channelLo;
int npKitCtxIdx = bid; // unused variable - compiler warning
#endif
ncclRing *ring = &ncclShmem.channel.ring;
int warp = threadIdx.x / WARP_SIZE;
ncclRing *ring = &ncclShmem.warpChannel[warp].ring;
const int *ringRanks = ring->userRanks;
const int nranks = ncclShmem.comm.nRanks;
ssize_t count, partOffset, partCount, chunkCount;
ncclCollCbdPart(work, ncclShmem.channelId, Proto::Id, sizeof(T), &count, &partOffset, &partCount, &chunkCount);
ncclCollCbdPart(work, ncclShmem.warpChannelId[warp], Proto::Id, sizeof(T), &count, &partOffset, &partCount, &chunkCount);
ssize_t offset;
ssize_t dataOffset;
int nelem;
Expand Down Expand Up @@ -142,7 +143,7 @@ namespace {
#endif
// Final wait/copy.
prims.directRecv(offset, nelem);

#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_ALL_GATHER_RING_DIRECT_RECV_EXIT)
if (tid == 0) {
NpKit::CollectGpuEvent(NPKIT_EVENT_ALL_GATHER_RING_DIRECT_RECV_EXIT, nelem*sizeof(T), prims.npKitDataProcessTotalTime, NPKIT_GET_GPU_TIMESTAMP(),
Expand Down
6 changes: 4 additions & 2 deletions src/device/all_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ namespace {
#else
__device__ __attribute__((noinline)) void runRing(int tid, int nthreads, struct ncclDevWorkColl* work) {
#endif
ncclRing *ring = &ncclShmem.channel.ring;
int warp = threadIdx.x / WARP_SIZE;
ncclRing *ring = &ncclShmem.warpChannel[warp].ring;
int ringIx = ring->index;

const int nranks = ncclShmem.comm.nRanks;
#if defined(ENABLE_NPKIT)
const int bid = ncclShmem.channelId - work->channelLo;
Expand All @@ -31,7 +33,7 @@ namespace {
ssize_t gridOffset;
ssize_t channelCount;
ssize_t chunkCount;
ncclCollCbdPart(work, ncclShmem.channelId, Proto::Id, sizeof(T), &size, &gridOffset, &channelCount, &chunkCount);
ncclCollCbdPart(work, ncclShmem.warpChannelId[warp], Proto::Id, sizeof(T), &size, &gridOffset, &channelCount, &chunkCount);
const ssize_t loopCount = nranks * chunkCount;
ssize_t offset;
int nelem;
Expand Down
5 changes: 3 additions & 2 deletions src/device/broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@ namespace {
const int bid = ncclShmem.channelId - work->channelLo;
int npKitCtxIdx = bid; // unused variable - compiler warning
#endif
ncclRing *ring = &ncclShmem.channel.ring;
int warp = threadIdx.x / WARP_SIZE;
ncclRing *ring = &ncclShmem.warpChannel[warp].ring;
const int rank = ring->userRanks[0];
const int nextRank = ring->userRanks[1];
const int root = work->root;
ssize_t size;
ssize_t chunkCount;
ssize_t channelCount;
ssize_t gridOffset;
ncclCollCbdPart(work, ncclShmem.channelId, Proto::Id, sizeof(T), &size, &gridOffset, &channelCount, &chunkCount);
ncclCollCbdPart(work, ncclShmem.warpChannelId[warp], Proto::Id, sizeof(T), &size, &gridOffset, &channelCount, &chunkCount);
size_t offset;
int nelem;
int workNthreads;
Expand Down
24 changes: 12 additions & 12 deletions src/device/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,24 @@ struct RunWorkNop {
__device__ void run() {}
};

__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/1>(&args4K.args);
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_1(ncclDevKernelArgsDefaultStorage NCCL_GRID_CONSTANT const argsStorage) {
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/1>(&argsStorage.args);
}
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/2>(&args4K.args);
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_2(ncclDevKernelArgsDefaultStorage NCCL_GRID_CONSTANT const argsStorage) {
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/2>(&argsStorage.args);
}
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/4>(&args4K.args);
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_4(ncclDevKernelArgsDefaultStorage NCCL_GRID_CONSTANT const argsStorage) {
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/4>(&argsStorage.args);
}
#ifdef ENABLE_COLLTRACE
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/1>(&args4K.args);
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_1(ncclDevKernelArgsDefaultStorage NCCL_GRID_CONSTANT const argsStorage) {
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/1>(&argsStorage.args);
}
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/2>(&args4K.args);
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_2(ncclDevKernelArgsDefaultStorage NCCL_GRID_CONSTANT const argsStorage) {
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/2>(&argsStorage.args);
}
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/4>(&args4K.args);
__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_4(ncclDevKernelArgsDefaultStorage NCCL_GRID_CONSTANT const argsStorage) {
ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/4>(&argsStorage.args);
}
#endif

Expand Down
62 changes: 54 additions & 8 deletions src/device/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,12 @@ struct ncclShmemGroup {
struct ncclShmemData {
struct ncclDevKernelArgs args;
int channelId;
int warpChannelId[NCCL_MAX_GROUPS];
int warpComm;
int aborted;
alignas(16) struct ncclDevComm comm;
alignas(16) struct ncclDevChannel channel;
alignas(16) struct ncclDevChannel warpChannel[NCCL_MAX_GROUPS];

int batchIx, nextBatchIx;
enum ncclDevWorkType workType;
Expand Down Expand Up @@ -445,7 +448,10 @@ struct RunWorkBatch {
// Coverity reports a possible thread divergence due to not all threads participating in the collective.
// However, the code ensures that the participation is on a per-warp basis.
// coverity[device_thread_diverged:FALSE]
if (tid < subtn) RunWorkColl<Fn, T, RedOp, Algo, Proto>().run(tid, subtn, work);
if (tid < subtn) {
if(ncclShmem.warpComm == 0 || Algo != NCCL_ALGO_RING) RunWorkColl<Fn, T, RedOp, Algo, Proto>().run(tid, subtn, work);
else if (ncclShmem.warpChannelId[tid / WARP_SIZE] >= 0) RunWorkColl<Fn, T, RedOp, Algo, Proto>().run(tid % WARP_SIZE, WARP_SIZE, work);
}
}
}
};
Expand Down Expand Up @@ -490,6 +496,11 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a
int total = 0, y;
int num = MAXCHANNELS/64 > 0 ? MAXCHANNELS/64 : 1;

int warpCount = tn / WARP_SIZE;
int localWarpId = tid / WARP_SIZE;
int globalWarpId = (warpCount * blockIdx.x) + localWarpId;
int laneId = tid % WARP_SIZE;

// Copy kernel args to shmem and then only read those. Otherwise the compiler
// will end up putting the args into thread local stack which is very wasteful.
if (tid < sizeof(ncclDevKernelArgs)/sizeof(uint32_t)) {
Expand Down Expand Up @@ -584,8 +595,43 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a
ncclShmem.collTraceTail = args->comm->collTraceTail + ncclShmem.channelId;
}
#endif
if(tid == 0) {
ncclShmem.warpComm = args->comm->warpLevelComm;
}
__syncthreads(); // publish shmem


// Determine per-warp channel assignment for WarpSpeed enablement
total = 0;
if(ncclShmem.warpComm == 1) { // If warpComm is enabled, assing warps to channels that have the corresponding channel mask enabled
ncclShmem.warpChannelId[localWarpId] = -1;
__syncthreads();
for (int i = 0; i < num; i++) {
if (args->channelMask.masks[i] & (1ull<<laneId)) {
y = __popcll(args->channelMask.masks[i] & ((1ull<<laneId)-1));
y = total + y;
if (globalWarpId == y) {
ncclShmem.warpChannelId[localWarpId] = laneId + total;
break;
}
}
total = total + __popcll(args->channelMask.masks[i]);
}
__syncthreads();
if(ncclShmem.warpChannelId[localWarpId] >= 0) {
void* dst = &ncclShmem.warpChannel[localWarpId];
void* src = &((ncclDevCommAndChannels*)ncclShmem.args.comm)->channels[ncclShmem.warpChannelId[localWarpId]];
int bytes = sizeof(ncclDevChannel);
static_assert(sizeof(ncclDevChannel) <= 16*WARP_SIZE, "ncclDevChannel cannot be loaded by a single warp in one insn.");
// assert((tid-localWarpId*WARP_SIZE) >= 0 && (tid-localWarpId*WARP_SIZE) < WARP_SIZE);
copyToShmem16(tid-localWarpId*WARP_SIZE, dst, src, bytes);
}
} else if(laneId == 0) { // If warpComm is disabled, all warps use the same channel as the block
ncclShmem.warpChannelId[localWarpId] = ncclShmem.channelId;
ncclShmem.warpChannel[localWarpId] = ncclShmem.channel;
}
__syncthreads();

#ifdef ENABLE_PROFILING
if (tid == 0) {
ncclShmem.prof.count = 0;
Expand Down Expand Up @@ -648,17 +694,17 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a
#endif
}

__global__ void ncclDevKernel_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
__global__ void ncclDevKernel_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
__global__ void ncclDevKernel_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
__global__ void ncclDevKernel_Generic_1(ncclDevKernelArgsDefaultStorage NCCL_GRID_CONSTANT const argsStorage);
__global__ void ncclDevKernel_Generic_2(ncclDevKernelArgsDefaultStorage NCCL_GRID_CONSTANT const argsStorage);
__global__ void ncclDevKernel_Generic_4(ncclDevKernelArgsDefaultStorage NCCL_GRID_CONSTANT const argsStorage);
#ifdef ENABLE_COLLTRACE
__global__ void ncclDevKernelDebug_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
__global__ void ncclDevKernelDebug_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
__global__ void ncclDevKernelDebug_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K);
__global__ void ncclDevKernelDebug_Generic_1(ncclDevKernelArgsDefaultStorage NCCL_GRID_CONSTANT const argsStorage);
__global__ void ncclDevKernelDebug_Generic_2(ncclDevKernelArgsDefaultStorage NCCL_GRID_CONSTANT const argsStorage);
__global__ void ncclDevKernelDebug_Generic_4(ncclDevKernelArgsDefaultStorage NCCL_GRID_CONSTANT const argsStorage);
#endif

#define DEFINE_ncclDevKernel_nop(suffix, coll, redop, ty, algo, proto, specializedFnId) \
__global__ void ncclDevKernel_##suffix(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {}
__global__ void ncclDevKernel_##suffix(ncclDevKernelArgsDefaultStorage NCCL_GRID_CONSTANT const argsStorage) {}

#ifdef USE_INDIRECT_FUNCTION_CALL
#define DEFINE_ncclDevFunc(suffix, coll, redop, ty, algo, proto, acc, pipeline, unroll) \
Expand Down
6 changes: 3 additions & 3 deletions src/device/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def calc_unroll_and_pipeline_for_local_arch():
# We want to remove duplicates but cannot use a dictionary since same gfx name can have different cu counts
# Use (gfx_name, cu_count) as key for dictionary and convert it to list here
gfx_targets = list(gfx_targets.keys())

# Homogeneous system is required to build for only 1 variant of unroll factor (except for gfx950)
if len(gfx_targets) == 1:
gfx_name, cu_count = gfx_targets[0]
Expand Down Expand Up @@ -505,7 +505,7 @@ def get_arch_guard(fn):
key = ((coll_idx & 0x3F) | ((proto_idx & 0x3F) << 8))
if fn.coll in ["SendRecv", "AllToAllPivot"]:
key = ((coll_idx & 0x3F))

out(f' {{{key}, {fn_id}}}, {comment}\n')
out("};\n")

Expand Down Expand Up @@ -577,7 +577,7 @@ def partition_by_name(fns):
.format(sym=sym, coll=fn.coll, redop_cxx=redop_to_cxx[fn.redop], ty_cxx=ty_to_cxx[fn.ty],
algo=(fn.algo or "RING"), proto=(fn.proto or "SIMPLE"), acc=fn.acc, pipeline=fn.pipeline, unroll=fn.unroll)
)
if guard:
if guard:
out("#endif\n")

# Generate each <gensrc>/<msccl_impl>.cpp
Expand Down
9 changes: 5 additions & 4 deletions src/device/prims_ll.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload, Metadata, Pi
i4.data2 = (val >> 32);
i4.flag2 = flag;
*((u64_gptr) dst->v) = *((u64_gptr) i4.v);
*((u64_gptr) dst->v+1) = *((u64_gptr) i4.v+1);
*((u64_gptr) dst->v+1) = *((u64_gptr) i4.v+1);
#if defined(__gfx950__) && ROCM_VERSION < 70002
__builtin_amdgcn_fence(__ATOMIC_RELEASE, ""); // flush cache on gfx950 if ROCr fix for hipHostMallocUncached is not available (ROCm version < 7.0.2)
#endif
Expand Down Expand Up @@ -507,7 +507,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload, Metadata, Pi
nelem -= eltPerTrip;
offset += nthreads;
}
#ifdef __gfx950__
#ifdef __gfx950__
if constexpr (isMsccl(Metadata) && DST){
// Wait for pending vector loads and stores
__builtin_amdgcn_s_waitcnt((15 << 8) | (7 << 4)); // s_waitcnt vmcnt(0)
Expand Down Expand Up @@ -652,9 +652,10 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL, P2p, isNetOffload, Metadata, Pi
bool ipcReg = false, bool netReg = false, int stepSize_ = 0
):
redOp(redOpArg),
tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), group(group), threadsPerBlock(blockDim.x),
tid(tid), nthreads(nthreads), wid(threadIdx.x%WARP_SIZE), group(group), threadsPerBlock(blockDim.x),
Copy link
Contributor

@alex-breslow-amd alex-breslow-amd Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the wid(threadIdx.x%WARP_SIZE) change safe in general? I think this may need to be contingent on whether it's WarpSpeed. Let me know if I am wrong. If I recall correctly, there are cases where there is virtualization of the thread ID like in AllToAll, so the wid in that case is also virtual and not threadIdx.x%WARP_SIZE.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Having thought about it, threadId.x%warpsize is no different than (threadId.x%warpsize)%warpsize which I intended to address in this change. So it is safe to revert it

stepLines(ncclShmem.comm.buffSizes[NCCL_PROTO_LL]/NCCL_STEPS/sizeof(ncclLLFifoLine)) {
auto *channel = &ncclShmem.channel;
int warp = threadIdx.x / WARP_SIZE;
auto *channel = &ncclShmem.warpChannel[warp];
barriers = &ncclShmem.groups[group].barrier;
// If we are going to support oneshot collNet + LL, then we would need to add connector index here
int nrecv=0, nsend=0;
Expand Down
2 changes: 1 addition & 1 deletion src/device/prims_ll128.h
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p, isNetOffload, Metadata,
tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), /*compiler warnings*/
stepSize(ncclShmem.comm.buffSizes[NCCL_PROTO_LL128]/NCCL_STEPS/sizeof(uint64_t)),
warp(tid/WARP_SIZE), warpInBlock(threadIdx.x/WARP_SIZE), flagThread((tid%4)==3), group(group), threadsPerBlock(blockDim.x){
auto *channel = &ncclShmem.channel;
auto *channel = &ncclShmem.warpChannel[warpInBlock];
barriers = &ncclShmem.groups[group].barrier;
int nrecv=0, nsend=0;
while (nrecv < MaxRecv && recvPeers[nrecv] >= 0) {
Expand Down
20 changes: 8 additions & 12 deletions src/device/prims_simple.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,18 +81,14 @@ class Primitives<

// Don't use barrier 0 as it's used by the final sync
inline __device__ void barrier() {
if (nthreads == WARP_SIZE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we may need something here to prevent code motion in the general case like an __atomic_signal_fence. That's effectively what __syncwarp() is doing for us here I believe since our warps can't diverge at the hardware level. Same with below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, I went with __syncwarp(), as this evaluation is out of scope for the PR

__syncwarp();
else
if(nthreads != WARP_SIZE)
#if defined(__gfx942__) || defined(__gfx950__)
barrier_generic(__threadfence_block(), nworkers, barrier_next, barriers);
#else
barrier_generic(__threadfence(), nworkers, barrier_next, barriers);
#endif
}
inline __device__ void subBarrier() {
if (nworkers == WARP_SIZE) __syncwarp();
else
barrier();
}

Expand Down Expand Up @@ -490,14 +486,14 @@ class Primitives<

public:
static inline __device__ void sendPeerNotify(int peer, int connIndex, int steps) {
ncclDevChannelPeer* peerPtr = ncclShmem.channel.peers[peer];
ncclDevChannelPeer* peerPtr = ncclShmem.warpChannel[threadIdx.x/WARP_SIZE].peers[peer];
peerPtr->send[connIndex].step += steps;
st_relaxed_sys_global(peerPtr->send[connIndex].tail, peerPtr->send[connIndex].step);
}

static inline __device__ void recvPeerNotify(int peer, int connIndex, int steps) {
int spins = 0;
ncclDevChannelPeer* peerPtr = ncclShmem.channel.peers[peer];
ncclDevChannelPeer* peerPtr = ncclShmem.warpChannel[threadIdx.x/WARP_SIZE].peers[peer];
peerPtr->recv[connIndex].step += steps;
st_relaxed_sys_global(peerPtr->recv[connIndex].head, peerPtr->recv[connIndex].step);
while (ld_volatile_global(peerPtr->recv[connIndex].tail) < peerPtr->recv[connIndex].step) {
Expand Down Expand Up @@ -758,7 +754,7 @@ class Primitives<
struct ncclDevWorkP2p* p2pWork = nullptr, int stepSize_ = 0, int mode = primsModeDefault
):
tid(tid), tidInBlock(threadIdx.x), nthreads(nthreads), /*compiler warnings*/
stepSize(stepSize_ == 0 ? ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T) : stepSize_), group(group), threadsPerBlock(blockDim.x){
stepSize(stepSize_ == 0 ? ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T) : stepSize_), group(ncclShmem.warpComm? tidInBlock / WARP_SIZE : group), threadsPerBlock(blockDim.x){

barriers = &ncclShmem.groups[group].barrier;
// PAT uses the same barrier for each group
Expand Down Expand Up @@ -819,9 +815,9 @@ class Primitives<
}

// coverity[overrun-call] => Coverity think prims.index can be greater than 1
if (flags & (RoleWaitRecv|RolePostRecv)) loadRecvConn(ncclShmem.channel.peers[peer], connIndexRecv, collWork ? collWork->direct : 0, recvIpcReg, recvNetReg);
if (flags & (RoleWaitRecv|RolePostRecv)) loadRecvConn(ncclShmem.warpChannel[tidInBlock/WARP_SIZE].peers[peer], connIndexRecv, collWork ? collWork->direct : 0, recvIpcReg, recvNetReg);
// coverity[overrun-call] => Coverity think prims.index can be greater than 1
if (flags & (RoleWaitSend|RolePostSend)) loadSendConn(ncclShmem.channel.peers[peer], connIndexSend, collWork ? collWork->direct : 0, sendIpcReg, sendNetReg);
if (flags & (RoleWaitSend|RolePostSend)) loadSendConn(ncclShmem.warpChannel[tidInBlock/WARP_SIZE].peers[peer], connIndexSend, collWork ? collWork->direct : 0, sendIpcReg, sendNetReg);

// if (barrierAny(flags & NetDeviceUnpack)) {
// flags |= AnyNetDeviceUnpack;
Expand Down Expand Up @@ -849,7 +845,7 @@ class Primitives<
// Load recv peer
int recvPeer = mode == primsModePatRs ? (rank - delta + nranks) % nranks : (rank + delta) % nranks;
struct ncclPatPeer* peer = ((struct ncclPatPeer*)recvPeers)+tid;
struct ncclConnInfo* conn = peer->conn = ncclShmem.channel.peers[recvPeer]->recv+connIndexRecv;
struct ncclConnInfo* conn = peer->conn = ncclShmem.warpChannel[tidInBlock/WARP_SIZE].peers[recvPeer]->recv+connIndexRecv;
peer->step = conn->step;
peer->buff = conn->buffs[NCCL_PROTO_SIMPLE];
peer->stepCache = loadStepValue(peer->tailPtr = conn->tail);
Expand All @@ -859,7 +855,7 @@ class Primitives<
// Load send peer
int sendPeer = mode == primsModePatAg ? (rank - delta + nranks) % nranks : (rank + delta) % nranks;
peer = ((struct ncclPatPeer*)sendPeers)+tid;
conn = peer->conn = ncclShmem.channel.peers[sendPeer]->send+connIndexSend;
conn = peer->conn = ncclShmem.warpChannel[tidInBlock/WARP_SIZE].peers[sendPeer]->send+connIndexSend;
peer->step = conn->step;
peer->connFifo = conn->connFifo;
peer->buff = conn->buffs[NCCL_PROTO_SIMPLE];
Expand Down
5 changes: 3 additions & 2 deletions src/device/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@ namespace {
#else
__device__ __attribute__((noinline)) void runRing(int tid, int nthreads, struct ncclDevWorkColl* work) {
#endif
ncclRing *ring = &ncclShmem.channel.ring;
int warp = threadIdx.x / WARP_SIZE;
ncclRing *ring = &ncclShmem.warpChannel[warp].ring;
const int nranks = ncclShmem.comm.nRanks;
const int rank = ncclShmem.comm.rank;
const int prevRank = ring->userRanks[nranks-1];
const int root = work->root;
size_t chunkCount;
size_t channelCount;
size_t gridOffset;
ncclCollCbdPart(work, ncclShmem.channelId, Proto::Id, sizeof(T), (size_t*)nullptr, &gridOffset, &channelCount, &chunkCount);
ncclCollCbdPart(work, ncclShmem.warpChannelId[warp], Proto::Id, sizeof(T), (size_t*)nullptr, &gridOffset, &channelCount, &chunkCount);
size_t offset;
int nelem;

Expand Down
Loading
Loading