Skip to content

Conversation

@aamirshafi
Copy link

@aamirshafi aamirshafi commented Dec 15, 2025

Summary

This PR adds NCCL as an additional communication backend for DeepEP, leveraging NCCL's Device API for GPU-initiated
network operations. The integration introduces a CommunicationBackend abstraction that allows users to use NVSHMEM or
NCCL as communication backend. This enables users to select their preferred backend based on deployment requirements.


Build/Runtime Selection of NVSHMEM and NCCL Backends

Both backends are fully supported and can be selected at build/runtime:

Backend Build Runtime Status
NVSHMEM NVSHMEM_DIR=/path/to/nvshmem Default ✅ Existing (unchanged)
NCCL ENABLE_NCCL=1 NCCL_DIR=/path/to/nccl DEEP_EP_BACKEND=nccl ✅ New

Why Support Both NVSHMEM and NCCL Backends:

  • Existing NVSHMEM deployments continue to work without any changes
  • NCCL provides a portable, widely-adopted additional option
  • Users can choose based on their infrastructure and preferences

Unchanged Buffer Interface

The Buffer class interface exported to AI frameworks remains completely unchanged. This means:

  • Existing AI frameworks (e.g., training and inference systems using DeepEP) can utilize NCCL without any code changes
  • The same dispatch(), combine(), low_latency_dispatch(), and low_latency_combine() APIs work transparently with either backend
  • Backend selection is handled via environment variables at build/runtime, not in application code

Communication Backend Interface

We introduce an abstract CommunicationBackend interface that decouples DeepEP kernels from the underlying communication library:

class CommunicationBackend {
public:
    // Initialization/Cleanup
    virtual int init(const std::vector<uint8_t>& root_unique_id, int rank, 
                     int num_ranks, bool low_latency_mode, int qps_per_rank) = 0;
    virtual void finalize() = 0;
    virtual void get_unique_id(void* unique_id) = 0;

    // Synchronization
    virtual void barrier() = 0;

    // Memory Management (RDMA-registered buffers)
    virtual void* alloc(size_t size, size_t alignment) = 0;
    virtual void free(void* ptr) = 0;

    // Query
    virtual int get_rank() const = 0;
    virtual int get_num_ranks() const = 0;
    virtual BackendType get_backend_type() const = 0;
};

Backend Selection:

// Auto-detect from DEEP_EP_BACKEND environment variable
BackendType detect_backend_type();

// Global singleton management
void initialize_backend(BackendType type, ...);
CommunicationBackend* get_backend();
void finalize_backend();

NCCL Backend Integration

The NCCL backend uses the Device API for GPU-initiated network operations (GIN), translating between NVSHMEM's PGAS model and NCCL's window-based model:

Aspect NVSHMEM/IBGDA NCCL
Memory Model Pointer-based addressing Window-based with offset addressing
Data Transfer put_nbi(dst_ptr, src_ptr, count, pe) put(peer, dstWin, dstOff, srcWin, srcOff, bytes)
Synchronization Memory atomics Signal atomics (signal(), readSignal())
Completion Per-QP quiet() Per-context flush()

Key Integration Challenges:

  1. Multi-Communicator Mapping: Each NCCL GIN context wraps a Queue Pair (QP), and NCCL provides 4 GIN contexts per communicator. DeepEP's QP requirements are met using ⌈QPs/4⌉ communicators:

    comm_id = channel_id / 4
    ctx_id  = channel_id % 4
    
  2. Memory Registration: Buffers registered with all communicators; window handles stored in GPU memory for kernel access.

  3. Signal-Based Synchronization: Pre-allocated signal layouts map memory atomics to NCCL signal primitives.

  4. Semantic Preservation: Zero-byte put() with SignalAdd is semantically equivalent to net.signal() but better performing in the current NCCL release.


NCCL GIN Type Selection

For GPU-initiated network operations, NCCL supports multiple backends:

NCCL_GIN_TYPE Backend Description
2 Proxy CPU proxy for network operations
3 GDAKI GPU Direct Async Kernel-Initiated (recommended)

Performance Results

Benchmarked on H100 (900 GB/s NVLink) with 8×400 Gbit/s InfiniBand (~50 GB/s per NIC).

High-Throughput Kernels

4096 tokens, 7168 hidden, top-8 experts, BF16 dispatch, BF16 combine

Backend #EP Dispatch BW Combine BW
NVSHMEM 16 79.7 GB/s 66.4 GB/s
NCCL 16 76.9 GB/s 66.2 GB/s
NVSHMEM 32 62.9 GB/s 62.9 GB/s
NCCL 32 61.7 GB/s 62.3 GB/s
NVSHMEM 64 53.5 GB/s 53.2 GB/s
NCCL 64 52.7 GB/s 52.9 GB/s

Low-Latency Kernels

