Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ add_subdirectory(nccl_device)
add_subdirectory(ras)
add_subdirectory(scheduler)
add_subdirectory(gin)
add_subdirectory(rma)

add_compile_options(-fmacro-prefix-map=${CMAKE_CURRENT_SOURCE_DIR}/=)

Expand All @@ -55,6 +56,7 @@ list(APPEND LIBSRCFILES
${SCHEDULER_SOURCES}
${GIN_SOURCES}
${DOCA_SOURCES}
${RMA_SOURCES}
)

###################### Create a shared NCCL library ############################
Expand Down
1 change: 1 addition & 0 deletions src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ LIBSRCFILES := \
$(wildcard nccl_device/*.cc) \
$(wildcard scheduler/*.cc) \
$(wildcard gin/*.cc) \
$(wildcard rma/*.cc) \
$(filter-out ras/client.cc,$(wildcard ras/*.cc))
BINSRCFILES := ras/client.cc

Expand Down
44 changes: 44 additions & 0 deletions src/collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ const char* ncclFuncToString(ncclFunc_t fn) {
case ncclFuncScatter: return "Scatter";
case ncclFuncSendRecv: return "SendRecv";
case ncclFuncSend: return "Send";
case ncclFuncPut: return "Put";
case ncclFuncSignal: return "Signal";
case ncclFuncWaitSignal: return "WaitSignal";
default: return "Invalid";
}
}
Expand Down Expand Up @@ -214,3 +217,44 @@ ncclResult_t ncclRecv(void* recvbuff, size_t count, ncclDataType_t datatype, int
1, 1 };
return ncclEnqueueCheck(&info);
}

NCCL_API(ncclResult_t, ncclPut, int ctx, const void* localbuff, size_t count, ncclDataType_t datatype,
int peer, size_t peerWinOffset, ncclWindow_t peerWin, ncclSignalMode_t signalMode, ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclPut(int ctx, const void* localbuff, size_t count, ncclDataType_t datatype,
int peer, size_t peerWinOffset, ncclWindow_t peerWin, ncclSignalMode_t signalMode, ncclComm_t comm, cudaStream_t stream) {
NVTX3_FUNC_WITH_PARAMS(Put, NcclNvtxParamsPut,
NVTX3_PAYLOAD(comm ? comm->commHash : 0, count * ncclTypeSize(datatype), peer, ctx));

struct ncclInfo info = { ncclFuncPut, "Put",
localbuff, NULL, count, datatype, ncclSum, peer, comm, stream, /* Args */
1, 1, /* chunkSteps, sliceSteps */
peerWinOffset, peerWin, signalMode, ctx, /* peerWinOffset, peerWin, signalMode, ctx */
NULL, NULL, 0 }; /* peers, nsignals, npeers */
return ncclEnqueueCheck(&info);
}

