Skip to content

Commit 6feb981

Browse files
authored
[GNA] Fixed handling of transposes around MatMul (openvinotoolkit#8600)
* [GNA] Fixed handling of transposes around MatMul * [GNA] Fixed swap matmul inputs tests * [GNA] Comments applying
1 parent f06aa99 commit 6feb981

File tree

9 files changed

+423
-95
lines changed

9 files changed

+423
-95
lines changed

inference-engine/src/gna_plugin/backend/gna_limitations.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,14 @@ constexpr uint32_t maxPoolMaxWindowSize = 6;
3131
constexpr uint32_t copyMaxGrouping = 8;
3232
constexpr uint32_t transposeMaxSize = 65528;
3333

34+
inline bool IsTranspose2d(const std::vector<size_t>& shape) {
35+
return std::count_if(std::begin(shape), std::end(shape), [](size_t dim) { return dim != 1; }) == 2;
36+
}
37+
3438
inline bool IsTransposeSupported(const std::vector<size_t>& shape) {
39+
if (!IsTranspose2d(shape)) return false;
3540
auto shape_no_1 = shape;
3641
shape_no_1.erase(std::remove(shape_no_1.begin(), shape_no_1.end(), 1), shape_no_1.end());
37-
if (shape_no_1.size() != 2) return false;
3842
size_t min, max;
3943
std::tie(min, max) = std::minmax(shape_no_1[0], shape_no_1[1]);
4044
return min <= 8 && max % 8 == 0 && max >= 8 && max <= transposeMaxSize;

inference-engine/src/gna_plugin/gna_plugin.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,8 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
711711
manager.register_pass<InsertReshapeAroundMatmulWithFq>();
712712
manager.register_pass<InsertReshapeAroundMatmulWithAdd>();
713713
manager.register_pass<InsertReshapeAroundMatmul>();
714+
manager.register_pass<SwapInputMatMulWithTrailingTranspose>();
715+
manager.register_pass<SwapInputMatMulWithAct>();
714716
manager.register_pass<SwapInputMatMulWithFq>();
715717
manager.register_pass<SwapInputMatMulWithBias>();
716718
manager.register_pass<SwapInputMatMul>();

inference-engine/src/gna_plugin/transformations/handle_transposes_around_matmul.cpp

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,15 @@ void ReplaceTransposeWithReshape(std::shared_ptr<ngraph::Node> transpose_node) {
3434
transpose_node->output(0).replace(reshape_node->output(0));
3535
}
3636

37-
void InsertTranspose(std::shared_ptr<ngraph::Node> prev_node, const std::string& base_name) {
37+
void InsertTranspose(std::shared_ptr<ngraph::Node> prev_node, const std::string& base_name, bool before_matmul) {
38+
auto create_reshape = [](const ngraph::Shape& shape, std::shared_ptr<ngraph::Node> input_node, const std::string& name) {
39+
auto reshape_const = std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64,
40+
ngraph::Shape{shape.size()}, shape);
41+
auto node = std::make_shared<ngraph::opset8::Reshape>(input_node, reshape_const, false);
42+
node->set_friendly_name(name);
43+
return node;
44+
};
45+
3846
auto consumers = prev_node->output(0).get_target_inputs();
3947
const auto orig_shape = prev_node->get_output_shape(0);
4048
std::vector<size_t> transpose_ids;
@@ -48,18 +56,29 @@ void InsertTranspose(std::shared_ptr<ngraph::Node> prev_node, const std::string&
4856
std::iota(std::begin(permute_order), std::end(permute_order), 0);
4957
std::swap(permute_order[transpose_ids[0]], permute_order[transpose_ids[1]]);
5058

59+
ngraph::NodeVector new_ops;
60+
std::shared_ptr<ngraph::Node> node = prev_node;
61+
if (!before_matmul) {
62+
auto shape = prev_node->get_output_shape(0);
63+
std::swap(shape[0], shape[1]);
64+
node = create_reshape(shape, node, base_name + "/reshape_before_transpose");
65+
new_ops.push_back(node);
66+
}
67+
5168
auto transpose_order = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{permute_order.size()}, permute_order);
52-
auto transpose = std::make_shared<ngraph::opset8::Transpose>(prev_node, transpose_order);
53-
transpose->set_friendly_name(base_name + "/in_transpose");
69+
node = std::make_shared<ngraph::opset8::Transpose>(node, transpose_order);
70+
node->set_friendly_name(base_name + "/in_transpose");
71+
new_ops.push_back(node);
5472

55-
auto reshapeConstAfter = std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64,
56-
ngraph::Shape{orig_shape.size()}, orig_shape);
57-
auto reshapeAfter = std::make_shared<ngraph::opset8::Reshape>(transpose, reshapeConstAfter, false);
58-
reshapeAfter->set_friendly_name(base_name + "/reshape_after_transpose");
59-
ngraph::copy_runtime_info(prev_node, ngraph::NodeVector{transpose, reshapeAfter});
73+
if (before_matmul) {
74+
node = create_reshape(orig_shape, node, base_name + "/reshape_after_transpose");
75+
new_ops.push_back(node);
76+
}
77+
78+
ngraph::copy_runtime_info(prev_node, new_ops);
6079

6180
for (auto input : consumers) {
62-
input.replace_source_output(reshapeAfter);
81+
input.replace_source_output(node);
6382
}
6483
}
6584

@@ -94,24 +113,25 @@ HandleTransposeBeforeMatMul::HandleTransposeBeforeMatMul() {
94113
return false;
95114
}
96115

116+
auto matmul_node = matmul_iter->second.get_node_shared_ptr();
97117
auto transpose_reshape_it = pattern_map.find(transpose);
98118
if (transpose_reshape_it != std::end(pattern_map)) {
99119
ReplaceTransposeWithReshape(transpose_reshape_it->second.get_node_shared_ptr());
100120
} else if ((transpose_reshape_it = pattern_map.find(reshape)) != std::end(pattern_map)) {
101121
auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr();
102122
if (GNALimitations::IsTransposeSupported(reshape_node->get_output_shape(0))) {
103-
auto matmul_node = matmul_iter->second.get_node_shared_ptr();
104-
InsertTranspose(reshape_node, matmul_node->get_friendly_name());
123+
InsertTranspose(reshape_node, matmul_node->get_friendly_name(), true);
105124
}
106125
}
107126

127+
// Transpose the constant input if it's the first input
108128
auto iter = pattern_map.find(fq);
109129
if (iter != pattern_map.end() ||
110130
(iter = pattern_map.find(constant)) != pattern_map.end()) {
111131
auto prev_node = iter->second.get_node_shared_ptr();
112-
if (!GNALimitations::IsTransposeSupported(prev_node->get_output_shape(0))) return false;
113-
auto matmul_node = iter->second.get_node_shared_ptr();
114-
InsertTranspose(prev_node, matmul_node->get_friendly_name());
132+
if (GNALimitations::IsTranspose2d(prev_node->get_output_shape(0))) {
133+
InsertTranspose(prev_node, prev_node->get_friendly_name(), true);
134+
}
115135
}
116136
return true;
117137
};
@@ -129,7 +149,11 @@ HandleTransposeAfterMatMul::HandleTransposeAfterMatMul() {
129149
auto fq_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, add_left, add_right});
130150
auto fq = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({fq_input, ngraph::pattern::any_input(),
131151
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()});
132-
auto transpose_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{fq_input, fq});
152+
auto act_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{fq_input, fq});
153+
auto act = ngraph::pattern::wrap_type<ngraph::opset8::Relu, ngraph::opset8::Sigmoid,
154+
ngraph::opset8::Tanh, ngraph::opset8::Abs, ngraph::opset8::Log, ngraph::opset8::Exp,
155+
ngraph::opset8::Sign, ngraph::opset8::Clamp>({act_input});
156+
auto transpose_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{act_input, act});
133157
auto transpose = ngraph::pattern::wrap_type<ngraph::opset8::Transpose>({transpose_input, ngraph::pattern::any_input()});
134158
auto reshape_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{transpose_input, transpose});
135159
auto reshape = ngraph::pattern::wrap_type<ngraph::opset8::Reshape>(
@@ -142,16 +166,17 @@ HandleTransposeAfterMatMul::HandleTransposeAfterMatMul() {
142166
ReplaceTransposeWithReshape(transpose_it->second.get_node_shared_ptr());
143167
} else {
144168
auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr();
145-
if (!GNALimitations::IsTransposeSupported(reshape_node->get_output_shape(0))) return false;
146-
auto iter = pattern_map.find(fq);
169+
if (!GNALimitations::IsTransposeSupported(reshape_node->get_input_shape(0))) return false;
170+
auto iter = pattern_map.find(act);
147171
if (iter == pattern_map.end() &&
172+
(iter = pattern_map.find(fq)) == pattern_map.end() &&
148173
(iter = pattern_map.find(add_left)) == pattern_map.end() &&
149174
(iter = pattern_map.find(add_right)) == pattern_map.end() &&
150175
(iter = pattern_map.find(matmul)) == pattern_map.end()) {
151176
return false;
152177
}
153178
auto node = iter->second.get_node_shared_ptr();
154-
InsertTranspose(node, node->get_friendly_name());
179+
InsertTranspose(node, node->get_friendly_name(), false);
155180
}
156181
return true;
157182
};

