Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
443bfa8
more
fzyzcjy Jun 17, 2025
b986cce
more
fzyzcjy Jun 17, 2025
3ea6f58
more
fzyzcjy Jun 17, 2025
5d3513b
more
fzyzcjy Jun 17, 2025
bda5695
more
fzyzcjy Jun 17, 2025
3740762
more
fzyzcjy Jun 17, 2025
ad4aee8
more
fzyzcjy Jun 17, 2025
b5e4aad
more
fzyzcjy Jun 17, 2025
240d058
more
fzyzcjy Jun 17, 2025
5379d59
more
fzyzcjy Jun 17, 2025
4fc8e79
more
fzyzcjy Jun 17, 2025
2e90afe
more
fzyzcjy Jun 17, 2025
3639a57
more
fzyzcjy Jun 17, 2025
4ef8f05
more
fzyzcjy Jun 17, 2025
047656e
more
fzyzcjy Jun 17, 2025
c21f36d
more
fzyzcjy Jun 17, 2025
7f3e4c0
more
fzyzcjy Jun 17, 2025
92fb573
more
fzyzcjy Jun 17, 2025
29f86f3
more
fzyzcjy Jun 17, 2025
5557e70
more
fzyzcjy Jun 17, 2025
9fd34e7
more
fzyzcjy Jun 17, 2025
6417393
more
fzyzcjy Jun 17, 2025
faaeaad
more
fzyzcjy Jun 17, 2025
c38dbed
more
fzyzcjy Jun 17, 2025
dc74c0a
more
fzyzcjy Jun 17, 2025
61dea30
more
fzyzcjy Jun 17, 2025
7d4bc93
more
fzyzcjy Jun 17, 2025
5b78f22
more
fzyzcjy Jun 17, 2025
75351cd
more
fzyzcjy Jun 17, 2025
7bb12d4
more
fzyzcjy Jun 17, 2025
0e5a155
more
fzyzcjy Jun 17, 2025
87b3980
more
fzyzcjy Jun 17, 2025
4398b5c
more
fzyzcjy Jun 17, 2025
d7e9ce3
more
fzyzcjy Jun 17, 2025
5b83cb8
more
fzyzcjy Jun 17, 2025
f024df5
more
fzyzcjy Jun 17, 2025
5a7b2f2
more
fzyzcjy Jun 17, 2025
6052379
more
fzyzcjy Jun 17, 2025
befcd27
more
fzyzcjy Jun 17, 2025
df598ea
more
fzyzcjy Jun 17, 2025
5b23a8a
more
fzyzcjy Jun 17, 2025
210e499
more
fzyzcjy Jun 17, 2025
379ac24
more
fzyzcjy Jun 17, 2025
43999dc
more
fzyzcjy Jun 17, 2025
7916011
more
fzyzcjy Jun 17, 2025
0525f8f
more
fzyzcjy Jun 17, 2025
c1d3606
Merge branch 'main-upstream_public' into feat/cu_mem_api
fzyzcjy Sep 1, 2025
84ff679
Merge branch 'main-upstream_public' into feat/cu_mem_api
fzyzcjy Oct 28, 2025
a465298
add flag
fzyzcjy Oct 28, 2025
2e61613
add test
fzyzcjy Oct 28, 2025
0d3a994
fix
fzyzcjy Oct 28, 2025
491eb95
more
fzyzcjy Oct 28, 2025
94fc763
apply
fzyzcjy Oct 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 127 additions & 12 deletions csrc/deep_ep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,117 @@
#include "kernels/api.cuh"
#include "kernels/configs.cuh"

namespace shared_memory {
void cu_mem_set_access_all(void* ptr, size_t size) {
int device_count;
CUDA_CHECK(cudaGetDeviceCount(&device_count));

CUmemAccessDesc access_desc[device_count];
for (int idx = 0; idx < device_count; ++idx) {
access_desc[idx].location.type = CU_MEM_LOCATION_TYPE_DEVICE;
access_desc[idx].location.id = idx;
access_desc[idx].flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
}

CU_CHECK(cuMemSetAccess((CUdeviceptr)ptr, size, access_desc, device_count));
}
Comment on lines +16 to +28
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this has an implicit assumption that all ranks see the same number of gpus.

a better practice would be the importer call cuMemSetAccess for itself after importing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do it just for simplicity and yes it can be changed


void cu_mem_free(void* ptr) {
CUmemGenericAllocationHandle handle;
CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr));

