This repository was archived by the owner on Jan 3, 2023. It is now read-only.
File tree Expand file tree Collapse file tree 2 files changed +10
-4
lines changed Expand file tree Collapse file tree 2 files changed +10
-4
lines changed Original file line number Diff line number Diff line change @@ -937,8 +937,9 @@ namespace ngraph
937937
938938 mkldnn::algorithm convolution_algo = mkldnn_utils::get_conv_algo ();
939939
940- if (node->get_input_element_type (0 ) != element::f32 &&
941- convolution_algo != mkldnn::algorithm::convolution_direct)
940+ if ((node->get_input_element_type (0 ) != element::f32 &&
941+ convolution_algo != mkldnn::algorithm::convolution_direct) ||
942+ convolution->get_argument (0 )->get_shape ()[1 ] <= 8 )
942943 {
943944 convolution_algo = mkldnn::algorithm::convolution_direct;
944945 }
Original file line number Diff line number Diff line change @@ -395,8 +395,13 @@ namespace ngraph
395395 std::unique_ptr<convolution_forward::desc> fwd_desc{nullptr };
396396 auto convolution_algo = mkldnn_utils::get_conv_algo ();
397397
398- if (node->get_input_element_type (0 ) != element::f32 &&
399- convolution_algo != mkldnn::algorithm::convolution_direct)
398+ // I/p channels less than 8 & convolution_algo = convolution_auto
399+ // forces src format to be nChw16c & the weight format to be
400+ // OIhw16i16o which invokes mkldnn reference implementation of conv
401+ // which crashes as it has no support for post ops
402+ if ((node->get_input_element_type (0 ) != element::f32 &&
403+ convolution_algo != mkldnn::algorithm::convolution_direct) ||
404+ arg0_shape[1 ] <= 8 )
400405 {
401406 convolution_algo = mkldnn::algorithm::convolution_direct;
402407 }
You can’t perform that action at this time.
0 commit comments