Skip to content

Commit ae7aed1

Browse files
committed
NCCL 2.28.7-1
GPU-Initiated Networking (GIN): * Provides device-side API for integrating GPU-Initiated Networking capability into application kernels. * New transport layer called DOCA GPUNetIO. * New ncclGin construct to create, destroy and manipulate GIN contexts. * New ncclGinBarrierSession to provide synchronization functionality. * New put, signal, counter operations for data movement and signaling. * GIN API signatures and functionalities are subject to change. * GIN Support Requirements * CUDA 12.2 or later when compiling the GPU code * NVIDIA GPUs: Volta or newer. NVIDIA GPU drivers >= 510.40.3 * NVIDIA NICs: CX4 or newer. rdma-core >= 44.0 * Requires nvidia-peermem or DMABUF support. When using DMABUF, linux kernel >= 6.1 is required. New ncclCommRevoke API for fault tolerance: * Introduces ncclCommRevoke to quiesce ongoing NCCL work on a communicator without freeing resources. * This answers the need for a lightweight way to cancel in-flight collectives and bring a communicator to a safe state before split/shrink/finalize/destroy. * Includes optional cross-rank coordination (global barrier) and supports blocking/non-blocking usage. New NCCL Environment Plugin: * The env plugin allows users to set NCCL environment variables, for example, after loading them from a centralized database. * The NCCL_ENV_PLUGIN variable can be used to let NCCL load an external environment plugin. New NCCL Examples on GitHub: * The NCCL examples directory provides users and developers with practical code samples that highlight NCCL’s core features. * It covers basic operations like communicator initialization, point-to-point communication, and collective operations, as well as advanced features such as user buffer registration, symmetric memory, and the device API. Device API improvements: * Adds ncclFindWindow API. * Adds new ncclBarrierSession to provide hybrid synchronization functionality. * Makes multimem available with as few as two ranks. * Removes distance (NCCL_P2P_LEVEL) considerations from determining the availability of symmetric memory. Enhanced NCCL RAS output: * Extends RAS subsystem with JSON format to support machine-parsable metrics collection. * Enables structured data export for monitoring tools, dashboards, and automated analysis systems. Github Pull Requests resolved: * Fast Init - CPU Optimizations for NCCL Initialization Large Scale. (PR #1789) * Fast Init - Improve Bootstrap AllGather by 2x at large scale by sending bootstrap information bidirectionally. (PR #1791) * Fixes spurious failures when PyTorch is statically linked with NCCL-2.28.3 because error is not drained, but rather gets propagated into the next CUDA kernel invocation. (PR #1864) Other notable improvements: * Fixes multicast object leaks in case of failed NVLS user buffer registrations, which could lead to crashes. Avoids such registration attempts in case of the use of incompatible memory allocators. * Fixes potential data corruption with built-in symmetric kernels for small messages with size granularity under 8 bytes or when multiple symmetric operations were aggregated in a group. * Generalizes the existing point-to-point scheduling to the case of un-even GPU count per node. * Fixes a crash when network plugin assignment fails. * Fixes a large performance issue with NCCL_CROSS_NIC=0 and certain split mask settings, where NCCL cannot find a viable ring. * Fixes crash when NCCL is compiled with recent CUDA versions but running on hosts with certain specific older CUDA drivers.
1 parent 834ef72 commit ae7aed1

File tree

165 files changed

+28241
-497
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

165 files changed

+28241
-497
lines changed

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,10 @@ if(MAX_EXT_NET_PLUGINS GREATER 0)
148148
add_definitions(-DNCCL_NET_MAX_PLUGINS=${MAX_EXT_NET_PLUGINS})
149149
endif()
150150

151+
add_definitions(-DDOCA_VERBS_USE_CUDA_WRAPPER)
152+
add_definitions(-DDOCA_VERBS_USE_NET_WRAPPER)
153+
add_definitions(-DNCCL_GIN_PROXY_ENABLE=1)
154+
151155
# Library dependencies
152156
find_library(RT_LIBRARY NAMES rt)
153157
if(RT_LIBRARY)

ext-tuner/example/plugin.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ __hidden ncclResult_t pluginInit(void** context, uint64_t commId, size_t nRanks,
308308
// Set NVLSTree base network latency to 24us
309309
constants->hwLatencies[NCCL_HW_NET][NCCL_ALGO_NVLS][NCCL_PROTO_SIMPLE] = 24.0;
310310
}
311-
311+
312312
TunerContext* ctx = (TunerContext*)malloc(sizeof(TunerContext));
313313
if (!ctx) return ncclSystemError;
314314

ext-tuner/example/test/test_plugin.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -767,16 +767,16 @@ int test_nvl_domain_info() {
767767
.minRanksPerNvlDomain = 3, // minimum ranks across all domains (bottleneck)
768768
.maxRanksPerNvlDomain = 5 // maximum ranks across all domains (capacity)
769769
};
770-
770+
771771
void* context = NULL;
772772
ncclResult_t result = pluginInit(&context, 0, 8, 2, mock_logger, &nvl_domain, NULL);
773773
TEST_ASSERT(result == ncclSuccess, "Plugin init with NVLink domains should succeed");
774-
774+
775775
// Validate NVLD info structure
776776
TEST_ASSERT(nvl_domain.nNvlDomains == 2, "Should have 2 domains (nodes)");
777777
TEST_ASSERT(nvl_domain.minRanksPerNvlDomain == 3, "Should have minimum 3 ranks per domain");
778778
TEST_ASSERT(nvl_domain.maxRanksPerNvlDomain == 5, "Should have maximum 5 ranks per domain");
779-
779+
780780
// Clean up
781781
pluginFinalize(context);
782782
printf("NVLink domain info test passed!\n");

makefiles/common.mk

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ NET_PROFILER ?= 0
2020
MLX5DV ?= 0
2121
MAX_EXT_NET_PLUGINS ?= 0
2222

23-
NVCC = $(CUDA_HOME)/bin/nvcc
23+
NVCC ?= $(CUDA_HOME)/bin/nvcc
2424

2525
CUDA_LIB ?= $(CUDA_HOME)/lib64
2626
CUDA_INC ?= $(CUDA_HOME)/include
@@ -85,6 +85,8 @@ NVCUFLAGS := -ccbin $(CXX) $(NVCC_GENCODE) $(CXXSTD) --expt-extended-lambda -Xp
8585
# Use addprefix so that we can specify more than one path
8686
NVLDFLAGS := -L${CUDA_LIB} -lcudart -lrt
8787

88+
NVCUFLAGS_SYM :=
89+
8890
########## GCOV ##########
8991
GCOV ?= 0 # disable by default.
9092
GCOV_FLAGS := $(if $(filter 0,${GCOV} ${DEBUG}),,--coverage) # only gcov=1 and debug =1
@@ -158,3 +160,8 @@ endif
158160
ifneq ($(MAX_EXT_NET_PLUGINS), 0)
159161
CXXFLAGS += -DNCCL_NET_MAX_PLUGINS=$(MAX_EXT_NET_PLUGINS)
160162
endif
163+
164+
CXXFLAGS += -DDOCA_VERBS_USE_CUDA_WRAPPER -DDOCA_VERBS_USE_NET_WRAPPER
165+
NVCUFLAGS += -DDOCA_VERBS_USE_CUDA_WRAPPER -DDOCA_VERBS_USE_NET_WRAPPER
166+
167+
CXXFLAGS += -DNCCL_GIN_PROXY_ENABLE=1

makefiles/version.mk

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
##### version
22
NCCL_MAJOR := 2
33
NCCL_MINOR := 28
4-
NCCL_PATCH := 3
4+
NCCL_PATCH := 7
55
NCCL_SUFFIX :=
66
PKG_REVISION := 1

src/CMakeLists.txt

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ add_subdirectory(device)
3939
add_subdirectory(nccl_device)
4040
add_subdirectory(ras)
4141
add_subdirectory(scheduler)
42+
add_subdirectory(gin)
4243

4344
add_compile_options(-fmacro-prefix-map=${CMAKE_CURRENT_SOURCE_DIR}/=)
4445

@@ -52,6 +53,8 @@ list(APPEND LIBSRCFILES
5253
${RAS_SOURCES}
5354
${SYM_SOURCES}
5455
${SCHEDULER_SOURCES}
56+
${GIN_SOURCES}
57+
${DOCA_SOURCES}
5558
)
5659

5760
###################### Create a shared NCCL library ############################
@@ -65,6 +68,7 @@ target_include_directories(nccl PUBLIC
6568
${CMAKE_CURRENT_SOURCE_DIR}/include
6669
${CMAKE_CURRENT_SOURCE_DIR}/include/plugin
6770
${CUDAToolkit_INCLUDE_DIRS}
71+
${DOCA_HOME}/include
6872
${CUDAToolkit_INCLUDE_DIRS}/cccl
6973
)
7074

@@ -80,9 +84,25 @@ add_custom_command(
8084
BYPRODUCTS ${CMAKE_BINARY_DIR}/include/nccl.h
8185
)
8286

83-
add_custom_target(nccl_header DEPENDS ${CMAKE_BINARY_DIR}/include/nccl.h)
87+
file(GLOB_RECURSE SRC_DEVICE_HEADERS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/include/nccl_device/*.h)
88+
89+
# Copy all device header files to the destination
90+
foreach(HEADER_FILE ${SRC_DEVICE_HEADERS})
91+
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/${HEADER_FILE} ${CMAKE_BINARY_DIR}/${HEADER_FILE} COPYONLY)
92+
list(APPEND DEVICE_HEADERS ${CMAKE_BINARY_DIR}/${HEADER_FILE})
93+
endforeach()
94+
95+
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/include/nccl_device.h ${CMAKE_BINARY_DIR}/include/nccl_device.h COPYONLY)
96+
97+
add_custom_target(nccl_header DEPENDS
98+
${CMAKE_BINARY_DIR}/include/nccl.h
99+
${CMAKE_BINARY_DIR}/include/nccl_device.h
100+
${DEVICE_HEADERS}
101+
${DEVICE_DOCA_HEADERS}
102+
)
84103

85104
add_dependencies(nccl nccl_header)
105+
add_dependencies(nccl_device nccl_header)
86106

87107
# Set version and output name
88108
set_target_properties(nccl PROPERTIES
@@ -111,6 +131,11 @@ target_link_libraries(nccl
111131
${EXTRA_LIBS}
112132
)
113133

134+
# Add version script for symbol visibility control
135+
target_link_options(nccl PRIVATE
136+
"-Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/libnccl.map"
137+
)
138+
114139
# Set output directories for nccl shared library
115140
set_target_properties(nccl PROPERTIES
116141
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib"
@@ -149,6 +174,7 @@ target_include_directories(nccl_static PUBLIC
149174
${CMAKE_CURRENT_SOURCE_DIR}/include
150175
${CMAKE_CURRENT_SOURCE_DIR}/include/plugin
151176
${CUDAToolkit_INCLUDE_DIRS}
177+
transport/gdaki/doca-gpunetio/include
152178
${CUDAToolkit_INCLUDE_DIRS}/cccl
153179
)
154180

src/Makefile

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,24 @@ include ../makefiles/version.mk
88

99
##### src files
1010
INCEXPORTS := nccl.h nccl_device.h \
11-
$(patsubst include/%,%,$(wildcard include/nccl_device/*.h include/nccl_device/impl/*.h))
11+
$(patsubst include/%,%,$(wildcard include/nccl_device/*.h include/nccl_device/*/*.h include/nccl_device/*/*/*.h))
1212

