Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit 0ea01ea

Browse files
author
Adam Procter
authored
Merge pull request #1662 from NervanaSystems/aprocter/cherry-picks
Cherry-pick "Common pass registration for codegen and Dex (#1642)" to r0.8
2 parents 2822885 + f117269 commit 0ea01ea

File tree

4 files changed

+28
-36
lines changed

4 files changed

+28
-36
lines changed

src/ngraph/runtime/cpu/cpu_external_function.cpp

Lines changed: 17 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -368,20 +368,8 @@ static void
368368
writer << "}\n";
369369
}
370370

371-
void runtime::cpu::CPU_ExternalFunction::compile()
371+
void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Manager& pass_manager)
372372
{
373-
if (m_is_compiled)
374-
{
375-
return;
376-
}
377-
378-
m_mkldnn_emitter.reset(new MKLDNNEmitter());
379-
380-
ngraph::pass::Manager pass_manager;
381-
382-
// nv_cwi is required only by some frontends
383-
// in which case they should run this pass(CPUWorkspaceInsertion) explicitly
384-
NodeVector nv_cwi;
385373
pass_manager.register_pass<ngraph::pass::LikeReplacement>();
386374
pass_manager.register_pass<ngraph::pass::NopElimination>();
387375
// TODO (pruthvi): Enable all the disabeled RNN fusion graph pass after fixing
@@ -396,11 +384,25 @@ void runtime::cpu::CPU_ExternalFunction::compile()
396384
pass_manager.register_pass<ngraph::pass::CoreFusion>();
397385
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
398386
pass_manager.register_pass<runtime::cpu::pass::CPUCollapseDims>();
399-
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>(nv_cwi);
387+
NodeVector nv_cwi; // We dont need CPUWorkspaceInsertion to return list of indices
388+
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>(nv_cwi, false);
400389
pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this);
401390
pass_manager.register_pass<runtime::cpu::pass::CPULayout>(this);
402391
pass_manager.register_pass<runtime::cpu::pass::CPUPostLayoutOptimizations>();
403392
pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>();
393+
}
394+
395+
void runtime::cpu::CPU_ExternalFunction::compile()
396+
{
397+
if (m_is_compiled)
398+
{
399+
return;
400+
}
401+
402+
m_mkldnn_emitter.reset(new MKLDNNEmitter());
403+
404+
ngraph::pass::Manager pass_manager;
405+
register_common_passes(pass_manager);
404406
unordered_map<Node*, Node*> node_function_map;
405407
string common_function_string;
406408
auto femitter = bind(&ngraph::runtime::cpu::CPU_ExternalFunction::emit_op_as_function,
@@ -1132,27 +1134,8 @@ void runtime::cpu::CPU_ExternalFunction::build()
11321134
m_mkldnn_emitter.reset(new MKLDNNEmitter());
11331135

11341136
ngraph::pass::Manager pass_manager;
1137+
register_common_passes(pass_manager);
11351138

1136-
// nv_cwi is required only by some frontends
1137-
// in which case they should run this pass(CPUWorkspaceInsertion) explicitly
1138-
NodeVector nv_cwi;
1139-
pass_manager.register_pass<ngraph::pass::NopElimination>();
1140-
// TODO (pruthvi): Enable all the disabeled RNN fusion graph pass after fixing
1141-
// failing mxnet unit tests.
1142-
// pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>();
1143-
// pass_manager.register_pass<runtime::cpu::pass::RNNFusion>();
1144-
// pass_manager.register_pass<runtime::cpu::pass::ConcatInputs>();
1145-
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
1146-
pass_manager.register_pass<runtime::cpu::pass::CPUBatchFusion>();
1147-
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
1148-
pass_manager.register_pass<ngraph::pass::CoreFusion>();
1149-
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
1150-
pass_manager.register_pass<runtime::cpu::pass::CPUCollapseDims>();
1151-
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>(nv_cwi);
1152-
pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this);
1153-
pass_manager.register_pass<runtime::cpu::pass::CPULayout>(this);
1154-
pass_manager.register_pass<runtime::cpu::pass::CPUPostLayoutOptimizations>();
1155-
pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>();
11561139
pass_manager.register_pass<ngraph::pass::Liveness>();
11571140
pass_manager.register_pass<ngraph::pass::MemoryLayout>(size_t(s_memory_pool_alignment), true);
11581141
pass_manager.run_passes(m_function, false);

src/ngraph/runtime/cpu/cpu_external_function.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#endif
3737

3838
#include "ngraph/function.hpp"
39+
#include "ngraph/pass/manager.hpp"
3940
#include "ngraph/runtime/cpu/cpu_call_frame.hpp"
4041
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
4142
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
@@ -139,6 +140,9 @@ namespace ngraph
139140
#endif
140141

141142
private:
143+
// Register passes that are common to codegen and DEX
144+
void register_common_passes(ngraph::pass::Manager& pass_manager);
145+
142146
// For non-destructive passthrough kernels, propagate function
143147
// input buffers to internal ops
144148
void propagate_in_place_input(ngraph::descriptor::Output* output,

src/ngraph/runtime/cpu/pass/cpu_workspace_insertion.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,9 @@ bool runtime::cpu::pass::CPUWorkspaceInsertion::transform(pattern::Matcher& m)
168168
m_max_pool->get_padding_above());
169169

170170
ngraph::replace_node(m_max_pool_bprop, max_pool_with_indices_bprop);
171-
m_indices_list.push_back(max_pool_with_indices_indices);
171+
if (m_return_indices)
172+
{
173+
m_indices_list.push_back(max_pool_with_indices_indices);
174+
}
172175
return true;
173176
}

src/ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,17 @@ namespace ngraph
3636
class ngraph::runtime::cpu::pass::CPUWorkspaceInsertion : public ngraph::pass::FunctionPass
3737
{
3838
public:
39-
CPUWorkspaceInsertion(ngraph::NodeVector& indices_list)
39+
CPUWorkspaceInsertion(ngraph::NodeVector& indices_list, bool return_indices = true)
4040
: FunctionPass()
4141
, m_indices_list(indices_list)
42+
, m_return_indices(return_indices)
4243
{
4344
}
4445

4546
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
4647

4748
private:
4849
ngraph::NodeVector& m_indices_list;
50+
bool m_return_indices;
4951
bool transform(ngraph::pattern::Matcher& m);
5052
};

0 commit comments

Comments
 (0)