Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit ea40bc4

Browse files
author
Adam Procter
authored
Merge pull request #1708 from NervanaSystems/aprocter/cherry-pick-1663
Cherry-pick "zero dim elem fix (#1663)"
2 parents f3c8845 + 427bcc1 commit ea40bc4

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

src/ngraph/pass/zero_dim_tensor_elimination.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ static bool has_zero_dim(std::shared_ptr<Node> node)
3737
{
3838
throw ngraph_error("has_zero_dim is called on multi-output op");
3939
}
40-
return shape_size(node->get_shape()) == 0;
40+
41+
const auto& shape = node->get_shape();
42+
return std::find(shape.begin(), shape.end(), 0) != shape.end();
4143
}
4244

4345
static bool verify_no_internal_zero_length_ops(std::shared_ptr<ngraph::Function> f)
@@ -75,6 +77,7 @@ static bool verify_no_internal_zero_length_ops(std::shared_ptr<ngraph::Function>
7577
bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngraph::Function> f)
7678
{
7779
bool replaced = false;
80+
auto cvals = std::vector<std::string>(0);
7881
// we need to go over all nodes since we could have sum or any other 0-length-tensor-to scalar op
7982
// as an internal node (i.e. a node that isn't an argument to `op::Result`)
8083
for (auto n : f->get_ordered_ops())
@@ -93,7 +96,6 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr
9396
{
9497
// we don't have to create constants every time but this is the easiest
9598
// and it's CSE's job to eliminate the same ones
96-
auto cvals = std::vector<std::string>(0);
9799
auto constant =
98100
std::make_shared<op::Constant>(n->get_element_type(), n->get_shape(), cvals);
99101
replace_node(n, constant);
@@ -102,8 +104,21 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr
102104
continue;
103105
}
104106

107+
if (n->get_inputs().size() == 0)
108+
{
109+
continue;
110+
}
111+
112+
auto arg = n->get_inputs().at(0).get_output().get_node();
113+
114+
if (arg->get_outputs().size() != 1 || !has_zero_dim(arg))
115+
{
116+
continue;
117+
}
118+
105119
auto new_node = n->get_default_value();
106-
if (!new_node || !has_zero_dim(n->get_argument(0)))
120+
121+
if (!new_node)
107122
{
108123
continue;
109124
}

0 commit comments

Comments
 (0)