Skip to content

Commit 4b67064

Browse files
authored
Add diagnosis module for efficient and precise location of slow rank (#311)
* Add diagnosis module for precise identification of slow ranks Signed-off-by: wangfakang <[email protected]> * Add tests for the slow diagnosis module Signed-off-by: wangfakang <[email protected]> * Update some comments for diagnose Signed-off-by: wangfakang <[email protected]> * Update test case for diagnose Signed-off-by: wangfakang <[email protected]> * Strip the diagnose module, thx LyricZhao and sphish. Signed-off-by: wangfakang <[email protected]> * update variable name and cumulative wait recv cost, thx sphish. Signed-off-by: wangfakang <[email protected]> * remove invalid comments. Signed-off-by: wangfakang <[email protected]> --------- Signed-off-by: wangfakang <[email protected]>
1 parent b92d0d4 commit 4b67064

File tree

6 files changed

+151
-3
lines changed

6 files changed

+151
-3
lines changed

csrc/deep_ep.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,6 +1090,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int
10901090
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
10911091
Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
10921092
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
1093+
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
10931094
int num_max_dispatch_tokens_per_rank, int num_experts,
10941095
bool use_fp8, bool round_scale, bool use_ue8m0,
10951096
bool async, bool return_recv_hook) {
@@ -1110,6 +1111,12 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
11101111
EP_HOST_ASSERT(cumulative_local_expert_recv_stats->size(0) == num_experts / num_ranks);
11111112
}
11121113

1114+
if (dispatch_wait_recv_cost_stats.has_value()) {
1115+
EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->scalar_type() == torch::kInt64);
1116+
EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->dim() == 1 and dispatch_wait_recv_cost_stats->is_contiguous());
1117+
EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->size(0) == num_ranks);
1118+
}
1119+
11131120
auto num_tokens = static_cast<int>(x.size(0)), hidden = static_cast<int>(x.size(1));
11141121
auto num_topk = static_cast<int>(topk_idx.size(1));
11151122
auto num_local_experts = num_experts / num_ranks;
@@ -1162,6 +1169,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
11621169
packed_recv_src_info.data_ptr<int>(), packed_recv_layout_range.data_ptr<int64_t>(),
11631170
packed_recv_count.data_ptr<int>(),
11641171
cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr<int>() : nullptr,
1172+
dispatch_wait_recv_cost_stats.has_value() ? dispatch_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
11651173
buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer,
11661174
buffer.dispatch_rdma_send_buffer,
11671175
x.data_ptr(), topk_idx.data_ptr<int64_t>(),
@@ -1200,6 +1208,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
12001208
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
12011209
Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
12021210
const torch::Tensor& src_info, const torch::Tensor& layout_range,
1211+
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
12031212
int num_max_dispatch_tokens_per_rank, int num_experts,
12041213
bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook,
12051214
const std::optional<torch::Tensor>& out) {
@@ -1222,6 +1231,13 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
12221231
EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous());
12231232
EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64);
12241233
EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks);
1234+
1235+
if (combine_wait_recv_cost_stats.has_value()) {
1236+
EP_HOST_ASSERT(combine_wait_recv_cost_stats->scalar_type() == torch::kInt64);
1237+
EP_HOST_ASSERT(combine_wait_recv_cost_stats->dim() == 1 and combine_wait_recv_cost_stats->is_contiguous());
1238+
EP_HOST_ASSERT(combine_wait_recv_cost_stats->size(0) == num_ranks);
1239+
}
1240+
12251241
auto hidden = static_cast<int>(x.size(2));
12261242
auto num_topk = static_cast<int>(topk_weights.size(1));
12271243
auto num_combined_tokens = static_cast<int>(topk_weights.size(0));
@@ -1259,6 +1275,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
12591275
buffer.combine_rdma_send_buffer,
12601276
x.data_ptr(), topk_idx.data_ptr<int64_t>(), topk_weights.data_ptr<float>(),
12611277
src_info.data_ptr<int>(), layout_range.data_ptr<int64_t>(),
1278+
combine_wait_recv_cost_stats.has_value() ? combine_wait_recv_cost_stats->data_ptr<int64_t>() : nullptr,
12621279
next_clean_meta.first, next_clean_meta.second,
12631280
num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank,
12641281
num_topk, num_experts, rank, num_ranks,

