Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit 4015565

Browse files
silee2diyessi
authored andcommitted
Reshape sinking: fix issue with handling rank changing reshape. (#3313)
1 parent ad20f21 commit 4015565

File tree

1 file changed

+70
-35
lines changed

1 file changed

+70
-35
lines changed

src/ngraph/pass/reshape_sinking.cpp

Lines changed: 70 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -56,20 +56,51 @@ static string describe_reshape(shared_ptr<Node> node)
5656
return ss.str();
5757
}
5858

59+
static shared_ptr<op::Reshape>
60+
make_reshape(shared_ptr<Node> arg, const AxisVector& input_order, const Shape& output_shape)
61+
{
62+
auto reshape = make_shared<op::Reshape>(arg, input_order, output_shape);
63+
NGRAPH_DEBUG << "Make Reshape " << describe_reshape(reshape);
64+
return reshape;
65+
}
66+
67+
static void
68+
write_reshapemap(ReshapeMap& reorders, shared_ptr<Node> target, shared_ptr<op::Reshape> reshape)
69+
{
70+
NGRAPH_DEBUG << "Write ReshapeMap[" << target->get_name()
71+
<< "] = " << describe_reshape(reshape);
72+
reorders[target] = reshape;
73+
}
74+
75+
static shared_ptr<op::Reshape> read_reshapemap(ReshapeMap& reorders, shared_ptr<Node> target)
76+
{
77+
auto reorder = reorders.at(target);
78+
NGRAPH_DEBUG << "Read ReshapeMap[" << target->get_name() << "] -> "
79+
<< describe_reshape(reorder);
80+
return reorder;
81+
}
82+
5983
static shared_ptr<op::Reshape> combine_reshapes(shared_ptr<op::Reshape> r1,
6084
shared_ptr<op::Reshape> r2)
6185
{
6286
auto default_order = ngraph::get_default_order(r1->get_shape());
6387
auto perm_r1 = apply_permutation(default_order, r1->get_input_order());
6488
auto perm_r2 = apply_permutation(perm_r1, r2->get_input_order());
65-
auto rreshape = make_shared<op::Reshape>(r2->get_argument(0), perm_r2, r2->get_shape());
89+
auto rreshape = make_reshape(r2->get_argument(0), perm_r2, r2->get_shape());
90+
NGRAPH_DEBUG << "Combining " << describe_reshape(r1) << " and " << describe_reshape(r2)
91+
<< " into " << describe_reshape(rreshape);
6692
return rreshape;
6793
}
6894

6995
static void insert_reshape(shared_ptr<Node> target, shared_ptr<Node> reshape, size_t input_index)
7096
{
97+
NGRAPH_DEBUG << "Inserting reshape at input " << target->get_name() << " input index "
98+
<< input_index;
7199
auto arg = target->input(input_index).get_source_output();
100+
NGRAPH_DEBUG << "Arg shape: " << arg.get_shape();
72101
auto new_reshape = reshape->copy_with_new_inputs({arg});
102+
NGRAPH_DEBUG << "Inserting reshape " << describe_reshape(new_reshape) << " at input "
103+
<< target->get_name() << " input index " << input_index;
73104
target->input(input_index).replace_source_output(new_reshape->output(0));
74105
}
75106

@@ -92,7 +123,8 @@ static void mark_reshape_for_deletion(shared_ptr<Node> reshape,
92123
static shared_ptr<op::Reshape> create_default_reshape(shared_ptr<Node> n)
93124
{
94125
auto default_order = ngraph::get_default_order(n->get_shape());
95-
auto default_reshape = make_shared<op::Reshape>(n, default_order, n->get_shape());
126+
auto default_reshape = make_reshape(n, default_order, n->get_shape());
127+
NGRAPH_DEBUG << "Default reshape: " << describe_reshape(default_reshape);
96128
return default_reshape;
97129
}
98130

@@ -187,7 +219,7 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
187219
auto new_arg_shape =
188220
ngraph::apply_permutation(broadcast_input->get_shape(), new_source_axis_order);
189221
broadcast_input =
190-
make_shared<op::Reshape>(broadcast_input, new_source_axis_order, new_arg_shape);
222+
make_reshape(broadcast_input, new_source_axis_order, new_arg_shape);
191223
}
192224

