Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
be2eed8
more
fzyzcjy Jun 17, 2025
9ecb941
more
fzyzcjy Jun 17, 2025
9683d94
more
fzyzcjy Jun 17, 2025
8cf6bd8
more
fzyzcjy Jun 17, 2025
acf108a
more
fzyzcjy Jun 17, 2025
3e2cede
Revert "more"
fzyzcjy Jun 17, 2025
45fa1af
Revert "more"
fzyzcjy Jun 17, 2025
443bfa8
more
fzyzcjy Jun 17, 2025
b986cce
more
fzyzcjy Jun 17, 2025
3ea6f58
more
fzyzcjy Jun 17, 2025
5d3513b
more
fzyzcjy Jun 17, 2025
bda5695
more
fzyzcjy Jun 17, 2025
3740762
more
fzyzcjy Jun 17, 2025
ad4aee8
more
fzyzcjy Jun 17, 2025
b5e4aad
more
fzyzcjy Jun 17, 2025
240d058
more
fzyzcjy Jun 17, 2025
5379d59
more
fzyzcjy Jun 17, 2025
4fc8e79
more
fzyzcjy Jun 17, 2025
2e90afe
more
fzyzcjy Jun 17, 2025
3639a57
more
fzyzcjy Jun 17, 2025
4ef8f05
more
fzyzcjy Jun 17, 2025
047656e
more
fzyzcjy Jun 17, 2025
c21f36d
more
fzyzcjy Jun 17, 2025
7f3e4c0
more
fzyzcjy Jun 17, 2025
92fb573
more
fzyzcjy Jun 17, 2025
29f86f3
more
fzyzcjy Jun 17, 2025
5557e70
more
fzyzcjy Jun 17, 2025
9fd34e7
more
fzyzcjy Jun 17, 2025
6417393
more
fzyzcjy Jun 17, 2025
faaeaad
more
fzyzcjy Jun 17, 2025
c38dbed
more
fzyzcjy Jun 17, 2025
dc74c0a
more
fzyzcjy Jun 17, 2025
61dea30
more
fzyzcjy Jun 17, 2025
7d4bc93
more
fzyzcjy Jun 17, 2025
5b78f22
more
fzyzcjy Jun 17, 2025
75351cd
more
fzyzcjy Jun 17, 2025
7bb12d4
more
fzyzcjy Jun 17, 2025
0e5a155
more
fzyzcjy Jun 17, 2025
87b3980
more
fzyzcjy Jun 17, 2025
4398b5c
more
fzyzcjy Jun 17, 2025
d7e9ce3
more
fzyzcjy Jun 17, 2025
5b83cb8
more
fzyzcjy Jun 17, 2025
f024df5
more
fzyzcjy Jun 17, 2025
5a7b2f2
more
fzyzcjy Jun 17, 2025
6052379
more
fzyzcjy Jun 17, 2025
befcd27
more
fzyzcjy Jun 17, 2025
df598ea
more
fzyzcjy Jun 17, 2025
5b23a8a
more
fzyzcjy Jun 17, 2025
210e499
more
fzyzcjy Jun 17, 2025
379ac24
more
fzyzcjy Jun 17, 2025
43999dc
more
fzyzcjy Jun 17, 2025
7916011
more
fzyzcjy Jun 17, 2025
2f90c2d
Merge branch 'feat/cu_mem_api' into feat/deepep_normal_update
fzyzcjy Jun 17, 2025
0525f8f
more
fzyzcjy Jun 17, 2025
3032ede
Merge branch 'feat/cu_mem_api' into feat/deepep_normal_update
fzyzcjy Jun 17, 2025
dc652ea
more
fzyzcjy Jun 17, 2025
151993b
more
fzyzcjy Jun 17, 2025
06169d5
more
fzyzcjy Jun 17, 2025
4b54c98
more
fzyzcjy Jun 17, 2025
dec3315
more
fzyzcjy Jun 17, 2025
04f6a5b
more
fzyzcjy Jun 17, 2025
0613b1f
more
fzyzcjy Jun 17, 2025
b0ba0ea
Revert "more"
fzyzcjy Jun 17, 2025
01f0f90
more
fzyzcjy Jun 17, 2025
b80e0d4
more
fzyzcjy Jun 17, 2025
26130b2
more
fzyzcjy Jun 17, 2025
e395621
moew
fzyzcjy Jun 17, 2025
5b7e55a
more
fzyzcjy Jun 17, 2025
a8c6df8
temp
fzyzcjy Jun 17, 2025
e895366
more
fzyzcjy Jun 17, 2025
af060e6
more
fzyzcjy Jun 17, 2025
378f9b2
more
fzyzcjy Jun 18, 2025
0fc2a30
more
fzyzcjy Jun 18, 2025
1b14ad6
more
fzyzcjy Jun 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 135 additions & 9 deletions csrc/deep_ep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,131 @@
#include "kernels/api.cuh"
#include "kernels/configs.cuh"

