diff --git a/CMakeLists.txt b/CMakeLists.txt index 892968cce3..87fd9a0b2f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -439,6 +439,7 @@ set(SRC_FILES src/bootstrap.cc src/channel.cc src/collectives.cc + src/commDump.cc src/debug.cc src/enqueue.cc src/group.cc diff --git a/src/commDump.cc b/src/commDump.cc new file mode 100644 index 0000000000..4d35502599 --- /dev/null +++ b/src/commDump.cc @@ -0,0 +1,26 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include "nccl.h" +#include +#include "comm.h" +#include "device.h" +#include "archinfo.h" + +__attribute__ ((visibility("default"))) +ncclResult_t ncclCommDump( + const ncclComm_t comm, + std::unordered_map& map) { + if (comm == nullptr) { + WARN("ncclCommDump comm is null"); + return ncclSuccess; + } + if (comm->proxyState->proxyTrace == nullptr) { + WARN("ncclCommDump comm->proxyState->proxyTrace is null"); + return ncclSuccess; + } + + WARN("ncclCommDump() ProxyTrace:"); + WARN("%s", comm->proxyState->proxyTrace->dump().c_str()); + + return ncclSuccess; +} diff --git a/src/include/proxy_trace/proxy_trace.h b/src/include/proxy_trace/proxy_trace.h index 0a1bfd8bd4..fdfcd20e53 100644 --- a/src/include/proxy_trace/proxy_trace.h +++ b/src/include/proxy_trace/proxy_trace.h @@ -15,6 +15,7 @@ #endif #include #include +#include #include namespace facebook_rccl { @@ -107,6 +108,10 @@ struct ProxyTraceOp { ProxyOpStepStatus status{ProxyOpStepStatus::UNINITIALIZED}; std::chrono::time_point startTs{}; std::chrono::time_point lastUpdateTs{}; + std::unordered_map> timestamps{ + {ProxyCounterTypes::POSTED, {}}, + {ProxyCounterTypes::KERNEL_COPY_READY, {}}, + }; void computeStatus(); // str the entry to a string std::string str(); @@ -123,11 +128,51 @@ using ProxyActiveOpIdTracker = std::unordered_map>; class ProxyTrace { -public: - ProxyTrace(int32_t rank) : myRank(rank) {} + public: + ProxyTrace(int32_t rank) : myRank(rank) {}; + + ProxyTrace() = delete; ProxyTrace(const ProxyTrace &) = delete; ProxyTrace &operator=(const ProxyTrace &) = delete; - bool initialized{false}; + + // + // Public APIs called by the proxy thread and ncclCommDump(). + // All these APIs lock the same shared mutex before executing. + // + + void updateProxyOpCounter( + const ProxyTraceRecordKey& traceKey, + ProxyCounterTypes counter, + int64_t val); + + void setProxyOpTimestamp( + const ProxyTraceRecordKey& traceKey, + ProxyCounterTypes counter); + + void addNewProxyOp( + ProxyTraceRecordKey& key, + const ProxyTraceExtraInfo& extraInfo, + ProxyOpType opType, + int channelId, + int nSteps, + uint32_t nbytes, + int peerRank); + + // Dump all trace for a given communicator + std::string dump(uint64_t commHash); + + // Dump all active ops + std::string dump(); + + // + // Getters called by public APIs as well as unit tests. + // These are not thread-safe unless called by the public APIs above. + // + + ProxyTraceOp *getProxyTraceOpPtr(const ProxyTraceRecordKey &traceKey); + float getMapSizeMB() const; + +private: void checkOpCompleted(const ProxyTraceRecordKey &key); void addNewProxyTraceOpImpl(const ProxyTraceRecordKey &key, @@ -139,21 +184,11 @@ class ProxyTrace { // If the opCount is not found, create a new entry for it and return 0 int64_t getOrCreateProxyOpId(uint64_t commHash, uint64_t opCount); - // Dump all trace for a given communicator - std::string dump(uint64_t commHash); - - // Dump all active ops - std::string dump(); - // check if an active send/recv operation exists for a given commHash:opCount bool checkActiveOpExist(uint64_t commHash, uint64_t opCount, uint32_t proxyOpId) const; - ProxyTraceOp *getProxyTraceOpPtr(const ProxyTraceRecordKey &traceKey); - float getMapSizeMB() const; - void resetAll(); - -private: + mutable std::mutex mutex_; int myRank{-1}; // Current active send/recv operations. @@ -170,15 +205,4 @@ class ProxyTrace { std::deque> finishedOps; }; struct ncclProxySubArgs; -void proxyTraceInit(std::unique_ptr &proxyTrace, - int32_t rank, uint64_t commHash); - -void updateProxyOpCounter(std::unique_ptr &proxyTraceObj, - const ProxyTraceRecordKey &traceKey, - ProxyCounterTypes counter, int64_t val); - -void addNewProxyOp( - std::unique_ptr &proxyTraceObj, ProxyTraceRecordKey &key, - const ProxyTraceExtraInfo &extraInfo, ProxyOpType opType, int channelId, - int nSteps, uint32_t nbytes, int peerRank); } // namespace facebook_rccl diff --git a/src/init.cc b/src/init.cc index 5ef4c40481..adfbac8cd6 100644 --- a/src/init.cc +++ b/src/init.cc @@ -439,7 +439,7 @@ static ncclResult_t commFree(ncclComm_t comm) { free(comm->connectRecv); if (rcclParamEnableProxyTrace()) { - WARN("ProxyTrace:"); + WARN("commFree() ProxyTrace:"); if (comm->proxyState && comm->proxyState->proxyTrace){ WARN("%s", comm->proxyState->proxyTrace->dump().c_str()); } diff --git a/src/misc/proxy_trace/proxy_trace.cc b/src/misc/proxy_trace/proxy_trace.cc index c93b961853..ea4be74fe6 100644 --- a/src/misc/proxy_trace/proxy_trace.cc +++ b/src/misc/proxy_trace/proxy_trace.cc @@ -24,13 +24,6 @@ static std::unordered_map {facebook_rccl::ProxyOpStepStatus::UNINITIALIZED, "ILLEGAL"}, }; -void facebook_rccl::ProxyTrace::resetAll() { - activeOps.clear(); - activeOpIdTracker.clear(); - myRank = -1; - initialized = false; -} - bool facebook_rccl::ProxyTrace::checkActiveOpExist(uint64_t commHash, uint64_t opCount, uint32_t proxyOpId) const { @@ -137,6 +130,7 @@ void facebook_rccl::ProxyTraceOp::computeStatus() { } std::string facebook_rccl::ProxyTrace::dump(uint64_t commHash) { + std::lock_guard lock(mutex_); std::string result = fmt::format("commDump for commHash:{}\n", commHash); std::map sortedDumpStrMap; for (auto &opCountMap : activeOps.at(commHash)) { @@ -154,6 +148,7 @@ std::string facebook_rccl::ProxyTrace::dump(uint64_t commHash) { } std::string facebook_rccl::ProxyTrace::dump() { + std::lock_guard lock(mutex_); std::string result = "commDump for all active ops "; result += fmt::format("mapSizeMB:{:.2f}\n", getMapSizeMB()); @@ -182,7 +177,7 @@ std::string facebook_rccl::ProxyTrace::dump() { std::string facebook_rccl::ProxyTraceOp::str() { computeStatus(); std::string ret = fmt::format( - "createT:{}, lastT:{}, cntNm:{}, {}, {}, {}->{}({}), " + "createT:{}, lastT:{}, postT:{}, sendT:{}, cntNm:{}, {}, {}, {}->{}({}), " "chan:{}, status:{}, ns:{}, nb:{}, po:{}, ke:{}, tail/h:{}, recvT:{}, " "connSz/h:{}, trans:{}, flushed:{}, recvd:{}, done:{}\n", std::chrono::duration_cast( @@ -191,6 +186,12 @@ std::string facebook_rccl::ProxyTraceOp::str() { std::chrono::duration_cast( lastUpdateTs.time_since_epoch()) .count(), + std::chrono::duration_cast( + timestamps[facebook_rccl::ProxyCounterTypes::POSTED].time_since_epoch()) + .count(), + std::chrono::duration_cast( + timestamps[facebook_rccl::ProxyCounterTypes::KERNEL_COPY_READY].time_since_epoch()) + .count(), static_cast(lastUpdatingCounter), traceKey.str(), extraInfo.str(), myRank, peerRank, opType == ProxyOpType::SEND ? "S" : "R", channelId, proxyStepStatusStrMap[status], nSteps, nbytes, @@ -220,44 +221,43 @@ float facebook_rccl::ProxyTrace::getMapSizeMB() const { return size / 1024.0 / 1024.0; } -void facebook_rccl::proxyTraceInit(std::unique_ptr &proxyTrace, - int32_t rank, uint64_t commHash) { - if (proxyTrace) { - WARN("[proxyTrace] Initializing non-empty proxyTrace! rank: %d, commHash: " - "%lu", - rank, commHash); - return; +void facebook_rccl::ProxyTrace::updateProxyOpCounter( + const ProxyTraceRecordKey& traceKey, + ProxyCounterTypes counter, + int64_t val) { + std::lock_guard lock(mutex_); + auto traceOpPtr = getProxyTraceOpPtr(traceKey); + if (traceOpPtr) { + traceOpPtr->counters[counter] = val; + traceOpPtr->lastUpdateTs = std::chrono::high_resolution_clock::now(); + traceOpPtr->lastUpdatingCounter = counter; + checkOpCompleted(traceKey); } - INFO(NCCL_PROXY, "Initializing ProxyTrace, rank: %d, commHash: %lu", rank, - commHash); - proxyTrace = std::make_unique(rank); - proxyTrace->initialized = true; } -void facebook_rccl::updateProxyOpCounter( - std::unique_ptr &proxyTraceObj, - const ProxyTraceRecordKey &traceKey, ProxyCounterTypes counter, - int64_t val) { - if (proxyTraceObj) { - auto traceOpPtr = proxyTraceObj->getProxyTraceOpPtr(traceKey); - if (traceOpPtr) { - traceOpPtr->counters[counter] = val; - traceOpPtr->lastUpdateTs = std::chrono::high_resolution_clock::now(); - traceOpPtr->lastUpdatingCounter = counter; - proxyTraceObj->checkOpCompleted(traceKey); - } +void facebook_rccl::ProxyTrace::setProxyOpTimestamp( + const ProxyTraceRecordKey& traceKey, + ProxyCounterTypes counter) { + std::lock_guard lock(mutex_); + auto traceOpPtr = getProxyTraceOpPtr(traceKey); + if (!traceOpPtr || traceOpPtr->timestamps.find(counter) == traceOpPtr->timestamps.end()) { + return; } + + traceOpPtr->timestamps[counter] = std::chrono::high_resolution_clock::now(); } -void facebook_rccl::addNewProxyOp(std::unique_ptr &proxyTraceObj, - ProxyTraceRecordKey &key, - const ProxyTraceExtraInfo &extraInfo, - ProxyOpType opType, int channelId, int nSteps, - uint32_t nbytes, int peerRank) { - if (proxyTraceObj) { - auto opId = proxyTraceObj->getOrCreateProxyOpId(key.commHash, key.opCount); - key.proxyOpId = opId; - proxyTraceObj->addNewProxyTraceOpImpl(key, extraInfo, opType, channelId, - nSteps, nbytes, peerRank); - } +void facebook_rccl::ProxyTrace::addNewProxyOp( + ProxyTraceRecordKey& key, + const ProxyTraceExtraInfo& extraInfo, + ProxyOpType opType, + int channelId, + int nSteps, + uint32_t nbytes, + int peerRank) { + std::lock_guard lock(mutex_); + auto opId = getOrCreateProxyOpId(key.commHash, key.opCount); + key.proxyOpId = opId; + addNewProxyTraceOpImpl( + key, extraInfo, opType, channelId, nSteps, nbytes, peerRank); } diff --git a/src/nccl.h.in b/src/nccl.h.in index 88c4c24344..a5f571cd06 100644 --- a/src/nccl.h.in +++ b/src/nccl.h.in @@ -948,4 +948,16 @@ ncclResult_t pncclGroupSimulateEnd(ncclSimInfo_t* simInfo); } // end extern "C" #endif +#ifdef __cplusplus +#define NCCL_COMM_DUMP + +#include +#include +/* Dump NCCL current internal state for a given communicator in a key-value store format. + * define outside extern "C"{} to pass C++ template */ +ncclResult_t ncclCommDump(ncclComm_t comm, std::unordered_map& map); +#else +#warning "NCCL C++ API is disabled because C compiler is being used. Please use a C++ compiler to build NCCL." +#endif + #endif // end include guard diff --git a/src/proxy.cc b/src/proxy.cc index f70e393e56..2ba97e203c 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -1873,7 +1873,8 @@ ncclResult_t ncclProxyInit(struct ncclComm* comm, struct ncclSocket* sock, union comm->proxyState->peerAddresses = peerAddresses; comm->proxyState->peerAddressesUDS = peerAddressesUDS; if (rcclParamEnableProxyTrace()) { - facebook_rccl::proxyTraceInit(comm->proxyState->proxyTrace, comm->rank, comm->commHash); + INFO(NCCL_PROXY, "Initializing ProxyTrace, rank: %d, commHash: %lu", comm->rank, comm->commHash); + comm->proxyState->proxyTrace = std::make_unique(comm->rank); } // UDS support diff --git a/src/transport/net.cc b/src/transport/net.cc index 1ad8de99fa..4e24d5985d 100644 --- a/src/transport/net.cc +++ b/src/transport/net.cc @@ -1270,9 +1270,16 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct resources->step = sub->base + sub->nsteps; sub->posted = sub->transmitted = sub->done = 0; ncclProfilerRecordProxyOpEventState(s, args, ncclProfilerProxyOpInProgress_v4); - facebook_rccl::addNewProxyOp(proxyState->proxyTrace, sub->traceKey, - sub->traceInfo, facebook_rccl::ProxyOpType::SEND, - sub->channelId, sub->nsteps, sub->nbytes, sub->peer); + if (proxyState->proxyTrace) { + proxyState->proxyTrace->addNewProxyOp( + sub->traceKey, + sub->traceInfo, + facebook_rccl::ProxyOpType::SEND, + sub->channelId, + sub->nsteps, + sub->nbytes, + sub->peer); + } if (!sub->reg) sub->sendMhandle = resources->mhandles[args->protocol]; } @@ -1313,7 +1320,10 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct sub->posted += args->sliceSteps; } ncclProfilerRecordProxyStepEventState(s, args, postedStepId, ncclProfilerProxyStepSendGPUWait); - facebook_rccl::updateProxyOpCounter(proxyState->proxyTrace, sub->traceKey, facebook_rccl::ProxyCounterTypes::POSTED, sub->posted); + if (proxyState->proxyTrace) { + proxyState->proxyTrace->updateProxyOpCounter(sub->traceKey, facebook_rccl::ProxyCounterTypes::POSTED, sub->posted); + proxyState->proxyTrace->setProxyOpTimestamp(sub->traceKey, facebook_rccl::ProxyCounterTypes::POSTED); + } args->idle = 0; continue; } @@ -1322,15 +1332,14 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct int buffSlot = (sub->base+sub->transmitted)%NCCL_STEPS; volatile uint64_t* recvTail = &resources->recvMem->tail; uint64_t tail = sub->base + sub->transmitted; - facebook_rccl::updateProxyOpCounter(proxyState->proxyTrace, sub->traceKey, - facebook_rccl::ProxyCounterTypes::RECV_TAIL, *recvTail); - - facebook_rccl::updateProxyOpCounter(proxyState->proxyTrace, sub->traceKey, - facebook_rccl::ProxyCounterTypes::TAIL_OR_HEAD, tail); - - facebook_rccl::updateProxyOpCounter( - proxyState->proxyTrace, sub->traceKey, - facebook_rccl::ProxyCounterTypes::FIFO_SZ_OR_HEAD_CACHE, connFifo[buffSlot].size); + if (proxyState->proxyTrace) { + proxyState->proxyTrace->updateProxyOpCounter(sub->traceKey, + facebook_rccl::ProxyCounterTypes::RECV_TAIL, *recvTail); + proxyState->proxyTrace->updateProxyOpCounter(sub->traceKey, + facebook_rccl::ProxyCounterTypes::TAIL_OR_HEAD, tail); + proxyState->proxyTrace->updateProxyOpCounter(sub->traceKey, + facebook_rccl::ProxyCounterTypes::FIFO_SZ_OR_HEAD_CACHE, connFifo[buffSlot].size); + } if (connFifo[buffSlot].size != -1 && (*recvTail > tail || p == NCCL_PROTO_LL)) { // We have something to receive, let's check if it's completely ready. @@ -1379,8 +1388,10 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct *resources->curr_hdp_reg = 1; } ncclProfilerRecordProxyStepEventState(s, args, transmittedStepId, ncclProfilerProxyStepSendPeerWait_v4); - facebook_rccl::updateProxyOpCounter(proxyState->proxyTrace, sub->traceKey, - facebook_rccl::ProxyCounterTypes::KERNEL_COPY_READY, sub->reg ? 1: sub->transmitted + args->sliceSteps); + if (proxyState->proxyTrace) { + proxyState->proxyTrace->updateProxyOpCounter(sub->traceKey, + facebook_rccl::ProxyCounterTypes::KERNEL_COPY_READY, sub->reg ? 1: sub->transmitted + args->sliceSteps); + } // Data is ready, try to send. // Coverity complains about the size here as pointing to an out-of-scope temporary. Which is nonsense, // since size is a plain integer. @@ -1394,6 +1405,9 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct if (ignoreCompletion) *requestPtr = (void *)NCCL_NET_OPTIONAL_RECV_COMPLETION; NCCLCHECK(proxyState->ncclNet->isend(resources->netSendComm, buff, size, resources->tpRank, sub->sendMhandle, phandle, requestPtr)); if (*requestPtr != NULL) { + if (proxyState->proxyTrace) { + proxyState->proxyTrace->setProxyOpTimestamp(sub->traceKey, facebook_rccl::ProxyCounterTypes::KERNEL_COPY_READY); + } #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_NET_SEND_ENTRY) && defined(ENABLE_NPKIT_EVENT_NET_SEND_EXIT) NpKit::CollectCpuEvent( NPKIT_EVENT_NET_SEND_ENTRY, @@ -1413,8 +1427,10 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct sub->transSize = size; sub->transmitted += args->sliceSteps; ncclProfilerRecordProxyStepEventState(s, args, transmittedStepId, ncclProfilerProxyStepSendWait); - facebook_rccl::updateProxyOpCounter(proxyState->proxyTrace, sub->traceKey, - facebook_rccl::ProxyCounterTypes::TRANSMITTED, sub->transmitted); + if (proxyState->proxyTrace) { + proxyState->proxyTrace->updateProxyOpCounter(sub->traceKey, + facebook_rccl::ProxyCounterTypes::TRANSMITTED, sub->transmitted); + } args->idle = 0; continue; } @@ -1480,6 +1496,10 @@ static ncclResult_t sendProxyProgress(struct ncclProxyState* proxyState, struct TRACE(NCCL_NET, "sendProxy [%ld/%d/%d] request %p done", sub->done, buffSlot, sub->nsteps, sub->requests[buffSlot]); sub->done += args->sliceSteps; ncclProfilerStopProxyStepEvent(s, args, doneStepId); + if (proxyState->proxyTrace) { + proxyState->proxyTrace->updateProxyOpCounter(sub->traceKey, + facebook_rccl::ProxyCounterTypes::DONE, sub->done); + } if (resources->shared == 0) { volatile uint64_t* sendHead = resources->gdcSync ? resources->gdcSync : &resources->sendMem->head; *sendHead = sub->base + sub->done; @@ -1546,8 +1566,16 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct sub->regBufferReady = 0; for (int i=0; iproxyTrace, sub->traceKey, sub->traceInfo, - facebook_rccl::ProxyOpType::RECV, sub->channelId, sub->nsteps, sub->nbytes, sub->peer); + if (proxyState->proxyTrace) { + proxyState->proxyTrace->addNewProxyOp( + sub->traceKey, + sub->traceInfo, + facebook_rccl::ProxyOpType::RECV, + sub->channelId, + sub->nsteps, + sub->nbytes, + sub->peer); + } if (!sub->reg) sub->recvMhandle = resources->mhandles[args->protocol]; } @@ -1644,8 +1672,10 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct sub->posted += args->sliceSteps; ncclProfilerRecordProxyStepEventState(s+i, args, postedStepId, ncclProfilerProxyStepRecvWait); - facebook_rccl::updateProxyOpCounter(proxyState->proxyTrace, - sub->traceKey, facebook_rccl::ProxyCounterTypes::POSTED, sub->posted); + if (proxyState->proxyTrace) { + proxyState->proxyTrace->updateProxyOpCounter( + sub->traceKey, facebook_rccl::ProxyCounterTypes::POSTED, sub->posted); + } } args->idle = 0; } @@ -1693,7 +1723,9 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct sub->transSize = sizes[i]; sub->received += args->sliceSteps; ncclProfilerRecordProxyStepEventState(s+i, args, receivedStepId, ncclProfilerProxyStepRecvFlushWait); - facebook_rccl::updateProxyOpCounter(proxyState->proxyTrace, sub->traceKey, facebook_rccl::ProxyCounterTypes::RECEIVED, sub->received); + if (proxyState->proxyTrace) { + proxyState->proxyTrace->updateProxyOpCounter(sub->traceKey, facebook_rccl::ProxyCounterTypes::RECEIVED, sub->received); + } if (step < sub->nsteps) { struct recvNetResources* resources = (struct recvNetResources*) (sub->connection->transportResources); if (resources->useGdr) needFlush |= resources->needFlush; @@ -1754,7 +1786,9 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct if (subGroup->requests[step%NCCL_STEPS]){ for(int i=0; igroupSize; i++) { struct ncclProxySubArgs* sub = subGroup + i; - facebook_rccl::updateProxyOpCounter(proxyState->proxyTrace, sub->traceKey, facebook_rccl::ProxyCounterTypes::FLUSHED, sub->received); + if (proxyState->proxyTrace) { + proxyState->proxyTrace->updateProxyOpCounter(sub->traceKey, facebook_rccl::ProxyCounterTypes::FLUSHED, sub->received); + } } } if (once) { @@ -1783,13 +1817,17 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct sub->transmitted += args->sliceSteps; ncclProfilerRecordProxyStepEventState(s+i, args, transmittedStepId, ncclProfilerProxyStepRecvGPUWait); - facebook_rccl::updateProxyOpCounter(proxyState->proxyTrace, sub->traceKey, facebook_rccl::ProxyCounterTypes::TRANSMITTED, sub->transmitted); + if (proxyState->proxyTrace) { + proxyState->proxyTrace->updateProxyOpCounter(sub->traceKey, facebook_rccl::ProxyCounterTypes::TRANSMITTED, sub->transmitted); + } if (step < sub->nsteps) { __sync_synchronize(); struct recvNetResources* resources = (struct recvNetResources*) (sub->connection->transportResources); volatile uint64_t* recvTail = resources->gdcSync ? resources->gdcSync : &resources->recvMem->tail; *recvTail = sub->base + sub->transmitted; - facebook_rccl::updateProxyOpCounter(proxyState->proxyTrace, sub->traceKey, facebook_rccl::ProxyCounterTypes::RECV_TAIL, *recvTail); + if (proxyState->proxyTrace) { + proxyState->proxyTrace->updateProxyOpCounter(sub->traceKey, facebook_rccl::ProxyCounterTypes::RECV_TAIL, *recvTail); + } if (resources->gdcSync) wc_store_fence(); // Flush out WC write } } @@ -1808,8 +1846,10 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct struct recvNetResources* resources = (struct recvNetResources*) (sub->connection->transportResources); volatile uint64_t* sendHead = &resources->sendMem->head; uint64_t done = *sendHead; - facebook_rccl::updateProxyOpCounter(proxyState->proxyTrace, sub->traceKey, facebook_rccl::ProxyCounterTypes::TAIL_OR_HEAD, done); - facebook_rccl::updateProxyOpCounter(proxyState->proxyTrace, sub->traceKey, facebook_rccl::ProxyCounterTypes::FIFO_SZ_OR_HEAD_CACHE, sub->base + sub->done); + if (proxyState->proxyTrace) { + proxyState->proxyTrace->updateProxyOpCounter(sub->traceKey, facebook_rccl::ProxyCounterTypes::TAIL_OR_HEAD, done); + proxyState->proxyTrace->updateProxyOpCounter(sub->traceKey, facebook_rccl::ProxyCounterTypes::FIFO_SZ_OR_HEAD_CACHE, sub->base + sub->done); + } while (done > sub->base + sub->done && // LL and LL128 can acknowledge 0-bytes send before they even happen. Don't go past what we transmitted. sub->transmitted > sub->done) { @@ -1822,6 +1862,9 @@ static ncclResult_t recvProxyProgress(struct ncclProxyState* proxyState, struct int doneStepId = sub->done; sub->done += args->sliceSteps; ncclProfilerStopProxyStepEvent(s+i, args, doneStepId); + if (proxyState->proxyTrace) { + proxyState->proxyTrace->updateProxyOpCounter(sub->traceKey, facebook_rccl::ProxyCounterTypes::DONE, sub->done); + } args->idle = 0; if (sub->done == sub->nsteps) { args->done++; diff --git a/test/proxy_trace/ProxyTraceUnitTests.cpp b/test/proxy_trace/ProxyTraceUnitTests.cpp index a62d394442..949f1ee43e 100644 --- a/test/proxy_trace/ProxyTraceUnitTests.cpp +++ b/test/proxy_trace/ProxyTraceUnitTests.cpp @@ -9,6 +9,7 @@ #include "proxy_trace/proxy_trace.h" #include #include +#include #include namespace RcclUnitTesting { @@ -21,7 +22,7 @@ class ProxyTraceTestFixture : public ::testing::Test { int nSteps = 10; void SetUp() override { proxyState = new ncclProxyState(); - facebook_rccl::proxyTraceInit(proxyState->proxyTrace, 0, commHash); + proxyState->proxyTrace = std::make_unique(0); EXPECT_NE(proxyState->proxyTrace, nullptr); sub1 = new ncclProxySubArgs(); sub2 = new ncclProxySubArgs(); @@ -36,9 +37,12 @@ class ProxyTraceTestFixture : public ::testing::Test { delete proxyState; } void AddTraceOp(ncclProxySubArgs *sub, facebook_rccl::ProxyOpType opType) { - facebook_rccl::addNewProxyOp(proxyState->proxyTrace, sub->traceKey, - sub->traceInfo, opType, sub->channelId, - sub->nsteps, sub->nbytes, sub->peer); + proxyState->proxyTrace->addNewProxyOp( + sub->traceKey, + sub->traceInfo, + opType, + sub->channelId, + sub->nsteps, sub->nbytes, sub->peer); } }; @@ -49,16 +53,10 @@ TEST_F(ProxyTraceTestFixture, nonEmptySingleton) { TEST_F(ProxyTraceTestFixture, addTraceOp) { auto &tracer = proxyState->proxyTrace; - EXPECT_EQ(tracer->getOrCreateProxyOpId(sub1->traceKey.commHash, - sub1->traceKey.opCount), - 0); AddTraceOp(sub1, facebook_rccl::ProxyOpType::SEND); EXPECT_EQ(sub1->traceKey.proxyOpId, 0); AddTraceOp(sub2, facebook_rccl::ProxyOpType::RECV); EXPECT_EQ(sub2->traceKey.proxyOpId, 1); - EXPECT_EQ(tracer->getOrCreateProxyOpId(sub1->traceKey.commHash, - sub1->traceKey.opCount), - 2); auto traceRecordPtr = tracer->getProxyTraceOpPtr(sub1->traceKey); EXPECT_EQ(traceRecordPtr->opType, facebook_rccl::ProxyOpType::SEND); } @@ -73,9 +71,10 @@ TEST_F(ProxyTraceTestFixture, getMapSizeMB) { EXPECT_GT(size2, size1); // finish sub1 sub1->done = nSteps; - facebook_rccl::updateProxyOpCounter(tracer, sub1->traceKey, - facebook_rccl::ProxyCounterTypes::DONE, - sub1->done); + tracer->updateProxyOpCounter( + sub1->traceKey, + facebook_rccl::ProxyCounterTypes::DONE, + sub1->done); // sub1 is now serialized and should be moved from activeOps to finishedOps auto size3 = tracer->getMapSizeMB(); EXPECT_GT(size3, size1); @@ -84,13 +83,14 @@ TEST_F(ProxyTraceTestFixture, getMapSizeMB) { TEST_F(ProxyTraceTestFixture, updateTraceOp) { auto &tracer = proxyState->proxyTrace; AddTraceOp(sub1, facebook_rccl::ProxyOpType::SEND); - facebook_rccl::updateProxyOpCounter( - tracer, sub1->traceKey, - facebook_rccl::ProxyCounterTypes::KERNEL_COPY_READY, 1); - facebook_rccl::updateProxyOpCounter( - tracer, sub1->traceKey, facebook_rccl::ProxyCounterTypes::POSTED, 3); - facebook_rccl::updateProxyOpCounter( - tracer, sub1->traceKey, facebook_rccl::ProxyCounterTypes::TRANSMITTED, 2); + tracer->updateProxyOpCounter( + sub1->traceKey, + facebook_rccl::ProxyCounterTypes::KERNEL_COPY_READY, + 1); + tracer->updateProxyOpCounter( + sub1->traceKey, facebook_rccl::ProxyCounterTypes::POSTED, 3); + tracer->updateProxyOpCounter( + sub1->traceKey, facebook_rccl::ProxyCounterTypes::TRANSMITTED, 2); auto traceRecordPtr = tracer->getProxyTraceOpPtr(sub1->traceKey); EXPECT_NE(traceRecordPtr, nullptr); @@ -110,25 +110,12 @@ TEST_F(ProxyTraceTestFixture, updateTraceOp2) { AddTraceOp(sub1, facebook_rccl::ProxyOpType::SEND); int64_t rand = 123456789; sub1->posted = rand; - facebook_rccl::updateProxyOpCounter(tracer, sub1->traceKey, - facebook_rccl::ProxyCounterTypes::POSTED, - sub1->posted); + tracer->updateProxyOpCounter(sub1->traceKey, + facebook_rccl::ProxyCounterTypes::POSTED, + sub1->posted); auto traceRecordPtr = tracer->getProxyTraceOpPtr(sub1->traceKey); EXPECT_EQ(traceRecordPtr->counters[facebook_rccl::ProxyCounterTypes::POSTED], rand); } -TEST_F(ProxyTraceTestFixture, memoryReclaim) { - auto &tracer = proxyState->proxyTrace; - tracer->resetAll(); - AddTraceOp(sub1, facebook_rccl::ProxyOpType::SEND); - sub1->done = nSteps; - facebook_rccl::updateProxyOpCounter(tracer, sub1->traceKey, - facebook_rccl::ProxyCounterTypes::DONE, - sub1->done); - auto traceRecordPtr = tracer->getProxyTraceOpPtr(sub1->traceKey); - EXPECT_EQ(traceRecordPtr, nullptr); - EXPECT_GT(tracer->getMapSizeMB(), 0); -} - -} // namespace RcclUnitTesting \ No newline at end of file +} // namespace RcclUnitTesting