From 430081a7c673576ec35f70a9929a97fc16aeb146 Mon Sep 17 00:00:00 2001 From: Mark Santesson Date: Sat, 8 Nov 2025 03:43:22 +0000 Subject: [PATCH] NCCL Put API Preview --- src/CMakeLists.txt | 2 + src/Makefile | 1 + src/collectives.cc | 44 ++ src/dev_runtime.cc | 31 +- src/enqueue.cc | 139 ++++- src/group.cc | 20 + src/include/comm.h | 44 +- src/include/dev_runtime.h | 4 + src/include/group.h | 8 + src/include/info.h | 8 + src/include/nccl_common.h | 5 +- src/include/nvtx.h | 5 +- src/include/nvtx_payload_schemas.h | 25 + src/include/rma/rma.h | 34 ++ src/include/rma/rma_ce.h | 50 ++ src/include/rma/rma_proxy.h | 139 +++++ src/init.cc | 34 ++ src/nccl.h.in | 86 ++++ src/plugin/net.cc | 1 + src/rma/CMakeLists.txt | 9 + src/rma/rma.cc | 240 +++++++++ src/rma/rma_ce.cc | 230 +++++++++ src/rma/rma_proxy.cc | 802 +++++++++++++++++++++++++++++ 23 files changed, 1952 insertions(+), 9 deletions(-) create mode 100644 src/include/rma/rma.h create mode 100644 src/include/rma/rma_ce.h create mode 100644 src/include/rma/rma_proxy.h create mode 100644 src/rma/CMakeLists.txt create mode 100644 src/rma/rma.cc create mode 100644 src/rma/rma_ce.cc create mode 100644 src/rma/rma_proxy.cc diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b48ed1880..9450d74cd 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -40,6 +40,7 @@ add_subdirectory(nccl_device) add_subdirectory(ras) add_subdirectory(scheduler) add_subdirectory(gin) +add_subdirectory(rma) add_compile_options(-fmacro-prefix-map=${CMAKE_CURRENT_SOURCE_DIR}/=) @@ -55,6 +56,7 @@ list(APPEND LIBSRCFILES ${SCHEDULER_SOURCES} ${GIN_SOURCES} ${DOCA_SOURCES} + ${RMA_SOURCES} ) ###################### Create a shared NCCL library ############################ diff --git a/src/Makefile b/src/Makefile index 471a0335e..0581b9490 100644 --- a/src/Makefile +++ b/src/Makefile @@ -26,6 +26,7 @@ LIBSRCFILES := \ $(wildcard nccl_device/*.cc) \ $(wildcard scheduler/*.cc) \ $(wildcard gin/*.cc) \ + $(wildcard rma/*.cc) \ $(filter-out ras/client.cc,$(wildcard ras/*.cc)) BINSRCFILES := ras/client.cc diff --git a/src/collectives.cc b/src/collectives.cc index ca69c9a78..c93404102 100644 --- a/src/collectives.cc +++ b/src/collectives.cc @@ -23,6 +23,9 @@ const char* ncclFuncToString(ncclFunc_t fn) { case ncclFuncScatter: return "Scatter"; case ncclFuncSendRecv: return "SendRecv"; case ncclFuncSend: return "Send"; + case ncclFuncPut: return "Put"; + case ncclFuncSignal: return "Signal"; + case ncclFuncWaitSignal: return "WaitSignal"; default: return "Invalid"; } } @@ -214,3 +217,44 @@ ncclResult_t ncclRecv(void* recvbuff, size_t count, ncclDataType_t datatype, int 1, 1 }; return ncclEnqueueCheck(&info); } + +NCCL_API(ncclResult_t, ncclPut, int ctx, const void* localbuff, size_t count, ncclDataType_t datatype, + int peer, size_t peerWinOffset, ncclWindow_t peerWin, ncclSignalMode_t signalMode, ncclComm_t comm, cudaStream_t stream); +ncclResult_t ncclPut(int ctx, const void* localbuff, size_t count, ncclDataType_t datatype, + int peer, size_t peerWinOffset, ncclWindow_t peerWin, ncclSignalMode_t signalMode, ncclComm_t comm, cudaStream_t stream) { +NVTX3_FUNC_WITH_PARAMS(Put, NcclNvtxParamsPut, + NVTX3_PAYLOAD(comm ? comm->commHash : 0, count * ncclTypeSize(datatype), peer, ctx)); + +struct ncclInfo info = { ncclFuncPut, "Put", + localbuff, NULL, count, datatype, ncclSum, peer, comm, stream, /* Args */ + 1, 1, /* chunkSteps, sliceSteps */ + peerWinOffset, peerWin, signalMode, ctx, /* peerWinOffset, peerWin, signalMode, ctx */ + NULL, NULL, 0 }; /* peers, nsignals, npeers */ +return ncclEnqueueCheck(&info); +} + +NCCL_API(ncclResult_t, ncclSignal, int ctx, int peer, ncclSignalMode_t signalMode, ncclComm_t comm, cudaStream_t stream); +ncclResult_t ncclSignal(int ctx, int peer, ncclSignalMode_t signalMode, ncclComm_t comm, cudaStream_t stream) { +NVTX3_FUNC_WITH_PARAMS(Signal, NcclNvtxParamsSignal, + NVTX3_PAYLOAD(comm ? comm->commHash : 0, peer, ctx)); + +struct ncclInfo info = { ncclFuncSignal, "Signal", + NULL, NULL, 0, ncclInt8, ncclSum, peer, comm, stream, /* Args */ + 1, 1, /* chunkSteps, sliceSteps */ + 0, NULL, signalMode, ctx, /* peerWinOffset, peerWin, signalMode, ctx */ + NULL, NULL, 0 }; /* peers, nsignals, npeers */ +return ncclEnqueueCheck(&info); +} + +NCCL_API(ncclResult_t, ncclWaitSignal, int ctx, int* peers, int* nsignals, int npeers, ncclSignalMode_t signalMode, ncclComm_t comm, cudaStream_t stream); +ncclResult_t ncclWaitSignal(int ctx, int* peers, int* nsignals, int npeers, ncclSignalMode_t signalMode, ncclComm_t comm, cudaStream_t stream) { +NVTX3_FUNC_WITH_PARAMS(WaitSignal, NcclNvtxParamsWaitSignal, + NVTX3_PAYLOAD(comm ? comm->commHash : 0, npeers, ctx)); + +struct ncclInfo info = { ncclFuncWaitSignal, "WaitSignal", + NULL, NULL, 0, ncclInt32, ncclSum, 0, comm, stream, /* Args */ + 1, 1, /* chunkSteps, sliceSteps */ + 0, NULL, signalMode, ctx, /* peerWinOffset, peerWin, signalMode, ctx */ + peers, nsignals, npeers }; /* peers, nsignals, npeers */ +return ncclEnqueueCheck(&info); +} \ No newline at end of file diff --git a/src/dev_runtime.cc b/src/dev_runtime.cc index 60cb200aa..470af0c3d 100644 --- a/src/dev_runtime.cc +++ b/src/dev_runtime.cc @@ -6,6 +6,7 @@ #include "dev_runtime.h" #include "comm.h" +#include "rma/rma.h" #include "device.h" #include "transport.h" #include "group.h" @@ -23,6 +24,8 @@ struct ncclDevrMemory { size_t bigOffset; // offset in big VA space void* ginHostWins[NCCL_GIN_MAX_CONTEXTS]; ncclGinWindow_t ginDevWins[NCCL_GIN_MAX_CONTEXTS]; + void* rmaHostWins[NCCL_GIN_MAX_CONTEXTS]; + ncclGinWindow_t rmaDevWins[NCCL_GIN_MAX_CONTEXTS]; }; struct ncclDevrWindowSorted { @@ -354,6 +357,12 @@ static ncclResult_t symMemoryRegisterGin(struct ncclComm* comm, struct ncclDevrM return ncclSuccess; } +static ncclResult_t symMemoryRegisterRma(struct ncclComm* comm, struct ncclDevrMemory* mem) { + NCCLCHECK(ncclRmaProxyConnectOnce(comm)); + NCCLCHECK(ncclRmaProxyRegister(comm, mem->primaryAddr, mem->size, mem->rmaHostWins, mem->rmaDevWins)); + return ncclSuccess; +} + // On success we take caller's reference on memHandle. // Due to multicast binds for each pre-exiting team, this function requires // caller do a world barrier before returning to user. @@ -402,6 +411,13 @@ static ncclResult_t symMemoryObtain( NCCLCHECKGOTO(symMemoryRegisterGin(comm, mem), ret, fail_mem_space_teams); } + // ginEnabled is set in ncclDevrCommCreateInternal, which might not be called for RMA proxy + // so we introduce rmaProxyEnabled to track if RMA proxy is enabled + devr->rmaProxyEnabled = comm->nNodes > 1 && comm->config.numRmaCtx > 0 && ncclParamGinType() == NCCL_NET_DEVICE_GIN_PROXY; + if (devr->rmaProxyEnabled) { + NCCLCHECKGOTO(symMemoryRegisterRma(comm, mem), ret, fail_mem_space_teams); + } + // Add to list of mems. mem->next = devr->memHead; devr->memHead = mem; @@ -431,6 +447,9 @@ static void symMemoryDropRef( if (devr->ginEnabled) { ncclGinDeregister(comm, mem->ginHostWins); } + if (devr->rmaProxyEnabled) { + ncclRmaProxyDeregister(comm, mem->rmaHostWins); + } for (struct ncclDevrTeam* t = devr->teamHead; t != nullptr; t = t->next) { symUnbindTeamMemory(comm, t, mem); } @@ -1020,7 +1039,6 @@ ncclResult_t ncclDevCommDestroy( return ncclSuccess; } - // Get the corresponding pointer in another lsa rank's symmetric memory window ncclResult_t ncclDevrGetLsaRankPtr(struct ncclComm* comm, struct ncclDevrWindow* winHost, size_t offset, int lsaRank, void** outPtr) { if (winHost == nullptr || outPtr == nullptr) { @@ -1044,6 +1062,17 @@ ncclResult_t ncclDevrGetLsaRankPtr(struct ncclComm* comm, struct ncclDevrWindow* return ncclSuccess; } +// Get the RMA device window handle for a specific context +ncclGinWindow_t ncclDevrGetRmaDevWin(struct ncclDevrWindow* winHost, int ctx) { + if (winHost == nullptr || winHost->memory == nullptr) { + return nullptr; + } + if (ctx < 0 || ctx >= NCCL_GIN_MAX_CONTEXTS) { + return nullptr; + } + return winHost->memory->rmaDevWins[ctx]; +} + // Get the multicast address for a given team ncclResult_t ncclDevrGetLsaTeamPtrMC(struct ncclComm* comm, struct ncclDevrWindow* winHost, size_t offset, struct ncclTeam lsaTeam, void** outPtr){ if (winHost == nullptr || outPtr == nullptr) { diff --git a/src/enqueue.cc b/src/enqueue.cc index da45abe6f..84464f51e 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -17,6 +17,7 @@ #include "ce_coll.h" #include "nvtx.h" #include "scheduler.h" +#include "rma/rma.h" #include // std::memcpy #include // PRIx64 @@ -1149,7 +1150,7 @@ namespace { } static ncclResult_t uploadWork(struct ncclComm* comm, struct ncclKernelPlan* plan) { - if (plan->isSymColl || plan->isCeColl) return ncclSuccess; + if (plan->isSymColl || plan->isCeColl || plan->isRma) return ncclSuccess; size_t workBytes = plan->workBytes; size_t batchBytes = plan->nWorkBatches*sizeof(struct ncclDevWorkBatch); @@ -1423,7 +1424,8 @@ ncclResult_t ncclLaunchPrepare(struct ncclComm* comm) { if (planner->nTasksColl + planner->nTasksP2p != 0 || !ncclIntruQueueEmpty(&planner->collSymTaskQueue) || - !ncclIntruQueueEmpty(&planner->collCeTaskQueue)) { + !ncclIntruQueueEmpty(&planner->collCeTaskQueue) || + planner->nTasksRma != 0) { do { memset(&planner->wipPlan, 0, sizeof(planner->wipPlan)); @@ -1435,7 +1437,13 @@ ncclResult_t ncclLaunchPrepare(struct ncclComm* comm) { plan->workStorageType = persistent ? ncclDevWorkStorageTypePersistent : ncclDevWorkStorageTypeFifo; - if (!ncclIntruQueueEmpty(&planner->collCeTaskQueue)) { + if (planner->nTasksRma != 0) { + NCCLCHECKGOTO(scheduleRmaTasksToPlan(comm, plan), result, failure); + if (plan->isRma && plan->rmaArgs != NULL && plan->rmaArgs->nRmaTasks > 0) { + ncclIntruQueueEnqueue(&planner->planQueue, plan); + nPlans += 1; + } + } else if (!ncclIntruQueueEmpty(&planner->collCeTaskQueue)) { struct ncclTaskColl* task = ncclIntruQueueHead(&planner->collCeTaskQueue); plan->isCeColl = true; plan->ceCollArgs = ncclMemoryStackAlloc(&comm->memScoped); @@ -1453,7 +1461,7 @@ ncclResult_t ncclLaunchPrepare(struct ncclComm* comm) { ncclMemoryPoolFree(&comm->memPool_ncclTaskColl, task); nPlans += 1; } else { - if (!ncclIntruQueueEmpty(&planner->collSymTaskQueue)) { + if (!ncclIntruQueueEmpty(&planner->collSymTaskQueue)) { NCCLCHECKGOTO(ncclSymmetricTaskScheduler(comm, &planner->collSymTaskQueue, plan), result, failure); } else { @@ -1483,7 +1491,8 @@ ncclResult_t ncclLaunchPrepare(struct ncclComm* comm) { } } while (planner->nTasksColl + planner->nTasksP2p != 0 || !ncclIntruQueueEmpty(&planner->collSymTaskQueue) || - !ncclIntruQueueEmpty(&planner->collCeTaskQueue)); + !ncclIntruQueueEmpty(&planner->collCeTaskQueue) || + planner->nTasksRma != 0); struct ncclKernelPlan* planHead = ncclIntruQueueHead(&planner->planQueue); planner->unlaunchedPlansHead = planHead; @@ -2542,6 +2551,124 @@ static ncclResult_t ceCollTaskAppend( return ncclSuccess; } +static ncclResult_t rmaTaskAppend( + struct ncclComm* comm, + struct ncclInfo* info) { + struct ncclKernelPlanner *planner = &comm->planner; + + void const* srcBuff = info->sendbuff; + + if (!comm->symmetricSupport){ + WARN("ncclPut: symmetric support is not enabled"); + return ncclInvalidArgument; + } + + // Check if user context is valid + if (info->ctx < 0 || info->ctx >= comm->config.numRmaCtx) { + WARN("User context index %d out of bounds (min: 0, max: %d)", info->ctx, comm->config.numRmaCtx - 1); + return ncclInvalidArgument; + } + + // Initialize window pointers - only needed for Put and Signal + struct ncclDevrWindow* peerWinHost = NULL; + struct ncclDevrWindow* srcWinHost = NULL; + size_t srcWinOffset = 0; + + if (info->coll == ncclFuncPut) { + // Validate peer window with detailed debugging + if (info->peerWin == NULL) { + WARN("ncclPut: peerWin is NULL"); + return ncclInvalidArgument; + } + + struct ncclWindow_vidmem* peerWinDevHost = NULL; + NCCLCHECK(ncclShadowPoolToHost(&comm->devrState.shadows, info->peerWin, &peerWinDevHost)); + peerWinHost = (struct ncclDevrWindow*)peerWinDevHost->winHost; + + // Validate source buffer and window + if (srcBuff == NULL) { + WARN("ncclPut: srcBuff is NULL"); + return ncclInvalidArgument; + } + NCCLCHECK(ncclDevrFindWindow(comm, srcBuff, &srcWinHost)); + if (srcWinHost == NULL || !(srcWinHost->winFlags & NCCL_WIN_COLL_SYMMETRIC)) { + WARN("ncclPut: srcWinHost is not in a valid symmetric window"); + return ncclInvalidArgument; + } + srcWinOffset = (char*)srcBuff - (char*)srcWinHost->userPtr; + } + else if (info->coll == ncclFuncSignal) { + // Check if count is valid + if (info->count != 0) { + WARN("ncclSignal: count must be 0"); + return ncclInvalidArgument; + } + // Check if signalMode is valid + if (info->signalMode == NCCL_SIGNAL_NONE) { + WARN("ncclSignal: signalMode is none"); + return ncclInvalidArgument; + } + } + else if (info->coll == ncclFuncWaitSignal) { + // Check if signalMode, peers and nsignals are valid + if (info->signalMode == NCCL_SIGNAL_NONE || info->peers == NULL || info->nsignals == NULL || info->npeers == 0) { + WARN("ncclWaitSignal: invalid arguments"); + return ncclInvalidArgument; + } + } + + // Check if RMA CE needs initialization + if (!comm->rmaState.rmaCeState.initialized && ncclIntruQueueEmpty(&comm->rmaCeInitTaskQueue)) { + struct ncclRmaCeInitTask* ceTask; + NCCLCHECK(ncclCalloc(&ceTask, 1)); + ceTask->comm = comm; + ncclIntruQueueEnqueue(&comm->rmaCeInitTaskQueue, ceTask); + ncclGroupCommJoin(comm, ncclGroupTaskTypeSymRegister); + } + + // Must be in thread local group before tasks can be alloc'd in `comm->memScoped`. + ncclGroupCommJoin(info->comm, ncclGroupTaskTypeCollective); + NCCLCHECK(ncclPlannerSetCapturingGraph(comm, info)); + struct ncclTaskRma* t = ncclMemoryPoolAlloc(&comm->memPool_ncclTaskRma, &comm->memPermanent); + + t->func = info->coll; + t->srcBuff = srcBuff; + t->srcWinOffset = srcWinOffset; + t->srcWinHost = srcWinHost; + t->count = info->count; + t->datatype = info->datatype; + t->bytes = t->count * ncclTypeSize(t->datatype); + t->ctx = info->ctx; + t->peer = info->root; + t->peerWinOffset = info->peerWinOffset; + t->peerWinHost = peerWinHost; + t->signalMode = info->signalMode; + + // Copy the peers and nsignals arrays if present + if (info->peers != NULL && info->nsignals != NULL && info->npeers > 0) { + int* peersCopy = ncclMemoryStackAlloc(&comm->memScoped, info->npeers); + int* nsignalsCopy = ncclMemoryStackAlloc(&comm->memScoped, info->npeers); + for (int i = 0; i < info->npeers; i++) { + peersCopy[i] = info->peers[i]; + nsignalsCopy[i] = info->nsignals[i]; + } + t->peers = peersCopy; + t->nsignals = nsignalsCopy; + } else { + t->peers = info->peers; + t->nsignals = info->nsignals; + } + t->npeers = info->npeers; + + t->eActivationMask = __atomic_load_n(&ncclProfilerEventMask, __ATOMIC_RELAXED); + + planner->nTasksRma++; + // Enqueue the task into the appropriate context queue + ncclIntruQueueEnqueue(&planner->rmaTaskQueues[t->ctx], t); + + return ncclSuccess; +} + // Converts `info` to a task and adds it to `comm->planner`. The exception is with // single rank communicators, collectives are issued as `ncclMemcpyAsync`s and // thus don't need a task. @@ -2550,6 +2677,8 @@ static ncclResult_t taskAppend(struct ncclComm* comm, struct ncclInfo* info) { if (info->coll == ncclFuncSend || info->coll == ncclFuncRecv) { NCCLCHECK(p2pTaskAppend(comm, info, info->coll, collAPI, (void*)info->recvbuff, info->count, info->datatype, info->root)); + } else if (info->coll == ncclFuncPut || info->coll == ncclFuncSignal || info->coll == ncclFuncWaitSignal) { + NCCLCHECK(rmaTaskAppend(comm, info)); } else { // Empty collectives can be discarded. if (info->count == 0) return ncclSuccess; diff --git a/src/group.cc b/src/group.cc index aa2824412..f1d8eeaf7 100644 --- a/src/group.cc +++ b/src/group.cc @@ -14,6 +14,7 @@ #include "ce_coll.h" #include "profiler.h" #include "nvtx.h" +#include "rma/rma.h" #define GROUP_MAX_RECLAIM_STEPS 10 @@ -250,6 +251,12 @@ ncclResult_t ncclCommGroupRegisterSymmetric(struct ncclAsyncJob* job_) { free(task); } + while (!ncclIntruQueueEmpty(&comm->rmaCeInitTaskQueue)) { + struct ncclRmaCeInitTask* task = ncclIntruQueueDequeue(&comm->rmaCeInitTaskQueue); + NCCLCHECKGOTO(ncclRmaCeInit(task->comm), ret, fail); + free(task); + } + exit: return ret; fail: @@ -305,6 +312,8 @@ static ncclResult_t doLaunches(struct ncclComm* head) { NCCLCHECKGOTO(ncclLaunchKernelBefore_NoUncapturedCuda(comm, plan), result, failure); if (plan->isCeColl) { NCCLCHECKGOTO(ncclLaunchCeColl(comm, plan), result, failure); + } else if (plan->isRma) { + NCCLCHECKGOTO(ncclLaunchRma(comm, plan), result, failure); } else { NCCLCHECKGOTO(ncclLaunchKernel(comm, plan), result, failure); } @@ -372,9 +381,20 @@ static void groupCleanup(struct ncclComm** groupCommHeadPtr, struct ncclIntruQue { // Reset comm->planner to empty. ncclKernelPlanner::Peer* tmp = comm->planner.peers; + ncclIntruQueue* tmpRmaQueues = comm->planner.rmaTaskQueues; + int numRmaCtx = comm->config.numRmaCtx; + memset(&comm->planner, 0, sizeof(comm->planner)); + comm->planner.peers = tmp; if (comm->planner.peers != NULL) memset(comm->planner.peers, 0, comm->nRanks * sizeof(comm->planner.peers[0])); + + comm->planner.rmaTaskQueues = tmpRmaQueues; + if (comm->planner.rmaTaskQueues != NULL) { + for (int i = 0; i < numRmaCtx; i++) { + ncclIntruQueueConstruct(&comm->planner.rmaTaskQueues[i]); + } + } } } diff --git a/src/include/comm.h b/src/include/comm.h index e1b37db16..9a05f38b6 100644 --- a/src/include/comm.h +++ b/src/include/comm.h @@ -21,6 +21,7 @@ #include "dev_runtime.h" #include "sym_kernels.h" #include "ce_coll.h" +#include "rma/rma.h" #if CUDART_VERSION < 9000 struct cudaLaunchParams { @@ -249,6 +250,36 @@ struct ncclTaskP2p { uint8_t nChannels; }; +struct ncclTaskRma { + struct ncclTaskRma* next; + ncclFunc_t func; + int ctx; + size_t count; + ncclDataType_t datatype; + size_t bytes; + + void const* srcBuff; + size_t srcWinOffset; + struct ncclDevrWindow* srcWinHost; + + int peer; + size_t peerWinOffset; + struct ncclDevrWindow* peerWinHost; + + // Signal operations + ncclSignalMode_t signalMode; + int*peers; + int*nsignals; + int npeers; + + // Profiler plugin + int eActivationMask; + void* groupApiEventHandle; + void* rmaApiEventHandle; + void* eventHandle; + uint8_t nChannels; +}; + struct ncclKernelPlan { // A kernel plan is also a callback that reclaims itself. Hence this must // be the first member. @@ -261,6 +292,7 @@ struct ncclKernelPlan { bool isHostCbEnq; bool isSymColl; bool isCeColl; + bool isRma; enum ncclDevWorkStorageType workStorageType; bool kernelSpecialized; void* kernelFn; @@ -268,6 +300,7 @@ struct ncclKernelPlan { struct ncclDevKernelArgs* kernelArgs; void* kernelSymArgs; struct ncclCeCollArgs* ceCollArgs; + struct ncclRmaArgs* rmaArgs; }; size_t kernelArgsSize; uint64_t channelMask; // bitset of which channels are present @@ -282,6 +315,8 @@ struct ncclKernelPlan { void* workBufPersistent; struct ncclIntruQueue p2pTaskQueue; + struct ncclIntruQueue rmaTaskQueueProxy; + struct ncclIntruQueue rmaTaskQueueCe; struct ncclIntruQueue collTaskQueue; struct ncclIntruQueue proxyOpQueue; @@ -377,7 +412,7 @@ struct ncclKernelPlanner { }; struct ncclTaskCollSorter collSorter; struct Peer* peers/*[nRanks]*/; - int nTasksColl, nTasksP2p; + int nTasksColl, nTasksP2p, nTasksRma; int nTasksP2pSend, nTasksP2pRecv; bool persistent; // The list of user streams aggregated over all tasks present. @@ -396,6 +431,7 @@ struct ncclKernelPlanner { struct ncclIntruQueue collTaskQueue; struct ncclIntruQueue collCeTaskQueue; + struct ncclIntruQueue *rmaTaskQueues; // Per-context queue for RMA tasks struct ncclIntruQueue collSymTaskQueue; struct ncclIntruQueue collWorkQueue; struct ncclIntruQueue tmpCollWorkQueue; @@ -603,8 +639,10 @@ struct ncclComm { // pools backed by comm->memPermanent struct ncclMemoryPool memPool_ncclTaskColl; struct ncclMemoryPool memPool_ncclTaskP2p; + struct ncclMemoryPool memPool_ncclTaskRma; struct ncclMemoryPool memPool_ncclProxyOp; struct ncclMemoryPool memPool_ncclKernelPlan; + struct ncclMemoryPool memPool_ncclRmaProxyDesc; // Next comm in this thread's active ncclGroup[Start|End](). Holds "0x1" when // this comm is not yet in a group. @@ -654,6 +692,10 @@ struct ncclComm { uint64_t seqNumber[NCCL_NUM_FUNCTIONS]; struct ncclProfilerProxy profiler; + // RMA state + struct ncclRmaState rmaState; + struct ncclIntruQueue rmaCeInitTaskQueue; + // CE Collective struct ncclCeColl ceColl; struct ncclIntruQueue ceInitTaskQueue; diff --git a/src/include/dev_runtime.h b/src/include/dev_runtime.h index 70bf77496..38dfc92cf 100644 --- a/src/include/dev_runtime.h +++ b/src/include/dev_runtime.h @@ -53,6 +53,7 @@ struct ncclDevrState { size_t granularity; // cuMemGetAllocationGranularity bool ginEnabled; + bool rmaProxyEnabled; struct ncclDevrMemory* memHead; struct ncclDevrWindowSorted* winSorted; int winSortedCapacity, winSortedCount; @@ -88,6 +89,9 @@ void freeDevCommRequirements( // Get the corresponding pointer in another lsa rank's symmetric memory window ncclResult_t ncclDevrGetLsaRankPtr(struct ncclComm* comm, struct ncclDevrWindow* winHost, size_t offset, int lsaRank, void** outPtr); +// Get the RMA device window handle for a specific context +ncclGinWindow_t ncclDevrGetRmaDevWin(struct ncclDevrWindow* winHost, int ctx); + // Get the multicast address for a given team ncclResult_t ncclDevrGetLsaTeamPtrMC(struct ncclComm* comm, struct ncclDevrWindow* winHost, size_t offset, struct ncclTeam lsaTeam, void** outPtr); #endif diff --git a/src/include/group.h b/src/include/group.h index 3b08d9f16..447cdebf6 100644 --- a/src/include/group.h +++ b/src/include/group.h @@ -116,8 +116,16 @@ inline void ncclGroupCommJoin(struct ncclComm* comm, int type) { if (type == ncclGroupTaskTypeCollective) { // Initialize planner ncclKernelPlanner::Peer* tmp = comm->planner.peers; + ncclIntruQueue* tmpRmaQueues = comm->planner.rmaTaskQueues; + int numRmaCtx = comm->config.numRmaCtx; memset(&comm->planner, 0, sizeof(comm->planner)); comm->planner.peers = tmp; + comm->planner.rmaTaskQueues = tmpRmaQueues; + if (comm->planner.rmaTaskQueues != NULL) { + for (int i = 0; i < numRmaCtx; i++) { + ncclIntruQueueConstruct(&comm->planner.rmaTaskQueues[i]); + } + } } } ncclGroupBlocking = comm->config.blocking; diff --git a/src/include/info.h b/src/include/info.h index 3cabae866..28471187a 100644 --- a/src/include/info.h +++ b/src/include/info.h @@ -28,6 +28,14 @@ struct ncclInfo { // Algorithm details int chunkSteps; int sliceSteps; + // One-sided ops + size_t peerWinOffset; + ncclWindow_t peerWin; + ncclSignalMode_t signalMode; + int ctx; + int* peers; + int* nsignals; + int npeers; }; #endif diff --git a/src/include/nccl_common.h b/src/include/nccl_common.h index 0a3842151..c3d0967b7 100644 --- a/src/include/nccl_common.h +++ b/src/include/nccl_common.h @@ -68,7 +68,10 @@ typedef enum { ncclFuncAlltoAll = 8, ncclFuncScatter = 9, ncclFuncGather = 10, - ncclNumFuncs = 11 + ncclFuncPut = 11, + ncclFuncSignal = 12, + ncclFuncWaitSignal = 13, + ncclNumFuncs = 14 } ncclFunc_t; diff --git a/src/include/nvtx.h b/src/include/nvtx.h index c1dbe5fd9..7a28ee14c 100644 --- a/src/include/nvtx.h +++ b/src/include/nvtx.h @@ -38,10 +38,13 @@ #define NVTX_SID_Gather 17 #define NVTX_SID_Scatter 18 #define NVTX_SID_CommRevoke 19 // same schema as NVTX_SID_CommInitRank +#define NVTX_SID_Put 20 +#define NVTX_SID_Signal 21 +#define NVTX_SID_WaitSignal 22 // When adding new schema IDs, DO NOT re-use/overlap with the enum schema ID below! // Define static schema ID for the reduction operation. -#define NVTX_PAYLOAD_ENTRY_NCCL_REDOP 20 + NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START +#define NVTX_PAYLOAD_ENTRY_NCCL_REDOP 23 + NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START extern const nvtxDomainHandle_t ncclNvtxDomainHandle; diff --git a/src/include/nvtx_payload_schemas.h b/src/include/nvtx_payload_schemas.h index 9a47fbe86..6d4970c55 100644 --- a/src/include/nvtx_payload_schemas.h +++ b/src/include/nvtx_payload_schemas.h @@ -156,4 +156,29 @@ NCCL_NVTX_DEFINE_STRUCT_WITH_SCHEMA_ENTRIES(NcclNvtxParamsSendRecv, static const ) ) +NCCL_NVTX_DEFINE_STRUCT_WITH_SCHEMA_ENTRIES(NcclNvtxParamsPut, static constexpr, + NCCL_NVTX_PAYLOAD_ENTRIES( + (uint64_t, comm, TYPE_UINT64, nccl_nvtxCommStr), + (size_t, bytes, TYPE_SIZE, nccl_nvtxMsgSizeStr), + (int, peer, TYPE_INT, "Peer rank"), + (int, ctx, TYPE_INT, "Context ID") + ) +) + +NCCL_NVTX_DEFINE_STRUCT_WITH_SCHEMA_ENTRIES(NcclNvtxParamsSignal, static constexpr, + NCCL_NVTX_PAYLOAD_ENTRIES( + (uint64_t, comm, TYPE_UINT64, nccl_nvtxCommStr), + (int, peer, TYPE_INT, "Peer rank"), + (int, ctx, TYPE_INT, "Context ID") + ) +) + +NCCL_NVTX_DEFINE_STRUCT_WITH_SCHEMA_ENTRIES(NcclNvtxParamsWaitSignal, static constexpr, + NCCL_NVTX_PAYLOAD_ENTRIES( + (uint64_t, comm, TYPE_UINT64, nccl_nvtxCommStr), + (int, npeers, TYPE_INT, "Number of peers"), + (int, ctx, TYPE_INT, "Context ID") + ) +) + #endif // end include guard diff --git a/src/include/rma/rma.h b/src/include/rma/rma.h new file mode 100644 index 000000000..744be235f --- /dev/null +++ b/src/include/rma/rma.h @@ -0,0 +1,34 @@ +/************************************************************************* + * Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#ifndef _NCCL_RMA_H_ +#define _NCCL_RMA_H_ + +#include "nccl.h" +#include "nccl_common.h" +#include "rma/rma_ce.h" +#include "rma/rma_proxy.h" + +struct ncclRmaArgs{ + int ctx; + ncclFunc_t func; + int nRmaTasks; + int nRmaTasksProxy; + int nRmaTasksCe; +}; + +struct ncclRmaState { + struct ncclRmaProxyState rmaProxyState; + struct ncclRmaCeState rmaCeState; +}; + +// Main RMA function declarations +ncclResult_t scheduleRmaTasksToPlan(struct ncclComm* comm, struct ncclKernelPlan* plan); +ncclResult_t ncclLaunchRma(struct ncclComm* comm, struct ncclKernelPlan* plan); +ncclResult_t ncclRmaWaitSignal(struct ncclComm* comm, struct ncclKernelPlan* plan, cudaStream_t stream); +ncclResult_t ncclRmaPut(struct ncclComm* comm, struct ncclKernelPlan* plan, cudaStream_t stream); + +#endif diff --git a/src/include/rma/rma_ce.h b/src/include/rma/rma_ce.h new file mode 100644 index 000000000..612fb07d2 --- /dev/null +++ b/src/include/rma/rma_ce.h @@ -0,0 +1,50 @@ +/************************************************************************* + * Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#ifndef _NCCL_RMA_CE_H_ +#define _NCCL_RMA_CE_H_ + +#include "nccl.h" +#include "nccl_common.h" +#include "dev_runtime.h" + +struct ncclComm; +struct ncclRmaArgs; + +struct ncclRmaCeInitTask { + struct ncclRmaCeInitTask *next; + struct ncclComm* comm; +}; + +struct ncclRmaCeCtx { + struct ncclComm *comm; + + // Per-rank sequence number for the signal operations + uint64_t* signalOpSeqs; + + // Signal memory layout and management + // Each RMA context allocates a signal buffer with the following layout: + // - Offsets [0 to nRanks*8-1]: per-rank distinct signals (8 bytes per rank) + // - Offset [nRanks*8]: shared aggregate signal counter (8 bytes) + // Total signal buffer size: (nRanks + 1) * 8 bytes + struct ncclDevrWindow* signalsWin; + uint64_t *signalsDev; + uint64_t* signalsHost; // Host buffer to track the expected values of the signals +}; + + +struct ncclRmaCeState { + bool initialized; + int rmaCeCtxCount; + void** rmaCeCtxs; +}; + +// CE-specific function declarations +ncclResult_t ncclRmaCeInit(struct ncclComm* comm); +ncclResult_t ncclRmaCeFinalize(struct ncclComm* comm); +ncclResult_t ncclRmaPutCe(struct ncclComm* comm, struct ncclKernelPlan* plan, cudaStream_t stream); +ncclResult_t ncclRmaWaitSignalCe(struct ncclComm* comm, struct ncclKernelPlan* plan, cudaStream_t stream); +#endif diff --git a/src/include/rma/rma_proxy.h b/src/include/rma/rma_proxy.h new file mode 100644 index 000000000..4a9277509 --- /dev/null +++ b/src/include/rma/rma_proxy.h @@ -0,0 +1,139 @@ +/************************************************************************* + * Copyright (c) 2016-2022, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +#ifndef _NCCL_RMA_PROXY_H_ +#define _NCCL_RMA_PROXY_H_ + +#include "nccl.h" +#include "nccl_net.h" +#include "nccl_common.h" +#include "gin/gin_host.h" +#include "alloc.h" + +struct ncclComm; +struct ncclRmaArgs; + +struct ncclRmaSignal_t { + void *signalMhandle; + uint64_t offset; + uint64_t val; + uint32_t op; +}; + +typedef enum ncclRmaDescState_t { + ncclRmaDescStatePending = 0, + ncclRmaDescStateInProgress, +} ncclRmaDescState_t; + +struct ncclRmaProxyDesc { + struct ncclRmaProxyDesc *next; + + // Network function descriptor + uint64_t srcOff; + void *srcHandle; + uint64_t dstOff; + void *dstHandle; + size_t size; + int targetRank; + ncclRmaSignal_t signal; + + // Sequence number for the network operation + uint64_t seq; + + // State of the network function descriptor + ncclRmaDescState_t rmaDescState; + + // Request handle for the network operation + void * request; +}; + +struct ncclRmaProxyCtx { + struct ncclComm *comm; + + // GIN context for the RMA proxy context + void *ginCollComm; + ncclNetDeviceHandle_t *devHandle; + ncclNetProperties_t props; + + // Per-rank rmaProxyDescQueues: Pending Descs waiting for readySeq to be ready + struct ncclIntruQueue* rmaProxyDescQueues; + // Per-rank rmaProxyInProgressQueues: Descs with issued network operations waiting for completion + struct ncclIntruQueue* rmaProxyInProgressQueues; + // Mutex to protect Desc queue access (user thread enqueues, progress thread dequeues) + pthread_mutex_t* DescQueueLocks; + + // Per-rank sequence number and counters + uint64_t* opSeqs; + uint64_t* opSeqsDev; + void* opSeqsGdrHandle; + uint64_t* readySeqs; + uint64_t* readySeqsDev; + void* readySeqsGdrHandle; + uint64_t* doneSeqs; + uint64_t* doneSeqsDev; + void* doneSeqsGdrHandle; + + // Signal memory layout and management + // Each RMA context allocates a signal buffer with the following layout: + // - Offsets [0 to nRanks*8-1]: per-rank distinct signals (8 bytes per rank) + // - Offset [nRanks*8]: shared aggregate signal counter (8 bytes) + // Total signal buffer size: (nRanks + 1) * 8 bytes + CUmemGenericAllocationHandle signalsCumemhandle; + void *signalsMhandle; + void *signalsGinHandle; + uint64_t *signalsDev; + uint64_t* signalsHost; // Host buffer to track the expected values of the signals +}; + +struct ncclRmaProxyState { + struct ncclComm *comm; + ncclGin_t* ncclGin; + void* ginInstance; + bool connected; + int ginType; + + // Physical GIN communicator contexts + int ginCommCount; + void* ginComms[NCCL_GIN_MAX_CONTEXTS]; + ncclNetProperties_t props[NCCL_GIN_MAX_CONTEXTS]; + + // Virtual RMA proxy contexts + int rmaProxyCtxCount; + void** rmaProxyCtxs; + ncclNetDeviceHandle_t** rmaProxyDevHandles; + + int needsProxyProgress; // Whether we need to progress GIN operations with the proxy + int ginProgress; // GIN progress is enabled + pthread_t thread; + pthread_mutex_t threadLock; + pthread_cond_t threadCond; + ncclResult_t asyncResult; +}; + +// Proxy-specific function declarations +ncclResult_t ncclRmaPutProxy(struct ncclComm* comm, struct ncclKernelPlan* plan, cudaStream_t stream); +ncclResult_t ncclRmaWaitSignalProxy(struct ncclComm* comm, struct ncclKernelPlan* plan, cudaStream_t stream); + +// RMA Proxy lifecycle functions +ncclResult_t ncclRmaProxyConnectOnce(struct ncclComm* comm); +ncclResult_t ncclRmaProxyFinalize(struct ncclComm* comm); + +// RMA Proxy context management +ncclResult_t ncclRmaProxyCreateContext(struct ncclComm *comm, void *collComm, ncclNetProperties_t props, + void **outRmaProxyCtx, ncclNetDeviceHandle_t **outDevHandle); +ncclResult_t ncclRmaProxyDestroyContext(ncclGin_t* ginComm, void* ginCtx); +ncclResult_t ncclRmaProxyProgress(ncclGin_t* ncclGin, void* ginCtx); + +// RMA Proxy memory registration +ncclResult_t ncclRmaProxyRegister(struct ncclComm* comm, void* address, size_t size, + void* rmaHostWins[NCCL_GIN_MAX_CONTEXTS], + ncclGinWindow_t rmaDevWins[NCCL_GIN_MAX_CONTEXTS]); +ncclResult_t ncclRmaProxyDeregister(struct ncclComm* comm, void* rmaHostWins[NCCL_GIN_MAX_CONTEXTS]); + +// Progress thread function +void* ncclRmaProxyProgressThread(void* rmaProxyState_); + +#endif diff --git a/src/init.cc b/src/init.cc index 8e8b0fdaa..6b3ef7fc3 100644 --- a/src/init.cc +++ b/src/init.cc @@ -36,6 +36,7 @@ #include "ce_coll.h" #include "nvtx.h" #include "env.h" +#include "rma/rma.h" #define STR2(v) #v #define STR(v) STR2(v) @@ -59,6 +60,7 @@ NCCL_PARAM(WinEnable, "WIN_ENABLE", 1); NCCL_PARAM(CollnetEnable, "COLLNET_ENABLE", NCCL_CONFIG_UNDEF_INT); NCCL_PARAM(CtaPolicy, "CTA_POLICY", NCCL_CONFIG_UNDEF_INT); NCCL_PARAM(NvlsChannels, "NVLS_NCHANNELS", NCCL_CONFIG_UNDEF_INT); +NCCL_PARAM(NumRmaCtx, "NUM_RMA_CTX", NCCL_CONFIG_UNDEF_INT); NCCL_PARAM(SetCpuStackSize, "SET_CPU_STACK_SIZE", 1); extern int64_t ncclParamSingleProcMemRegEnable(); @@ -241,6 +243,7 @@ static ncclResult_t commFree(ncclComm_t comm) { return ncclSuccess; NCCLCHECK(ncclCeFinalize(comm)); + NCCLCHECK(ncclRmaCeFinalize(comm)); if (comm->symmetricSupport) { NCCLCHECK(ncclSymkFinalize(comm)); @@ -286,6 +289,7 @@ static ncclResult_t commFree(ncclComm_t comm) { // GIN may use proxy. We need to finalize it before destroying the proxy. NCCLCHECK(ncclGinFinalize(comm)); + NCCLCHECK(ncclRmaProxyFinalize(comm)); int sharedResRefCount = 0; if (comm->sharedRes) { @@ -456,6 +460,7 @@ static ncclResult_t commAlloc(struct ncclComm* comm, struct ncclComm* parent, in ncclMemoryPoolConstruct(&comm->memPool_ncclKernelPlan); ncclMemoryPoolConstruct(&comm->memPool_ncclProxyOp); + ncclMemoryPoolConstruct(&comm->memPool_ncclRmaProxyDesc); for (int i = 0; i < ncclGroupTaskTypeNum; i++) { comm->groupNext[i] = reinterpret_cast(0x1); @@ -1266,6 +1271,15 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p comm->planner.peers = ncclMemoryStackAlloc(&comm->memPermanent, comm->nRanks); NCCLCHECK(ncclP2pSchedule(comm)); + if (comm->config.numRmaCtx > 0) { + comm->planner.rmaTaskQueues = ncclMemoryStackAlloc>(&comm->memPermanent, comm->config.numRmaCtx); + for (int i = 0; i < comm->config.numRmaCtx; i++) { + ncclIntruQueueConstruct(&comm->planner.rmaTaskQueues[i]); + } + } else { + comm->planner.rmaTaskQueues = NULL; + } + comm->runtimeConn = comm->cuMemSupport && ncclParamRuntimeConnect(); if (comm->runtimeConn) { for (int c=0; cnChannels; c++) { @@ -1642,6 +1656,7 @@ static ncclResult_t envConfigOverride(ncclComm_t comm) { int nvlsCTAsEnv; int nChannelsPerNetPeerEnv; int nvlinkUtilCentricSchedEnableEnv; + int numRmaCtxEnv; /* override configuration with env variable. */ blockingEnv = ncclParamCommBlocking(); @@ -1703,6 +1718,14 @@ static ncclResult_t envConfigOverride(ncclComm_t comm) { } } + numRmaCtxEnv = ncclParamNumRmaCtx(); + if (numRmaCtxEnv != NCCL_CONFIG_UNDEF_INT) { + if (numRmaCtxEnv <= 0) + INFO(NCCL_ENV, "NCCL_NUM_RMA_CTX %d is too low, leaving it set at %d", numRmaCtxEnv, comm->config.numRmaCtx); + else + comm->config.numRmaCtx = numRmaCtxEnv; + } + envNetName = ncclGetEnv("NCCL_NET"); if (envNetName) tmpNetName = envNetName; @@ -1847,6 +1870,9 @@ static ncclResult_t parseCommConfig(ncclComm_t comm, ncclConfig_t *config) { internalConfigPtr->nChannelsPerNetPeer = defaultConfig.nChannelsPerNetPeer; internalConfigPtr->nvlinkCentricSched = defaultConfig.nvlinkCentricSched; } + if (internalConfigPtr->version < NCCL_VERSION(2, 29, 0)) { + internalConfigPtr->numRmaCtx = defaultConfig.numRmaCtx; + } } /* check input config attributes, -1 means user-undefined and we should use default value from NCCL. */ @@ -1915,6 +1941,12 @@ static ncclResult_t parseCommConfig(ncclComm_t comm, ncclConfig_t *config) { goto fail; } + if (internalConfigPtr->numRmaCtx != NCCL_CONFIG_UNDEF_INT && internalConfigPtr->numRmaCtx <= 0) { + WARN("Invalid config numRmaCtx attribute value %d", internalConfigPtr->numRmaCtx); + ret = ncclInvalidArgument; + goto fail; + } + /* default config value can be tuned on different platform. */ NCCL_CONFIG_DEFAULT(internalConfigPtr, blocking, NCCL_CONFIG_UNDEF_INT, 1, "Blocking", "%d"); NCCL_CONFIG_DEFAULT(internalConfigPtr, cgaClusterSize, NCCL_CONFIG_UNDEF_INT, 4, "CGA cluster size", "%d"); @@ -1931,6 +1963,7 @@ static ncclResult_t parseCommConfig(ncclComm_t comm, ncclConfig_t *config) { NCCL_CONFIG_DEFAULT(internalConfigPtr, nChannelsPerNetPeer, NCCL_CONFIG_UNDEF_INT, NCCL_CONFIG_UNDEF_INT, "nChannelsPerNetPeer", "%d"); NCCL_CONFIG_DEFAULT(internalConfigPtr, nvlinkCentricSched, NCCL_CONFIG_UNDEF_INT, 0, "nvlinkCentricSched", "%d"); + NCCL_CONFIG_DEFAULT(internalConfigPtr, numRmaCtx, NCCL_CONFIG_UNDEF_INT, 1, "numRmaCtx", "%d"); /* assign config to communicator */ comm->config.blocking = internalConfigPtr->blocking; @@ -1947,6 +1980,7 @@ static ncclResult_t parseCommConfig(ncclComm_t comm, ncclConfig_t *config) { comm->config.nvlsCTAs = internalConfigPtr->nvlsCTAs; comm->config.nChannelsPerNetPeer = internalConfigPtr->nChannelsPerNetPeer; comm->config.nvlinkCentricSched = internalConfigPtr->nvlinkCentricSched; + comm->config.numRmaCtx = internalConfigPtr->numRmaCtx; NCCLCHECKGOTO(envConfigOverride(comm), ret, fail); exit: diff --git a/src/nccl.h.in b/src/nccl.h.in index 61de6b800..57c551471 100644 --- a/src/nccl.h.in +++ b/src/nccl.h.in @@ -94,6 +94,7 @@ typedef struct ncclConfig_v22800 { int nvlsCTAs; int nChannelsPerNetPeer; int nvlinkCentricSched; + int numRmaCtx; } ncclConfig_t; /* Config initializer must be assigned to initialize config structure when it is created. @@ -116,6 +117,7 @@ typedef struct ncclConfig_v22800 { NCCL_CONFIG_UNDEF_INT, /* nvlsCTAs */ \ NCCL_CONFIG_UNDEF_INT, /* nChannelsPerNetPeer */ \ NCCL_CONFIG_UNDEF_INT, /* nvlinkCentricSched */ \ + NCCL_CONFIG_UNDEF_INT, /* numRmaCtx */ \ } /* This struct will be used by ncclGroupSimulateEnd() API to query information about simulation. */ @@ -526,6 +528,90 @@ ncclResult_t pncclRecv(void* recvbuff, size_t count, ncclDataType_t datatype, in ncclResult_t ncclRecv(void* recvbuff, size_t count, ncclDataType_t datatype, int peer, ncclComm_t comm, cudaStream_t stream); + +typedef enum { + NCCL_SIGNAL_NONE = 0, // No signaling + NCCL_SIGNAL_AGGREGATE = 1, // Signals can be aggregated/merged across peers + NCCL_SIGNAL_DISTINCT = 2 // Signals must remain distinct per-peer +} ncclSignalMode_t; + + +/* + * Put + * + * One-sided communication operation that writes data from the local buffer to a + * remote peer's registered memory window without explicit participation from the + * target process. + * + * Parameters: + * ctx - Context identifier for the operation + * localbuff - Local source buffer containing data to be transferred + * count - Number of elements to transfer + * datatype - NCCL data type of each element + * peer - Target rank to write data to + * peerWinOffset- Offset in bytes from the start of peer's registered window + * peerWin - Memory window object registered by the target peer + * signalMode - Signaling behavior: + * NCCL_SIGNAL_NONE: No signaling after put operation + * NCCL_SIGNAL_AGGREGATE: Signal can be merged with others (use for barrier-like patterns) + * NCCL_SIGNAL_DISTINCT: Signal must remain separate per-peer (use when peer identity matters) + * comm - NCCL communicator + * stream - CUDA stream to enqueue the operation on + * + * Returns: + * ncclSuccess on successful enqueue, error code otherwise + */ +ncclResult_t ncclPut(int ctx, const void* localbuff, size_t count, ncclDataType_t datatype, int peer, + size_t peerWinOffset, ncclWindow_t peerWin, ncclSignalMode_t signalMode, ncclComm_t comm, cudaStream_t stream); + +ncclResult_t pncclPut(int ctx, const void* localbuff, size_t count, ncclDataType_t datatype, + int peer, size_t peerWinOffset, ncclWindow_t peerWin, ncclSignalMode_t signalMode, ncclComm_t comm, cudaStream_t stream); + +/* + * Signal + * + * Sends a signal to the specified peer without transferring data. + * + * Parameters: + * ctx - Context identifier for the operation + * peer - Target rank to send signal to + * signalMode - Signaling behavior: + * NCCL_SIGNAL_AGGREGATE: Signal can be merged with others (use for barrier-like patterns) + * NCCL_SIGNAL_DISTINCT: Signal must remain separate per-peer (use when peer identity matters) + * Note: NCCL_SIGNAL_NONE is not valid for explicit signal operations + * comm - NCCL communicator + * stream - CUDA stream to enqueue the operation on + * + * Returns: + * ncclSuccess on successful signal enqueue, error code otherwise + */ +ncclResult_t ncclSignal(int ctx, int peer, ncclSignalMode_t signalMode, ncclComm_t comm, cudaStream_t stream); +ncclResult_t pncclSignal(int ctx, int peer, ncclSignalMode_t signalMode, ncclComm_t comm, cudaStream_t stream); + +/* + * Wait Signal + * + * Waits for specified number of signals from each peer. + * + * Parameters: + * ctx - Context identifier for the operation + * peers - Array of peer ranks to wait signals from + * nsignals - Array of signal counts, where nsignals[i] is the number of + * signals to wait for from peers[i] + * npeers - Number of peers (length of both peers and nsignals arrays) + * signalMode - Signaling behavior: + * NCCL_SIGNAL_AGGREGATE: Signals within same transport can be merged + * NCCL_SIGNAL_DISTINCT: All signals must remain separate per-peer + * Note: NCCL_SIGNAL_NONE is not valid for wait operations + * comm - NCCL communicator + * stream - CUDA stream to enqueue the operation on + * + * Returns: + * ncclSuccess when all required signals received, error code otherwise + */ +ncclResult_t ncclWaitSignal(int ctx, int* peers, int* nsignals, int npeers, ncclSignalMode_t signalMode, ncclComm_t comm, cudaStream_t stream); +ncclResult_t pncclWaitSignal(int ctx, int* peers, int* nsignals, int npeers, ncclSignalMode_t signalMode, ncclComm_t comm, cudaStream_t stream); + /* * Group semantics * diff --git a/src/plugin/net.cc b/src/plugin/net.cc index 6ec428a2d..ec549e459 100644 --- a/src/plugin/net.cc +++ b/src/plugin/net.cc @@ -248,6 +248,7 @@ static ncclResult_t ncclNetPluginAssignToComm(struct ncclComm* comm, int pluginI if (netPluginLibs[pluginIndex].ncclGinPluginState >= ncclNetPluginStateEnabled) { INFO(NCCL_INIT|NCCL_NET, "Assigned GIN plugin %s to comm", netPluginLibs[pluginIndex].ncclGin->name); comm->sharedRes->ginState.ncclGin = netPluginLibs[pluginIndex].ncclGin; + comm->rmaState.rmaProxyState.ncclGin = netPluginLibs[pluginIndex].ncclGin; } } exit: diff --git a/src/rma/CMakeLists.txt b/src/rma/CMakeLists.txt new file mode 100644 index 000000000..609639814 --- /dev/null +++ b/src/rma/CMakeLists.txt @@ -0,0 +1,9 @@ +# RMA sources +set(RMA_SOURCES + ${CMAKE_CURRENT_SOURCE_DIR}/rma.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rma_proxy.cc + ${CMAKE_CURRENT_SOURCE_DIR}/rma_ce.cc +) + +# Add RMA sources to parent scope +set(RMA_SOURCES ${RMA_SOURCES} PARENT_SCOPE) diff --git a/src/rma/rma.cc b/src/rma/rma.cc new file mode 100644 index 000000000..fff1cd7af --- /dev/null +++ b/src/rma/rma.cc @@ -0,0 +1,240 @@ +#include +#include "nccl.h" +#include "alloc.h" +#include "checks.h" +#include "comm.h" +#include "rma/rma.h" + +static bool isLsaAccessible(struct ncclComm* comm, int rank) { + for (int i = 0; i < comm->devrState.lsaSize; i++) { + if (comm->devrState.lsaRankList[i] == rank) { + return true; + } + } + return false; +} + +ncclResult_t ncclRmaWaitSignal(struct ncclComm* comm, struct ncclKernelPlan* plan, cudaStream_t stream){ + ncclResult_t ret = ncclSuccess; + + if (plan->rmaArgs->nRmaTasksProxy > 0) { + NCCLCHECKGOTO(ncclRmaWaitSignalProxy(comm, plan, stream), ret, fail); + } + + if (plan->rmaArgs->nRmaTasksCe > 0) { + NCCLCHECKGOTO(ncclRmaWaitSignalCe(comm, plan, stream), ret, fail); + } + +exit: + return ret; +fail: + goto exit; +} + + +ncclResult_t ncclRmaPut(struct ncclComm* comm, struct ncclKernelPlan* plan, cudaStream_t stream){ + ncclResult_t ret = ncclSuccess; + + if (plan->rmaArgs->nRmaTasksProxy > 0) { + NCCLCHECKGOTO(ncclRmaPutProxy(comm, plan, stream), ret, fail); + } + + if (plan->rmaArgs->nRmaTasksCe > 0) { + NCCLCHECKGOTO(ncclRmaPutCe(comm, plan, stream), ret, fail); + } + +exit: + return ret; +fail: + goto exit; +} + +ncclResult_t ncclLaunchRma(struct ncclComm* comm, struct ncclKernelPlan* plan) { + ncclResult_t ret = ncclSuccess; + cudaStream_t stream = comm->planner.streams->stream; + + switch (plan->rmaArgs->func) { + case ncclFuncPut: + NCCLCHECKGOTO(ncclRmaPut(comm, plan, stream), ret, fail); + break; + case ncclFuncSignal: + NCCLCHECKGOTO(ncclRmaPut(comm, plan, stream), ret, fail); + break; + case ncclFuncWaitSignal: + NCCLCHECKGOTO(ncclRmaWaitSignal(comm, plan, stream), ret, fail); + break; + default: + ret = ncclInvalidUsage; + } + +exit: + return ret; +fail: + goto exit; +} + +static inline bool isRmaPutOrSignal(ncclFunc_t func) { + return (func == ncclFuncPut || func == ncclFuncSignal); +} + +// Check if two RMA tasks can be batched together +static inline bool canBatchRmaTasks(struct ncclTaskRma* task1, struct ncclTaskRma* task2) { + // Check if the tasks are in the same context + if (task1->ctx != task2->ctx) return false; + + // Check if the tasks are the same function + if (task1->func == task2->func) return true; + + // Put/Signal tasks can be batched together + if (isRmaPutOrSignal(task1->func) && isRmaPutOrSignal(task2->func)) { + return true; + } + + return false; +} + +// Schedule comm->planner RMA tasks to the plan and split the RMA tasks into CE and Proxy tasks +// Then seek opportunities to batch tasks, batching checked for consecutive operations targeting the same context +// - ncclFuncWaitSignal does not perform further batching as the API can already batch waitSignal from multiple peers +// - Consecutive put/signal operation can be batched into the same plan +ncclResult_t scheduleRmaTasksToPlan(struct ncclComm* comm, struct ncclKernelPlan* plan) { + ncclResult_t ret = ncclSuccess; + struct ncclKernelPlanner* planner = &comm->planner; + + // Find the first non-empty context queue + int ctx = -1; + for (int i = 0; i < comm->config.numRmaCtx; i++) { + if (!ncclIntruQueueEmpty(&planner->rmaTaskQueues[i])) { + ctx = i; + break; + } + } + + // No RMA tasks to schedule + if (ctx == -1) return ncclSuccess; + + struct ncclIntruQueue* ctxQueue = &planner->rmaTaskQueues[ctx]; + + // Get the first task to determine the operation category + struct ncclTaskRma* firstTask = ncclIntruQueueDequeue(ctxQueue); + + // Initialize plan + plan->isRma = true; + plan->rmaArgs = ncclMemoryStackAlloc(&comm->memScoped); + plan->rmaArgs->ctx = ctx; + plan->rmaArgs->func = firstTask->func; + plan->rmaArgs->nRmaTasks = 0; + plan->rmaArgs->nRmaTasksProxy = 0; + plan->rmaArgs->nRmaTasksCe = 0; + + // WaitSignal tasks + if (firstTask->func == ncclFuncWaitSignal) { + // Allocate temporary arrays to hold peers and nsignals for both proxy and CE paths + int* peersCe = ncclMemoryStackAlloc(&comm->memScoped, firstTask->npeers); + int* nsignalsCe = ncclMemoryStackAlloc(&comm->memScoped, firstTask->npeers); + int* peersProxy = ncclMemoryStackAlloc(&comm->memScoped, firstTask->npeers); + int* nsignalsProxy = ncclMemoryStackAlloc(&comm->memScoped, firstTask->npeers); + + int npeersCe = 0; + int npeersProxy = 0; + + // Go over the firstTask->peers and split them based on LSA accessibility + for (int i = 0; i < firstTask->npeers; i++) { + int peerRank = firstTask->peers[i]; + bool lsaAccessible = isLsaAccessible(comm, peerRank); + + if (lsaAccessible) { + // Add to CE list + peersCe[npeersCe] = peerRank; + nsignalsCe[npeersCe] = firstTask->nsignals[i]; + npeersCe++; + } else { + // Add to Proxy list + peersProxy[npeersProxy] = peerRank; + nsignalsProxy[npeersProxy] = firstTask->nsignals[i]; + npeersProxy++; + } + } + + // Initialize the CE task if there are CE peers + if (npeersCe > 0) { + struct ncclTaskRma* waitSignalTaskCe = ncclMemoryPoolAlloc(&comm->memPool_ncclTaskRma, &comm->memPermanent); + waitSignalTaskCe->func = ncclFuncWaitSignal; + waitSignalTaskCe->ctx = firstTask->ctx; + waitSignalTaskCe->signalMode = firstTask->signalMode; + waitSignalTaskCe->peers = peersCe; + waitSignalTaskCe->nsignals = nsignalsCe; + waitSignalTaskCe->npeers = npeersCe; + ncclIntruQueueEnqueue(&plan->rmaTaskQueueCe, waitSignalTaskCe); + plan->rmaArgs->nRmaTasksCe = 1; + } else { + plan->rmaArgs->nRmaTasksCe = 0; + } + + // Initialize the Proxy task if there are Proxy peers + if (npeersProxy > 0) { + struct ncclTaskRma* waitSignalTaskProxy = ncclMemoryPoolAlloc(&comm->memPool_ncclTaskRma, &comm->memPermanent); + waitSignalTaskProxy->func = ncclFuncWaitSignal; + waitSignalTaskProxy->ctx = firstTask->ctx; + waitSignalTaskProxy->signalMode = firstTask->signalMode; + waitSignalTaskProxy->peers = peersProxy; + waitSignalTaskProxy->nsignals = nsignalsProxy; + waitSignalTaskProxy->npeers = npeersProxy; + ncclIntruQueueEnqueue(&plan->rmaTaskQueueProxy, waitSignalTaskProxy); + plan->rmaArgs->nRmaTasksProxy = 1; + } else { + plan->rmaArgs->nRmaTasksProxy = 0; + } + + plan->rmaArgs->nRmaTasks = (npeersCe > 0 ? 1 : 0) + (npeersProxy > 0 ? 1 : 0); + planner->nTasksRma -= 1; + // Free the original WaitSignal task (split into CE and Proxy tasks) + ncclMemoryPoolFree(&comm->memPool_ncclTaskRma, firstTask); + } + // Put/Signal tasks + else { + // Check if the first task is LSA accessible + bool lsaAccessible = isLsaAccessible(comm, firstTask->peer); + + plan->rmaArgs->nRmaTasks = 1; + plan->rmaArgs->nRmaTasksProxy = lsaAccessible ? 0 : 1; + plan->rmaArgs->nRmaTasksCe = lsaAccessible ? 1 : 0; + + if (lsaAccessible) { + ncclIntruQueueEnqueue(&plan->rmaTaskQueueCe, firstTask); + } else { + ncclIntruQueueEnqueue(&plan->rmaTaskQueueProxy, firstTask); + } + + planner->nTasksRma -= 1; + + // Batch consecutive tasks from the same context that match operation category + while (!ncclIntruQueueEmpty(ctxQueue)) { + struct ncclTaskRma* task = ncclIntruQueueHead(ctxQueue); + + // Check if this task can be batched with the first task + if (!canBatchRmaTasks(firstTask, task)) { + break; + } + + bool lsaAccessible = isLsaAccessible(comm, task->peer); + + // If the task can be batched, remove from context queue and add to plan + ncclIntruQueueDequeue(ctxQueue); + if (lsaAccessible) { + ncclIntruQueueEnqueue(&plan->rmaTaskQueueCe, task); + plan->rmaArgs->nRmaTasksCe++; + } else { + ncclIntruQueueEnqueue(&plan->rmaTaskQueueProxy, task); + plan->rmaArgs->nRmaTasksProxy++; + } + plan->rmaArgs->nRmaTasks++; + planner->nTasksRma -= 1; + } + } + + INFO(NCCL_COLL, "scheduleRmaTasksToPlan: rank=%d ctx=%d func=%d nRmaTasks=%d nRmaTasksProxy=%d nRmaTasksCe=%d", + comm->rank, ctx, plan->rmaArgs->func, plan->rmaArgs->nRmaTasks, plan->rmaArgs->nRmaTasksProxy, plan->rmaArgs->nRmaTasksCe); + + return ret; +} diff --git a/src/rma/rma_ce.cc b/src/rma/rma_ce.cc new file mode 100644 index 000000000..2c31f08a1 --- /dev/null +++ b/src/rma/rma_ce.cc @@ -0,0 +1,230 @@ +#include +#include "nccl.h" +#include "alloc.h" +#include "checks.h" +#include "comm.h" +#include "collectives.h" +#include "rma/rma.h" +#include "rma/rma_ce.h" + +ncclResult_t ncclRmaCeInit(struct ncclComm* comm){ + ncclResult_t ret = ncclSuccess; + + // Ensure symmetric memory runtime is initialized + NCCLCHECKGOTO(ncclDevrInitOnce(comm), ret, fail); + + comm->rmaState.rmaCeState.rmaCeCtxCount = comm->config.numRmaCtx; + + NCCLCHECKGOTO(ncclCalloc(&comm->rmaState.rmaCeState.rmaCeCtxs, comm->rmaState.rmaCeState.rmaCeCtxCount), ret, fail); + for (int i = 0; i < comm->rmaState.rmaCeState.rmaCeCtxCount; i++) { + // Allocate the RMA CE context + struct ncclRmaCeCtx* ceCtx; + NCCLCHECKGOTO(ncclCalloc(&ceCtx, 1), ret, fail); + comm->rmaState.rmaCeState.rmaCeCtxs[i] = ceCtx; + + // Initialize context + ceCtx->comm = comm; + + // Allocate and register symmetric memory for signals + // Signal buffer layout: [0..nRanks-1] per-rank signals, [nRanks] aggregate signal + size_t signalsBufSize = (comm->nRanks + 1) * sizeof(uint64_t); + uint64_t* signalsDevBase; + ncclWindow_vidmem* signalsWinDev; + ncclWindow_vidmem* signalsWinDevHost; + + NCCLCHECKGOTO(ncclMemAlloc((void**)&signalsDevBase, signalsBufSize), ret, fail); + NCCLCHECKGOTO(ncclDevrWindowRegisterInGroup(comm, signalsDevBase, signalsBufSize, NCCL_WIN_COLL_SYMMETRIC, &signalsWinDev), ret, fail); + NCCLCHECKGOTO(ncclShadowPoolToHost(&comm->devrState.shadows, signalsWinDev, &signalsWinDevHost), ret, fail); + + // Get the ncclDevrWindow from the winHost field + ceCtx->signalsWin = (struct ncclDevrWindow*)signalsWinDevHost->winHost; + ceCtx->signalsDev = signalsDevBase; + + // Allocate host buffer to track expected signal values + NCCLCHECKGOTO(ncclCalloc(&ceCtx->signalsHost, signalsBufSize), ret, fail); + + // Allocate per-rank operation sequence counters + NCCLCHECKGOTO(ncclCalloc(&ceCtx->signalOpSeqs, comm->nRanks), ret, fail); + + } + + INFO(NCCL_INIT, "Rank %d: finished init RMA CE contexts, numRmaCeCtxs %d", comm->rank, comm->config.numRmaCtx); + + comm->rmaState.rmaCeState.initialized = true; + +exit: + return ret; +fail: + goto exit; +} + +ncclResult_t ncclRmaCeFinalize(struct ncclComm* comm){ + ncclResult_t ret = ncclSuccess; + + // Clean up rmaCeInitTaskQueue + while (!ncclIntruQueueEmpty(&comm->rmaCeInitTaskQueue)) { + struct ncclRmaCeInitTask* task = ncclIntruQueueDequeue(&comm->rmaCeInitTaskQueue); + free(task); + } + + for (int i = 0; i < comm->rmaState.rmaCeState.rmaCeCtxCount; i++) { + struct ncclRmaCeCtx* ceCtx = (struct ncclRmaCeCtx*)comm->rmaState.rmaCeState.rmaCeCtxs[i]; + + // Free per-rank operation sequence counters + if (ceCtx->signalOpSeqs) free(ceCtx->signalOpSeqs); + + // Free host signals buffer + if (ceCtx->signalsHost) free(ceCtx->signalsHost); + + // Deregister and free signal window + if (ceCtx->signalsWin) NCCLCHECKGOTO(ncclCommWindowDeregister(comm, ceCtx->signalsWin->vidmem), ret, fail); + + // Free signal device memory + if (ceCtx->signalsDev) NCCLCHECKGOTO(ncclMemFree(ceCtx->signalsDev), ret, fail); + + // Free the context itself + free(ceCtx); + comm->rmaState.rmaCeState.rmaCeCtxs[i] = NULL; + } + + // Reset the number of contexts and initialized flag + comm->rmaState.rmaCeState.rmaCeCtxCount = 0; + comm->rmaState.rmaCeState.initialized = false; + + free(comm->rmaState.rmaCeState.rmaCeCtxs); + comm->rmaState.rmaCeState.rmaCeCtxs = NULL; + +exit: + return ret; +fail: + goto exit; +} + +ncclResult_t ncclRmaPutCe(struct ncclComm* comm, struct ncclKernelPlan* plan, cudaStream_t stream){ + ncclResult_t ret = ncclSuccess; + + // Make sure the RMA CE is initialized + if (!comm->rmaState.rmaCeState.initialized) { + WARN("RMA CE is not initialized"); + return ncclInternalError; + } + + int nRmaTasksCe = plan->rmaArgs->nRmaTasksCe; + int ctx = plan->rmaArgs->ctx; + struct ncclRmaCeCtx* ceCtx = (struct ncclRmaCeCtx*)comm->rmaState.rmaCeState.rmaCeCtxs[ctx]; + + for (int i = 0; i < nRmaTasksCe; i++) { + struct ncclTaskRma* task = ncclIntruQueueHead(&plan->rmaTaskQueueCe); + ncclIntruQueueDequeue(&plan->rmaTaskQueueCe); + + // Convert global peer rank to LSA rank index + // LSA rank is computed as: peer % lsaSize (see dev_runtime.cc) + int peerLsaRank = task->peer % comm->devrState.lsaSize; + + size_t bytes = task->count * ncclTypeSize(task->datatype); + + if (bytes > 0) { + // Get the peer buffer from the peer window + void* peerBuff; + NCCLCHECKGOTO(ncclDevrGetLsaRankPtr(comm, task->peerWinHost, task->peerWinOffset, peerLsaRank, &peerBuff), ret, fail); + + // Validate peer buffer + if (peerBuff == NULL) { + WARN("RMA CE: peerBuff is NULL after ncclDevrGetLsaRankPtr"); + ret = ncclInvalidArgument; + goto fail; + } + + // Copy the data to the peer buffer + CUDACHECKGOTO(cudaMemcpyAsync(peerBuff, task->srcBuff, bytes, cudaMemcpyDeviceToDevice, stream), ret, fail); + } + + // Write signal if needed for the target rank + // CE over NVL only supports distinct signal + if (task->signalMode != NCCL_SIGNAL_NONE) { + // Get the signal location in peer's signal buffer where we write to notify them + // We write to offset [comm->rank] in peer's signal buffer, same as proxy version + // So peer waits on their signalsDev[comm->rank] to see our signals + void* peerSignal; + NCCLCHECKGOTO(ncclDevrGetLsaRankPtr(comm, ceCtx->signalsWin, comm->rank * sizeof(uint64_t), peerLsaRank, &peerSignal), ret, fail); + + // Increment our sequence number for operations to this peer + ceCtx->signalOpSeqs[task->peer]++; + + // Write the absolute sequence number - peer will wait for this value + CUCHECKGOTO(cuStreamWriteValue64(stream, (CUdeviceptr)peerSignal, ceCtx->signalOpSeqs[task->peer], CU_STREAM_WRITE_VALUE_DEFAULT), ret, fail); + } + + // Free the task after processing + ncclMemoryPoolFree(&comm->memPool_ncclTaskRma, task); + } + +exit: + return ret; +fail: + goto exit; +} + + +ncclResult_t ncclRmaWaitSignalCe(struct ncclComm* comm, struct ncclKernelPlan* plan, cudaStream_t stream){ + ncclResult_t ret = ncclSuccess; + + // Make sure the RMA CE is initialized + if (!comm->rmaState.rmaCeState.initialized) { + WARN("RMA CE is not initialized"); + return ncclInternalError; + } + + int ctx = plan->rmaArgs->ctx; + struct ncclRmaCeCtx* ceCtx = (struct ncclRmaCeCtx*)comm->rmaState.rmaCeState.rmaCeCtxs[ctx]; + + struct ncclTaskRma* task = ncclIntruQueueHead(&plan->rmaTaskQueueCe); + ncclIntruQueueDequeue(&plan->rmaTaskQueueCe); + + // Assert task func is ncclFuncWaitSignal + assert(task->func == ncclFuncWaitSignal); + // Assert task context is the same as the plan context + assert(task->ctx == ctx); + // Assert the plan has exactly one RMA CE task + assert(plan->rmaArgs->nRmaTasksCe == 1); + + size_t opIdx = 0; + CUstreamBatchMemOpParams* batchParams = nullptr; + + NCCLCHECK(ncclCalloc(&batchParams, task->npeers)); + + // NVL only supports per-rank signal + if (task->signalMode == NCCL_SIGNAL_DISTINCT || task->signalMode == NCCL_SIGNAL_AGGREGATE) { + for (int i = 0; i < task->npeers; i++) { + int peerRank = task->peers[i]; + + // Calculate the expected signal value from this peer + // We wait on signalsDev[peerRank] where peerRank writes their sequence numbers + uint64_t waitValue = ceCtx->signalsHost[peerRank] + task->nsignals[i]; + + // Update our expectation for future waits + ceCtx->signalsHost[peerRank] = waitValue; + + // Add wait operation to batch + // Wait on our local signal buffer at offset [peerRank] where peer writes to us + batchParams[opIdx] = {}; + batchParams[opIdx].waitValue.operation = CU_STREAM_MEM_OP_WAIT_VALUE_64; + batchParams[opIdx].waitValue.address = (CUdeviceptr)&ceCtx->signalsDev[peerRank]; + batchParams[opIdx].waitValue.value64 = waitValue; + batchParams[opIdx].waitValue.flags = CU_STREAM_WAIT_VALUE_GEQ; + opIdx++; + } + + // Execute all wait operations in a single batch + CUCHECKGOTO(cuStreamBatchMemOp(stream, opIdx, batchParams, 0), ret, fail); + } + + // Free the task + ncclMemoryPoolFree(&comm->memPool_ncclTaskRma, task); + +exit: + if (batchParams) free(batchParams); + return ret; +fail: + goto exit; +} \ No newline at end of file diff --git a/src/rma/rma_proxy.cc b/src/rma/rma_proxy.cc new file mode 100644 index 000000000..7b3dba8b3 --- /dev/null +++ b/src/rma/rma_proxy.cc @@ -0,0 +1,802 @@ +#include +#include +#include +#include +#include "nccl.h" +#include "alloc.h" +#include "checks.h" +#include "gdrwrap.h" +#include "comm.h" +#include "bootstrap.h" +#include "rma/rma.h" +#include "rma/rma_proxy.h" +#include "dev_runtime.h" +#include "nccl_device/gin/proxy/gin_proxy_device_host_common.h" + + +extern int64_t ncclParamDmaBufEnable(); +extern int64_t ncclParamIbDataDirect(); +extern int64_t ncclParamGinEnable(); +extern int64_t ncclParamGinType(); + +NCCL_PARAM(RmaProxyDumpSignal, "RMA_PROXY_DUMP_SIGNAL", -1); + +#include +static ncclRmaProxyState* ncclLastRmaProxyState; + +ncclResult_t dumpRmaProxyState(struct ncclRmaProxyState* rmaProxyState) { + ncclLastRmaProxyState = rmaProxyState; + if (rmaProxyState->comm) { + printf("Rank %d RMA Proxy State:\n", rmaProxyState->comm->rank); + printf(" ginProgress: %d\n", rmaProxyState->ginProgress); + printf(" ginCommCount: %d\n", rmaProxyState->ginCommCount); + printf(" rmaProxyCtxCount:%d\n", rmaProxyState->rmaProxyCtxCount); + printf(" connected: %d\n", rmaProxyState->connected); + printf(" needsProxyProgress: %d\n", rmaProxyState->needsProxyProgress); + + // dump per-context information + for (int i = 0; i < rmaProxyState->rmaProxyCtxCount; i++) { + struct ncclRmaProxyCtx* ctx = (struct ncclRmaProxyCtx*)rmaProxyState->rmaProxyCtxs[i]; + printf(" rmaCtx[%d]: %p\n", i, ctx); + printf(" rmaDevHandles: %p\n", ctx->devHandle); + printf(" rmaCollComms: %p\n", ctx->ginCollComm); + if (ctx && ctx->comm) { + printf(" nRanks: %d, myRank: %d\n", ctx->comm->nRanks, ctx->comm->rank); + // dump per-peer information + for (int peer = 0; peer < ctx->comm->nRanks; peer++) { + int pendingCount = 0, inProgressCount = 0; + uint64_t readySeq = __atomic_load_n(&ctx->readySeqs[peer], __ATOMIC_ACQUIRE); + uint64_t doneSeq = __atomic_load_n(&ctx->doneSeqs[peer], __ATOMIC_ACQUIRE); + uint64_t opSeq = __atomic_load_n(&ctx->opSeqs[peer], __ATOMIC_ACQUIRE); + printf(" Peer %d: readySeq: %lu, doneSeq: %lu, opSeq: %lu\n", peer, readySeq, doneSeq, opSeq); + + // Count pending Descs + struct ncclRmaProxyDesc* desc = ncclIntruQueueHead(&ctx->rmaProxyDescQueues[peer]); + while (desc != NULL) { + pendingCount++; + desc = desc->next; + } + printf(" Pending Descs: %d\n", pendingCount); + // print all pending Descs + desc = ncclIntruQueueHead(&ctx->rmaProxyDescQueues[peer]); + while (desc != NULL) { + printf(" Desc: seq=%lu targetRank=%d size=%zu\n", + desc->seq, desc->targetRank, desc->size); + desc = desc->next; + } + + // Count in-progress Descs + desc = ncclIntruQueueHead(&ctx->rmaProxyInProgressQueues[peer]); + while (desc != NULL) { + inProgressCount++; + desc = desc->next; + } + printf(" In-progress Descs: %d\n", inProgressCount); + // print all in-progress Descs + desc = ncclIntruQueueHead(&ctx->rmaProxyInProgressQueues[peer]); + while (desc != NULL) { + printf(" Desc: seq=%lu targetRank=%d size=%zu\n", + desc->seq, desc->targetRank, desc->size); + desc = desc->next; + } + } + } else { + printf(" ginCtx[%d]: NULL\n", i); + } + } + } + return ncclSuccess; +} + +void ncclDumpRmaProxyState(int signal) { + dumpRmaProxyState(ncclLastRmaProxyState); +} + +static ncclResult_t getDmaBufFd(void *addr, size_t length, int *fd, + bool forceNonDataDirect = false) { + if (ncclParamDmaBufEnable() == 0) return ncclInvalidUsage; + +#if CUDA_VERSION >= 11070 + static size_t hostPageSize = sysconf(_SC_PAGESIZE); + size_t alignedSize = length; + ALIGN_SIZE(alignedSize, hostPageSize); + +#if CUDA_VERSION >= 12080 + if (ncclParamIbDataDirect() && !forceNonDataDirect) { + CUresult status = pfn_cuMemGetHandleForAddressRange( + (void *)fd, (CUdeviceptr)addr, alignedSize, CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, + CU_MEM_RANGE_FLAG_DMA_BUF_MAPPING_TYPE_PCIE); + if (status == CUDA_SUCCESS) return ncclSuccess; + } +#endif + CUresult status = pfn_cuMemGetHandleForAddressRange((void *)fd, (CUdeviceptr)addr, alignedSize, + CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, 0); + if (status == CUDA_SUCCESS) return ncclSuccess; +#endif + + return ncclInvalidUsage; +} + +// Check if the GIN plugin supports DMA-BUF, if so we can try to get the DMA-BUF handle from CUDA, +// if that fails we fallback to non-DMA-BUF +static ncclResult_t ncclRmaProxyRegMrSym(ncclGin_t *ginComm, void *ginCollComm, ncclNetProperties_t props, void *addr, + size_t size, int type, int mr_flags, void **mhandle, + void **ginHandle) { + if (type == NCCL_PTR_HOST) { + NCCLCHECK(ginComm->regMrSym(ginCollComm, addr, size, type, mr_flags, mhandle, ginHandle)); + } else if (type == NCCL_PTR_CUDA) { + ncclResult_t dmabufResult = ncclInvalidUsage; + if (ncclParamDmaBufEnable() && (props.ptrSupport & NCCL_PTR_DMABUF)) { + ncclResult_t registrationResult = ncclSuccess; + int dmabufFd = -1; + dmabufResult = getDmaBufFd(addr, size, &dmabufFd); + if (dmabufResult == ncclSuccess) { + registrationResult = ginComm->regMrSymDmaBuf(ginCollComm, addr, size, type, 0, dmabufFd, + mr_flags, mhandle, ginHandle); + close(dmabufFd); + } + if (registrationResult != ncclSuccess) { + dmabufFd = -1; + dmabufResult = getDmaBufFd(addr, size, &dmabufFd, true); + if (dmabufResult == ncclSuccess) { + NCCLCHECK(ginComm->regMrSymDmaBuf(ginCollComm, addr, size, type, 0, dmabufFd, + mr_flags, mhandle, ginHandle)); + close(dmabufFd); + } + } + } + // Fallback to non-DMA-BUF if the DMA-BUF handle is not supported + if (dmabufResult != ncclSuccess) { + NCCLCHECK(ginComm->regMrSym(ginCollComm, addr, size, type, mr_flags, mhandle, ginHandle)); + } + } else { + return ncclInvalidUsage; + } + + return ncclSuccess; +} + +// Depending on GDR, allocate memory on the CPU or GPU. +// host_flags is not used for now, but it is here for future use. +template +static ncclResult_t allocMemCPUAccessible(T **ptr, T **devPtr, size_t nelem, int host_flags, + void **gdrHandle, bool forceHost = false) { + if (ncclGdrCopy && !forceHost) { + NCCLCHECK(ncclGdrCudaCalloc(ptr, devPtr, nelem, gdrHandle)); + } else { + NCCLCHECK(ncclCuMemHostAlloc((void **)ptr, NULL, nelem * sizeof(T))); + memset((void *)*ptr, 0, nelem * sizeof(T)); + *devPtr = *ptr; + if (gdrHandle) *gdrHandle = NULL; // Mark as host allocated by nulling GDR handle + } + return ncclSuccess; +} + +// Depending on GDR, free memory on the CPU or GPU. +template +static ncclResult_t freeMemCPUAccessible(T *ptr, void *gdrHandle) { + if (gdrHandle != NULL) { // If a GDR handle exists, it was GDR memory + NCCLCHECK(ncclGdrCudaFree(gdrHandle)); + } else { // Otherwise, it was host memory (or GDR was off) + NCCLCHECK(ncclCuMemHostFree(ptr)); + } + return ncclSuccess; +} + +ncclResult_t ncclRmaProxyCreateContext(struct ncclComm *comm, void *collComm, ncclNetProperties_t props, + void **outRmaProxyCtx, ncclNetDeviceHandle_t **outDevHandle) { + // Get the GIN plugin interface + ncclGin_t *ginComm = (ncclGin_t *)comm->rmaState.rmaProxyState.ncclGin; + + // Allocate the RMA proxy context + struct ncclRmaProxyCtx *rmaProxyCtx = NULL; + NCCLCHECK(ncclCalloc(&rmaProxyCtx, 1)); + + rmaProxyCtx->comm = comm; + rmaProxyCtx->ginCollComm = collComm; + rmaProxyCtx->props = props; + + // Allocate the signals on the GPU and then register the memory region with the GIN plugin. + // Enforcing strong ordering on the signals mr is vital to ensure ordering between puts and signals. + size_t signalsBufSize = (comm->nRanks + 1) * sizeof(uint64_t); + NCCLCHECK(ncclCuMemAlloc((void **)&rmaProxyCtx->signalsDev, &rmaProxyCtx->signalsCumemhandle, + CU_MEM_HANDLE_TYPE_NONE, signalsBufSize)); + CUDACHECK(cudaMemset(rmaProxyCtx->signalsDev, 0, signalsBufSize)); + NCCLCHECK(ncclRmaProxyRegMrSym(ginComm, rmaProxyCtx->ginCollComm, rmaProxyCtx->props, rmaProxyCtx->signalsDev, signalsBufSize, + NCCL_PTR_CUDA, NCCL_NET_MR_FLAG_FORCE_SO, + &rmaProxyCtx->signalsMhandle, &rmaProxyCtx->signalsGinHandle)); + + // Allocate the host buffer to track the expected values of the signals + NCCLCHECK(ncclCalloc(&rmaProxyCtx->signalsHost, signalsBufSize)); + + // Allocate the sequence numbers for the per-rank network function descriptors + // These are allocated as CPU-accessible memory (either GDR or host memory) + NCCLCHECK(allocMemCPUAccessible(&rmaProxyCtx->opSeqs, &rmaProxyCtx->opSeqsDev, + comm->nRanks, 0, &rmaProxyCtx->opSeqsGdrHandle)); + NCCLCHECK(allocMemCPUAccessible(&rmaProxyCtx->readySeqs, &rmaProxyCtx->readySeqsDev, + comm->nRanks, 0, &rmaProxyCtx->readySeqsGdrHandle)); + NCCLCHECK(allocMemCPUAccessible(&rmaProxyCtx->doneSeqs, &rmaProxyCtx->doneSeqsDev, + comm->nRanks, 0, &rmaProxyCtx->doneSeqsGdrHandle)); + + // Allocate per-peer network function descriptor queues from permanent memory + // rmaProxyDescQueues: pending Descs waiting for readySeq + // rmaProxyInProgressQueues: Descs with issued operations waiting for completion + rmaProxyCtx->rmaProxyDescQueues = ncclMemoryStackAlloc>(&comm->memPermanent, comm->nRanks); + rmaProxyCtx->rmaProxyInProgressQueues = ncclMemoryStackAlloc>(&comm->memPermanent, comm->nRanks); + NCCLCHECK(ncclCalloc(&rmaProxyCtx->DescQueueLocks, comm->nRanks)); + for (int i = 0; i < comm->nRanks; i++) { + ncclIntruQueueConstruct(&rmaProxyCtx->rmaProxyDescQueues[i]); + ncclIntruQueueConstruct(&rmaProxyCtx->rmaProxyInProgressQueues[i]); + pthread_mutex_init(&rmaProxyCtx->DescQueueLocks[i], NULL); + } + + // Allocate and initialize device handle + ncclNetDeviceHandle_t *devHandle = NULL; + NCCLCHECK(ncclCalloc(&devHandle, 1)); + devHandle->netDeviceType = NCCL_NET_DEVICE_GIN_PROXY; + devHandle->netDeviceVersion = NCCL_GIN_PROXY_VERSION; + devHandle->handle = (void *)rmaProxyCtx; + devHandle->size = 0; + devHandle->needsProxyProgress = 1; + + rmaProxyCtx->devHandle = devHandle; + + *outDevHandle = devHandle; + *outRmaProxyCtx = rmaProxyCtx; + + return ncclSuccess; +} + +// Poll and test completion of InProgress Descs for a given peer +// Returns after testing head Desc (stops on first incomplete to enforce FIFO) +static ncclResult_t ncclRmaProxyPollCompletion(ncclGin_t *ncclGin, struct ncclRmaProxyCtx *ctx, int peer) { + while (true) { + struct ncclRmaProxyDesc *inProgressDesc = ncclIntruQueueHead(&ctx->rmaProxyInProgressQueues[peer]); + if (inProgressDesc == NULL) break; // No InProgress Descs + + int done = 0; + NCCLCHECK(ncclGin->test(ctx->ginCollComm, inProgressDesc->request, &done)); + if (done) { + INFO(NCCL_COLL, "Rank %d ncclRmaProxyPollCompletion: targetRank=%d descSeq=%lu COMPLETED, updating doneSeq", + ctx->comm->rank, inProgressDesc->targetRank, inProgressDesc->seq); + + // Update the doneSeq for the target rank with RELEASE to ensure GPU sees it + __atomic_store_n(&ctx->doneSeqs[inProgressDesc->targetRank], inProgressDesc->seq, __ATOMIC_RELEASE); // sync with the custreamWait aquire semantic + // Dequeue and free the completed Desc + ncclIntruQueueDequeue(&ctx->rmaProxyInProgressQueues[peer]); + ncclMemoryPoolFree(&ctx->comm->memPool_ncclRmaProxyDesc, inProgressDesc); + + free(inProgressDesc); + } else { + // Head is not done - stop testing to enforce FIFO completion order + break; + } + } + return ncclSuccess; +} + +// Poll and issue ready Pending Descs for a given peer +// Moves ready Descs from pending queue to InProgress queue +static ncclResult_t ncclRmaProxyPollDesc(ncclGin_t *ncclGin, struct ncclRmaProxyCtx *ctx, int peer) { + while (true) { + // Lock mutex to safely check and dequeue from pending queue + pthread_mutex_lock(&ctx->DescQueueLocks[peer]); + struct ncclRmaProxyDesc *pendingDesc = ncclIntruQueueHead(&ctx->rmaProxyDescQueues[peer]); + if (pendingDesc == NULL) { + pthread_mutex_unlock(&ctx->DescQueueLocks[peer]); + break; // No pending Descs + } + + // Check if this Desc is ready to be issued + uint64_t readySeq = __atomic_load_n(&ctx->readySeqs[peer], __ATOMIC_ACQUIRE); + if (readySeq >= pendingDesc->seq) { + // Dequeue while holding lock, then unlock before issuing (to minimize lock time) + ncclIntruQueueDequeue(&ctx->rmaProxyDescQueues[peer]); // this might moved before the previous if due to prefetching, but not a problem here + pthread_mutex_unlock(&ctx->DescQueueLocks[peer]); // release + + // Issue the network operation + if (pendingDesc->signal.op == 0) { + // No signal operation + NCCLCHECK(ncclGin->iput(ctx->ginCollComm, + pendingDesc->srcOff, pendingDesc->srcHandle, pendingDesc->size, + pendingDesc->dstOff, pendingDesc->dstHandle, + pendingDesc->targetRank, &pendingDesc->request)); + } else { + // Signal operation needed + NCCLCHECK(ncclGin->iputSignal(ctx->ginCollComm, + pendingDesc->srcOff, pendingDesc->srcHandle, pendingDesc->size, + pendingDesc->dstOff, pendingDesc->dstHandle, + pendingDesc->targetRank, pendingDesc->signal.offset, pendingDesc->signal.signalMhandle, + pendingDesc->signal.val, pendingDesc->signal.op, &pendingDesc->request)); + } + + // Enqueue to InProgress queue (no lock needed - progress thread only) + ncclIntruQueueEnqueue(&ctx->rmaProxyInProgressQueues[peer], pendingDesc); + + INFO(NCCL_COLL, "Rank %d ncclRmaProxyPollDesc: targetRank=%d descSeq=%lu readySeq=%lu srcOff=%lu srcHandle=%p dstOff=%lu dstHandle=%p size=%lu - issuing network operation", + ctx->comm->rank, pendingDesc->targetRank, pendingDesc->seq, readySeq, pendingDesc->srcOff, pendingDesc->srcHandle, pendingDesc->dstOff, pendingDesc->dstHandle, pendingDesc->size); + } else { + // ReadySeq not ready yet - stop processing this peer's pending queue to maintain FIFO order + pthread_mutex_unlock(&ctx->DescQueueLocks[peer]); + break; + } + } + return ncclSuccess; +} + +// Checks the RMA proxy progress. +ncclResult_t ncclRmaProxyProgress(ncclGin_t *ncclGin, void *rmaProxyCtx) { + struct ncclRmaProxyCtx *ctx = (struct ncclRmaProxyCtx *)rmaProxyCtx; + + // Loop through each peer's queues + for (int i = 0; i < ctx->comm->nRanks; i++) { + // Step 1: Poll completion of InProgress Descs + NCCLCHECK(ncclRmaProxyPollCompletion(ncclGin, ctx, i)); + + // Step 2: Poll and issue ready Pending Descs + NCCLCHECK(ncclRmaProxyPollDesc(ncclGin, ctx, i)); + } + return ncclSuccess; +} + +ncclResult_t ncclRmaProxyDestroyContext(ncclGin_t* ginComm, void* rmaProxyCtx){ + if (!rmaProxyCtx) return ncclSuccess; + struct ncclRmaProxyCtx *ctx = (struct ncclRmaProxyCtx *)rmaProxyCtx; + + // Free per-rank network function descriptor queues and their Descs + if (ctx->rmaProxyDescQueues) { + for (int i = 0; i < ctx->comm->nRanks; i++) { + struct ncclRmaProxyDesc *desc = ncclIntruQueueHead(&ctx->rmaProxyDescQueues[i]); + while (desc != NULL) { + struct ncclRmaProxyDesc *nextDesc = desc->next; + ncclIntruQueueDequeue(&ctx->rmaProxyDescQueues[i]); + ncclMemoryPoolFree(&ctx->comm->memPool_ncclRmaProxyDesc, desc); + desc = nextDesc; + } + } + } + if (ctx->rmaProxyInProgressQueues) { + for (int i = 0; i < ctx->comm->nRanks; i++) { + struct ncclRmaProxyDesc *desc = ncclIntruQueueHead(&ctx->rmaProxyInProgressQueues[i]); + while (desc != NULL) { + struct ncclRmaProxyDesc *nextDesc = desc->next; + ncclIntruQueueDequeue(&ctx->rmaProxyInProgressQueues[i]); + ncclMemoryPoolFree(&ctx->comm->memPool_ncclRmaProxyDesc, desc); + desc = nextDesc; + } + } + } + + // Destroy Desc queue locks + if (ctx->DescQueueLocks) { + for (int i = 0; i < ctx->comm->nRanks; i++) { + pthread_mutex_destroy(&ctx->DescQueueLocks[i]); + } + free(ctx->DescQueueLocks); + } + + // Free counters (using GDR-aware deallocation) + if (ctx->opSeqs) freeMemCPUAccessible(ctx->opSeqs, ctx->opSeqsGdrHandle); + if (ctx->readySeqs) freeMemCPUAccessible(ctx->readySeqs, ctx->readySeqsGdrHandle); + if (ctx->doneSeqs) freeMemCPUAccessible(ctx->doneSeqs, ctx->doneSeqsGdrHandle); + + // Free signals + if (ginComm && ctx->ginCollComm && ctx->signalsMhandle) + ginComm->deregMrSym(ctx->ginCollComm, ctx->signalsMhandle); + if (ctx->signalsDev) ncclCudaFree(ctx->signalsDev); + + // Free host signals buffer + if (ctx->signalsHost) free(ctx->signalsHost); + + ncclNetDeviceHandle_t *devHandle = (ncclNetDeviceHandle_t *)ctx->devHandle; + if (devHandle) { + // Note: devHandle->handle points to ctx itself, so we don't free it separately + free(devHandle); + } + + free(ctx); + + return ncclSuccess; +} + + +ncclResult_t ncclRmaProxyRegister(struct ncclComm* comm, void* address, size_t size, + void* rmaHostWins[NCCL_GIN_MAX_CONTEXTS], + ncclGinWindow_t rmaDevWins[NCCL_GIN_MAX_CONTEXTS]){ + struct ncclRmaProxyState* rmaProxyState = &comm->rmaState.rmaProxyState; + for (int n = 0; n < rmaProxyState->ginCommCount; n++) { + struct ncclRmaProxyCtx* ctx = (struct ncclRmaProxyCtx*)rmaProxyState->rmaProxyCtxs[n]; + NCCLCHECK(ncclRmaProxyRegMrSym(rmaProxyState->ncclGin, ctx->ginCollComm, ctx->props, address, size, + NCCL_PTR_CUDA, 0, &rmaHostWins[n], &rmaDevWins[n])); + if (rmaHostWins[n] == NULL) { + WARN("rank %d - GIN Symmetric register failed: buff %p, size %ld", comm->rank, address, size); + return ncclSystemError; + } + } + return ncclSuccess; +} + +ncclResult_t ncclRmaProxyDeregister(struct ncclComm* comm, void* rmaHostWins[NCCL_GIN_MAX_CONTEXTS]){ + struct ncclRmaProxyState* rmaProxyState = &comm->rmaState.rmaProxyState; + for (int n = 0; n < rmaProxyState->ginCommCount; n++) { + NCCLCHECK(rmaProxyState->ncclGin->deregMrSym(rmaProxyState->ginComms[n], rmaHostWins[n])); + } + return ncclSuccess; +} + +void* ncclRmaProxyProgressThread(void* rmaProxyState_) { + struct ncclRmaProxyState *rmaProxyState = (struct ncclRmaProxyState *)rmaProxyState_; + const int sig = ncclParamRmaProxyDumpSignal(); + if (sig != -1) signal(sig, ncclDumpRmaProxyState); + ncclLastRmaProxyState = rmaProxyState; + while (1) { + pthread_mutex_lock(&rmaProxyState->threadLock); + if (rmaProxyState->ginProgress == 1) { + pthread_mutex_unlock(&rmaProxyState->threadLock); + for (int n=0; nrmaProxyCtxCount; n++) { + ncclResult_t ret = ncclRmaProxyProgress(rmaProxyState->ncclGin, rmaProxyState->rmaProxyCtxs[n]); + if (ret != ncclSuccess) { + __atomic_store_n(&rmaProxyState->asyncResult, ret, __ATOMIC_RELEASE); + INFO(NCCL_ALL,"%s:%d -> %d [RMA Proxy Progress Thread]", __FILE__, __LINE__, ret); + rmaProxyState->ginProgress = -2; + return NULL; + } + } + sched_yield(); + } else if (rmaProxyState->ginProgress == -1) { + pthread_mutex_unlock(&rmaProxyState->threadLock); + return NULL; + } else if (rmaProxyState->ginProgress == 0) { + pthread_cond_wait(&rmaProxyState->threadCond, &rmaProxyState->threadLock); + pthread_mutex_unlock(&rmaProxyState->threadLock); + } else { + pthread_mutex_unlock(&rmaProxyState->threadLock); + INFO(NCCL_ALL,"%s:%d -> [RMA Proxy Progress Thread] state unknown %d", __FILE__, __LINE__, rmaProxyState->ginProgress); + rmaProxyState->ginProgress = -2; + return NULL; + } + } +} + +ncclResult_t ncclRmaProxyConnectOnce(struct ncclComm* comm) { + ncclResult_t ret = ncclSuccess; + struct ncclRmaProxyState *rmaProxyState = &comm->rmaState.rmaProxyState; + rmaProxyState->comm = comm; + if (rmaProxyState->ncclGin == NULL) { + WARN("GIN not supported."); + return ncclInvalidUsage; + } + if (ncclParamGinEnable() == 0) { + WARN("GIN is disabled."); + return ncclInternalError; + } + if (rmaProxyState->connected) return ncclSuccess; + + NCCLCHECK(rmaProxyState->ncclGin->init(&rmaProxyState->ginInstance, comm->commHash, ncclDebugLog)); + + int ndev = 0; + NCCLCHECK(rmaProxyState->ncclGin->devices(&ndev)); + if (ndev <= 0) { + WARN("No GIN-capable devices found."); + return ncclInternalError; + } + + ncclNetProperties_t props; + NCCLCHECK(rmaProxyState->ncclGin->getProperties(0, &props)); + rmaProxyState->ginType = props.netDeviceType; + if (((ncclParamGinType() != -1) && (rmaProxyState->ginType != ncclParamGinType())) || rmaProxyState->ginType != NCCL_NET_DEVICE_GIN_PROXY) { + WARN("GIN-capable device type mismatch."); + return ncclInternalError; + } + + int ginCommCount; + int64_t localNets[NCCL_TOPO_MAX_NODES]; + NCCLCHECK(ncclTopoGetLocalNets(comm->topo, comm->rank, localNets, &rmaProxyState->ginCommCount)); + ginCommCount = std::min(rmaProxyState->ginCommCount, NCCL_GIN_MAX_CONTEXTS); + ginCommCount = std::min(ginCommCount, ndev); + + int* allCommCounts = NULL; + void** handles = NULL; + char* allHandles = NULL; + + // Get the min local net count from all ranks + NCCLCHECK(ncclCalloc(&allCommCounts, comm->nRanks)); + allCommCounts[comm->rank] = ginCommCount; + NCCLCHECKGOTO(bootstrapAllGather(comm->bootstrap, allCommCounts, sizeof(int)), ret, fail); + for (int i = 0; i < comm->nRanks; i++) { + ginCommCount = std::min(ginCommCount, allCommCounts[i]); + } + free(allCommCounts); + allCommCounts = NULL; + + if (ginCommCount == 0) { + WARN("Gin connect : min local net count is zero"); + ret = ncclSystemError; + goto fail; + } + rmaProxyState->ginCommCount = ginCommCount; + + NCCLCHECKGOTO(ncclCalloc(&allHandles, (size_t)comm->nRanks * NCCL_NET_HANDLE_MAXSIZE), ret, fail); + NCCLCHECKGOTO(ncclCalloc(&handles, comm->nRanks), ret, fail); + for (int r = 0; r < comm->nRanks; r++) handles[r] = allHandles + r * NCCL_NET_HANDLE_MAXSIZE; + + for (int n = 0; n < ginCommCount; n++) { + void* listenComm; + NCCLCHECKGOTO( + rmaProxyState->ncclGin->listen(rmaProxyState->ginInstance, localNets[n], + allHandles + NCCL_NET_HANDLE_MAXSIZE * comm->rank, &listenComm), + ret, fail); + NCCLCHECKGOTO(bootstrapAllGather(comm->bootstrap, allHandles, NCCL_NET_HANDLE_MAXSIZE), ret, + fail); + NCCLCHECKGOTO(rmaProxyState->ncclGin->connect(comm->netContext, handles, comm->nRanks, comm->rank, + listenComm, rmaProxyState->ginComms + n), + ret, fail); + NCCLCHECKGOTO(rmaProxyState->ncclGin->getProperties(localNets[n], &rmaProxyState->props[n]), ret, fail); + NCCLCHECKGOTO(rmaProxyState->ncclGin->closeListen(listenComm), ret, fail); + } + free(handles); + handles = NULL; + free(allHandles); + allHandles = NULL; + + // Create virtual RMA proxy contexts + rmaProxyState->rmaProxyCtxCount = comm->config.numRmaCtx; + NCCLCHECK(ncclCalloc(&rmaProxyState->rmaProxyCtxs, rmaProxyState->rmaProxyCtxCount)); + NCCLCHECK(ncclCalloc(&rmaProxyState->rmaProxyDevHandles, rmaProxyState->rmaProxyCtxCount)); + for (int n = 0; n < rmaProxyState->rmaProxyCtxCount; n++) { + // Round-robin mapping to physical GIN communicator contexts + int ginCommIdx = n % rmaProxyState->ginCommCount; + NCCLCHECKGOTO(ncclRmaProxyCreateContext(comm, rmaProxyState->ginComms[ginCommIdx], rmaProxyState->props[ginCommIdx], + &rmaProxyState->rmaProxyCtxs[n], &rmaProxyState->rmaProxyDevHandles[n]), + ret, fail); + } + + // Check whether we need proxy progress and if so, start / wake up the progress thread. + rmaProxyState->needsProxyProgress = 0; + for (int n = 0; n < rmaProxyState->rmaProxyCtxCount; n++) { + if (rmaProxyState->rmaProxyDevHandles[n]->needsProxyProgress) rmaProxyState->needsProxyProgress = 1; + } + if (rmaProxyState->needsProxyProgress) { + rmaProxyState->ginProgress = 1; + pthread_mutex_init(&rmaProxyState->threadLock, NULL); + pthread_cond_init(&rmaProxyState->threadCond, NULL); + PTHREADCHECK(pthread_create(&rmaProxyState->thread, NULL, ncclRmaProxyProgressThread, rmaProxyState), "pthread_create"); + ncclSetThreadName(rmaProxyState->thread, "NCCL RMA Proxy Progress%2d", comm->cudaDev); + } + + INFO(NCCL_INIT, "Rank %d ncclRmaProxyConnectOnce: ginCommCount %d rmaProxyCtxCount:%d needsProxyProgress %d", comm->rank, ginCommCount, rmaProxyState->rmaProxyCtxCount, rmaProxyState->needsProxyProgress); + +exit: + if (ret == ncclSuccess) rmaProxyState->connected = true; + return ret; +fail: + free(allCommCounts); + free(allHandles); + free(handles); + goto exit; +} + +ncclResult_t ncclRmaProxyFinalize(struct ncclComm* comm) { + struct ncclRmaProxyState* rmaProxyState = &comm->rmaState.rmaProxyState; + if (!rmaProxyState->connected) return ncclSuccess; + + if (rmaProxyState->needsProxyProgress) { + pthread_mutex_lock(&rmaProxyState->threadLock); + rmaProxyState->ginProgress = -1; + pthread_cond_signal(&rmaProxyState->threadCond); + pthread_mutex_unlock(&rmaProxyState->threadLock); + PTHREADCHECK(pthread_join(rmaProxyState->thread, NULL), "pthread_join"); + } + + // Destroy all virtual RMA proxy contexts + if (rmaProxyState->rmaProxyCtxs) { + for (int n = 0; n < rmaProxyState->rmaProxyCtxCount; n++) { + if (rmaProxyState->rmaProxyCtxs[n] != NULL) { + NCCLCHECK(ncclRmaProxyDestroyContext(rmaProxyState->ncclGin, rmaProxyState->rmaProxyCtxs[n])); + rmaProxyState->rmaProxyCtxs[n] = NULL; + } + } + // Free the dynamically allocated context array + free(rmaProxyState->rmaProxyCtxs); + rmaProxyState->rmaProxyCtxs = NULL; + } + + // Free the device handles array + if (rmaProxyState->rmaProxyDevHandles) { + free(rmaProxyState->rmaProxyDevHandles); + rmaProxyState->rmaProxyDevHandles = NULL; + } + + // Close all physical GIN communicators + for (int n = 0; n < rmaProxyState->ginCommCount; n++) { + if (rmaProxyState->ginComms[n] != NULL) { + NCCLCHECK(rmaProxyState->ncclGin->closeColl(rmaProxyState->ginComms[n])); + rmaProxyState->ginComms[n] = NULL; + } + } + + // Finalize the GIN instance + NCCLCHECK(rmaProxyState->ncclGin->finalize(rmaProxyState->ginInstance)); + memset(rmaProxyState, 0, sizeof(*rmaProxyState)); + return ncclSuccess; +} + +ncclResult_t ncclRmaPutProxy(struct ncclComm* comm, struct ncclKernelPlan* plan, cudaStream_t stream){ + ncclResult_t ret = ncclSuccess; + + // Make sure the RMA proxy is connected + if (!comm->rmaState.rmaProxyState.connected) { + WARN("RMA proxy is not connected"); + return ncclInternalError; + } + + int ctx = plan->rmaArgs->ctx; + int nRmaTasksProxy = plan->rmaArgs->nRmaTasksProxy; + struct ncclRmaProxyCtx * rmaProxyCtx = (struct ncclRmaProxyCtx *)comm->rmaState.rmaProxyState.rmaProxyCtxs[ctx]; + + // Allocate 2*nRmaTasksProxy CUstreamBatchMemOpParams + CUstreamBatchMemOpParams* batchParams = NULL; + NCCLCHECK(ncclCalloc(&batchParams, 2*nRmaTasksProxy)); + + for (int i = 0; i < nRmaTasksProxy; i++) { + struct ncclTaskRma* task = ncclIntruQueueHead(&plan->rmaTaskQueueProxy); + ncclIntruQueueDequeue(&plan->rmaTaskQueueProxy); + + assert(task->ctx == ctx); + + struct ncclRmaProxyDesc *desc = NULL; + NCCLCHECK(ncclCalloc(&desc, 1)); + desc->srcOff = task->srcWinOffset; + desc->srcHandle = ncclDevrGetRmaDevWin(task->srcWinHost, ctx); + desc->dstOff = task->peerWinOffset; + desc->dstHandle = ncclDevrGetRmaDevWin(task->peerWinHost, ctx); + desc->size = task->count * ncclTypeSize(task->datatype); + desc->targetRank = task->peer; + desc->seq = rmaProxyCtx->opSeqs[task->peer]++; + desc->rmaDescState = ncclRmaDescStatePending; + desc->request = NULL; + + // If the signal mode is none, we do not need to set the signal operation + if (task->signalMode == NCCL_SIGNAL_NONE) { + desc->signal.op = 0; + } + // If the signal mode is aggregate, we use the shared counter to aggregate the signals + else if (task->signalMode == NCCL_SIGNAL_AGGREGATE) { + desc->signal.op = NCCL_NET_SIGNAL_OP_ADD; + desc->signal.offset = comm->nRanks * sizeof(uint64_t); // Shared aggregate signal counter + desc->signal.signalMhandle = rmaProxyCtx->signalsMhandle; + desc->signal.val = 1; + } + // If the signal mode is distinct, we use the per-rank signal for the target rank + else if (task->signalMode == NCCL_SIGNAL_DISTINCT) { + desc->signal.op = NCCL_NET_SIGNAL_OP_ADD; + desc->signal.offset = comm->rank * sizeof(uint64_t); // Write to our rank slot in peer's buffer + desc->signal.signalMhandle = rmaProxyCtx->signalsMhandle; + desc->signal.val = 1; + } + + // Prepare the readySeq write operation + batchParams[i].writeValue.operation = CU_STREAM_MEM_OP_WRITE_VALUE_64; + batchParams[i].writeValue.address = (CUdeviceptr)&rmaProxyCtx->readySeqsDev[task->peer]; + batchParams[i].writeValue.value = desc->seq; + batchParams[i].writeValue.flags = CU_STREAM_WRITE_VALUE_DEFAULT; + + // Prepare the doneSeq wait operation + batchParams[i+nRmaTasksProxy].waitValue.operation = CU_STREAM_MEM_OP_WAIT_VALUE_64; + batchParams[i+nRmaTasksProxy].waitValue.address = (CUdeviceptr)&rmaProxyCtx->doneSeqsDev[task->peer]; + batchParams[i+nRmaTasksProxy].waitValue.value = desc->seq; + batchParams[i+nRmaTasksProxy].waitValue.flags = CU_STREAM_WAIT_VALUE_GEQ; + + INFO(NCCL_COLL, "ncclRmaPutProxy enqueued Desc: rank=%d peer=%d ctx=%d size=%ld signalMode=%d readySeq=%lu doneSeq=%lu", + comm->rank, task->peer, ctx, task->count * ncclTypeSize(task->datatype), task->signalMode, (uint64_t)desc->seq, (uint64_t)desc->seq); + + // Enqueue the network function descriptor to the target rank's queue in the RMA context + // Use mutex to protect against concurrent dequeue from progress thread + pthread_mutex_lock(&rmaProxyCtx->DescQueueLocks[task->peer]); //lock has an acquire semantic -> enqueue can not move up, every thing before enqueue can move down // but it can not go beyond the unlock which has a release semantic + ncclIntruQueueEnqueue(&rmaProxyCtx->rmaProxyDescQueues[task->peer], desc); + pthread_mutex_unlock(&rmaProxyCtx->DescQueueLocks[task->peer]); // unlock has a release semantic -> enqueue can not move down, but stream memory op might get moved before enqueue + + // Free the task + ncclMemoryPoolFree(&comm->memPool_ncclTaskRma, task); + } + + // Execute both operations in a single batch after all Descs are enqueued + CUCHECKGOTO(cuStreamBatchMemOp(stream, 2*nRmaTasksProxy, batchParams, 0), ret, fail); + +exit: + if (batchParams) free(batchParams); + return ret; +fail: + goto exit; +} + + + +ncclResult_t ncclRmaWaitSignalProxy(struct ncclComm* comm, struct ncclKernelPlan* plan, cudaStream_t stream){ + ncclResult_t ret = ncclSuccess; + + // Make sure the RMA proxy is connected + if (!comm->rmaState.rmaProxyState.connected) { + WARN("RMA proxy is not connected"); + return ncclInternalError; + } + + int ctx = plan->rmaArgs->ctx; + struct ncclRmaProxyCtx* proxyCtx = (struct ncclRmaProxyCtx*)comm->rmaState.rmaProxyState.rmaProxyCtxs[ctx]; + + struct ncclTaskRma* task = ncclIntruQueueHead(&plan->rmaTaskQueueProxy); + ncclIntruQueueDequeue(&plan->rmaTaskQueueProxy); + + // Assert task func is ncclFuncWaitSignal + assert(task->func == ncclFuncWaitSignal); + // Assert task context is the same as the plan context + assert(task->ctx == ctx); + // Assert the plan has exactly one RMA proxy task + assert(plan->rmaArgs->nRmaTasksProxy == 1); + + size_t opIdx = 0; + CUstreamBatchMemOpParams* batchParams = nullptr; + + NCCLCHECK(ncclCalloc(&batchParams, task->npeers)); + + // Use per-rank signal for the target rank if the signal mode is distinct + if (task->signalMode == NCCL_SIGNAL_DISTINCT) { + for (int i = 0; i < task->npeers; i++) { + int peerRank = task->peers[i]; + // Calculate the expected signal value from this peer + uint64_t waitValue = proxyCtx->signalsHost[peerRank] + task->nsignals[i]; + + // Update our expectation for future waits + proxyCtx->signalsHost[peerRank] = waitValue; + + // Add wait operation to batch + batchParams[opIdx] = {}; + batchParams[opIdx].waitValue.operation = CU_STREAM_MEM_OP_WAIT_VALUE_64; + batchParams[opIdx].waitValue.address = (CUdeviceptr)&proxyCtx->signalsDev[peerRank]; + batchParams[opIdx].waitValue.value64 = waitValue; + batchParams[opIdx].waitValue.flags = CU_STREAM_WAIT_VALUE_GEQ; + opIdx++; + } + + // Execute all wait operations in a single batch + CUCHECKGOTO(cuStreamBatchMemOp(stream, opIdx, batchParams, 0), ret, fail); + } + // Use shared signal whenever possible if the signal mode is aggregate + else if (task->signalMode == NCCL_SIGNAL_AGGREGATE) { + uint64_t networkAggregateSignals = 0; + + // Process all peers + for (int i = 0; i < task->npeers; i++) { + // Network peers: accumulate signals for aggregate counter + networkAggregateSignals += task->nsignals[i]; + } + + // Add wait operation for network aggregate signal if we have network peers + if (networkAggregateSignals > 0) { + // Update host-side tracking of aggregate signal + uint64_t currentAggregateValue = proxyCtx->signalsHost[comm->nRanks]; + uint64_t expectedAggregateValue = currentAggregateValue + networkAggregateSignals; + proxyCtx->signalsHost[comm->nRanks] = expectedAggregateValue; + + // Add aggregate wait operation to batch + batchParams[opIdx] = {}; + batchParams[opIdx].waitValue.operation = CU_STREAM_MEM_OP_WAIT_VALUE_64; + batchParams[opIdx].waitValue.address = (CUdeviceptr)&proxyCtx->signalsDev[comm->nRanks]; + batchParams[opIdx].waitValue.value64 = expectedAggregateValue; + batchParams[opIdx].waitValue.flags = CU_STREAM_WAIT_VALUE_GEQ; + opIdx++; + } + + // Execute all wait operations in a single batch + CUCHECKGOTO(cuStreamBatchMemOp(stream, opIdx, batchParams, 0), ret, fail); + } + + // Free the task + ncclMemoryPoolFree(&comm->memPool_ncclTaskRma, task); + +exit: + if (batchParams) free(batchParams); + return ret; +fail: + goto exit; +} \ No newline at end of file