@@ -60,15 +60,13 @@ struct BatchNormHelper {
6060 std::vector<Variable> MeanAndVariance (Variable x) {
6161#ifdef CINN_WITH_CUDA
6262 // To optimize the bn forward by merge the reduce computation of mean and variance,
63- // build a fusion op 'BnMeanVarianceReduce ' by hand as the fusion pass is not support now.
63+ // build a fusion op 'BnMeanVariance ' by hand as the fusion pass is not support now.
6464 // When the fusion pass is rebuild, this op is to be removed.
65- auto vars = builder->BnMeanVarianceReduce (x);
65+ auto vars = builder->BnMeanVariance (x);
6666 auto element_count_1d_0 = GetTensorFromScalar<float >(element_count, " element_count" , param_shape);
6767 auto element_count_1d_1 = GetTensorFromScalar<float >(element_count, " element_count" , param_shape);
68- auto mean = builder->Div (builder->Reduce (vars[0 ], ReduceKind::kSum , std::vector<int >(1 , vars[0 ]->shape .size () - 1 )),
69- element_count_1d_0);
70- auto mean_squre = builder->Div (
71- builder->Reduce (vars[1 ], ReduceKind::kSum , std::vector<int >(1 , vars[1 ]->shape .size () - 1 )), element_count_1d_1);
68+ auto mean = builder->Div (vars[0 ], element_count_1d_0);
69+ auto mean_squre = builder->Div (vars[1 ], element_count_1d_1);
7270
7371 auto variance = builder->Sub (mean_squre, builder->Mul (mean, builder->Identity (mean)));
7472#else
@@ -82,11 +80,9 @@ struct BatchNormHelper {
8280
8381 std::vector<Variable> GradBiasAndScale (Variable x, Variable x_mean, Variable y_grad) {
8482#ifdef CINN_WITH_CUDA
85- // Using fusion op "BnGradBiasScaleReduce " as the same reason with "BnMeanVarianceReduce ".
83+ // Using fusion op "BnGradBiasScale " as the same reason with "BnMeanVariance ".
8684 // It also will be removed.
87- auto vars = builder->BnGradBiasScaleReduce (x, x_mean, y_grad);
88- return {builder->Reduce (vars[0 ], ReduceKind::kSum , std::vector<int >(1 , vars[0 ]->shape .size () - 1 )),
89- builder->Reduce (vars[1 ], ReduceKind::kSum , std::vector<int >(1 , vars[1 ]->shape .size () - 1 ))};
85+ return builder->BnGradBiasScale (x, x_mean, y_grad);
9086#else
9187 auto mean_4d = builder->BroadcastTo (x_mean, x->shape , {channel_dim});
9288 auto x_mean_diff = builder->Sub (x, mean_4d);
0 commit comments