namespace shared_memory {
void cu_mem_set_access_all(void* ptr, size_t size) {
int device_count;
CUDA_CHECK(cudaGetDeviceCount(&device_count));

CUmemAccessDesc access_desc[device_count];
for (int idx = 0; idx < device_count; ++idx) {
access_desc[idx].location.type = CU_MEM_LOCATION_TYPE_DEVICE;
access_desc[idx].location.id = idx;
access_desc[idx].flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
}

CU_CHECK(cuMemSetAccess((CUdeviceptr)ptr, size, access_desc, device_count));
}

void cu_mem_free(void* ptr) {
CUmemGenericAllocationHandle handle;
CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr));

size_t size = 0;
CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr));

CU_CHECK(cuMemUnmap((CUdeviceptr)ptr, size));
CU_CHECK(cuMemAddressFree((CUdeviceptr)ptr, size));
CU_CHECK(cuMemRelease(handle));
}

size_t get_size_align_to_granularity(size_t size_raw, size_t granularity) {
size_t size = (size_raw + granularity - 1) & ~(granularity - 1);
if (size == 0) size = granularity;
return size;
}

bool support_fabric() {
int device_count;
CUDA_CHECK(cudaGetDeviceCount(&device_count));

for (int device = 0; device < device_count; ++device) {
int support = 0;
CU_CHECK(cuDeviceGetAttribute(&support, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, device));
if (!support) {
return false;
}
}

return true;
}

SharedMemoryAllocator::SharedMemoryAllocator() : enable_fabric(support_fabric()) {}

void SharedMemoryAllocator::malloc(void** ptr, size_t size_raw) {
if (enable_fabric) {
CUdevice device;
CU_CHECK(cuCtxGetDevice(&device));

CUmemAllocationProp prop = {};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC;
prop.location.id = device;

size_t granularity = 0;
CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));

size_t size = get_size_align_to_granularity(size_raw, granularity);

CUmemGenericAllocationHandle handle;
CU_CHECK(cuMemCreate(&handle, size, &prop, 0));

CU_CHECK(cuMemAddressReserve((CUdeviceptr *)ptr, size, granularity, 0, 0));
CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0));
cu_mem_set_access_all(*ptr, size);
} else {
CUDA_CHECK(cudaMalloc(ptr, size_raw));
}
}

void SharedMemoryAllocator::free(void* ptr) {
if (enable_fabric) {
cu_mem_free(ptr);
} else {
CUDA_CHECK(cudaFree(ptr));
}
}

void SharedMemoryAllocator::get_mem_handle(MemHandle* mem_handle, void* ptr) {
size_t size = 0;
CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr));

mem_handle->size = size;

if (enable_fabric) {
CUmemGenericAllocationHandle handle;
CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr));

CU_CHECK(cuMemExportToShareableHandle(&mem_handle->inner.cu_mem_fabric_handle, handle, CU_MEM_HANDLE_TYPE_FABRIC, 0));
} else {
CUDA_CHECK(cudaIpcGetMemHandle(&mem_handle->inner.cuda_ipc_mem_handle, ptr));
}
}

void SharedMemoryAllocator::open_mem_handle(void** ptr, MemHandle* mem_handle) {
if (enable_fabric) {
size_t size = mem_handle->size;

CUmemGenericAllocationHandle handle;
CU_CHECK(cuMemImportFromShareableHandle(&handle, &mem_handle->inner.cu_mem_fabric_handle, CU_MEM_HANDLE_TYPE_FABRIC));

CU_CHECK(cuMemAddressReserve((CUdeviceptr *)ptr, size, 0, 0, 0));
CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0));
cu_mem_set_access_all(*ptr, size);
} else {
CUDA_CHECK(cudaIpcOpenMemHandle(ptr, mem_handle->inner.cuda_ipc_mem_handle, cudaIpcMemLazyEnablePeerAccess));
}
}

void SharedMemoryAllocator::close_mem_handle(void* ptr) {
if (enable_fabric) {
cu_mem_free(ptr);
} else {
CUDA_CHECK(cudaIpcCloseMemHandle(ptr));
}
}
}

namespace deep_ep {

Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode):
Expand Down Expand Up @@ -45,8 +170,8 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_

