Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit 5f83ff5

Browse files
tachyon77diyessi
authored andcommitted
Fixes provenance bug causing extra tags to be added during node replacement (#2950)
1 parent a02bfa4 commit 5f83ff5

File tree

3 files changed

+84
-1
lines changed

3 files changed

+84
-1
lines changed

src/ngraph/graph_util.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,36 @@ void ngraph::traverse_functions(std::shared_ptr<ngraph::Function> p,
135135
}
136136
}
137137

138+
NodeVector ngraph::find_common_args(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement)
139+
{
140+
std::unordered_set<std::shared_ptr<Node>> target_args;
141+
142+
auto compute_target_args = [&target_args](const std::shared_ptr<Node> node) {
143+
target_args.insert(node);
144+
};
145+
146+
traverse_nodes({target}, compute_target_args, false, NodeVector{});
147+
148+
std::unordered_set<std::shared_ptr<Node>> replacement_args;
149+
150+
auto compute_replacement_args = [&replacement_args](const std::shared_ptr<Node> node) {
151+
replacement_args.insert(node);
152+
};
153+
154+
traverse_nodes({replacement}, compute_replacement_args, false, NodeVector{});
155+
156+
NodeVector common_args;
157+
for (auto e : target_args)
158+
{
159+
if (replacement_args.count(e) > 0)
160+
{
161+
common_args.push_back(e);
162+
}
163+
}
164+
165+
return common_args;
166+
}
167+
138168
void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement)
139169
{
140170
if (target->is_output())
@@ -156,7 +186,8 @@ void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> re
156186
replacement->merge_provenance_tags_from(node);
157187
};
158188

159-
traverse_nodes({target}, set_replacement_prov, false, replacement->get_arguments());
189+
traverse_nodes(
190+
{target}, set_replacement_prov, false, ngraph::find_common_args(target, replacement));
160191
}
161192

162193
// For each of target's output O with replacement output O_rep:

src/ngraph/graph_util.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ namespace ngraph
7575

7676
void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
7777

78+
NodeVector find_common_args(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
79+
7880
template <typename T>
7981
std::list<std::shared_ptr<Node>> topological_sort(const T& nodes,
8082
bool include_control_deps = false)

test/provenance.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,4 +209,54 @@ TEST(provenance, provenance)
209209

210210
EXPECT_EQ(d->get_provenance_tags(), (ProvSet{"tag_a", "tag_b", "tag_c"}));
211211
}
212+
213+
//
214+
// Before:
215+
//
216+
// A{tag_a} B{tag_b}
217+
// | |
218+
// C{tag_c}
219+
//
220+
//
221+
// Replacement:
222+
//
223+
// A{tag_a} B{tag_b}
224+
// | |
225+
// E{tag_e} |
226+
// | |
227+
// C -> D{tag_d}
228+
//
229+
//
230+
// After:
231+
//
232+
// A{tag_a} B{tag_b}
233+
// | |
234+
// E{tag_e} |
235+
// | |
236+
// D{tag_c, tag_d}
237+
//
238+
// Comment:
239+
// * D is the replacement root replacing C and creating a new argument node E
240+
//
241+
{
242+
auto x = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
243+
auto y = make_shared<op::Parameter>(element::i32, PartialShape{2, 3, 4});
244+
245+
auto a = make_shared<op::Add>(x, y);
246+
a->add_provenance_tag("tag_a");
247+
auto b = make_shared<op::Multiply>(y, x);
248+
b->add_provenance_tag("tag_b");
249+
auto c = make_shared<op::Subtract>(a, b);
250+
c->add_provenance_tag("tag_c");
251+
252+
auto f = make_shared<Function>(c, ParameterVector{x, y});
253+
254+
auto e = make_shared<op::Subtract>(a, x);
255+
auto d = make_shared<op::Subtract>(e, b);
256+
d->add_provenance_tag("tag_d");
257+
258+
replace_node(c, d);
259+
260+
EXPECT_EQ(d->get_provenance_tags(), (ProvSet{"tag_c", "tag_d"}));
261+
}
212262
}

0 commit comments

Comments
 (0)