NCCL_API(ncclResult_t, ncclSignal, int ctx, int peer, ncclSignalMode_t signalMode, ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclSignal(int ctx, int peer, ncclSignalMode_t signalMode, ncclComm_t comm, cudaStream_t stream) {
NVTX3_FUNC_WITH_PARAMS(Signal, NcclNvtxParamsSignal,
NVTX3_PAYLOAD(comm ? comm->commHash : 0, peer, ctx));

struct ncclInfo info = { ncclFuncSignal, "Signal",
NULL, NULL, 0, ncclInt8, ncclSum, peer, comm, stream, /* Args */
1, 1, /* chunkSteps, sliceSteps */
0, NULL, signalMode, ctx, /* peerWinOffset, peerWin, signalMode, ctx */
NULL, NULL, 0 }; /* peers, nsignals, npeers */
return ncclEnqueueCheck(&info);
}

NCCL_API(ncclResult_t, ncclWaitSignal, int ctx, int* peers, int* nsignals, int npeers, ncclSignalMode_t signalMode, ncclComm_t comm, cudaStream_t stream);
ncclResult_t ncclWaitSignal(int ctx, int* peers, int* nsignals, int npeers, ncclSignalMode_t signalMode, ncclComm_t comm, cudaStream_t stream) {
NVTX3_FUNC_WITH_PARAMS(WaitSignal, NcclNvtxParamsWaitSignal,
NVTX3_PAYLOAD(comm ? comm->commHash : 0, npeers, ctx));

struct ncclInfo info = { ncclFuncWaitSignal, "WaitSignal",
NULL, NULL, 0, ncclInt32, ncclSum, 0, comm, stream, /* Args */
1, 1, /* chunkSteps, sliceSteps */
0, NULL, signalMode, ctx, /* peerWinOffset, peerWin, signalMode, ctx */
peers, nsignals, npeers }; /* peers, nsignals, npeers */
return ncclEnqueueCheck(&info);
}
31 changes: 30 additions & 1 deletion src/dev_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "dev_runtime.h"
#include "comm.h"
#include "rma/rma.h"
#include "device.h"
#include "transport.h"
#include "group.h"
Expand All @@ -23,6 +24,8 @@ struct ncclDevrMemory {
size_t bigOffset; // offset in big VA space
void* ginHostWins[NCCL_GIN_MAX_CONTEXTS];
ncclGinWindow_t ginDevWins[NCCL_GIN_MAX_CONTEXTS];
void* rmaHostWins[NCCL_GIN_MAX_CONTEXTS];
ncclGinWindow_t rmaDevWins[NCCL_GIN_MAX_CONTEXTS];
};

struct ncclDevrWindowSorted {
Expand Down Expand Up @@ -354,6 +357,12 @@ static ncclResult_t symMemoryRegisterGin(struct ncclComm* comm, struct ncclDevrM
return ncclSuccess;
}

static ncclResult_t symMemoryRegisterRma(struct ncclComm* comm, struct ncclDevrMemory* mem) {
NCCLCHECK(ncclRmaProxyConnectOnce(comm));
NCCLCHECK(ncclRmaProxyRegister(comm, mem->primaryAddr, mem->size, mem->rmaHostWins, mem->rmaDevWins));
return ncclSuccess;
}

// On success we take caller's reference on memHandle.
// Due to multicast binds for each pre-exiting team, this function requires
// caller do a world barrier before returning to user.
Expand Down Expand Up @@ -402,6 +411,13 @@ static ncclResult_t symMemoryObtain(
NCCLCHECKGOTO(symMemoryRegisterGin(comm, mem), ret, fail_mem_space_teams);
}

// ginEnabled is set in ncclDevrCommCreateInternal, which might not be called for RMA proxy
// so we introduce rmaProxyEnabled to track if RMA proxy is enabled
devr->rmaProxyEnabled = comm->nNodes > 1 && comm->config.numRmaCtx > 0 && ncclParamGinType() == NCCL_NET_DEVICE_GIN_PROXY;
if (devr->rmaProxyEnabled) {
NCCLCHECKGOTO(symMemoryRegisterRma(comm, mem), ret, fail_mem_space_teams);
}

// Add to list of mems.
mem->next = devr->memHead;
devr->memHead = mem;
Expand Down Expand Up @@ -431,6 +447,9 @@ static void symMemoryDropRef(
if (devr->ginEnabled) {
ncclGinDeregister(comm, mem->ginHostWins);
}
if (devr->rmaProxyEnabled) {
ncclRmaProxyDeregister(comm, mem->rmaHostWins);
}
for (struct ncclDevrTeam* t = devr->teamHead; t != nullptr; t = t->next) {
symUnbindTeamMemory(comm, t, mem);
}
Expand Down Expand Up @@ -1020,7 +1039,6 @@ ncclResult_t ncclDevCommDestroy(
return ncclSuccess;
}


// Get the corresponding pointer in another lsa rank's symmetric memory window
ncclResult_t ncclDevrGetLsaRankPtr(struct ncclComm* comm, struct ncclDevrWindow* winHost, size_t offset, int lsaRank, void** outPtr) {
if (winHost == nullptr || outPtr == nullptr) {
Expand All @@ -1044,6 +1062,17 @@ ncclResult_t ncclDevrGetLsaRankPtr(struct ncclComm* comm, struct ncclDevrWindow*
return ncclSuccess;
}

// Get the RMA device window handle for a specific context
ncclGinWindow_t ncclDevrGetRmaDevWin(struct ncclDevrWindow* winHost, int ctx) {
if (winHost == nullptr || winHost->memory == nullptr) {
return nullptr;
}
if (ctx < 0 || ctx >= NCCL_GIN_MAX_CONTEXTS) {
return nullptr;
}
return winHost->memory->rmaDevWins[ctx];
}

// Get the multicast address for a given team
ncclResult_t ncclDevrGetLsaTeamPtrMC(struct ncclComm* comm, struct ncclDevrWindow* winHost, size_t offset, struct ncclTeam lsaTeam, void** outPtr){
if (winHost == nullptr || outPtr == nullptr) {
Expand Down
139 changes: 134 additions & 5 deletions src/enqueue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "ce_coll.h"
#include "nvtx.h"
#include "scheduler.h"
#include "rma/rma.h"

#include <cstring> // std::memcpy
#include <cinttypes> // PRIx64
Expand Down Expand Up @@ -1149,7 +1150,7 @@ namespace {
}

static ncclResult_t uploadWork(struct ncclComm* comm, struct ncclKernelPlan* plan) {
if (plan->isSymColl || plan->isCeColl) return ncclSuccess;
if (plan->isSymColl || plan->isCeColl || plan->isRma) return ncclSuccess;

size_t workBytes = plan->workBytes;
size_t batchBytes = plan->nWorkBatches*sizeof(struct ncclDevWorkBatch);
Expand Down Expand Up @@ -1423,7 +1424,8 @@ ncclResult_t ncclLaunchPrepare(struct ncclComm* comm) {

if (planner->nTasksColl + planner->nTasksP2p != 0 ||
!ncclIntruQueueEmpty(&planner->collSymTaskQueue) ||
!ncclIntruQueueEmpty(&planner->collCeTaskQueue)) {
!ncclIntruQueueEmpty(&planner->collCeTaskQueue) ||
planner->nTasksRma != 0) {
do {
memset(&planner->wipPlan, 0, sizeof(planner->wipPlan));

Expand All @@ -1435,7 +1437,13 @@ ncclResult_t ncclLaunchPrepare(struct ncclComm* comm) {
plan->workStorageType = persistent ? ncclDevWorkStorageTypePersistent
: ncclDevWorkStorageTypeFifo;

if (!ncclIntruQueueEmpty(&planner->collCeTaskQueue)) {
if (planner->nTasksRma != 0) {
NCCLCHECKGOTO(scheduleRmaTasksToPlan(comm, plan), result, failure);
if (plan->isRma && plan->rmaArgs != NULL && plan->rmaArgs->nRmaTasks > 0) {
ncclIntruQueueEnqueue(&planner->planQueue, plan);
nPlans += 1;
}
} else if (!ncclIntruQueueEmpty(&planner->collCeTaskQueue)) {
struct ncclTaskColl* task = ncclIntruQueueHead(&planner->collCeTaskQueue);
plan->isCeColl = true;
plan->ceCollArgs = ncclMemoryStackAlloc<struct ncclCeCollArgs>(&comm->memScoped);
Expand All @@ -1453,7 +1461,7 @@ ncclResult_t ncclLaunchPrepare(struct ncclComm* comm) {
ncclMemoryPoolFree(&comm->memPool_ncclTaskColl, task);
nPlans += 1;
} else {
if (!ncclIntruQueueEmpty(&planner->collSymTaskQueue)) {
if (!ncclIntruQueueEmpty(&planner->collSymTaskQueue)) {
NCCLCHECKGOTO(ncclSymmetricTaskScheduler(comm, &planner->collSymTaskQueue, plan), result, failure);
}
else {
Expand Down Expand Up @@ -1483,7 +1491,8 @@ ncclResult_t ncclLaunchPrepare(struct ncclComm* comm) {
}
} while (planner->nTasksColl + planner->nTasksP2p != 0 ||
!ncclIntruQueueEmpty(&planner->collSymTaskQueue) ||
!ncclIntruQueueEmpty(&planner->collCeTaskQueue));
!ncclIntruQueueEmpty(&planner->collCeTaskQueue) ||
planner->nTasksRma != 0);

struct ncclKernelPlan* planHead = ncclIntruQueueHead(&planner->planQueue);
planner->unlaunchedPlansHead = planHead;
Expand Down Expand Up @@ -2542,6 +2551,124 @@ static ncclResult_t ceCollTaskAppend(
return ncclSuccess;
}

static ncclResult_t rmaTaskAppend(
struct ncclComm* comm,
struct ncclInfo* info) {
struct ncclKernelPlanner *planner = &comm->planner;

void const* srcBuff = info->sendbuff;

if (!comm->symmetricSupport){
WARN("ncclPut: symmetric support is not enabled");
return ncclInvalidArgument;
}

// Check if user context is valid
if (info->ctx < 0 || info->ctx >= comm->config.numRmaCtx) {
WARN("User context index %d out of bounds (min: 0, max: %d)", info->ctx, comm->config.numRmaCtx - 1);
return ncclInvalidArgument;
}

// Initialize window pointers - only needed for Put and Signal
struct ncclDevrWindow* peerWinHost = NULL;
struct ncclDevrWindow* srcWinHost = NULL;
size_t srcWinOffset = 0;

if (info->coll == ncclFuncPut) {
// Validate peer window with detailed debugging
if (info->peerWin == NULL) {
WARN("ncclPut: peerWin is NULL");
return ncclInvalidArgument;
}

struct ncclWindow_vidmem* peerWinDevHost = NULL;
NCCLCHECK(ncclShadowPoolToHost(&comm->devrState.shadows, info->peerWin, &peerWinDevHost));
peerWinHost = (struct ncclDevrWindow*)peerWinDevHost->winHost;

// Validate source buffer and window
if (srcBuff == NULL) {
WARN("ncclPut: srcBuff is NULL");
return ncclInvalidArgument;
}
NCCLCHECK(ncclDevrFindWindow(comm, srcBuff, &srcWinHost));
if (srcWinHost == NULL || !(srcWinHost->winFlags & NCCL_WIN_COLL_SYMMETRIC)) {
WARN("ncclPut: srcWinHost is not in a valid symmetric window");
return ncclInvalidArgument;
}
srcWinOffset = (char*)srcBuff - (char*)srcWinHost->userPtr;
}
else if (info->coll == ncclFuncSignal) {
// Check if count is valid
if (info->count != 0) {
WARN("ncclSignal: count must be 0");
return ncclInvalidArgument;
}
// Check if signalMode is valid
if (info->signalMode == NCCL_SIGNAL_NONE) {
WARN("ncclSignal: signalMode is none");
return ncclInvalidArgument;
}
}
else if (info->coll == ncclFuncWaitSignal) {
// Check if signalMode, peers and nsignals are valid
if (info->signalMode == NCCL_SIGNAL_NONE || info->peers == NULL || info->nsignals == NULL || info->npeers == 0) {
WARN("ncclWaitSignal: invalid arguments");
return ncclInvalidArgument;
}
}

// Check if RMA CE needs initialization
if (!comm->rmaState.rmaCeState.initialized && ncclIntruQueueEmpty(&comm->rmaCeInitTaskQueue)) {
struct ncclRmaCeInitTask* ceTask;
NCCLCHECK(ncclCalloc(&ceTask, 1));
ceTask->comm = comm;
ncclIntruQueueEnqueue(&comm->rmaCeInitTaskQueue, ceTask);
ncclGroupCommJoin(comm, ncclGroupTaskTypeSymRegister);
}

// Must be in thread local group before tasks can be alloc'd in `comm->memScoped`.
ncclGroupCommJoin(info->comm, ncclGroupTaskTypeCollective);
NCCLCHECK(ncclPlannerSetCapturingGraph(comm, info));
struct ncclTaskRma* t = ncclMemoryPoolAlloc<struct ncclTaskRma>(&comm->memPool_ncclTaskRma, &comm->memPermanent);

t->func = info->coll;
t->srcBuff = srcBuff;
t->srcWinOffset = srcWinOffset;
t->srcWinHost = srcWinHost;
t->count = info->count;
t->datatype = info->datatype;
t->bytes = t->count * ncclTypeSize(t->datatype);
t->ctx = info->ctx;
t->peer = info->root;
t->peerWinOffset = info->peerWinOffset;
t->peerWinHost = peerWinHost;
t->signalMode = info->signalMode;

// Copy the peers and nsignals arrays if present
if (info->peers != NULL && info->nsignals != NULL && info->npeers > 0) {
int* peersCopy = ncclMemoryStackAlloc<int>(&comm->memScoped, info->npeers);
int* nsignalsCopy = ncclMemoryStackAlloc<int>(&comm->memScoped, info->npeers);
for (int i = 0; i < info->npeers; i++) {
peersCopy[i] = info->peers[i];
nsignalsCopy[i] = info->nsignals[i];
}
t->peers = peersCopy;
t->nsignals = nsignalsCopy;
} else {
t->peers = info->peers;
t->nsignals = info->nsignals;
}
t->npeers = info->npeers;

t->eActivationMask = __atomic_load_n(&ncclProfilerEventMask, __ATOMIC_RELAXED);

planner->nTasksRma++;
// Enqueue the task into the appropriate context queue
ncclIntruQueueEnqueue(&planner->rmaTaskQueues[t->ctx], t);

return ncclSuccess;
}

// Converts `info` to a task and adds it to `comm->planner`. The exception is with
// single rank communicators, collectives are issued as `ncclMemcpyAsync`s and
// thus don't need a task.
Expand All @@ -2550,6 +2677,8 @@ static ncclResult_t taskAppend(struct ncclComm* comm, struct ncclInfo* info) {

if (info->coll == ncclFuncSend || info->coll == ncclFuncRecv) {
NCCLCHECK(p2pTaskAppend(comm, info, info->coll, collAPI, (void*)info->recvbuff, info->count, info->datatype, info->root));
} else if (info->coll == ncclFuncPut || info->coll == ncclFuncSignal || info->coll == ncclFuncWaitSignal) {
NCCLCHECK(rmaTaskAppend(comm, info));
} else {
// Empty collectives can be discarded.
if (info->count == 0) return ncclSuccess;
Expand Down
Loading