@@ -644,9 +644,31 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias()
644644 std::shared_ptr<Node> nn;
645645
646646 auto conv = std::dynamic_pointer_cast<op::Convolution>(m.match_root ()->get_input_op (0 ));
647- auto bias = m.match_root ()->get_input_op (1 )->get_input_op (0 );
648- auto conv_bias = std::shared_ptr<Node>(new op::ConvolutionBias (conv, bias));
649- return conv_bias;
647+ if (conv->get_input_shape (0 ).size () == 4 )
648+ {
649+ auto bias = m.match_root ()->get_input_op (1 )->get_input_op (0 );
650+ auto bias_shape = bias->get_shape ();
651+ if (bias_shape.size () > 1 )
652+ {
653+ NGRAPH_DEBUG
654+ << " mpattern = " << m.match_root ()->get_name ()
655+ << " conv_bias bias shape != 1, requires reshape to match filter count." ;
656+ ngraph::AxisVector order (bias_shape.size ());
657+ std::iota (begin (order), end (order), 0 );
658+ auto bias_reshape =
659+ std::make_shared<op::Reshape>(bias, order, Shape{conv->get_input_shape (1 )[0 ]});
660+ auto conv_bias = std::shared_ptr<Node>(new op::ConvolutionBias (conv, bias_reshape));
661+ return conv_bias;
662+ }
663+ else
664+ {
665+ auto conv_bias = std::shared_ptr<Node>(new op::ConvolutionBias (conv, bias));
666+ return conv_bias;
667+ }
668+ }
669+ NGRAPH_DEBUG << " mpattern = " << m.match_root ()->get_name ()
670+ << " conv_bias fusion skipped due to input rank size != 4." ;
671+ return std::shared_ptr<Node>(nullptr );
650672 };
651673
652674 auto m = std::make_shared<ngraph::pattern::Matcher>(p_conv_bias, callback);
0 commit comments