Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit 2069823

Browse files
authored
Migrate #2864 (#2901)
* Migrate #2864 * Remove extra changes
1 parent 8728d7c commit 2069823

File tree

2 files changed

+233
-37
lines changed

2 files changed

+233
-37
lines changed

src/ngraph/pass/manager.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,19 +160,18 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive)
160160
std::string index_str = std::to_string(index);
161161
index_str = std::string(num_digits_in_pass_index - index_str.length(), '0') + index_str;
162162
auto base_filename = f_array.at(0)->get_name() + std::string("_") + index_str +
163-
std::string("_") + m_pass_names.at(index) + std::string(".");
163+
std::string("_") + m_pass_names.at(index);
164164

165165
if (m_visualize)
166166
{
167-
pass::VisualizeTree vt(base_filename + pass::VisualizeTree::get_file_ext());
167+
pass::VisualizeTree vt(base_filename);
168168
vt.set_ops_to_details(get_state().get_visualize_tree_ops_map());
169169
vt.run_on_module(f_array);
170170
}
171171

172172
if (m_serialize)
173173
{
174-
// no "." in the extension
175-
pass::Serialization st(base_filename + "json");
174+
pass::Serialization st(base_filename + ".json");
176175
st.run_on_module(f_array);
177176
}
178177
}

src/ngraph/pass/visualize_tree.cpp

Lines changed: 230 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -29,33 +29,237 @@ using namespace std;
2929

3030
#define TI(x) type_index(typeid(x))
3131

