@@ -40,15 +40,42 @@ def trtllm_allreduce(tensor, op, all_reduce_params=None):
4040 def fused_allreduce_residual_rmsnorm (
4141 tensor : torch .Tensor , residual : torch .Tensor , norm_weight : torch .Tensor , eps : float
4242 ) -> tuple [torch .Tensor , torch .Tensor ]:
43- """Fusing allreduce, residual (add), and hf_rms_norm together."""
44- all_reduce_params = AllReduceParams (
45- fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM ,
46- bias = None ,
47- residual = residual ,
48- norm_weight = norm_weight ,
49- eps = eps ,
50- )
51- return trtllm_allreduce (tensor , ReduceOp .SUM , all_reduce_params = all_reduce_params )
43+ """Fusing allreduce, residual (add), and hf_rms_norm together.
44+
45+ When TRT-LLM ops are available (MPI mode), uses the fused kernel.
46+ Otherwise, falls back to separate operations using torch distributed.
47+ """
48+ # Only use TRT-LLM fused op when running with MPI
49+ if is_trtllm_op_available ():
50+ all_reduce_params = AllReduceParams (
51+ fusion_op = AllReduceFusionOp .RESIDUAL_RMS_NORM ,
52+ bias = None ,
53+ residual = residual ,
54+ norm_weight = norm_weight ,
55+ eps = eps ,
56+ )
57+ return trtllm_allreduce (tensor , ReduceOp .SUM , all_reduce_params = all_reduce_params )
58+ else :
59+ # Fallback: unfused implementation using torch distributed
60+ # This is used in demollm mode without MPI
61+ from .common import all_reduce as torch_all_reduce
62+
63+ # 1. All-reduce the tensor
64+ tensor_reduced = tensor .clone ()
65+ torch_all_reduce (tensor_reduced , op = ReduceOp .SUM )
66+
67+ # 2. Add residual
68+ tensor_with_residual = tensor_reduced + residual
69+
70+ # 3. Apply RMSNorm using PyTorch's built-in function
71+ norm_out = torch .nn .functional .rms_norm (
72+ tensor_with_residual ,
73+ normalized_shape = (tensor_with_residual .size (- 1 ),),
74+ weight = norm_weight ,
75+ eps = eps ,
76+ )
77+
78+ return norm_out , tensor_with_residual
5279
5380 @fused_allreduce_residual_rmsnorm .register_fake
5481 def fused_allreduce_residual_rmsnorm_fake (
0 commit comments