1313
LIBSRCFILES := \
1414
bootstrap.cc channel.cc collectives.cc debug.cc enqueue.cc group.cc \
1515
init.cc init_nvtx.cc proxy.cc transport.cc mnnvl.cc allocator.cc dev_runtime.cc sym_kernels.cc ce_coll.cc \
1616
$(wildcard graph/*.cc) \
1717
$(wildcard misc/*.cc) \
1818
$(wildcard transport/*.cc) \
19+
$(wildcard transport/gdaki/*.cc) \
1920
$(wildcard register/*.cc) \
2021
$(wildcard plugin/*.cc) \
2122
$(wildcard plugin/net/*.cc) \
2223
$(wildcard plugin/tuner/*.cc) \
2324
$(wildcard plugin/profiler/*.cc) \
25+
$(wildcard plugin/env/*.cc) \
2426
$(wildcard nccl_device/*.cc) \
2527
$(wildcard scheduler/*.cc) \
28+
$(wildcard gin/*.cc) \
2629
$(filter-out ras/client.cc,$(wildcard ras/*.cc))
2730
BINSRCFILES := ras/client.cc
2831

@@ -40,6 +43,7 @@ LIBDIR := $(BUILDDIR)/lib
4043
OBJDIR := $(BUILDDIR)/obj
4144
PKGDIR := $(BUILDDIR)/lib/pkgconfig
4245
BINDIR := $(BUILDDIR)/bin
46+
4347
##### target files
4448
CUDARTLIB ?= cudart_static
4549

@@ -61,6 +65,17 @@ INCPLUGIN := include/plugin
6165

6266
DEVMANIFEST := $(BUILDDIR)/obj/device/manifest
6367

68+
# DOCA GPUNetIO definitions
69+
DOCA_HOME ?= transport/gdaki/doca-gpunetio
70+
DOCA_INC_INSTALL := $(INCDIR)/nccl_device/gin/gdaki/doca_gpunetio
71+
DOCA_OBJDIR := $(OBJDIR)/transport/gdaki/doca-gpunetio
72+
DOCA_INCLUDES := $(DOCA_HOME)/include/doca_gpunetio_device.h $(wildcard $(DOCA_HOME)/include/common/*.h) $(wildcard $(DOCA_HOME)/include/device/*.cuh)
73+
DOCA_INCTARGETS := $(DOCA_INCLUDES:$(DOCA_HOME)/include/%=$(DOCA_INC_INSTALL)/%)
74+
INCTARGETS += $(DOCA_INCTARGETS)
75+
DOCA_LIBSRC := doca_verbs_qp.cpp doca_verbs_cq.cpp doca_verbs_device_attr.cpp doca_verbs_umem.cpp doca_verbs_srq.cpp doca_verbs_uar.cpp doca_gpunetio.cpp doca_gpunetio_log.cpp doca_gpunetio_high_level.cpp doca_verbs_cuda_wrapper.cpp doca_verbs_mlx5dv_wrapper.cpp doca_verbs_ibv_wrapper.cpp doca_gpunetio_gdrcopy.cpp
76+
DOCA_LIBOBJ := $(DOCA_LIBSRC:%.cpp=$(DOCA_OBJDIR)/%.o)
77+
LIBOBJ += $(DOCA_LIBOBJ)
78+
6479
##### rules
6580
build : lib staticlib binary
6681

@@ -94,7 +109,7 @@ $(INCDIR)/nccl.h : nccl.h.in ../makefiles/version.mk
94109
$(LIBDIR)/$(LIBTARGET): $(LIBOBJ) $(DEVMANIFEST)
95110
@printf "Linking %-35s > %s\n" $(LIBTARGET) $@
96111
mkdir -p $(LIBDIR)
97-
$(CXX) $(CXXFLAGS) -shared -Wl,--no-as-needed -Wl,-soname,$(LIBSONAME) -o $@ $(LIBOBJ) $$(cat $(DEVMANIFEST)) $(LDFLAGS)
112+
$(CXX) $(CXXFLAGS) -shared -Wl,--no-as-needed -Wl,-soname,$(LIBSONAME) -o $@ $(LIBOBJ) $$(cat $(DEVMANIFEST)) $(LDFLAGS) -Wl,--version-script=libnccl.map
98113
ln -sf $(LIBSONAME) $(LIBDIR)/$(LIBNAME)
99114
ln -sf $(LIBTARGET) $(LIBDIR)/$(LIBSONAME)
100115

@@ -137,6 +152,36 @@ $(INCDIR)/nccl_device/impl/%.h: include/nccl_device/impl/%.h
137152
mkdir -p $(INCDIR)/nccl_device/impl
138153
install -m 644 $< $@
139154

155+
$(INCDIR)/nccl_device/gin/%.h: include/nccl_device/gin/%.h
156+
@printf "Grabbing %-35s > %s\n" $< $@
157+
mkdir -p $(INCDIR)/nccl_device/gin
158+
install -m 644 $< $@
159+
160+
$(INCDIR)/nccl_device/gin/gdaki/%.h: include/nccl_device/gin/gdaki/%.h
161+
@printf "Grabbing %-35s > %s\n" $< $@
162+
mkdir -p $(INCDIR)/nccl_device/gin/gdaki
163+
install -m 644 $< $@
164+
165+
$(INCDIR)/nccl_device/gin/proxy/%.h: include/nccl_device/gin/proxy/%.h
166+
@printf "Grabbing %-35s > %s\n" $< $@
167+
mkdir -p $(INCDIR)/nccl_device/gin/proxy
168+
install -m 644 $< $@
169+
170+
$(DOCA_INC_INSTALL)/%.h: $(DOCA_HOME)/include/%.h
171+
@printf "Grabbing %-35s > %s\n" $< $@
172+
mkdir -p $(DOCA_INC_INSTALL)
173+
install -m 644 $< $@
174+
175+
$(DOCA_INC_INSTALL)/common/%.h: $(DOCA_HOME)/include/common/%.h
176+
@printf "Grabbing %-35s > %s\n" $< $@
177+
mkdir -p $(DOCA_INC_INSTALL)/common
178+
install -m 644 $< $@
179+
180+
$(DOCA_INC_INSTALL)/device/%.cuh: $(DOCA_HOME)/include/device/%.cuh
181+
@printf "Grabbing %-35s > %s\n" $< $@
182+
mkdir -p $(DOCA_INC_INSTALL)/device
183+
install -m 644 $< $@
184+
140185
$(PKGDIR)/%.pc : %.pc
141186
@printf "Grabbing %-35s > %s\n" $< $@
142187
mkdir -p $(PKGDIR)
@@ -145,8 +190,18 @@ $(PKGDIR)/%.pc : %.pc
145190
$(OBJDIR)/%.o : %.cc $(INCTARGETS)
146191
@printf "Compiling %-35s > %s\n" $< $@
147192
mkdir -p `dirname $@`
148-
$(CXX) -I. -I$(INCDIR) $(CXXFLAGS) -Iinclude -I$(INCPLUGIN) -c $< -o $@
149-
@$(CXX) -I. -I$(INCDIR) $(CXXFLAGS) -Iinclude -I$(INCPLUGIN) -M $< > $(@:%.o=%.d.tmp)
193+
$(CXX) -I. -I$(INCDIR) $(CXXFLAGS) -Iinclude -I$(INCPLUGIN) -I$(DOCA_HOME)/include -c $< -o $@
194+
@$(CXX) -I. -I$(INCDIR) $(CXXFLAGS) -Iinclude -I$(INCPLUGIN) -I$(DOCA_HOME)/include -M $< > $(@:%.o=%.d.tmp)
195+
@sed "0,/^.*:/s//$(subst /,\/,$@):/" $(@:%.o=%.d.tmp) > $(@:%.o=%.d)
196+
@sed -e 's/.*://' -e 's/\\$$//' < $(@:%.o=%.d.tmp) | fmt -1 | \
197+
sed -e 's/^ *//' -e 's/$$/:/' >> $(@:%.o=%.d)
198+
@rm -f $(@:%.o=%.d.tmp)
199+
200+
$(DOCA_OBJDIR)/%.o : $(DOCA_HOME)/src/%.cpp
201+
@printf "Compiling %-35s > %s\n" $< $@
202+
mkdir -p `dirname $@`
203+
$(CXX) -I$(DOCA_HOME)/src -I$(DOCA_HOME)/include $(CXXFLAGS) -c $< -o $@
204+
@$(CXX) -I$(DOCA_HOME)/src -I$(DOCA_HOME)/include $(CXXFLAGS) -M $< > $(@:%.o=%.d.tmp)
150205
@sed "0,/^.*:/s//$(subst /,\/,$@):/" $(@:%.o=%.d.tmp) > $(@:%.o=%.d)
151206
@sed -e 's/.*://' -e 's/\\$$//' < $(@:%.o=%.d.tmp) | fmt -1 | \
152207
sed -e 's/^ *//' -e 's/$$/:/' >> $(@:%.o=%.d)

src/bootstrap.cc

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,21 @@ static ncclResult_t socketSendRecv(struct ncclSocket* sendSock, void* sendData,
226226
return ncclSuccess;
227227
}
228228

229+
static ncclResult_t socketDoubleSendRecv(struct ncclSocketOp ops[4]) {
230+
// ops synchronously exchange size then asynchronously exchange data in send->recv->send->recv order
231+
int senderRecvSize1, senderRecvSize2;
232+
NCCLCHECK(ncclSocketSendRecv(ops[0].sock, &ops[0].size, sizeof(int), ops[1].sock, &senderRecvSize1, sizeof(int)));
233+
NCCLCHECK(ncclSocketSendRecv(ops[2].sock, &ops[2].size, sizeof(int), ops[3].sock, &senderRecvSize2, sizeof(int)));
234+
if (senderRecvSize1 > ops[1].size || senderRecvSize2 > ops[3].size) {
235+
WARN("Message truncated : received %d,%d bytes instead of %d,%d", senderRecvSize1, senderRecvSize2, ops[1].size, ops[3].size);
236+
return ncclInternalError;
237+
}
238+
ops[1].size = std::min(ops[1].size, senderRecvSize1);
239+
ops[3].size = std::min(ops[3].size, senderRecvSize2);
240+
NCCLCHECK(ncclSocketMultiOp(ops, 4));
241+
return ncclSuccess;
242+
}
243+
229244
union ringConnectInfo {
230245
union ncclSocketAddress addr;
231246
char handle[NCCL_NET_HANDLE_MAXSIZE];
@@ -1007,22 +1022,40 @@ static ncclResult_t netRingAllGather(ncclNet_t* net, void* sendComm, void* recvC
10071022
if (recvDataHandle) netDereg(net, recvComm, &recvDataHandle);
10081023
return res;
10091024
}
1010-
static ncclResult_t socketRingAllGather(struct ncclSocket* sendSock, struct ncclSocket* recvSock, int rank, int nranks, char* data, int size) {
1025+
static ncclResult_t socketRingAllGather(struct ncclSocket* nextSock, struct ncclSocket* prevSock, int rank, int nranks, char* data, int size) {
10111026
ncclResult_t res = ncclSuccess;
10121027
uint64_t tFirst = 0, tRest = 0;
10131028
/* Simple ring based AllGather
10141029
* At each step i receive data from (rank-i-1) from prev
10151030
* and send previous step's data from (rank-i) to next
10161031
*/
1017-
TRACE(NCCL_BOOTSTRAP, "socketRingAllGather started");
1032+
TRACE(NCCL_BOOTSTRAP, "socketRingAllGather started: rank=%d nranks=%d", rank, nranks);
1033+
int totalSteps = nranks / 2;
1034+
TRACE(NCCL_BOOTSTRAP, "bidirectional bootstrap: totalSteps=%d", totalSteps);
10181035
BOOTSTRAP_PROF_OPEN(tFirst);
1019-
for (int i = 0; i < nranks - 1; i++) {
1020-
size_t rslice = (rank - i - 1 + nranks) % nranks;
1021-
size_t sslice = (rank - i + nranks) % nranks;
1022-
void* recv_data = data + rslice * size;
1023-
void* send_data = data + sslice * size;
1024-
NCCLCHECKGOTO(socketSendRecv(sendSock, send_data, size, recvSock, recv_data, size), res, exit);
1025-
if (i == 0) {
1036+
for (int step = 0; step < totalSteps; step++) {
1037+
// N ranks requires (N-1)/2 steps for the double ring algorithm. If N is even, the last step is requires a single send/recv
1038+
bool isFinalUnidirectional = (step == totalSteps - 1) && (nranks % 2 == 0);
1039+
// Ring0: ring from previous to next
1040+
int sendSliceRing0 = (rank - step + nranks) % nranks; // Send this slice to next neighbor
1041+
int recvSliceRing0 = (rank - step - 1 + nranks) % nranks; // Receive this slice from prev neighbor
1042+
// Ring1: ring from next to previous
1043+
int sendSliceRing1 = (rank + step) % nranks; // Send this slice to prev neighbor
1044+
int recvSliceRing1 = (rank + step + 1) % nranks; // Receive this slice from next neighbor
1045+
if (isFinalUnidirectional) {
1046+
// Final unidirectional step, only Ring0 is used
1047+
NCCLCHECKGOTO(socketSendRecv(nextSock, data + sendSliceRing0 * size, size, prevSock, data + recvSliceRing0 * size, size), res, exit);
1048+
} else {
1049+
// Bidirectional step: Ring0 and Ring1 are used simultaneously
1050+
struct ncclSocketOp ops[4] = {
1051+
{NCCL_SOCKET_SEND, nextSock, data + sendSliceRing0 * size, size, 0}, // Ring0: send to next
1052+
{NCCL_SOCKET_RECV, prevSock, data + recvSliceRing0 * size, size, 0}, // Ring0: recv from prev
1053+
{NCCL_SOCKET_SEND, prevSock, data + sendSliceRing1 * size, size, 0}, // Ring1: send to prev
1054+
{NCCL_SOCKET_RECV, nextSock, data + recvSliceRing1 * size, size, 0} // Ring1: recv from next
1055+
};
1056+
NCCLCHECKGOTO(socketDoubleSendRecv(ops), res, exit);
1057+
}
1058+
if (step == 0) {
10261059
BOOTSTRAP_PROF_CLOSE(tFirst);
10271060
BOOTSTRAP_PROF_OPEN(tRest);
10281061
}

0 commit comments

Comments
 (0)