@@ -80,6 +80,16 @@ static ncclResult_t bootstrapNetRecv(struct ncclSocket* sock, void* data, int si
8080 NCCLCHECK (ncclSocketRecv (sock, data, std::min (recvSize, size)));
8181 return ncclSuccess;
8282}
83+ static ncclResult_t bootstrapNetSendRecv (struct ncclSocket * sendSock, void * sendData, int sendSize, struct ncclSocket * recvSock, void * recvData, int recvSize) {
84+ int senderRecvSize;
85+ NCCLCHECK (ncclSocketSendRecv (sendSock, &sendSize, sizeof (int ), recvSock, &senderRecvSize, sizeof (int )));
86+ if (senderRecvSize > recvSize) {
87+ WARN (" Message truncated : received %d bytes instead of %d" , senderRecvSize, recvSize);
88+ return ncclInternalError;
89+ }
90+ NCCLCHECK (ncclSocketSendRecv (sendSock, sendData, sendSize, recvSock, recvData, recvSize));
91+ return ncclSuccess;
92+ }
8393
8494struct extInfo {
8595 int rank;
@@ -390,103 +400,40 @@ ncclResult_t bootstrapSplit(struct ncclBootstrapHandle* handle, struct ncclComm*
390400 goto exit;
391401}
392402
393- ncclResult_t bootstrapAllGather (void * commState, void * allData, int size) {
394- struct bootstrapState * state = (struct bootstrapState *)commState;
395- char * data = (char *)allData;
396- int rank = state->rank ;
397- int nranks = state->nranks ;
398-
399- TRACE (NCCL_INIT, " rank %d nranks %d size %d" , rank, nranks, size);
403+ // Bootstrap send/receive functions
404+ //
405+ // We do not keep connections opened with all ranks at all times, and we have no guarantee
406+ // that connections to our unique listen socket will arrive in the same order as we need
407+ // them. Therefore, when establishing a connection, the sender sends a (peer, tag) tuple to
408+ // allow the receiver to identify the flow, and keep it in an unexpected queue if needed.
400409
401- /* Simple ring based AllGather
402- * At each step i receive data from (rank-i-1) from left
403- * and send previous step's data from (rank-i) to right
404- */
405- for (int i=0 ; i<nranks-1 ; i++) {
406- size_t rslice = (rank - i - 1 + nranks) % nranks;
407- size_t sslice = (rank - i + nranks) % nranks;
408-
409- // Send slice to the right
410- NCCLCHECK (bootstrapNetSend (&state->ringSendSocket , data+sslice*size, size));
411- // Recv slice from the left
412- NCCLCHECK (bootstrapNetRecv (&state->ringRecvSocket , data+rslice*size, size));
413- }
410+ ncclResult_t bootstrapConnect (void * commState, int peer, int tag, struct ncclSocket * sock) {
411+ ncclResult_t ret = ncclSuccess;
412+ struct bootstrapState * state = (struct bootstrapState *)commState;
414413
415- TRACE (NCCL_INIT, " rank %d nranks %d size %d - DONE" , rank, nranks, size);
414+ NCCLCHECKGOTO (ncclSocketInit (sock, state->peerCommAddresses +peer, state->magic , ncclSocketTypeBootstrap), ret, fail);
415+ NCCLCHECKGOTO (ncclSocketConnect (sock), ret, fail);
416+ NCCLCHECKGOTO (bootstrapNetSend (sock, &state->rank , sizeof (int )), ret, fail);
417+ NCCLCHECKGOTO (bootstrapNetSend (sock, &tag, sizeof (int )), ret, fail);
416418 return ncclSuccess;
419+ fail:
420+ NCCLCHECK (ncclSocketClose (sock));
421+ return ret;
417422}
418423
419424ncclResult_t bootstrapSend (void * commState, int peer, int tag, void * data, int size) {
420425 ncclResult_t ret = ncclSuccess;
421- struct bootstrapState * state = (struct bootstrapState *)commState;
422426 struct ncclSocket sock;
423427
424- NCCLCHECKGOTO ( ncclSocketInit (&sock, state-> peerCommAddresses + peer, state-> magic , ncclSocketTypeBootstrap), ret, fail );
425- NCCLCHECKGOTO ( ncclSocketConnect (&sock), ret, fail );
426- NCCLCHECKGOTO (bootstrapNetSend (&sock, &state-> rank , sizeof ( int )) , ret, fail );
427- NCCLCHECKGOTO ( bootstrapNetSend (&sock, &tag, sizeof ( int )), ret, fail);
428- NCCLCHECKGOTO ( bootstrapNetSend (&sock, data, size), ret, fail );
428+ TRACE (NCCL_BOOTSTRAP, " Sending to peer=%d tag=%d size=%d " , peer, tag, size );
429+ NCCLCHECK ( bootstrapConnect (commState, peer, tag, &sock) );
430+ NCCLCHECKGOTO (bootstrapNetSend (&sock, data, size) , ret, exit );
431+
432+ TRACE (NCCL_BOOTSTRAP, " Sent to peer=%d tag=%d size=%d " , peer, tag, size );
429433
430434exit:
431435 NCCLCHECK (ncclSocketClose (&sock));
432436 return ret;
433- fail:
434- goto exit;
435- }
436-
437- ncclResult_t bootstrapBarrier (void * commState, int *ranks, int rank, int nranks, int tag) {
438- if (nranks == 1 ) return ncclSuccess;
439- TRACE (NCCL_INIT, " rank %d nranks %d tag %x - ENTER" , rank, nranks, tag);
440-
441- /* Simple intra process barrier
442- *
443- * Based on the dissemination algorithm by Debra Hensgen, Raphael Finkel, and Udi Manbet,
444- * "Two Algorithms for Barrier Synchronization," International Journal of Parallel Programming, 17(1):1-17, 1988"
445- */
446- int data[1 ];
447- for (int mask=1 ; mask<nranks; mask<<=1 ) {
448- int src = (rank - mask + nranks) % nranks;
449- int dst = (rank + mask) % nranks;
450- NCCLCHECK (bootstrapSend (commState, ranks[dst], tag, data, sizeof (data)));
451- NCCLCHECK (bootstrapRecv (commState, ranks[src], tag, data, sizeof (data)));
452- }
453-
454- TRACE (NCCL_INIT, " rank %d nranks %d tag %x - DONE" , rank, nranks, tag);
455- return ncclSuccess;
456- }
457-
458- ncclResult_t bootstrapIntraNodeAllGather (void * commState, int *ranks, int rank, int nranks, void * allData, int size) {
459- if (nranks == 1 ) return ncclSuccess;
460- char * data = (char *)allData;
461- TRACE (NCCL_INIT, " rank %d nranks %d size %d - ENTER" , rank, nranks, size);
462-
463- for (int i=1 ; i<nranks; i++) {
464- int src = (rank - i + nranks) % nranks;
465- int dst = (rank + i) % nranks;
466- NCCLCHECK (bootstrapSend (commState, ranks[dst], /* tag=*/ i, data+rank*size, size));
467- NCCLCHECK (bootstrapRecv (commState, ranks[src], /* tag=*/ i, data+src*size, size));
468- }
469-
470- TRACE (NCCL_INIT, " rank %d nranks %d size %d - DONE" , rank, nranks, size);
471- return ncclSuccess;
472- }
473-
474- // IntraNode in-place Broadcast
475- ncclResult_t bootstrapIntraNodeBroadcast (void * commState, int *ranks, int rank, int nranks, int root, void * bcastData, int size) {
476- if (nranks == 1 ) return ncclSuccess;
477- TRACE (NCCL_INIT, " rank %d nranks %d root %d size %d - ENTER" , rank, nranks, root, size);
478-
479- if (rank == root) {
480- for (int i=0 ; i<nranks; i++) {
481- if (i != root) NCCLCHECK (bootstrapSend (commState, ranks[i], /* tag=*/ ranks[i], bcastData, size));
482- }
483- }
484- else {
485- NCCLCHECK (bootstrapRecv (commState, ranks[root], /* tag=*/ ranks[rank], bcastData, size));
486- }
487-
488- TRACE (NCCL_INIT, " rank %d nranks %d root %d size %d - DONE" , rank, nranks, root, size);
489- return ncclSuccess;
490437}
491438
492439ncclResult_t unexpectedEnqueue (struct bootstrapState * state, int peer, int tag, struct ncclSocket * sock) {
@@ -543,38 +490,136 @@ static void unexpectedFree(struct bootstrapState* state) {
543490}
544491
545492// We can't know who we'll receive from, so we need to receive everything at once
546- ncclResult_t bootstrapRecv (void * commState, int peer, int tag, void * data, int size ) {
493+ ncclResult_t bootstrapAccept (void * commState, int peer, int tag, struct ncclSocket * sock ) {
547494 ncclResult_t ret = ncclSuccess;
548495 struct bootstrapState * state = (struct bootstrapState *)commState;
549- struct ncclSocket sock;
550496 int newPeer, newTag;
551497
552498 // Search unexpected connections first
553499 int found;
554- NCCLCHECK (unexpectedDequeue (state, peer, tag, &sock, &found));
555- if (found) {
556- NCCLCHECKGOTO (bootstrapNetRecv (&sock, ((char *)data), size), ret, fail);
557- goto exit;
558- }
500+ NCCLCHECK (unexpectedDequeue (state, peer, tag, sock, &found));
501+ if (found) return ncclSuccess;
559502
560503 // Then look for new connections
561504 while (1 ) {
562- NCCLCHECKGOTO (ncclSocketInit (&sock), ret, fail);
563- NCCLCHECKGOTO (ncclSocketAccept (&sock, &state->listenSock ), ret, fail);
564- NCCLCHECKGOTO (bootstrapNetRecv (&sock, &newPeer, sizeof (int )), ret, fail);
565- NCCLCHECKGOTO (bootstrapNetRecv (&sock, &newTag, sizeof (int )), ret, fail);
566- if (newPeer == peer && newTag == tag) {
567- NCCLCHECKGOTO (bootstrapNetRecv (&sock, ((char *)data), size), ret, fail);
568- goto exit;
569- }
570- // Unexpected connection. Save for later.
571- NCCLCHECKGOTO (unexpectedEnqueue (state, newPeer, newTag, &sock), ret, fail);
505+ NCCLCHECKGOTO (ncclSocketInit (sock), ret, fail);
506+ NCCLCHECKGOTO (ncclSocketAccept (sock, &state->listenSock ), ret, fail);
507+ NCCLCHECKGOTO (bootstrapNetRecv (sock, &newPeer, sizeof (int )), ret, fail);
508+ NCCLCHECKGOTO (bootstrapNetRecv (sock, &newTag, sizeof (int )), ret, fail);
509+ if (newPeer == peer && newTag == tag) return ncclSuccess;
510+ NCCLCHECKGOTO (unexpectedEnqueue (state, newPeer, newTag, sock), ret, fail);
572511 }
512+ return ncclSuccess;
513+ fail:
514+ NCCLCHECK (ncclSocketClose (sock));
515+ return ret;
516+ }
517+
518+ // We can't know who we'll receive from, so we need to receive everything at once
519+ ncclResult_t bootstrapRecv (void * commState, int peer, int tag, void * data, int size) {
520+ ncclResult_t ret;
521+ struct ncclSocket sock;
522+ NCCLCHECK (bootstrapAccept (commState, peer, tag, &sock));
523+ TRACE (NCCL_BOOTSTRAP, " Receiving tag=%d peer=%d size=%d" , tag, peer, size);
524+ NCCLCHECKGOTO (bootstrapNetRecv (&sock, ((char *)data), size), ret, exit);
573525exit:
574526 NCCLCHECK (ncclSocketClose (&sock));
575527 return ret;
576- fail:
577- goto exit;
528+ }
529+
530+ // Collective algorithms, based on bootstrapSend/Recv, and sometimes bootstrapConnect/Accept
531+
532+ ncclResult_t bootstrapRingAllGather (struct ncclSocket * prevSocket, struct ncclSocket * nextSocket, int rank, int nranks, char * data, int size) {
533+ /* Simple ring based AllGather
534+ * At each step i receive data from (rank-i-1) from prev
535+ * and send previous step's data from (rank-i) to next
536+ */
537+ for (int i=0 ; i<nranks-1 ; i++) {
538+ size_t rslice = (rank - i - 1 + nranks) % nranks;
539+ size_t sslice = (rank - i + nranks) % nranks;
540+
541+ // Send slice to the right, recv slice from the left
542+ NCCLCHECK (bootstrapNetSendRecv (nextSocket, data+sslice*size, size, prevSocket, data+rslice*size, size));
543+ }
544+ return ncclSuccess;
545+ }
546+ ncclResult_t bootstrapAllGather (void * commState, void * allData, int size) {
547+ struct bootstrapState * state = (struct bootstrapState *)commState;
548+ int rank = state->rank ;
549+ int nranks = state->nranks ;
550+
551+ TRACE (NCCL_INIT, " rank %d nranks %d size %d" , rank, nranks, size);
552+
553+ NCCLCHECK (bootstrapRingAllGather (&state->ringRecvSocket , &state->ringSendSocket , rank, nranks, (char *)allData, size));
554+
555+ TRACE (NCCL_INIT, " rank %d nranks %d size %d - DONE" , rank, nranks, size);
556+ return ncclSuccess;
557+ }
558+
559+ ncclResult_t bootstrapIntraNodeBarrier (void * commState, int *ranks, int rank, int nranks, int tag) {
560+ if (nranks == 1 ) return ncclSuccess;
561+ TRACE (NCCL_INIT, " rank %d nranks %d tag %x - ENTER" , rank, nranks, tag);
562+
563+ /* Simple [intra] process barrier
564+ *
565+ * Based on the dissemination algorithm by Debra Hensgen, Raphael Finkel, and Udi Manbet,
566+ * "Two Algorithms for Barrier Synchronization," International Journal of Parallel Programming, 17(1):1-17, 1988"
567+ */
568+ int data[1 ];
569+ for (int mask=1 ; mask<nranks; mask<<=1 ) {
570+ int src = (rank - mask + nranks) % nranks;
571+ int dst = (rank + mask) % nranks;
572+ NCCLCHECK (bootstrapSend (commState, ranks ? ranks[dst] : dst, tag, data, sizeof (data)));
573+ NCCLCHECK (bootstrapRecv (commState, ranks ? ranks[src] : src, tag, data, sizeof (data)));
574+ }
575+
576+ TRACE (NCCL_INIT, " rank %d nranks %d tag %x - DONE" , rank, nranks, tag);
577+ return ncclSuccess;
578+ }
579+
580+ ncclResult_t bootstrapBarrier (void * commState, int rank, int nranks, int tag) {
581+ return bootstrapIntraNodeBarrier (commState, NULL , rank, nranks, tag);
582+ }
583+
584+ ncclResult_t bootstrapIntraNodeAllGather (void * commState, int *ranks, int rank, int nranks, void * allData, int size) {
585+ if (nranks == 1 ) return ncclSuccess;
586+ TRACE (NCCL_INIT, " rank %d nranks %d size %d - ENTER" , rank, nranks, size);
587+
588+ int prevRank = ranks[(rank - 1 + nranks)%nranks];
589+ int nextRank = ranks[(rank + 1 ) % nranks];
590+ struct ncclSocket prevSocket, nextSocket;
591+ NCCLCHECK (bootstrapConnect (commState, nextRank, 0 , &nextSocket));
592+ NCCLCHECK (bootstrapAccept (commState, prevRank, 0 , &prevSocket));
593+
594+ NCCLCHECK (bootstrapRingAllGather (&prevSocket, &nextSocket, rank, nranks, (char *)allData, size));
595+
596+ NCCLCHECK (ncclSocketClose (&nextSocket));
597+ NCCLCHECK (ncclSocketClose (&prevSocket));
598+
599+ TRACE (NCCL_INIT, " rank %d nranks %d size %d - DONE" , rank, nranks, size);
600+ return ncclSuccess;
601+ }
602+
603+ // [IntraNode] in-place Broadcast
604+ ncclResult_t bootstrapIntraNodeBroadcast (void * commState, int *ranks, int rank, int nranks, int root, void * bcastData, int size) {
605+ if (nranks == 1 ) return ncclSuccess;
606+ TRACE (NCCL_INIT, " rank %d nranks %d root %d size %d - ENTER" , rank, nranks, root, size);
607+
608+ if (rank == root) {
609+ for (int i=0 ; i<nranks; i++) {
610+ if (i != root) NCCLCHECK (bootstrapSend (commState, ranks ? ranks[i] : i, /* tag=*/ ranks ? ranks[i] : i, bcastData, size));
611+ }
612+ }
613+ else {
614+ NCCLCHECK (bootstrapRecv (commState, ranks ? ranks[root] : root, /* tag=*/ ranks ? ranks[rank] : rank, bcastData, size));
615+ }
616+
617+ TRACE (NCCL_INIT, " rank %d nranks %d root %d size %d - DONE" , rank, nranks, root, size);
618+ return ncclSuccess;
619+ }
620+
621+ ncclResult_t bootstrapBroadcast (void * commState, int rank, int nranks, int root, void * bcastData, int size) {
622+ return bootstrapIntraNodeBroadcast (commState, NULL , rank, nranks, root, bcastData, size);
578623}
579624
580625ncclResult_t bootstrapClose (void * commState) {
0 commit comments