@@ -34,7 +34,15 @@ void ReplaceTransposeWithReshape(std::shared_ptr<ngraph::Node> transpose_node) {
3434 transpose_node->output (0 ).replace (reshape_node->output (0 ));
3535}
3636
37- void InsertTranspose (std::shared_ptr<ngraph::Node> prev_node, const std::string& base_name) {
37+ void InsertTranspose (std::shared_ptr<ngraph::Node> prev_node, const std::string& base_name, bool before_matmul) {
38+ auto create_reshape = [](const ngraph::Shape& shape, std::shared_ptr<ngraph::Node> input_node, const std::string& name) {
39+ auto reshape_const = std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64 ,
40+ ngraph::Shape{shape.size ()}, shape);
41+ auto node = std::make_shared<ngraph::opset8::Reshape>(input_node, reshape_const, false );
42+ node->set_friendly_name (name);
43+ return node;
44+ };
45+
3846 auto consumers = prev_node->output (0 ).get_target_inputs ();
3947 const auto orig_shape = prev_node->get_output_shape (0 );
4048 std::vector<size_t > transpose_ids;
@@ -48,18 +56,29 @@ void InsertTranspose(std::shared_ptr<ngraph::Node> prev_node, const std::string&
4856 std::iota (std::begin (permute_order), std::end (permute_order), 0 );
4957 std::swap (permute_order[transpose_ids[0 ]], permute_order[transpose_ids[1 ]]);
5058
59+ ngraph::NodeVector new_ops;
60+ std::shared_ptr<ngraph::Node> node = prev_node;
61+ if (!before_matmul) {
62+ auto shape = prev_node->get_output_shape (0 );
63+ std::swap (shape[0 ], shape[1 ]);
64+ node = create_reshape (shape, node, base_name + " /reshape_before_transpose" );
65+ new_ops.push_back (node);
66+ }
67+
5168 auto transpose_order = ngraph::opset8::Constant::create (ngraph::element::i64 , ngraph::Shape{permute_order.size ()}, permute_order);
52- auto transpose = std::make_shared<ngraph::opset8::Transpose>(prev_node, transpose_order);
53- transpose->set_friendly_name (base_name + " /in_transpose" );
69+ node = std::make_shared<ngraph::opset8::Transpose>(node, transpose_order);
70+ node->set_friendly_name (base_name + " /in_transpose" );
71+ new_ops.push_back (node);
5472
55- auto reshapeConstAfter = std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64 ,
56- ngraph::Shape{orig_shape.size ()}, orig_shape);
57- auto reshapeAfter = std::make_shared<ngraph::opset8::Reshape>(transpose, reshapeConstAfter, false );
58- reshapeAfter->set_friendly_name (base_name + " /reshape_after_transpose" );
59- ngraph::copy_runtime_info (prev_node, ngraph::NodeVector{transpose, reshapeAfter});
73+ if (before_matmul) {
74+ node = create_reshape (orig_shape, node, base_name + " /reshape_after_transpose" );
75+ new_ops.push_back (node);
76+ }
77+
78+ ngraph::copy_runtime_info (prev_node, new_ops);
6079
6180 for (auto input : consumers) {
62- input.replace_source_output (reshapeAfter );
81+ input.replace_source_output (node );
6382 }
6483}
6584
@@ -94,24 +113,25 @@ HandleTransposeBeforeMatMul::HandleTransposeBeforeMatMul() {
94113 return false ;
95114 }
96115
116+ auto matmul_node = matmul_iter->second .get_node_shared_ptr ();
97117 auto transpose_reshape_it = pattern_map.find (transpose);
98118 if (transpose_reshape_it != std::end (pattern_map)) {
99119 ReplaceTransposeWithReshape (transpose_reshape_it->second .get_node_shared_ptr ());
100120 } else if ((transpose_reshape_it = pattern_map.find (reshape)) != std::end (pattern_map)) {
101121 auto reshape_node = pattern_map.at (reshape).get_node_shared_ptr ();
102122 if (GNALimitations::IsTransposeSupported (reshape_node->get_output_shape (0 ))) {
103- auto matmul_node = matmul_iter->second .get_node_shared_ptr ();
104- InsertTranspose (reshape_node, matmul_node->get_friendly_name ());
123+ InsertTranspose (reshape_node, matmul_node->get_friendly_name (), true );
105124 }
106125 }
107126
127+ // Transpose the constant input if it's the first input
108128 auto iter = pattern_map.find (fq);
109129 if (iter != pattern_map.end () ||
110130 (iter = pattern_map.find (constant)) != pattern_map.end ()) {
111131 auto prev_node = iter->second .get_node_shared_ptr ();
112- if (! GNALimitations::IsTransposeSupported (prev_node->get_output_shape (0 ))) return false ;
113- auto matmul_node = iter-> second . get_node_shared_ptr ( );
114- InsertTranspose (prev_node, matmul_node-> get_friendly_name ());
132+ if (GNALimitations::IsTranspose2d (prev_node->get_output_shape (0 ))) {
133+ InsertTranspose (prev_node, prev_node-> get_friendly_name (), true );
134+ }
115135 }
116136 return true ;
117137 };
@@ -129,7 +149,11 @@ HandleTransposeAfterMatMul::HandleTransposeAfterMatMul() {
129149 auto fq_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, add_left, add_right});
130150 auto fq = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({fq_input, ngraph::pattern::any_input (),
131151 ngraph::pattern::any_input (), ngraph::pattern::any_input (), ngraph::pattern::any_input ()});
132- auto transpose_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{fq_input, fq});
152+ auto act_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{fq_input, fq});
153+ auto act = ngraph::pattern::wrap_type<ngraph::opset8::Relu, ngraph::opset8::Sigmoid,
154+ ngraph::opset8::Tanh, ngraph::opset8::Abs, ngraph::opset8::Log, ngraph::opset8::Exp,
155+ ngraph::opset8::Sign, ngraph::opset8::Clamp>({act_input});
156+ auto transpose_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{act_input, act});
133157 auto transpose = ngraph::pattern::wrap_type<ngraph::opset8::Transpose>({transpose_input, ngraph::pattern::any_input ()});
134158 auto reshape_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{transpose_input, transpose});
135159 auto reshape = ngraph::pattern::wrap_type<ngraph::opset8::Reshape>(
@@ -142,16 +166,17 @@ HandleTransposeAfterMatMul::HandleTransposeAfterMatMul() {
142166 ReplaceTransposeWithReshape (transpose_it->second .get_node_shared_ptr ());
143167 } else {
144168 auto reshape_node = pattern_map.at (reshape).get_node_shared_ptr ();
145- if (!GNALimitations::IsTransposeSupported (reshape_node->get_output_shape (0 ))) return false ;
146- auto iter = pattern_map.find (fq );
169+ if (!GNALimitations::IsTransposeSupported (reshape_node->get_input_shape (0 ))) return false ;
170+ auto iter = pattern_map.find (act );
147171 if (iter == pattern_map.end () &&
172+ (iter = pattern_map.find (fq)) == pattern_map.end () &&
148173 (iter = pattern_map.find (add_left)) == pattern_map.end () &&
149174 (iter = pattern_map.find (add_right)) == pattern_map.end () &&
150175 (iter = pattern_map.find (matmul)) == pattern_map.end ()) {
151176 return false ;
152177 }
153178 auto node = iter->second .get_node_shared_ptr ();
154- InsertTranspose (node, node->get_friendly_name ());
179+ InsertTranspose (node, node->get_friendly_name (), false );
155180 }
156181 return true ;
157182 };
0 commit comments