Skip to content

Commit c106c12

Browse files
committed
cleanup
1 parent 02e8b6f commit c106c12

File tree

2 files changed

+5
-16
lines changed

2 files changed

+5
-16
lines changed

csrc/kernels/internode_ll.cu

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -769,13 +769,11 @@ __forceinline__ __device__ void decode_and_accumulate(uint32_t* ld_buffer, float
769769
}
770770
}
771771

772-
// TODO unify with original code
773772
template <bool kUseLogFMT, int kHidden, int kNumMaxTopk, int kNumMaxUnrolls>
774773
__global__
775774
__launch_bounds__(1024, 1)
776-
// __maxnreg__(48) // rm
777775
void
778-
combine_v2(void* combined_x,
776+
combine_overlappable(void* combined_x,
779777
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
780778
const void* x, const int64_t* topk_idx, const float* topk_weights,
781779
const int* src_info, const int64_t* layout_range,
@@ -838,7 +836,6 @@ combine_v2(void* combined_x,
838836

839837
// Issue IBGDA sends
840838
if (responsible_expert_idx < num_experts) {
841-
// NOTE move tma-related to outside local_expert_idx loop
842839
// ------------------------------------------ START tma-related -------------------------------------------------
843840
// TMA stuffs
844841
constexpr int kNumTMABufferBytes = sizeof(int4) * 32 * kNumSendUnrolls;
@@ -897,12 +894,10 @@ combine_v2(void* combined_x,
897894
// NOTE added
898895
if (src_signals != nullptr) {
899896
// TODO shall we let 1st expert be separately computed and then do *not* wait for it
900-
// if ((threadIdx.x == 0) and (local_expert_idx > 0)) {
901897
if (threadIdx.x == 0) {
902898
wait_signal(src_signals + local_expert_idx, src_signal_expect_value);
903899
}
904900

905-
// TODO original code uses NamedBarrier, better than this?
906901
__syncthreads();
907902
}
908903

@@ -991,7 +986,6 @@ combine_v2(void* combined_x,
991986
}
992987
}
993988

994-
// TODO maybe move to above?
995989
// Put the finishing flag
996990
EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 16);
997991
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(num_warps_per_group * 32));
@@ -1031,7 +1025,6 @@ combine_v2(void* combined_x,
10311025
}
10321026
}
10331027
}
1034-
// if (thread_id % 32 == 0) { printf("[R%d,S%d,T%d] combine phase=send END\n", rank, sm_id, thread_id); }
10351028

10361029
// Receiving phase
10371030
LOW_LATENCY_COMBINE_RECV:
@@ -1188,11 +1181,9 @@ combine_v2(void* combined_x,
11881181
// Flush all stores
11891182
tma_store_wait<0>();
11901183
}
1191-
1192-
// if (thread_id % 32 == 0) { printf("[R%d,S%d,T%d] combine phase=recv END\n", rank, sm_id, thread_id); }
11931184
}
11941185

1195-
void combine_v2(void* combined_x,
1186+
void combine_overlappable(void* combined_x,
11961187
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
11971188
const void* x, const int64_t* topk_idx, const float* topk_weights,
11981189
const int* src_info, const int64_t* layout_range,
@@ -1245,8 +1236,8 @@ void combine_v2(void* combined_x,
12451236

12461237
#define COMBINE_LAUNCH_CASE(hidden) { \
12471238
auto combine_func = use_logfmt ? \
1248-
combine_v2<true, hidden, kNumMaxTopk, kNumMaxUnrolls> : \
1249-
combine_v2<false, hidden, kNumMaxTopk, kNumMaxUnrolls>; \
1239+
combine_overlappable<true, hidden, kNumMaxTopk, kNumMaxUnrolls> : \
1240+
combine_overlappable<false, hidden, kNumMaxTopk, kNumMaxUnrolls>; \
12501241
SET_SHARED_MEMORY_FOR_TMA(combine_func); \
12511242
LAUNCH_KERNEL(&cfg, combine_func, \
12521243
combined_x, \
@@ -1641,7 +1632,7 @@ void combine(void* combined_x,
16411632
cudaStream_t stream, int phases, bool zero_copy,
16421633
bool overlap, uint32_t* src_signals, uint32_t src_signal_expect_value) {
16431634
if (overlap) {
1644-
return combine_v2(
1635+
return combine_overlappable(
16451636
combined_x,
16461637
rdma_recv_x, rdma_recv_flag, rdma_send_x,
16471638
x, topk_idx, topk_weights,

csrc/kernels/utils.cuh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -599,8 +599,6 @@ __forceinline__ __device__ T warp_reduce_or(T value) {
599599
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceOr<T>{});
600600
}
601601

602-
// TODO wait once per thraed block, not per thread
603-
// TODO correct?
604602
__device__ __forceinline__ void wait_signal(uint32_t* addr, uint32_t expect_value) {
605603
while (true) {
606604
uint32_t ready = 0;

0 commit comments

Comments
 (0)