-
Notifications
You must be signed in to change notification settings - Fork 999
Support cuMem API in cross process shared memory management #217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
csrc/deep_ep.cpp
Outdated
|
|
||
| for (int device = 0; device < device_count; ++device) { | ||
| int support = 0; | ||
| CU_CHECK(cuDeviceGetAttribute(&support, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, device)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this check is not enough. see https://forums.developer.nvidia.com/t/cudevicegetattribute-shows-i-can-use-fabric-handle-but-actually-i-cannot/336426 , even if it says it is supported, we cannot use the allocation.
let me know if your environment says something different.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah that's weird.
let me know if your environment says something different.
In my environment the code does work. If there is no good way to correctly know fabric support, a workaround may be, let the users pass in a bool flag to say whether they want to enable this.
Applications that intend to use CU_MEM_HANDLE_TYPE_FABRIC based memory sharing must ensure: (1)
nvidia-caps-imex-channelscharacter device is created by the driver and is listed under /proc/devices (2) have at least one IMEX channel file accessible by the user launching the application.
Wondering whether that is related or not in your env
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my environment the code does work.
what does this mean? what do you get running the code?
I think you can run it with single node h100 + cuda 12.5+, without nvidia-caps-imex-channels set up, and see if cuDeviceGetAttribute tells you fabric handle is supported.
I added the fabric handle support in pytorch just now in pytorch/pytorch#156074 , i use an actual cumem call to see if the allocation is successful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does this mean? what do you get running the code?
I do not run on H100 but on some other devices (that's why I make this PR - o/w DeepEP will fail to startup) and the tests pass. (Originally thought the question "let me know if your environment says something different" means "does the code run on my env i.e. my device and software etc") I will have a check on single-node H100 and update the code later when having time.
i use an actual cumem call to see if the allocation is successful.
That looks reasonable
Btw the check was from https://github.com/kvcache-ai/Mooncake/blob/main/mooncake-transfer-engine/src/transport/nvlink_transport/nvlink_transport.cpp and I also checked NCCL code a bit. So if that check has issues, maybe Mooncake needs to be updated as well.
# Conflicts: # csrc/deep_ep.cpp
|
will merge to main branch? |
|
yes I hope so. will do it once having time (have other high priority task now) |
| void cu_mem_set_access_all(void* ptr, size_t size) { | ||
| int device_count; | ||
| CUDA_CHECK(cudaGetDeviceCount(&device_count)); | ||
|
|
||
| CUmemAccessDesc access_desc[device_count]; | ||
| for (int idx = 0; idx < device_count; ++idx) { | ||
| access_desc[idx].location.type = CU_MEM_LOCATION_TYPE_DEVICE; | ||
| access_desc[idx].location.id = idx; | ||
| access_desc[idx].flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; | ||
| } | ||
|
|
||
| CU_CHECK(cuMemSetAccess((CUdeviceptr)ptr, size, access_desc, device_count)); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this has an implicit assumption that all ranks see the same number of gpus.
a better practice would be the importer call cuMemSetAccess for itself after importing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do it just for simplicity and yes it can be changed
| CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr)); | ||
|
|
||
| CU_CHECK(cuMemExportToShareableHandle(&mem_handle->inner.cu_mem_fabric_handle, handle, CU_MEM_HANDLE_TYPE_FABRIC, 0)); | ||
| } else { | ||
| CUDA_CHECK(cudaIpcGetMemHandle(&mem_handle->inner.cuda_ipc_mem_handle, ptr)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mixing cumem and cudamalloc can be problematic 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems to be a constant bool flag if I understand correctly
youkaichao
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cumem APIs are fragile and error-prone. if possible, I'd suggest using some existing libraries to allocate such shared memory (e.g. pytorch), and deepep just uses that buffer without all these pains.
|
that is also reasonable, though for simplicity I choose to replace cuda malloc etc with almost equivalent cumem apis |
code is roughly like this; I will work on related things more which will also verify this PR more
EDIT: it works well on the target hardware; I will try to find out some time to beautify and generalize the code (probably some time later)