Skip to content

Commit 2eec484

Browse files
authored
Fine grained overlap (cleaned) (#468)
1 parent 1dddd19 commit 2eec484

File tree

8 files changed

+553
-8
lines changed

8 files changed

+553
-8
lines changed

csrc/deep_ep.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,7 +1277,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
12771277
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
12781278
int num_max_dispatch_tokens_per_rank, int num_experts,
12791279
bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook,
1280-
const std::optional<torch::Tensor>& out) {
1280+
const std::optional<torch::Tensor>& out,
1281+
bool overlap, const std::optional<torch::Tensor>& src_signals, uint32_t src_signal_expect_value) {
12811282
#ifndef DISABLE_NVSHMEM
12821283
EP_HOST_ASSERT(low_latency_mode);
12831284

@@ -1347,7 +1348,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
13471348
num_topk, num_experts, rank, num_ranks,
13481349
use_logfmt,
13491350
workspace, num_device_sms,
1350-
launch_stream, phases, zero_copy);
1351+
launch_stream, phases, zero_copy,
1352+
overlap, src_signals.has_value() ? src_signals->data_ptr<uint32_t>() : nullptr, src_signal_expect_value);
13511353
};
13521354
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));
13531355

csrc/deep_ep.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ struct Buffer {
180180
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
181181
int num_max_dispatch_tokens_per_rank, int num_experts,
182182
bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook,
183-
const std::optional<torch::Tensor>& out = std::nullopt);
183+
const std::optional<torch::Tensor>& out = std::nullopt,
184+
bool overlap = false, const std::optional<torch::Tensor>& src_signals = std::nullopt, uint32_t src_signal_expect_value = 0);
184185

185186
torch::Tensor
186187
get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const;

csrc/kernels/api.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,8 @@ void combine(void* combined_x,
166166
int num_topk, int num_experts, int rank, int num_ranks,
167167
bool use_logfmt,
168168
void* workspace, int num_device_sms,
169-
cudaStream_t stream, int phases, bool zero_copy);
169+
cudaStream_t stream, int phases, bool zero_copy,
170+
bool overlap, uint32_t* src_signals, uint32_t src_signal_expect_value);
170171

171172
} // namespace internode_ll
172173

csrc/kernels/exception.cuh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@ do { \
3131
} while (0)
3232
#endif
3333

34+
#ifndef CU_CHECK
35+
#define CU_CHECK(cmd) \
36+
do { \
37+
CUresult e = (cmd); \
38+
if (e != CUDA_SUCCESS) { \
39+
const char *error_str = NULL; \
40+
cuGetErrorString(e, &error_str); \
41+
throw EPException("CU", __FILE__, __LINE__, std::string(error_str)); \
42+
} \
43+
} while (0)
44+
#endif
45+
3446
#ifndef EP_HOST_ASSERT
3547
#define EP_HOST_ASSERT(cond) \
3648
do { \

csrc/kernels/internode_ll.cu

Lines changed: 508 additions & 1 deletion
Large diffs are not rendered by default.

csrc/kernels/utils.cuh

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,10 @@ __device__ __forceinline__ uint32_t elect_one_sync() {
328328
#endif
329329
}
330330

331+
__device__ __forceinline__ void fence_view_async_shared() {
332+
asm volatile("fence.proxy.async.shared::cta; \n" :: );
333+
}
334+
331335
// TMA PTX instructions
332336
#ifndef DISABLE_SM90_FEATURES
333337

@@ -599,4 +603,20 @@ __forceinline__ __device__ T warp_reduce_or(T value) {
599603
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceOr<T>{});
600604
}
601605

606+
__device__ __forceinline__ void wait_signal(uint32_t* addr, uint32_t expect_value) {
607+
while (true) {
608+
uint32_t ready = 0;
609+
asm volatile("ld.acquire.gpu.global.u32 %0, [%1];"
610+
: "=r"(ready)
611+
: "l"(addr)
612+
: "memory");
613+
614+
if (ready == expect_value) {
615+
return;
616+
}
617+
618+
asm volatile("nanosleep.u32 20;");
619+
};
620+
}
621+
602622
} // namespace deep_ep

deep_ep/buffer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,8 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
721721
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
722722
handle: tuple, use_logfmt: bool = False, zero_copy: bool = False, async_finish: bool = False,
723723
return_recv_hook: bool = False, out: Optional[torch.Tensor] = None,
724-
combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \
724+
combine_wait_recv_cost_stats: Optional[torch.Tensor] = None,
725+
overlap: bool = False, src_signals: Optional[torch.Tensor] = None, src_signal_expect_value: int = 0) -> \
725726
Tuple[torch.Tensor, EventOverlap, Callable]:
726727
"""
727728
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
@@ -761,7 +762,8 @@ def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weig
761762
combine_wait_recv_cost_stats,
762763
num_max_dispatch_tokens_per_rank, num_experts,
763764
use_logfmt, zero_copy, async_finish, return_recv_hook,
764-
out)
765+
out,
766+
overlap, src_signals, src_signal_expect_value)
765767
tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x)
766768
return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook
767769

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def get_extension_deep_ep_cpp():
172172
include_dirs = ['csrc/']
173173
library_dirs = []
174174
nvcc_dlink = []
175-
extra_link_args = []
175+
extra_link_args = ['-lcuda']
176176

177177
# NVSHMEM flags
178178
if disable_nvshmem:

0 commit comments

Comments
 (0)