193225
auto new_broadcast = make_shared<op::Broadcast>(
@@ -209,26 +241,25 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
209241
//of a binary op isn't in the default format (i.e. nhwc instead of nchw)
210242
//We have to normalize this other argument to nchw by swimming nchw towards parameters
211243
//as far as we can
212-
static void convert_binary_to_default_order(
213-
shared_ptr<Node> binary,
214-
const Input<Node>& input,
215-
shared_ptr<Node> right,
216-
unordered_map<shared_ptr<Node>, shared_ptr<op::Reshape>>& reorders,
217-
set<shared_ptr<Node>>& reshapes_to_delete)
244+
static void convert_binary_to_default_order(shared_ptr<Node> binary,
245+
const Input<Node>& input,
246+
shared_ptr<Node> right,
247+
ReshapeMap& reorders,
248+
set<shared_ptr<Node>>& reshapes_to_delete)
218249
{
219250
auto left = input.get_source_output().get_node_shared_ptr();
220251
auto perm_to_def =
221252
ngraph::get_permutation_to_default_order(reorders.at(right)->get_input_order());
222253
auto new_shape = apply_permutation(left->get_shape(), perm_to_def);
223254
NGRAPH_DEBUG << "right = " << ngraph::vector_to_string(right->get_shape()) << ", "
224255
<< right->get_name();
225-
auto new_reshape = make_shared<op::Reshape>(left, perm_to_def, new_shape);
256+
auto new_reshape = make_reshape(left, perm_to_def, new_shape);
226257
NGRAPH_DEBUG << "left : About to swim " << describe_reshape(new_reshape) << " up to "
227258
<< left->get_name();
228259
//this should now insert and swim reshape on right
229260
swim(input, new_reshape);
230261
mark_reshape_for_deletion(reorders.at(right), reshapes_to_delete);
231-
reorders[binary] = reorders.at(right);
262+
write_reshapemap(reorders, binary, read_reshapemap(reorders, right));
232263
}
233264

234265
static void materialize_shapes(shared_ptr<Node> n,
@@ -247,32 +278,37 @@ static void materialize_shapes(shared_ptr<Node> n,
247278
auto arg = n->get_argument(i);
248279
if (reorders.count(arg) != 0)
249280
{
250-
NGRAPH_DEBUG << "Materializing " << describe_reshape(reorders.at(arg)) << " for "
281+
auto arg_reshape = reorders.at(arg);
282+
NGRAPH_DEBUG << "Materializing " << describe_reshape(arg_reshape) << " for "
251283
<< arg->get_name();
252-
mark_reshape_for_deletion(reorders.at(arg), reshapes_to_delete);
253-
if (reorders.at(arg)->get_input_order() != get_default_order(arg->get_shape()))
284+
mark_reshape_for_deletion(arg_reshape, reshapes_to_delete);
285+
auto arg_shape = arg->get_shape();
286+
if (arg_reshape->get_input_order() != get_default_order(arg->get_shape()))
254287
{
255288
// Insert if arg needs to be transposed.
256-
insert_reshape(n, reorders.at(arg), i);
289+
insert_reshape(n, arg_reshape, i);
257290
}
258291
//no swimming up
259292
}
260293
}
261-
reorders[n] = create_default_reshape(n);
294+
write_reshapemap(reorders, n, create_default_reshape(n));
262295
}
263296

264297
static void sink_reshape(shared_ptr<op::Reshape> reshape,
265298
ReshapeMap& reorders,
266299
set<shared_ptr<Node>>& reshapes_to_delete)
267300
{
301+
NGRAPH_DEBUG << "Sinking Reshape :" << describe_reshape(reshape);
268302
auto orig_reshape = reorders.at(reshape->get_argument(0));
269-
if (!reshape->get_is_transpose())
303+
// 1) Not a Transpose or 2) Rank changing operation.
304+
if ((reshape->get_output_shape().size() != reshape->get_input_order().size()) ||
305+
(!reshape->get_is_transpose()))
270306
{
271307
NGRAPH_DEBUG << "Materializing " << describe_reshape(orig_reshape) << " for reshape "
272-
<< reshape->get_name();
308+
<< describe_reshape(reshape);
273309
insert_reshape(reshape, orig_reshape, 0);
274310
mark_reshape_for_deletion(orig_reshape, reshapes_to_delete);
275-
reorders[reshape] = create_default_reshape(reshape);
311+
write_reshapemap(reorders, reshape, create_default_reshape(reshape));
276312
}
277313
else
278314
{
@@ -284,19 +320,17 @@ static void sink_reshape(shared_ptr<op::Reshape> reshape,
284320
//replace reshape with combined one
285321
ngraph::replace_node(reshape, new_reshape);
286322
mark_reshape_for_deletion(new_reshape, reshapes_to_delete);
287-
reorders[new_reshape] = new_reshape;
288-
NGRAPH_DEBUG << "Combining " << describe_reshape(orig_reshape) << " and"
289-
<< describe_reshape(reshape) << " into " << describe_reshape(new_reshape);
323+
write_reshapemap(reorders, new_reshape, new_reshape);
290324
}
291325
}
292326

293327
static void sink_unary(shared_ptr<op::util::UnaryElementwiseArithmetic> n,
294328
ReshapeMap& reorders,
295329
set<shared_ptr<Node>>& reshapes_to_delete)
296330
{
297-
auto arg_reshape = reorders.at(n->get_argument(0));
331+
auto arg_reshape = read_reshapemap(reorders, n->get_argument(0));
298332
NGRAPH_DEBUG << "Propagating " << describe_reshape(arg_reshape) << " for " << n->get_name();
299-
reorders[n] = reorders[n->get_argument(0)];
333+
write_reshapemap(reorders, n, arg_reshape);
300334
}
301335

302336
static void sink_binary(shared_ptr<op::util::BinaryElementwiseArithmetic> binary,
@@ -310,7 +344,7 @@ static void sink_binary(shared_ptr<op::util::BinaryElementwiseArithmetic> binary
310344
{
311345
NGRAPH_DEBUG << "Propagating " << describe_reshape(reorders.at(left)) << " for "
312346
<< binary->get_name();
313-
reorders[binary] = reorders.at(left);
347+
write_reshapemap(reorders, binary, read_reshapemap(reorders, left));
314348
//at this point, both reshapes will be eventually removed
315349
mark_reshape_for_deletion(reorders.at(left), reshapes_to_delete);
316350
mark_reshape_for_deletion(reorders.at(right), reshapes_to_delete);
@@ -360,9 +394,9 @@ static void sink_slice(shared_ptr<op::Slice> n,
360394
NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_slice->get_name();
361395
ngraph::replace_node(n, new_slice);
362396

363-
auto new_reshape = make_shared<op::Reshape>(new_slice, order, n->get_shape());
397+
auto new_reshape = make_reshape(new_slice, order, n->get_shape());
364398
NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name();
365-
reorders[new_slice] = new_reshape;
399+
write_reshapemap(reorders, new_slice, new_reshape);
366400
}
367401

368402
static void
@@ -385,9 +419,9 @@ static void
385419
ngraph::replace_node(dummy_correct_shape, n->get_argument(0));
386420
NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_pad->get_name();
387421
ngraph::replace_node(n, new_pad);
388-
auto new_reshape = make_shared<op::Reshape>(new_pad, order, n->get_shape());
422+
auto new_reshape = make_reshape(new_pad, order, n->get_shape());
389423
NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name();
390-
reorders[new_pad] = new_reshape;
424+
write_reshapemap(reorders, new_pad, new_reshape);
391425
}
392426
static void sink_quantize(shared_ptr<op::Quantize> quantize,
393427
ReshapeMap& reorders,
@@ -404,7 +438,7 @@ static void sink_quantize(shared_ptr<op::Quantize> quantize,
404438
quantize->get_round_mode());
405439

406440
ngraph::replace_node(quantize, new_quantize);
407-
reorders[new_quantize] = arg_reshape;
441+
write_reshapemap(reorders, new_quantize, arg_reshape);
408442
}
409443

410444
static void sink_concat(shared_ptr<op::Concat> n,
@@ -451,9 +485,9 @@ static void sink_concat(shared_ptr<op::Concat> n,
451485
NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_concat->get_name();
452486
ngraph::replace_node(n, new_concat);
453487

454-
auto new_reshape = make_shared<op::Reshape>(new_concat, order, n->get_shape());
488+
auto new_reshape = make_reshape(new_concat, order, n->get_shape());
455489
NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name();
456-
reorders[new_concat] = new_reshape;
490+
write_reshapemap(reorders, new_concat, new_reshape);
457491
}
458492

459493
static void sink_dequantize(shared_ptr<op::Dequantize> dequantize,
@@ -470,7 +504,7 @@ static void sink_dequantize(shared_ptr<op::Dequantize> dequantize,
470504
axes_in_def_order);
471505

472506
ngraph::replace_node(dequantize, new_dequantize);
473-
reorders[new_dequantize] = arg_reshape;
507+
write_reshapemap(reorders, new_dequantize, arg_reshape);
474508
}
475509

476510
//The goal of ReshapeSinking is to remove
@@ -491,7 +525,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function>
491525
//STEP 1 : Sink or Swim reshapes away for op clusters
492526
for (auto n : f->get_ordered_ops())
493527
{
494-
NGRAPH_DEBUG << "Processing node " << n->get_name();
528+
NGRAPH_DEBUG << "Start: Processing node " << n->get_name();
495529
//collect all Result nodes for a sanity check
496530
if (n->is_output())
497531
{
@@ -512,7 +546,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function>
512546
}
513547
else if (auto goe = dynamic_pointer_cast<op::GetOutputElement>(n))
514548
{
515-
reorders[goe] = create_default_reshape(goe);
549+
write_reshapemap(reorders, goe, create_default_reshape(goe));
516550
}
517551
else if (auto quantize = dynamic_pointer_cast<op::Quantize>(n))
518552
{
@@ -555,6 +589,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function>
555589
{
556590
materialize_shapes(n, reorders, reshapes_to_delete);
557591
}
592+
NGRAPH_DEBUG << "End: Processing node " << n->get_name();
558593
}
559594

560595
//STEP 2: purge all the reshapes we either sunk or swam.

0 commit comments

Comments
 (0)