-
Notifications
You must be signed in to change notification settings - Fork 1k
Support cuMem API in cross process shared memory management #217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+185
−17
Merged
Changes from all commits
Commits
Show all changes
53 commits
Select commit
Hold shift + click to select a range
443bfa8
more
fzyzcjy b986cce
more
fzyzcjy 3ea6f58
more
fzyzcjy 5d3513b
more
fzyzcjy bda5695
more
fzyzcjy 3740762
more
fzyzcjy ad4aee8
more
fzyzcjy b5e4aad
more
fzyzcjy 240d058
more
fzyzcjy 5379d59
more
fzyzcjy 4fc8e79
more
fzyzcjy 2e90afe
more
fzyzcjy 3639a57
more
fzyzcjy 4ef8f05
more
fzyzcjy 047656e
more
fzyzcjy c21f36d
more
fzyzcjy 7f3e4c0
more
fzyzcjy 92fb573
more
fzyzcjy 29f86f3
more
fzyzcjy 5557e70
more
fzyzcjy 9fd34e7
more
fzyzcjy 6417393
more
fzyzcjy faaeaad
more
fzyzcjy c38dbed
more
fzyzcjy dc74c0a
more
fzyzcjy 61dea30
more
fzyzcjy 7d4bc93
more
fzyzcjy 5b78f22
more
fzyzcjy 75351cd
more
fzyzcjy 7bb12d4
more
fzyzcjy 0e5a155
more
fzyzcjy 87b3980
more
fzyzcjy 4398b5c
more
fzyzcjy d7e9ce3
more
fzyzcjy 5b83cb8
more
fzyzcjy f024df5
more
fzyzcjy 5a7b2f2
more
fzyzcjy 6052379
more
fzyzcjy befcd27
more
fzyzcjy df598ea
more
fzyzcjy 5b23a8a
more
fzyzcjy 210e499
more
fzyzcjy 379ac24
more
fzyzcjy 43999dc
more
fzyzcjy 7916011
more
fzyzcjy 0525f8f
more
fzyzcjy c1d3606
Merge branch 'main-upstream_public' into feat/cu_mem_api
fzyzcjy 84ff679
Merge branch 'main-upstream_public' into feat/cu_mem_api
fzyzcjy a465298
add flag
fzyzcjy 2e61613
add test
fzyzcjy 0d3a994
fix
fzyzcjy 491eb95
more
fzyzcjy 94fc763
apply
fzyzcjy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,117 @@ | |
| #include "kernels/api.cuh" | ||
| #include "kernels/configs.cuh" | ||
|
|
||
| namespace shared_memory { | ||
| void cu_mem_set_access_all(void* ptr, size_t size) { | ||
| int device_count; | ||
| CUDA_CHECK(cudaGetDeviceCount(&device_count)); | ||
|
|
||
| CUmemAccessDesc access_desc[device_count]; | ||
| for (int idx = 0; idx < device_count; ++idx) { | ||
| access_desc[idx].location.type = CU_MEM_LOCATION_TYPE_DEVICE; | ||
| access_desc[idx].location.id = idx; | ||
| access_desc[idx].flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; | ||
| } | ||
|
|
||
| CU_CHECK(cuMemSetAccess((CUdeviceptr)ptr, size, access_desc, device_count)); | ||
| } | ||
|
|
||
| void cu_mem_free(void* ptr) { | ||
| CUmemGenericAllocationHandle handle; | ||
| CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr)); | ||
|
|
||
| size_t size = 0; | ||
| CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr)); | ||
|
|
||
| CU_CHECK(cuMemUnmap((CUdeviceptr)ptr, size)); | ||
| CU_CHECK(cuMemAddressFree((CUdeviceptr)ptr, size)); | ||
| CU_CHECK(cuMemRelease(handle)); | ||
| } | ||
|
|
||
| size_t get_size_align_to_granularity(size_t size_raw, size_t granularity) { | ||
| size_t size = (size_raw + granularity - 1) & ~(granularity - 1); | ||
| if (size == 0) | ||
| size = granularity; | ||
| return size; | ||
| } | ||
|
|
||
| SharedMemoryAllocator::SharedMemoryAllocator(bool use_fabric) : use_fabric(use_fabric) {} | ||
|
|
||
| void SharedMemoryAllocator::malloc(void** ptr, size_t size_raw) { | ||
| if (use_fabric) { | ||
| CUdevice device; | ||
| CU_CHECK(cuCtxGetDevice(&device)); | ||
|
|
||
| CUmemAllocationProp prop = {}; | ||
| prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; | ||
| prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; | ||
| prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC; | ||
| prop.location.id = device; | ||
|
|
||
| size_t granularity = 0; | ||
| CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); | ||
|
|
||
| size_t size = get_size_align_to_granularity(size_raw, granularity); | ||
|
|
||
| CUmemGenericAllocationHandle handle; | ||
| CU_CHECK(cuMemCreate(&handle, size, &prop, 0)); | ||
|
|
||
| CU_CHECK(cuMemAddressReserve((CUdeviceptr*)ptr, size, granularity, 0, 0)); | ||
| CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0)); | ||
| cu_mem_set_access_all(*ptr, size); | ||
| } else { | ||
| CUDA_CHECK(cudaMalloc(ptr, size_raw)); | ||
| } | ||
| } | ||
|
|
||
| void SharedMemoryAllocator::free(void* ptr) { | ||
| if (use_fabric) { | ||
| cu_mem_free(ptr); | ||
| } else { | ||
| CUDA_CHECK(cudaFree(ptr)); | ||
| } | ||
| } | ||
|
|
||
| void SharedMemoryAllocator::get_mem_handle(MemHandle* mem_handle, void* ptr) { | ||
| size_t size = 0; | ||
| CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr)); | ||
|
|
||
| mem_handle->size = size; | ||
|
|
||
| if (use_fabric) { | ||
| CUmemGenericAllocationHandle handle; | ||
| CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr)); | ||
|
|
||
| CU_CHECK(cuMemExportToShareableHandle(&mem_handle->inner.cu_mem_fabric_handle, handle, CU_MEM_HANDLE_TYPE_FABRIC, 0)); | ||
| } else { | ||
| CUDA_CHECK(cudaIpcGetMemHandle(&mem_handle->inner.cuda_ipc_mem_handle, ptr)); | ||
|
Comment on lines
+94
to
+98
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mixing cumem and cudamalloc can be problematic 🤔
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it seems to be a constant bool flag if I understand correctly |
||
| } | ||
| } | ||
|
|
||
| void SharedMemoryAllocator::open_mem_handle(void** ptr, MemHandle* mem_handle) { | ||
| if (use_fabric) { | ||
| size_t size = mem_handle->size; | ||
|
|
||
| CUmemGenericAllocationHandle handle; | ||
| CU_CHECK(cuMemImportFromShareableHandle(&handle, &mem_handle->inner.cu_mem_fabric_handle, CU_MEM_HANDLE_TYPE_FABRIC)); | ||
|
|
||
| CU_CHECK(cuMemAddressReserve((CUdeviceptr*)ptr, size, 0, 0, 0)); | ||
| CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0)); | ||
| cu_mem_set_access_all(*ptr, size); | ||
| } else { | ||
| CUDA_CHECK(cudaIpcOpenMemHandle(ptr, mem_handle->inner.cuda_ipc_mem_handle, cudaIpcMemLazyEnablePeerAccess)); | ||
| } | ||
| } | ||
|
|
||
| void SharedMemoryAllocator::close_mem_handle(void* ptr) { | ||
| if (use_fabric) { | ||
| cu_mem_free(ptr); | ||
| } else { | ||
| CUDA_CHECK(cudaIpcCloseMemHandle(ptr)); | ||
| } | ||
| } | ||
| } // namespace shared_memory | ||
|
|
||
| namespace deep_ep { | ||
|
|
||
| Buffer::Buffer(int rank, | ||
|
|
@@ -20,15 +131,17 @@ Buffer::Buffer(int rank, | |
| int64_t num_rdma_bytes, | ||
| bool low_latency_mode, | ||
| bool explicitly_destroy, | ||
| bool enable_shrink) | ||
| bool enable_shrink, | ||
| bool use_fabric) | ||
| : rank(rank), | ||
| num_ranks(num_ranks), | ||
| num_nvl_bytes(num_nvl_bytes), | ||
| num_rdma_bytes(num_rdma_bytes), | ||
| enable_shrink(enable_shrink), | ||
| low_latency_mode(low_latency_mode), | ||
| explicitly_destroy(explicitly_destroy), | ||
| comm_stream(at::cuda::getStreamFromPool(true)) { | ||
| comm_stream(at::cuda::getStreamFromPool(true)), | ||
| shared_memory_allocator(use_fabric) { | ||
| // Metadata memory | ||
| int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int); | ||
| int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void*); | ||
|
|
@@ -66,8 +179,9 @@ Buffer::Buffer(int rank, | |
|
|
||
| if (num_nvl_bytes > 0) { | ||
| // Local IPC: alloc local memory and set local IPC handles | ||
| CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes)); | ||
| CUDA_CHECK(cudaIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank])); | ||
| shared_memory_allocator.malloc(&buffer_ptrs[nvl_rank], | ||
| num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes); | ||
| shared_memory_allocator.get_mem_handle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank]); | ||
| buffer_ptrs_gpu = reinterpret_cast<void**>(static_cast<uint8_t*>(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes); | ||
|
|
||
| // Set barrier signals | ||
|
|
@@ -136,7 +250,8 @@ int Buffer::get_local_device_id() const { | |
| } | ||
|
|
||
| pybind11::bytearray Buffer::get_local_ipc_handle() const { | ||
| return {ipc_handles[nvl_rank].reserved, CUDA_IPC_HANDLE_SIZE}; | ||
| const shared_memory::MemHandle& handle = ipc_handles[nvl_rank]; | ||
| return {reinterpret_cast<const char*>(&handle), sizeof(handle)}; | ||
| } | ||
|
|
||
| pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const { | ||
|
|
@@ -176,11 +291,11 @@ void Buffer::destroy() { | |
| if (is_available()) { | ||
| for (int i = 0; i < num_nvl_ranks; ++i) | ||
| if (i != nvl_rank) | ||
| CUDA_CHECK(cudaIpcCloseMemHandle(buffer_ptrs[i])); | ||
| shared_memory_allocator.close_mem_handle(buffer_ptrs[i]); | ||
| } | ||
|
|
||
| // Free local buffer and error flag | ||
| CUDA_CHECK(cudaFree(buffer_ptrs[nvl_rank])); | ||
| shared_memory_allocator.free(buffer_ptrs[nvl_rank]); | ||
| } | ||
|
|
||
| // Free NVSHMEM | ||
|
|
@@ -220,13 +335,13 @@ void Buffer::sync(const std::vector<int>& device_ids, | |
| for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks; ++i) { | ||
| EP_HOST_ASSERT(all_gathered_handles[offset + i].has_value()); | ||
| auto handle_str = std::string(all_gathered_handles[offset + i].value()); | ||
| EP_HOST_ASSERT(handle_str.size() == CUDA_IPC_HANDLE_SIZE); | ||
| EP_HOST_ASSERT(handle_str.size() == shared_memory::HANDLE_SIZE); | ||
| if (offset + i != rank) { | ||
| std::memcpy(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE); | ||
| CUDA_CHECK(cudaIpcOpenMemHandle(&buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess)); | ||
| std::memcpy(&ipc_handles[i], handle_str.c_str(), shared_memory::HANDLE_SIZE); | ||
| shared_memory_allocator.open_mem_handle(&buffer_ptrs[i], &ipc_handles[i]); | ||
| barrier_signal_ptrs[i] = reinterpret_cast<int*>(static_cast<uint8_t*>(buffer_ptrs[i]) + num_nvl_bytes); | ||
| } else { | ||
| EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE) == 0); | ||
| EP_HOST_ASSERT(std::memcmp(&ipc_handles[i], handle_str.c_str(), shared_memory::HANDLE_SIZE) == 0); | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -1739,7 +1854,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
| .def("current_stream_wait", &deep_ep::EventHandle::current_stream_wait); | ||
|
|
||
| pybind11::class_<deep_ep::Buffer>(m, "Buffer") | ||
| .def(pybind11::init<int, int, int64_t, int64_t, bool, bool, bool>()) | ||
| .def(pybind11::init<int, int, int64_t, int64_t, bool, bool, bool, bool>()) | ||
| .def("is_available", &deep_ep::Buffer::is_available) | ||
| .def("get_num_rdma_ranks", &deep_ep::Buffer::get_num_rdma_ranks) | ||
| .def("get_rdma_rank", &deep_ep::Buffer::get_rdma_rank) | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this has an implicit assumption that all ranks see the same number of gpus.
a better practice would be the importer call
cuMemSetAccessfor itself after importing.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do it just for simplicity and yes it can be changed