Skip to content

Commit 277747c

Browse files
[Device] Add dynamic fetch/reduce pipelining for reduction collectives - Simple protocol (#1861)
* Support pipelining codegen and template specialization * Support ReduceCopy pipelining for AllReduce, ReduceScatter, and Reduce (currently enabled for bfloat16) * Remove need for FUNC_INDEX_TOTAL * Add pipeline field to device function key construction logic * Avoid unneeded codegen for LL/LL64 kernels * Modify conditions and add pipeline dtypes env * Optimize selection for both gfx942 and gfx950 * Increase pipeline bitfield width * Use __forceinline__ for all device functions * Realign reduceCopy with original form * Add opt-out option to enable perf debugs * Remove force-reduce-pipelining option from README * Update CHANGELOG.md --------- Co-authored-by: Jeffrey Novotny <[email protected]>
1 parent b882af9 commit 277747c

File tree

18 files changed

+286
-170
lines changed

18 files changed

+286
-170
lines changed

CHANGELOG.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ Full documentation for RCCL is available at [https://rccl.readthedocs.io](https:
2222
* Multi-node tuning for AllGather, AllReduce, and ReduceScatter that leverages LL/LL64/LL128 protocol to use nontemporal vector load/store for tunable message size ranges.
2323
* LL/LL128 usage ranges for AR, AG, and RS are part of the tuning models, which enable architecture-specific tuning in conjunction with the existing Rome Models scheme in RCCL.
2424
* Two new APIs are exposed as part of an initiative to separate RCCL code. These APIs are `rcclGetAlgoInfo` and `rcclFuncMaxSendRecvCount`. However, user-level invocation requires that RCCL be built with `RCCL_EXPOSE_STATIC` enabled.
25-
* Enabled double-buffering in `reduceCopyPacks` to trigger pipelining, especially to overlap bf16 arithmetic.
26-
* Added `--force-reduce-pipeline` as an option that can be passed to the `install.sh` script. Passing this option will enable software-triggered pipelining `bfloat16` reductions (i.e. `all_reduce`, `reduce_scatter` and `reduce`).
25+
* Enabled double-buffering in `reduceCopyPacks` to trigger pipelining, especially to overlap `bf16` arithmetic and bridge the gap between `fp32` performance and `bf16` for both `gfx942` and `gfx950`. Pipelining has been made tunable via `rcclSetPipelining`, similar to algorithms/protocols so that regression is avoided in certain message sizes.
2726
* Added a direct allgather algorithm. This is enabled by default for multi-node if there are 16 nodes or fewer. The message size threshold is 4MB.
2827

28+
2929
### Changed
3030

3131
* Compatibility with NCCL 2.23.4

CMakeLists.txt

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ option(PROFILE "Enable profiling"
3838
option(TIMETRACE "Enable time-trace during compilation" OFF)
3939
option(TRACE "Enable additional tracing" OFF)
4040
option(FAULT_INJECTION "Enable fault injection" ON)
41-
option(FORCE_REDUCE_PIPELINING "Force reduce pipelining" OFF)
4241

4342
# Default GPU architectures to build
4443
#==================================================================================================
@@ -848,18 +847,6 @@ foreach(file ${GENERATED_FILES})
848847
list(APPEND HIP_SOURCES ${file})
849848
endforeach()
850849

851-
# Enable SW pipelining where needed
852-
foreach(SOURCE_FILE ${HIP_SOURCES})
853-
# TODO: enable bf16 pipelining by default upon having the pipelined/scalar switching feature
854-
# if (FORCE_REDUCE_PIPELINING AND (SOURCE_FILE MATCHES "gensrc/reduce_.*" OR SOURCE_FILE MATCHES "gensrc/reduce_scatter_.*" OR SOURCE_FILE MATCHES "gensrc/all_reduce_.*"))
855-
# message(STATUS "RCCL_ENABLE_SW_PIPELINE enabled for ${SOURCE_FILE}")
856-
# set_source_files_properties(${SOURCE_FILE} PROPERTIES COMPILE_FLAGS "-DRCCL_ENABLE_SW_PIPELINE")
857-
if(FORCE_REDUCE_PIPELINING AND SOURCE_FILE MATCHES "gensrc/(reduce|reduce_scatter|all_reduce).*_bf16\\.cpp$")
858-
message(STATUS "BF16 pipelining support enabled for ${SOURCE_FILE}")
859-
set_source_files_properties(${SOURCE_FILE} PROPERTIES COMPILE_FLAGS "-DRCCL_ENABLE_SW_PIPELINE")
860-
endif()
861-
endforeach()
862-
863850
# Create an initial git_version.cpp file (that will be updated with latest git version)
864851
#==================================================================================================
865852
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/git_version.cpp "")

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ RCCL build & installation helper script
6464
-t|--tests_build Build rccl unit tests, but do not run
6565
--time-trace Plot the build time of RCCL (requires `ninja-build` package installed on the system)
6666
--verbose Show compile commands
67-
--force-reduce-pipeline Force reduce_copy sw pipeline to be used for every reduce-based collectives and datatypes
6867
```
6968
7069
By default, RCCL builds for all GPU targets defined in `DEFAULT_GPUS` in `CMakeLists.txt`. To target specific GPU(s), and potentially reduce build time, use `--amdgpu_targets` as a `;` separated string listing GPU(s) to target.

cmake/scripts/add_unroll.sh

100644100755
Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,22 @@
2121
HIP_FILE=$1
2222

2323
if [[ "$HIP_FILE" =~ .*/src/device/.*\.h ]]; then
24-
perl -pi -e 's/(template<typename T, typename RedOp(?:, typename Proto)?)(, bool isNetOffload.*?)?>/\1, int USE_ACC, int COLL_UNROLL\2>/g' "$HIP_FILE"
24+
perl -pi -e 's/(template<typename T, typename RedOp(?:, typename Proto)?)(, bool isNetOffload.*?)?>/\1, int USE_ACC, int COLL_UNROLL, int Pipeline\2>/g' "$HIP_FILE"
25+
perl -pi -e 's/(template<typename T, typename RedOp(?:, typename Proto)?(?:, int RCCLMetadata)?)(, bool isNetOffload.*?)?>/\1, int USE_ACC, int COLL_UNROLL, int Pipeline\2>/g' "$HIP_FILE"
2526
perl -pi -e 's/(ProtoSimple<[^,]*?,[^,]+?)>/\1, USE_ACC, COLL_UNROLL>/g' "$HIP_FILE"
2627
perl -pi -e 's/(runRing<T.*?)((, (true|false))?>\()/\1, USE_ACC, COLL_UNROLL\2/g' "$HIP_FILE"
2728
perl -pi -e 's/(runTreeUpDown<T.*?)>\(/\1, USE_ACC, COLL_UNROLL>(/' "$HIP_FILE"
2829
perl -pi -e 's/(runTreeSplit<T.*?)>\(/\1, USE_ACC, COLL_UNROLL>(/' "$HIP_FILE"
29-
sed -i "s/\\(struct RunWorkColl<ncclFunc[^>]*\\)>*/\\1, USE_ACC, COLL_UNROLL>/" "$HIP_FILE"
30-
sed -i "s/\\(struct RunWorkBatch<ncclFunc[^>]*\\)>*/\\1, USE_ACC, COLL_UNROLL>/" "$HIP_FILE"
31-
echo "Added COLL_UNROLL and USE_ACC template arguments to $HIP_FILE"
3230

31+
perl -pi -e 's/(runTreeSplit<T, RedOp, (ProtoLL|ProtoLL128), USE_ACC, COLL_UNROLL.*?)>/\1, 0>/' "$HIP_FILE"
32+
perl -pi -e 's/(runTreeUpDown<T, RedOp, (ProtoLL|ProtoLL128), USE_ACC, COLL_UNROLL.*?)>/\1, 0>/' "$HIP_FILE"
33+
perl -pi -e 's/(runRing<T, RedOp, (ProtoLL|ProtoLL128), USE_ACC, COLL_UNROLL.*?)>/\1, 0>/' "$HIP_FILE"
34+
perl -pi -e 's/(runRing<T, RedOp, (ProtoLL|ProtoLL128), (RCCL_ONE_NODE_RING_SIMPLE|RCCL_METADATA_EMPTY), USE_ACC, COLL_UNROLL.*?)>/\1, 0>/' "$HIP_FILE"
35+
36+
perl -pi -e 's/(runRing<T, RedOp, Proto, (RCCL_ONE_NODE_RING_SIMPLE|RCCL_METADATA_EMPTY), USE_ACC, COLL_UNROLL.*?)>/\1, Pipeline>/' "$HIP_FILE"
37+
perl -pi -e 's/(runRing<T, RedOp, Proto, USE_ACC, COLL_UNROLL.*?)>/\1, Pipeline>/' "$HIP_FILE"
38+
perl -pi -e 's/(runTreeSplit<T, RedOp, Proto, USE_ACC, COLL_UNROLL.*?)>/\1, Pipeline>/' "$HIP_FILE"
39+
perl -pi -e 's/(runTreeUpDown<T, RedOp, Proto, USE_ACC, COLL_UNROLL.*?)>/\1, Pipeline>/' "$HIP_FILE"
40+
sed -i "s/\\(struct RunWorkBatch<ncclFunc[^>]*\\)>*/\\1, USE_ACC, COLL_UNROLL, Pipeline>/" "$HIP_FILE"
41+
sed -i "s/\\(RunWorkColl<[^,]*,[^,]*,[^,]*,[^,]*,[^>]*\\)>/\\1, USE_ACC, COLL_UNROLL, Pipeline>/" "$HIP_FILE"
3342
fi

src/device/all_reduce.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#endif
1515

1616
namespace {
17-
template<typename T, typename RedOp, typename Proto, int RCCLMetadata, int USE_ACC, int COLL_UNROLL>
17+
template<typename T, typename RedOp, typename Proto, int RCCLMetadata>
1818
#if defined(USE_INDIRECT_FUNCTION_CALL) && !defined(__gfx942__) && !defined(__gfx950__)
1919
__device__ void runRing(int tid, int nthreads, struct ncclDevWorkColl* work) {
2020
#else
@@ -61,7 +61,7 @@ namespace {
6161
// Coverity reports that the callee treats &ring->next as an array. However, due to the use of
6262
// FanSymmetric<1>, only the first element is ever accessed, so it's fine.
6363
// coverity[callee_ptr_arith:FALSE]
64-
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto, 0, false, RCCLMetadata> prims
64+
Primitives<T, RedOp, FanSymmetric<1>, 0, Proto, 0, false, RCCLMetadata, Pipeline> prims
6565
(tid, nthreads, &ring->prev, &ring->next, work->sendbuff, work->recvbuff, work->redOpArg, 0, work->connIndex, work->connIndex, work);
6666

6767
#if defined(ENABLE_NPKIT)
@@ -252,7 +252,7 @@ namespace {
252252
#endif
253253

254254
{ // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
255-
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DEV_ARITY, 1>, /*Direct=*/0, Proto, 0> prims
255+
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DEV_ARITY, 1>, /*Direct=*/0, Proto, 0, false, 0, Pipeline> prims
256256
(tid, nthreads, tree->down, &tree->up, work->sendbuff, work->recvbuff, work->redOpArg, 0, 0, 0, work);
257257

258258
#if defined(ENABLE_NPKIT)
@@ -301,7 +301,7 @@ namespace {
301301
}
302302

303303
{ // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
304-
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DEV_ARITY>, /*Direct=*/0, Proto, 0> prims
304+
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DEV_ARITY>, /*Direct=*/0, Proto, 0, false, 0, Pipeline> prims
305305
(tid, nthreads, &tree->up, tree->down, work->sendbuff, work->recvbuff, work->redOpArg, 0, 0, 0, work);
306306

307307
#if defined(ENABLE_NPKIT)
@@ -420,7 +420,7 @@ namespace {
420420

421421
if (tree->up == -1) {
422422
// Reduce and broadcast. Max number of recv is 2, max number of send is 2
423-
Primitives<T, RedOp, FanSymmetric<NCCL_MAX_DEV_ARITY>, /*Direct=*/0, Proto,USE_ACC >
423+
Primitives<T, RedOp, FanSymmetric<NCCL_MAX_DEV_ARITY>, /*Direct=*/0, Proto, 0, false, 0, Pipeline, USE_ACC>
424424
prims(tid, nthreads, tree->down, tree->down, work->sendbuff, work->recvbuff, work->redOpArg, 0, 0, 0, work);
425425

426426
#if defined(ENABLE_NPKIT)
@@ -463,7 +463,7 @@ namespace {
463463
// Coverity reports that the callee treats &tree->up as an array. However, due to the use of
464464
// FanAsymmetric<n, 1>, only the first element is ever accessed, so it's fine.
465465
// coverity[callee_ptr_arith:FALSE]
466-
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DEV_ARITY, 1>, /*Direct=*/0, Proto, 0>
466+
Primitives<T, RedOp, FanAsymmetric<NCCL_MAX_DEV_ARITY, 1>, /*Direct=*/0, Proto, 0, false, 0, Pipeline>
467467
prims(tid, nthreadsSplit, tree->down, &tree->up, work->sendbuff, work->recvbuff, work->redOpArg, 0*Proto::MaxGroupWidth, 0, 0, work);
468468

469469
#if defined(ENABLE_NPKIT)
@@ -508,7 +508,7 @@ namespace {
508508
// Coverity reports that the callee treats &tree->up as an array. However, due to the use of
509509
// FanAsymmetric<1, n>, only the first element is ever accessed, so it's fine.
510510
// coverity[callee_ptr_arith:FALSE]
511-
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DEV_ARITY>, /*Direct=*/0, Proto, 0>
511+
Primitives<T, RedOp, FanAsymmetric<1, NCCL_MAX_DEV_ARITY>, /*Direct=*/0, Proto, 0, false, 0, Pipeline>
512512
prims(tid-nthreadsSplit, nthreads-nthreadsSplit, &tree->up, tree->down, work->sendbuff, work->recvbuff,
513513
work->redOpArg, 1*Proto::MaxGroupWidth, 0, 0, work);
514514

@@ -560,7 +560,7 @@ namespace {
560560
}
561561
}
562562

563-
#if defined(__gfx942__) || defined(__gfx950__) // Use a single slice per simple primitive for a single node on some GFX9 devices.
563+
#if defined(__gfx942__) || defined(__gfx950__) // Use a single slice per simple primitive for a single node on some GFX9 devices.
564564
#define rcclAllReduceRunRingSimpleProtoImpl(tid, nthreads, work) \
565565
if(work->rcclUseOneSlice){ \
566566
using Proto = ProtoSimple<ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS_SINGLE_NODE, ALLREDUCE_SLICESTEPS_SINGLE_NODE>; \
@@ -579,7 +579,7 @@ namespace {
579579
template<typename T, typename RedOp>
580580
struct RunWorkColl<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
581581
__device__ __forceinline__ void run(int tid, int nthreads, struct ncclDevWorkColl* work) {
582-
rcclAllReduceRunRingSimpleProtoImpl(tid, nthreads, work);
582+
rcclAllReduceRunRingSimpleProtoImpl(tid, nthreads, work);
583583
}
584584
};
585585

src/device/common.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -392,22 +392,22 @@ __device__ __forceinline__ void loadWorkBatchToShmem(
392392
}
393393
}
394394

395-
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int USE_ACC, int COLL_UNROLL>
395+
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int USE_ACC, int COLL_UNROLL, int Pipeline>
396396
struct RunWorkColl {
397397
__device__ void run(int tid, int tn, struct ncclDevWorkColl* work) {
398398
// Put NOT IMPLEMENTED behavior here.
399399
}
400400
};
401401

402-
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int USE_ACC, int COLL_UNROLL>
402+
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int USE_ACC, int COLL_UNROLL, int Pipeline>
403403
struct RunWorkBatch;
404404

405405
// Specialized for P2p in sendrecv.h
406406
template<typename T, typename RedOp>
407407
struct RunWorkBatch<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE>;
408408

409409
// Specialized here for non-P2p (Coll and CollReg)
410-
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int USE_ACC, int COLL_UNROLL>
410+
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int USE_ACC, int COLL_UNROLL, int Pipeline>
411411
struct RunWorkBatch {
412412
// This __forceinline__ is necessary. The compiler was inserting a function call
413413
// here from the LL ncclKernel.
@@ -437,7 +437,7 @@ struct RunWorkBatch {
437437
// Coverity reports a possible thread divergence due to not all threads participating in the collective.
438438
// However, the code ensures that the participation is on a per-warp basis.
439439
// coverity[device_thread_diverged:FALSE]
440-
if (tid < subtn) RunWorkColl<Fn, T, RedOp, Algo, Proto, USE_ACC, COLL_UNROLL>().run(tid, subtn, work);
440+
if (tid < subtn) RunWorkColl<Fn, T, RedOp, Algo, Proto>().run(tid, subtn, work);
441441
}
442442
}
443443
};
@@ -672,14 +672,14 @@ __global__ void ncclDevKernelDebug_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONST
672672
__global__ void ncclDevKernel_##suffix(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {}
673673

674674
#ifdef USE_INDIRECT_FUNCTION_CALL
675-
#define DEFINE_ncclDevFunc(suffix, coll, redop, ty, algo, proto, acc, unroll) \
675+
#define DEFINE_ncclDevFunc(suffix, coll, redop, ty, algo, proto, acc, pipeline, unroll) \
676676
__device__ void ncclDevFunc_##suffix() { \
677-
RunWorkBatch<coll, ty, redop<ty>, algo, proto, acc, unroll>().run(); \
677+
RunWorkBatch<coll, ty, redop<ty>, algo, proto, acc, unroll, pipeline>().run(); \
678678
}
679679
#else
680-
#define DEFINE_ncclDevFunc(suffix, coll, redop, ty, algo, proto, acc, unroll) \
680+
#define DEFINE_ncclDevFunc(suffix, coll, redop, ty, algo, proto, acc, pipeline, unroll) \
681681
__device__ __attribute__((noinline)) void ncclDevFunc_##suffix() { \
682-
RunWorkBatch<coll, ty, redop<ty>, algo, proto, acc, unroll>().run(); \
682+
RunWorkBatch<coll, ty, redop<ty>, algo, proto, acc, unroll, pipeline>().run(); \
683683
}
684684
#endif
685685

0 commit comments

Comments
 (0)