Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit 3340cbc

Browse files
authored
Migrate #3028 to r0.21 (#3032)
1 parent 002aef5 commit 3340cbc

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

src/ngraph/runtime/cpu/mkldnn_emitter.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -937,8 +937,9 @@ namespace ngraph
937937

938938
mkldnn::algorithm convolution_algo = mkldnn_utils::get_conv_algo();
939939

940-
if (node->get_input_element_type(0) != element::f32 &&
941-
convolution_algo != mkldnn::algorithm::convolution_direct)
940+
if ((node->get_input_element_type(0) != element::f32 &&
941+
convolution_algo != mkldnn::algorithm::convolution_direct) ||
942+
convolution->get_argument(0)->get_shape()[1] <= 8)
942943
{
943944
convolution_algo = mkldnn::algorithm::convolution_direct;
944945
}

src/ngraph/runtime/cpu/pass/cpu_layout.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,13 @@ namespace ngraph
395395
std::unique_ptr<convolution_forward::desc> fwd_desc{nullptr};
396396
auto convolution_algo = mkldnn_utils::get_conv_algo();
397397

398-
if (node->get_input_element_type(0) != element::f32 &&
399-
convolution_algo != mkldnn::algorithm::convolution_direct)
398+
// I/p channels less than 8 & convolution_algo = convolution_auto
399+
// forces src format to be nChw16c & the weight format to be
400+
// OIhw16i16o which invokes mkldnn reference implementation of conv
401+
// which crashes as it has no support for post ops
402+
if ((node->get_input_element_type(0) != element::f32 &&
403+
convolution_algo != mkldnn::algorithm::convolution_direct) ||
404+
arg0_shape[1] <= 8)
400405
{
401406
convolution_algo = mkldnn::algorithm::convolution_direct;
402407
}

0 commit comments

Comments
 (0)