@@ -37,7 +37,9 @@ static bool has_zero_dim(std::shared_ptr<Node> node)
3737 {
3838 throw ngraph_error (" has_zero_dim is called on multi-output op" );
3939 }
40- return shape_size (node->get_shape ()) == 0 ;
40+
41+ const auto & shape = node->get_shape ();
42+ return std::find (shape.begin (), shape.end (), 0 ) != shape.end ();
4143}
4244
4345static bool verify_no_internal_zero_length_ops (std::shared_ptr<ngraph::Function> f)
@@ -75,6 +77,7 @@ static bool verify_no_internal_zero_length_ops(std::shared_ptr<ngraph::Function>
7577bool ngraph::pass::ZeroDimTensorElimination::run_on_function (std::shared_ptr<ngraph::Function> f)
7678{
7779 bool replaced = false ;
80+ auto cvals = std::vector<std::string>(0 );
7881 // we need to go over all nodes since we could have sum or any other 0-length-tensor-to scalar op
7982 // as an internal node (i.e. a node that isn't an argument to `op::Result`)
8083 for (auto n : f->get_ordered_ops ())
@@ -93,7 +96,6 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr
9396 {
9497 // we don't have to create constants every time but this is the easiest
9598 // and it's CSE's job to eliminate the same ones
96- auto cvals = std::vector<std::string>(0 );
9799 auto constant =
98100 std::make_shared<op::Constant>(n->get_element_type (), n->get_shape (), cvals);
99101 replace_node (n, constant);
@@ -102,8 +104,21 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr
102104 continue ;
103105 }
104106
107+ if (n->get_inputs ().size () == 0 )
108+ {
109+ continue ;
110+ }
111+
112+ auto arg = n->get_inputs ().at (0 ).get_output ().get_node ();
113+
114+ if (arg->get_outputs ().size () != 1 || !has_zero_dim (arg))
115+ {
116+ continue ;
117+ }
118+
105119 auto new_node = n->get_default_value ();
106- if (!new_node || !has_zero_dim (n->get_argument (0 )))
120+
121+ if (!new_node)
107122 {
108123 continue ;
109124 }
0 commit comments