128 tokens, 7168 hidden, top-8 experts, FP8 dispatch, BF16 combine

Backend #EP Dispatch Latency Dispatch BW Combine Latency Combine BW
NVSHMEM 8 160.7 μs 46.8 GB/s 304.2 μs 47.8 GB/s
NCCL 8 160.8 μs 47.0 GB/s 302.8 μs 47.9 GB/s
NVSHMEM 16 182.3 μs 41.4 GB/s 319.8 μs 45.5 GB/s
NCCL 16 178.6 μs 42.2 GB/s 320.8 μs 45.3 GB/s
NVSHMEM 32 188.7 μs 40.0 GB/s 332.9 μs 43.7 GB/s
NCCL 32 190.1 μs 39.8 GB/s 333.2 μs 43.6 GB/s
NVSHMEM 64 225.1 μs 34.8 GB/s 343.1 μs 42.4 GB/s
NCCL 64 218.9 μs 34.5 GB/s 351.1 μs 41.4 GB/s

Requirements

  • NCCL 2.28.9 (or above) with Device API and GIN support (required)

Tested configurations:

  • CUDA 12.8+
  • Hopper (H100) and Blackwell GPUs

Usage

# Build with NCCL backend
export NCCL_DIR=/path/to/nccl/build
export ENABLE_NCCL=1
python3 setup.py build_ext --inplace && pip install --no-build-isolation .

# Run with NCCL
DEEP_EP_BACKEND=nccl NCCL_GIN_TYPE=3 python3 tests/test_low_latency.py
DEEP_EP_BACKEND=nccl NCCL_GIN_TYPE=3 python3 tests/test_internode.py

Future Work

  1. Reduce Communicator Count: Currently, we initialize multiple NCCL communicators due to the limitation of 4 GIN contexts per communicator. Work is in progress on the NCCL roadmap to increase the number of GIN contexts per communicator, which will simplify the integration.

  2. Unified Backend Interface for NVSHMEM: The current NVSHMEM code does not follow the CommunicationBackend interface in order to stay close to the upstream DeepEP main branch. Migrating NVSHMEM to use this interface should be addressed in consultation with DeepEP maintainers.


References

📄 Paper: GPU-Initiated Networking for NCCL

🔗 NCCL Repository: https://github.com/NVIDIA/nccl

📄 NCCL README: NCCL README


Co-authored with @grtheod.

fzyzcjy and others added 30 commits November 4, 2025 09:45
…-ai#217)

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* more

* add flag

* add test

* fix

* more

* apply
- Add TORCH_DISTRIBUTED_BACKEND env var configuration
- Fix tensor shape compatibility between NCCL and Gloo
- Add backend-aware wrappers for distributed operations
- Update test files to work with different backends
Signed-off-by: Georgios Theodorakis <[email protected]>
Signed-off-by: Georgios Theodorakis <[email protected]>
…. Removed redundant/dead code.

Signed-off-by: Georgios Theodorakis <[email protected]>
Signed-off-by: Georgios Theodorakis <[email protected]>
…de for the DeepEP buffers.

Signed-off-by: Georgios Theodorakis <[email protected]>
NVLink comms is disabled for now.
… enabled. Removing unnecessary comments from internode.cu

Signed-off-by: Georgios Theodorakis <[email protected]>
…mmunicators only across symmetric RDMA ranks).

Signed-off-by: Georgios Theodorakis <[email protected]>
Signed-off-by: Georgios Theodorakis <[email protected]>
- Fix COMBINE_LAUNCH_CASE macro redefinition in internode_ll.cu
- Fix Python linting errors (unused imports/variables, missing imports)
- Enable half and bfloat16 operators by undefining PyTorch's NO_* flags in setup.py
- Fix missing os import in test_low_latency.py
- Remove trailing whitespace in test utils
Signed-off-by: Georgios Theodorakis <[email protected]>
Extends existing --disable-nvlink flag (already supported in NVSHMEM) to
work with NCCL GIN backend. Implements device-side P2P pointer resolution
via ncclGetPeerPointer with fallback to RDMA when P2P unavailable or disabled.
buffer cleanup to prepare for next dispatch/combine. This helps avoid
a sync at the end of dispatch.

Previously signals for the current dispatch/combine were cleared towards
the end of the dispatch/combine kernels.
grtheod and others added 24 commits December 12, 2025 01:12
This commit resolves compilation issues when building with NCCL-only or
NVSHMEM modes by properly guarding backend-specific code with preprocessor
directives. It also documents the LD_PRELOAD workaround needed for NCCL
symbol resolution with PyTorch.

Key changes:

1. Renamed flag DISABLE_NVSHMEM to DISABLE_NVSHMEM_AND_NCCL for clarity
   - Updated all occurrences across setup.py, configs.cuh, runtime.cu,
     config.hpp, and deep_ep.cpp