if (num_nvl_bytes > 0) {
// Local IPC: alloc local memory and set local IPC handles
CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes));
CUDA_CHECK(cudaIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank]));
shared_memory_allocator.malloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes);
shared_memory_allocator.get_mem_handle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank]);
buffer_ptrs_gpu = reinterpret_cast<void**>(static_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes);

// Set barrier signals
Expand Down Expand Up @@ -92,11 +217,11 @@ Buffer::~Buffer() noexcept(false) {
// Close remote IPC
if (is_available()) {
for (int i = 0; i < num_nvl_ranks; ++ i) if (i != nvl_rank)
CUDA_CHECK(cudaIpcCloseMemHandle(buffer_ptrs[i]));
shared_memory_allocator.close_mem_handle(buffer_ptrs[i]);
}

// Free local buffer and error flag
CUDA_CHECK(cudaFree(buffer_ptrs[nvl_rank]));
shared_memory_allocator.free(buffer_ptrs[nvl_rank]);
}

// Free NVSHMEM
Expand Down Expand Up @@ -142,7 +267,8 @@ int Buffer::get_local_device_id() const {
}

pybind11::bytearray Buffer::get_local_ipc_handle() const {
return {ipc_handles[nvl_rank].reserved, CUDA_IPC_HANDLE_SIZE};
const shared_memory::MemHandle& handle = ipc_handles[nvl_rank];
return {reinterpret_cast<const char*>(&handle), sizeof(handle)};
}

pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const {
Expand Down Expand Up @@ -175,13 +301,13 @@ void Buffer::sync(const std::vector<int> &device_ids,
for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks; ++ i) {
EP_HOST_ASSERT(all_gathered_handles[offset + i].has_value());
auto handle_str = std::string(all_gathered_handles[offset + i].value());
EP_HOST_ASSERT(handle_str.size() == CUDA_IPC_HANDLE_SIZE);
EP_HOST_ASSERT(handle_str.size() == shared_memory::HANDLE_SIZE);
if (offset + i != rank) {
std::memcpy(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE);
CUDA_CHECK(cudaIpcOpenMemHandle(&buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess));
std::memcpy(&ipc_handles[i], handle_str.c_str(), shared_memory::HANDLE_SIZE);
shared_memory_allocator.open_mem_handle(&buffer_ptrs[i], &ipc_handles[i]);
barrier_signal_ptrs[i] = reinterpret_cast<int*>(static_cast<uint8_t*>(buffer_ptrs[i]) + num_nvl_bytes);
} else {
EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE) == 0);
EP_HOST_ASSERT(std::memcmp(&ipc_handles[i], handle_str.c_str(), shared_memory::HANDLE_SIZE) == 0);
}
}

Expand Down
34 changes: 33 additions & 1 deletion csrc/deep_ep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,40 @@
#define TORCH_EXTENSION_NAME deep_ep_cpp
#endif

namespace shared_memory {

union MemHandleInner {
cudaIpcMemHandle_t cuda_ipc_mem_handle;
CUmemFabricHandle cu_mem_fabric_handle;
};

struct MemHandle {
MemHandleInner inner;
size_t size;
};

constexpr size_t HANDLE_SIZE = sizeof(MemHandle);

class SharedMemoryAllocator {
public:
SharedMemoryAllocator();
void malloc(void** ptr, size_t size);
void free(void* ptr);
void get_mem_handle(MemHandle* mem_handle, void* ptr);
void open_mem_handle(void** ptr, MemHandle* mem_handle);
void close_mem_handle(void* ptr);
private:
bool enable_fabric;
};
}

namespace deep_ep {

struct Buffer {

#ifndef NVLINK_DOMAIN_LARGE
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "The number of maximum NVLink peers must be 8");
#endif

private:
// Low-latency mode buffer
Expand All @@ -44,7 +74,7 @@ struct Buffer {
int num_device_sms;
int rank, rdma_rank, nvl_rank;
int num_ranks, num_rdma_ranks, num_nvl_ranks;
cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS];
shared_memory::MemHandle ipc_handles[NUM_MAX_NVL_PEERS];

// Stream for communication
at::cuda::CUDAStream comm_stream;
Expand All @@ -71,6 +101,8 @@ struct Buffer {
volatile int* moe_recv_rdma_counter = nullptr;
int* moe_recv_rdma_counter_mapped = nullptr;

shared_memory::SharedMemoryAllocator shared_memory_allocator;

public:
Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode);

Expand Down
7 changes: 7 additions & 0 deletions csrc/kernels/configs.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
#pragma once

#define NVLINK_DOMAIN_LARGE

#ifdef NVLINK_DOMAIN_LARGE
#define NUM_MAX_NVL_PEERS 24
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we increase it to 72?

Copy link
Contributor Author

@fzyzcjy fzyzcjy Jun 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering the use case of it - it seems large scale EP on prefill with 72 gpus does not have benefits iirc

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's for training in NVL72.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that looks pretty reasonable! I think it is implementable, but since there are already a lot of PRs pending waiting for LyricZhao to have time to review and merge, I may continue this PR a bit later.

Copy link

@DorianZi DorianZi Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I'm wrong. NVL72 is 18 nodes of 4-GPU, so the intra-node nvlink peer number is no more than 4, while the inter-node nvshmem can itself find cross-node nvlink. Why do we need extend the intra-node nvlink peer to 24 or larger?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding

