Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/include/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ enum ncclSocketState {
ncclSocketStateTerminating = 8,
ncclSocketStateClosed = 9,
ncclSocketStateError = 10,
ncclSocketStateNum = 11
ncclSocketStateInvalidMagic = 11,
ncclSocketStateNum = 12
};

enum ncclSocketType {
Expand Down
51 changes: 48 additions & 3 deletions src/misc/socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

NCCL_PARAM(RetryCnt, "SOCKET_RETRY_CNT", 34);
NCCL_PARAM(RetryTimeOut, "SOCKET_RETRY_SLEEP_MSEC", 100);
NCCL_PARAM(RecvWarnTime, "SOCKET_RECV_WARN_TIME_MSEC", 6000);
NCCL_PARAM(RecvTimeOut, "SOCKET_RECV_SLEEP_MSEC", 180000);
static void msleep(unsigned int time_msec) {
const long c_1e6 = 1e6;
struct timespec tv = (struct timespec){
Expand Down Expand Up @@ -68,17 +70,48 @@ static ncclResult_t socketProgress(int op, struct ncclSocket* sock, void* ptr, i
return ncclSuccess;
} else {
char line[SOCKET_NAME_MAXLEN+1];
WARN("socketProgress: Connection closed by remote peer %s",
ncclSocketToString(&sock->addr, line, /*numericHostForm*/0));
// Added printing of local listening address and link type information
ncclSocketAddress listenAddr;
struct sockaddr_in addr;
socklen_t len = sizeof(addr);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, shouldn't this be a constexpr? Or getsocketname not like that?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, shouldn't this be a constexpr? Or getsocketname not like that?

I’ve referred to some online examples, and this way of writing seems fine. Would you be willing to elaborate on your questions in detail?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
socklen_t len = sizeof(addr);
constexpr socklen_t len = sizeof(addr);

getsockname(sock->acceptFd, (struct sockaddr*)&addr, &len);
char listenLine[SOCKET_NAME_MAXLEN+1];
listenAddr.sin = addr;
WARN("socketProgress: local: %s connection type: %d, closed by remote peer %s", ncclSocketToString(&listenAddr, listenLine), sock->type, ncclSocketToString(&sock->addr, line));
return ncclRemoteError;
}
}
return ncclSuccess;
}

// Added timeout logic during reception to address the following two potential issues:
// 1. Hanging caused by the peer not sending a data packet after the connection is established
// 2. Hanging caused by the peer sending data that is shorter than expected after the connection is established
static ncclResult_t socketWait(int op, struct ncclSocket* sock, void* ptr, int size, int* offset) {
while (*offset < size)
int64_t st = clockNano();
int64_t lastWarnTime = st; // Initialize the last reminder time
while (*offset < size) {
int64_t et = clockNano();
int64_t elapsed = et - st;
int64_t sinceLastWarn = et - lastWarnTime;

// Check if the reminder interval has been reached
if (sinceLastWarn >= ncclParamRecvWarnTime() * 1000000LL && op == NCCL_SOCKET_RECV) {
char line[SOCKET_NAME_MAXLEN+1];
WARN("socketWait: still waiting data from %s (elapsed %.1f s / timeout %.1f s)",
ncclSocketToString(&sock->addr, line),
(double)elapsed / 1000000000.0, ncclParamRecvTimeOut() / 1000.0);
lastWarnTime = et; // Update last reminder time
}

// timeout check
if (elapsed > ncclParamRecvTimeOut() * 1000000LL && op == NCCL_SOCKET_RECV) {
char line[SOCKET_NAME_MAXLEN+1];
WARN("socketWait: remote peer %s exceeded max retry time (incomplete data received)", ncclSocketToString(&sock->addr, line));
return ncclRemoteError;
}
NCCLCHECK(socketProgress(op, sock, ptr, size, offset));
}
return ncclSuccess;
}

Expand Down Expand Up @@ -489,6 +522,9 @@ static void socketResetAccept(struct ncclSocket* sock) {
sock->fd = -1;
sock->state = ncclSocketStateAccepting;
sock->finalizeCounter = 0;
if (sock->type == ncclSocketTypeProxy || sock->type == ncclSocketTypeRasNetwork) {
sock->state = ncclSocketStateInvalidMagic;
}
}

static ncclResult_t socketFinalizeAccept(struct ncclSocket* sock) {
Expand Down Expand Up @@ -520,6 +556,7 @@ static ncclResult_t socketFinalizeAccept(struct ncclSocket* sock) {
memcpy(&magic, sock->finalizeBuffer, sizeof(magic));
}
if (magic != sock->magic) {
WARN("socketFinalizeAccept: received invalid magic from %s, expected %lu but got %lu", ncclSocketToString(&sock->addr, line), sock->magic, magic);
socketResetAccept(sock);
return ncclSuccess;
}
Expand Down Expand Up @@ -692,6 +729,10 @@ ncclResult_t ncclSocketReady(struct ncclSocket* sock, int *running) {
if (*running == 0) {
NCCLCHECK(socketProgressState(sock));
*running = (sock->state == ncclSocketStateReady) ? 1 : 0;
if (sock->type == ncclSocketTypeRasNetwork && sock->state == ncclSocketStateInvalidMagic) {
WARN("ncclSocketReady type: %d: invalid magic", sock->type);
return ncclRemoteError;
}
}
return ncclSuccess;
}
Expand Down Expand Up @@ -785,6 +826,10 @@ ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* listen
case ncclSocketStateError:
ret = ncclSystemError;
break;
case ncclSocketStateInvalidMagic:
WARN("ncclSocketAccept: invalid magic socket state %d", sock->state);
return ncclRemoteError;
break;
default:
WARN("ncclSocketAccept: wrong socket state %d", sock->state);
ret = ncclInternalError;
Expand Down