diff --git a/src/graph/paths.cc b/src/graph/paths.cc index ee8cbcd8f..edf8a4aed 100644 --- a/src/graph/paths.cc +++ b/src/graph/paths.cc @@ -522,7 +522,7 @@ ncclResult_t ncclTopoIsGdrAvail(struct ncclTopoSystem* system, int rank, bool *a NCCL_PARAM(NetForceFlush, "NET_FORCE_FLUSH", 0); // Determine whether we need to flush the GDR recv buffers -ncclResult_t ncclTopoNeedFlush(struct ncclComm* comm, int64_t netId, int netDev, int rank, int* flush) { +ncclResult_t ncclTopoNeedFlush(struct ncclComm* comm, int64_t netId, int netDev, int rank, bool netManaged, int* flush) { *flush = 1; ncclNetProperties_t props; NCCLCHECK(comm->ncclNet->getProperties(netDev, &props)); @@ -531,7 +531,7 @@ ncclResult_t ncclTopoNeedFlush(struct ncclComm* comm, int64_t netId, int netDev, struct ncclTopoSystem* system = comm->topo; NCCLCHECK(ncclTopoRankToIndex(system, rank, &g, /*showWarn=*/true)); #if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) - *flush = 1; + *flush = !netManaged; #else struct ncclTopoNode* gpu = system->nodes[GPU].nodes+g; // unused variable - compiler warning // Flush is required on Ampere and earlier diff --git a/src/include/graph.h b/src/include/graph.h index 7b5bb9537..d71123950 100644 --- a/src/include/graph.h +++ b/src/include/graph.h @@ -44,7 +44,7 @@ enum ncclTopoGdrMode { ncclTopoGdrModeNum = 3 }; ncclResult_t ncclTopoCheckGdr(struct ncclTopoSystem* topo, int rank, int64_t netId, int read, enum ncclTopoGdrMode* gdrMode); -ncclResult_t ncclTopoNeedFlush(struct ncclComm* comm, int64_t netId, int netDev, int rank, int* flush); +ncclResult_t ncclTopoNeedFlush(struct ncclComm* comm, int64_t netId, int netDev, int rank, bool netManaged, int* flush); ncclResult_t ncclTopoIsGdrAvail(struct ncclTopoSystem* system, int rank, bool *avail); ncclResult_t ncclTopoCheckNet(struct ncclTopoSystem* system, int rank1, int rank2, int* net); int ncclPxnDisable(struct ncclComm* comm); diff --git a/src/transport/coll_net.cc b/src/transport/coll_net.cc index d1420e899..7926f565f 100644 --- a/src/transport/coll_net.cc +++ b/src/transport/coll_net.cc @@ -196,7 +196,7 @@ static ncclResult_t recvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->rank, netId, 0, &req.useGdr)); recv->conn.flags |= req.useGdr ? NCCL_DIRECT_NIC : 0; // Determine whether we need to flush the GDR buffer on recv or not - if (req.useGdr) NCCLCHECK(ncclTopoNeedFlush(comm, netId, req.netDev, myInfo->rank, &req.needFlush)); + if (req.useGdr) NCCLCHECK(ncclTopoNeedFlush(comm, netId, req.netDev, myInfo->rank, false, &req.needFlush)); recv->proxyConn.tpLocalRank = comm->topParentLocalRanks[comm->localRank]; NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_COLLNET, 0, myInfo->rank, &recv->proxyConn)); diff --git a/src/transport/net.cc b/src/transport/net.cc index 1ad8de99f..2b93c27d0 100644 --- a/src/transport/net.cc +++ b/src/transport/net.cc @@ -290,7 +290,10 @@ static ncclResult_t recvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph // Determine whether we need to flush the GDR buffer on recv or not if (req.useGdr) { - NCCLCHECK(ncclTopoNeedFlush(comm, netId, req.netDev, myInfo->rank, &req.needFlush)); + int managed; + // Flush is not needed when the hardware supports direct managed memory access from host + CUDACHECK(hipDeviceGetAttribute(&managed, hipDeviceAttributeDirectManagedMemAccessFromHost, 0)); + NCCLCHECK(ncclTopoNeedFlush(comm, netId, req.netDev, myInfo->rank, (bool)managed, &req.needFlush)); CUDACHECK(hipDeviceGetAttribute((int*)&req.curr_hdp_reg, hipDeviceAttributeHdpMemFlushCntl, myInfo->cudaDev)); recv->conn.curr_hdp_reg = req.curr_hdp_reg; }