diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 9c90178b..d872c037 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -10,6 +10,131 @@ #include "kernels/api.cuh" #include "kernels/configs.cuh" +namespace shared_memory { +void cu_mem_set_access_all(void* ptr, size_t size) { + int device_count; + CUDA_CHECK(cudaGetDeviceCount(&device_count)); + + CUmemAccessDesc access_desc[device_count]; + for (int idx = 0; idx < device_count; ++idx) { + access_desc[idx].location.type = CU_MEM_LOCATION_TYPE_DEVICE; + access_desc[idx].location.id = idx; + access_desc[idx].flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + } + + CU_CHECK(cuMemSetAccess((CUdeviceptr)ptr, size, access_desc, device_count)); +} + +void cu_mem_free(void* ptr) { + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr)); + + size_t size = 0; + CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr)); + + CU_CHECK(cuMemUnmap((CUdeviceptr)ptr, size)); + CU_CHECK(cuMemAddressFree((CUdeviceptr)ptr, size)); + CU_CHECK(cuMemRelease(handle)); +} + +size_t get_size_align_to_granularity(size_t size_raw, size_t granularity) { + size_t size = (size_raw + granularity - 1) & ~(granularity - 1); + if (size == 0) size = granularity; + return size; +} + +bool support_fabric() { + int device_count; + CUDA_CHECK(cudaGetDeviceCount(&device_count)); + + for (int device = 0; device < device_count; ++device) { + int support = 0; + CU_CHECK(cuDeviceGetAttribute(&support, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, device)); + if (!support) { + return false; + } + } + + return true; +} + +SharedMemoryAllocator::SharedMemoryAllocator() : enable_fabric(support_fabric()) {} + +void SharedMemoryAllocator::malloc(void** ptr, size_t size_raw) { + if (enable_fabric) { + CUdevice device; + CU_CHECK(cuCtxGetDevice(&device)); + + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC; + prop.location.id = device; + + size_t granularity = 0; + CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); + + size_t size = get_size_align_to_granularity(size_raw, granularity); + + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemCreate(&handle, size, &prop, 0)); + + CU_CHECK(cuMemAddressReserve((CUdeviceptr *)ptr, size, granularity, 0, 0)); + CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0)); + cu_mem_set_access_all(*ptr, size); + } else { + CUDA_CHECK(cudaMalloc(ptr, size_raw)); + } +} + +void SharedMemoryAllocator::free(void* ptr) { + if (enable_fabric) { + cu_mem_free(ptr); + } else { + CUDA_CHECK(cudaFree(ptr)); + } +} + +void SharedMemoryAllocator::get_mem_handle(MemHandle* mem_handle, void* ptr) { + size_t size = 0; + CU_CHECK(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr)); + + mem_handle->size = size; + + if (enable_fabric) { + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemRetainAllocationHandle(&handle, ptr)); + + CU_CHECK(cuMemExportToShareableHandle(&mem_handle->inner.cu_mem_fabric_handle, handle, CU_MEM_HANDLE_TYPE_FABRIC, 0)); + } else { + CUDA_CHECK(cudaIpcGetMemHandle(&mem_handle->inner.cuda_ipc_mem_handle, ptr)); + } +} + +void SharedMemoryAllocator::open_mem_handle(void** ptr, MemHandle* mem_handle) { + if (enable_fabric) { + size_t size = mem_handle->size; + + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemImportFromShareableHandle(&handle, &mem_handle->inner.cu_mem_fabric_handle, CU_MEM_HANDLE_TYPE_FABRIC)); + + CU_CHECK(cuMemAddressReserve((CUdeviceptr *)ptr, size, 0, 0, 0)); + CU_CHECK(cuMemMap((CUdeviceptr)*ptr, size, 0, handle, 0)); + cu_mem_set_access_all(*ptr, size); + } else { + CUDA_CHECK(cudaIpcOpenMemHandle(ptr, mem_handle->inner.cuda_ipc_mem_handle, cudaIpcMemLazyEnablePeerAccess)); + } +} + +void SharedMemoryAllocator::close_mem_handle(void* ptr) { + if (enable_fabric) { + cu_mem_free(ptr); + } else { + CUDA_CHECK(cudaIpcCloseMemHandle(ptr)); + } +} +} + namespace deep_ep { Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode): @@ -45,8 +170,8 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_ if (num_nvl_bytes > 0) { // Local IPC: alloc local memory and set local IPC handles - CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes)); - CUDA_CHECK(cudaIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank])); + shared_memory_allocator.malloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes); + shared_memory_allocator.get_mem_handle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank]); buffer_ptrs_gpu = reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes); // Set barrier signals @@ -92,11 +217,11 @@ Buffer::~Buffer() noexcept(false) { // Close remote IPC if (is_available()) { for (int i = 0; i < num_nvl_ranks; ++ i) if (i != nvl_rank) - CUDA_CHECK(cudaIpcCloseMemHandle(buffer_ptrs[i])); + shared_memory_allocator.close_mem_handle(buffer_ptrs[i]); } // Free local buffer and error flag - CUDA_CHECK(cudaFree(buffer_ptrs[nvl_rank])); + shared_memory_allocator.free(buffer_ptrs[nvl_rank]); } // Free NVSHMEM @@ -142,7 +267,8 @@ int Buffer::get_local_device_id() const { } pybind11::bytearray Buffer::get_local_ipc_handle() const { - return {ipc_handles[nvl_rank].reserved, CUDA_IPC_HANDLE_SIZE}; + const shared_memory::MemHandle& handle = ipc_handles[nvl_rank]; + return {reinterpret_cast(&handle), sizeof(handle)}; } pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const { @@ -175,13 +301,13 @@ void Buffer::sync(const std::vector &device_ids, for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks; ++ i) { EP_HOST_ASSERT(all_gathered_handles[offset + i].has_value()); auto handle_str = std::string(all_gathered_handles[offset + i].value()); - EP_HOST_ASSERT(handle_str.size() == CUDA_IPC_HANDLE_SIZE); + EP_HOST_ASSERT(handle_str.size() == shared_memory::HANDLE_SIZE); if (offset + i != rank) { - std::memcpy(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE); - CUDA_CHECK(cudaIpcOpenMemHandle(&buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess)); + std::memcpy(&ipc_handles[i], handle_str.c_str(), shared_memory::HANDLE_SIZE); + shared_memory_allocator.open_mem_handle(&buffer_ptrs[i], &ipc_handles[i]); barrier_signal_ptrs[i] = reinterpret_cast(static_cast(buffer_ptrs[i]) + num_nvl_bytes); } else { - EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE) == 0); + EP_HOST_ASSERT(std::memcmp(&ipc_handles[i], handle_str.c_str(), shared_memory::HANDLE_SIZE) == 0); } } diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index dfa2202d..cf7b090e 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -20,10 +20,40 @@ #define TORCH_EXTENSION_NAME deep_ep_cpp #endif +namespace shared_memory { + +union MemHandleInner { + cudaIpcMemHandle_t cuda_ipc_mem_handle; + CUmemFabricHandle cu_mem_fabric_handle; +}; + +struct MemHandle { + MemHandleInner inner; + size_t size; +}; + +constexpr size_t HANDLE_SIZE = sizeof(MemHandle); + +class SharedMemoryAllocator { +public: + SharedMemoryAllocator(); + void malloc(void** ptr, size_t size); + void free(void* ptr); + void get_mem_handle(MemHandle* mem_handle, void* ptr); + void open_mem_handle(void** ptr, MemHandle* mem_handle); + void close_mem_handle(void* ptr); +private: + bool enable_fabric; +}; +} + namespace deep_ep { struct Buffer { + +#ifndef NVLINK_DOMAIN_LARGE EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "The number of maximum NVLink peers must be 8"); +#endif private: // Low-latency mode buffer @@ -44,7 +74,7 @@ struct Buffer { int num_device_sms; int rank, rdma_rank, nvl_rank; int num_ranks, num_rdma_ranks, num_nvl_ranks; - cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS]; + shared_memory::MemHandle ipc_handles[NUM_MAX_NVL_PEERS]; // Stream for communication at::cuda::CUDAStream comm_stream; @@ -71,6 +101,8 @@ struct Buffer { volatile int* moe_recv_rdma_counter = nullptr; int* moe_recv_rdma_counter_mapped = nullptr; + shared_memory::SharedMemoryAllocator shared_memory_allocator; + public: Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode); diff --git a/csrc/kernels/configs.cuh b/csrc/kernels/configs.cuh index 8893b79e..1c78ab68 100644 --- a/csrc/kernels/configs.cuh +++ b/csrc/kernels/configs.cuh @@ -1,6 +1,13 @@ #pragma once +#define NVLINK_DOMAIN_LARGE + +#ifdef NVLINK_DOMAIN_LARGE +#define NUM_MAX_NVL_PEERS 24 +#else #define NUM_MAX_NVL_PEERS 8 +#endif + #define NUM_MAX_RDMA_PEERS 20 #define NUM_WORKSPACE_BYTES (32 * 1024 * 1024) #define NUM_MAX_LOCAL_EXPERTS 1024 diff --git a/csrc/kernels/exception.cuh b/csrc/kernels/exception.cuh index 7db0ddb7..3026374b 100644 --- a/csrc/kernels/exception.cuh +++ b/csrc/kernels/exception.cuh @@ -31,6 +31,18 @@ do { \ } while (0) #endif +#ifndef CU_CHECK +#define CU_CHECK(cmd) \ +do { \ + CUresult e = (cmd); \ + if (e != CUDA_SUCCESS) { \ + const char *error_str = NULL; \ + cuGetErrorString(e, &error_str); \ + throw EPException("CU", __FILE__, __LINE__, std::string(error_str)); \ + } \ +} while (0) +#endif + #ifndef EP_HOST_ASSERT #define EP_HOST_ASSERT(cond) \ do { \ diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index 4a33f17c..24621f64 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -14,7 +14,9 @@ extern nvshmem_team_t cpu_rdma_team; struct SourceMeta { int src_rdma_rank, is_token_in_nvl_rank_bits; +#ifndef NVLINK_DOMAIN_LARGE EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "Invalid number of maximum NVL peers"); +#endif __forceinline__ SourceMeta() = default; @@ -85,6 +87,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in void* rdma_buffer_ptr, void** buffer_ptrs, int** barrier_signal_ptrs, int rank, const nvshmem_team_t rdma_team) { +#ifndef NVLINK_DOMAIN_LARGE auto sm_id = static_cast(blockIdx.x); auto thread_id = static_cast(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id(); auto num_threads = static_cast(blockDim.x), num_warps = num_threads / 32; @@ -282,6 +285,7 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in prefix_row[i] += prefix_row[i - 1]; } } +#endif } void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, @@ -349,6 +353,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, int rank, int num_ranks) { +#ifndef NVLINK_DOMAIN_LARGE enum class WarpRole { kRDMASender, kRDMASenderCoordinator, @@ -935,6 +940,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx); } } +#endif } void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta, diff --git a/csrc/kernels/launch.cuh b/csrc/kernels/launch.cuh index 5b398bff..34ad59a1 100644 --- a/csrc/kernels/launch.cuh +++ b/csrc/kernels/launch.cuh @@ -52,6 +52,7 @@ cfg.dynamicSmemBytes = smem_size; case 2: case_macro(2); \ case 4: case_macro(4); \ case 8: case_macro(8); \ + case 24: case_macro(24); \ default: EP_HOST_ASSERT(false and "Unsupported ranks"); \ } while (false) @@ -72,6 +73,7 @@ cfg.dynamicSmemBytes = smem_size; case 2: case_macro(dtype, 2); \ case 4: case_macro(dtype, 4); \ case 8: case_macro(dtype, 8); \ + case 24: case_macro(dtype, 24); \ default: EP_HOST_ASSERT(false && "Unsupported ranks"); \ } while (false) diff --git a/csrc/kernels/layout.cu b/csrc/kernels/layout.cu index 829d5bc6..9d68efcd 100644 --- a/csrc/kernels/layout.cu +++ b/csrc/kernels/layout.cu @@ -51,12 +51,22 @@ get_dispatch_layout(const int64_t* topk_idx, EP_DEVICE_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS); // Count rank statistics +#ifndef NVLINK_DOMAIN_LARGE constexpr int kNumRDMARanksPerSM = kNumRanksPerSM / NUM_MAX_NVL_PEERS; +#endif + __shared__ int num_tokens_per_rank_per_thread[kNumThreads][kNumRanksPerSM]; +#ifndef NVLINK_DOMAIN_LARGE __shared__ int num_tokens_per_rdma_rank_per_thread[kNumThreads][kNumRDMARanksPerSM]; +#endif + auto sm_begin = (num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM; + int rank_begin_idx = (sm_id - sm_begin) * kNumRanksPerSM, rank_end_idx = min(rank_begin_idx + kNumRanksPerSM, num_ranks); +#ifndef NVLINK_DOMAIN_LARGE int rdma_rank_begin_idx = rank_begin_idx / NUM_MAX_NVL_PEERS, rdma_rank_end_idx = rank_end_idx / NUM_MAX_NVL_PEERS; +#endif + if (rank_begin_idx < rank_end_idx) { const auto num_expert_per_rank = num_experts / num_ranks; auto expert_begin = rank_begin_idx * num_expert_per_rank; @@ -66,20 +76,32 @@ get_dispatch_layout(const int64_t* topk_idx, #pragma unroll for (int i = 0; i < kNumRanksPerSM; ++ i) num_tokens_per_rank_per_thread[thread_id][i] = 0; + +#ifndef NVLINK_DOMAIN_LARGE #pragma unroll for (int i = 0; i < kNumRDMARanksPerSM; ++ i) num_tokens_per_rdma_rank_per_thread[thread_id][i] = 0; +#endif + #pragma unroll for (int i = thread_id; i < num_tokens; i += kNumThreads) { auto shifted_topk_idx = topk_idx + i * num_topk; - int is_in_rank[kNumRanksPerSM] = {0}, is_in_rdma_rank[kNumRDMARanksPerSM] = {0}; + + int is_in_rank[kNumRanksPerSM] = {0}; +#ifndef NVLINK_DOMAIN_LARGE + int is_in_rdma_rank[kNumRDMARanksPerSM] = {0}; +#endif + #pragma unroll for (int j = 0, expert_idx, rank_idx; j < num_topk; ++j) { expert_idx = static_cast(shifted_topk_idx[j]); if (expert_begin <= expert_idx and expert_idx < expert_end) { // Count single rank rank_idx = expert_idx / num_expert_per_rank - rank_begin_idx; - is_in_rank[rank_idx] ++, is_in_rdma_rank[rank_idx / NUM_MAX_NVL_PEERS] ++; + is_in_rank[rank_idx] ++; +#ifndef NVLINK_DOMAIN_LARGE + is_in_rdma_rank[rank_idx / NUM_MAX_NVL_PEERS] ++; +#endif } } @@ -90,9 +112,11 @@ get_dispatch_layout(const int64_t* topk_idx, num_tokens_per_rank_per_thread[thread_id][j] += (is_in_rank[j] > 0); } +#ifndef NVLINK_DOMAIN_LARGE #pragma unroll for (int j = 0; j + rdma_rank_begin_idx < rdma_rank_end_idx; ++ j) num_tokens_per_rdma_rank_per_thread[thread_id][j] += (is_in_rdma_rank[j] > 0); +#endif } __syncthreads(); @@ -106,6 +130,7 @@ get_dispatch_layout(const int64_t* topk_idx, num_tokens_per_rank[rank_begin_idx + thread_id] = sum; } +#ifndef NVLINK_DOMAIN_LARGE if (num_tokens_per_rdma_rank != nullptr and rdma_rank_begin_idx + thread_id < rdma_rank_end_idx) { int sum = 0; #pragma unroll @@ -113,6 +138,9 @@ get_dispatch_layout(const int64_t* topk_idx, sum += num_tokens_per_rdma_rank_per_thread[i][thread_id]; num_tokens_per_rdma_rank[rdma_rank_begin_idx + thread_id] = sum; } +#else + EP_DEVICE_ASSERT(num_tokens_per_rdma_rank == nullptr); +#endif } } @@ -123,7 +151,10 @@ void get_dispatch_layout(const int64_t* topk_idx, cudaStream_t stream) { constexpr int kNumThreads = 256, kNumExpertsPerSM = 32, kNumRanksPerSM = 8; int num_sms = ((num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM) + (num_ranks + kNumRanksPerSM - 1) / kNumRanksPerSM; + +#ifndef NVLINK_DOMAIN_LARGE EP_STATIC_ASSERT(kNumExpertsPerSM % NUM_MAX_NVL_PEERS == 0, "Invalid number of experts per SM"); +#endif SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); LAUNCH_KERNEL(&cfg, (get_dispatch_layout), diff --git a/setup.py b/setup.py index b16310a7..93294f74 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ include_dirs = ['csrc/'] library_dirs = [] nvcc_dlink = [] - extra_link_args = [] + extra_link_args = ['-lcuda'] # NVSHMEM flags if disable_nvshmem: diff --git a/tests/test_intranode.py b/tests/test_intranode.py index 14c81cf9..8b5ecd01 100644 --- a/tests/test_intranode.py +++ b/tests/test_intranode.py @@ -1,3 +1,4 @@ +import os import time import torch import torch.distributed as dist @@ -12,7 +13,12 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup): # Settings - num_tokens, hidden, num_topk, num_experts = 4096, 7168, 8, (256 // num_ranks) * num_ranks + # num_tokens, hidden, num_topk, num_experts = 4096, 7168, 8, (256 // num_ranks) * num_ranks + num_tokens = int(os.environ.get("DEEPEP_TEST_NUM_TOKENS", "4096")) + hidden = int(os.environ.get("DEEPEP_TEST_HIDDEN", "7168")) + num_topk = int(os.environ.get("DEEPEP_TEST_NUM_TOPK", "8")) + num_experts = int(os.environ.get("DEEPEP_TEST_NUM_EXPERTS", str((256 // num_ranks) * num_ranks))) + assert num_experts % num_ranks == 0 if local_rank == 0: print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}', flush=True) @@ -184,9 +190,9 @@ def check_data(check_x, rank_prefix_matrix): best_time, best_results = t, (num_sms, nvl_chunk_size) if local_rank == 0: print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: ' - f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True) + f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) t={t * 1e3}ms', flush=True) if local_rank == 0: - print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True) + print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL) t={best_time * 1e3}ms', flush=True) print('', flush=True) # Gather the best config from rank 0 and the first test setting @@ -215,12 +221,12 @@ def check_data(check_x, rank_prefix_matrix): t = bench(lambda: buffer.combine(**tune_args))[0] if local_rank == 0: print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: ' - f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', flush=True) + f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) t={t * 1e3}ms', flush=True) if t < best_time and nvl_chunk_size > 0: best_time, best_results = t, (num_sms, nvl_chunk_size) if local_rank == 0: - print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)', flush=True) + print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL) t={best_time * 1e3}ms', flush=True) print('', flush=True) @@ -236,7 +242,9 @@ def test_loop(local_rank: int, num_local_ranks: int): num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1)) torch.manual_seed(rank) - for i in (24, ): + num_sms = int(os.environ.get("DEEPEP_TEST_NUM_SMS", "24")) + + for i in (num_sms, ): test_main(i, local_rank, num_ranks, rank, buffer, group) if local_rank == 0: print('', flush=True) @@ -252,5 +260,6 @@ def test_loop(local_rank: int, num_local_ranks: int): if __name__ == '__main__': - num_processes = 8 + # num_processes = 8 + num_processes = int(os.environ.get("DEEPEP_TEST_NUM_PROCESSES", "8")) torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes) diff --git a/tests/utils.py b/tests/utils.py index 1a9c176e..5164f89f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -12,7 +12,9 @@ def init_dist(local_rank: int, num_local_ranks: int): port = int(os.getenv('MASTER_PORT', '8361')) num_nodes = int(os.getenv('WORLD_SIZE', 1)) node_rank = int(os.getenv('RANK', 0)) - assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8 + + print('HACK: remove init_dist assertion') + # assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8 dist.init_process_group( backend='nccl',