@@ -54,6 +54,71 @@ pnnx.Output output 1 0 out
5454 }
5555};
5656
57+ class fuse_rmsnorm_pass_without_gamma : public GraphRewriterPass
58+ {
59+ public:
60+ const char * match_pattern_graph () const
61+ {
62+ return R"PNNXIR( 7767517
63+ 5 4
64+ pnnx.Input input 0 1 input
65+ pnnx.Expression op_0 1 1 input sq expr=pow(@0,2)
66+ torch.mean op_1 1 1 sq sqmean dim=(-1) keepdim=True
67+ pnnx.Expression op_2 2 1 input sqmean out expr=mul(@0,rsqrt(add(@1,%eps)))
68+ pnnx.Output output 1 0 out
69+ )PNNXIR" ;
70+ }
71+
72+ const char * type_str () const
73+ {
74+ return " nn.RMSNorm" ;
75+ }
76+
77+ const char * name_str () const
78+ {
79+ return " t5ln" ;
80+ }
81+
82+ bool match (const std::map<std::string, const Operator*>& matched_operators, const std::map<std::string, Parameter>& /* captured_params*/ , const std::map<std::string, Attribute>& /* captured_attrs*/ ) const
83+ {
84+ const Operator* op_0 = matched_operators.at (" op_0" );
85+ const std::vector<int >& shape = op_0->inputs [0 ]->shape ;
86+ if (shape.empty ())
87+ {
88+ // unknown normalized_shape
89+ return false ;
90+ }
91+
92+ return true ;
93+ }
94+
95+ void write (Operator* op, const std::map<std::string, Parameter>& captured_params) const
96+ {
97+ const std::vector<int >& shape = op->inputs [0 ]->shape ;
98+ const int c = shape[shape.size () - 1 ];
99+
100+ op->params [" elementwise_affine" ] = false ;
101+ op->params [" eps" ] = captured_params.at (" eps" );
102+ op->params [" normalized_shape" ] = std::vector<int >{c};
103+ }
104+ };
105+
106+ class fuse_rmsnorm_pass_without_gamma_1 : public fuse_rmsnorm_pass_without_gamma
107+ {
108+ public:
109+ const char * match_pattern_graph () const
110+ {
111+ return R"PNNXIR( 7767517
112+ 5 4
113+ pnnx.Input input 0 1 input
114+ pnnx.Expression op_0 1 1 input sq expr=pow(@0,2)
115+ torch.mean op_1 1 1 sq sqmean dim=(-1) keepdim=True
116+ pnnx.Expression op_2 2 1 input sqmean out expr=mul(@0,reciprocal(sqrt(add(@1,%eps))))
117+ pnnx.Output output 1 0 out
118+ )PNNXIR" ;
119+ }
120+ };
121+
57122class fuse_rmsnorm_pass_onnx : public fuse_rmsnorm_pass
58123{
59124public:
@@ -75,11 +140,15 @@ void fuse_rmsnorm(Graph& graph)
75140{
76141 fuse_rmsnorm_pass a;
77142 fuse_rmsnorm_pass_1 a1;
143+ fuse_rmsnorm_pass_without_gamma a2;
144+ fuse_rmsnorm_pass_without_gamma_1 a3;
78145 fuse_rmsnorm_pass_onnx b;
79146 int opindex = 0 ;
80147
81148 pnnx_graph_rewrite (graph, &a, opindex);
82149 pnnx_graph_rewrite (graph, &a1, opindex);
150+ pnnx_graph_rewrite (graph, &a2, opindex);
151+ pnnx_graph_rewrite (graph, &a3, opindex);
83152 pnnx_graph_rewrite (graph, &b, opindex);
84153}
85154
0 commit comments