Skip to content

Commit 10e89ee

Browse files
committed
feat: remove record_stream of normal mode
1 parent 4623c67 commit 10e89ee

File tree

3 files changed

+40
-111
lines changed

3 files changed

+40
-111
lines changed

csrc/deep_ep.cpp

Lines changed: 5 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -320,16 +320,7 @@ Buffer::get_dispatch_layout(
320320
std::optional<EventHandle> event;
321321
if (async) {
322322
event = EventHandle(comm_stream);
323-
for (auto& t : {topk_idx, num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank}) {
324-
t.record_stream(comm_stream);
325-
if (allocate_on_comm_stream)
326-
t.record_stream(compute_stream);
327-
}
328-
for (auto& to : {num_tokens_per_rdma_rank}) {
329-
to.has_value() ? to->record_stream(comm_stream) : void();
330-
if (allocate_on_comm_stream)
331-
to.has_value() ? to->record_stream(compute_stream) : void();
332-
}
323+
// NOTES: record_stream removed, tensors are now held by Python-layer extra_tensors
333324
} else {
334325
stream_wait(compute_stream, comm_stream);
335326
}
@@ -606,32 +597,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x,
606597
std::optional<EventHandle> event;
607598
if (async) {
608599
event = EventHandle(comm_stream);
609-
for (auto& t : {x,
610-
is_token_in_rank,
611-
rank_prefix_matrix,
612-
channel_prefix_matrix,
613-
recv_x,
614-
recv_src_idx,
615-
recv_channel_prefix_matrix,
616-
send_head}) {
617-
t.record_stream(comm_stream);
618-
if (allocate_on_comm_stream)
619-
t.record_stream(compute_stream);
620-
}
621-
for (auto& to : {x_scales,
622-
topk_idx,
623-
topk_weights,
624-
num_tokens_per_rank,
625-
num_tokens_per_expert,
626-
cached_channel_prefix_matrix,
627-
cached_rank_prefix_matrix,
628-
recv_topk_idx,
629-
recv_topk_weights,
630-
recv_x_scales}) {
631-
to.has_value() ? to->record_stream(comm_stream) : void();
632-
if (allocate_on_comm_stream)
633-
to.has_value() ? to->record_stream(compute_stream) : void();
634-
}
600+
// NOTES: record_stream removed, tensors are now held by Python-layer extra_tensors
635601
} else {
636602
stream_wait(compute_stream, comm_stream);
637603
}
@@ -774,16 +740,7 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandl
774740
std::optional<EventHandle> event;
775741
if (async) {
776742
event = EventHandle(comm_stream);
777-
for (auto& t : {x, src_idx, send_head, rank_prefix_matrix, channel_prefix_matrix, recv_x}) {
778-
t.record_stream(comm_stream);
779-
if (allocate_on_comm_stream)
780-
t.record_stream(compute_stream);
781-
}
782-
for (auto& to : {topk_weights, recv_topk_weights, bias_0, bias_1}) {
783-
to.has_value() ? to->record_stream(comm_stream) : void();
784-
if (allocate_on_comm_stream)
785-
to.has_value() ? to->record_stream(compute_stream) : void();
786-
}
743+
// NOTES: record_stream removed, tensors are now held by Python-layer extra_tensors
787744
} else {
788745
stream_wait(compute_stream, comm_stream);
789746
}
@@ -1121,39 +1078,7 @@ Buffer::internode_dispatch(const torch::Tensor& x,
11211078
std::optional<EventHandle> event;
11221079
if (async) {
11231080
event = EventHandle(comm_stream);
1124-
for (auto& t : {x,
1125-
is_token_in_rank,
1126-
recv_x,
1127-
rdma_channel_prefix_matrix,
1128-
recv_rdma_rank_prefix_sum,
1129-
gbl_channel_prefix_matrix,
1130-
recv_gbl_rank_prefix_sum}) {
1131-
t.record_stream(comm_stream);
1132-
if (allocate_on_comm_stream)
1133-
t.record_stream(compute_stream);
1134-
}
1135-
for (auto& to : {x_scales,
1136-
topk_idx,
1137-
topk_weights,
1138-
num_tokens_per_rank,
1139-
num_tokens_per_rdma_rank,
1140-
num_tokens_per_expert,
1141-
cached_rdma_channel_prefix_matrix,
1142-
cached_recv_rdma_rank_prefix_sum,
1143-
cached_gbl_channel_prefix_matrix,
1144-
cached_recv_gbl_rank_prefix_sum,
1145-
recv_topk_idx,
1146-
recv_topk_weights,
1147-
recv_x_scales,
1148-
recv_rdma_channel_prefix_matrix,
1149-
recv_gbl_channel_prefix_matrix,
1150-
send_rdma_head,
1151-
send_nvl_head,
1152-
recv_src_meta}) {
1153-
to.has_value() ? to->record_stream(comm_stream) : void();
1154-
if (allocate_on_comm_stream)
1155-
to.has_value() ? to->record_stream(compute_stream) : void();
1156-
}
1081+
// NOTES: record_stream removed, tensors are now held by Python-layer extra_tensors
11571082
} else {
11581083
stream_wait(compute_stream, comm_stream);
11591084
}
@@ -1338,24 +1263,7 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>, std::optional<EventHandl
13381263
std::optional<EventHandle> event;
13391264
if (async) {
13401265
event = EventHandle(comm_stream);
1341-
for (auto& t : {x,
1342-
src_meta,
1343-
is_combined_token_in_rank,
1344-
rdma_channel_prefix_matrix,
1345-
rdma_rank_prefix_sum,
1346-
gbl_channel_prefix_matrix,
1347-
combined_x,
1348-
combined_rdma_head,
1349-
combined_nvl_head}) {
1350-
t.record_stream(comm_stream);
1351-
if (allocate_on_comm_stream)
1352-
t.record_stream(compute_stream);
1353-
}
1354-
for (auto& to : {topk_weights, combined_topk_weights, bias_0, bias_1}) {
1355-
to.has_value() ? to->record_stream(comm_stream) : void();
1356-
if (allocate_on_comm_stream)
1357-
to.has_value() ? to->record_stream(compute_stream) : void();
1358-
}
1266+
// NOTES: record_stream removed, tensors are now held by Python-layer extra_tensors
13591267
} else {
13601268
stream_wait(compute_stream, comm_stream);
13611269
}

deep_ep/buffer.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,9 @@ def get_dispatch_layout(self, topk_idx: torch.Tensor, num_experts: int,
314314
num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event = \
315315
self.runtime.get_dispatch_layout(topk_idx, num_experts, getattr(previous_event, 'event', None),
316316
async_finish, allocate_on_comm_stream)
317-
return num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, EventOverlap(event)
317+
tensors_to_record = (topk_idx, num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank) if async_finish else None
318+
319+
return num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, EventOverlap(event, tensors_to_record)
318320

319321
# noinspection PyTypeChecker
320322
def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
@@ -386,7 +388,9 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
386388
recv_x, recv_x_scales, _, _, _, _, _, _, _, _, event = self.runtime.intranode_dispatch(
387389
x, x_scales, None, None, None, is_token_in_rank, None, num_recv_tokens, rank_prefix_matrix, channel_prefix_matrix,
388390
expert_alignment, num_worst_tokens, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
389-
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event)
391+
392+
tensors_to_record = (x, x_scales, is_token_in_rank, rank_prefix_matrix, channel_prefix_matrix, recv_x, recv_x_scales, recv_src_idx) if async_finish else None
393+
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event, tensors_to_record)
390394
else:
391395
assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
392396
recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event = \
@@ -395,10 +399,10 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
395399
expert_alignment, num_worst_tokens, config,
396400
getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
397401
handle = (rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head)
398-
return (
399-
recv_x, recv_x_scales
400-
) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(
401-
event)
402+
tensors_to_record = (x, x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_expert,
403+
is_token_in_rank, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix,
404+
recv_x, recv_x_scales, recv_src_idx, recv_topk_idx, recv_topk_weights, send_head) if async_finish else None
405+
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event, tensors_to_record)
402406

403407
# noinspection PyTypeChecker
404408
def combine(self, x: torch.Tensor, handle: Tuple,
@@ -446,7 +450,8 @@ def combine(self, x: torch.Tensor, handle: Tuple,
446450
channel_prefix_matrix, send_head, config,
447451
getattr(previous_event, 'event',
448452
None), async_finish, allocate_on_comm_stream)
449-
return recv_x, recv_topk_weights, EventOverlap(event)
453+
tensors_to_record = (x, topk_weights, bias_0, bias_1, src_idx, rank_prefix_matrix, channel_prefix_matrix, send_head, recv_x, recv_topk_weights) if async_finish else None
454+
return recv_x, recv_topk_weights, EventOverlap(event, tensors_to_record)
450455

451456
# noinspection PyTypeChecker
452457
def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
@@ -479,7 +484,11 @@ def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Te
479484
x, x_scales, topk_idx, topk_weights, None, None, is_token_in_rank, None, num_recv_tokens, num_rdma_recv_tokens,
480485
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,
481486
expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream)
482-
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event)
487+
488+
tensors_to_record =(x, x_scales, is_token_in_rank, recv_x, recv_x_scales,
489+
rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,
490+
recv_rdma_channel_prefix_matrix, recv_src_meta, send_rdma_head, send_nvl_head) if async_finish else None
491+
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event, tensors_to_record)
483492
else:
484493
assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
485494
recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, \
@@ -494,10 +503,15 @@ def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Te
494503
handle = (is_token_in_rank, rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, recv_rdma_channel_prefix_matrix,
495504
recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, recv_src_meta, send_rdma_head,
496505
send_nvl_head)
497-
return (
498-
recv_x, recv_x_scales
499-
) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(
500-
event)
506+
tensors_to_record = (x, x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert,
507+
is_token_in_rank, recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights,
508+
rdma_channel_prefix_matrix, gbl_channel_prefix_matrix,
509+
recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum,
510+
recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum,
511+
recv_src_meta, send_rdma_head, send_nvl_head) if async_finish else None
512+
513+
return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event, tensors_to_record)
514+
501515

