Skip to content

Commit 1dff84b

Browse files
committed
fixed rebase, fixed rms norm fusion to use the correct strategy, enhanced rms test to check strategy
Signed-off-by: Eran Geva <[email protected]>
1 parent f08e8a3 commit 1dff84b

File tree

4 files changed

+96
-30
lines changed

4 files changed

+96
-30
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_dist.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,22 @@ def trtllm_allreduce(tensor, op, strategy: str, all_reduce_params=None):
3030
rank, world_size = get_rank_world_size()
3131
assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op."
3232

33-
# Cache key includes rank, world_size, and dtype to handle different configurations
34-
cache_key = (rank, world_size, tensor.dtype)
33+
# Convert string strategy to enum
34+
try:
35+
strategy_enum = getattr(AllReduceStrategy, strategy)
36+
except AttributeError:
37+
raise ValueError(
38+
f"Invalid allreduce strategy: {strategy}. "
39+
f"Valid options: AUTO, NCCL, ONESHOT, TWOSHOT, MIN_LATENCY, "
40+
f"LOWPRECISION, UB, MNNVL, NCCL_SYMMETRIC"
41+
)
42+
43+
# Cache key includes rank, world_size, dtype, and strategy to handle different configurations
44+
cache_key = (rank, world_size, tensor.dtype, strategy_enum)
3545
if cache_key not in _allreduce_cache:
3646
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
37-
# Use Strategy.AUTO for optimal performance
3847
_allreduce_cache[cache_key] = AllReduce(
39-
mapping=p_config, strategy=strategy, dtype=tensor.dtype
48+
mapping=p_config, strategy=strategy_enum, dtype=tensor.dtype
4049
)
4150

4251
torch_op = _allreduce_cache[cache_key]
@@ -87,7 +96,11 @@ def trtllm_dist_all_reduce_fake(tensor, strategy):
8796
"dist::trtllm_fused_allreduce_residual_rmsnorm", mutates_args=(), device_types="cuda"
8897
)
8998
def trtllm_fused_allreduce_residual_rmsnorm(
90-
tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float
99+
tensor: torch.Tensor,
100+
residual: torch.Tensor,
101+
norm_weight: torch.Tensor,
102+
eps: float,
103+
strategy: str,
91104
) -> tuple[torch.Tensor, torch.Tensor]:
92105
"""Fused allreduce + residual + rmsnorm using TRT-LLM optimized kernel.
93106
@@ -100,12 +113,18 @@ def trtllm_fused_allreduce_residual_rmsnorm(
100113
norm_weight=norm_weight,
101114
eps=eps,
102115
)
103-
return trtllm_allreduce(tensor, ReduceOp.SUM, all_reduce_params=all_reduce_params)
116+
return trtllm_allreduce(
117+
tensor, ReduceOp.SUM, strategy=strategy, all_reduce_params=all_reduce_params
118+
)
104119

105120

106121
@trtllm_fused_allreduce_residual_rmsnorm.register_fake
107122
def trtllm_fused_allreduce_residual_rmsnorm_fake(
108-
tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float
123+
tensor: torch.Tensor,
124+
residual: torch.Tensor,
125+
norm_weight: torch.Tensor,
126+
eps: float,
127+
strategy: str,
109128
) -> tuple[torch.Tensor, torch.Tensor]:
110129
return torch.empty_like(tensor), torch.empty_like(tensor)
111130

tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
The torch backend (demollm mode) does not benefit from fusion.
66
"""
77

8+
from functools import partial
89
from typing import Tuple
910

1011
import torch
@@ -28,11 +29,14 @@
2829
# ============================================================================
2930

3031

31-
def _make_allreduce_residual_rmsnorm_pattern(add_order: str = "residual_first"):
32+
def _make_allreduce_residual_rmsnorm_pattern(
33+
add_order: str = "residual_first", strategy: str = "AUTO"
34+
):
3235
"""Factory function to create pattern functions for allreduce+residual+rmsnorm fusion.
3336
3437
Args:
3538
add_order: Either "residual_first" (residual + x) or "x_first" (x + residual)
39+
strategy: AllReduce strategy to use in the pattern
3640
3741
Returns:
3842
A pattern function that can be used with register_ad_pattern
@@ -50,7 +54,7 @@ def pattern_fn(
5054
Returns (normed, z)
5155
"""
5256
input_dtype = x.dtype
53-
hidden_states = torch.ops.auto_deploy.trtllm_dist_all_reduce(x)
57+
hidden_states = torch.ops.auto_deploy.trtllm_dist_all_reduce(x, strategy)
5458

5559
# Handle addition order
5660
if add_order == "residual_first":
@@ -70,10 +74,12 @@ def pattern_fn(
7074

7175

7276
def _allreduce_residual_rmsnorm_replacement(
73-
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float
77+
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float, strategy: str
7478
):
7579
"""Replacement using TRT-LLM fused kernel."""
76-
return torch.ops.dist.trtllm_fused_allreduce_residual_rmsnorm(x, residual, weight, eps)
80+
return torch.ops.dist.trtllm_fused_allreduce_residual_rmsnorm(
81+
x, residual, weight, eps, strategy
82+
)
7783

7884

7985
# ============================================================================
@@ -115,19 +121,22 @@ def _apply(
115121
# Instantiate Pattern Functions
116122
# ============================================================================
117123

124+
# Get the allreduce strategy from shared_config
125+
strategy = shared_config.sharding_config.allreduce_strategy.name
126+
118127
# TRT-LLM backend (MPI mode) - two patterns for different addition orders
119128
_allreduce_residual_rmsnorm_pattern_trtllm = _make_allreduce_residual_rmsnorm_pattern(
120-
add_order="residual_first"
129+
add_order="residual_first", strategy=strategy
121130
)
122131
_allreduce_residual_rmsnorm_pattern2_trtllm = _make_allreduce_residual_rmsnorm_pattern(
123-
add_order="x_first"
132+
add_order="x_first", strategy=strategy
124133
)
125134

126135
# Register TRT-LLM backend patterns only (no torch backend fusion)
127136
# Pattern 1: residual + allreduce(x)
128137
register_ad_pattern(
129138
search_fn=_allreduce_residual_rmsnorm_pattern_trtllm,
130-
replace_fn=_allreduce_residual_rmsnorm_replacement,
139+
replace_fn=partial(_allreduce_residual_rmsnorm_replacement, strategy=strategy),
131140
patterns=patterns,
132141
dummy_args=dummy_args,
133142
op_ignore_types=op_ignore_types,
@@ -137,7 +146,7 @@ def _apply(
137146
# Pattern 2: allreduce(x) + residual
138147
register_ad_pattern(
139148
search_fn=_allreduce_residual_rmsnorm_pattern2_trtllm,
140-
replace_fn=_allreduce_residual_rmsnorm_replacement,
149+
replace_fn=partial(_allreduce_residual_rmsnorm_replacement, strategy=strategy),
141150
patterns=patterns,
142151
dummy_args=dummy_args,
143152
op_ignore_types=op_ignore_types,

tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def validate_allreduce_strategy(v):
5858
if isinstance(v, int):
5959
return AllReduceStrategy(v)
6060
return v # Let Pydantic handle other types
61+
62+
6163
def _get_dist_ops(backend: str):
6264
"""Get the appropriate distributed ops based on backend availability.
6365
@@ -585,7 +587,7 @@ def _shard_parameter_node(
585587

586588
# add reduction node
587589
with gm.graph.inserting_after(node):
588-
dist_node = gm.graph.call_function(fn_dist, args=dist_args)
590+
dist_node = gm.graph.call_function(fn_dist, args=(node,) + tuple(dist_args))
589591
node.replace_all_uses_with(dist_node)
590592
dist_node.replace_input_with(dist_node, node)
591593

@@ -1232,7 +1234,9 @@ def validate(self, gm: GraphModule = None, node: Node = None) -> bool:
12321234

12331235
def apply(self, gm: GraphModule, node: Node) -> None:
12341236
"""Apply EP sharding transformation to the graph module."""
1235-
_insert_sharded_moe(gm, node, self.rank, self.world_size, self.allreduce_strategy, self.dist_backend, [])
1237+
_insert_sharded_moe(
1238+
gm, node, self.rank, self.world_size, self.allreduce_strategy, self.dist_backend, []
1239+
)
12361240

12371241

12381242
class MXFP4EPShardingInfo(EPShardingInfo):
@@ -1246,7 +1250,9 @@ def validate(self, gm: GraphModule = None, node: Node = None) -> bool:
12461250
return True
12471251

12481252
def apply(self, gm: GraphModule, node: Node) -> None:
1249-
_insert_sharded_mxfp4_mlp_ep(gm, node, self.rank, self.world_size, self.allreduce_strategy, self.dist_backend)
1253+
_insert_sharded_mxfp4_mlp_ep(
1254+
gm, node, self.rank, self.world_size, self.allreduce_strategy, self.dist_backend
1255+
)
12501256

12511257

12521258
class FP8EPShardingInfo(EPShardingInfo, QuantizationShardingMixin):
@@ -1263,7 +1269,13 @@ def scale_names(self) -> List[str]:
12631269

12641270
def apply(self, gm: GraphModule, node: Node) -> None:
12651271
_insert_sharded_moe(
1266-
gm, node, self.rank, self.world_size, self.allreduce_strategy, self.dist_backend, self.scale_names()
1272+
gm,
1273+
node,
1274+
self.rank,
1275+
self.world_size,
1276+
self.allreduce_strategy,
1277+
self.dist_backend,
1278+
self.scale_names(),
12671279
)
12681280

12691281

@@ -1281,7 +1293,13 @@ def scale_names(self) -> List[str]:
12811293

12821294
def apply(self, gm: GraphModule, node: Node) -> None:
12831295
_insert_sharded_moe(
1284-
gm, node, self.rank, self.world_size, self.allreduce_strategy, self.dist_backend, self.scale_names()
1296+
gm,
1297+
node,
1298+
self.rank,
1299+
self.world_size,
1300+
self.allreduce_strategy,
1301+
self.dist_backend,
1302+
self.scale_names(),
12851303
)
12861304

12871305

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,13 @@ def forward(self, hidden_states: torch.Tensor):
3232
class AllreduceResidualNorm(torch.nn.Module):
3333
"""AllreduceResidualNorm pattern model that do residual plus x"""
3434

35-
def __init__(self, hidden_size, dtype):
35+
def __init__(self, hidden_size, dtype, strategy):
3636
super().__init__()
3737
self.norm = RMSNorm(hidden_size, 1e-5, dtype)
38+
self.strategy = strategy
3839

3940
def forward(self, x, residual):
40-
x = torch.ops.auto_deploy.torch_dist_all_reduce.default(x, "AUTO")
41+
x = torch.ops.auto_deploy.trtllm_dist_all_reduce.default(x, self.strategy)
4142
y = residual + x
4243
normed = self.norm(y)
4344
return normed, y
@@ -46,18 +47,19 @@ def forward(self, x, residual):
4647
class AllreduceResidualNorm2(torch.nn.Module):
4748
"""AllreduceResidualNorm pattern model that do x plus residual"""
4849

49-
def __init__(self, hidden_size, dtype):
50+
def __init__(self, hidden_size, dtype, strategy):
5051
super().__init__()
5152
self.norm = RMSNorm(hidden_size, 1e-5, dtype)
53+
self.strategy = strategy
5254

5355
def forward(self, x, residual):
54-
x = torch.ops.auto_deploy.torch_dist_all_reduce.default(x, "AUTO")
56+
x = torch.ops.auto_deploy.trtllm_dist_all_reduce.default(x, self.strategy)
5557
y = x + residual
5658
normed = self.norm(y)
5759
return normed, y
5860

5961

60-
def _test_allreduce_fusion(port: int, ModuleCls):
62+
def _test_allreduce_fusion(port: int, ModuleCls, strategy: str):
6163
if not is_trtllm_op_available():
6264
pytest.skip("Require trtllm ops to run test_allreduce_fusion.")
6365

@@ -69,7 +71,7 @@ def _test_allreduce_fusion(port: int, ModuleCls):
6971
residual = torch.randn(16, 16).to(dtype).cuda()
7072

7173
# Trace the original model
72-
model = ModuleCls(16, dtype)
74+
model = ModuleCls(16, dtype, strategy=strategy)
7375
args = (
7476
x,
7577
residual,
@@ -78,10 +80,14 @@ def _test_allreduce_fusion(port: int, ModuleCls):
7880
# Run the original
7981
original_outputs, residual_original = gm(x, residual)
8082

81-
# Fuse ops
83+
# Fuse ops with the specified strategy
8284
gm_transformed = InferenceOptimizer(
8385
None,
8486
{
87+
"detect_sharding": {
88+
"stage": "post_export",
89+
"allreduce_strategy": strategy,
90+
},
8591
"fuse_allreduce_residual_rmsnorm": {
8692
"stage": "post_load_fusion",
8793
},
@@ -91,12 +97,21 @@ def _test_allreduce_fusion(port: int, ModuleCls):
9197
# Run the fused graph
9298
fused_outputs, residual_fused = gm_transformed(x, residual)
9399

94-
# Check if fused node in the graph
100+
# Check if fused node in the graph and verify strategy
95101
has_fused_node = False
102+
fused_node_strategy = None
96103
for node in gm_transformed.graph.nodes:
97104
if is_op(node, torch.ops.dist.trtllm_fused_allreduce_residual_rmsnorm):
98105
has_fused_node = True
106+
# The fused node should have the strategy as the last argument
107+
# args: (x, residual, weight, eps, strategy)
108+
if len(node.args) >= 5:
109+
fused_node_strategy = node.args[4]
110+
99111
assert has_fused_node, "Fused node not found."
112+
assert fused_node_strategy == strategy, (
113+
f"Fused node strategy mismatch: expected '{strategy}', got '{fused_node_strategy}'"
114+
)
100115

101116
# Verify outputs are consistent
102117
assert torch.allclose(residual_original, residual_fused, atol=1e-5), (
@@ -117,11 +132,16 @@ def _test_allreduce_fusion(port: int, ModuleCls):
117132
[AllreduceResidualNorm, AllreduceResidualNorm2],
118133
ids=["residual_plus_x", "x_plus_residual"],
119134
)
120-
def test_allreduce_fusion(device_count, ModuleCls):
135+
@pytest.mark.parametrize(
136+
"strategy",
137+
["AUTO", "NCCL", "ONESHOT"],
138+
ids=["strategy_auto", "strategy_nccl", "strategy_oneshot"],
139+
)
140+
def test_allreduce_fusion(device_count, ModuleCls, strategy):
121141
if device_count <= 1:
122142
pytest.skip("Require multi GPUs to run test_allreduce_fusion.")
123143
port = dist.get_free_port()
124144

125145
n_workers = device_count
126146
mpi_pool = MpiPoolSession(n_workers=n_workers)
127-
mpi_pool.submit_sync(_test_allreduce_fusion, port=port, ModuleCls=ModuleCls)
147+
mpi_pool.submit_sync(_test_allreduce_fusion, port=port, ModuleCls=ModuleCls, strategy=strategy)

0 commit comments

Comments
 (0)