Skip to content

Commit d21ea65

Browse files
andylytensorflower-gardener
authored andcommitted
[Grappler] Don't rewrite reduction(inner_function(foo)) to inner_function(opposite_reduction(foo)) if reduction is a fetch node.
PiperOrigin-RevId: 225558983
1 parent 40934f0 commit d21ea65

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2722,6 +2722,9 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
27222722

27232723
Status TrySimplify(NodeDef* reduction_node,
27242724
string* simplified_node_name) override {
2725+
if (IsInPreserveSet(*reduction_node)) {
2726+
return Status::OK();
2727+
}
27252728
NodeDef* inner_function;
27262729
TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function));
27272730
// Optimize only if:

tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3490,6 +3490,35 @@ TEST_F(ArithmeticOptimizerTest,
34903490
VerifyGraphsMatch(item.graph, output, __LINE__);
34913491
}
34923492

3493+
TEST_F(ArithmeticOptimizerTest,
3494+
OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNodeReduction) {
3495+
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
3496+
auto x = ops::Const(s.WithOpName("x"), {2, 3}, {1, 2});
3497+
Output reshape = ops::Reshape(s.WithOpName("reshape"), x, {-1});
3498+
Output y = ops::Neg(s.WithOpName("y"), reshape);
3499+
Output z = ops::Max(s.WithOpName("z"), y, {0});
3500+
3501+
GrapplerItem item;
3502+
item.fetch = {"z"};
3503+
TF_CHECK_OK(s.ToGraphDef(&item.graph));
3504+
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
3505+
ASSERT_EQ(1, tensors_expected.size());
3506+
3507+
GraphDef output;
3508+
ArithmeticOptimizer optimizer;
3509+
EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
3510+
OptimizeTwice(&optimizer, &item, &output);
3511+
3512+
// Should be a NoOp since we are not allowed to change the output of fetch
3513+
// nodes.
3514+
VerifyGraphsMatch(item.graph, output, __LINE__);
3515+
3516+
auto tensors = EvaluateNodes(output, item.fetch);
3517+
ASSERT_EQ(1, tensors.size());
3518+
test::ExpectTensorEqual<int>(tensors[0], tensors_expected[0]);
3519+
test::ExpectTensorEqual<int>(tensors[0], Tensor(-2));
3520+
}
3521+
34933522
TEST_F(ArithmeticOptimizerTest,
34943523
OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing) {
34953524
tensorflow::Scope s = tensorflow::Scope::NewRootScope();

0 commit comments

Comments
 (0)