diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 0961de23..9d91b76b 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -1277,7 +1277,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id const std::optional& combine_wait_recv_cost_stats, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook, - const std::optional& out) { + const std::optional& out, + bool overlap, const std::optional& src_signals, uint32_t src_signal_expect_value) { #ifndef DISABLE_NVSHMEM EP_HOST_ASSERT(low_latency_mode); @@ -1347,7 +1348,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id num_topk, num_experts, rank, num_ranks, use_logfmt, workspace, num_device_sms, - launch_stream, phases, zero_copy); + launch_stream, phases, zero_copy, + overlap, src_signals.has_value() ? src_signals->data_ptr() : nullptr, src_signal_expect_value); }; launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index ff3015d7..383e88de 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -180,7 +180,8 @@ struct Buffer { const std::optional& combine_wait_recv_cost_stats, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook, - const std::optional& out = std::nullopt); + const std::optional& out = std::nullopt, + bool overlap = false, const std::optional& src_signals = std::nullopt, uint32_t src_signal_expect_value = 0); torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const; diff --git a/csrc/kernels/api.cuh b/csrc/kernels/api.cuh index 276fc01f..3f12e48d 100644 --- a/csrc/kernels/api.cuh +++ b/csrc/kernels/api.cuh @@ -166,7 +166,8 @@ void combine(void* combined_x, int num_topk, int num_experts, int rank, int num_ranks, bool use_logfmt, void* workspace, int num_device_sms, - cudaStream_t stream, int phases, bool zero_copy); + cudaStream_t stream, int phases, bool zero_copy, + bool overlap, uint32_t* src_signals, uint32_t src_signal_expect_value); } // namespace internode_ll diff --git a/csrc/kernels/exception.cuh b/csrc/kernels/exception.cuh index 7db0ddb7..3026374b 100644 --- a/csrc/kernels/exception.cuh +++ b/csrc/kernels/exception.cuh @@ -31,6 +31,18 @@ do { \ } while (0) #endif +#ifndef CU_CHECK +#define CU_CHECK(cmd) \ +do { \ + CUresult e = (cmd); \ + if (e != CUDA_SUCCESS) { \ + const char *error_str = NULL; \ + cuGetErrorString(e, &error_str); \ + throw EPException("CU", __FILE__, __LINE__, std::string(error_str)); \ + } \ +} while (0) +#endif + #ifndef EP_HOST_ASSERT #define EP_HOST_ASSERT(cond) \ do { \ diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index d8f7d24b..50463bbf 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -769,6 +769,495 @@ __forceinline__ __device__ void decode_and_accumulate(uint32_t* ld_buffer, float } } +template +__global__ +__launch_bounds__(1024, 1) +void +combine_overlappable(void* combined_x, + void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x, + const void* x, const int64_t* topk_idx, const float* topk_weights, + const int* src_info, const int64_t* layout_range, + int64_t* combine_wait_recv_cost_stats, + int* next_clean, int num_next_clean_int, + int* atomic_clean_flag, + int num_combined_tokens, int hidden, int num_topk, + int num_max_dispatch_tokens_per_rank, + int num_experts, int rank, int num_ranks, + int num_warp_groups, int num_warps_per_group, + int phases, bool zero_copy, + uint32_t* src_signals, uint32_t src_signal_expect_value) { + const auto sm_id = __shfl_sync(0xffffffff, static_cast(blockIdx.x), 0); + const auto num_sms = __shfl_sync(0xffffffff, static_cast(gridDim.x), 0); + const auto thread_id = static_cast(threadIdx.x); + const auto num_threads = __shfl_sync(0xffffffff, static_cast(blockDim.x), 0); + const auto warp_id = __shfl_sync(0xffffffff, thread_id / 32, 0), lane_id = get_lane_id(); + const auto num_local_experts = num_experts / num_ranks; + const auto warp_group_id = warp_id / num_warps_per_group; + const auto sub_warp_id = warp_id % num_warps_per_group; + const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id; + + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // Data type staffs + constexpr int kNumElemsPerInt4 = sizeof(int4) / sizeof(nv_bfloat16); + constexpr int64_t hidden_bf16_int4 = kHidden / kNumElemsPerInt4; + + // Use different unroll factors for send and recv phases + constexpr int kNumSendUnrolls = kHidden % (32 * 4 * sizeof(int4) / sizeof(nv_bfloat16)) == 0 ? 4 : 2; + constexpr int kNumRecvUnrolls = 2; + constexpr int hidden_bf16_int4_pad = align_up(static_cast(hidden_bf16_int4), 32 * kNumSendUnrolls); + EP_STATIC_ASSERT(kHidden % (32 * 2 * sizeof(int4) / sizeof(nv_bfloat16)) == 0, "Invalid hidden"); + EP_STATIC_ASSERT(kNumSendUnrolls <= kNumMaxUnrolls and kNumRecvUnrolls <= kNumMaxUnrolls, "Invalid unrolls"); + EP_STATIC_ASSERT(hidden_bf16_int4 % kNumSendUnrolls == 0, "Invalid hidden"); + EP_STATIC_ASSERT(kNumSendUnrolls >= kNumRecvUnrolls, "Invalid unroll factors"); + + // Message package + EP_STATIC_ASSERT(kHidden % 128 == 0, "Invalid hidden"); + constexpr int kNumDivisions = kHidden / 128; + constexpr int kNumMetaBytes = kNumDivisions * sizeof(nv_bfloat162); + constexpr size_t num_bytes_per_slot = kHidden * sizeof(nv_bfloat16) + kNumMetaBytes; + EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization"); + + // Sending phase + if ((phases & LOW_LATENCY_SEND_PHASE) == 0) + goto LOW_LATENCY_COMBINE_RECV; + + // Clean up next buffer + if (sm_id == 0 and warp_group_id == 0 and sub_warp_id == 0) { + #pragma unroll + for (int i = lane_id; i < num_next_clean_int; i += 32) + next_clean[i] = 0; + + // Notify before executing `int_p` + __syncwarp(); + if (lane_id == 0) + atomic_add_release_global(atomic_clean_flag, num_experts); + } + + // Issue IBGDA sends + if (responsible_expert_idx < num_experts) { + // ------------------------------------------ START tma-related ------------------------------------------------- + // TMA stuffs + constexpr int kNumTMABufferBytes = sizeof(int4) * 32 * kNumSendUnrolls; + constexpr int kNumStages = 3; + constexpr int kNumPrefetch = 1; + EP_STATIC_ASSERT(kNumStages == 3 and kNumPrefetch == 1, "Invalid stages"); + + auto smem_ptr = smem_buffer + warp_id * (kNumStages * (kNumTMABufferBytes + 16) + kNumMetaBytes); + uint32_t tma_phase = 0; + auto tma_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast(smem_ptr + i * (kNumTMABufferBytes + 16)); }); + auto full_barriers = PatternVisitor([=](const int& i) { return reinterpret_cast(smem_ptr + i * (kNumTMABufferBytes + 16) + kNumTMABufferBytes); }); + auto meta_buffers = kUseLogFMT ? reinterpret_cast(smem_ptr + kNumStages * (kNumTMABufferBytes + 16)) : nullptr; + EP_STATIC_ASSERT(kNumSendUnrolls * kNumStages <= 12, "TMA buffer size exceed limit"); + + // Initialize m-barriers + if (lane_id < kNumStages) { + mbarrier_init(full_barriers[lane_id], 1); + fence_view_async_shared(); + fence_barrier_init(); + } + __syncwarp(); + + constexpr int kNumIters = hidden_bf16_int4_pad / (32 * kNumSendUnrolls); + auto tma_load_and_arrive = [&](const int& stage_idx, const int4* gmem_ptr, const int& num_bytes) { + tma_load_1d(tma_buffers[stage_idx], gmem_ptr, full_barriers[stage_idx], num_bytes); + mbarrier_arrive_and_expect_tx(full_barriers[stage_idx], num_bytes); + }; + auto get_num_tma_bytes = [&](const int& offset_int4) { + return min(kNumTMABufferBytes, static_cast((hidden_bf16_int4 - offset_int4) * sizeof(int4))); + }; + // -------------------------------------------- END tma-related ----------------------------------------------- + + const auto dst_rank = responsible_expert_idx / num_local_experts; + + // NOTE + // before: "one warp group --- all tokens for one (dsk_rank, local_expert_idx)" + // after: "multiple warp groups --- cooperate on tokens for one (dsk_rank, local_expert_idx)" + for (int local_expert_idx = 0; local_expert_idx < num_local_experts; ++local_expert_idx) { + // NOTE changed + // const auto local_expert_idx = responsible_expert_idx % num_local_experts; + const auto token_cooperate_part_idx = responsible_expert_idx % num_local_experts; + const auto num_token_cooperate_parts = num_local_experts; + + const auto global_expert_idx = rank * num_local_experts + local_expert_idx; + const auto layout = __ldg(layout_range + local_expert_idx * num_ranks + dst_rank); + const auto local_x = static_cast(x) + + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_bf16_int4; + const auto local_src_info = src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; + const auto rdma_send_x_vec = static_cast(rdma_send_x) + + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_slot; + + // Unpack layout + int offset, num_tokens_to_send; + unpack2(layout, num_tokens_to_send, offset); + + // NOTE added + if (src_signals != nullptr) { + // TODO shall we let 1st expert be separately computed and then do *not* wait for it + if (threadIdx.x == 0) { + wait_signal(src_signals + local_expert_idx, src_signal_expect_value); + } + + __syncthreads(); + } + + // Issue IBGDA send + // NOTE changed + // for (int token_idx = offset + sub_warp_id; token_idx < offset + num_tokens_to_send; token_idx += num_warps_per_group) { + const int num_tokens_to_send_per_cooperate_part = ceil_div(num_tokens_to_send, num_token_cooperate_parts); + const int token_idx_part_end = offset + min(num_tokens_to_send, num_tokens_to_send_per_cooperate_part * (token_cooperate_part_idx + 1)); + for ( + int token_idx = offset + num_tokens_to_send_per_cooperate_part * token_cooperate_part_idx + sub_warp_id; + token_idx < token_idx_part_end; + token_idx += num_warps_per_group + ) { + const auto x_int4 = local_x + token_idx * hidden_bf16_int4; + const auto rdma_send_type_row = reinterpret_cast(rdma_send_x_vec + token_idx * num_bytes_per_slot); + const auto rdma_send_x_vec_row = reinterpret_cast(rdma_send_type_row); + + // Copy directly to local rank, or copy to buffer and issue RDMA + const auto src_idx = __shfl_sync(0xffffffff, __ldg(local_src_info + token_idx), 0); + const auto buf_ptr = reinterpret_cast(rdma_send_x_vec_row); + const auto dst_ptr = reinterpret_cast(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot; + const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); + int num_send_bytes = hidden * sizeof(nv_bfloat16); + + if (not zero_copy or dst_p2p_ptr != 0) { + // Read from `cpy_src_int4_ptr` and copy into `cpy_dst_int4_ptr` + const auto cpy_src_int4_ptr = zero_copy ? reinterpret_cast(buf_ptr) : x_int4; + const auto cpy_dst_int4_ptr = dst_p2p_ptr == 0 ? reinterpret_cast(buf_ptr) : reinterpret_cast(dst_p2p_ptr); + + // Prefetch + if (elect_one_sync()) + tma_load_and_arrive(0, cpy_src_int4_ptr, get_num_tma_bytes(0)); + __syncwarp(); + + int tma_offset_bytes = kNumMetaBytes; + #pragma unroll + for (int i = lane_id * kNumSendUnrolls, iter_idx = 0; i < hidden_bf16_int4_pad; i += 32 * kNumSendUnrolls, ++ iter_idx) { + // Load the next iteration + const int& stage_idx = iter_idx % kNumStages; + const int& next_stage_idx = (iter_idx + 1) % kNumStages; + if (iter_idx + 1 < kNumIters and elect_one_sync()) { + tma_store_wait(); + const auto& offset_int4 = i + 32 * kNumSendUnrolls; + tma_load_and_arrive(next_stage_idx, cpy_src_int4_ptr + offset_int4, get_num_tma_bytes(offset_int4)); + } + __syncwarp(); + + // Wait the current TMA arrival + EP_STATIC_ASSERT(kNumStages < 32, "Too many stages"); + mbarrier_wait(full_barriers[stage_idx], tma_phase, stage_idx); + if constexpr (kUseLogFMT) { + // Cast if possible + constexpr int kNumInt4PerDivision = 128 / kNumElemsPerInt4; + int num_tma_bytes = logfmt_encode( + tma_buffers[stage_idx], + // NOTES: only the leader lane will write the result + (i % kNumInt4PerDivision == 0) ? meta_buffers + i / kNumInt4PerDivision : nullptr, + lane_id); + if (elect_one_sync()) + tma_store_1d(tma_buffers[stage_idx], reinterpret_cast(cpy_dst_int4_ptr) + tma_offset_bytes, num_tma_bytes); + tma_offset_bytes += num_tma_bytes; + } else { + // BF16 original values + if (elect_one_sync()) + tma_store_1d(tma_buffers[stage_idx], cpy_dst_int4_ptr + i, get_num_tma_bytes(i)); + } + __syncwarp(); + } + + // Store metadata (min/max values) for LogFMT + if constexpr (kUseLogFMT) { + num_send_bytes = tma_offset_bytes; + if (elect_one_sync()) + tma_store_1d(meta_buffers, cpy_dst_int4_ptr, kNumMetaBytes); + } + + // Flush all stores + tma_store_wait<0>(); + __syncwarp(); + } + + // Issue RDMA + // NOTES: for zero-copy mode, we assume the data is already in the send buffer + if (dst_p2p_ptr == 0) + nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, num_send_bytes, dst_rank, local_expert_idx, lane_id, token_idx - offset); + } + } + + // Put the finishing flag + EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 16); + asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(num_warps_per_group * 32)); + if (sub_warp_id == 1 and lane_id == 0) { + // copied from global to this part + const auto local_expert_idx_for_signal = responsible_expert_idx % num_local_experts; + const auto global_expert_idx_for_signal = rank * num_local_experts + local_expert_idx_for_signal; + // ============================================= + + while (ld_acquire_global(atomic_clean_flag) == 0); + auto dst_ptr = reinterpret_cast(rdma_recv_flag + global_expert_idx_for_signal); + auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); + if (dst_p2p_ptr == 0) { + // will not visit this branch + // nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast(dst_ptr), 1, dst_rank, local_expert_idx); + EP_DEVICE_ASSERT(0); + } else { + st_release_sys_global(reinterpret_cast(dst_p2p_ptr), 1); + } + atomic_add_release_global(atomic_clean_flag, -1); + } + __syncwarp(); + + // Destroy m-barriers + if (lane_id < kNumStages) { + mbarrier_inval(full_barriers[lane_id]); + fence_view_async_shared(); + fence_barrier_init(); + } + __syncwarp(); + } else { + // NOTE add + for (int local_expert_idx = 0; local_expert_idx < num_local_experts; ++local_expert_idx) { + if (src_signals != nullptr) { + // TODO original code uses NamedBarrier, better than this? + __syncthreads(); + } + } + } + + // Receiving phase + LOW_LATENCY_COMBINE_RECV: + if ((phases & LOW_LATENCY_RECV_PHASE) == 0) + return; + + // Wait all ranks to arrive + if (responsible_expert_idx < num_experts) { + EP_DEVICE_ASSERT(num_warps_per_group > 1); + if (sub_warp_id == 0 and lane_id == 0) { + auto start_time = clock64(); + while (ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0); + auto wait_recv_cost = clock64() - start_time; + if (combine_wait_recv_cost_stats != nullptr) { + const auto& src_rank = responsible_expert_idx / num_local_experts; + atomicAdd(reinterpret_cast(combine_wait_recv_cost_stats + src_rank), wait_recv_cost); + } + } + } + cg::this_grid().sync(); + + // Reassign warp groups + constexpr int kMaxNumGroups = 2; + const int num_decode_warps = hidden_bf16_int4_pad / (kNumRecvUnrolls * 32); + const int num_groups = min(kMaxNumGroups, (num_threads / 32) / (num_decode_warps + 1)); + const int decode_warp_idx = __shfl_sync(0xffffffff, warp_id % (num_decode_warps + 1), 0); + const int group_idx = __shfl_sync(0xffffffff, warp_id / (num_decode_warps + 1), 0); + EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerInt4) == 0, "Invalid vectorization"); + EP_DEVICE_ASSERT(num_topk <= 32); + EP_DEVICE_ASSERT(num_groups > 0); + + if (group_idx < num_groups) { + constexpr int kNumStages = 3; + constexpr int kNumTMABufferBytes = 16 * 2 + kHidden * 2; + constexpr int kNumBF16PerWarpBytes = 32 * kNumRecvUnrolls * kNumElemsPerInt4 * 2; + constexpr int kNumLogFMTPerWarpBytes = kNumBF16PerWarpBytes / 16 * 10; + constexpr int kNumDivisionBytes = kNumDivisions * sizeof(uint32_t); + constexpr int kNumBytesPerGroup = kNumStages * kNumTMABufferBytes + kHidden * 2 + kNumStages * kNumDivisionBytes * 3; + + // Reallocate shared memory + const auto smem_group_buffer = smem_buffer + kNumBytesPerGroup * group_idx; + auto full_barriers = PatternVisitor([=](const int& i) { return reinterpret_cast(smem_group_buffer + i * kNumTMABufferBytes); }); + auto empty_barriers = PatternVisitor([=](const int& i) { return reinterpret_cast(smem_group_buffer + i * kNumTMABufferBytes + 8); }); + auto tma_ld_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast(smem_group_buffer + i * kNumTMABufferBytes + 16); }); + auto tma_st_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast(smem_group_buffer + kNumStages * kNumTMABufferBytes + i * kNumBF16PerWarpBytes); }); + + // Redundant when logfmt is disabled + const auto smem_group_ptr = smem_group_buffer + kNumStages * kNumTMABufferBytes + kHidden * 2; + auto log_amax_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast(smem_group_ptr + i * kNumDivisionBytes); }); + auto log_amin_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast(smem_group_ptr + kNumStages * kNumDivisionBytes + i * kNumDivisionBytes); }); + auto cast_info_buffers = PatternVisitor([=](const int& i) { return reinterpret_cast (smem_group_ptr + kNumStages * kNumDivisionBytes * 2 + i * kNumDivisionBytes); }); + + uint32_t tma_phase = 0; + EP_STATIC_ASSERT(kNumStages < 32, "Too many stages"); + if (decode_warp_idx == num_decode_warps) + tma_phase = (1 << kNumStages) - 1; + + // Initialize m-barriers + if (decode_warp_idx == num_decode_warps and lane_id < kNumStages) { + mbarrier_init(full_barriers[lane_id], 1); + mbarrier_init(empty_barriers[lane_id], num_decode_warps); + } + asm volatile("bar.sync %0, %1;" :: "r"(group_idx + 1), "r"((num_decode_warps + 1) * 32)); + + int stage_idx = 0, topk_idx_by_lane = 0; + EP_STATIC_ASSERT(kNumMaxTopk <= 32, "Invalid number of topks"); + if (decode_warp_idx == num_decode_warps) { + // TMA load warp + for (int token_idx = sm_id + num_sms * group_idx; token_idx < num_combined_tokens; token_idx += num_sms * num_groups) { + if (lane_id < num_topk) + topk_idx_by_lane = static_cast(__ldg(topk_idx + token_idx * num_topk + lane_id)); + for (int i = 0; i < num_topk; ++ i) { + int topk_idx_reg = __shfl_sync(0xffffffff, topk_idx_by_lane, i); + if (topk_idx_reg < 0) + continue; + + mbarrier_wait(empty_barriers[stage_idx], tma_phase, stage_idx); + auto buffer = static_cast(rdma_recv_x) + (topk_idx_reg * num_max_dispatch_tokens_per_rank + token_idx) * num_bytes_per_slot; + if constexpr (kUseLogFMT) { + logfmt_check_amaxmin( + buffer, reinterpret_cast(log_amax_buffers[stage_idx]), + reinterpret_cast(log_amin_buffers[stage_idx]), cast_info_buffers[stage_idx], lane_id); + } + if (elect_one_sync()) { + int num_casted = 0; + if constexpr (kUseLogFMT) { + const auto& info = cast_info_buffers[stage_idx][num_decode_warps - 1]; + num_casted = (info >> 1) + (info & 1); + } + int num_tma_bytes = num_casted * kNumLogFMTPerWarpBytes + (num_decode_warps - num_casted) * kNumBF16PerWarpBytes; + tma_load_1d(tma_ld_buffers[stage_idx], buffer + (kUseLogFMT ? kNumMetaBytes : 0), full_barriers[stage_idx], num_tma_bytes); + mbarrier_arrive_and_expect_tx(full_barriers[stage_idx], num_tma_bytes); + } + __syncwarp(); + stage_idx = (stage_idx + 1) % kNumStages; + } + } + } else { + // Reduction warps + float topk_weights_by_lane; + for (int token_idx = sm_id + num_sms * group_idx; token_idx < num_combined_tokens; token_idx += num_sms * num_groups) { + if (lane_id < num_topk) { + topk_idx_by_lane = static_cast(__ldg(topk_idx + token_idx * num_topk + lane_id)); + topk_weights_by_lane = __ldg(topk_weights + token_idx * num_topk + lane_id); + } + __syncwarp(); + + float combined_values[kNumElemsPerInt4 * kNumRecvUnrolls] = {0.0f}; + for (int i = 0; i < num_topk; ++ i) { + if (__shfl_sync(0xffffffff, topk_idx_by_lane, i) < 0) + continue; + const auto& topk_weight = __shfl_sync(0xffffffff, topk_weights_by_lane, i); + + mbarrier_wait(full_barriers[stage_idx], tma_phase, stage_idx); + if constexpr (kUseLogFMT) { + const auto& info = cast_info_buffers[stage_idx][decode_warp_idx]; + bool enable_cast = info & 1; + int num_casted_prefix = info >> 1; + int tma_offset = kNumLogFMTPerWarpBytes * num_casted_prefix + kNumBF16PerWarpBytes * (decode_warp_idx - num_casted_prefix); + int division_idx = decode_warp_idx * (kNumRecvUnrolls * 2) + lane_id * kNumRecvUnrolls / 16; + decode_and_accumulate( + reinterpret_cast(tma_ld_buffers[stage_idx] + tma_offset + (enable_cast ? kNumLogFMTPerWarpBytes : kNumBF16PerWarpBytes) / 32 * lane_id), + combined_values, log_amax_buffers[stage_idx][division_idx], log_amin_buffers[stage_idx][division_idx], enable_cast, topk_weight + ); + } else { + int tma_offset = kNumBF16PerWarpBytes * decode_warp_idx; + decode_and_accumulate( + reinterpret_cast(tma_ld_buffers[stage_idx] + tma_offset + kNumBF16PerWarpBytes / 32 * lane_id), + combined_values, 0, 0, false, topk_weight + ); + } + + if (elect_one_sync()) + mbarrier_arrive(empty_barriers[stage_idx]); + stage_idx = (stage_idx + 1) % kNumStages; + } + tma_store_wait<0>(); + + #pragma unroll + for (int k = 0; k < kNumRecvUnrolls * 4; ++ k) { + auto combined_pack = __nv_bfloat162(combined_values[k * 2], combined_values[k * 2 + 1]); + tma_st_buffers[decode_warp_idx][kNumRecvUnrolls * 4 * lane_id + k] = *reinterpret_cast(&combined_pack); + } + tma_store_fence(); + if (elect_one_sync()) { + tma_store_1d(tma_st_buffers[decode_warp_idx], + static_cast(combined_x) + token_idx * hidden_bf16_int4 + decode_warp_idx * kNumRecvUnrolls * 32, + kNumBF16PerWarpBytes); + } + __syncwarp(); + } + } + + // Flush all stores + tma_store_wait<0>(); + } +} + +void combine_overlappable(void* combined_x, + void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x, + const void* x, const int64_t* topk_idx, const float* topk_weights, + const int* src_info, const int64_t* layout_range, + int64_t* combine_wait_recv_cost_stats, + int* next_clean, int num_next_clean_int, + int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, + int num_topk, int num_experts, int rank, int num_ranks, + bool use_logfmt, + void* workspace, int num_device_sms, + cudaStream_t stream, int phases, bool zero_copy, + uint32_t* src_signals, uint32_t src_signal_expect_value) { + // NOTE reduce combine_send num sm + if ((phases & LOW_LATENCY_RECV_PHASE) == 0) { + // TODO let it be configurable + num_device_sms = 32; + } + + constexpr int kNumMaxTopk = 9; + const int num_warp_groups = ceil_div(num_experts, num_device_sms); + const int num_warps_per_group = 32 / num_warp_groups; + const int num_recv_per_sm = ceil_div(num_combined_tokens, num_device_sms); + EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0 and ((num_combined_tokens == 0) or (num_recv_per_sm > 0))); + + const auto num_warps = num_warp_groups * num_warps_per_group; + const auto num_sms = max(ceil_div(num_experts, num_warp_groups), ceil_div(num_combined_tokens, num_recv_per_sm)); + + // Check workspace + auto atomic_clean_flag = static_cast(workspace); + EP_HOST_ASSERT(sizeof(int) <= NUM_WORKSPACE_BYTES); + EP_HOST_ASSERT(num_topk <= kNumMaxTopk); + + // Online cast cannot use zero-copy + EP_HOST_ASSERT(not (zero_copy and use_logfmt)); + + constexpr int kNumStages = 3; + constexpr int kNumMaxUnrolls = 4; + constexpr int kMaxNumGroups = 2; + + // Send buffer size + const int num_meta_bytes = hidden / 128 * 4; + const int num_send_tma_bytes = 32 * sizeof(int4) * kNumMaxUnrolls + 16; + const int smem_send_size = num_warps * (kNumStages * num_send_tma_bytes + num_meta_bytes); + + // Receive buffer size + const int num_recv_tma_bytes = 16 + hidden * 2; + const int smem_recv_size = kMaxNumGroups * (kNumStages * num_recv_tma_bytes + hidden * 2 + kNumStages * num_meta_bytes * 3); + + // Total requirement + const int smem_size = max(smem_send_size, smem_recv_size); + +#define COMBINE_LAUNCH_CASE(hidden) { \ +auto combine_func = use_logfmt ? \ + combine_overlappable : \ + combine_overlappable; \ +SET_SHARED_MEMORY_FOR_TMA(combine_func); \ +LAUNCH_KERNEL(&cfg, combine_func, \ + combined_x, \ + rdma_recv_x, rdma_recv_flag, rdma_send_x, \ + x, topk_idx, topk_weights, src_info, layout_range, \ + combine_wait_recv_cost_stats, \ + next_clean, num_next_clean_int, \ + atomic_clean_flag, \ + num_combined_tokens, hidden, num_topk, \ + num_max_dispatch_tokens_per_rank, \ + num_experts, rank, num_ranks, \ + num_warp_groups, num_warps_per_group, \ + phases, zero_copy, \ + src_signals, src_signal_expect_value); } break + + SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream); + SWITCH_HIDDEN(COMBINE_LAUNCH_CASE); +#undef COMBINE_LAUNCH_CASE +} + template __global__ __launch_bounds__(1024, 1) void combine(void* combined_x, @@ -1140,7 +1629,25 @@ void combine(void* combined_x, int num_topk, int num_experts, int rank, int num_ranks, bool use_logfmt, void* workspace, int num_device_sms, - cudaStream_t stream, int phases, bool zero_copy) { + cudaStream_t stream, int phases, bool zero_copy, + bool overlap, uint32_t* src_signals, uint32_t src_signal_expect_value) { + if (overlap) { + return combine_overlappable( + combined_x, + rdma_recv_x, rdma_recv_flag, rdma_send_x, + x, topk_idx, topk_weights, + src_info, layout_range, + combine_wait_recv_cost_stats, + next_clean, num_next_clean_int, + num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank, + num_topk, num_experts, rank, num_ranks, + use_logfmt, + workspace, num_device_sms, + stream, phases, zero_copy, + src_signals, src_signal_expect_value + ); + } + constexpr int kNumMaxTopk = 11; const int num_warp_groups = ceil_div(num_experts, num_device_sms); const int num_warps_per_group = 32 / num_warp_groups; diff --git a/csrc/kernels/utils.cuh b/csrc/kernels/utils.cuh index da6c34fa..d36bd043 100644 --- a/csrc/kernels/utils.cuh +++ b/csrc/kernels/utils.cuh @@ -328,6 +328,10 @@ __device__ __forceinline__ uint32_t elect_one_sync() { #endif } +__device__ __forceinline__ void fence_view_async_shared() { + asm volatile("fence.proxy.async.shared::cta; \n" :: ); +} + // TMA PTX instructions #ifndef DISABLE_SM90_FEATURES @@ -599,4 +603,20 @@ __forceinline__ __device__ T warp_reduce_or(T value) { return warp_reduce(value, ReduceOr{}); } +__device__ __forceinline__ void wait_signal(uint32_t* addr, uint32_t expect_value) { + while (true) { + uint32_t ready = 0; + asm volatile("ld.acquire.gpu.global.u32 %0, [%1];" + : "=r"(ready) + : "l"(addr) + : "memory"); + + if (ready == expect_value) { + return; + } + + asm volatile("nanosleep.u32 20;"); + }; +} + } // namespace deep_ep diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index 23ab8433..9c9c9e89 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -721,7 +721,8 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor, def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, handle: tuple, use_logfmt: bool = False, zero_copy: bool = False, async_finish: bool = False, return_recv_hook: bool = False, out: Optional[torch.Tensor] = None, - combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \ + combine_wait_recv_cost_stats: Optional[torch.Tensor] = None, + overlap: bool = False, src_signals: Optional[torch.Tensor] = None, src_signal_expect_value: int = 0) -> \ Tuple[torch.Tensor, EventOverlap, Callable]: """ A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA. @@ -761,7 +762,8 @@ def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weig combine_wait_recv_cost_stats, num_max_dispatch_tokens_per_rank, num_experts, use_logfmt, zero_copy, async_finish, return_recv_hook, - out) + out, + overlap, src_signals, src_signal_expect_value) tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x) return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook diff --git a/setup.py b/setup.py index 046b8d4c..f5e43f74 100644 --- a/setup.py +++ b/setup.py @@ -103,7 +103,7 @@ def get_extension_deep_ep_cpp(): include_dirs = ['csrc/'] library_dirs = [] nvcc_dlink = [] - extra_link_args = [] + extra_link_args = ['-lcuda'] # NVSHMEM flags if disable_nvshmem: