|
| 1 | +//***************************************************************************** |
| 2 | +// Copyright 2017-2019 Intel Corporation |
| 3 | +// |
| 4 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +// you may not use this file except in compliance with the License. |
| 6 | +// You may obtain a copy of the License at |
| 7 | +// |
| 8 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +// |
| 10 | +// Unless required by applicable law or agreed to in writing, software |
| 11 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +// See the License for the specific language governing permissions and |
| 14 | +// limitations under the License. |
| 15 | +//***************************************************************************** |
| 16 | + |
| 17 | +#include "concat_fusion.hpp" |
| 18 | +#include <algorithm> |
| 19 | +#include <iostream> |
| 20 | +#include <numeric> |
| 21 | +#include <unordered_set> |
| 22 | + |
| 23 | +#include "ngraph/graph_util.hpp" |
| 24 | +#include "ngraph/log.hpp" |
| 25 | +#include "ngraph/op/broadcast.hpp" |
| 26 | +#include "ngraph/op/concat.hpp" |
| 27 | +#include "ngraph/op/parameter.hpp" |
| 28 | +#include "ngraph/op/reshape.hpp" |
| 29 | +#include "ngraph/pattern/matcher.hpp" |
| 30 | +#include "ngraph/pattern/op/label.hpp" |
| 31 | +#include "ngraph/pattern/op/skip.hpp" |
| 32 | +#include "ngraph/util.hpp" |
| 33 | + |
| 34 | +using namespace ngraph; |
| 35 | + |
| 36 | +namespace |
| 37 | +{ |
| 38 | + bool check_self_concat_op(const std::shared_ptr<Node>& op) |
| 39 | + { |
| 40 | + auto input_args = op->get_arguments(); |
| 41 | + std::set<std::shared_ptr<Node>> input_args_set(input_args.begin(), input_args.end()); |
| 42 | + return (input_args_set.size() == 1); |
| 43 | + } |
| 44 | + |
| 45 | + bool check_concat_axis_dim_value(const std::shared_ptr<Node>& concat_op) |
| 46 | + { |
| 47 | + auto input_shape = concat_op->get_input_shape(0); |
| 48 | + size_t concat_axis = |
| 49 | + std::static_pointer_cast<op::Concat>(concat_op)->get_concatenation_axis(); |
| 50 | + |
| 51 | + return (input_shape[concat_axis] == 1); |
| 52 | + } |
| 53 | + |
| 54 | + bool check_concat_has_no_fan_out(const std::shared_ptr<Node>& op) |
| 55 | + { |
| 56 | + auto users = op->get_users(true); |
| 57 | + std::set<std::shared_ptr<Node>> user_set(users.begin(), users.end()); |
| 58 | + size_t num_unique_users = user_set.size(); |
| 59 | + if (num_unique_users == 1) |
| 60 | + { |
| 61 | + return true; |
| 62 | + } |
| 63 | + else |
| 64 | + { |
| 65 | + NGRAPH_DEBUG << "self_concat_fusion: " << op->get_name() << " has fan out\n"; |
| 66 | + return false; |
| 67 | + } |
| 68 | + } |
| 69 | + |
| 70 | + bool valid_self_concat(const std::shared_ptr<Node>& Op) |
| 71 | + { |
| 72 | + if (!check_self_concat_op(Op)) |
| 73 | + { |
| 74 | + NGRAPH_DEBUG << "self_concat_fusion: Matcher matched " << Op->get_name() |
| 75 | + << " but it is not a self concat\n"; |
| 76 | + return false; |
| 77 | + } |
| 78 | + |
| 79 | + if (!check_concat_axis_dim_value(Op)) |
| 80 | + { |
| 81 | + NGRAPH_DEBUG << "self_concat_fusion: Input shape value along concat axis of " |
| 82 | + << Op->get_name() << " is not equal to 1\n"; |
| 83 | + return false; |
| 84 | + } |
| 85 | + |
| 86 | + return true; |
| 87 | + } |
| 88 | + |
| 89 | + std::vector<size_t> get_concatenation_axis_vector(const NodeVector& bounded_concat_ops) |
| 90 | + { |
| 91 | + std::vector<size_t> concat_axis_vec; |
| 92 | + for (auto iter : bounded_concat_ops) |
| 93 | + { |
| 94 | + auto concat_op = std::static_pointer_cast<op::Concat>(iter); |
| 95 | + concat_axis_vec.push_back(concat_op->get_concatenation_axis()); |
| 96 | + } |
| 97 | + return concat_axis_vec; |
| 98 | + } |
| 99 | +} |
| 100 | + |
| 101 | +void pass::ConcatElimination::construct_concat_elimination() |
| 102 | +{ |
| 103 | + auto op_label = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 3}); |
| 104 | + auto concat = std::make_shared<op::Concat>(NodeVector{op_label}, 0); |
| 105 | + auto concat_label = std::make_shared<pattern::op::Label>(concat, nullptr, NodeVector{concat}); |
| 106 | + |
| 107 | + auto callback = [op_label](pattern::Matcher& m) { |
| 108 | + NGRAPH_DEBUG |
| 109 | + << "concat_elimination: In callback for construct_concat_elimination against node = " |
| 110 | + << m.get_match_root()->get_name(); |
| 111 | + auto pattern_map = m.get_pattern_map(); |
| 112 | + auto op = pattern_map[op_label]; |
| 113 | + |
| 114 | + auto root = std::dynamic_pointer_cast<op::Concat>(m.get_match_root()); |
| 115 | + if (root && (root->get_input_shape(0) == root->get_output_shape(0))) |
| 116 | + { |
| 117 | + NGRAPH_DEBUG << " eliminated " << m.get_match_root() << "\n"; |
| 118 | + replace_node(m.get_match_root(), op); |
| 119 | + |
| 120 | + return true; |
| 121 | + } |
| 122 | + NGRAPH_DEBUG << " Incorrect match in callback\n"; |
| 123 | + return false; |
| 124 | + }; |
| 125 | + |
| 126 | + auto m = std::make_shared<pattern::Matcher>(concat_label, callback); |
| 127 | + this->add_matcher(m); |
| 128 | +} |
| 129 | + |
| 130 | +bool ngraph::pass::SelfConcatFusion::run_on_function(std::shared_ptr<Function> function) |
| 131 | +{ |
| 132 | + bool modify_graph = false; |
| 133 | + auto has_multiple_inputs = [](std::shared_ptr<Node> n) { |
| 134 | + auto input_size = n->get_input_size(); |
| 135 | + auto root = std::dynamic_pointer_cast<op::Concat>(n); |
| 136 | + return (root && input_size > 1); |
| 137 | + }; |
| 138 | + |
| 139 | + auto print_state_of_bounded_vectors = [this]() -> std::string { |
| 140 | + std::stringstream ss; |
| 141 | + ss << "-----------------------------------------------------------" << std::endl; |
| 142 | + ss << "State of bounded pattern node vectors: " << std::endl; |
| 143 | + ss << "-----------------------------------------------------------" << std::endl; |
| 144 | + ss << "Number of pattern node vectors: " << this->m_concat_pattern_vectors.size() |
| 145 | + << std::endl; |
| 146 | + size_t c = 0; |
| 147 | + for (auto iter : this->m_concat_pattern_vectors) |
| 148 | + { |
| 149 | + ss << "For vector " << c << std::endl; |
| 150 | + auto iter_node_vec = iter; |
| 151 | + ss << "concat_op_vector: "; |
| 152 | + for (auto it : iter_node_vec) |
| 153 | + { |
| 154 | + ss << it->get_name() << " "; |
| 155 | + } |
| 156 | + ss << std::endl; |
| 157 | + c++; |
| 158 | + } |
| 159 | + ss << "-----------------------------" << std::endl; |
| 160 | + return ss.str(); |
| 161 | + }; |
| 162 | + |
| 163 | + auto concat_op_label = |
| 164 | + std::make_shared<pattern::op::Label>(element::f32, Shape{1, 3}, has_multiple_inputs); |
| 165 | + auto matcher = std::make_shared<pattern::Matcher>(concat_op_label); |
| 166 | + for (auto n : function->get_ordered_ops()) |
| 167 | + { |
| 168 | + construct_concat_patterns(matcher, concat_op_label, n); |
| 169 | + } |
| 170 | + |
| 171 | + NGRAPH_DEBUG << print_state_of_bounded_vectors(); |
| 172 | + |
| 173 | + remove_single_concat_op_pattern(); |
| 174 | + |
| 175 | + for (auto concat_op_pattern_node_vector : this->m_concat_pattern_vectors) |
| 176 | + { |
| 177 | + modify_graph = replace_patterns(concat_op_pattern_node_vector); |
| 178 | + } |
| 179 | + |
| 180 | + return modify_graph; |
| 181 | +} |
| 182 | + |
| 183 | +void ngraph::pass::SelfConcatFusion::construct_concat_patterns( |
| 184 | + const std::shared_ptr<pattern::Matcher>& matcher, |
| 185 | + const std::shared_ptr<pattern::op::Label>& concat_op_label, |
| 186 | + const std::shared_ptr<Node>& n) |
| 187 | +{ |
| 188 | + if (matcher->match(n)) |
| 189 | + { |
| 190 | + auto concat_op = matcher->get_pattern_map()[concat_op_label]; |
| 191 | + if (!std::dynamic_pointer_cast<op::Concat>(concat_op)) |
| 192 | + { |
| 193 | + NGRAPH_DEBUG << "self_concat_fusion: Pattern matcher matched incorrect op. Matched " |
| 194 | + << concat_op->get_name() << " instead of a self concat"; |
| 195 | + return; |
| 196 | + } |
| 197 | + if (!valid_self_concat(concat_op)) |
| 198 | + { |
| 199 | + NGRAPH_DEBUG << "self_concat_fusion: " << concat_op->get_name() |
| 200 | + << " is not a valid self concat\n"; |
| 201 | + return; |
| 202 | + } |
| 203 | + else |
| 204 | + { |
| 205 | + NGRAPH_DEBUG << "self_concat_fusion: " << concat_op->get_name() |
| 206 | + << " is a VALID self concat\n"; |
| 207 | + } |
| 208 | + |
| 209 | + auto& concat_vectors = this->m_concat_pattern_vectors; |
| 210 | + if (concat_vectors.empty()) |
| 211 | + { |
| 212 | + concat_vectors.push_back(NodeVector{concat_op}); |
| 213 | + } |
| 214 | + else |
| 215 | + { |
| 216 | + update_concat_pattern_vectors(concat_op); |
| 217 | + } |
| 218 | + } |
| 219 | +} |
| 220 | + |
| 221 | +void ngraph::pass::SelfConcatFusion::update_concat_pattern_vectors( |
| 222 | + const std::shared_ptr<Node>& concat_op) |
| 223 | +{ |
| 224 | + bool concat_source_found = false; |
| 225 | + for (auto& concat_pattern_vec : this->m_concat_pattern_vectors) |
| 226 | + { |
| 227 | + auto last_op_in_pattern_vec = concat_pattern_vec.back(); |
| 228 | + if ((concat_op->get_argument(0) == last_op_in_pattern_vec) && |
| 229 | + (check_concat_has_no_fan_out(last_op_in_pattern_vec))) |
| 230 | + { |
| 231 | + concat_pattern_vec.push_back(concat_op); |
| 232 | + concat_source_found = true; |
| 233 | + break; |
| 234 | + } |
| 235 | + } |
| 236 | + |
| 237 | + if (!concat_source_found) |
| 238 | + { |
| 239 | + this->m_concat_pattern_vectors.push_back(NodeVector{concat_op}); |
| 240 | + } |
| 241 | +} |
| 242 | + |
| 243 | +void ngraph::pass::SelfConcatFusion::remove_single_concat_op_pattern() |
| 244 | +{ |
| 245 | + auto iter = m_concat_pattern_vectors.begin(); |
| 246 | + while (iter != m_concat_pattern_vectors.end()) |
| 247 | + { |
| 248 | + if (iter->size() == 1) |
| 249 | + { |
| 250 | + iter = m_concat_pattern_vectors.erase(iter); |
| 251 | + } |
| 252 | + else |
| 253 | + { |
| 254 | + iter++; |
| 255 | + } |
| 256 | + } |
| 257 | +} |
| 258 | + |
| 259 | +bool ngraph::pass::SelfConcatFusion::replace_patterns(const NodeVector& bounded_concat_ops) |
| 260 | +{ |
| 261 | + auto scalarize_dim = [](std::vector<size_t> concat_axis_vector, |
| 262 | + const Shape& input_shape) -> Shape { |
| 263 | + |
| 264 | + Shape scalarized_shape; |
| 265 | + for (size_t i = 0; i < input_shape.size(); i++) |
| 266 | + { |
| 267 | + auto it = std::find(concat_axis_vector.begin(), concat_axis_vector.end(), i); |
| 268 | + if (it == concat_axis_vector.end()) |
| 269 | + { |
| 270 | + scalarized_shape.push_back(input_shape[i]); |
| 271 | + } |
| 272 | + } |
| 273 | + return scalarized_shape; |
| 274 | + }; |
| 275 | + |
| 276 | + auto concat_axis_vector = get_concatenation_axis_vector(bounded_concat_ops); |
| 277 | + |
| 278 | + auto& first_bounded_concat = (*bounded_concat_ops.begin()); |
| 279 | + auto driver_op = first_bounded_concat->get_argument(0); |
| 280 | + const Shape& input_shape = first_bounded_concat->get_input_shape(0); |
| 281 | + |
| 282 | + auto scalarized_shape = scalarize_dim(concat_axis_vector, input_shape); |
| 283 | + AxisVector axis_order = get_default_order(input_shape); |
| 284 | + auto reshape = std::make_shared<op::Reshape>(driver_op, axis_order, scalarized_shape); |
| 285 | + auto last_bounded_concat_op = bounded_concat_ops.back(); |
| 286 | + auto broadcast_out_shape = last_bounded_concat_op->get_shape(); |
| 287 | + auto broadcast = |
| 288 | + std::make_shared<op::Broadcast>(reshape, broadcast_out_shape, concat_axis_vector); |
| 289 | + |
| 290 | + replace_node(last_bounded_concat_op, broadcast); |
| 291 | + return true; |
| 292 | +} |
0 commit comments