Skip to content

Commit e96a3d2

Browse files
authored
[None][autodeploy] minor refactor to rmsnorm transforms (#8657)
Signed-off-by: Fridah-nv <[email protected]>
1 parent 12f339f commit e96a3d2

File tree

5 files changed

+62
-78
lines changed

5 files changed

+62
-78
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,11 @@ transforms:
122122
fuse_allreduce_residual_rmsnorm:
123123
stage: post_load_fusion
124124
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
125-
# check if we can fuse rmsnorm
126125
fuse_rmsnorm:
127-
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
128-
# check if we can fuse rmsnorm
129126
stage: post_load_fusion
130-
backend: triton
127+
rmsnorm_backend: triton
128+
gated_rmsnorm_backend: triton
131129
requires_shape_prop: true
132-
fuse_gated_rmsnorm:
133-
stage: post_load_fusion
134130

135131
############################################################################################
136132
# VISUALIZE GRAPH

tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ def _(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
8383
return torch.empty_like(input)
8484

8585

86-
@torch.library.custom_op("auto_deploy::torch_rmsnorm_gated", mutates_args=())
87-
def torch_rmsnorm_gated(
86+
@torch.library.custom_op("auto_deploy::triton_rmsnorm_gated", mutates_args=())
87+
def triton_rmsnorm_gated(
8888
x: torch.Tensor,
8989
weight: torch.Tensor,
9090
gate: torch.Tensor | None,
@@ -140,8 +140,8 @@ def torch_rmsnorm_gated(
140140
return out2.reshape(x_shape)
141141

142142

143-
@torch_rmsnorm_gated.register_fake
144-
def _torch_rmsnorm_gated_meta(
143+
@triton_rmsnorm_gated.register_fake
144+
def _triton_rmsnorm_gated_meta(
145145
x,
146146
weight,
147147
gate,

tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py

Lines changed: 52 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -90,23 +90,28 @@ def _rms_norm_replacement(
9090
class FuseRMSNormConfig(TransformConfig):
9191
"""Configuration for the RMSNorm fusion transform."""
9292

93-
backend: str = Field(
93+
rmsnorm_backend: str = Field(
9494
default="flashinfer",
95-
description="Backend to use for RMSNorm computation ('flashinfer' or 'triton').",
95+
description="Backend to use for RMSNorm computation ('flashinfer', 'triton', or 'torch').",
96+
)
97+
gated_rmsnorm_backend: str = Field(
98+
default="triton",
99+
description="Backend to use for gated RMSNorm computation (currently only 'triton').",
96100
)
97101

98102

99103
@TransformRegistry.register("fuse_rmsnorm")
100104
class FuseRMSNorm(BaseTransform):
101-
"""Matches and replaces RMSNorm patterns in the graph with FlashInfer or Triton implementation.
105+
"""Matches and replaces RMSNorm patterns (regular and gated) in the graph with optimized implementations.
102106
103-
This function sets up pattern matching to identify RMSNorm operations in the graph
107+
This function sets up pattern matching to identify both regular and gated RMSNorm operations in the graph
104108
and replaces them with optimized implementations. It uses dummy tensors to register
105109
the pattern matching rules.
106110
107111
Args:
108112
gm: Input graph module to transform.
109-
backend: Backend to use for RMSNorm computation ("flashinfer" or "triton").
113+
rmsnorm_backend: Backend to use for regular RMSNorm computation ("flashinfer", "triton", or "torch").
114+
gated_rmsnorm_backend: Backend to use for gated RMSNorm computation (currently only "triton").
110115
111116
Returns:
112117
Transformed graph module with optimized RMSNorm operations.
@@ -125,15 +130,23 @@ def _apply(
125130
factory: ModelFactory,
126131
shared_config: SharedConfig,
127132
) -> Tuple[GraphModule, TransformInfo]:
128-
if self.config.backend.lower() not in _BACKEND_OPS:
133+
# Validate rmsnorm_backend
134+
if self.config.rmsnorm_backend.lower() not in _BACKEND_OPS:
135+
raise ValueError(
136+
f"Invalid rmsnorm_backend, must be one of {list(_BACKEND_OPS)}, got {self.config.rmsnorm_backend}"
137+
)
138+
139+
# Validate gated_rmsnorm_backend (currently only triton is supported)
140+
if self.config.gated_rmsnorm_backend.lower() != "triton":
129141
raise ValueError(
130-
f"Invalid backend, must be one of {list(_BACKEND_OPS)}, got {self.config.backend}"
142+
f"""Invalid gated_rmsnorm_backend, currently only 'triton' is supported,
143+
got {self.config.gated_rmsnorm_backend}"""
131144
)
132145

133146
graph = gm.graph
134147
patterns = ADPatternMatcherPass()
135148

136-
# Create dummy tensors for pattern matching
149+
# Pattern matching for regular RMSNorm
137150
bs = 2
138151
hidden_size = 512
139152

@@ -160,13 +173,42 @@ def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float =
160173
for input_dtype, weight_dtype in configs:
161174
register_ad_pattern(
162175
search_fn=search_fn,
163-
replace_fn=partial(_rms_norm_replacement, backend=self.config.backend),
176+
replace_fn=partial(_rms_norm_replacement, backend=self.config.rmsnorm_backend),
164177
patterns=patterns,
165178
dummy_args=dummy_args(input_dtype, weight_dtype),
166179
op_ignore_types={},
167180
scalar_workaround={"eps": 1e-6},
168181
)
169182

183+
# Pattern matching for gated RMSNorm
184+
B, S, H = 2, 3, 4096
185+
group_size = 512
186+
eps = 1e-5
187+
188+
def make_dummy_args_gated(group_size: int, eps: float) -> list:
189+
x = torch.randn(B, S, H, dtype=torch.float32)
190+
w = torch.randn(H, dtype=torch.float32)
191+
g = torch.randn(B, S, H, dtype=torch.float32)
192+
return [x, w, g, eps, group_size]
193+
194+
op_ignore_types = {
195+
torch.ops.aten.reshape.default: (int, list, tuple),
196+
torch.ops.aten.view.default: (int, list, tuple),
197+
torch.ops.aten.mean.dim: (list, tuple),
198+
torch.ops.aten.to.dtype: (torch.dtype,),
199+
}
200+
201+
# Register pattern for gated RMSNorm
202+
register_ad_pattern(
203+
search_fn=_gated_rmsnorm_pattern_ref,
204+
replace_fn=_gated_rmsnorm_replacement,
205+
patterns=patterns,
206+
dummy_args=make_dummy_args_gated(group_size, eps),
207+
op_ignore_types=op_ignore_types,
208+
scalar_workaround={"eps": eps, "group_size": group_size},
209+
skip_duplicates=True,
210+
)
211+
170212
cnt = patterns.apply(graph)
171213

172214
info = TransformInfo(
@@ -204,61 +246,6 @@ def _gated_rmsnorm_replacement(
204246
eps: float,
205247
group_size: int,
206248
) -> torch.Tensor:
207-
return torch.ops.auto_deploy.torch_rmsnorm_gated(
249+
return torch.ops.auto_deploy.triton_rmsnorm_gated(
208250
x, weight, gate, float(eps), int(group_size), False
209251
)
210-
211-
212-
@TransformRegistry.register("fuse_gated_rmsnorm")
213-
class FuseGatedRMSNorm(BaseTransform):
214-
"""
215-
Fuse the NemotronH-style gated RMSNorm subgraph into a single custom op:
216-
auto_deploy::torch_rmsnorm_gated(x, weight, gate, eps, group_size, norm_before_gate=False)
217-
"""
218-
219-
def _apply(
220-
self,
221-
gm: GraphModule,
222-
cm: CachedSequenceInterface,
223-
factory: ModelFactory,
224-
shared_config: SharedConfig,
225-
) -> Tuple[GraphModule, TransformInfo]:
226-
graph = gm.graph
227-
patterns = ADPatternMatcherPass()
228-
229-
B, S, H = 2, 3, 4096
230-
group_size = 512
231-
eps = 1e-5
232-
233-
def make_dummy_args(group_size: int, eps: float) -> list:
234-
x = torch.randn(B, S, H, dtype=torch.float32)
235-
w = torch.randn(H, dtype=torch.float32)
236-
g = torch.randn(B, S, H, dtype=torch.float32)
237-
return [x, w, g, eps, group_size]
238-
239-
op_ignore_types = {
240-
torch.ops.aten.reshape.default: (int, list, tuple),
241-
torch.ops.aten.view.default: (int, list, tuple),
242-
torch.ops.aten.mean.dim: (list, tuple),
243-
torch.ops.aten.to.dtype: (torch.dtype,),
244-
}
245-
246-
register_ad_pattern(
247-
search_fn=_gated_rmsnorm_pattern_ref,
248-
replace_fn=partial(_gated_rmsnorm_replacement),
249-
patterns=patterns,
250-
dummy_args=make_dummy_args(group_size, eps),
251-
op_ignore_types=op_ignore_types,
252-
scalar_workaround={"eps": eps, "group_size": group_size},
253-
skip_duplicates=True,
254-
)
255-
256-
num = patterns.apply(graph)
257-
258-
info = TransformInfo(
259-
skipped=False,
260-
num_matches=num,
261-
is_clean=num == 0,
262-
has_valid_shapes=num == 0,
263-
)
264-
return gm, info

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_mamba_rms_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def test_custom_op_matches_ref(B, T, H, group, use_gate, dtype):
2424
)
2525

2626
# Custom op (currently returns fp32). Cast it back to x.dtype for apples-to-apples with ref.
27-
y_op_fp32 = torch.ops.auto_deploy.torch_rmsnorm_gated(x, w, z, 1e-5, group, False)
27+
y_op_fp32 = torch.ops.auto_deploy.triton_rmsnorm_gated(x, w, z, 1e-5, group, False)
2828
y_op = y_op_fp32.to(x.dtype)
2929

3030
assert y_ref.dtype == x.dtype and y_op.dtype == x.dtype

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fuse_rmsnorm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ def checker(gm):
6666
{
6767
"fuse_rmsnorm": {
6868
"stage": "post_load_fusion",
69-
"backend": variant,
69+
"gated_rmsnorm_backend": "triton",
70+
"rmsnorm_backend": variant,
7071
},
7172
},
7273
)(None, gm)
@@ -102,4 +103,4 @@ def test_rmsnorm_fusion(eps, variant, op):
102103
def test_rmsnorm_fusion_nemotron_h():
103104
# Only the triton backend supports the nemotron h rmsnorm
104105
model = TestModel(eps=1e-6, use_nemotron_h=True)
105-
_run_test(model, torch.ops.auto_deploy.triton_rms_norm, "triton")
106+
_run_test(model, torch.ops.auto_deploy.triton_rms_norm, variant="triton")

0 commit comments

Comments
 (0)