@@ -45,25 +45,10 @@ size_t get_size_align_to_granularity(size_t size_raw, size_t granularity) {
4545 return size;
4646}
4747
48- bool support_fabric () {
49- int device_count;
50- CUDA_CHECK (cudaGetDeviceCount (&device_count));
51-
52- for (int device = 0 ; device < device_count; ++device) {
53- int support = 0 ;
54- CU_CHECK (cuDeviceGetAttribute (&support, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, device));
55- if (!support) {
56- return false ;
57- }
58- }
59-
60- return true ;
61- }
62-
63- SharedMemoryAllocator::SharedMemoryAllocator () : enable_fabric(support_fabric()) {}
48+ SharedMemoryAllocator::SharedMemoryAllocator (bool use_fabric) : use_fabric(use_fabric) {}
6449
6550void SharedMemoryAllocator::malloc (void ** ptr, size_t size_raw) {
66- if (enable_fabric ) {
51+ if (use_fabric ) {
6752 CUdevice device;
6853 CU_CHECK (cuCtxGetDevice (&device));
6954
@@ -90,7 +75,7 @@ void SharedMemoryAllocator::malloc(void** ptr, size_t size_raw) {
9075}
9176
9277void SharedMemoryAllocator::free (void * ptr) {
93- if (enable_fabric ) {
78+ if (use_fabric ) {
9479 cu_mem_free (ptr);
9580 } else {
9681 CUDA_CHECK (cudaFree (ptr));
@@ -103,7 +88,7 @@ void SharedMemoryAllocator::get_mem_handle(MemHandle* mem_handle, void* ptr) {
10388
10489 mem_handle->size = size;
10590
106- if (enable_fabric ) {
91+ if (use_fabric ) {
10792 CUmemGenericAllocationHandle handle;
10893 CU_CHECK (cuMemRetainAllocationHandle (&handle, ptr));
10994
@@ -114,7 +99,7 @@ void SharedMemoryAllocator::get_mem_handle(MemHandle* mem_handle, void* ptr) {
11499}
115100
116101void SharedMemoryAllocator::open_mem_handle (void ** ptr, MemHandle* mem_handle) {
117- if (enable_fabric ) {
102+ if (use_fabric ) {
118103 size_t size = mem_handle->size ;
119104
120105 CUmemGenericAllocationHandle handle;
@@ -129,7 +114,7 @@ void SharedMemoryAllocator::open_mem_handle(void** ptr, MemHandle* mem_handle) {
129114}
130115
131116void SharedMemoryAllocator::close_mem_handle (void * ptr) {
132- if (enable_fabric ) {
117+ if (use_fabric ) {
133118 cu_mem_free (ptr);
134119 } else {
135120 CUDA_CHECK (cudaIpcCloseMemHandle (ptr));
@@ -145,15 +130,17 @@ Buffer::Buffer(int rank,
145130 int64_t num_rdma_bytes,
146131 bool low_latency_mode,
147132 bool explicitly_destroy,
148- bool enable_shrink)
133+ bool enable_shrink,
134+ bool use_fabric)
149135 : rank(rank),
150136 num_ranks (num_ranks),
151137 num_nvl_bytes(num_nvl_bytes),
152138 num_rdma_bytes(num_rdma_bytes),
153139 enable_shrink(enable_shrink),
154140 low_latency_mode(low_latency_mode),
155141 explicitly_destroy(explicitly_destroy),
156- comm_stream(at::cuda::getStreamFromPool(true )) {
142+ comm_stream(at::cuda::getStreamFromPool(true )),
143+ shared_memory_allocator(use_fabric) {
157144 // Metadata memory
158145 int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof (int );
159146 int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof (void *);
0 commit comments