@@ -29,33 +29,237 @@ using namespace std;
2929
3030#define TI (x ) type_index(typeid (x))
3131
32+ //
33+ // As we are visualizing the graph, we will make some tweaks to the generated dot file to make
34+ // routing more tractable for Graphviz as well as (hopefully) more legible for the user.
35+ //
36+ // NOTE: It's possible, even likely, that better algorithms are available here. I just tried a
37+ // few different things without doing much research, and this seemed to work well. Please feel
38+ // free to improve on this. --amprocte
39+ //
40+ // -----------------
41+ //
42+ // The first tweak is to trim edges that, intuitively speaking, have long "skip distance". For
43+ // example:
44+ //
45+ // [Actual Graph Structure] [Visualization]
46+ // n0 n0
47+ // | \ | \
48+ // n1 \ n1 [to n50]
49+ // | | |
50+ // n2 | n2
51+ // | | |
52+ // n3 | n3
53+ // | | |
54+ // ... | ... [from n0]
55+ // | / | /
56+ // n50 n50
57+ //
58+ // This is useful for training graphs especially, which tend to have very long feed-forward edges
59+ // for intermediate values from fprop being stored for later reuse in the bprop phase.
60+ //
61+ // Efficiently detecting a "long skip" is a bit tricky. We want to come up with a metric that is
62+ // reasonably fast to compute, but does not result in cuts that will split the graph into multiple
63+ // components. The heuristic we are using for the jump distance between n and m is the maximum
64+ // difference in maximum path length from n and m to any result node that is reachable from both
65+ // n and m (or 0, if no such result node exists). Not sure if this is mathematically *guaranteed*
66+ // not to split graph components, but it seems to work well in practice.
67+ //
68+ // Formally:
69+ //
70+ // Compute-Heights-Above-Each-Parameter(N):
71+ // Inputs: nodes N; define R={n in N | n is a Result node}
72+ // Output: height_maps: map from N to (map from R to int)
73+ //
74+ // height_maps is initially empty
75+ //
76+ // for each r in R:
77+ // Insert into height_map the map {r -> 1}
78+ //
79+ // for each n in N in reverse topological ("results-first") order:
80+ // for each user m of n:
81+ // for each r in height_maps[m].keys:
82+ // height_maps[n][r] := max(height_maps[n][r], height_maps[m][r]+1)
83+ //
84+ // Jump-Distance(n,m,height_maps):
85+ // Inputs: n (source node), m (destination node), height_maps (pre-computed above)
86+ // Output: jump_distance: int
87+ //
88+ // jump_distance := 0
89+ //
90+ // for each r in height_maps[n].keys:
91+ // if r is in height_maps[m].keys:
92+ // jump_distance := max(jump_distance, abs(height_maps[n][r] - height_maps[m][r]))
93+ //
94+ // Later on, if E is an edge from n to m, and Jump-Distance(n,m,height_map) > K (where K is kind
95+ // of arbitrary but currently set to 20), we will "cut" the edge as illustrated above.
96+ //
97+ // -----------------
98+ //
99+ // The second tweak aims to eliminate routing pressure from nodes that have large outdegree and
100+ // are connected to many otherwise-distant places in the graph. For this, the only thing we are
101+ // doing at the moment is to "float" Parameter and Constant nodes. This means that rather than
102+ // visualizing them as a single node (which might have very large outdegree as in, e.g., a
103+ // learning rate parameter being fed to many different places), we make a "copy" of the node at
104+ // each occurrence site (with a dashed outline).
105+ //
106+ // NOTE: This tweak could probably be extended to float other kinds of nodes with high out-degree.
107+ // (This situation is likely to arise after constant subexpression elimination.) Here one has to
108+ // be careful to avoid splitting the components. I have some rough ideas on how this could be
109+ // dealt with, but have not had time to implement them yet. --amprocte
110+ //
111+ class HeightMap
112+ {
113+ public:
114+ HeightMap () {}
115+ HeightMap (std::set<Node*> initials)
116+ {
117+ for (auto & n : initials)
118+ {
119+ m_heights[n] = 0 ;
120+ }
121+ }
122+ void absorb (const HeightMap& other)
123+ {
124+ for (auto & p : other.m_heights )
125+ {
126+ auto k = p.first ;
127+ auto v = p.second ;
128+ m_heights[k] = std::max (m_heights[k], v + 1 );
129+ }
130+ }
131+ int64_t max_jump_to (const HeightMap& target)
132+ {
133+ int64_t result = 0 ;
134+ for (auto & p : m_heights)
135+ {
136+ auto k = p.first ;
137+ auto v = p.second ;
138+ if (target.m_heights .count (k) != 0 )
139+ {
140+ result = std::max (result, std::abs (target.m_heights .at (k) - v));
141+ }
142+ }
143+ return result;
144+ }
145+
146+ private:
147+ std::unordered_map<Node*, int64_t > m_heights;
148+ };
149+
150+ static std::string label_edge (const std::shared_ptr<Node>& src,
151+ const std::shared_ptr<Node>& dst,
152+ size_t arg_index,
153+ int64_t jump_distance)
154+ {
155+ std::stringstream ss;
156+ if (getenv (" NGRAPH_VISUALIZE_EDGE_LABELS" ) != nullptr )
157+ {
158+ size_t output = 0 ;
159+ if (auto goe = dynamic_pointer_cast<op::GetOutputElement>(dst))
160+ {
161+ output = goe->get_n ();
162+ }
163+ stringstream label_edge;
164+ label_edge << " [label=\" " << output << " -> " << arg_index << " \" ]" ;
165+ ss << label_edge.str ();
166+ }
167+
168+ else if (getenv (" NGRAPH_VISUALIZE_EDGE_JUMP_DISTANCE" ) != nullptr )
169+ {
170+ if (jump_distance > 1 )
171+ {
172+ stringstream label_edge;
173+ label_edge << " [label=\" jump=" << jump_distance << " \" ]" ;
174+ ss << label_edge.str ();
175+ }
176+ }
177+ return ss.str ();
178+ }
179+
32180bool pass::VisualizeTree::run_on_module (vector<shared_ptr<Function>>& functions)
33181{
34182 for (shared_ptr<Function> f : functions)
35183 {
36- // map<size_t, list<node_ptr>> dependent_nodes;
184+ unordered_map<Node*, HeightMap> height_maps;
185+
186+ for (auto & node : f->get_ops ())
187+ {
188+ if (node->description () == " Result" )
189+ {
190+ height_maps[node.get ()] = HeightMap ({node.get ()});
191+ }
192+ else
193+ {
194+ height_maps[node.get ()] = HeightMap ();
195+ }
196+ }
197+
198+ auto nodes = topological_sort (f->get_ops ());
199+ nodes.reverse ();
200+
201+ for (auto & node : nodes)
202+ {
203+ for (auto & output : node->outputs ())
204+ {
205+ for (auto & input : output.get_target_inputs ())
206+ {
207+ auto target_node = input.get_node ();
208+ height_maps[node.get ()].absorb (height_maps[target_node]);
209+ }
210+ }
211+ }
212+
213+ // TODO(amprocte): Maybe find a way to make this tunable.
214+ const int max_jump_distance = 20 ;
215+
216+ size_t fake_node_ctr = 0 ;
217+
37218 traverse_nodes (f, [&](shared_ptr<Node> node) {
38- size_t i = 0 ;
219+ size_t arg_index = 0 ;
39220 for (auto arg : node->get_arguments ())
40221 {
41- m_ss << add_attributes (arg);
42- m_ss << add_attributes (node);
43- m_ss << " " << arg->get_name () << " -> " << node->get_name ();
222+ size_t jump_distance = height_maps[arg.get ()].max_jump_to (height_maps[node.get ()]);
44223
45- if (getenv ( " NGRAPH_VISUALIZE_EDGE_LABELS " ) != nullptr )
224+ if (arg-> description () == " Constant " || arg-> description () == " Parameter " )
46225 {
47- size_t output = 0 ;
48- if ( auto goe = dynamic_pointer_cast<op::GetOutputElement>(node))
49- {
50- output = goe-> get_n ();
51- }
52- stringstream label_edge;
53- label_edge << " [label= \" " << output << " -> " << i << " \" ] " ;
54- m_ss << label_edge. str () ;
226+ auto clone_name = " CLONE_ " + to_string (fake_node_ctr) ;
227+ auto color = (arg-> description () == " Parameter " ? " blue " : " black " );
228+ m_ss << " " << clone_name
229+ << " [shape= \" box \" style= \" dashed,filled \" color= \" " << color
230+ << " \" fillcolor= \" white \" label= \" " << arg-> get_name () << " \" ] \n " ;
231+ m_ss << " " << clone_name << " -> " << node-> get_name ()
232+ << label_edge (arg, node, arg_index, jump_distance) << " \n " ;
233+ fake_node_ctr++ ;
55234 }
235+ else if (jump_distance > max_jump_distance)
236+ {
237+ m_ss << add_attributes (arg);
238+ m_ss << add_attributes (node);
239+ auto recv_node_name = " RECV_" + to_string (fake_node_ctr);
240+ auto send_node_name = " SEND_" + to_string (fake_node_ctr);
56241
57- m_ss << " ;\n " ;
58- i++;
242+ m_ss << " " << recv_node_name << " [shape=\" box\" style=\" solid,filled\" "
243+ " fillcolor=\" #ffcccc\" label=\" Receive["
244+ << arg->get_name () << " ]\" ]\n " ;
245+ m_ss << " " << send_node_name << " [shape=\" box\" style=\" solid,filled\" "
246+ " fillcolor=\" #ccffcc\" label=\" Send["
247+ << node->get_name () << " ]\" ]\n " ;
248+
249+ m_ss << " " << arg->get_name () << " -> " << send_node_name
250+ << label_edge (arg, node, arg_index, jump_distance) << " \n " ;
251+ m_ss << " " << recv_node_name << " -> " << node->get_name ()
252+ << label_edge (arg, node, arg_index, jump_distance) << " \n " ;
253+ fake_node_ctr++;
254+ }
255+ else
256+ {
257+ m_ss << add_attributes (arg);
258+ m_ss << add_attributes (node);
259+ m_ss << " " << arg->get_name () << " -> " << node->get_name ()
260+ << label_edge (arg, node, arg_index, jump_distance) << " \n " ;
261+ }
262+ arg_index++;
59263 }
60264 });
61265 }
@@ -86,30 +290,22 @@ string pass::VisualizeTree::add_attributes(shared_ptr<Node> node)
86290string pass::VisualizeTree::get_attributes (shared_ptr<Node> node)
87291{
88292 vector<string> attributes;
89- if (node->is_parameter () || node->is_output ())
293+ attributes.push_back (" shape=box" );
294+
295+ if (node->is_output ())
90296 {
91- attributes.push_back (" shape=box" );
92- if (node->is_parameter ())
93- {
94- attributes.push_back (" color=blue" );
95- attributes.push_back (" penwidth=1.5" );
96- }
97- if (node->is_output ())
98- {
99- attributes.push_back (" color=crimson" );
100- attributes.push_back (" penwidth=1.5" );
101- }
297+ attributes.push_back (" color=crimson" );
298+ attributes.push_back (" penwidth=1.5" );
102299 }
103300 else
104301 {
105- attributes.push_back (" shape=ellipse" );
106302 attributes.push_back (" color=black" );
107303 }
108304
109305 // Construct the label attribute
110306 {
111307 stringstream label;
112- label << " label=\" " << node->get_friendly_name ();
308+ label << " label=\" " << node->get_name ();
113309
114310 static const char * nvtos = getenv (" NGRAPH_VISUALIZE_TREE_OUTPUT_SHAPES" );
115311 if (nvtos != nullptr )
@@ -156,7 +352,7 @@ string pass::VisualizeTree::get_file_ext()
156352 const char * format = getenv (" NGRAPH_VISUALIZE_TREE_OUTPUT_FORMAT" );
157353 if (!format)
158354 {
159- format = " png " ;
355+ format = " dot " ;
160356 }
161357
162358 if (format[0 ] == ' .' )
@@ -178,11 +374,12 @@ void pass::VisualizeTree::render() const
178374 out << " }\n " ;
179375 out.close ();
180376
181- if (!m_dot_only)
377+ if (!m_dot_only && get_file_ext () != " dot " )
182378 {
183379#ifndef _WIN32
184380 stringstream ss;
185- ss << " dot -T" << get_file_ext () << " " << dot_file << " -o " << m_name;
381+ ss << " dot -T" << get_file_ext () << " " << dot_file << " -o" << m_name << " ."
382+ << get_file_ext ();
186383 auto cmd = ss.str ();
187384 auto stream = popen (cmd.c_str (), " r" );
188385 if (stream)
0 commit comments