@@ -1882,21 +1882,16 @@ void StageGraph(Graph* dest, Node* stage_node, Node* unstage_node,
18821882 }
18831883 }
18841884
1885- EdgeSet source_edge_set ;
1885+ std::unordered_set<Node *> source_node_set ;
18861886 for (const Edge* out_edge : out_edges) {
18871887 const Edge* in_edge = NULL ;
18881888 int index = out_edge->src_output ();
18891889 Status s = stage_node->input_edge (index, &in_edge);
18901890 TF_CHECK_OK (s);
1891- const Node* dst = out_edge->dst ();
1892- int dst_input = out_edge->dst_input ();
1891+ Node* dst = out_edge->dst ();
18931892 s = dest->UpdateEdge (in_edge->src (), in_edge->src_output (), out_edge->dst (), out_edge->dst_input ());
18941893 TF_CHECK_OK (s);
1895- const Edge* e = NULL ;
1896- s = dst->input_edge (dst_input, &e);
1897- TF_CHECK_OK (s);
1898- CHECK (e != nullptr );
1899- source_edge_set.insert (e);
1894+ source_node_set.insert (dst);
19001895 }
19011896
19021897 std::vector<const Edge*> in_edges;
@@ -1909,7 +1904,7 @@ void StageGraph(Graph* dest, Node* stage_node, Node* unstage_node,
19091904 }
19101905
19111906 std::vector<const Edge*> edge_vec;
1912- GetStagingEdges (*dest, source_edge_set , target_nodes, edge_vec);
1907+ GetStagingEdges (*dest, source_node_set , target_nodes, edge_vec);
19131908
19141909 std::vector<DataType> type_vec;
19151910 int i = 0 ;
@@ -1918,6 +1913,11 @@ void StageGraph(Graph* dest, Node* stage_node, Node* unstage_node,
19181913 std::map<const Edge*, int64> edge_to_stage;
19191914 std::map<const Edge*, int64> edge_to_unstage;
19201915 for (const Edge* e : edge_vec) {
1916+ if (e->IsControlEdge ()) {
1917+ // control flow is implemented by stage node and unstage node, remove control edge.
1918+ dest->RemoveEdge (e);
1919+ continue ;
1920+ }
19211921 std::string name = e->src ()->name () + std::to_string (e->src_output ());
19221922 if (edge_map.find (name) == edge_map.end ()) {
19231923 type_vec.push_back (e->src ()->output_type (e->src_output ()));
@@ -1969,17 +1969,18 @@ void StageGraph(Graph* dest, Node* stage_node, Node* unstage_node,
19691969 }
19701970}
19711971
1972- void GetStagingEdges (const Graph& dest, const EdgeSet& source_edge_set ,
1972+ void GetStagingEdges (const Graph& dest, const std::unordered_set<Node *>& source_node_set ,
19731973 const std::vector<std::string>& target_nodes,
19741974 std::vector<const Edge*>& edge_vec) {
19751975 std::queue<const Node*> q;
19761976 for (Node* n : dest.op_nodes ()) {
1977- if (n->IsVariable () || n->IsKvVarHandle () || n->IsPlaceholder () ||
1977+ if (n->IsVariable () || n->IsKvVarHandle () || n->IsPlaceholder () || n-> IsControlFlow () ||
19781978 std::find (target_nodes.begin (), target_nodes.end (), n->name ()) != target_nodes.end ()) {
19791979 q.push (n);
19801980 }
19811981 }
1982- std::vector<bool > is_var_relate (dest.num_nodes ());
1982+
1983+ std::vector<bool > is_var_relate (dest.num_nodes (), false );
19831984 while (!q.empty ()) {
19841985 const Node* node = q.front ();
19851986 q.pop ();
@@ -1991,53 +1992,25 @@ void GetStagingEdges(const Graph& dest, const EdgeSet& source_edge_set,
19911992 }
19921993 }
19931994
1994- std::map<std::string, std::set<const Edge*>> edges_map;
1995- for (const Edge* e : source_edge_set) {
1996- std::string name = e->src ()->name () + std::to_string (e->src_output ());
1997- edges_map[name].insert (e);;
1995+ std::queue<Node *> queue;
1996+ for (Node *n : source_node_set) {
1997+ queue.push (n);
19981998 }
19991999
2000- std::vector<std::string> has_visit_node_output;
2001-
2002- std::queue<std::set<const Edge*>> queue;
2003- for (auto iter = edges_map.begin (); iter != edges_map.end (); ++iter) {
2004- queue.push (iter->second );
2005- has_visit_node_output.push_back (iter->first );
2006- }
2000+ std::unordered_set<Node *> has_visit_node;
20072001 while (!queue.empty ()) {
2008- auto edges = queue.front ();
2002+ Node *n = queue.front ();
20092003 queue.pop ();
2010- bool stop = false ;
2011- for (const Edge* e : edges) {
2012- if (is_var_relate[e->dst ()->id ()] ||
2013- e->IsControlEdge () ||
2014- e->dst ()->IsControlFlow ()) {
2015- stop = true ;
2016- }
2017- }
2018- if (stop) {
2019- for (const Edge* e : edges) {
2020- if (!e->IsControlEdge ()) {
2021- if (std::find (edge_vec.begin (), edge_vec.end (), e) == edge_vec.end ())
2022- edge_vec.push_back (e);
2023- }
2024- }
2025- } else {
2026- for (const Edge* e : edges) {
2027- const Node* node = e->dst ();
2028- std::map<std::string, std::set<const Edge*>> edges_map;
2029- for (const Edge* e : node->out_edges ()) {
2030- std::string name = e->src ()->name () + std::to_string (e->src_output ());
2031- edges_map[name].insert (e);
2032- }
2033- for (auto iter = edges_map.begin (); iter != edges_map.end (); ++iter) {
2034- if (std::find (has_visit_node_output.begin (),
2035- has_visit_node_output.end (),
2036- iter->first ) == has_visit_node_output.end ()) {
2037- queue.push (iter->second );
2038- has_visit_node_output.push_back (iter->first );
2039- }
2040- }
2004+ if (has_visit_node.find (n) != has_visit_node.end ())
2005+ continue ;
2006+
2007+ has_visit_node.insert (n);
2008+ for (auto edge : n->out_edges ()) {
2009+ Node *dst = edge->dst ();
2010+ if (is_var_relate[dst->id ()]) {
2011+ edge_vec.push_back (edge);
2012+ } else {
2013+ queue.push (dst);
20412014 }
20422015 }
20432016 }
0 commit comments