  1. cross-node nvlink / MNNVL is implemented as intra-node.
  2. DeepEP uses nvshmem low level infiniband API in inter-node, so it doesn't benefit from nvshmem MNNVL feature.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vinjn Thanks for the reply. Wondering without changes here , how did sglang run DeepEP with EP48 on nvlink-only NVL72 ?

Copy link
Contributor

@shifangx shifangx Sep 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vinjn Thanks for the reply. Wondering without changes here , how did sglang run DeepEP with EP48 on nvlink-only NVL72 ?

  • For SGLang Decoding, we can get performance gain with large EP size, such as EP48. It uses low latency dispatch/combine, which already support NVL72 for any EP size.

  • For SGLang Prefill, it uses intranode/internode dispatch/combine, which is the kernels we are talking about.

Without the pr-218, intranode dispatch/combine cannot support EP size larger than 4. Internode dispatch/combine supports any EP size, but it uses two hops transition, so it is not the best solution for NVL72.
With the pr-218, intranode dispatch/combine can expand to EP24.

#else
#define NUM_MAX_NVL_PEERS 8
#endif

#define NUM_MAX_RDMA_PEERS 20
#define NUM_WORKSPACE_BYTES (32 * 1024 * 1024)
#define NUM_MAX_LOCAL_EXPERTS 1024
Expand Down
12 changes: 12 additions & 0 deletions csrc/kernels/exception.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ do { \
} while (0)
#endif

#ifndef CU_CHECK
#define CU_CHECK(cmd) \
do { \
CUresult e = (cmd); \
if (e != CUDA_SUCCESS) { \
const char *error_str = NULL; \
cuGetErrorString(e, &error_str); \
throw EPException("CU", __FILE__, __LINE__, std::string(error_str)); \
} \
} while (0)
#endif

#ifndef EP_HOST_ASSERT
#define EP_HOST_ASSERT(cond) \
do { \
Expand Down
6 changes: 6 additions & 0 deletions csrc/kernels/internode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ extern nvshmem_team_t cpu_rdma_team;
struct SourceMeta {
int src_rdma_rank, is_token_in_nvl_rank_bits;

#ifndef NVLINK_DOMAIN_LARGE
EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "Invalid number of maximum NVL peers");
#endif

__forceinline__ SourceMeta() = default;

Expand Down Expand Up @@ -85,6 +87,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
void* rdma_buffer_ptr,
void** buffer_ptrs, int** barrier_signal_ptrs, int rank,
const nvshmem_team_t rdma_team) {
#ifndef NVLINK_DOMAIN_LARGE
auto sm_id = static_cast<int>(blockIdx.x);
auto thread_id = static_cast<int>(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id();
auto num_threads = static_cast<int>(blockDim.x), num_warps = num_threads / 32;
Expand Down Expand Up @@ -282,6 +285,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
prefix_row[i] += prefix_row[i - 1];
}
}
#endif
}

void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
Expand Down Expand Up @@ -349,6 +353,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens,
void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens,
int rank, int num_ranks) {
#ifndef NVLINK_DOMAIN_LARGE
enum class WarpRole {
kRDMASender,
kRDMASenderCoordinator,
Expand Down Expand Up @@ -935,6 +940,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx);
}
}
#endif
}

void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta,
Expand Down
2 changes: 2 additions & 0 deletions csrc/kernels/launch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ cfg.dynamicSmemBytes = smem_size;
case 2: case_macro(2); \
case 4: case_macro(4); \
case 8: case_macro(8); \
case 24: case_macro(24); \
default: EP_HOST_ASSERT(false and "Unsupported ranks"); \
} while (false)

Expand All @@ -72,6 +73,7 @@ cfg.dynamicSmemBytes = smem_size;
case 2: case_macro(dtype, 2); \
case 4: case_macro(dtype, 4); \
case 8: case_macro(dtype, 8); \
case 24: case_macro(dtype, 24); \
default: EP_HOST_ASSERT(false && "Unsupported ranks"); \
} while (false)

Expand Down
Loading