inference-engine/src/gna_plugin/transformations/handle_transposes_around_matmul.hpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,18 @@ namespace GNAPluginNS {
1111
/**
1212
* @brief Inserts Transpose before MatMul or removes it (if it exists) if there is Reshape
1313
* before MatMul which changes the batch size:
14-
* [1, A*B] [1, A*B]
14+
* [1, A*B] [1, A*B]
1515
* | |
1616
* Reshape Reshape
1717
* | |
18-
* [1, A, 1, B] [1, A, 1, B]
18+
* [A, B] [A, B]
1919
* | |
2020
* | Transpose
2121
* | -> |
22-
* | <- [1, B, 1, A]
22+
* | <- [B, A]
23+
* | |
24+
* | Reshape
25+
* | [A, B]
2326
* | |
2427
* MatMul MatMul
2528
*/
@@ -33,12 +36,20 @@ class HandleTransposeBeforeMatMul : public ngraph::pass::MatcherPass {
3336
* @brief Inserts Transpose after MatMul or removes it (if it exists) if there is Reshape
3437
* after MatMul which changes the batch size:
3538
* MatMul MatMul
39+
* [A, B] [A, B]
40+
* | |
41+
* [Add] [Add]
42+
* | |
43+
* [FakeQuantize] [FakeQuantize]
44+
* | |
45+
* [Activation] [Activation]
3646
* | |
37-
* [1, A, 1, B] [1, A, 1, B]
47+
* | Reshape
48+
* | [B, A]
3849
* | |
3950
* | Transpose
4051
* | -> |
41-
* | <- [1, B, 1, A]
52+
* | <- [A, B]
4253
* | |
4354
* Reshape Reshape
4455
* | |

0 commit comments

Comments
 (0)