Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit 19899f4

Browse files
louisfengMatthew Brookhart
authored andcommitted
fixed conv+bias pattern match causing mxnet tests to fail. (#647)
1 parent 5ec9e25 commit 19899f4

File tree

1 file changed

+25
-3
lines changed

1 file changed

+25
-3
lines changed

src/ngraph/runtime/cpu/pass/cpu_fusion.cpp

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -644,9 +644,31 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_conv_bias()
644644
std::shared_ptr<Node> nn;
645645

646646
auto conv = std::dynamic_pointer_cast<op::Convolution>(m.match_root()->get_input_op(0));
647-
auto bias = m.match_root()->get_input_op(1)->get_input_op(0);
648-
auto conv_bias = std::shared_ptr<Node>(new op::ConvolutionBias(conv, bias));
649-
return conv_bias;
647+
if (conv->get_input_shape(0).size() == 4)
648+
{
649+
auto bias = m.match_root()->get_input_op(1)->get_input_op(0);
650+
auto bias_shape = bias->get_shape();
651+
if (bias_shape.size() > 1)
652+
{
653+
NGRAPH_DEBUG
654+
<< "mpattern = " << m.match_root()->get_name()
655+
<< "conv_bias bias shape != 1, requires reshape to match filter count.";
656+
ngraph::AxisVector order(bias_shape.size());
657+
std::iota(begin(order), end(order), 0);
658+
auto bias_reshape =
659+
std::make_shared<op::Reshape>(bias, order, Shape{conv->get_input_shape(1)[0]});
660+
auto conv_bias = std::shared_ptr<Node>(new op::ConvolutionBias(conv, bias_reshape));
661+
return conv_bias;
662+
}
663+
else
664+
{
665+
auto conv_bias = std::shared_ptr<Node>(new op::ConvolutionBias(conv, bias));
666+
return conv_bias;
667+
}
668+
}
669+
NGRAPH_DEBUG << "mpattern = " << m.match_root()->get_name()
670+
<< "conv_bias fusion skipped due to input rank size != 4.";
671+
return std::shared_ptr<Node>(nullptr);
650672
};
651673

652674
auto m = std::make_shared<ngraph::pattern::Matcher>(p_conv_bias, callback);

0 commit comments

Comments
 (0)