32+
//
33+
// As we are visualizing the graph, we will make some tweaks to the generated dot file to make
34+
// routing more tractable for Graphviz as well as (hopefully) more legible for the user.
35+
//
36+
// NOTE: It's possible, even likely, that better algorithms are available here. I just tried a
37+
// few different things without doing much research, and this seemed to work well. Please feel
38+
// free to improve on this. --amprocte
39+
//
40+
// -----------------
41+
//
42+
// The first tweak is to trim edges that, intuitively speaking, have long "skip distance". For
43+
// example:
44+
//
45+
// [Actual Graph Structure] [Visualization]
46+
// n0 n0
47+
// | \ | \
48+
// n1 \ n1 [to n50]
49+
// | | |
50+
// n2 | n2
51+
// | | |
52+
// n3 | n3
53+
// | | |
54+
// ... | ... [from n0]
55+
// | / | /
56+
// n50 n50
57+
//
58+
// This is useful for training graphs especially, which tend to have very long feed-forward edges
59+
// for intermediate values from fprop being stored for later reuse in the bprop phase.
60+
//
61+
// Efficiently detecting a "long skip" is a bit tricky. We want to come up with a metric that is
62+
// reasonably fast to compute, but does not result in cuts that will split the graph into multiple
63+
// components. The heuristic we are using for the jump distance between n and m is the maximum
64+
// difference in maximum path length from n and m to any result node that is reachable from both
65+
// n and m (or 0, if no such result node exists). Not sure if this is mathematically *guaranteed*
66+
// not to split graph components, but it seems to work well in practice.
67+
//
68+
// Formally:
69+
//
70+
// Compute-Heights-Above-Each-Parameter(N):
71+
// Inputs: nodes N; define R={n in N | n is a Result node}
72+
// Output: height_maps: map from N to (map from R to int)
73+
//
74+
// height_maps is initially empty
75+
//
76+
// for each r in R:
77+
// Insert into height_map the map {r -> 1}
78+
//
79+
// for each n in N in reverse topological ("results-first") order:
80+
// for each user m of n:
81+
// for each r in height_maps[m].keys:
82+
// height_maps[n][r] := max(height_maps[n][r], height_maps[m][r]+1)
83+
//
84+
// Jump-Distance(n,m,height_maps):
85+
// Inputs: n (source node), m (destination node), height_maps (pre-computed above)
86+
// Output: jump_distance: int
87+
//
88+
// jump_distance := 0
89+
//
90+
// for each r in height_maps[n].keys:
91+
// if r is in height_maps[m].keys:
92+
// jump_distance := max(jump_distance, abs(height_maps[n][r] - height_maps[m][r]))
93+
//
94+
// Later on, if E is an edge from n to m, and Jump-Distance(n,m,height_map) > K (where K is kind
95+
// of arbitrary but currently set to 20), we will "cut" the edge as illustrated above.
96+
//
97+
// -----------------
98+
//
99+
// The second tweak aims to eliminate routing pressure from nodes that have large outdegree and
100+
// are connected to many otherwise-distant places in the graph. For this, the only thing we are
101+
// doing at the moment is to "float" Parameter and Constant nodes. This means that rather than
102+
// visualizing them as a single node (which might have very large outdegree as in, e.g., a
103+
// learning rate parameter being fed to many different places), we make a "copy" of the node at
104+
// each occurrence site (with a dashed outline).
105+
//
106+
// NOTE: This tweak could probably be extended to float other kinds of nodes with high out-degree.
107+
// (This situation is likely to arise after constant subexpression elimination.) Here one has to
108+
// be careful to avoid splitting the components. I have some rough ideas on how this could be
109+
// dealt with, but have not had time to implement them yet. --amprocte
110+
//
111+
class HeightMap
112+
{
113+
public:
114+
HeightMap() {}
115+
HeightMap(std::set<Node*> initials)
116+
{
117+
for (auto& n : initials)
118+
{
119+
m_heights[n] = 0;
120+
}
121+
}
122+
void absorb(const HeightMap& other)
123+
{
124+
for (auto& p : other.m_heights)
125+
{
126+
auto k = p.first;
127+
auto v = p.second;
128+
m_heights[k] = std::max(m_heights[k], v + 1);
129+
}
130+
}
131+
int64_t max_jump_to(const HeightMap& target)
132+
{
133+
int64_t result = 0;
134+
for (auto& p : m_heights)
135+
{
136+
auto k = p.first;
137+
auto v = p.second;
138+
if (target.m_heights.count(k) != 0)
139+
{
140+
result = std::max(result, std::abs(target.m_heights.at(k) - v));
141+
}
142+
}
143+
return result;
144+
}
145+
146+
private:
147+
std::unordered_map<Node*, int64_t> m_heights;
148+
};
149+
150+
static std::string label_edge(const std::shared_ptr<Node>& src,
151+
const std::shared_ptr<Node>& dst,
152+
size_t arg_index,
153+
int64_t jump_distance)
154+
{
155+
std::stringstream ss;
156+
if (getenv("NGRAPH_VISUALIZE_EDGE_LABELS") != nullptr)
157+
{
158+
size_t output = 0;
159+
if (auto goe = dynamic_pointer_cast<op::GetOutputElement>(dst))
160+
{
161+
output = goe->get_n();
162+
}
163+
stringstream label_edge;
164+
label_edge << "[label=\" " << output << " -> " << arg_index << " \"]";
165+
ss << label_edge.str();
166+
}
167+
168+
else if (getenv("NGRAPH_VISUALIZE_EDGE_JUMP_DISTANCE") != nullptr)
169+
{
170+
if (jump_distance > 1)
171+
{
172+
stringstream label_edge;
173+
label_edge << "[label=\"jump=" << jump_distance << "\"]";
174+
ss << label_edge.str();
175+
}
176+
}
177+
return ss.str();
178+
}
179+
32180
bool pass::VisualizeTree::run_on_module(vector<shared_ptr<Function>>& functions)
33181
{
34182
for (shared_ptr<Function> f : functions)
35183
{
36-
// map<size_t, list<node_ptr>> dependent_nodes;
184+
unordered_map<Node*, HeightMap> height_maps;
185+
186+
for (auto& node : f->get_ops())
187+
{
188+
if (node->description() == "Result")
189+
{
190+
height_maps[node.get()] = HeightMap({node.get()});
191+
}
192+
else
193+
{
194+
height_maps[node.get()] = HeightMap();
195+
}
196+
}
197+
198+
auto nodes = topological_sort(f->get_ops());
199+
nodes.reverse();
200+
201+
for (auto& node : nodes)
202+
{
203+
for (auto& output : node->outputs())
204+
{
205+
for (auto& input : output.get_target_inputs())
206+
{
207+
auto target_node = input.get_node();
208+
height_maps[node.get()].absorb(height_maps[target_node]);
209+
}
210+
}
211+
}
212+
213+
// TODO(amprocte): Maybe find a way to make this tunable.
214+
const int max_jump_distance = 20;
215+
216+
size_t fake_node_ctr = 0;
217+
37218
traverse_nodes(f, [&](shared_ptr<Node> node) {
38-
size_t i = 0;
219+
size_t arg_index = 0;
39220
for (auto arg : node->get_arguments())
40221
{
41-
m_ss << add_attributes(arg);
42-
m_ss << add_attributes(node);
43-
m_ss << " " << arg->get_name() << " -> " << node->get_name();
222+
size_t jump_distance = height_maps[arg.get()].max_jump_to(height_maps[node.get()]);
44223

45-
if (getenv("NGRAPH_VISUALIZE_EDGE_LABELS") != nullptr)
224+
if (arg->description() == "Constant" || arg->description() == "Parameter")
46225
{
47-
size_t output = 0;
48-
if (auto goe = dynamic_pointer_cast<op::GetOutputElement>(node))
49-
{
50-
output = goe->get_n();
51-
}
52-
stringstream label_edge;
53-
label_edge << "[label=\" " << output << " -> " << i << " \"]";
54-
m_ss << label_edge.str();
226+
auto clone_name = "CLONE_" + to_string(fake_node_ctr);
227+
auto color = (arg->description() == "Parameter" ? "blue" : "black");
228+
m_ss << " " << clone_name
229+
<< "[shape=\"box\" style=\"dashed,filled\" color=\"" << color
230+
<< "\" fillcolor=\"white\" label=\"" << arg->get_name() << "\"]\n";
231+
m_ss << " " << clone_name << " -> " << node->get_name()
232+
<< label_edge(arg, node, arg_index, jump_distance) << "\n";
233+
fake_node_ctr++;
55234
}
235+
else if (jump_distance > max_jump_distance)
236+
{
237+
m_ss << add_attributes(arg);
238+
m_ss << add_attributes(node);
239+
auto recv_node_name = "RECV_" + to_string(fake_node_ctr);
240+
auto send_node_name = "SEND_" + to_string(fake_node_ctr);
56241

57-
m_ss << ";\n";
58-
i++;
242+
m_ss << " " << recv_node_name << "[shape=\"box\" style=\"solid,filled\" "
243+
"fillcolor=\"#ffcccc\" label=\"Receive["
244+
<< arg->get_name() << "]\"]\n";
245+
m_ss << " " << send_node_name << "[shape=\"box\" style=\"solid,filled\" "
246+
"fillcolor=\"#ccffcc\" label=\"Send["
247+
<< node->get_name() << "]\"]\n";
248+
249+
m_ss << " " << arg->get_name() << " -> " << send_node_name
250+
<< label_edge(arg, node, arg_index, jump_distance) << "\n";
251+
m_ss << " " << recv_node_name << " -> " << node->get_name()
252+
<< label_edge(arg, node, arg_index, jump_distance) << "\n";
253+
fake_node_ctr++;
254+
}
255+
else
256+
{
257+
m_ss << add_attributes(arg);
258+
m_ss << add_attributes(node);
259+
m_ss << " " << arg->get_name() << " -> " << node->get_name()
260+
<< label_edge(arg, node, arg_index, jump_distance) << "\n";
261+
}
262+
arg_index++;
59263
}
60264
});
61265
}
@@ -86,30 +290,22 @@ string pass::VisualizeTree::add_attributes(shared_ptr<Node> node)
86290
string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
87291
{
88292
vector<string> attributes;
89-
if (node->is_parameter() || node->is_output())
293+
attributes.push_back("shape=box");
294+
295+
if (node->is_output())
90296
{
91-
attributes.push_back("shape=box");
92-
if (node->is_parameter())
93-
{
94-
attributes.push_back("color=blue");
95-
attributes.push_back("penwidth=1.5");
96-
}
97-
if (node->is_output())
98-
{
99-
attributes.push_back("color=crimson");
100-
attributes.push_back("penwidth=1.5");
101-
}
297+
attributes.push_back("color=crimson");
298+
attributes.push_back("penwidth=1.5");
102299
}
103300
else
104301
{
105-
attributes.push_back("shape=ellipse");
106302
attributes.push_back("color=black");
107303
}
108304

109305
// Construct the label attribute
110306
{
111307
stringstream label;
112-
label << "label=\"" << node->get_friendly_name();
308+
label << "label=\"" << node->get_name();
113309

114310
static const char* nvtos = getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_SHAPES");
115311
if (nvtos != nullptr)
@@ -156,7 +352,7 @@ string pass::VisualizeTree::get_file_ext()
156352
const char* format = getenv("NGRAPH_VISUALIZE_TREE_OUTPUT_FORMAT");
157353
if (!format)
158354
{
159-
format = "png";
355+
format = "dot";
160356
}
161357

162358
if (format[0] == '.')
@@ -178,11 +374,12 @@ void pass::VisualizeTree::render() const
178374
out << "}\n";
179375
out.close();
180376

181-
if (!m_dot_only)
377+
if (!m_dot_only && get_file_ext() != "dot")
182378
{
183379
#ifndef _WIN32
184380
stringstream ss;
185-
ss << "dot -T" << get_file_ext() << " " << dot_file << " -o " << m_name;
381+
ss << "dot -T" << get_file_ext() << " " << dot_file << " -o" << m_name << "."
382+
<< get_file_ext();
186383
auto cmd = ss.str();
187384
auto stream = popen(cmd.c_str(), "r");
188385
if (stream)

0 commit comments

Comments
 (0)