@@ -43,6 +43,25 @@ class ngraph::pass::Manager
4343
4444 template <typename T, class ... Args>
4545 void register_pass (Args&&... args)
46+ {
47+ push_pass<T>(std::forward<Args>(args)...);
48+ if (m_per_pass_validation)
49+ {
50+ push_pass<Validate>();
51+ }
52+ }
53+
54+ void run_passes (std::shared_ptr<Function>, bool transitive = true );
55+
56+ ManagerState& get_state ();
57+ PassConfig& get_pass_config () { return m_pass_config; }
58+ void set_pass_config (const PassConfig& pass_config) { m_pass_config = pass_config; }
59+ void set_pass_visualization (bool new_state) { m_visualize = new_state; }
60+ void set_pass_serialization (bool new_state) { m_serialize = new_state; }
61+ void set_per_pass_validation (bool new_state) { m_per_pass_validation = new_state; }
62+ private:
63+ template <typename T, class ... Args>
64+ void push_pass (Args&&... args)
4665 {
4766 static_assert (std::is_base_of<pass::PassBase, T>::value, " pass not derived from pass base" );
4867 auto pass = std::make_shared<T>(std::forward<Args>(args)...);
@@ -61,23 +80,8 @@ class ngraph::pass::Manager
6180 m_pass_names.push_back (typeid (T).name ());
6281#endif
6382 }
64- if (m_per_pass_validation)
65- {
66- auto validate = std::make_shared<Validate>();
67- auto validate_base = std::static_pointer_cast<PassBase>(validate);
68- m_pass_list.push_back (validate_base);
69- }
7083 }
7184
72- void run_passes (std::shared_ptr<Function>, bool transitive = true );
73-
74- ManagerState& get_state ();
75- PassConfig& get_pass_config () { return m_pass_config; }
76- void set_pass_config (const PassConfig& pass_config) { m_pass_config = pass_config; }
77- void set_pass_visualization (bool new_state) { m_visualize = new_state; }
78- void set_pass_serialization (bool new_state) { m_serialize = new_state; }
79- void set_per_pass_validation (bool new_state) { m_per_pass_validation = new_state; }
80- private:
8185 std::vector<std::string> m_pass_names;
8286 std::vector<std::shared_ptr<PassBase>> m_pass_list;
8387 ManagerState m_state;
0 commit comments