size_t size = 0;
CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr));

CU_CHECK(cuMemUnmap((CUdeviceptr)ptr, size));
CU_CHECK(cuMemAddressFree((CUdeviceptr)ptr, size));
CU_CHECK(cuMemRelease(handle));
}

size_t get_size_align_to_granularity(size_t size_raw, size_t granularity) {
size_t size = (size_raw + granularity - 1) & ~(granularity - 1);
if (size == 0)
size = granularity;
return size;
}

SharedMemoryAllocator::SharedMemoryAllocator(bool use_fabric) : use_fabric(use_fabric) {}

void SharedMemoryAllocator::malloc(void** ptr, size_t size_raw) {
if (use_fabric) {
CUdevice device;
CU_CHECK(cuCtxGetDevice(&device));

CUmemAllocationProp prop = {};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC;
prop.location.id = device;

size_t granularity = 0;
CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));

size_t size = get_size_align_to_granularity(size_raw, granularity);

CUmemGenericAllocationHandle handle;
CU_CHECK(cuMemCreate(&handle, size, &prop, 0));

CU_CHECK(cuMemAddressReserve((CUdeviceptr*)ptr, size, granularity, 0, 0));
CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0));
cu_mem_set_access_all(*ptr, size);
} else {
CUDA_CHECK(cudaMalloc(ptr, size_raw));
}
}

void SharedMemoryAllocator::free(void* ptr) {
if (use_fabric) {
cu_mem_free(ptr);
} else {
CUDA_CHECK(cudaFree(ptr));
}
}

void SharedMemoryAllocator::get_mem_handle(MemHandle* mem_handle, void* ptr) {
size_t size = 0;
CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr));

mem_handle->size = size;

if (use_fabric) {
CUmemGenericAllocationHandle handle;
CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr));

CU_CHECK(cuMemExportToShareableHandle(&mem_handle->inner.cu_mem_fabric_handle, handle, CU_MEM_HANDLE_TYPE_FABRIC, 0));
} else {
CUDA_CHECK(cudaIpcGetMemHandle(&mem_handle->inner.cuda_ipc_mem_handle, ptr));
Comment on lines +94 to +98
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mixing cumem and cudamalloc can be problematic 🤔

Copy link
Contributor Author

@fzyzcjy fzyzcjy Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems to be a constant bool flag if I understand correctly

}
}

void SharedMemoryAllocator::open_mem_handle(void** ptr, MemHandle* mem_handle) {
if (use_fabric) {
size_t size = mem_handle->size;

CUmemGenericAllocationHandle handle;
CU_CHECK(cuMemImportFromShareableHandle(&handle, &mem_handle->inner.cu_mem_fabric_handle, CU_MEM_HANDLE_TYPE_FABRIC));

CU_CHECK(cuMemAddressReserve((CUdeviceptr*)ptr, size, 0, 0, 0));
CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0));
cu_mem_set_access_all(*ptr, size);
} else {
CUDA_CHECK(cudaIpcOpenMemHandle(ptr, mem_handle->inner.cuda_ipc_mem_handle, cudaIpcMemLazyEnablePeerAccess));
}
}

void SharedMemoryAllocator::close_mem_handle(void* ptr) {
if (use_fabric) {
cu_mem_free(ptr);
} else {
CUDA_CHECK(cudaIpcCloseMemHandle(ptr));
}
}
} // namespace shared_memory

namespace deep_ep {

Buffer::Buffer(int rank,
Expand All @@ -20,15 +131,17 @@ Buffer::Buffer(int rank,
int64_t num_rdma_bytes,
bool low_latency_mode,
bool explicitly_destroy,
bool enable_shrink)
bool enable_shrink,
bool use_fabric)
: rank(rank),
num_ranks(num_ranks),
num_nvl_bytes(num_nvl_bytes),
num_rdma_bytes(num_rdma_bytes),
enable_shrink(enable_shrink),
low_latency_mode(low_latency_mode),
explicitly_destroy(explicitly_destroy),
comm_stream(at::cuda::getStreamFromPool(true)) {
comm_stream(at::cuda::getStreamFromPool(true)),
shared_memory_allocator(use_fabric) {
// Metadata memory
int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int);
int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void*);
Expand Down Expand Up @@ -66,8 +179,9 @@ Buffer::Buffer(int rank,

if (num_nvl_bytes > 0) {
// Local IPC: alloc local memory and set local IPC handles
CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes));
CUDA_CHECK(cudaIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank]));
shared_memory_allocator.malloc(&buffer_ptrs[nvl_rank],
num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes);
shared_memory_allocator.get_mem_handle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank]);
buffer_ptrs_gpu = reinterpret_cast<void**>(static_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes);

