Skip to content

Commit 44b0bd7

Browse files
authored
[Graph] Fix graph contains circle bug when enable SmartStage optimization. (#149)
1 parent 06964b3 commit 44b0bd7

File tree

2 files changed

+30
-56
lines changed

2 files changed

+30
-56
lines changed

tensorflow/core/graph/graph_constructor.cc

Lines changed: 28 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1882,21 +1882,16 @@ void StageGraph(Graph* dest, Node* stage_node, Node* unstage_node,
18821882
}
18831883
}
18841884

1885-
EdgeSet source_edge_set;
1885+
std::unordered_set<Node *> source_node_set;
18861886
for (const Edge* out_edge : out_edges) {
18871887
const Edge* in_edge = NULL;
18881888
int index = out_edge->src_output();
18891889
Status s = stage_node->input_edge(index, &in_edge);
18901890
TF_CHECK_OK(s);
1891-
const Node* dst = out_edge->dst();
1892-
int dst_input = out_edge->dst_input();
1891+
Node* dst = out_edge->dst();
18931892
s = dest->UpdateEdge(in_edge->src(), in_edge->src_output(), out_edge->dst(), out_edge->dst_input());
18941893
TF_CHECK_OK(s);
1895-
const Edge* e = NULL;
1896-
s = dst->input_edge(dst_input, &e);
1897-
TF_CHECK_OK(s);
1898-
CHECK(e != nullptr);
1899-
source_edge_set.insert(e);
1894+
source_node_set.insert(dst);
19001895
}
19011896

19021897
std::vector<const Edge*> in_edges;
@@ -1909,7 +1904,7 @@ void StageGraph(Graph* dest, Node* stage_node, Node* unstage_node,
19091904
}
19101905

19111906
std::vector<const Edge*> edge_vec;
1912-
GetStagingEdges(*dest, source_edge_set, target_nodes, edge_vec);
1907+
GetStagingEdges(*dest, source_node_set, target_nodes, edge_vec);
19131908

19141909
std::vector<DataType> type_vec;
19151910
int i = 0;
@@ -1918,6 +1913,11 @@ void StageGraph(Graph* dest, Node* stage_node, Node* unstage_node,
19181913
std::map<const Edge*, int64> edge_to_stage;
19191914
std::map<const Edge*, int64> edge_to_unstage;
19201915
for (const Edge* e : edge_vec) {
1916+
if (e->IsControlEdge()) {
1917+
// control flow is implemented by stage node and unstage node, remove control edge.
1918+
dest->RemoveEdge(e);
1919+
continue;
1920+
}
19211921
std::string name = e->src()->name() + std::to_string(e->src_output());
19221922
if (edge_map.find(name) == edge_map.end()) {
19231923
type_vec.push_back(e->src()->output_type(e->src_output()));
@@ -1969,17 +1969,18 @@ void StageGraph(Graph* dest, Node* stage_node, Node* unstage_node,
19691969
}
19701970
}
19711971

1972-
void GetStagingEdges(const Graph& dest, const EdgeSet& source_edge_set,
1972+
void GetStagingEdges(const Graph& dest, const std::unordered_set<Node *>& source_node_set,
19731973
const std::vector<std::string>& target_nodes,
19741974
std::vector<const Edge*>& edge_vec) {
19751975
std::queue<const Node*> q;
19761976
for (Node* n : dest.op_nodes()) {
1977-
if (n->IsVariable() || n->IsKvVarHandle() || n->IsPlaceholder() ||
1977+
if (n->IsVariable() || n->IsKvVarHandle() || n->IsPlaceholder() || n->IsControlFlow() ||
19781978
std::find(target_nodes.begin(), target_nodes.end(), n->name()) != target_nodes.end()) {
19791979
q.push(n);
19801980
}
19811981
}
1982-
std::vector<bool> is_var_relate(dest.num_nodes());
1982+
1983+
std::vector<bool> is_var_relate(dest.num_nodes(), false);
19831984
while (!q.empty()) {
19841985
const Node* node = q.front();
19851986
q.pop();
@@ -1991,53 +1992,25 @@ void GetStagingEdges(const Graph& dest, const EdgeSet& source_edge_set,
19911992
}
19921993
}
19931994

1994-
std::map<std::string, std::set<const Edge*>> edges_map;
1995-
for (const Edge* e : source_edge_set) {
1996-
std::string name = e->src()->name() + std::to_string(e->src_output());
1997-
edges_map[name].insert(e);;
1995+
std::queue<Node *> queue;
1996+
for (Node *n : source_node_set) {
1997+
queue.push(n);
19981998
}
19991999

2000-
std::vector<std::string> has_visit_node_output;
2001-
2002-
std::queue<std::set<const Edge*>> queue;
2003-
for (auto iter = edges_map.begin(); iter != edges_map.end(); ++iter) {
2004-
queue.push(iter->second);
2005-
has_visit_node_output.push_back(iter->first);
2006-
}
2000+
std::unordered_set<Node *> has_visit_node;
20072001
while (!queue.empty()) {
2008-
auto edges = queue.front();
2002+
Node *n = queue.front();
20092003
queue.pop();
2010-
bool stop = false;
2011-
for (const Edge* e : edges) {
2012-
if (is_var_relate[e->dst()->id()] ||
2013-
e->IsControlEdge() ||
2014-
e->dst()->IsControlFlow()) {
2015-
stop = true;
2016-
}
2017-
}
2018-
if (stop) {
2019-
for (const Edge* e : edges) {
2020-
if (!e->IsControlEdge()) {
2021-
if (std::find(edge_vec.begin(), edge_vec.end(), e) == edge_vec.end())
2022-
edge_vec.push_back(e);
2023-
}
2024-
}
2025-
} else {
2026-
for (const Edge* e : edges) {
2027-
const Node* node = e->dst();
2028-
std::map<std::string, std::set<const Edge*>> edges_map;
2029-
for (const Edge* e : node->out_edges()) {
2030-
std::string name = e->src()->name() + std::to_string(e->src_output());
2031-
edges_map[name].insert(e);
2032-
}
2033-
for (auto iter = edges_map.begin(); iter != edges_map.end(); ++iter) {
2034-
if (std::find(has_visit_node_output.begin(),
2035-
has_visit_node_output.end(),
2036-
iter->first) == has_visit_node_output.end()) {
2037-
queue.push(iter->second);
2038-
has_visit_node_output.push_back(iter->first);
2039-
}
2040-
}
2004+
if (has_visit_node.find(n) != has_visit_node.end())
2005+
continue;
2006+
2007+
has_visit_node.insert(n);
2008+
for (auto edge : n->out_edges()) {
2009+
Node *dst = edge->dst();
2010+
if (is_var_relate[dst->id()]) {
2011+
edge_vec.push_back(edge);
2012+
} else {
2013+
queue.push(dst);
20412014
}
20422015
}
20432016
}

tensorflow/core/graph/graph_constructor.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,8 @@ extern void ExtendGraph(Graph* dest, std::unordered_set<const Node*> excluded,
194194

195195
extern void StageGraph(Graph* dest, Node* stage_node, Node* unstage_node,
196196
const std::vector<std::string>& target_nodes);
197-
extern void GetStagingEdges(const Graph& dest, const EdgeSet& source_edge_set,
197+
extern void GetStagingEdges(const Graph& dest,
198+
const std::unordered_set<Node *>& source_node_set,
198199
const std::vector<std::string>& target_nodes,
199200
std::vector<const Edge*>& edge_vec);
200201
} // namespace tensorflow

0 commit comments

Comments
 (0)