diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index a29ed598..195b1856 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -12,6 +12,117 @@ #include "kernels/api.cuh" #include "kernels/configs.cuh" +namespace shared_memory { +void cu_mem_set_access_all(void* ptr, size_t size) { + int device_count; + CUDA_CHECK(cudaGetDeviceCount(&device_count)); + + CUmemAccessDesc access_desc[device_count]; + for (int idx = 0; idx < device_count; ++idx) { + access_desc[idx].location.type = CU_MEM_LOCATION_TYPE_DEVICE; + access_desc[idx].location.id = idx; + access_desc[idx].flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + } + + CU_CHECK(cuMemSetAccess((CUdeviceptr)ptr, size, access_desc, device_count)); +} + +void cu_mem_free(void* ptr) { + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr)); + + size_t size = 0; + CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr)); + + CU_CHECK(cuMemUnmap((CUdeviceptr)ptr, size)); + CU_CHECK(cuMemAddressFree((CUdeviceptr)ptr, size)); + CU_CHECK(cuMemRelease(handle)); +} + +size_t get_size_align_to_granularity(size_t size_raw, size_t granularity) { + size_t size = (size_raw + granularity - 1) & ~(granularity - 1); + if (size == 0) + size = granularity; + return size; +} + +SharedMemoryAllocator::SharedMemoryAllocator(bool use_fabric) : use_fabric(use_fabric) {} + +void SharedMemoryAllocator::malloc(void** ptr, size_t size_raw) { + if (use_fabric) { + CUdevice device; + CU_CHECK(cuCtxGetDevice(&device)); + + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC; + prop.location.id = device; + + size_t granularity = 0; + CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); + + size_t size = get_size_align_to_granularity(size_raw, granularity); + + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemCreate(&handle, size, &prop, 0)); + + CU_CHECK(cuMemAddressReserve((CUdeviceptr*)ptr, size, granularity, 0, 0)); + CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0)); + cu_mem_set_access_all(*ptr, size); + } else { + CUDA_CHECK(cudaMalloc(ptr, size_raw)); + } +} + +void SharedMemoryAllocator::free(void* ptr) { + if (use_fabric) { + cu_mem_free(ptr); + } else { + CUDA_CHECK(cudaFree(ptr)); + } +} + +void SharedMemoryAllocator::get_mem_handle(MemHandle* mem_handle, void* ptr) { + size_t size = 0; + CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr)); + + mem_handle->size = size; + + if (use_fabric) { + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr)); + + CU_CHECK(cuMemExportToShareableHandle(&mem_handle->inner.cu_mem_fabric_handle, handle, CU_MEM_HANDLE_TYPE_FABRIC, 0)); + } else { + CUDA_CHECK(cudaIpcGetMemHandle(&mem_handle->inner.cuda_ipc_mem_handle, ptr)); + } +} + +void SharedMemoryAllocator::open_mem_handle(void** ptr, MemHandle* mem_handle) { + if (use_fabric) { + size_t size = mem_handle->size; + + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemImportFromShareableHandle(&handle, &mem_handle->inner.cu_mem_fabric_handle, CU_MEM_HANDLE_TYPE_FABRIC)); + + CU_CHECK(cuMemAddressReserve((CUdeviceptr*)ptr, size, 0, 0, 0)); + CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0)); + cu_mem_set_access_all(*ptr, size); + } else { + CUDA_CHECK(cudaIpcOpenMemHandle(ptr, mem_handle->inner.cuda_ipc_mem_handle, cudaIpcMemLazyEnablePeerAccess)); + } +} + +void SharedMemoryAllocator::close_mem_handle(void* ptr) { + if (use_fabric) { + cu_mem_free(ptr); + } else { + CUDA_CHECK(cudaIpcCloseMemHandle(ptr)); + } +} +} // namespace shared_memory + namespace deep_ep { Buffer::Buffer(int rank, @@ -20,7 +131,8 @@ Buffer::Buffer(int rank, int64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy, - bool enable_shrink) + bool enable_shrink, + bool use_fabric) : rank(rank), num_ranks(num_ranks), num_nvl_bytes(num_nvl_bytes), @@ -28,7 +140,8 @@ Buffer::Buffer(int rank, enable_shrink(enable_shrink), low_latency_mode(low_latency_mode), explicitly_destroy(explicitly_destroy), - comm_stream(at::cuda::getStreamFromPool(true)) { + comm_stream(at::cuda::getStreamFromPool(true)), + shared_memory_allocator(use_fabric) { // Metadata memory int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int); int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void*); @@ -66,8 +179,9 @@ Buffer::Buffer(int rank, if (num_nvl_bytes > 0) { // Local IPC: alloc local memory and set local IPC handles - CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes)); - CUDA_CHECK(cudaIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank])); + shared_memory_allocator.malloc(&buffer_ptrs[nvl_rank], + num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes); + shared_memory_allocator.get_mem_handle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank]); buffer_ptrs_gpu = reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes); // Set barrier signals @@ -136,7 +250,8 @@ int Buffer::get_local_device_id() const { } pybind11::bytearray Buffer::get_local_ipc_handle() const { - return {ipc_handles[nvl_rank].reserved, CUDA_IPC_HANDLE_SIZE}; + const shared_memory::MemHandle& handle = ipc_handles[nvl_rank]; + return {reinterpret_cast(&handle), sizeof(handle)}; } pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const { @@ -176,11 +291,11 @@ void Buffer::destroy() { if (is_available()) { for (int i = 0; i < num_nvl_ranks; ++i) if (i != nvl_rank) - CUDA_CHECK(cudaIpcCloseMemHandle(buffer_ptrs[i])); + shared_memory_allocator.close_mem_handle(buffer_ptrs[i]); } // Free local buffer and error flag - CUDA_CHECK(cudaFree(buffer_ptrs[nvl_rank])); + shared_memory_allocator.free(buffer_ptrs[nvl_rank]); } // Free NVSHMEM @@ -220,13 +335,13 @@ void Buffer::sync(const std::vector& device_ids, for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks; ++i) { EP_HOST_ASSERT(all_gathered_handles[offset + i].has_value()); auto handle_str = std::string(all_gathered_handles[offset + i].value()); - EP_HOST_ASSERT(handle_str.size() == CUDA_IPC_HANDLE_SIZE); + EP_HOST_ASSERT(handle_str.size() == shared_memory::HANDLE_SIZE); if (offset + i != rank) { - std::memcpy(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE); - CUDA_CHECK(cudaIpcOpenMemHandle(&buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess)); + std::memcpy(&ipc_handles[i], handle_str.c_str(), shared_memory::HANDLE_SIZE); + shared_memory_allocator.open_mem_handle(&buffer_ptrs[i], &ipc_handles[i]); barrier_signal_ptrs[i] = reinterpret_cast(static_cast(buffer_ptrs[i]) + num_nvl_bytes); } else { - EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE) == 0); + EP_HOST_ASSERT(std::memcmp(&ipc_handles[i], handle_str.c_str(), shared_memory::HANDLE_SIZE) == 0); } } @@ -1739,7 +1854,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("current_stream_wait", &deep_ep::EventHandle::current_stream_wait); pybind11::class_(m, "Buffer") - .def(pybind11::init()) + .def(pybind11::init()) .def("is_available", &deep_ep::Buffer::is_available) .def("get_num_rdma_ranks", &deep_ep::Buffer::get_num_rdma_ranks) .def("get_rdma_rank", &deep_ep::Buffer::get_rdma_rank) diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index 7213bdd4..96a0f8f4 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -21,6 +21,34 @@ #define TORCH_EXTENSION_NAME deep_ep_cpp #endif +namespace shared_memory { + +union MemHandleInner { + cudaIpcMemHandle_t cuda_ipc_mem_handle; + CUmemFabricHandle cu_mem_fabric_handle; +}; + +struct MemHandle { + MemHandleInner inner; + size_t size; +}; + +constexpr size_t HANDLE_SIZE = sizeof(MemHandle); + +class SharedMemoryAllocator { +public: + SharedMemoryAllocator(bool use_fabric); + void malloc(void** ptr, size_t size); + void free(void* ptr); + void get_mem_handle(MemHandle* mem_handle, void* ptr); + void open_mem_handle(void** ptr, MemHandle* mem_handle); + void close_mem_handle(void* ptr); + +private: + bool use_fabric; +}; +} // namespace shared_memory + namespace deep_ep { struct Buffer { @@ -50,7 +78,7 @@ struct Buffer { int num_device_sms; int rank, rdma_rank, nvl_rank; int num_ranks, num_rdma_ranks, num_nvl_ranks; - cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS]; + shared_memory::MemHandle ipc_handles[NUM_MAX_NVL_PEERS]; // Stream for communication at::cuda::CUDAStream comm_stream; @@ -82,6 +110,8 @@ struct Buffer { volatile int* moe_recv_rdma_counter = nullptr; int* moe_recv_rdma_counter_mapped = nullptr; + shared_memory::SharedMemoryAllocator shared_memory_allocator; + public: Buffer(int rank, int num_ranks, @@ -89,7 +119,8 @@ struct Buffer { int64_t num_rdma_bytes, bool low_latency_mode, bool explicitly_destroy, - bool enable_shrink); + bool enable_shrink, + bool use_fabric); ~Buffer() noexcept(false); diff --git a/csrc/kernels/exception.cuh b/csrc/kernels/exception.cuh index 4c48f4e6..507efa28 100644 --- a/csrc/kernels/exception.cuh +++ b/csrc/kernels/exception.cuh @@ -31,6 +31,18 @@ public: } while (0) #endif +#ifndef CU_CHECK +#define CU_CHECK(cmd) \ + do { \ + CUresult e = (cmd); \ + if (e != CUDA_SUCCESS) { \ + const char* error_str = NULL; \ + cuGetErrorString(e, &error_str); \ + throw EPException("CU", __FILE__, __LINE__, std::string(error_str)); \ + } \ + } while (0) +#endif + #ifndef EP_HOST_ASSERT #define EP_HOST_ASSERT(cond) \ do { \ diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index bdf26e8e..aca01167 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -37,6 +37,7 @@ def __init__(self, num_qps_per_rank: int = 24, allow_nvlink_for_low_latency_mode: bool = True, allow_mnnvl: bool = False, + use_fabric: bool = False, explicitly_destroy: bool = False, enable_shrink: bool = False, comm: Optional["mpi4py.MPI.Comm"] = None) -> None: # noqa: F821 @@ -55,6 +56,7 @@ def __init__(self, Warning: PCIe connections may lead to errors due to memory ordering issues, please make sure all connections are via NVLink. allow_mnnvl: whether to allow MNNVL + use_fabric: whether to use fabric API for memory buffers. enable_shrink: whether to enable shrink mode. The enable mode allocates a mask buffer to support masking ranks dynamically. explicitly_destroy: If this flag is set to True, you need to explicitly call `destroy()` to release resources; otherwise, the resources will be released by the destructor. @@ -88,7 +90,7 @@ def all_gather_object(obj): self.explicitly_destroy = explicitly_destroy self.enable_shrink = enable_shrink self.runtime = deep_ep_cpp.Buffer(self.rank, self.group_size, num_nvl_bytes, num_rdma_bytes, low_latency_mode, explicitly_destroy, - enable_shrink) + enable_shrink, use_fabric) # Synchronize device IDs local_device_id = self.runtime.get_local_device_id() diff --git a/format.sh b/format.sh index f8abd277..6ab74344 100755 --- a/format.sh +++ b/format.sh @@ -184,6 +184,10 @@ if ! git diff --quiet &>/dev/null; then echo git --no-pager diff --name-only + echo 'You can also copy-paste the diff below to fix the lint:' + echo + git --no-pager diff + exit 1 fi diff --git a/setup.py b/setup.py index 2b51d4f4..e89a9dbb 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ def get_nvshmem_host_lib_name(base_dir): include_dirs = ['csrc/'] library_dirs = [] nvcc_dlink = [] - extra_link_args = [] + extra_link_args = ['-lcuda'] # NVSHMEM flags if disable_nvshmem: diff --git a/tests/test_intranode.py b/tests/test_intranode.py index 50dd5f17..48749137 100644 --- a/tests/test_intranode.py +++ b/tests/test_intranode.py @@ -275,7 +275,9 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): num_rdma_bytes, low_latency_mode=test_ll_compatibility, num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1), - explicitly_destroy=True) + explicitly_destroy=True, + allow_mnnvl=args.allow_mnnvl, + use_fabric=args.use_fabric) torch.manual_seed(rank) for i in (24, ): @@ -301,6 +303,8 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): parser.add_argument('--hidden', type=int, default=7168, help='Hidden dimension size (default: 7168)') parser.add_argument('--num-topk', type=int, default=8, help='Number of top-k experts (default: 8)') parser.add_argument('--num-experts', type=int, default=256, help='Number of experts (default: 256)') + parser.add_argument('--allow-mnnvl', action="store_true", help='Enable MNNVL support') + parser.add_argument('--use-fabric', action="store_true", help='Enable fabric mode') args = parser.parse_args() num_processes = args.num_processes