Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions csrc/deep_ep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1277,7 +1277,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out) {
const std::optional<torch::Tensor>& out,
bool overlap, const std::optional<torch::Tensor>& src_signals, uint32_t src_signal_expect_value) {
#ifndef DISABLE_NVSHMEM
EP_HOST_ASSERT(low_latency_mode);

Expand Down Expand Up @@ -1347,7 +1348,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id
num_topk, num_experts, rank, num_ranks,
use_logfmt,
workspace, num_device_sms,
launch_stream, phases, zero_copy);
launch_stream, phases, zero_copy,
overlap, src_signals.has_value() ? src_signals->data_ptr<uint32_t>() : nullptr, src_signal_expect_value);
};
launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE));

Expand Down
3 changes: 2 additions & 1 deletion csrc/deep_ep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ struct Buffer {
const std::optional<torch::Tensor>& combine_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_logfmt, bool zero_copy, bool async, bool return_recv_hook,
const std::optional<torch::Tensor>& out = std::nullopt);
const std::optional<torch::Tensor>& out = std::nullopt,
bool overlap = false, const std::optional<torch::Tensor>& src_signals = std::nullopt, uint32_t src_signal_expect_value = 0);

torch::Tensor
get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const;
Expand Down
3 changes: 2 additions & 1 deletion csrc/kernels/api.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ void combine(void* combined_x,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_logfmt,
void* workspace, int num_device_sms,
cudaStream_t stream, int phases, bool zero_copy);
cudaStream_t stream, int phases, bool zero_copy,
bool overlap, uint32_t* src_signals, uint32_t src_signal_expect_value);

} // namespace internode_ll

Expand Down
12 changes: 12 additions & 0 deletions csrc/kernels/exception.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ do { \
} while (0)
#endif

#ifndef CU_CHECK
#define CU_CHECK(cmd) \
do { \
CUresult e = (cmd); \
if (e != CUDA_SUCCESS) { \
const char *error_str = NULL; \
cuGetErrorString(e, &error_str); \
throw EPException("CU", __FILE__, __LINE__, std::string(error_str)); \
} \
} while (0)
#endif

#ifndef EP_HOST_ASSERT
#define EP_HOST_ASSERT(cond) \
do { \
Expand Down
509 changes: 508 additions & 1 deletion csrc/kernels/internode_ll.cu

Large diffs are not rendered by default.

20 changes: 20 additions & 0 deletions csrc/kernels/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,10 @@ __device__ __forceinline__ uint32_t elect_one_sync() {
#endif
}

__device__ __forceinline__ void fence_view_async_shared() {
asm volatile("fence.proxy.async.shared::cta; \n" :: );
}

// TMA PTX instructions
#ifndef DISABLE_SM90_FEATURES

Expand Down Expand Up @@ -599,4 +603,20 @@ __forceinline__ __device__ T warp_reduce_or(T value) {
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceOr<T>{});
}

__device__ __forceinline__ void wait_signal(uint32_t* addr, uint32_t expect_value) {
while (true) {
uint32_t ready = 0;
asm volatile("ld.acquire.gpu.global.u32 %0, [%1];"
: "=r"(ready)
: "l"(addr)
: "memory");

if (ready == expect_value) {
return;
}

asm volatile("nanosleep.u32 20;");
};
}

} // namespace deep_ep
6 changes: 4 additions & 2 deletions deep_ep/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,8 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor,
handle: tuple, use_logfmt: bool = False, zero_copy: bool = False, async_finish: bool = False,
return_recv_hook: bool = False, out: Optional[torch.Tensor] = None,
combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \
combine_wait_recv_cost_stats: Optional[torch.Tensor] = None,
overlap: bool = False, src_signals: Optional[torch.Tensor] = None, src_signal_expect_value: int = 0) -> \
Tuple[torch.Tensor, EventOverlap, Callable]:
"""
A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
Expand Down Expand Up @@ -761,7 +762,8 @@ def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weig
combine_wait_recv_cost_stats,
num_max_dispatch_tokens_per_rank, num_experts,
use_logfmt, zero_copy, async_finish, return_recv_hook,
out)
out,
overlap, src_signals, src_signal_expect_value)
tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x)
return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def get_extension_deep_ep_cpp():
include_dirs = ['csrc/']
library_dirs = []
nvcc_dlink = []
extra_link_args = []
extra_link_args = ['-lcuda']

# NVSHMEM flags
if disable_nvshmem:
Expand Down