@@ -721,7 +721,8 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
721721 def low_latency_combine (self , x : torch .Tensor , topk_idx : torch .Tensor , topk_weights : torch .Tensor ,
722722 handle : tuple , use_logfmt : bool = False , zero_copy : bool = False , async_finish : bool = False ,
723723 return_recv_hook : bool = False , out : Optional [torch .Tensor ] = None ,
724- combine_wait_recv_cost_stats : Optional [torch .Tensor ] = None ) -> \
724+ combine_wait_recv_cost_stats : Optional [torch .Tensor ] = None ,
725+ overlap : bool = False , src_signals : Optional [torch .Tensor ] = None , src_signal_expect_value : int = 0 ) -> \
725726 Tuple [torch .Tensor , EventOverlap , Callable ]:
726727 """
727728 A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA.
@@ -761,7 +762,8 @@ def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weig
761762 combine_wait_recv_cost_stats ,
762763 num_max_dispatch_tokens_per_rank , num_experts ,
763764 use_logfmt , zero_copy , async_finish , return_recv_hook ,
764- out )
765+ out ,
766+ overlap , src_signals , src_signal_expect_value )
765767 tensors_to_record = (x , topk_idx , topk_weights , src_info , layout_range , combined_x )
766768 return combined_x , EventOverlap (event , tensors_to_record if async_finish else None ), hook
767769
0 commit comments