Skip to content

Commit 3ac11a6

Browse files
authored
[NVIDIA#9152][fix] AutoDeploy fused_allreduce_residual_rmsnorm to support demollm mode (NVIDIA#9197)
Signed-off-by: Eran Geva <[email protected]>
1 parent f0b68e4 commit 3ac11a6

File tree

1 file changed

+36
-9
lines changed
  • tensorrt_llm/_torch/auto_deploy/distributed

1 file changed

+36
-9
lines changed

tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,42 @@ def trtllm_allreduce(tensor, op, all_reduce_params=None):
4040
def fused_allreduce_residual_rmsnorm(
4141
tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float
4242
) -> tuple[torch.Tensor, torch.Tensor]:
43-
"""Fusing allreduce, residual (add), and hf_rms_norm together."""
44-
all_reduce_params = AllReduceParams(
45-
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
46-
bias=None,
47-
residual=residual,
48-
norm_weight=norm_weight,
49-
eps=eps,
50-
)
51-
return trtllm_allreduce(tensor, ReduceOp.SUM, all_reduce_params=all_reduce_params)
43+
"""Fusing allreduce, residual (add), and hf_rms_norm together.
44+
45+
When TRT-LLM ops are available (MPI mode), uses the fused kernel.
46+
Otherwise, falls back to separate operations using torch distributed.
47+
"""
48+
# Only use TRT-LLM fused op when running with MPI
49+
if is_trtllm_op_available():
50+
all_reduce_params = AllReduceParams(
51+
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
52+
bias=None,
53+
residual=residual,
54+
norm_weight=norm_weight,
55+
eps=eps,
56+
)
57+
return trtllm_allreduce(tensor, ReduceOp.SUM, all_reduce_params=all_reduce_params)
58+
else:
59+
# Fallback: unfused implementation using torch distributed
60+
# This is used in demollm mode without MPI
61+
from .common import all_reduce as torch_all_reduce
62+
63+
# 1. All-reduce the tensor
64+
tensor_reduced = tensor.clone()
65+
torch_all_reduce(tensor_reduced, op=ReduceOp.SUM)
66+
67+
# 2. Add residual
68+
tensor_with_residual = tensor_reduced + residual
69+
70+
# 3. Apply RMSNorm using PyTorch's built-in function
71+
norm_out = torch.nn.functional.rms_norm(
72+
tensor_with_residual,
73+
normalized_shape=(tensor_with_residual.size(-1),),
74+
weight=norm_weight,
75+
eps=eps,
76+
)
77+
78+
return norm_out, tensor_with_residual
5279

5380
@fused_allreduce_residual_rmsnorm.register_fake
5481
def fused_allreduce_residual_rmsnorm_fake(

0 commit comments

Comments
 (0)