diff --git a/src/misc/socket.cc b/src/misc/socket.cc index 278fb5c51..a50e294bc 100644 --- a/src/misc/socket.cc +++ b/src/misc/socket.cc @@ -16,6 +16,8 @@ NCCL_PARAM(RetryCnt, "SOCKET_RETRY_CNT", 34); NCCL_PARAM(RetryTimeOut, "SOCKET_RETRY_SLEEP_MSEC", 100); +NCCL_PARAM(PollTimeOut, "SOCKET_POLL_TIMEOUT_MSEC", 0); + static void msleep(unsigned int time_msec) { const long c_1e6 = 1e6; struct timespec tv = (struct timespec){ @@ -25,6 +27,14 @@ static void msleep(unsigned int time_msec) { nanosleep(&tv, NULL); } +static void pollSocket(int fd, int op) { + struct pollfd pfd; + pfd.fd = fd; + pfd.events = (op == NCCL_SOCKET_RECV) ? POLLIN : POLLOUT; + pfd.revents = 0; + poll(&pfd, 1, ncclParamPollTimeOut()); +} + static ncclResult_t socketProgressOpt(int op, struct ncclSocket* sock, void* ptr, int size, int* offset, int block, int* closed) { int bytes = 0; *closed = 0; @@ -77,8 +87,12 @@ static ncclResult_t socketProgress(int op, struct ncclSocket* sock, void* ptr, i } static ncclResult_t socketWait(int op, struct ncclSocket* sock, void* ptr, int size, int* offset) { - while (*offset < size) + while (*offset < size) { NCCLCHECK(socketProgress(op, sock, ptr, size, offset)); + // If we have more data to read or write, use the poll system call to wait + // until the socket becomes readable or writable again. + if ((*offset < size) && ncclParamPollTimeOut()) pollSocket(sock->fd, op); + } return ncclSuccess; }