@@ -90,23 +90,28 @@ def _rms_norm_replacement(
9090class FuseRMSNormConfig (TransformConfig ):
9191 """Configuration for the RMSNorm fusion transform."""
9292
93- backend : str = Field (
93+ rmsnorm_backend : str = Field (
9494 default = "flashinfer" ,
95- description = "Backend to use for RMSNorm computation ('flashinfer' or 'triton')." ,
95+ description = "Backend to use for RMSNorm computation ('flashinfer', 'triton', or 'torch')." ,
96+ )
97+ gated_rmsnorm_backend : str = Field (
98+ default = "triton" ,
99+ description = "Backend to use for gated RMSNorm computation (currently only 'triton')." ,
96100 )
97101
98102
99103@TransformRegistry .register ("fuse_rmsnorm" )
100104class FuseRMSNorm (BaseTransform ):
101- """Matches and replaces RMSNorm patterns in the graph with FlashInfer or Triton implementation .
105+ """Matches and replaces RMSNorm patterns (regular and gated) in the graph with optimized implementations .
102106
103- This function sets up pattern matching to identify RMSNorm operations in the graph
107+ This function sets up pattern matching to identify both regular and gated RMSNorm operations in the graph
104108 and replaces them with optimized implementations. It uses dummy tensors to register
105109 the pattern matching rules.
106110
107111 Args:
108112 gm: Input graph module to transform.
109- backend: Backend to use for RMSNorm computation ("flashinfer" or "triton").
113+ rmsnorm_backend: Backend to use for regular RMSNorm computation ("flashinfer", "triton", or "torch").
114+ gated_rmsnorm_backend: Backend to use for gated RMSNorm computation (currently only "triton").
110115
111116 Returns:
112117 Transformed graph module with optimized RMSNorm operations.
@@ -125,15 +130,23 @@ def _apply(
125130 factory : ModelFactory ,
126131 shared_config : SharedConfig ,
127132 ) -> Tuple [GraphModule , TransformInfo ]:
128- if self .config .backend .lower () not in _BACKEND_OPS :
133+ # Validate rmsnorm_backend
134+ if self .config .rmsnorm_backend .lower () not in _BACKEND_OPS :
135+ raise ValueError (
136+ f"Invalid rmsnorm_backend, must be one of { list (_BACKEND_OPS )} , got { self .config .rmsnorm_backend } "
137+ )
138+
139+ # Validate gated_rmsnorm_backend (currently only triton is supported)
140+ if self .config .gated_rmsnorm_backend .lower () != "triton" :
129141 raise ValueError (
130- f"Invalid backend, must be one of { list (_BACKEND_OPS )} , got { self .config .backend } "
142+ f"""Invalid gated_rmsnorm_backend, currently only 'triton' is supported,
143+ got { self .config .gated_rmsnorm_backend } """
131144 )
132145
133146 graph = gm .graph
134147 patterns = ADPatternMatcherPass ()
135148
136- # Create dummy tensors for pattern matching
149+ # Pattern matching for regular RMSNorm
137150 bs = 2
138151 hidden_size = 512
139152
@@ -160,13 +173,42 @@ def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float =
160173 for input_dtype , weight_dtype in configs :
161174 register_ad_pattern (
162175 search_fn = search_fn ,
163- replace_fn = partial (_rms_norm_replacement , backend = self .config .backend ),
176+ replace_fn = partial (_rms_norm_replacement , backend = self .config .rmsnorm_backend ),
164177 patterns = patterns ,
165178 dummy_args = dummy_args (input_dtype , weight_dtype ),
166179 op_ignore_types = {},
167180 scalar_workaround = {"eps" : 1e-6 },
168181 )
169182
183+ # Pattern matching for gated RMSNorm
184+ B , S , H = 2 , 3 , 4096
185+ group_size = 512
186+ eps = 1e-5
187+
188+ def make_dummy_args_gated (group_size : int , eps : float ) -> list :
189+ x = torch .randn (B , S , H , dtype = torch .float32 )
190+ w = torch .randn (H , dtype = torch .float32 )
191+ g = torch .randn (B , S , H , dtype = torch .float32 )
192+ return [x , w , g , eps , group_size ]
193+
194+ op_ignore_types = {
195+ torch .ops .aten .reshape .default : (int , list , tuple ),
196+ torch .ops .aten .view .default : (int , list , tuple ),
197+ torch .ops .aten .mean .dim : (list , tuple ),
198+ torch .ops .aten .to .dtype : (torch .dtype ,),
199+ }
200+
201+ # Register pattern for gated RMSNorm
202+ register_ad_pattern (
203+ search_fn = _gated_rmsnorm_pattern_ref ,
204+ replace_fn = _gated_rmsnorm_replacement ,
205+ patterns = patterns ,
206+ dummy_args = make_dummy_args_gated (group_size , eps ),
207+ op_ignore_types = op_ignore_types ,
208+ scalar_workaround = {"eps" : eps , "group_size" : group_size },
209+ skip_duplicates = True ,
210+ )
211+
170212 cnt = patterns .apply (graph )
171213
172214 info = TransformInfo (
@@ -204,61 +246,6 @@ def _gated_rmsnorm_replacement(
204246 eps : float ,
205247 group_size : int ,
206248) -> torch .Tensor :
207- return torch .ops .auto_deploy .torch_rmsnorm_gated (
249+ return torch .ops .auto_deploy .triton_rmsnorm_gated (
208250 x , weight , gate , float (eps ), int (group_size ), False
209251 )
210-
211-
212- @TransformRegistry .register ("fuse_gated_rmsnorm" )
213- class FuseGatedRMSNorm (BaseTransform ):
214- """
215- Fuse the NemotronH-style gated RMSNorm subgraph into a single custom op:
216- auto_deploy::torch_rmsnorm_gated(x, weight, gate, eps, group_size, norm_before_gate=False)
217- """
218-
219- def _apply (
220- self ,
221- gm : GraphModule ,
222- cm : CachedSequenceInterface ,
223- factory : ModelFactory ,
224- shared_config : SharedConfig ,
225- ) -> Tuple [GraphModule , TransformInfo ]:
226- graph = gm .graph
227- patterns = ADPatternMatcherPass ()
228-
229- B , S , H = 2 , 3 , 4096
230- group_size = 512
231- eps = 1e-5
232-
233- def make_dummy_args (group_size : int , eps : float ) -> list :
234- x = torch .randn (B , S , H , dtype = torch .float32 )
235- w = torch .randn (H , dtype = torch .float32 )
236- g = torch .randn (B , S , H , dtype = torch .float32 )
237- return [x , w , g , eps , group_size ]
238-
239- op_ignore_types = {
240- torch .ops .aten .reshape .default : (int , list , tuple ),
241- torch .ops .aten .view .default : (int , list , tuple ),
242- torch .ops .aten .mean .dim : (list , tuple ),
243- torch .ops .aten .to .dtype : (torch .dtype ,),
244- }
245-
246- register_ad_pattern (
247- search_fn = _gated_rmsnorm_pattern_ref ,
248- replace_fn = partial (_gated_rmsnorm_replacement ),
249- patterns = patterns ,
250- dummy_args = make_dummy_args (group_size , eps ),
251- op_ignore_types = op_ignore_types ,
252- scalar_workaround = {"eps" : eps , "group_size" : group_size },
253- skip_duplicates = True ,
254- )
255-
256- num = patterns .apply (graph )
257-
258- info = TransformInfo (
259- skipped = False ,
260- num_matches = num ,
261- is_clean = num == 0 ,
262- has_valid_shapes = num == 0 ,
263- )
264- return gm , info
0 commit comments