Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit 42a3a0a

Browse files
tachyon77diyessi
authored andcommitted
Implements proveance tag propagation for reshape sinking pass (#3742)
* Implements proveance tag propagation for reshape sinking pass * Addresses code review feedback. * Applies style fixes.
1 parent 5a43c0f commit 42a3a0a

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

src/ngraph/pass/reshape_sinking.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ static void insert_reshape(shared_ptr<Node> target, shared_ptr<Node> reshape, si
105105
auto arg = target->input(input_index).get_source_output();
106106
NGRAPH_DEBUG << "Arg shape: " << arg.get_shape();
107107
auto new_reshape = reshape->copy_with_new_inputs({arg});
108+
new_reshape->merge_provenance_tags_from(reshape);
108109
NGRAPH_DEBUG << "Inserting reshape " << describe_reshape(new_reshape) << " at input "
109110
<< target->get_name() << " input index " << input_index;
110111
target->input(input_index).replace_source_output(new_reshape->output(0));
@@ -115,7 +116,8 @@ static void delete_reshape(shared_ptr<Node> reshape)
115116
NGRAPH_DEBUG << "Removing reshape " << reshape->get_name();
116117
if (!reshape->get_users().empty())
117118
{
118-
ngraph::replace_node(reshape, reshape->get_argument(0));
119+
ngraph::replace_node(
120+
reshape, reshape->input(0).get_source_output().get_node_shared_ptr(), true);
119121
}
120122
}
121123

@@ -130,6 +132,7 @@ static shared_ptr<op::Reshape> create_default_reshape(shared_ptr<Node> n)
130132
{
131133
auto default_order = ngraph::get_default_order(n->get_shape());
132134
auto default_reshape = make_reshape(n, default_order, n->get_shape());
135+
default_reshape->merge_provenance_tags_from(n);
133136
NGRAPH_DEBUG << "Default reshape: " << describe_reshape(default_reshape);
134137
return default_reshape;
135138
}
@@ -230,13 +233,15 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
230233

231234
auto new_broadcast = make_shared<op::Broadcast>(
232235
broadcast_input, broadcast_reshape->get_shape(), new_broadcast_axes);
236+
new_broadcast->merge_provenance_tags_from(old_broadcast);
233237
csw.input.replace_source_output(new_broadcast->output(0));
234238
}
235239
//TODO: Add cases to push through Reshape and BinaryElementwiseArithmetic
236240
else
237241
{
238242
//materialize
239243
auto new_reshape = csw.reshape->copy_with_new_args({n});
244+
new_reshape->merge_provenance_tags_from(n);
240245
NGRAPH_DEBUG << "Materializing new reshape " << describe_reshape(new_reshape);
241246
csw.input.replace_source_output(new_reshape->output(0));
242247
}

0 commit comments

Comments
 (0)