@@ -42,124 +42,147 @@ using namespace std;
4242
4343#define TI (x ) std::type_index(typeid (x))
4444
45- static const std::unordered_set<std::type_index> s_op_registry{
46- TI (ngraph::op::Add),
47- TI (ngraph::op::AvgPool),
48- TI (ngraph::op::AvgPoolBackprop),
49- TI (ngraph::op::BatchNorm),
50- TI (ngraph::op::BatchNormBackprop),
51- TI (ngraph::op::Concat),
52- TI (ngraph::op::Convolution),
53- TI (ngraph::op::ConvolutionBackpropData),
54- TI (ngraph::op::ConvolutionBackpropFilters),
55- TI (ngraph::op::ConvolutionBias),
56- TI (ngraph::op::ConvolutionRelu),
57- TI (ngraph::op::ConvolutionBiasBackpropFiltersBias),
58- TI (ngraph::op::MaxPool),
59- TI (ngraph::op::MaxPoolBackprop),
60- TI (ngraph::op::Relu),
61- TI (ngraph::op::ReluBackprop),
62- TI (ngraph::op::Reshape)};
63-
64- // Mapping from POD types to MKLDNN data types
65- static const std::map<element::Type, const mkldnn::memory::data_type> s_mkldnn_data_type_map{
66- {element::boolean, mkldnn::memory::data_type::s8},
67- {element::f32 , mkldnn::memory::data_type::f32 },
68- {element::f64 , mkldnn::memory::data_type::data_undef},
69- {element::i8 , mkldnn::memory::data_type::s8},
70- {element::i16 , mkldnn::memory::data_type::s16},
71- {element::i32 , mkldnn::memory::data_type::s32},
72- {element::i64 , mkldnn::memory::data_type::data_undef},
73- {element::u8 , mkldnn::memory::data_type::u8 },
74- {element::u16 , mkldnn::memory::data_type::data_undef},
75- {element::u32 , mkldnn::memory::data_type::data_undef},
76- {element::u64 , mkldnn::memory::data_type::data_undef}};
77-
78- static const std::map<element::Type, const std::string> s_mkldnn_data_type_string_map{
79- {element::boolean, " mkldnn::memory::data_type::s8" },
80- {element::f32 , " mkldnn::memory::data_type::f32" },
81- {element::f64 , " mkldnn::memory::data_type::data_undef" },
82- {element::i8 , " mkldnn::memory::data_type::s8" },
83- {element::i16 , " mkldnn::memory::data_type::s16" },
84- {element::i32 , " mkldnn::memory::data_type::s32" },
85- {element::i64 , " mkldnn::memory::data_type::data_undef" },
86- {element::u8 , " mkldnn::memory::data_type::u8" },
87- {element::u16 , " mkldnn::memory::data_type::data_undef" },
88- {element::u32 , " mkldnn::memory::data_type::data_undef" },
89- {element::u64 , " mkldnn::memory::data_type::data_undef" }};
90-
91- // TODO (jbobba): Add the rest of memory formats to this map as well
92- static const std::map<memory::format, const std::string> s_mkldnn_format_string_map{
93- {memory::format::format_undef, " memory::format::format_undef" },
94- {memory::format::any, " memory::format::any" },
95- {memory::format::blocked, " memory::format::blocked" },
96- {memory::format::x, " memory::format::x" },
97- {memory::format::nc, " memory::format::nc" },
98- {memory::format::nchw, " memory::format::nchw" },
99- {memory::format::nhwc, " memory::format::nhwc" },
100- {memory::format::chwn, " memory::format::chwn" },
101- {memory::format::nChw8c, " memory::format::nChw8c" },
102- {memory::format::nChw16c, " memory::format::nChw16c" },
103- {memory::format::ncdhw, " memory::format::ndhwc" },
104- {memory::format::ncdhw, " memory::format::ndhwc" },
105- {memory::format::nCdhw16c, " memory::format::nCdhw16c" },
106- {memory::format::oi, " memory::format::oi" },
107- {memory::format::io, " memory::format::io" },
108- {memory::format::oihw, " memory::format::oihw" },
109- {memory::format::ihwo, " memory::format::ihwo" },
110- {memory::format::hwio, " memory::format::hwio" },
111- // TODO (nishant): Uncomment after the next release of mkl-dnn"
112- // {memory::format::dhwio, "memory::format::dhwio"},
113- {memory::format::oidhw, " memory::format::oidhw" },
114- {memory::format::OIdhw16i16o, " memory::format::OIdhw16i16o" },
115- {memory::format::OIdhw16o16i, " memory::format::OIdhw16o16i" },
116- {memory::format::Oidhw16o, " memory::format::Oidhw16o" },
117- {memory::format::Odhwi16o, " memory::format::Odhwi16o" },
118- {memory::format::oIhw8i, " memory::format::oIhw8i" },
119- {memory::format::oIhw16i, " memory::format::oIhw16i" },
120- {memory::format::OIhw8i8o, " memory::format::OIhw8i8o" },
121- {memory::format::OIhw16i16o, " memory::format::OIhw16i16o" },
122- {memory::format::IOhw16o16i, " memory::format::IOhw16o16i" },
123- {memory::format::OIhw8o8i, " memory::format::OIhw8o8i" },
124- {memory::format::OIhw16o16i, " memory::format::OIhw16o16i" },
125- {memory::format::Oihw8o, " memory::format::Oihw8o" },
126- {memory::format::Oihw16o, " memory::format::Oihw16o" },
127- {memory::format::Ohwi8o, " memory::format::Ohwi8o" },
128- {memory::format::Ohwi16o, " memory::format::Ohwi16o" },
129- {memory::format::OhIw16o4i, " memory::format::OhIw16o4i" },
130- {memory::format::tnc, " memory::format::tnc" },
131- {memory::format::ldsnc, " memory::format::ldsnc" },
132- {memory::format::ldigo, " memory::format::ldigo" },
133- {memory::format::ldgo, " memory::format::ldgo" },
134- };
135-
136- static const std::set<memory::format> s_filter_formats{
137- memory::format::oihw,
138- memory::format::ihwo,
139- memory::format::hwio,
140- // TODO (nishant): Uncomment after the next release of mkl-dnn"
141- // memory::format::dhwio,
142- memory::format::oidhw,
143- memory::format::OIdhw16i16o,
144- memory::format::OIdhw16o16i,
145- memory::format::Oidhw16o,
146- memory::format::Odhwi16o,
147- // memory::format::oIhw8i, // These currently map to nChw8c and nChw16c
148- // memory::format::oIhw16i,
149- memory::format::OIhw8i8o,
150- memory::format::OIhw16i16o,
151- memory::format::IOhw16o16i,
152- memory::format::OIhw8o8i,
153- memory::format::OIhw16o16i,
154- memory::format::Oihw8o,
155- memory::format::Oihw16o,
156- memory::format::Ohwi8o,
157- memory::format::Ohwi16o,
158- memory::format::OhIw16o4i};
45+ std::unordered_set<std::type_index>& runtime::cpu::mkldnn_utils::get_op_registry ()
46+ {
47+ static std::unordered_set<std::type_index> s_op_registry{
48+ TI (ngraph::op::Add),
49+ TI (ngraph::op::AvgPool),
50+ TI (ngraph::op::AvgPoolBackprop),
51+ TI (ngraph::op::BatchNorm),
52+ TI (ngraph::op::BatchNormBackprop),
53+ TI (ngraph::op::Concat),
54+ TI (ngraph::op::Convolution),
55+ TI (ngraph::op::ConvolutionBackpropData),
56+ TI (ngraph::op::ConvolutionBackpropFilters),
57+ TI (ngraph::op::ConvolutionBias),
58+ TI (ngraph::op::ConvolutionRelu),
59+ TI (ngraph::op::ConvolutionBiasBackpropFiltersBias),
60+ TI (ngraph::op::MaxPool),
61+ TI (ngraph::op::MaxPoolBackprop),
62+ TI (ngraph::op::Relu),
63+ TI (ngraph::op::ReluBackprop),
64+ TI (ngraph::op::Reshape)};
65+ return s_op_registry;
66+ }
67+
68+ std::map<element::Type, const mkldnn::memory::data_type>&
69+ runtime::cpu::mkldnn_utils::get_mkldnn_data_type_map ()
70+ {
71+ // Mapping from POD types to MKLDNN data types
72+ static std::map<element::Type, const mkldnn::memory::data_type> s_mkldnn_data_type_map = {
73+ {element::boolean, mkldnn::memory::data_type::s8},
74+ {element::f32 , mkldnn::memory::data_type::f32 },
75+ {element::f64 , mkldnn::memory::data_type::data_undef},
76+ {element::i8 , mkldnn::memory::data_type::s8},
77+ {element::i16 , mkldnn::memory::data_type::s16},
78+ {element::i32 , mkldnn::memory::data_type::s32},
79+ {element::i64 , mkldnn::memory::data_type::data_undef},
80+ {element::u8 , mkldnn::memory::data_type::u8 },
81+ {element::u16 , mkldnn::memory::data_type::data_undef},
82+ {element::u32 , mkldnn::memory::data_type::data_undef},
83+ {element::u64 , mkldnn::memory::data_type::data_undef},
84+ };
85+ return s_mkldnn_data_type_map;
86+ }
87+
88+ std::map<element::Type, const std::string>&
89+ runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string_map ()
90+ {
91+ static std::map<element::Type, const std::string> s_mkldnn_data_type_string_map{
92+ {element::boolean, " mkldnn::memory::data_type::s8" },
93+ {element::f32 , " mkldnn::memory::data_type::f32" },
94+ {element::f64 , " mkldnn::memory::data_type::data_undef" },
95+ {element::i8 , " mkldnn::memory::data_type::s8" },
96+ {element::i16 , " mkldnn::memory::data_type::s16" },
97+ {element::i32 , " mkldnn::memory::data_type::s32" },
98+ {element::i64 , " mkldnn::memory::data_type::data_undef" },
99+ {element::u8 , " mkldnn::memory::data_type::u8" },
100+ {element::u16 , " mkldnn::memory::data_type::data_undef" },
101+ {element::u32 , " mkldnn::memory::data_type::data_undef" },
102+ {element::u64 , " mkldnn::memory::data_type::data_undef" }};
103+ return s_mkldnn_data_type_string_map;
104+ }
159105
106+ std::map<memory::format, const std::string>&
107+ runtime::cpu::mkldnn_utils::get_mkldnn_format_string_map ()
108+ {
109+ // TODO (jbobba): Add the rest of memory formats to this map as well
110+ static std::map<memory::format, const std::string> s_mkldnn_format_string_map{
111+ {memory::format::format_undef, " memory::format::format_undef" },
112+ {memory::format::any, " memory::format::any" },
113+ {memory::format::blocked, " memory::format::blocked" },
114+ {memory::format::x, " memory::format::x" },
115+ {memory::format::nc, " memory::format::nc" },
116+ {memory::format::nchw, " memory::format::nchw" },
117+ {memory::format::nhwc, " memory::format::nhwc" },
118+ {memory::format::chwn, " memory::format::chwn" },
119+ {memory::format::nChw8c, " memory::format::nChw8c" },
120+ {memory::format::nChw16c, " memory::format::nChw16c" },
121+ {memory::format::ncdhw, " memory::format::ndhwc" },
122+ {memory::format::ncdhw, " memory::format::ndhwc" },
123+ {memory::format::nCdhw16c, " memory::format::nCdhw16c" },
124+ {memory::format::oi, " memory::format::oi" },
125+ {memory::format::io, " memory::format::io" },
126+ {memory::format::oihw, " memory::format::oihw" },
127+ {memory::format::ihwo, " memory::format::ihwo" },
128+ {memory::format::hwio, " memory::format::hwio" },
129+ // TODO (nishant): Uncomment after the next release of mkl-dnn"
130+ // {memory::format::dhwio, "memory::format::dhwio"},
131+ {memory::format::oidhw, " memory::format::oidhw" },
132+ {memory::format::OIdhw16i16o, " memory::format::OIdhw16i16o" },
133+ {memory::format::OIdhw16o16i, " memory::format::OIdhw16o16i" },
134+ {memory::format::Oidhw16o, " memory::format::Oidhw16o" },
135+ {memory::format::Odhwi16o, " memory::format::Odhwi16o" },
136+ {memory::format::oIhw8i, " memory::format::oIhw8i" },
137+ {memory::format::oIhw16i, " memory::format::oIhw16i" },
138+ {memory::format::OIhw8i8o, " memory::format::OIhw8i8o" },
139+ {memory::format::OIhw16i16o, " memory::format::OIhw16i16o" },
140+ {memory::format::IOhw16o16i, " memory::format::IOhw16o16i" },
141+ {memory::format::OIhw8o8i, " memory::format::OIhw8o8i" },
142+ {memory::format::OIhw16o16i, " memory::format::OIhw16o16i" },
143+ {memory::format::Oihw8o, " memory::format::Oihw8o" },
144+ {memory::format::Oihw16o, " memory::format::Oihw16o" },
145+ {memory::format::Ohwi8o, " memory::format::Ohwi8o" },
146+ {memory::format::Ohwi16o, " memory::format::Ohwi16o" },
147+ {memory::format::OhIw16o4i, " memory::format::OhIw16o4i" },
148+ {memory::format::tnc, " memory::format::tnc" },
149+ {memory::format::ldsnc, " memory::format::ldsnc" },
150+ {memory::format::ldigo, " memory::format::ldigo" },
151+ {memory::format::ldgo, " memory::format::ldgo" },
152+ };
153+ return s_mkldnn_format_string_map;
154+ }
155+
156+ std::set<memory::format>& runtime::cpu::mkldnn_utils::get_filter_formats ()
157+ {
158+ static std::set<memory::format> s_filter_formats{
159+ memory::format::oihw,
160+ memory::format::ihwo,
161+ memory::format::hwio,
162+ // TODO (nishant): Uncomment after the next release of mkl-dnn"
163+ // memory::format::dhwio,
164+ memory::format::oidhw,
165+ memory::format::OIdhw16i16o,
166+ memory::format::OIdhw16o16i,
167+ memory::format::Oidhw16o,
168+ memory::format::Odhwi16o,
169+ // memory::format::oIhw8i, // These currently map to nChw8c and nChw16c
170+ // memory::format::oIhw16i,
171+ memory::format::OIhw8i8o,
172+ memory::format::OIhw16i16o,
173+ memory::format::IOhw16o16i,
174+ memory::format::OIhw8o8i,
175+ memory::format::OIhw16o16i,
176+ memory::format::Oihw8o,
177+ memory::format::Oihw16o,
178+ memory::format::Ohwi8o,
179+ memory::format::Ohwi16o,
180+ memory::format::OhIw16o4i};
181+ return s_filter_formats;
182+ }
160183bool runtime::cpu::mkldnn_utils::IsMKLDNNOp (ngraph::Node& op)
161184{
162- return (s_op_registry .find (TI (op)) != s_op_registry .end ());
185+ return (get_op_registry () .find (TI (op)) != get_op_registry () .end ());
163186}
164187
165188mkldnn::memory::format runtime::cpu::mkldnn_utils::CreateNativeDataFormat (
@@ -183,27 +206,31 @@ mkldnn::memory::format runtime::cpu::mkldnn_utils::CreateNativeDataFormat(const
183206const std::string&
184207 runtime::cpu::mkldnn_utils::get_mkldnn_data_type_string (const ngraph::element::Type& type)
185208{
186- auto it = s_mkldnn_data_type_string_map.find (type);
187- if (it == s_mkldnn_data_type_string_map.end () || it->second .empty ())
188- throw ngraph_error (" No MKLDNN data type exists for the given element type" );
209+ auto it = get_mkldnn_data_type_string_map ().find (type);
210+ if (it == get_mkldnn_data_type_string_map ().end () || it->second .empty ())
211+ {
212+ throw ngraph_error (" No MKLDNN data type exists for the given element type" +
213+ type.c_type_string ());
214+ }
189215 return it->second ;
190216}
191217
192218mkldnn::memory::data_type
193219 runtime::cpu::mkldnn_utils::get_mkldnn_data_type (const ngraph::element::Type& type)
194220{
195- auto it = s_mkldnn_data_type_map .find (type);
196- if (it == s_mkldnn_data_type_map .end ())
221+ auto it = get_mkldnn_data_type_map () .find (type);
222+ if (it == get_mkldnn_data_type_map () .end ())
197223 {
198- throw ngraph_error (" No MKLDNN data type exists for the given element type" );
224+ throw ngraph_error (" No MKLDNN data type exists for the given element type" +
225+ type.c_type_string ());
199226 }
200227 return it->second ;
201228}
202229
203230const std::string& runtime::cpu::mkldnn_utils::get_mkldnn_format_string (memory::format fmt)
204231{
205- auto it = s_mkldnn_format_string_map .find (fmt);
206- if (it == s_mkldnn_format_string_map .end ())
232+ auto it = get_mkldnn_format_string_map () .find (fmt);
233+ if (it == get_mkldnn_format_string_map () .end ())
207234 throw ngraph_error (" No MKLDNN format exists for the given format type " +
208235 std::to_string (fmt));
209236 return it->second ;
@@ -250,12 +277,13 @@ bool runtime::cpu::mkldnn_utils::can_create_mkldnn_md(const Shape& dims,
250277 const Strides& strides,
251278 const ngraph::element::Type type)
252279{
253- auto it = s_mkldnn_data_type_map .find (type);
280+ auto it = get_mkldnn_data_type_map () .find (type);
254281 if (dims.size () == 0 )
255282 {
256283 return false ;
257284 }
258- if (it == s_mkldnn_data_type_map.end () || it->second == mkldnn::memory::data_type::data_undef)
285+ if (it == get_mkldnn_data_type_map ().end () ||
286+ it->second == mkldnn::memory::data_type::data_undef)
259287 {
260288 return false ;
261289 }
@@ -450,7 +478,7 @@ bool runtime::cpu::mkldnn_utils::compare_mkldnn_mds(const mkldnn::memory::desc&
450478
451479bool runtime::cpu::mkldnn_utils::is_mkldnn_filter_format (mkldnn::memory::format fmt)
452480{
453- if (s_filter_formats .find (fmt) != s_filter_formats .end ())
481+ if (get_filter_formats () .find (fmt) != get_filter_formats () .end ())
454482 {
455483 return true ;
456484 }
0 commit comments