@@ -769,13 +769,11 @@ __forceinline__ __device__ void decode_and_accumulate(uint32_t* ld_buffer, float
769769 }
770770}
771771
772- // TODO unify with original code
773772template <bool kUseLogFMT , int kHidden , int kNumMaxTopk , int kNumMaxUnrolls >
774773__global__
775774__launch_bounds__ (1024 , 1 )
776- // __maxnreg__(48) // rm
777775void
778- combine_v2 (void * combined_x,
776+ combine_overlappable (void * combined_x,
779777 void * rdma_recv_x, int * rdma_recv_flag, void * rdma_send_x,
780778 const void * x, const int64_t * topk_idx, const float * topk_weights,
781779 const int * src_info, const int64_t * layout_range,
@@ -838,7 +836,6 @@ combine_v2(void* combined_x,
838836
839837 // Issue IBGDA sends
840838 if (responsible_expert_idx < num_experts) {
841- // NOTE move tma-related to outside local_expert_idx loop
842839 // ------------------------------------------ START tma-related -------------------------------------------------
843840 // TMA stuffs
844841 constexpr int kNumTMABufferBytes = sizeof (int4 ) * 32 * kNumSendUnrolls ;
@@ -897,12 +894,10 @@ combine_v2(void* combined_x,
897894 // NOTE added
898895 if (src_signals != nullptr ) {
899896 // TODO shall we let 1st expert be separately computed and then do *not* wait for it
900- // if ((threadIdx.x == 0) and (local_expert_idx > 0)) {
901897 if (threadIdx .x == 0 ) {
902898 wait_signal (src_signals + local_expert_idx, src_signal_expect_value);
903899 }
904900
905- // TODO original code uses NamedBarrier, better than this?
906901 __syncthreads ();
907902 }
908903
@@ -991,7 +986,6 @@ combine_v2(void* combined_x,
991986 }
992987 }
993988
994- // TODO maybe move to above?
995989 // Put the finishing flag
996990 EP_DEVICE_ASSERT (num_warps_per_group > 1 and num_warp_groups < 16 );
997991 asm volatile (" bar.sync %0, %1;" :: " r" (warp_group_id + 1 ), " r" (num_warps_per_group * 32 ));
@@ -1031,7 +1025,6 @@ combine_v2(void* combined_x,
10311025 }
10321026 }
10331027 }
1034- // if (thread_id % 32 == 0) { printf("[R%d,S%d,T%d] combine phase=send END\n", rank, sm_id, thread_id); }
10351028
10361029 // Receiving phase
10371030 LOW_LATENCY_COMBINE_RECV:
@@ -1188,11 +1181,9 @@ combine_v2(void* combined_x,
11881181 // Flush all stores
11891182 tma_store_wait<0 >();
11901183 }
1191-
1192- // if (thread_id % 32 == 0) { printf("[R%d,S%d,T%d] combine phase=recv END\n", rank, sm_id, thread_id); }
11931184}
11941185
1195- void combine_v2 (void * combined_x,
1186+ void combine_overlappable (void * combined_x,
11961187 void * rdma_recv_x, int * rdma_recv_flag, void * rdma_send_x,
11971188 const void * x, const int64_t * topk_idx, const float * topk_weights,
11981189 const int * src_info, const int64_t * layout_range,
@@ -1245,8 +1236,8 @@ void combine_v2(void* combined_x,
12451236
12461237#define COMBINE_LAUNCH_CASE (hidden ) { \
12471238auto combine_func = use_logfmt ? \
1248- combine_v2 <true , hidden, kNumMaxTopk , kNumMaxUnrolls > : \
1249- combine_v2 <false , hidden, kNumMaxTopk , kNumMaxUnrolls >; \
1239+ combine_overlappable <true , hidden, kNumMaxTopk , kNumMaxUnrolls > : \
1240+ combine_overlappable <false , hidden, kNumMaxTopk , kNumMaxUnrolls >; \
12501241SET_SHARED_MEMORY_FOR_TMA (combine_func); \
12511242LAUNCH_KERNEL (&cfg, combine_func, \
12521243 combined_x, \
@@ -1641,7 +1632,7 @@ void combine(void* combined_x,
16411632 cudaStream_t stream, int phases, bool zero_copy,
16421633 bool overlap, uint32_t * src_signals, uint32_t src_signal_expect_value) {
16431634 if (overlap) {
1644- return combine_v2 (
1635+ return combine_overlappable (
16451636 combined_x,
16461637 rdma_recv_x, rdma_recv_flag, rdma_send_x,
16471638 x, topk_idx, topk_weights,
0 commit comments