diff --git a/src/include/socket.h b/src/include/socket.h index adeae9b2a..98bf811f6 100644 --- a/src/include/socket.h +++ b/src/include/socket.h @@ -39,7 +39,8 @@ enum ncclSocketState { ncclSocketStateTerminating = 8, ncclSocketStateClosed = 9, ncclSocketStateError = 10, - ncclSocketStateNum = 11 + ncclSocketStateBadMagic = 11, + ncclSocketStateNum = 12 }; enum ncclSocketType { @@ -84,7 +85,7 @@ ncclResult_t ncclSocketConnect(struct ncclSocket* sock); // Return socket connection state. ncclResult_t ncclSocketReady(struct ncclSocket* sock, int *running); // Accept an incoming connection from listenSock->fd and keep the file descriptor in sock->fd, with the remote side IP/port in sock->addr. -ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* ulistenSock); +ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* ulistenSock, bool retryOnBadMagic = true); ncclResult_t ncclSocketGetFd(struct ncclSocket* sock, int* fd); ncclResult_t ncclSocketSetFd(int fd, struct ncclSocket* sock); diff --git a/src/misc/socket.cc b/src/misc/socket.cc index 5633fef3e..17468b315 100644 --- a/src/misc/socket.cc +++ b/src/misc/socket.cc @@ -521,6 +521,7 @@ static ncclResult_t socketFinalizeAccept(struct ncclSocket* sock) { } if (magic != sock->magic) { socketResetAccept(sock); + sock->state = ncclSocketStateBadMagic; return ncclSuccess; } } @@ -743,7 +744,7 @@ ncclResult_t ncclSocketConnect(struct ncclSocket* sock) { } } -ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* listenSock) { +ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* listenSock, bool retryOnBadMagic) { ncclResult_t ret = ncclSuccess; if (listenSock == NULL || sock == NULL) { @@ -769,6 +770,9 @@ ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* listen do { NCCLCHECKGOTO(socketProgressState(sock), ret, exit); + if (sock->state == ncclSocketStateBadMagic && retryOnBadMagic) { + sock->state = ncclSocketStateAccepting; + } } while (sock->asyncFlag == 0 && (sock->abortFlag == NULL || __atomic_load_n(sock->abortFlag, __ATOMIC_ACQUIRE) == 0) && (sock->state == ncclSocketStateAccepting || @@ -780,6 +784,7 @@ ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* listen case ncclSocketStateAccepting: case ncclSocketStateAccepted: case ncclSocketStateReady: + case ncclSocketStateBadMagic: ret = ncclSuccess; break; case ncclSocketStateError: diff --git a/src/proxy.cc b/src/proxy.cc index 25a14cd64..f1ab33e1d 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -1668,7 +1668,7 @@ void* ncclProxyService(void* _args) { WARN("[Service thread] Initialize peers[%d].sock fails", s); return NULL; } - if (ncclSocketAccept(&peers[s].sock, proxyState->listenSock) != ncclSuccess) { + if (ncclSocketAccept(&peers[s].sock, proxyState->listenSock, false) != ncclSuccess) { WARN("[Service thread] Accept failed %s", strerror(errno)); } else { if (ncclSocketGetFd(&peers[s].sock, &pollfds[s].fd) != ncclSuccess) { diff --git a/src/ras/rasnet.cc b/src/ras/rasnet.cc index 1194e61b5..928623dd5 100644 --- a/src/ras/rasnet.cc +++ b/src/ras/rasnet.cc @@ -362,7 +362,7 @@ ncclResult_t rasNetAcceptNewSocket() { NCCLCHECKGOTO(ncclSocketInit(&sock->sock, nullptr, NCCL_SOCKET_MAGIC, ncclSocketTypeRasNetwork, nullptr, /*asyncFlag*/1), ret, fail); socketInitialized = true; - NCCLCHECKGOTO(ncclSocketAccept(&sock->sock, &rasNetListeningSocket), ret, fail); + NCCLCHECKGOTO(ncclSocketAccept(&sock->sock, &rasNetListeningSocket, false), ret, fail); NCCLCHECKGOTO(ncclSocketReady(&sock->sock, &ready), ret, fail); if (sock->sock.fd == -1)