@@ -314,7 +314,9 @@ def get_dispatch_layout(self, topk_idx: torch.Tensor, num_experts: int,
314314 num_tokens_per_rank , num_tokens_per_rdma_rank , num_tokens_per_expert , is_token_in_rank , event = \
315315 self .runtime .get_dispatch_layout (topk_idx , num_experts , getattr (previous_event , 'event' , None ),
316316 async_finish , allocate_on_comm_stream )
317- return num_tokens_per_rank , num_tokens_per_rdma_rank , num_tokens_per_expert , is_token_in_rank , EventOverlap (event )
317+ tensors_to_record = (topk_idx , num_tokens_per_rank , num_tokens_per_rdma_rank , num_tokens_per_expert , is_token_in_rank ) if async_finish else None
318+
319+ return num_tokens_per_rank , num_tokens_per_rdma_rank , num_tokens_per_expert , is_token_in_rank , EventOverlap (event , tensors_to_record )
318320
319321 # noinspection PyTypeChecker
320322 def dispatch (self , x : Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]],
@@ -386,7 +388,9 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
386388 recv_x , recv_x_scales , _ , _ , _ , _ , _ , _ , _ , _ , event = self .runtime .intranode_dispatch (
387389 x , x_scales , None , None , None , is_token_in_rank , None , num_recv_tokens , rank_prefix_matrix , channel_prefix_matrix ,
388390 expert_alignment , num_worst_tokens , config , getattr (previous_event , 'event' , None ), async_finish , allocate_on_comm_stream )
389- return (recv_x , recv_x_scales ) if x_scales is not None else recv_x , None , None , None , None , EventOverlap (event )
391+
392+ tensors_to_record = (x , x_scales , is_token_in_rank , rank_prefix_matrix , channel_prefix_matrix , recv_x , recv_x_scales , recv_src_idx ) if async_finish else None
393+ return (recv_x , recv_x_scales ) if x_scales is not None else recv_x , None , None , None , None , EventOverlap (event , tensors_to_record )
390394 else :
391395 assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
392396 recv_x , recv_x_scales , recv_topk_idx , recv_topk_weights , num_recv_tokens_per_expert_list , rank_prefix_matrix , channel_prefix_matrix , recv_channel_prefix_matrix , recv_src_idx , send_head , event = \
@@ -395,10 +399,10 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
395399 expert_alignment , num_worst_tokens , config ,
396400 getattr (previous_event , 'event' , None ), async_finish , allocate_on_comm_stream )
397401 handle = (rank_prefix_matrix , channel_prefix_matrix , recv_channel_prefix_matrix , recv_src_idx , is_token_in_rank , send_head )
398- return (
399- recv_x , recv_x_scales
400- ) if x_scales is not None else recv_x , recv_topk_idx , recv_topk_weights , num_recv_tokens_per_expert_list , handle , EventOverlap (
401- event )
402+ tensors_to_record = ( x , x_scales , topk_idx , topk_weights , num_tokens_per_rank , num_tokens_per_expert ,
403+ is_token_in_rank , rank_prefix_matrix , channel_prefix_matrix , recv_channel_prefix_matrix ,
404+ recv_x , recv_x_scales , recv_src_idx , recv_topk_idx , recv_topk_weights , send_head ) if async_finish else None
405+ return ( recv_x , recv_x_scales ) if x_scales is not None else recv_x , recv_topk_idx , recv_topk_weights , num_recv_tokens_per_expert_list , handle , EventOverlap ( event , tensors_to_record )
402406
403407 # noinspection PyTypeChecker
404408 def combine (self , x : torch .Tensor , handle : Tuple ,
@@ -446,7 +450,8 @@ def combine(self, x: torch.Tensor, handle: Tuple,
446450 channel_prefix_matrix , send_head , config ,
447451 getattr (previous_event , 'event' ,
448452 None ), async_finish , allocate_on_comm_stream )
449- return recv_x , recv_topk_weights , EventOverlap (event )
453+ tensors_to_record = (x , topk_weights , bias_0 , bias_1 , src_idx , rank_prefix_matrix , channel_prefix_matrix , send_head , recv_x , recv_topk_weights ) if async_finish else None
454+ return recv_x , recv_topk_weights , EventOverlap (event , tensors_to_record )
450455
451456 # noinspection PyTypeChecker
452457 def internode_dispatch (self , x : Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]],
@@ -479,7 +484,11 @@ def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Te
479484 x , x_scales , topk_idx , topk_weights , None , None , is_token_in_rank , None , num_recv_tokens , num_rdma_recv_tokens ,
480485 rdma_channel_prefix_matrix , recv_rdma_rank_prefix_sum , gbl_channel_prefix_matrix , recv_gbl_rank_prefix_sum ,
481486 expert_alignment , config , getattr (previous_event , 'event' , None ), async_finish , allocate_on_comm_stream )
482- return (recv_x , recv_x_scales ) if x_scales is not None else recv_x , None , None , None , None , EventOverlap (event )
487+
488+ tensors_to_record = (x , x_scales , is_token_in_rank , recv_x , recv_x_scales ,
489+ rdma_channel_prefix_matrix , recv_rdma_rank_prefix_sum , gbl_channel_prefix_matrix , recv_gbl_rank_prefix_sum ,
490+ recv_rdma_channel_prefix_matrix , recv_src_meta , send_rdma_head , send_nvl_head ) if async_finish else None
491+ return (recv_x , recv_x_scales ) if x_scales is not None else recv_x , None , None , None , None , EventOverlap (event , tensors_to_record )
483492 else :
484493 assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None
485494 recv_x , recv_x_scales , recv_topk_idx , recv_topk_weights , num_recv_tokens_per_expert_list , \
@@ -494,10 +503,15 @@ def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Te
494503 handle = (is_token_in_rank , rdma_channel_prefix_matrix , gbl_channel_prefix_matrix , recv_rdma_channel_prefix_matrix ,
495504 recv_rdma_rank_prefix_sum , recv_gbl_channel_prefix_matrix , recv_gbl_rank_prefix_sum , recv_src_meta , send_rdma_head ,
496505 send_nvl_head )
497- return (
498- recv_x , recv_x_scales
499- ) if x_scales is not None else recv_x , recv_topk_idx , recv_topk_weights , num_recv_tokens_per_expert_list , handle , EventOverlap (
500- event )
506+ tensors_to_record = (x , x_scales , topk_idx , topk_weights , num_tokens_per_rank , num_tokens_per_rdma_rank , num_tokens_per_expert ,
507+ is_token_in_rank , recv_x , recv_x_scales , recv_topk_idx , recv_topk_weights ,
508+ rdma_channel_prefix_matrix , gbl_channel_prefix_matrix ,
509+ recv_rdma_channel_prefix_matrix , recv_rdma_rank_prefix_sum ,
510+ recv_gbl_channel_prefix_matrix , recv_gbl_rank_prefix_sum ,
511+ recv_src_meta , send_rdma_head , send_nvl_head ) if async_finish else None
512+
513+ return (recv_x , recv_x_scales ) if x_scales is not None else recv_x , recv_topk_idx , recv_topk_weights , num_recv_tokens_per_expert_list , handle , EventOverlap (event , tensors_to_record )
514+
501515
502516 # noinspection PyTypeChecker
503517 def internode_combine (self , x : torch .Tensor , handle : Union [tuple , list ],
@@ -527,7 +541,10 @@ def internode_combine(self, x: torch.Tensor, handle: Union[tuple, list],
527541 send_rdma_head , send_nvl_head , config ,
528542 getattr (previous_event , 'event' ,
529543 None ), async_finish , allocate_on_comm_stream )
530- return combined_x , combined_topk_weights , EventOverlap (event )
544+ tensors_to_record = (x , topk_weights , bias_0 , bias_1 , src_meta , is_combined_token_in_rank ,
545+ rdma_channel_prefix_matrix , rdma_rank_prefix_sum , gbl_channel_prefix_matrix ,
546+ send_rdma_head , send_nvl_head , combined_x , combined_topk_weights ) if async_finish else None
547+ return combined_x , combined_topk_weights , EventOverlap (event , tensors_to_record )
531548
532549 def clean_low_latency_buffer (self , num_max_dispatch_tokens_per_rank : int , hidden : int , num_experts : int ) -> None :
533550 """
0 commit comments