Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit 2612994

Browse files
authored
Merge pull request #3403 from NervanaSystems/aprocter/visualize-crash-fix-r025
Cherry-pick "Fix crash when NGRAPH_ENABLE_{VISUALIZE,SERIALIZE}_TRACING=1" to r0.25
2 parents 9937f8b + 6696c14 commit 2612994

File tree

2 files changed

+45
-15
lines changed

2 files changed

+45
-15
lines changed

src/ngraph/pass/manager.hpp

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,25 @@ class ngraph::pass::Manager
4343

4444
template <typename T, class... Args>
4545
void register_pass(Args&&... args)
46+
{
47+
push_pass<T>(std::forward<Args>(args)...);
48+
if (m_per_pass_validation)
49+
{
50+
push_pass<Validate>();
51+
}
52+
}
53+
54+
void run_passes(std::shared_ptr<Function>, bool transitive = true);
55+
56+
ManagerState& get_state();
57+
PassConfig& get_pass_config() { return m_pass_config; }
58+
void set_pass_config(const PassConfig& pass_config) { m_pass_config = pass_config; }
59+
void set_pass_visualization(bool new_state) { m_visualize = new_state; }
60+
void set_pass_serialization(bool new_state) { m_serialize = new_state; }
61+
void set_per_pass_validation(bool new_state) { m_per_pass_validation = new_state; }
62+
private:
63+
template <typename T, class... Args>
64+
void push_pass(Args&&... args)
4665
{
4766
static_assert(std::is_base_of<pass::PassBase, T>::value, "pass not derived from pass base");
4867
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
@@ -61,23 +80,8 @@ class ngraph::pass::Manager
6180
m_pass_names.push_back(typeid(T).name());
6281
#endif
6382
}
64-
if (m_per_pass_validation)
65-
{
66-
auto validate = std::make_shared<Validate>();
67-
auto validate_base = std::static_pointer_cast<PassBase>(validate);
68-
m_pass_list.push_back(validate_base);
69-
}
7083
}
7184

72-
void run_passes(std::shared_ptr<Function>, bool transitive = true);
73-
74-
ManagerState& get_state();
75-
PassConfig& get_pass_config() { return m_pass_config; }
76-
void set_pass_config(const PassConfig& pass_config) { m_pass_config = pass_config; }
77-
void set_pass_visualization(bool new_state) { m_visualize = new_state; }
78-
void set_pass_serialization(bool new_state) { m_serialize = new_state; }
79-
void set_per_pass_validation(bool new_state) { m_per_pass_validation = new_state; }
80-
private:
8185
std::vector<std::string> m_pass_names;
8286
std::vector<std::shared_ptr<PassBase>> m_pass_list;
8387
ManagerState m_state;

test/pass_manager.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,29 @@ TEST(pass_manager, add)
4141
EXPECT_EQ(node_count, sorted.size());
4242
EXPECT_TRUE(validate_list(sorted));
4343
}
44+
45+
namespace
46+
{
47+
class DummyPass : public pass::FunctionPass
48+
{
49+
public:
50+
DummyPass()
51+
: FunctionPass()
52+
{
53+
}
54+
bool run_on_function(std::shared_ptr<ngraph::Function> f) override { return false; }
55+
};
56+
}
57+
58+
// Regression test: We've had an issue in the past where enabling per-pass validation and
59+
// per-pass serialization at the same time causes a crash.
60+
TEST(pass_manager, serialize_with_revalidate_does_not_crash)
61+
{
62+
pass::Manager pass_manager;
63+
pass_manager.set_per_pass_validation(true);
64+
pass_manager.set_pass_serialization(true);
65+
pass_manager.register_pass<DummyPass>();
66+
67+
auto graph = make_test_graph();
68+
pass_manager.run_passes(graph);
69+
}

0 commit comments

Comments
 (0)