// Set barrier signals
Expand Down Expand Up @@ -136,7 +250,8 @@ int Buffer::get_local_device_id() const {
}

pybind11::bytearray Buffer::get_local_ipc_handle() const {
return {ipc_handles[nvl_rank].reserved, CUDA_IPC_HANDLE_SIZE};
const shared_memory::MemHandle& handle = ipc_handles[nvl_rank];
return {reinterpret_cast<const char*>(&handle), sizeof(handle)};
}

pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const {
Expand Down Expand Up @@ -176,11 +291,11 @@ void Buffer::destroy() {
if (is_available()) {
for (int i = 0; i < num_nvl_ranks; ++i)
if (i != nvl_rank)
CUDA_CHECK(cudaIpcCloseMemHandle(buffer_ptrs[i]));
shared_memory_allocator.close_mem_handle(buffer_ptrs[i]);
}

// Free local buffer and error flag
CUDA_CHECK(cudaFree(buffer_ptrs[nvl_rank]));
shared_memory_allocator.free(buffer_ptrs[nvl_rank]);
}

// Free NVSHMEM
Expand Down Expand Up @@ -220,13 +335,13 @@ void Buffer::sync(const std::vector<int>& device_ids,
for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks; ++i) {
EP_HOST_ASSERT(all_gathered_handles[offset + i].has_value());
auto handle_str = std::string(all_gathered_handles[offset + i].value());
EP_HOST_ASSERT(handle_str.size() == CUDA_IPC_HANDLE_SIZE);
EP_HOST_ASSERT(handle_str.size() == shared_memory::HANDLE_SIZE);
if (offset + i != rank) {
std::memcpy(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE);
CUDA_CHECK(cudaIpcOpenMemHandle(&buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess));
std::memcpy(&ipc_handles[i], handle_str.c_str(), shared_memory::HANDLE_SIZE);
shared_memory_allocator.open_mem_handle(&buffer_ptrs[i], &ipc_handles[i]);
barrier_signal_ptrs[i] = reinterpret_cast<int*>(static_cast<uint8_t*>(buffer_ptrs[i]) + num_nvl_bytes);
} else {
EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE) == 0);
EP_HOST_ASSERT(std::memcmp(&ipc_handles[i], handle_str.c_str(), shared_memory::HANDLE_SIZE) == 0);
}
}

Expand Down Expand Up @@ -1739,7 +1854,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("current_stream_wait", &deep_ep::EventHandle::current_stream_wait);

pybind11::class_<deep_ep::Buffer>(m, "Buffer")
.def(pybind11::init<int, int, int64_t, int64_t, bool, bool, bool>())
.def(pybind11::init<int, int, int64_t, int64_t, bool, bool, bool, bool>())
.def("is_available", &deep_ep::Buffer::is_available)
.def("get_num_rdma_ranks", &deep_ep::Buffer::get_num_rdma_ranks)
.def("get_rdma_rank", &deep_ep::Buffer::get_rdma_rank)
Expand Down
35 changes: 33 additions & 2 deletions csrc/deep_ep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,34 @@
#define TORCH_EXTENSION_NAME deep_ep_cpp
#endif

namespace shared_memory {

union MemHandleInner {
cudaIpcMemHandle_t cuda_ipc_mem_handle;
CUmemFabricHandle cu_mem_fabric_handle;
};

struct MemHandle {
MemHandleInner inner;
size_t size;
};

constexpr size_t HANDLE_SIZE = sizeof(MemHandle);

class SharedMemoryAllocator {
public:
SharedMemoryAllocator(bool use_fabric);
void malloc(void** ptr, size_t size);
void free(void* ptr);
void get_mem_handle(MemHandle* mem_handle, void* ptr);
void open_mem_handle(void** ptr, MemHandle* mem_handle);
void close_mem_handle(void* ptr);

private:
bool use_fabric;
};
} // namespace shared_memory

