@@ -56,20 +56,51 @@ static string describe_reshape(shared_ptr<Node> node)
56
56
return ss.str ();
57
57
}
58
58
59
+ static shared_ptr<op::Reshape>
60
+ make_reshape (shared_ptr<Node> arg, const AxisVector& input_order, const Shape& output_shape)
61
+ {
62
+ auto reshape = make_shared<op::Reshape>(arg, input_order, output_shape);
63
+ NGRAPH_DEBUG << " Make Reshape " << describe_reshape (reshape);
64
+ return reshape;
65
+ }
66
+
67
+ static void
68
+ write_reshapemap (ReshapeMap& reorders, shared_ptr<Node> target, shared_ptr<op::Reshape> reshape)
69
+ {
70
+ NGRAPH_DEBUG << " Write ReshapeMap[" << target->get_name ()
71
+ << " ] = " << describe_reshape (reshape);
72
+ reorders[target] = reshape;
73
+ }
74
+
75
+ static shared_ptr<op::Reshape> read_reshapemap (ReshapeMap& reorders, shared_ptr<Node> target)
76
+ {
77
+ auto reorder = reorders.at (target);
78
+ NGRAPH_DEBUG << " Read ReshapeMap[" << target->get_name () << " ] -> "
79
+ << describe_reshape (reorder);
80
+ return reorder;
81
+ }
82
+
59
83
static shared_ptr<op::Reshape> combine_reshapes (shared_ptr<op::Reshape> r1,
60
84
shared_ptr<op::Reshape> r2)
61
85
{
62
86
auto default_order = ngraph::get_default_order (r1->get_shape ());
63
87
auto perm_r1 = apply_permutation (default_order, r1->get_input_order ());
64
88
auto perm_r2 = apply_permutation (perm_r1, r2->get_input_order ());
65
- auto rreshape = make_shared<op::Reshape>(r2->get_argument (0 ), perm_r2, r2->get_shape ());
89
+ auto rreshape = make_reshape (r2->get_argument (0 ), perm_r2, r2->get_shape ());
90
+ NGRAPH_DEBUG << " Combining " << describe_reshape (r1) << " and " << describe_reshape (r2)
91
+ << " into " << describe_reshape (rreshape);
66
92
return rreshape;
67
93
}
68
94
69
95
static void insert_reshape (shared_ptr<Node> target, shared_ptr<Node> reshape, size_t input_index)
70
96
{
97
+ NGRAPH_DEBUG << " Inserting reshape at input " << target->get_name () << " input index "
98
+ << input_index;
71
99
auto arg = target->input (input_index).get_source_output ();
100
+ NGRAPH_DEBUG << " Arg shape: " << arg.get_shape ();
72
101
auto new_reshape = reshape->copy_with_new_inputs ({arg});
102
+ NGRAPH_DEBUG << " Inserting reshape " << describe_reshape (new_reshape) << " at input "
103
+ << target->get_name () << " input index " << input_index;
73
104
target->input (input_index).replace_source_output (new_reshape->output (0 ));
74
105
}
75
106
@@ -92,7 +123,8 @@ static void mark_reshape_for_deletion(shared_ptr<Node> reshape,
92
123
static shared_ptr<op::Reshape> create_default_reshape (shared_ptr<Node> n)
93
124
{
94
125
auto default_order = ngraph::get_default_order (n->get_shape ());
95
- auto default_reshape = make_shared<op::Reshape>(n, default_order, n->get_shape ());
126
+ auto default_reshape = make_reshape (n, default_order, n->get_shape ());
127
+ NGRAPH_DEBUG << " Default reshape: " << describe_reshape (default_reshape);
96
128
return default_reshape;
97
129
}
98
130
@@ -187,7 +219,7 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
187
219
auto new_arg_shape =
188
220
ngraph::apply_permutation (broadcast_input->get_shape (), new_source_axis_order);
189
221
broadcast_input =
190
- make_shared<op::Reshape> (broadcast_input, new_source_axis_order, new_arg_shape);
222
+ make_reshape (broadcast_input, new_source_axis_order, new_arg_shape);
191
223
}
192
224
193
225
auto new_broadcast = make_shared<op::Broadcast>(
@@ -209,26 +241,25 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
209
241
// of a binary op isn't in the default format (i.e. nhwc instead of nchw)
210
242
// We have to normalize this other argument to nchw by swimming nchw towards parameters
211
243
// as far as we can
212
- static void convert_binary_to_default_order (
213
- shared_ptr<Node> binary,
214
- const Input<Node>& input,
215
- shared_ptr<Node> right,
216
- unordered_map<shared_ptr<Node>, shared_ptr<op::Reshape>>& reorders,
217
- set<shared_ptr<Node>>& reshapes_to_delete)
244
+ static void convert_binary_to_default_order (shared_ptr<Node> binary,
245
+ const Input<Node>& input,
246
+ shared_ptr<Node> right,
247
+ ReshapeMap& reorders,
248
+ set<shared_ptr<Node>>& reshapes_to_delete)
218
249
{
219
250
auto left = input.get_source_output ().get_node_shared_ptr ();
220
251
auto perm_to_def =
221
252
ngraph::get_permutation_to_default_order (reorders.at (right)->get_input_order ());
222
253
auto new_shape = apply_permutation (left->get_shape (), perm_to_def);
223
254
NGRAPH_DEBUG << " right = " << ngraph::vector_to_string (right->get_shape ()) << " , "
224
255
<< right->get_name ();
225
- auto new_reshape = make_shared<op::Reshape> (left, perm_to_def, new_shape);
256
+ auto new_reshape = make_reshape (left, perm_to_def, new_shape);
226
257
NGRAPH_DEBUG << " left : About to swim " << describe_reshape (new_reshape) << " up to "
227
258
<< left->get_name ();
228
259
// this should now insert and swim reshape on right
229
260
swim (input, new_reshape);
230
261
mark_reshape_for_deletion (reorders.at (right), reshapes_to_delete);
231
- reorders[ binary] = reorders. at ( right);
262
+ write_reshapemap ( reorders, binary, read_reshapemap (reorders, right) );
232
263
}
233
264
234
265
static void materialize_shapes (shared_ptr<Node> n,
@@ -247,32 +278,37 @@ static void materialize_shapes(shared_ptr<Node> n,
247
278
auto arg = n->get_argument (i);
248
279
if (reorders.count (arg) != 0 )
249
280
{
250
- NGRAPH_DEBUG << " Materializing " << describe_reshape (reorders.at (arg)) << " for "
281
+ auto arg_reshape = reorders.at (arg);
282
+ NGRAPH_DEBUG << " Materializing " << describe_reshape (arg_reshape) << " for "
251
283
<< arg->get_name ();
252
- mark_reshape_for_deletion (reorders.at (arg), reshapes_to_delete);
253
- if (reorders.at (arg)->get_input_order () != get_default_order (arg->get_shape ()))
284
+ mark_reshape_for_deletion (arg_reshape, reshapes_to_delete);
285
+ auto arg_shape = arg->get_shape ();
286
+ if (arg_reshape->get_input_order () != get_default_order (arg->get_shape ()))
254
287
{
255
288
// Insert if arg needs to be transposed.
256
- insert_reshape (n, reorders. at (arg) , i);
289
+ insert_reshape (n, arg_reshape , i);
257
290
}
258
291
// no swimming up
259
292
}
260
293
}
261
- reorders[n] = create_default_reshape (n);
294
+ write_reshapemap ( reorders, n, create_default_reshape (n) );
262
295
}
263
296
264
297
static void sink_reshape (shared_ptr<op::Reshape> reshape,
265
298
ReshapeMap& reorders,
266
299
set<shared_ptr<Node>>& reshapes_to_delete)
267
300
{
301
+ NGRAPH_DEBUG << " Sinking Reshape :" << describe_reshape (reshape);
268
302
auto orig_reshape = reorders.at (reshape->get_argument (0 ));
269
- if (!reshape->get_is_transpose ())
303
+ // 1) Not a Transpose or 2) Rank changing operation.
304
+ if ((reshape->get_output_shape ().size () != reshape->get_input_order ().size ()) ||
305
+ (!reshape->get_is_transpose ()))
270
306
{
271
307
NGRAPH_DEBUG << " Materializing " << describe_reshape (orig_reshape) << " for reshape "
272
- << reshape-> get_name ( );
308
+ << describe_reshape (reshape );
273
309
insert_reshape (reshape, orig_reshape, 0 );
274
310
mark_reshape_for_deletion (orig_reshape, reshapes_to_delete);
275
- reorders[ reshape] = create_default_reshape (reshape);
311
+ write_reshapemap ( reorders, reshape, create_default_reshape (reshape) );
276
312
}
277
313
else
278
314
{
@@ -284,19 +320,17 @@ static void sink_reshape(shared_ptr<op::Reshape> reshape,
284
320
// replace reshape with combined one
285
321
ngraph::replace_node (reshape, new_reshape);
286
322
mark_reshape_for_deletion (new_reshape, reshapes_to_delete);
287
- reorders[new_reshape] = new_reshape;
288
- NGRAPH_DEBUG << " Combining " << describe_reshape (orig_reshape) << " and"
289
- << describe_reshape (reshape) << " into " << describe_reshape (new_reshape);
323
+ write_reshapemap (reorders, new_reshape, new_reshape);
290
324
}
291
325
}
292
326
293
327
static void sink_unary (shared_ptr<op::util::UnaryElementwiseArithmetic> n,
294
328
ReshapeMap& reorders,
295
329
set<shared_ptr<Node>>& reshapes_to_delete)
296
330
{
297
- auto arg_reshape = reorders. at ( n->get_argument (0 ));
331
+ auto arg_reshape = read_reshapemap (reorders, n->get_argument (0 ));
298
332
NGRAPH_DEBUG << " Propagating " << describe_reshape (arg_reshape) << " for " << n->get_name ();
299
- reorders[n] = reorders[n-> get_argument ( 0 )] ;
333
+ write_reshapemap ( reorders, n, arg_reshape) ;
300
334
}
301
335
302
336
static void sink_binary (shared_ptr<op::util::BinaryElementwiseArithmetic> binary,
@@ -310,7 +344,7 @@ static void sink_binary(shared_ptr<op::util::BinaryElementwiseArithmetic> binary
310
344
{
311
345
NGRAPH_DEBUG << " Propagating " << describe_reshape (reorders.at (left)) << " for "
312
346
<< binary->get_name ();
313
- reorders[ binary] = reorders. at ( left);
347
+ write_reshapemap ( reorders, binary, read_reshapemap (reorders, left) );
314
348
// at this point, both reshapes will be eventually removed
315
349
mark_reshape_for_deletion (reorders.at (left), reshapes_to_delete);
316
350
mark_reshape_for_deletion (reorders.at (right), reshapes_to_delete);
@@ -360,9 +394,9 @@ static void sink_slice(shared_ptr<op::Slice> n,
360
394
NGRAPH_DEBUG << " Replacing " << n->get_name () << " with " << new_slice->get_name ();
361
395
ngraph::replace_node (n, new_slice);
362
396
363
- auto new_reshape = make_shared<op::Reshape> (new_slice, order, n->get_shape ());
397
+ auto new_reshape = make_reshape (new_slice, order, n->get_shape ());
364
398
NGRAPH_DEBUG << " Propagating " << describe_reshape (new_reshape) << " for " << n->get_name ();
365
- reorders[ new_slice] = new_reshape;
399
+ write_reshapemap ( reorders, new_slice, new_reshape) ;
366
400
}
367
401
368
402
static void
@@ -385,9 +419,9 @@ static void
385
419
ngraph::replace_node (dummy_correct_shape, n->get_argument (0 ));
386
420
NGRAPH_DEBUG << " Replacing " << n->get_name () << " with " << new_pad->get_name ();
387
421
ngraph::replace_node (n, new_pad);
388
- auto new_reshape = make_shared<op::Reshape> (new_pad, order, n->get_shape ());
422
+ auto new_reshape = make_reshape (new_pad, order, n->get_shape ());
389
423
NGRAPH_DEBUG << " Propagating " << describe_reshape (new_reshape) << " for " << n->get_name ();
390
- reorders[ new_pad] = new_reshape;
424
+ write_reshapemap ( reorders, new_pad, new_reshape) ;
391
425
}
392
426
static void sink_quantize (shared_ptr<op::Quantize> quantize,
393
427
ReshapeMap& reorders,
@@ -404,7 +438,7 @@ static void sink_quantize(shared_ptr<op::Quantize> quantize,
404
438
quantize->get_round_mode ());
405
439
406
440
ngraph::replace_node (quantize, new_quantize);
407
- reorders[ new_quantize] = arg_reshape;
441
+ write_reshapemap ( reorders, new_quantize, arg_reshape) ;
408
442
}
409
443
410
444
static void sink_concat (shared_ptr<op::Concat> n,
@@ -451,9 +485,9 @@ static void sink_concat(shared_ptr<op::Concat> n,
451
485
NGRAPH_DEBUG << " Replacing " << n->get_name () << " with " << new_concat->get_name ();
452
486
ngraph::replace_node (n, new_concat);
453
487
454
- auto new_reshape = make_shared<op::Reshape> (new_concat, order, n->get_shape ());
488
+ auto new_reshape = make_reshape (new_concat, order, n->get_shape ());
455
489
NGRAPH_DEBUG << " Propagating " << describe_reshape (new_reshape) << " for " << n->get_name ();
456
- reorders[ new_concat] = new_reshape;
490
+ write_reshapemap ( reorders, new_concat, new_reshape) ;
457
491
}
458
492
459
493
static void sink_dequantize (shared_ptr<op::Dequantize> dequantize,
@@ -470,7 +504,7 @@ static void sink_dequantize(shared_ptr<op::Dequantize> dequantize,
470
504
axes_in_def_order);
471
505
472
506
ngraph::replace_node (dequantize, new_dequantize);
473
- reorders[ new_dequantize] = arg_reshape;
507
+ write_reshapemap ( reorders, new_dequantize, arg_reshape) ;
474
508
}
475
509
476
510
// The goal of ReshapeSinking is to remove
@@ -491,7 +525,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function>
491
525
// STEP 1 : Sink or Swim reshapes away for op clusters
492
526
for (auto n : f->get_ordered_ops ())
493
527
{
494
- NGRAPH_DEBUG << " Processing node " << n->get_name ();
528
+ NGRAPH_DEBUG << " Start: Processing node " << n->get_name ();
495
529
// collect all Result nodes for a sanity check
496
530
if (n->is_output ())
497
531
{
@@ -512,7 +546,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function>
512
546
}
513
547
else if (auto goe = dynamic_pointer_cast<op::GetOutputElement>(n))
514
548
{
515
- reorders[ goe] = create_default_reshape (goe);
549
+ write_reshapemap ( reorders, goe, create_default_reshape (goe) );
516
550
}
517
551
else if (auto quantize = dynamic_pointer_cast<op::Quantize>(n))
518
552
{
@@ -555,6 +589,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function>
555
589
{
556
590
materialize_shapes (n, reorders, reshapes_to_delete);
557
591
}
592
+ NGRAPH_DEBUG << " End: Processing node " << n->get_name ();
558
593
}
559
594
560
595
// STEP 2: purge all the reshapes we either sunk or swam.
0 commit comments