2. Fixed conditional compilation in internode.cu and internode_ll.cu:
   - Guarded #include "ibgda_device.cuh" with #ifdef ENABLE_NVSHMEM
   - Guarded extern nvshmem_team_t cpu_rdma_team with #ifdef ENABLE_NVSHMEM
   - Changed all #else blocks to #elif defined(ENABLE_NVSHMEM) for explicit
     backend selection (31 instances in internode.cu, 18 in internode_ll.cu)
   - Made function signatures and kernel parameters conditional based on
     ENABLE_NCCL vs ENABLE_NVSHMEM
   - Fixed kernel launch macros to pass correct parameters per backend

3. Fixed runtime.cu NVSHMEM header guards:
   - Added nested #ifdef ENABLE_NVSHMEM within #ifndef DISABLE_NVSHMEM_AND_NCCL
   - Ensures NVSHMEM headers only included when actually enabled

4. Fixed setup.py:
   - Removed duplicate include_dirs.append()
   - Fixed undefined nccl_lib variable reference
   - Added -dlink flag to nvcc_dlink for proper CUDA device linking with RDC

5. Added cooperative_groups support to internode_ll.cu:
   - Added #include <cooperative_groups.h> and namespace alias
   - Resolves cg::this_grid().sync() compilation errors

6. Fixed internode.cu warnings:
   - Added #undef DISPATCH_LAUNCH_CASE to prevent macro redefinition
   - Guarded rdma_rank declaration with #ifdef ENABLE_NVSHMEM

7. Fixed NVSHMEM kernel parameter mismatch:
   - Added missing nccl_windows and signals_base to NOTIFY_DISPATCH_LAUNCH_CASE
   - Added missing signals_base to cached_notify kernel launch

8. Documentation:
   - Added LD_PRELOAD documentation to README-NCCL.md explaining the
     workaround for PyTorch's bundled NCCL vs custom GIN-enabled NCCL

This allows clean compilation in both NCCL-only mode (ENABLE_NCCL=1) and
NVSHMEM mode (NVSHMEM_DIR set), with proper symbol resolution at runtime.
…intranode kernels, cleaned-up nvshmem/nccl gin only code.

Signed-off-by: Georgios Theodorakis <[email protected]>
- Fix include order: move configs.cuh before CUDA headers in internode.cu and internode_ll.cu to properly undef PyTorch's half/bfloat16 restrictions
- Remove redundant -U flags from setup.py (no longer needed with correct include order)
- Remove NCCL src/include path, use only build/include (public API)
- Consolidate extra_link_args.extend() calls in setup.py
- Add NCCL path to build output
- Update internode::init() signature to accept num_rdma_ranks parameter
- Add reference to GIN paper (arXiv:2511.15076) in README-NCCL.md
Signed-off-by: Georgios Theodorakis <[email protected]>
- Add rdma_rank parameter to init() function signature
- Update call site in Buffer::sync to pass rdma_rank
- Fix default backend type string to "nccl"
…atch

- Remove bulk signal reset loops from dispatch/combine initialization
- Add inline net.resetSignal() after receiving data in dispatch recv
- This ensures signals are reset immediately after use rather than upfront
- Remove num_comms parameter from nccl_get_p2p_ptr
- Remove signals_base_next from dispatch/combine kernel signatures
- Simplify nccl_get_p2p_ptr return statement to single line
- Conditionally clean next buffer only when P2P is not disabled
- Remove unused comments
Signed-off-by: Georgios Theodorakis <[email protected]>
…tion

wq# Please enter the commit message for your changes. Lines starting
 1) Used ncclCoopWarp() for RDMA put operations
 2) Added net.signal() code for signaling. Keeping net.put() with 0 bytes for signaling because of better Dispatch-Send and Combine-Send latency.
…tion

- setup.py: Remove /build/ suffix from include/lib paths
- README-NCCL.md: Update documentation for NCCL_DIR convention
- All LL/HT scripts: Update NCCL_DIR and LD_LIBRARY_PATH accordingly
Replace the hardcoded 8 GPUs/node assumption with dynamic NCCL LSA
team-based peer detection. This allows low-latency mode to work on
systems with varying GPU configurations without recompilation.

Also removes unused DEEPEP_DEBUG_PRINT macro.
…onfig

Revert README.md to H800 benchmarks, add NCCL_GIN_TYPE docs and QP depth env vars
@alpha-baby
Copy link
Contributor

good job!

It seems that the version of NCCL v2.29.1 has not been released, but how did you get the code of this version?

@polarstormx
Copy link
Contributor

good job!

It seems that the version of NCCL v2.29.1 has not been released, but how did you get the code of this version?

Because he works at NVIDIA lol

@aamirshafi
Copy link
Author

@alpha-baby - Thanks for letting us know. Updated the NCCL version to 2.28.9.

@polarstormx - You got that right 💚

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants