From f309cfdf56a16b448137c77d377a1e56e522da5c Mon Sep 17 00:00:00 2001 From: Harold Huang Date: Mon, 14 Jul 2025 19:50:46 +0800 Subject: [PATCH] support binding client socket address to specific ifname --- src/bootstrap.cc | 12 +++---- src/include/socket.h | 10 +++--- src/misc/socket.cc | 67 +++++++++++++++++++++++++++---------- src/proxy.cc | 4 +-- src/ras/collectives.cc | 2 +- src/ras/peers.cc | 2 +- src/ras/ras.cc | 16 ++++----- src/ras/rasnet.cc | 30 ++++++++--------- src/transport/net_ib.cc | 10 +++--- src/transport/net_socket.cc | 6 ++-- 10 files changed, 96 insertions(+), 63 deletions(-) diff --git a/src/bootstrap.cc b/src/bootstrap.cc index f05337249..afaf7d594 100644 --- a/src/bootstrap.cc +++ b/src/bootstrap.cc @@ -258,7 +258,7 @@ static ncclResult_t setFilesLimit() { static ncclResult_t rootSend(union ncclSocketAddress* addr, uint64_t magic, union ringConnectInfo* info) { ncclResult_t res = ncclSuccess; struct ncclSocket sock; - NCCLCHECKGOTO(ncclSocketInit(&sock, addr, magic, ncclSocketTypeBootstrap), res, fail); + NCCLCHECKGOTO(ncclSocketInit(&sock, &bootstrapNetIfAddr, addr, magic, ncclSocketTypeBootstrap), res, fail); NCCLCHECKGOTO(ncclSocketConnect(&sock), res, fail); NCCLCHECKGOTO(socketSend(&sock, info, sizeof(union ringConnectInfo)), res, fail); NCCLCHECK(ncclSocketClose(&sock)); @@ -381,7 +381,7 @@ ncclResult_t bootstrapCreateRoot(struct ncclBootstrapHandle* handle, bool idFrom pthread_t thread; NCCLCHECK(ncclCalloc(&listenSock, 1)); - NCCLCHECKGOTO(ncclSocketInit(listenSock, &handle->addr, handle->magic, ncclSocketTypeBootstrap, NULL, 0), ret, fail); + NCCLCHECKGOTO(ncclSocketInit(listenSock, &handle->addr, NULL, handle->magic, ncclSocketTypeBootstrap, NULL, 0), ret, fail); NCCLCHECKGOTO(ncclSocketListen(listenSock), ret, fail); NCCLCHECKGOTO(ncclSocketGetAddr(listenSock, &handle->addr), ret, fail); @@ -470,7 +470,7 @@ struct bootstrapState { // helper functions static ncclResult_t createListenSocket(struct ncclComm* comm, uint64_t magic, struct ncclSocket* socket, union ncclSocketAddress* addr, ncclSocketType type) { - NCCLCHECK(ncclSocketInit(socket, &bootstrapNetIfAddr, magic, type, comm->abortFlag)); + NCCLCHECK(ncclSocketInit(socket, &bootstrapNetIfAddr, NULL, magic, type, comm->abortFlag)); NCCLCHECK(ncclSocketListen(socket)); NCCLCHECK(ncclSocketGetAddr(socket, addr)); return ncclSuccess; @@ -550,7 +550,7 @@ static ncclResult_t netRingConnect(ncclNet_t* net, struct bootstrapListen_t* lis return ncclSuccess; } static ncclResult_t socketRingConnect(ncclSocketAddress* addr, struct ncclSocket* sendSocket, struct ncclSocket* listenSock, struct ncclSocket* recvSocket, uint64_t magic, volatile uint32_t* abortFlag) { - NCCLCHECK(ncclSocketInit(sendSocket, addr, magic, ncclSocketTypeBootstrap, abortFlag)); + NCCLCHECK(ncclSocketInit(sendSocket, &bootstrapNetIfAddr, addr, magic, ncclSocketTypeBootstrap, abortFlag)); NCCLCHECK(ncclSocketConnect(sendSocket)); NCCLCHECK(ncclSocketInit(recvSocket)); NCCLCHECK(ncclSocketAccept(recvSocket, listenSock)); @@ -604,7 +604,7 @@ static ncclResult_t ringAllInfo(struct ncclComm* comm, struct bootstrapState* st static ncclResult_t sendToRoot(struct ncclBootstrapHandle* handle, struct ncclComm* comm, struct extInfo* info) { ncclResult_t ret = ncclSuccess; struct ncclSocket sock; - NCCLCHECK(ncclSocketInit(&sock, &handle->addr, handle->magic, ncclSocketTypeBootstrap, comm->abortFlag)); + NCCLCHECK(ncclSocketInit(&sock, &bootstrapNetIfAddr, &handle->addr, handle->magic, ncclSocketTypeBootstrap, comm->abortFlag)); NCCLCHECKGOTO(ncclSocketConnect(&sock), ret, fail); NCCLCHECKGOTO(socketSend(&sock, info, sizeof(struct extInfo)), ret, fail); NCCLCHECK(ncclSocketClose(&sock)); @@ -867,7 +867,7 @@ static ncclResult_t socketConnect(void* commState, int peer, int tag, struct ncc struct bootstrapState* state = (struct bootstrapState*)commState; struct socketAckInfo ack = (struct socketAckInfo){.rank = state->rank, .tag = tag}; - NCCLCHECKGOTO(ncclSocketInit(sock, state->peerP2pAddresses + peer, state->magic, ncclSocketTypeBootstrap, state->abortFlag), ret, fail); + NCCLCHECKGOTO(ncclSocketInit(sock, &bootstrapNetIfAddr, state->peerP2pAddresses + peer, state->magic, ncclSocketTypeBootstrap, state->abortFlag), ret, fail); NCCLCHECKGOTO(ncclSocketConnect(sock), ret, fail); NCCLCHECKGOTO(socketSend(sock, &ack, sizeof(struct socketAckInfo)), ret, fail); return ncclSuccess; diff --git a/src/include/socket.h b/src/include/socket.h index adeae9b2a..0599e64c6 100644 --- a/src/include/socket.h +++ b/src/include/socket.h @@ -56,6 +56,8 @@ struct ncclSocket { int acceptFd; int errorRetries; union ncclSocketAddress addr; + union ncclSocketAddress peerAddr; + int family; volatile uint32_t* abortFlag; int asyncFlag; enum ncclSocketState state; @@ -75,15 +77,15 @@ ncclResult_t ncclFindInterfaces(char* ifNames, union ncclSocketAddress *ifAddrs, int* nIfs); // Initialize a socket -ncclResult_t ncclSocketInit(struct ncclSocket* sock, const union ncclSocketAddress* addr = NULL, uint64_t magic = NCCL_SOCKET_MAGIC, enum ncclSocketType type = ncclSocketTypeUnknown, volatile uint32_t* abortFlag = NULL, int asyncFlag = 0, int customRetry = 0); +ncclResult_t ncclSocketInit(struct ncclSocket* sock, const union ncclSocketAddress* addr = NULL, const union ncclSocketAddress* peerAddr = NULL, uint64_t magic = NCCL_SOCKET_MAGIC, enum ncclSocketType type = ncclSocketTypeUnknown, volatile uint32_t* abortFlag = NULL, int asyncFlag = 0, int customRetry = 0); // Create a listening socket. sock->addr can be pre-filled with IP & port info. sock->fd is set after a successful call ncclResult_t ncclSocketListen(struct ncclSocket* sock); -ncclResult_t ncclSocketGetAddr(struct ncclSocket* sock, union ncclSocketAddress* addr); -// Connect to sock->addr. sock->fd is set after a successful call. +ncclResult_t ncclSocketGetAddr(struct ncclSocket* sock, union ncclSocketAddress* addr, bool isPeer = false); +// Connect to sock->peerAddr. sock->fd is set after a successful call. ncclResult_t ncclSocketConnect(struct ncclSocket* sock); // Return socket connection state. ncclResult_t ncclSocketReady(struct ncclSocket* sock, int *running); -// Accept an incoming connection from listenSock->fd and keep the file descriptor in sock->fd, with the remote side IP/port in sock->addr. +// Accept an incoming connection from listenSock->fd and keep the file descriptor in sock->fd, with the remote side IP/port in sock->peerAddr. ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* ulistenSock); ncclResult_t ncclSocketGetFd(struct ncclSocket* sock, int* fd); ncclResult_t ncclSocketSetFd(int fd, struct ncclSocket* sock); diff --git a/src/misc/socket.cc b/src/misc/socket.cc index d066d2829..1e4bd89ac 100644 --- a/src/misc/socket.cc +++ b/src/misc/socket.cc @@ -44,7 +44,7 @@ static ncclResult_t socketProgressOpt(int op, struct ncclSocket* sock, void* ptr } if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) { WARN("socketProgressOpt: Call to %s %s failed : %s", (op == NCCL_SOCKET_RECV ? "recv from" : "send to"), - ncclSocketToString(&sock->addr, line), strerror(errno)); + ncclSocketToString(&sock->peerAddr, line), strerror(errno)); return ncclRemoteError; } else { bytes = 0; @@ -69,7 +69,7 @@ static ncclResult_t socketProgress(int op, struct ncclSocket* sock, void* ptr, i } else { char line[SOCKET_NAME_MAXLEN+1]; WARN("socketProgress: Connection closed by remote peer %s", - ncclSocketToString(&sock->addr, line, /*numericHostForm*/0)); + ncclSocketToString(&sock->peerAddr, line, /*numericHostForm*/0)); return ncclRemoteError; } } @@ -425,19 +425,22 @@ ncclResult_t ncclSocketListen(struct ncclSocket* sock) { return ncclSuccess; } -ncclResult_t ncclSocketGetAddr(struct ncclSocket* sock, union ncclSocketAddress* addr) { +ncclResult_t ncclSocketGetAddr(struct ncclSocket* sock, union ncclSocketAddress* addr, bool isPeer) { if (sock == NULL) { WARN("ncclSocketGetAddr: pass NULL socket"); return ncclInvalidArgument; } if (sock->state != ncclSocketStateReady) return ncclInternalError; - memcpy(addr, &sock->addr, sizeof(union ncclSocketAddress)); + if (isPeer) + memcpy(addr, &sock->peerAddr, sizeof(union ncclSocketAddress)); + else + memcpy(addr, &sock->addr, sizeof(union ncclSocketAddress)); return ncclSuccess; } static ncclResult_t socketTryAccept(struct ncclSocket* sock) { socklen_t socklen = sizeof(union ncclSocketAddress); - sock->fd = accept(sock->acceptFd, (struct sockaddr*)&sock->addr, &socklen); + sock->fd = accept(sock->acceptFd, (struct sockaddr*)&sock->peerAddr, &socklen); if (sock->fd != -1) { sock->state = ncclSocketStateAccepted; } else if (errno == ENETDOWN || errno == EPROTO || errno == ENOPROTOOPT || errno == EHOSTDOWN || @@ -545,7 +548,7 @@ static ncclResult_t socketFinalizeAccept(struct ncclSocket* sock) { static ncclResult_t socketResetFd(struct ncclSocket* sock) { ncclResult_t ret = ncclSuccess; int fd = -1; - SYSCHECKGOTO(fd = socket(sock->addr.sa.sa_family, SOCK_STREAM, 0), "socket", ret, cleanup); + SYSCHECKGOTO(fd = socket(sock->family, SOCK_STREAM, 0), "socket", ret, cleanup); // if sock->fd is valid, close it and reuse its number if (sock->fd != -1) { SYSCHECKGOTO(dup2(fd, sock->fd), "dup2", ret, cleanup); @@ -589,7 +592,7 @@ static ncclResult_t socketConnectCheck(struct ncclSocket* sock, int errCode, con sock->state = ncclSocketStateConnecting; } else { sock->state = ncclSocketStateError; - WARN("%s: connect to %s failed : %s", funcName, ncclSocketToString(&sock->addr, line), strerror(errCode)); + WARN("%s: connect to %s failed : %s", funcName, ncclSocketToString(&sock->peerAddr, line), strerror(errCode)); return ncclSystemError; } return ncclSuccess; @@ -597,7 +600,7 @@ static ncclResult_t socketConnectCheck(struct ncclSocket* sock, int errCode, con static ncclResult_t socketStartConnect(struct ncclSocket* sock) { /* blocking/non-blocking connect() is determined by asyncFlag. */ - int ret = connect(sock->fd, &sock->addr.sa, sock->salen); + int ret = connect(sock->fd, &sock->peerAddr.sa, sock->salen); return socketConnectCheck(sock, (ret == -1) ? errno : 0, __func__); } @@ -695,6 +698,7 @@ ncclResult_t ncclSocketReady(struct ncclSocket* sock, int *running) { ncclResult_t ncclSocketConnect(struct ncclSocket* sock) { #ifdef ENABLE_TRACE char line[SOCKET_NAME_MAXLEN+1]; + char linePeer[SOCKET_NAME_MAXLEN+1]; #endif if (sock == NULL) { @@ -711,7 +715,15 @@ ncclResult_t ncclSocketConnect(struct ncclSocket* sock) { if (sock->state == ncclSocketStateError) return ncclRemoteError; return ncclInternalError; } - TRACE(NCCL_INIT|NCCL_NET,"Connecting to socket %s", ncclSocketToString(&sock->addr, line)); + SYSCHECK(bind(sock->fd, &sock->addr.sa, sock->salen), "bind"); + + /* Get the assigned Port */ + socklen_t size = sock->salen; + SYSCHECK(getsockname(sock->fd, &sock->addr.sa, &size), "getsockname"); + +#ifdef ENABLE_TRACE + TRACE(NCCL_INIT|NCCL_NET,"Connecting to socket local addr: %s, peer addr: %s", ncclSocketToString(&sock->addr, line), ncclSocketToString(&sock->peerAddr, linePeer)); +#endif sock->state = ncclSocketStateConnecting; sock->finalizeCounter = 0; @@ -791,8 +803,9 @@ ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* listen return ret; } -ncclResult_t ncclSocketInit(struct ncclSocket* sock, const union ncclSocketAddress* addr, uint64_t magic, enum ncclSocketType type, volatile uint32_t* abortFlag, int asyncFlag, int customRetry) { +ncclResult_t ncclSocketInit(struct ncclSocket* sock, const union ncclSocketAddress* addr, const union ncclSocketAddress* peerAddr, uint64_t magic, enum ncclSocketType type, volatile uint32_t* abortFlag, int asyncFlag, int customRetry) { ncclResult_t ret = ncclSuccess; + int family = -1; if (sock == NULL) goto exit; sock->errorRetries = 0; @@ -804,24 +817,42 @@ ncclResult_t ncclSocketInit(struct ncclSocket* sock, const union ncclSocketAddre sock->fd = -1; sock->acceptFd = -1; sock->customRetry = customRetry; + sock->family = -1; if (addr) { - /* IPv4/IPv6 support */ - int family; memcpy(&sock->addr, addr, sizeof(union ncclSocketAddress)); - family = sock->addr.sa.sa_family; + } else { + memset(&sock->addr, 0, sizeof(union ncclSocketAddress)); + } + if (peerAddr) { + memcpy(&sock->peerAddr, peerAddr, sizeof(union ncclSocketAddress)); + } else { + memset(&sock->peerAddr, 0, sizeof(union ncclSocketAddress)); + } + if (addr && peerAddr) { + if (addr->sa.sa_family != peerAddr->sa.sa_family) { + WARN("ncclSocketInit: local address and peer address family should be the same"); + ret = ncclInternalError; + goto exit; + } + family = addr->sa.sa_family; + } else if (addr) { + family = addr->sa.sa_family; + } else if (peerAddr) { + family = peerAddr->sa.sa_family; + } + if (addr || peerAddr) { + /* IPv4/IPv6 support */ if (family != AF_INET && family != AF_INET6) { - char line[SOCKET_NAME_MAXLEN+1]; - WARN("ncclSocketInit: connecting to address %s with family %d is neither AF_INET(%d) nor AF_INET6(%d)", - ncclSocketToString(&sock->addr, line), family, AF_INET, AF_INET6); + WARN("ncclSocketInit: socket address family %d is neither AF_INET(%d) nor AF_INET6(%d)", + family, AF_INET, AF_INET6); ret = ncclInternalError; goto exit; } + sock->family = family; sock->salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6); // in case of error, we close the fd before returning as it's unclear if the caller has to use ncclSocketClose for cleanup NCCLCHECKGOTO(socketResetFd(sock), ret, fail); - } else { - memset(&sock->addr, 0, sizeof(union ncclSocketAddress)); } exit: return ret; diff --git a/src/proxy.cc b/src/proxy.cc index 74ec70f0e..5ffaf9d30 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -1105,7 +1105,7 @@ ncclResult_t ncclProxyConnect(struct ncclComm* comm, int transport, int send, in sock = sharedProxyState->peerSocks + proxyConn->tpLocalRank; NCCLCHECK(ncclSocketReady(sock, &ready)); if (!ready) { - NCCLCHECK(ncclSocketInit(sock, sharedProxyState->peerAddresses+proxyConn->tpRank, comm->sharedRes->magic, ncclSocketTypeProxy, comm->abortFlag)); + NCCLCHECK(ncclSocketInit(sock, NULL, sharedProxyState->peerAddresses+proxyConn->tpRank, comm->sharedRes->magic, ncclSocketTypeProxy, comm->abortFlag)); NCCLCHECK(ncclSocketConnect(sock)); } @@ -1855,7 +1855,7 @@ ncclResult_t ncclProxyStop(struct ncclComm* comm) { // We need to send a ncclProxyMsgStop message to our own proxy struct ncclSocket sock; int type = ncclProxyMsgStop; - NCCLCHECK(ncclSocketInit(&sock, sharedProxyState->peerAddresses + comm->topParentRanks[comm->rank], comm->sharedRes->magic, ncclSocketTypeProxy, comm->abortFlag)); + NCCLCHECK(ncclSocketInit(&sock, NULL, sharedProxyState->peerAddresses + comm->topParentRanks[comm->rank], comm->sharedRes->magic, ncclSocketTypeProxy, comm->abortFlag)); if (ncclSocketConnect(&sock) == ncclSuccess) { (void)ncclSocketSend(&sock, &type, sizeof(int)); } diff --git a/src/ras/collectives.cc b/src/ras/collectives.cc index 4f8b6efc4..1fd9a98a3 100644 --- a/src/ras/collectives.cc +++ b/src/ras/collectives.cc @@ -323,7 +323,7 @@ ncclResult_t rasMsgHandleCollResp(struct rasMsg* msg, struct rasSocket* sock) { if (coll == nullptr) { INFO(NCCL_RAS, "RAS failed to find a matching ongoing collective for response %s:%ld from %s!", ncclSocketToString(&msg->collResp.rootAddr, line), msg->collResp.rootId, - ncclSocketToString(&sock->sock.addr, rasLine)); + ncclSocketToString(&sock->sock.peerAddr, rasLine)); goto exit; } diff --git a/src/ras/peers.cc b/src/ras/peers.cc index 8573209f1..888443260 100644 --- a/src/ras/peers.cc +++ b/src/ras/peers.cc @@ -493,7 +493,7 @@ ncclResult_t rasMsgHandlePeersUpdate(struct rasMsg* msg, struct rasSocket* sock) bool updatePeers, updateDeadPeers; INFO(NCCL_RAS, "RAS handling peersUpdate from %s (peersHash 0x%lx, deadPeersHash 0x%lx, nPeers %d, nDeadPeers %d)", - ncclSocketToString(&sock->sock.addr, rasLine), msg->peersUpdate.peersHash, msg->peersUpdate.deadPeersHash, + ncclSocketToString(&sock->sock.peerAddr, rasLine), msg->peersUpdate.peersHash, msg->peersUpdate.deadPeersHash, msg->peersUpdate.nPeers, msg->peersUpdate.nDeadPeers); INFO(NCCL_RAS, "RAS my old rasPeersHash 0x%lx, rasDeadPeersHash 0x%lx, nRasPeers %d, nRasDeadPeers %d", rasPeersHash, rasDeadPeersHash, nRasPeers, nRasDeadPeers); diff --git a/src/ras/ras.cc b/src/ras/ras.cc index 8ef551c64..10bb01ba0 100644 --- a/src/ras/ras.cc +++ b/src/ras/ras.cc @@ -97,7 +97,7 @@ ncclResult_t ncclRasCommInit(struct ncclComm* comm, struct rasRankInit* myRank) memcpy(&addr, &myRank->addr, sizeof(addr)); (addr.sa.sa_family == AF_INET ? addr.sin.sin_port : addr.sin6.sin6_port) = htons(0); - NCCLCHECKGOTO(ncclSocketInit(&rasNetListeningSocket, &addr, NCCL_SOCKET_MAGIC, ncclSocketTypeRasNetwork, + NCCLCHECKGOTO(ncclSocketInit(&rasNetListeningSocket, &addr, NULL, NCCL_SOCKET_MAGIC, ncclSocketTypeRasNetwork, /*abortFlag*/nullptr, /*asyncFlag*/1), ret, fail); NCCLCHECKGOTO(ncclSocketListen(&rasNetListeningSocket), ret, fail); INFO(NCCL_RAS, "RAS network listening socket at %s", @@ -405,7 +405,7 @@ ncclResult_t rasMsgHandle(struct rasMsg* msg, struct rasSocket* sock) { } else if (msg->type == RAS_MSG_COLLRESP) { NCCLCHECK(rasMsgHandleCollResp(msg, sock)); } else { - WARN("RAS received unknown message type (%d) from %s", msg->type, ncclSocketToString(&sock->sock.addr, rasLine)); + WARN("RAS received unknown message type (%d) from %s", msg->type, ncclSocketToString(&sock->sock.peerAddr, rasLine)); return ncclInternalError; } @@ -422,13 +422,13 @@ static ncclResult_t rasMsgHandleConnInit(const struct rasMsg* msg, struct rasSoc char line[SOCKET_NAME_MAXLEN+1]; INFO(NCCL_RAS, "RAS handling connInit from %s (version %d, listeningAddr %s, peersHash 0x%lx, deadPeersHash 0x%lx)", - ncclSocketToString(&sock->sock.addr, rasLine), msg->connInit.ncclVersion, + ncclSocketToString(&sock->sock.peerAddr, rasLine), msg->connInit.ncclVersion, ncclSocketToString(&msg->connInit.listeningAddr, line), msg->connInit.peersHash, msg->connInit.deadPeersHash); if (msg->connInit.ncclVersion != NCCL_VERSION_CODE) { // Close any such sockets immediately! This is basically unrecoverable... WARN("NCCL version mismatch with remote peer %s (local: %d, remote %d)", - ncclSocketToString(&sock->sock.addr, rasLine), NCCL_VERSION_CODE, msg->connInit.ncclVersion); + ncclSocketToString(&sock->sock.peerAddr, rasLine), NCCL_VERSION_CODE, msg->connInit.ncclVersion); rasNetSendNack(sock); rasSocketTerminate(sock, /*finalize*/true); ret = ncclInvalidUsage; @@ -482,7 +482,7 @@ static ncclResult_t rasMsgHandleConnInit(const struct rasMsg* msg, struct rasSoc conn->sock = sock; sock->conn = conn; - memcpy(&sock->sock.addr, &msg->connInit.listeningAddr, sizeof(sock->sock.addr)); + memcpy(&sock->sock.peerAddr, &msg->connInit.listeningAddr, sizeof(sock->sock.peerAddr)); // Make sure that the connection is part of the right links forming the RAS network. At this point we only // update the expected (non-external) connections; external ones will be added during keep-alive handling. @@ -518,13 +518,13 @@ static ncclResult_t rasMsgHandleConnInit(const struct rasMsg* msg, struct rasSoc // Handles the second message sent over a RAS socket as part of the handshake. static ncclResult_t rasMsgHandleConnInitAck(const struct rasMsg* msg, struct rasSocket* sock) { INFO(NCCL_RAS, "RAS handling connInitAck from %s (nack %d)", - ncclSocketToString(&sock->sock.addr, rasLine), msg->connInitAck.nack); + ncclSocketToString(&sock->sock.peerAddr, rasLine), msg->connInitAck.nack); if (msg->connInitAck.nack) { // The remote peer doesn't want to talk to us. The easiest way to prevent it is by declaring it dead. // We make a copy of the address because rasConnDisconnect will terminate the rasSocket. union ncclSocketAddress addr; - memcpy(&addr, &sock->sock.addr, sizeof(addr)); + memcpy(&addr, &sock->sock.peerAddr, sizeof(addr)); rasConnDisconnect(&addr); (void)rasPeerDeclareDead(&addr); @@ -563,7 +563,7 @@ static ncclResult_t rasNetSendNack(struct rasSocket* sock) { int closed = 0; int offset; - INFO(NCCL_RAS, "RAS sending NACK to %s", ncclSocketToString(&sock->sock.addr, rasLine)); + INFO(NCCL_RAS, "RAS sending NACK to %s", ncclSocketToString(&sock->sock.peerAddr, rasLine)); memset(&msg, '\0', sizeof(msg)); msg.type = RAS_MSG_CONNINITACK; diff --git a/src/ras/rasnet.cc b/src/ras/rasnet.cc index 1194e61b5..f0b7b5251 100644 --- a/src/ras/rasnet.cc +++ b/src/ras/rasnet.cc @@ -128,7 +128,7 @@ static void rasConnOpen(struct rasConnection* conn) { int ready; NCCLCHECKGOTO(getNewSockEntry(&sock), ret, fail); - NCCLCHECKGOTO(ncclSocketInit(&sock->sock, &conn->addr, NCCL_SOCKET_MAGIC, ncclSocketTypeRasNetwork, nullptr, + NCCLCHECKGOTO(ncclSocketInit(&sock->sock, nullptr, &conn->addr, NCCL_SOCKET_MAGIC, ncclSocketTypeRasNetwork, nullptr, /*asyncFlag*/1, /*customRetry*/1), ret, fail); closeSocketOnFail = true; NCCLCHECKGOTO(ncclSocketConnect(&sock->sock), ret, fail); @@ -359,7 +359,7 @@ ncclResult_t rasNetAcceptNewSocket() { bool socketInitialized = false; NCCLCHECKGOTO(getNewSockEntry(&sock), ret, fail); - NCCLCHECKGOTO(ncclSocketInit(&sock->sock, nullptr, NCCL_SOCKET_MAGIC, ncclSocketTypeRasNetwork, nullptr, + NCCLCHECKGOTO(ncclSocketInit(&sock->sock, nullptr, nullptr, NCCL_SOCKET_MAGIC, ncclSocketTypeRasNetwork, nullptr, /*asyncFlag*/1), ret, fail); socketInitialized = true; NCCLCHECKGOTO(ncclSocketAccept(&sock->sock, &rasNetListeningSocket), ret, fail); @@ -374,7 +374,7 @@ ncclResult_t rasNetAcceptNewSocket() { // helps the code tell the sides apart. sock->status = RAS_SOCK_CONNECTING; - INFO(NCCL_RAS, "RAS new incoming socket connection from %s", ncclSocketToString(&sock->sock.addr, rasLine)); + INFO(NCCL_RAS, "RAS new incoming socket connection from %s", ncclSocketToString(&sock->sock.peerAddr, rasLine)); exit: return ret; @@ -432,11 +432,11 @@ void rasSocksHandleTimeouts(int64_t now, int64_t* nextWakeup) { if (now - sock->createTime > RAS_STUCK_TIMEOUT) { if (sock->conn == nullptr) { INFO(NCCL_RAS, "RAS init timeout error (%lds) on incoming socket connection from %s", - (now-sock->createTime)/CLOCK_UNITS_PER_SEC, ncclSocketToString(&sock->sock.addr, rasLine)); + (now-sock->createTime)/CLOCK_UNITS_PER_SEC, ncclSocketToString(&sock->sock.peerAddr, rasLine)); } else { INFO(NCCL_RAS, "RAS init timeout error (%lds) on socket connection with %s " "(experiencingDelays %d, startRetryTime %.2fs, socket status %d)", - (now-sock->createTime)/CLOCK_UNITS_PER_SEC, ncclSocketToString(&sock->sock.addr, rasLine), + (now-sock->createTime)/CLOCK_UNITS_PER_SEC, ncclSocketToString(&sock->sock.peerAddr, rasLine), sock->conn->experiencingDelays, (sock->conn->startRetryTime ? (now-sock->conn->startRetryTime)/1e9 : 0.0), sock->status); } @@ -450,7 +450,7 @@ void rasSocksHandleTimeouts(int64_t now, int64_t* nextWakeup) { if (now - std::max(sock->lastSendTime, sock->lastRecvTime) > RAS_STUCK_TIMEOUT) { INFO(NCCL_RAS, "RAS termination stuck timeout error (%lds) on socket connection with %s", (now-std::max(sock->lastSendTime, sock->lastRecvTime)) / CLOCK_UNITS_PER_SEC, - ncclSocketToString(&sock->sock.addr, rasLine)); + ncclSocketToString(&sock->sock.peerAddr, rasLine)); rasSocketTerminate(sock, /*finalize*/true); // This socket is presumably already being re-established, if needed. } else { @@ -463,7 +463,7 @@ void rasSocksHandleTimeouts(int64_t now, int64_t* nextWakeup) { if (now - std::max(sock->lastSendTime, sock->lastRecvTime) > RAS_IDLE_TIMEOUT) { INFO(NCCL_RAS, "RAS idle timeout (%lds) on socket connection with %s", (now - std::max(sock->lastSendTime, sock->lastRecvTime)) / CLOCK_UNITS_PER_SEC, - ncclSocketToString(&sock->sock.addr, rasLine)); + ncclSocketToString(&sock->sock.peerAddr, rasLine)); rasSocketTerminate(sock, /*finalize*/false, /*startRetryOffset*/0, /*retry*/false); // The RAS network timeout handler will terminate the conn it was associated with, if any. } else { @@ -562,7 +562,7 @@ void rasSockEventLoop(struct rasSocket* sock, int pollIdx) { // Socket is not yet fully established. Continue the OS or NCCL-level handshake. if (ncclSocketReady(&sock->sock, &ready) != ncclSuccess) { INFO(NCCL_RAS, "RAS unexpected error from ncclSocketReady; terminating the socket connection with %s", - ncclSocketToString(&sock->sock.addr, rasLine)); + ncclSocketToString(&sock->sock.peerAddr, rasLine)); rasSocketTerminate(sock); // We may retry further down. } else { @@ -576,7 +576,7 @@ void rasSockEventLoop(struct rasSocket* sock, int pollIdx) { if (sock->conn->sock == sock) { if (rasConnPrepare(sock->conn) != ncclSuccess) { INFO(NCCL_RAS, "RAS unexpected error from rasConnPrepare; terminating the socket connection with %s", - ncclSocketToString(&sock->sock.addr, rasLine)); + ncclSocketToString(&sock->sock.peerAddr, rasLine)); rasSocketTerminate(sock); // We may retry further down. } @@ -584,7 +584,7 @@ void rasSockEventLoop(struct rasSocket* sock, int pollIdx) { // The connection this socket is associated with no longer considers it to be the current one. // This could possibly happen due to a race condition. Simply terminate it. INFO(NCCL_RAS, "RAS connected with %s via a socket that's no longer current!", - ncclSocketToString(&sock->sock.addr, rasLine)); + ncclSocketToString(&sock->sock.peerAddr, rasLine)); rasSocketTerminate(sock); } } // if (connectSide) @@ -603,12 +603,12 @@ void rasSockEventLoop(struct rasSocket* sock, int pollIdx) { assert(sock->conn->sock == sock); if (rasConnSendMsg(sock->conn, &closed, &allSent) != ncclSuccess) { INFO(NCCL_RAS, "RAS unexpected error from rasConnSendMsg; terminating the socket connection with %s", - ncclSocketToString(&sock->sock.addr, rasLine)); + ncclSocketToString(&sock->sock.peerAddr, rasLine)); rasSocketTerminate(sock); // We may retry further down. } else if (closed) { INFO(NCCL_RAS, "RAS socket connection with %s closed by peer on send; terminating it", - ncclSocketToString(&sock->sock.addr, rasLine)); + ncclSocketToString(&sock->sock.peerAddr, rasLine)); rasSocketTerminate(sock); // We may retry further down. } else { @@ -624,7 +624,7 @@ void rasSockEventLoop(struct rasSocket* sock, int pollIdx) { msg = nullptr; if (rasMsgRecv(sock, &msg, &closed) != ncclSuccess) { INFO(NCCL_RAS, "RAS unexpected error from rasMsgRecv; terminating the socket connection with %s", - ncclSocketToString(&sock->sock.addr, rasLine)); + ncclSocketToString(&sock->sock.peerAddr, rasLine)); rasSocketTerminate(sock, /*finalize*/true); // We may retry further down. } else if (closed) { @@ -638,7 +638,7 @@ void rasSockEventLoop(struct rasSocket* sock, int pollIdx) { else socketType = "current"; INFO(NCCL_RAS, "RAS %s socket connection with %s closed by peer on receive; terminating it", - socketType, ncclSocketToString(&sock->sock.addr, rasLine)); + socketType, ncclSocketToString(&sock->sock.peerAddr, rasLine)); rasSocketTerminate(sock, /*finalize*/true); // We may retry further down. } else { // !closed @@ -844,7 +844,7 @@ ncclResult_t rasMsgHandleKeepAlive(const struct rasMsg* msg, struct rasSocket* s // Just in case there's some unforeseen problem with the peers propagation though, exchange with the // remote to get everybody in sync. INFO(NCCL_RAS, "RAS keepAlive hash mismatch from %s (peersHash 0x%lx, deadPeersHash 0x%lx)", - ncclSocketToString(&sock->sock.addr, rasLine), msg->keepAlive.peersHash, msg->keepAlive.deadPeersHash); + ncclSocketToString(&sock->sock.peerAddr, rasLine), msg->keepAlive.peersHash, msg->keepAlive.deadPeersHash); INFO(NCCL_RAS, "RAS my peersHash 0x%lx, deadPeersHash 0x%lx", rasPeersHash, rasDeadPeersHash); NCCLCHECK(rasConnSendPeersUpdate(sock->conn, rasPeers, nRasPeers)); } diff --git a/src/transport/net_ib.cc b/src/transport/net_ib.cc index 709e7ad40..477ee8455 100644 --- a/src/transport/net_ib.cc +++ b/src/transport/net_ib.cc @@ -1259,7 +1259,7 @@ ncclResult_t ncclIbListen(int dev, void* opaqueHandle, void** listenComm) { memset(handle, 0, sizeof(struct ncclIbHandle)); comm->dev = dev; handle->magic = NCCL_SOCKET_MAGIC; - NCCLCHECKGOTO(ncclSocketInit(&comm->sock, &ncclIbIfAddr, handle->magic, ncclSocketTypeNetIb, NULL, 1), ret, fail); + NCCLCHECKGOTO(ncclSocketInit(&comm->sock, &ncclIbIfAddr, NULL, handle->magic, ncclSocketTypeNetIb, NULL, 1), ret, fail); NCCLCHECKGOTO(ncclSocketListen(&comm->sock), ret, fail); NCCLCHECKGOTO(ncclSocketGetAddr(&comm->sock, &handle->connectAddr), ret, fail); *listenComm = comm; @@ -1294,7 +1294,7 @@ ncclResult_t ncclIbConnect(int dev, ncclNetCommConfig_t* config, void* opaqueHan NCCLCHECK(ncclIbMalloc((void**)&comm, sizeof(struct ncclIbSendComm))); NCCLCHECKGOTO(ncclIbStatsInit(&comm->base.stats), ret, fail); - NCCLCHECKGOTO(ncclSocketInit(&comm->base.sock, &handle->connectAddr, handle->magic, ncclSocketTypeNetIb, NULL, 1), ret, fail); + NCCLCHECKGOTO(ncclSocketInit(&comm->base.sock, &ncclIbIfAddr, &handle->connectAddr, handle->magic, ncclSocketTypeNetIb, NULL, 1), ret, fail); stage->comm = comm; stage->state = ncclIbCommStateConnect; NCCLCHECKGOTO(ncclSocketConnect(&comm->base.sock), ret, fail); @@ -2140,7 +2140,7 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, size_t size, int tag, void* if (slots[r].size < 0 || slots[r].addr == 0 || slots[r].rkeys[0] == 0) { char line[SOCKET_NAME_MAXLEN + 1]; union ncclSocketAddress addr; - ncclSocketGetAddr(&comm->base.sock, &addr); + ncclSocketGetAddr(&comm->base.sock, &addr, true); WARN("NET/IB : req %d/%d tag %x peer %s posted incorrect receive info: size %ld addr %lx rkeys[0]=%x", r, nreqs, tag, ncclSocketToString(&addr, line), slots[r].size, slots[r].addr, slots[r].rkeys[0]); return ncclInternalError; @@ -2444,7 +2444,7 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) { struct ibv_wc *wc = wcs+w; if (wc->status != IBV_WC_SUCCESS) { union ncclSocketAddress addr; - ncclSocketGetAddr(r->sock, &addr); + ncclSocketGetAddr(r->sock, &addr, true); char localGidString[INET6_ADDRSTRLEN] = ""; char remoteGidString[INET6_ADDRSTRLEN] = ""; const char* localGidStr = NULL, *remoteGidStr = NULL; @@ -2462,7 +2462,7 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) { } union ncclSocketAddress addr; - ncclSocketGetAddr(r->sock, &addr); + ncclSocketGetAddr(r->sock, &addr, true); struct ncclIbRequest* req = r->base->reqs+(wc->wr_id & 0xff); #ifdef ENABLE_TRACE diff --git a/src/transport/net_socket.cc b/src/transport/net_socket.cc index 985810c47..73de17377 100644 --- a/src/transport/net_socket.cc +++ b/src/transport/net_socket.cc @@ -347,7 +347,7 @@ ncclResult_t ncclNetSocketListen(int dev, void* opaqueHandle, void** listenComm) struct ncclNetSocketListenComm* comm; NCCLCHECK(ncclCalloc(&comm, 1)); handle->magic = NCCL_SOCKET_MAGIC; - NCCLCHECKGOTO(ncclSocketInit(&comm->sock, &ncclNetSocketDevs[dev].addr, handle->magic, ncclSocketTypeNetSocket, NULL, 1), ret, fail); + NCCLCHECKGOTO(ncclSocketInit(&comm->sock, &ncclNetSocketDevs[dev].addr, NULL, handle->magic, ncclSocketTypeNetSocket, NULL, 1), ret, fail); NCCLCHECKGOTO(ncclSocketListen(&comm->sock), ret, fail); NCCLCHECKGOTO(ncclSocketGetAddr(&comm->sock, &handle->connectAddr), ret, fail); NCCLCHECKGOTO(ncclNetSocketGetNsockNthread(dev, &comm->nSocks, &comm->nThreads), ret, fail); @@ -388,7 +388,7 @@ ncclResult_t ncclNetSocketConnect(int dev, ncclNetCommConfig_t* config, void* op CUDACHECK(cudaGetDevice(&comm->cudaDev)); for (; inSocks+1; i++) { sock = (i == comm->nSocks) ? &comm->ctrlSock : comm->socks+i; - NCCLCHECK(ncclSocketInit(sock, &handle->connectAddr, handle->magic, ncclSocketTypeNetSocket, NULL, 1)); + NCCLCHECK(ncclSocketInit(sock, &ncclNetSocketDevs[dev].addr, &handle->connectAddr, handle->magic, ncclSocketTypeNetSocket, NULL, 1)); stage->sock = sock; stage->state = ncclNetSocketCommStateConnect; @@ -557,7 +557,7 @@ ncclResult_t ncclNetSocketTest(void* request, int* done, int* size) { if (senderSize > r->size) { char line[SOCKET_NAME_MAXLEN + 1]; union ncclSocketAddress addr; - NCCLCHECK(ncclSocketGetAddr(r->ctrlSock, &addr)); + NCCLCHECK(ncclSocketGetAddr(r->ctrlSock, &addr,/*peer address*/1)); WARN("NET/Socket : peer %s message truncated : receiving %d bytes instead of %d. If you believe your socket network is in a healthy state, " "there may be a mismatch in collective sizes or environment settings (e.g. NCCL_PROTO, NCCL_ALGO) between ranks", ncclSocketToString(&addr, line), senderSize, r->size);