Skip to content

Commit a465298

Browse files
committed
add flag
1 parent 84ff679 commit a465298

File tree

3 files changed

+17
-27
lines changed

3 files changed

+17
-27
lines changed

csrc/deep_ep.cpp

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -45,25 +45,10 @@ size_t get_size_align_to_granularity(size_t size_raw, size_t granularity) {
4545
return size;
4646
}
4747

48-
bool support_fabric() {
49-
int device_count;
50-
CUDA_CHECK(cudaGetDeviceCount(&device_count));
51-
52-
for (int device = 0; device < device_count; ++device) {
53-
int support = 0;
54-
CU_CHECK(cuDeviceGetAttribute(&support, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, device));
55-
if (!support) {
56-
return false;
57-
}
58-
}
59-
60-
return true;
61-
}
62-
63-
SharedMemoryAllocator::SharedMemoryAllocator() : enable_fabric(support_fabric()) {}
48+
SharedMemoryAllocator::SharedMemoryAllocator(bool use_fabric) : use_fabric(use_fabric) {}
6449

6550
void SharedMemoryAllocator::malloc(void** ptr, size_t size_raw) {
66-
if (enable_fabric) {
51+
if (use_fabric) {
6752
CUdevice device;
6853
CU_CHECK(cuCtxGetDevice(&device));
6954

@@ -90,7 +75,7 @@ void SharedMemoryAllocator::malloc(void** ptr, size_t size_raw) {
9075
}
9176

9277
void SharedMemoryAllocator::free(void* ptr) {
93-
if (enable_fabric) {
78+
if (use_fabric) {
9479
cu_mem_free(ptr);
9580
} else {
9681
CUDA_CHECK(cudaFree(ptr));
@@ -103,7 +88,7 @@ void SharedMemoryAllocator::get_mem_handle(MemHandle* mem_handle, void* ptr) {
10388

10489
mem_handle->size = size;
10590

106-
if (enable_fabric) {
91+
if (use_fabric) {
10792
CUmemGenericAllocationHandle handle;
10893
CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr));
10994

@@ -114,7 +99,7 @@ void SharedMemoryAllocator::get_mem_handle(MemHandle* mem_handle, void* ptr) {
11499
}
115100

116101
void SharedMemoryAllocator::open_mem_handle(void** ptr, MemHandle* mem_handle) {
117-
if (enable_fabric) {
102+
if (use_fabric) {
118103
size_t size = mem_handle->size;
119104

120105
CUmemGenericAllocationHandle handle;
@@ -129,7 +114,7 @@ void SharedMemoryAllocator::open_mem_handle(void** ptr, MemHandle* mem_handle) {
129114
}
130115

131116
void SharedMemoryAllocator::close_mem_handle(void* ptr) {
132-
if (enable_fabric) {
117+
if (use_fabric) {
133118
cu_mem_free(ptr);
134119
} else {
135120
CUDA_CHECK(cudaIpcCloseMemHandle(ptr));
@@ -145,15 +130,17 @@ Buffer::Buffer(int rank,
145130
int64_t num_rdma_bytes,
146131
bool low_latency_mode,
147132
bool explicitly_destroy,
148-
bool enable_shrink)
133+
bool enable_shrink,
134+
bool use_fabric)
149135
: rank(rank),
150136
num_ranks(num_ranks),
151137
num_nvl_bytes(num_nvl_bytes),
152138
num_rdma_bytes(num_rdma_bytes),
153139
enable_shrink(enable_shrink),
154140
low_latency_mode(low_latency_mode),
155141
explicitly_destroy(explicitly_destroy),
156-
comm_stream(at::cuda::getStreamFromPool(true)) {
142+
comm_stream(at::cuda::getStreamFromPool(true)),
143+
shared_memory_allocator(use_fabric) {
157144
// Metadata memory
158145
int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int);
159146
int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void*);

csrc/deep_ep.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,14 @@ constexpr size_t HANDLE_SIZE = sizeof(MemHandle);
3737

3838
class SharedMemoryAllocator {
3939
public:
40-
SharedMemoryAllocator();
40+
SharedMemoryAllocator(bool use_fabric);
4141
void malloc(void** ptr, size_t size);
4242
void free(void* ptr);
4343
void get_mem_handle(MemHandle* mem_handle, void* ptr);
4444
void open_mem_handle(void** ptr, MemHandle* mem_handle);
4545
void close_mem_handle(void* ptr);
4646
private:
47-
bool enable_fabric;
47+
bool use_fabric;
4848
};
4949
}
5050

@@ -118,7 +118,8 @@ struct Buffer {
118118
int64_t num_rdma_bytes,
119119
bool low_latency_mode,
120120
bool explicitly_destroy,
121-
bool enable_shrink);
121+
bool enable_shrink,
122+
bool use_fabric);
122123

123124
~Buffer() noexcept(false);
124125

deep_ep/buffer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(self,
3737
num_qps_per_rank: int = 24,
3838
allow_nvlink_for_low_latency_mode: bool = True,
3939
allow_mnnvl: bool = False,
40+
use_fabric: bool = False,
4041
explicitly_destroy: bool = False,
4142
enable_shrink: bool = False,
4243
comm: Optional["mpi4py.MPI.Comm"] = None) -> None: # noqa: F821
@@ -55,6 +56,7 @@ def __init__(self,
5556
Warning: PCIe connections may lead to errors due to memory ordering issues,
5657
please make sure all connections are via NVLink.
5758
allow_mnnvl: whether to allow MNNVL
59+
use_fabric: whether to use fabric API for memory buffers.
5860
enable_shrink: whether to enable shrink mode. The enable mode allocates a mask buffer to support masking ranks dynamically.
5961
explicitly_destroy: If this flag is set to True, you need to explicitly call `destroy()` to release resources;
6062
otherwise, the resources will be released by the destructor.
@@ -88,7 +90,7 @@ def all_gather_object(obj):
8890
self.explicitly_destroy = explicitly_destroy
8991
self.enable_shrink = enable_shrink
9092
self.runtime = deep_ep_cpp.Buffer(self.rank, self.group_size, num_nvl_bytes, num_rdma_bytes, low_latency_mode, explicitly_destroy,
91-
enable_shrink)
93+
enable_shrink, use_fabric)
9294

9395
# Synchronize device IDs
9496
local_device_id = self.runtime.get_local_device_id()

0 commit comments

Comments
 (0)