Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit 5b9b098

Browse files
authored
Migrate concat fusion to r0.16 (#2658)
1 parent 6cf0012 commit 5b9b098

File tree

5 files changed

+648
-0
lines changed

5 files changed

+648
-0
lines changed

src/ngraph/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,8 @@ set (SRC
312312
pass/zero_dim_tensor_elimination.cpp
313313
pass/zero_dim_tensor_elimination.hpp
314314
pass/zero_dim_tensor_elimination.hpp
315+
pass/concat_fusion.hpp
316+
pass/concat_fusion.cpp
315317
pattern/matcher.cpp
316318
pattern/matcher.hpp
317319
pattern/op/any.hpp

src/ngraph/pass/concat_fusion.cpp

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
//*****************************************************************************
2+
// Copyright 2017-2019 Intel Corporation
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
//*****************************************************************************
16+
17+
#include "concat_fusion.hpp"
18+
#include <algorithm>
19+
#include <iostream>
20+
#include <numeric>
21+
#include <unordered_set>
22+
23+
#include "ngraph/graph_util.hpp"
24+
#include "ngraph/log.hpp"
25+
#include "ngraph/op/broadcast.hpp"
26+
#include "ngraph/op/concat.hpp"
27+
#include "ngraph/op/parameter.hpp"
28+
#include "ngraph/op/reshape.hpp"
29+
#include "ngraph/pattern/matcher.hpp"
30+
#include "ngraph/pattern/op/label.hpp"
31+
#include "ngraph/pattern/op/skip.hpp"
32+
#include "ngraph/util.hpp"
33+
34+
using namespace ngraph;
35+
36+
namespace
37+
{
38+
bool check_self_concat_op(const std::shared_ptr<Node>& op)
39+
{
40+
auto input_args = op->get_arguments();
41+
std::set<std::shared_ptr<Node>> input_args_set(input_args.begin(), input_args.end());
42+
return (input_args_set.size() == 1);
43+
}
44+
45+
bool check_concat_axis_dim_value(const std::shared_ptr<Node>& concat_op)
46+
{
47+
auto input_shape = concat_op->get_input_shape(0);
48+
size_t concat_axis =
49+
std::static_pointer_cast<op::Concat>(concat_op)->get_concatenation_axis();
50+
51+
return (input_shape[concat_axis] == 1);
52+
}
53+
54+
bool check_concat_has_no_fan_out(const std::shared_ptr<Node>& op)
55+
{
56+
auto users = op->get_users(true);
57+
std::set<std::shared_ptr<Node>> user_set(users.begin(), users.end());
58+
size_t num_unique_users = user_set.size();
59+
if (num_unique_users == 1)
60+
{
61+
return true;
62+
}
63+
else
64+
{
65+
NGRAPH_DEBUG << "self_concat_fusion: " << op->get_name() << " has fan out\n";
66+
return false;
67+
}
68+
}
69+
70+
bool valid_self_concat(const std::shared_ptr<Node>& Op)
71+
{
72+
if (!check_self_concat_op(Op))
73+
{
74+
NGRAPH_DEBUG << "self_concat_fusion: Matcher matched " << Op->get_name()
75+
<< " but it is not a self concat\n";
76+
return false;
77+
}
78+
79+
if (!check_concat_axis_dim_value(Op))
80+
{
81+
NGRAPH_DEBUG << "self_concat_fusion: Input shape value along concat axis of "
82+
<< Op->get_name() << " is not equal to 1\n";
83+
return false;
84+
}
85+
86+
return true;
87+
}
88+
89+
std::vector<size_t> get_concatenation_axis_vector(const NodeVector& bounded_concat_ops)
90+
{
91+
std::vector<size_t> concat_axis_vec;
92+
for (auto iter : bounded_concat_ops)
93+
{
94+
auto concat_op = std::static_pointer_cast<op::Concat>(iter);
95+
concat_axis_vec.push_back(concat_op->get_concatenation_axis());
96+
}
97+
return concat_axis_vec;
98+
}
99+
}
100+
101+
void pass::ConcatElimination::construct_concat_elimination()
102+
{
103+
auto op_label = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 3});
104+
auto concat = std::make_shared<op::Concat>(NodeVector{op_label}, 0);
105+
auto concat_label = std::make_shared<pattern::op::Label>(concat, nullptr, NodeVector{concat});
106+
107+
auto callback = [op_label](pattern::Matcher& m) {
108+
NGRAPH_DEBUG
109+
<< "concat_elimination: In callback for construct_concat_elimination against node = "
110+
<< m.get_match_root()->get_name();
111+
auto pattern_map = m.get_pattern_map();
112+
auto op = pattern_map[op_label];
113+
114+
auto root = std::dynamic_pointer_cast<op::Concat>(m.get_match_root());
115+
if (root && (root->get_input_shape(0) == root->get_output_shape(0)))
116+
{
117+
NGRAPH_DEBUG << " eliminated " << m.get_match_root() << "\n";
118+
replace_node(m.get_match_root(), op);
119+
120+
return true;
121+
}
122+
NGRAPH_DEBUG << " Incorrect match in callback\n";
123+
return false;
124+
};
125+
126+
auto m = std::make_shared<pattern::Matcher>(concat_label, callback);
127+
this->add_matcher(m);
128+
}
129+
130+
bool ngraph::pass::SelfConcatFusion::run_on_function(std::shared_ptr<Function> function)
131+
{
132+
bool modify_graph = false;
133+
auto has_multiple_inputs = [](std::shared_ptr<Node> n) {
134+
auto input_size = n->get_input_size();
135+
auto root = std::dynamic_pointer_cast<op::Concat>(n);
136+
return (root && input_size > 1);
137+
};
138+
139+
auto print_state_of_bounded_vectors = [this]() -> std::string {
140+
std::stringstream ss;
141+
ss << "-----------------------------------------------------------" << std::endl;
142+
ss << "State of bounded pattern node vectors: " << std::endl;
143+
ss << "-----------------------------------------------------------" << std::endl;
144+
ss << "Number of pattern node vectors: " << this->m_concat_pattern_vectors.size()
145+
<< std::endl;
146+
size_t c = 0;
147+
for (auto iter : this->m_concat_pattern_vectors)
148+
{
149+
ss << "For vector " << c << std::endl;
150+
auto iter_node_vec = iter;
151+
ss << "concat_op_vector: ";
152+
for (auto it : iter_node_vec)
153+
{
154+
ss << it->get_name() << " ";
155+
}
156+
ss << std::endl;
157+
c++;
158+
}
159+
ss << "-----------------------------" << std::endl;
160+
return ss.str();
161+
};
162+
163+
auto concat_op_label =
164+
std::make_shared<pattern::op::Label>(element::f32, Shape{1, 3}, has_multiple_inputs);
165+
auto matcher = std::make_shared<pattern::Matcher>(concat_op_label);
166+
for (auto n : function->get_ordered_ops())
167+
{
168+
construct_concat_patterns(matcher, concat_op_label, n);
169+
}
170+
171+
NGRAPH_DEBUG << print_state_of_bounded_vectors();
172+
173+
remove_single_concat_op_pattern();
174+
175+
for (auto concat_op_pattern_node_vector : this->m_concat_pattern_vectors)
176+
{
177+
modify_graph = replace_patterns(concat_op_pattern_node_vector);
178+
}
179+
180+
return modify_graph;
181+
}
182+
183+
void ngraph::pass::SelfConcatFusion::construct_concat_patterns(
184+
const std::shared_ptr<pattern::Matcher>& matcher,
185+
const std::shared_ptr<pattern::op::Label>& concat_op_label,
186+
const std::shared_ptr<Node>& n)
187+
{
188+
if (matcher->match(n))
189+
{
190+
auto concat_op = matcher->get_pattern_map()[concat_op_label];
191+
if (!std::dynamic_pointer_cast<op::Concat>(concat_op))
192+
{
193+
NGRAPH_DEBUG << "self_concat_fusion: Pattern matcher matched incorrect op. Matched "
194+
<< concat_op->get_name() << " instead of a self concat";
195+
return;
196+
}
197+
if (!valid_self_concat(concat_op))
198+
{
199+
NGRAPH_DEBUG << "self_concat_fusion: " << concat_op->get_name()
200+
<< " is not a valid self concat\n";
201+
return;
202+
}
203+
else
204+
{
205+
NGRAPH_DEBUG << "self_concat_fusion: " << concat_op->get_name()
206+
<< " is a VALID self concat\n";
207+
}
208+
209+
auto& concat_vectors = this->m_concat_pattern_vectors;
210+
if (concat_vectors.empty())
211+
{
212+
concat_vectors.push_back(NodeVector{concat_op});
213+
}
214+
else
215+
{
216+
update_concat_pattern_vectors(concat_op);
217+
}
218+
}
219+
}
220+
221+
void ngraph::pass::SelfConcatFusion::update_concat_pattern_vectors(
222+
const std::shared_ptr<Node>& concat_op)
223+
{
224+
bool concat_source_found = false;
225+
for (auto& concat_pattern_vec : this->m_concat_pattern_vectors)
226+
{
227+
auto last_op_in_pattern_vec = concat_pattern_vec.back();
228+
if ((concat_op->get_argument(0) == last_op_in_pattern_vec) &&
229+
(check_concat_has_no_fan_out(last_op_in_pattern_vec)))
230+
{
231+
concat_pattern_vec.push_back(concat_op);
232+
concat_source_found = true;
233+
break;
234+
}
235+
}
236+
237+
if (!concat_source_found)
238+
{
239+
this->m_concat_pattern_vectors.push_back(NodeVector{concat_op});
240+
}
241+
}
242+
243+
void ngraph::pass::SelfConcatFusion::remove_single_concat_op_pattern()
244+
{
245+
auto iter = m_concat_pattern_vectors.begin();
246+
while (iter != m_concat_pattern_vectors.end())
247+
{
248+
if (iter->size() == 1)
249+
{
250+
iter = m_concat_pattern_vectors.erase(iter);
251+
}
252+
else
253+
{
254+
iter++;
255+
}
256+
}
257+
}
258+
259+
bool ngraph::pass::SelfConcatFusion::replace_patterns(const NodeVector& bounded_concat_ops)
260+
{
261+
auto scalarize_dim = [](std::vector<size_t> concat_axis_vector,
262+
const Shape& input_shape) -> Shape {
263+
264+
Shape scalarized_shape;
265+
for (size_t i = 0; i < input_shape.size(); i++)
266+
{
267+
auto it = std::find(concat_axis_vector.begin(), concat_axis_vector.end(), i);
268+
if (it == concat_axis_vector.end())
269+
{
270+
scalarized_shape.push_back(input_shape[i]);
271+
}
272+
}
273+
return scalarized_shape;
274+
};
275+
276+
auto concat_axis_vector = get_concatenation_axis_vector(bounded_concat_ops);
277+
278+
auto& first_bounded_concat = (*bounded_concat_ops.begin());
279+
auto driver_op = first_bounded_concat->get_argument(0);
280+
const Shape& input_shape = first_bounded_concat->get_input_shape(0);
281+
282+
auto scalarized_shape = scalarize_dim(concat_axis_vector, input_shape);
283+
AxisVector axis_order = get_default_order(input_shape);
284+
auto reshape = std::make_shared<op::Reshape>(driver_op, axis_order, scalarized_shape);
285+
auto last_bounded_concat_op = bounded_concat_ops.back();
286+
auto broadcast_out_shape = last_bounded_concat_op->get_shape();
287+
auto broadcast =
288+
std::make_shared<op::Broadcast>(reshape, broadcast_out_shape, concat_axis_vector);
289+
290+
replace_node(last_bounded_concat_op, broadcast);
291+
return true;
292+
}

