@@ -3490,6 +3490,35 @@ TEST_F(ArithmeticOptimizerTest,
34903490 VerifyGraphsMatch (item.graph , output, __LINE__);
34913491}
34923492
3493+ TEST_F (ArithmeticOptimizerTest,
3494+ OptimizeMaxOrMinOfMonotonicElementWiseDoNotChangeFetchNodeReduction) {
3495+ tensorflow::Scope s = tensorflow::Scope::NewRootScope ();
3496+ auto x = ops::Const (s.WithOpName (" x" ), {2 , 3 }, {1 , 2 });
3497+ Output reshape = ops::Reshape (s.WithOpName (" reshape" ), x, {-1 });
3498+ Output y = ops::Neg (s.WithOpName (" y" ), reshape);
3499+ Output z = ops::Max (s.WithOpName (" z" ), y, {0 });
3500+
3501+ GrapplerItem item;
3502+ item.fetch = {" z" };
3503+ TF_CHECK_OK (s.ToGraphDef (&item.graph ));
3504+ auto tensors_expected = EvaluateNodes (item.graph , item.fetch );
3505+ ASSERT_EQ (1 , tensors_expected.size ());
3506+
3507+ GraphDef output;
3508+ ArithmeticOptimizer optimizer;
3509+ EnableOnlyOptimizeMaxOrMinOfMonotonic (&optimizer);
3510+ OptimizeTwice (&optimizer, &item, &output);
3511+
3512+ // Should be a NoOp since we are not allowed to change the output of fetch
3513+ // nodes.
3514+ VerifyGraphsMatch (item.graph , output, __LINE__);
3515+
3516+ auto tensors = EvaluateNodes (output, item.fetch );
3517+ ASSERT_EQ (1 , tensors.size ());
3518+ test::ExpectTensorEqual<int >(tensors[0 ], tensors_expected[0 ]);
3519+ test::ExpectTensorEqual<int >(tensors[0 ], Tensor (-2 ));
3520+ }
3521+
34933522TEST_F (ArithmeticOptimizerTest,
34943523 OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing) {
34953524 tensorflow::Scope s = tensorflow::Scope::NewRootScope ();
0 commit comments