csrc/deep_ep.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,15 @@ struct Buffer {
146146
std::tuple<torch::Tensor, std::optional<torch::Tensor>, torch::Tensor, torch::Tensor, torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
147147
low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx,
148148
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
149+
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
149150
int num_max_dispatch_tokens_per_rank, int num_experts,
150151
bool use_fp8, bool round_scale, bool use_ue8m0,
151152
bool async, bool return_recv_hook);
152153

153154
std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
154155
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
155156
const torch::Tensor& src_info, const torch::Tensor& layout_range,
157+
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
156158
int num_max_dispatch_tokens_per_rank, int num_experts,
157159
bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook,
158160
const std::optional<torch::Tensor>& out = std::nullopt);

csrc/kernels/api.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
143143
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
144144
int* packed_recv_count,
145145
int* cumulative_local_expert_recv_stats,
146+
int64_t* dispatch_wait_recv_cost_stats,
146147
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
147148
const void* x, const int64_t* topk_idx,
148149
int* next_clean, int num_next_clean_int,
@@ -156,6 +157,7 @@ void combine(void* combined_x,
156157
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
157158
const void* x, const int64_t* topk_idx, const float* topk_weights,
158159
const int* src_info, const int64_t* layout_range,
160+
int64_t* combine_wait_recv_cost_stats,
159161
int* next_clean, int num_next_clean_int,
160162
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
161163
int num_topk, int num_experts, int rank, int num_ranks,

csrc/kernels/internode_ll.cu

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
4242
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
4343
int* packed_recv_count,
4444
int* cumulative_local_expert_recv_stats,
45+
int64_t* dispatch_wait_recv_cost_stats,
4546
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
4647
const void* x, const int64_t* topk_idx,
4748
int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert,
@@ -272,14 +273,20 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
272273
int num_recv_tokens, recv_token_begin_idx;
273274
EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 15);
274275
if (sub_warp_id == 1 and lane_id == 0) {
276+
auto start_time = clock64();
275277
while ((num_recv_tokens = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0);
278+
auto wait_recv_cost = clock64() - start_time;
276279
num_recv_tokens = -num_recv_tokens - 1;
277280
recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens);
278281
shared_num_recv_tokens[warp_group_id] = num_recv_tokens;
279282
shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx;
280283
recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
281284
if (cumulative_local_expert_recv_stats != nullptr)
282285
atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens);
286+
287+
if (dispatch_wait_recv_cost_stats != nullptr)
288+
atomicAdd(reinterpret_cast<unsigned long long*>(dispatch_wait_recv_cost_stats + src_rank),
289+
wait_recv_cost);
283290
}
284291
asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(num_warps_per_group * 32));
285292
num_recv_tokens = shared_num_recv_tokens[warp_group_id];
@@ -330,6 +337,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
330337
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
331338
int* packed_recv_count,
332339
int* cumulative_local_expert_recv_stats,
340+
int64_t* dispatch_wait_recv_cost_stats,
333341
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
334342
const void* x, const int64_t* topk_idx,
335343
int* next_clean, int num_next_clean_int,
@@ -368,6 +376,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
368376
packed_recv_src_info, packed_recv_layout_range, \
369377
packed_recv_count, \
370378
cumulative_local_expert_recv_stats, \
379+
dispatch_wait_recv_cost_stats, \
371380
rdma_recv_x, rdma_recv_count, rdma_x, \
372381
x, topk_idx, \
373382
atomic_counter_per_expert, atomic_finish_counter_per_expert, \
@@ -388,6 +397,7 @@ combine(void* combined_x,
388397
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
389398
const void* x, const int64_t* topk_idx, const float* topk_weights,
390399
const int* src_info, const int64_t* layout_range,
400+
int64_t* combine_wait_recv_cost_stats,
391401
int* next_clean, int num_next_clean_int,
392402
int* atomic_clean_flag,
393403
int num_combined_tokens, int hidden, int num_topk,
@@ -618,7 +628,12 @@ combine(void* combined_x,
618628
if (responsible_expert_idx < num_experts) {
619629
EP_DEVICE_ASSERT(num_warps_per_group > 1);
620630
if (sub_warp_id == 0 and lane_id == 0) {
631+
auto start_time = clock64();
621632
while (ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0);
633+
auto wait_recv_cost = clock64() - start_time;
634+
if (combine_wait_recv_cost_stats != nullptr)
635+
atomicAdd(reinterpret_cast<unsigned long long*>(combine_wait_recv_cost_stats
636+
+ responsible_expert_idx / num_local_experts), wait_recv_cost);
622637
}
623638
}
624639
cg::this_grid().sync();
@@ -667,6 +682,7 @@ void combine(void* combined_x,
667682
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
668683
const void* x, const int64_t* topk_idx, const float* topk_weights,
669684
const int* src_info, const int64_t* layout_range,
685+
int64_t* combine_wait_recv_cost_stats,
670686
int* next_clean, int num_next_clean_int,
671687
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
672688
int num_topk, int num_experts, int rank, int num_ranks,
@@ -701,6 +717,7 @@ LAUNCH_KERNEL(&cfg, combine_func, \
701717
combined_x, \
702718
rdma_recv_x, rdma_recv_flag, rdma_send_x, \
703719
x, topk_idx, topk_weights, src_info, layout_range, \
720+
combine_wait_recv_cost_stats, \
704721
next_clean, num_next_clean_int, \
705722
atomic_clean_flag, \
706723
num_combined_tokens, hidden, num_topk, \

deep_ep/buffer.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,7 @@ def clean_low_latency_buffer(self, num_max_dispatch_tokens_per_rank: int, hidden
515515
def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
516516
num_max_dispatch_tokens_per_rank: int, num_experts: int,
517517
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
518+
dispatch_wait_recv_cost_stats: Optional[torch.Tensor] = None,
518519
use_fp8: bool = True, round_scale: bool = False, use_ue8m0: bool = False,
519520
async_finish: bool = False, return_recv_hook: bool = False) -> \
520521
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
@@ -535,6 +536,9 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
535536
cumulative_local_expert_recv_stats: a cumulative expert count tensor for statistics, which should have shape
536537
`[num_local_experts]` and be typed as `torch.int`. This is useful for online service EP load balance
537538
monitoring.
539+
dispatch_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
540+
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
541+
This is useful for detecting and pre-cisely localizing slow anomalies.
538542
use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors.
539543
round_scale: whether round the scaling factors into power of 2.
540544
use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`).
@@ -565,6 +569,7 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
565569
packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \
566570
self.runtime.low_latency_dispatch(x, topk_idx,
567571
cumulative_local_expert_recv_stats,
572+
dispatch_wait_recv_cost_stats,
568573
num_max_dispatch_tokens_per_rank, num_experts,
569574
use_fp8, round_scale, use_ue8m0,
570575
async_finish, return_recv_hook)
@@ -579,7 +584,8 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
579584
# noinspection PyTypeChecker
580585
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
581586
handle: tuple, use_logfmt: bool = False, zero_copy: bool = False, async_finish: bool = False,
582-
return_recv_hook: bool = False, out: Optional[torch.Tensor] = None) -> \
587+
return_recv_hook: bool = False, out: Optional[torch.Tensor] = None,
588+
combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \
583589
Tuple[torch.Tensor, EventOverlap, Callable]:
584590
"""
585591
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
@@ -605,6 +611,9 @@ def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weig
605611
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
606612
If you do not set this flag, the kernel will ensure the data's arrival.
607613
out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly.
614+
combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics,
615+
which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`.
616+
This is useful for detecting and pre-cisely localizing slow anomalies.
608617
609618
Returns:
610619
combined_x: the reduced token tensor, with shape `[num_combined_tokens, hidden]` and type `torch.bfloat16`.
@@ -613,6 +622,7 @@ def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weig
613622
"""
614623
src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle
615624
combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range,
625+
combine_wait_recv_cost_stats,
616626
num_max_dispatch_tokens_per_rank, num_experts,
617627
use_logfmt, zero_copy, async_finish, return_recv_hook,
618628
out)

0 commit comments

Comments
 (0)