namespace deep_ep {

struct Buffer {
Expand Down Expand Up @@ -50,7 +78,7 @@ struct Buffer {
int num_device_sms;
int rank, rdma_rank, nvl_rank;
int num_ranks, num_rdma_ranks, num_nvl_ranks;
cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS];
shared_memory::MemHandle ipc_handles[NUM_MAX_NVL_PEERS];

// Stream for communication
at::cuda::CUDAStream comm_stream;
Expand Down Expand Up @@ -82,14 +110,17 @@ struct Buffer {
volatile int* moe_recv_rdma_counter = nullptr;
int* moe_recv_rdma_counter_mapped = nullptr;

shared_memory::SharedMemoryAllocator shared_memory_allocator;

public:
Buffer(int rank,
int num_ranks,
int64_t num_nvl_bytes,
int64_t num_rdma_bytes,
bool low_latency_mode,
bool explicitly_destroy,
bool enable_shrink);
bool enable_shrink,
bool use_fabric);

~Buffer() noexcept(false);

Expand Down
12 changes: 12 additions & 0 deletions csrc/kernels/exception.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ public:
} while (0)
#endif

#ifndef CU_CHECK
#define CU_CHECK(cmd) \
do { \
CUresult e = (cmd); \
if (e != CUDA_SUCCESS) { \
const char* error_str = NULL; \
cuGetErrorString(e, &error_str); \
throw EPException("CU", __FILE__, __LINE__, std::string(error_str)); \
} \
} while (0)
#endif

#ifndef EP_HOST_ASSERT
#define EP_HOST_ASSERT(cond) \
do { \
Expand Down
4 changes: 3 additions & 1 deletion deep_ep/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self,
num_qps_per_rank: int = 24,
allow_nvlink_for_low_latency_mode: bool = True,
allow_mnnvl: bool = False,
use_fabric: bool = False,
explicitly_destroy: bool = False,
enable_shrink: bool = False,
comm: Optional["mpi4py.MPI.Comm"] = None) -> None: # noqa: F821
Expand All @@ -55,6 +56,7 @@ def __init__(self,
Warning: PCIe connections may lead to errors due to memory ordering issues,
please make sure all connections are via NVLink.
allow_mnnvl: whether to allow MNNVL
use_fabric: whether to use fabric API for memory buffers.
enable_shrink: whether to enable shrink mode. The enable mode allocates a mask buffer to support masking ranks dynamically.
explicitly_destroy: If this flag is set to True, you need to explicitly call `destroy()` to release resources;
otherwise, the resources will be released by the destructor.
Expand Down Expand Up @@ -88,7 +90,7 @@ def all_gather_object(obj):
self.explicitly_destroy = explicitly_destroy
self.enable_shrink = enable_shrink
self.runtime = deep_ep_cpp.Buffer(self.rank, self.group_size, num_nvl_bytes, num_rdma_bytes, low_latency_mode, explicitly_destroy,
enable_shrink)
enable_shrink, use_fabric)

# Synchronize device IDs
local_device_id = self.runtime.get_local_device_id()
Expand Down
4 changes: 4 additions & 0 deletions format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ if ! git diff --quiet &>/dev/null; then
echo
git --no-pager diff --name-only

echo 'You can also copy-paste the diff below to fix the lint:'
echo
git --no-pager diff

exit 1
fi

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_nvshmem_host_lib_name(base_dir):
include_dirs = ['csrc/']
library_dirs = []
nvcc_dlink = []
extra_link_args = []
extra_link_args = ['-lcuda']

# NVSHMEM flags
if disable_nvshmem:
Expand Down
6 changes: 5 additions & 1 deletion tests/test_intranode.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,9 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_rdma_bytes,
low_latency_mode=test_ll_compatibility,
num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1),
explicitly_destroy=True)
explicitly_destroy=True,
allow_mnnvl=args.allow_mnnvl,
use_fabric=args.use_fabric)
torch.manual_seed(rank)

for i in (24, ):
Expand All @@ -301,6 +303,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
parser.add_argument('--hidden', type=int, default=7168, help='Hidden dimension size (default: 7168)')
parser.add_argument('--num-topk', type=int, default=8, help='Number of top-k experts (default: 8)')
parser.add_argument('--num-experts', type=int, default=256, help='Number of experts (default: 256)')
parser.add_argument('--allow-mnnvl', action="store_true", help='Enable MNNVL support')
parser.add_argument('--use-fabric', action="store_true", help='Enable fabric mode')
args = parser.parse_args()

num_processes = args.num_processes
Expand Down