502516
# noinspection PyTypeChecker
503517
def internode_combine(self, x: torch.Tensor, handle: Union[tuple, list],
@@ -527,7 +541,10 @@ def internode_combine(self, x: torch.Tensor, handle: Union[tuple, list],
527541
send_rdma_head, send_nvl_head, config,
528542
getattr(previous_event, 'event',
529543
None), async_finish, allocate_on_comm_stream)
530-
return combined_x, combined_topk_weights, EventOverlap(event)
544+
tensors_to_record = (x, topk_weights, bias_0, bias_1, src_meta, is_combined_token_in_rank,
545+
rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix,
546+
send_rdma_head, send_nvl_head, combined_x, combined_topk_weights) if async_finish else None
547+
return combined_x, combined_topk_weights, EventOverlap(event, tensors_to_record)
531548

532549
def clean_low_latency_buffer(self, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> None:
533550
"""

deep_ep/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,12 @@ def __init__(self, event: Optional[EventHandle] = None, extra_tensors: Optional[
3333
def current_stream_wait(self) -> None:
3434
"""
3535
The current stream `torch.cuda.current_stream()` waits for the event to be finished.
36+
After synchronization completes, tensor references are released to allow memory reuse.
3637
"""
3738
assert self.event is not None
3839
self.event.current_stream_wait()
40+
# Release tensor references after synchronization is complete
41+
self.extra_tensors = None
3942

4043
def __enter__(self) -> Any:
4144
"""
@@ -56,9 +59,10 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
5659
Utility for overlapping and Python `with` syntax.
5760
5861
Please follow the example in the `__enter__` function.
62+
After synchronization completes, tensor references are released to allow memory reuse.
5963
"""
6064
if self.event is not None:
61-
self.event.current_stream_wait()
65+
self.current_stream_wait()
6266

6367

6468
def check_nvlink_connections(group: dist.ProcessGroup):

0 commit comments

Comments
 (0)