src/ngraph/pass/concat_fusion.hpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
//*****************************************************************************
2+
// Copyright 2017-2019 Intel Corporation
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
//*****************************************************************************
16+
17+
#pragma once
18+
19+
#include "ngraph/pass/graph_rewrite.hpp"
20+
#include "ngraph/pass/pass.hpp"
21+
#include "ngraph/pattern/matcher.hpp"
22+
#include "ngraph/pattern/op/label.hpp"
23+
24+
namespace ngraph
25+
{
26+
namespace pass
27+
{
28+
class ConcatElimination;
29+
class SelfConcatFusion;
30+
}
31+
}
32+
33+
class ngraph::pass::ConcatElimination : public ngraph::pass::GraphRewrite
34+
{
35+
public:
36+
ConcatElimination()
37+
: GraphRewrite()
38+
{
39+
construct_concat_elimination();
40+
}
41+
42+
private:
43+
void construct_concat_elimination();
44+
};
45+
46+
class ngraph::pass::SelfConcatFusion : public ngraph::pass::FunctionPass
47+
{
48+
public:
49+
virtual bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
50+
51+
private:
52+
void update_concat_pattern_vectors(const std::shared_ptr<Node>&);
53+
void remove_single_concat_op_pattern();
54+
void construct_concat_patterns(const std::shared_ptr<pattern::Matcher>&,
55+
const std::shared_ptr<pattern::op::Label>&,
56+
const std::shared_ptr<Node>&);
57+
bool replace_patterns(const NodeVector&);
58+
std::vector<NodeVector> m_concat_pattern_vectors;
59+
};

test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ set(SRC
3333
build_graph.cpp
3434
builder_autobroadcast.cpp
3535
constant_folding.cpp
36+
concat_fusion.cpp
3637
control_dependencies.cpp
3738
coordinate.cpp
3839
copy.cpp

0 commit comments

Comments
 (0)