Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit 70b04f0

Browse files
Introduce manually optimized CUDA block_reduce function and use it to generate a single reduce kernel (#622) (#637)
1 parent de53fce commit 70b04f0

File tree

11 files changed

+821
-351
lines changed

11 files changed

+821
-351
lines changed

cinn/frontend/cinn_builder.cc

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -216,17 +216,21 @@ Variable CinnBuilder::Reverse(const Variable& operand, const std::vector<int>& a
216216
return instr.GetOutput(0);
217217
}
218218

219-
std::vector<Variable> CinnBuilder::BnMeanVarianceReduce(const Variable& x) {
220-
Instruction instr("bn_mean_variance_reduce", {x});
219+
std::vector<Variable> CinnBuilder::BnMeanVariance(const Variable& x) {
220+
Instruction instr("bn_mean_variance", {x});
221+
// optimize bn forward reduce computation, set reduce dimension(NCHW suppport only, to be deprecated).
222+
instr.SetAttr("dim", std::vector<int>{0, 2, 3});
223+
instr.SetAttr("keep_dim", false);
221224
InferShape(instr);
222225
AppendInstruction(instr);
223226
return instr.GetOutputs();
224227
}
225228

226-
std::vector<Variable> CinnBuilder::BnGradBiasScaleReduce(const Variable& x,
227-
const Variable& x_mean,
228-
const Variable& y_grad) {
229-
Instruction instr("bn_grad_bias_scale_reduce", {x, x_mean, y_grad});
229+
std::vector<Variable> CinnBuilder::BnGradBiasScale(const Variable& x, const Variable& x_mean, const Variable& y_grad) {
230+
Instruction instr("bn_grad_bias_scale", {x, x_mean, y_grad});
231+
// optimize bn backward reduce computation, set reduce dimension(NCHW suppport only, to be deprecated).
232+
instr.SetAttr("dim", std::vector<int>{0, 2, 3});
233+
instr.SetAttr("keep_dim", false);
230234
InferShape(instr);
231235
AppendInstruction(instr);
232236
return instr.GetOutputs();

cinn/frontend/cinn_builder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,9 @@ class CinnBuilder : public BaseBuilder {
179179

180180
Variable Reverse(const Variable& operand, const std::vector<int>& axis);
181181

182-
std::vector<Variable> BnMeanVarianceReduce(const Variable& x);
182+
std::vector<Variable> BnMeanVariance(const Variable& x);
183183

184-
std::vector<Variable> BnGradBiasScaleReduce(const Variable& x, const Variable& x_mean, const Variable& y_grad);
184+
std::vector<Variable> BnGradBiasScale(const Variable& x, const Variable& x_mean, const Variable& y_grad);
185185

186186
private:
187187
Variable UnaryOp(const std::string& op_type, const Variable& operand);

cinn/frontend/decomposer/batch_norm.cc

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,13 @@ struct BatchNormHelper {
6060
std::vector<Variable> MeanAndVariance(Variable x) {
6161
#ifdef CINN_WITH_CUDA
6262
// To optimize the bn forward by merge the reduce computation of mean and variance,
63-
// build a fusion op 'BnMeanVarianceReduce' by hand as the fusion pass is not support now.
63+
// build a fusion op 'BnMeanVariance' by hand as the fusion pass is not support now.
6464
// When the fusion pass is rebuild, this op is to be removed.
65-
auto vars = builder->BnMeanVarianceReduce(x);
65+
auto vars = builder->BnMeanVariance(x);
6666
auto element_count_1d_0 = GetTensorFromScalar<float>(element_count, "element_count", param_shape);
6767
auto element_count_1d_1 = GetTensorFromScalar<float>(element_count, "element_count", param_shape);
68-
auto mean = builder->Div(builder->Reduce(vars[0], ReduceKind::kSum, std::vector<int>(1, vars[0]->shape.size() - 1)),
69-
element_count_1d_0);
70-
auto mean_squre = builder->Div(
71-
builder->Reduce(vars[1], ReduceKind::kSum, std::vector<int>(1, vars[1]->shape.size() - 1)), element_count_1d_1);
68+
auto mean = builder->Div(vars[0], element_count_1d_0);
69+
auto mean_squre = builder->Div(vars[1], element_count_1d_1);
7270

7371
auto variance = builder->Sub(mean_squre, builder->Mul(mean, builder->Identity(mean)));
7472
#else
@@ -82,11 +80,9 @@ struct BatchNormHelper {
8280

8381
std::vector<Variable> GradBiasAndScale(Variable x, Variable x_mean, Variable y_grad) {
8482
#ifdef CINN_WITH_CUDA
85-
// Using fusion op "BnGradBiasScaleReduce" as the same reason with "BnMeanVarianceReduce".
83+
// Using fusion op "BnGradBiasScale" as the same reason with "BnMeanVariance".
8684
// It also will be removed.
87-
auto vars = builder->BnGradBiasScaleReduce(x, x_mean, y_grad);
88-
return {builder->Reduce(vars[0], ReduceKind::kSum, std::vector<int>(1, vars[0]->shape.size() - 1)),
89-
builder->Reduce(vars[1], ReduceKind::kSum, std::vector<int>(1, vars[1]->shape.size() - 1))};
85+
return builder->BnGradBiasScale(x, x_mean, y_grad);
9086
#else
9187
auto mean_4d = builder->BroadcastTo(x_mean, x->shape, {channel_dim});
9288
auto x_mean_diff = builder->Sub(x, mean_4d);

cinn/hlir/op/reduction.cc

Lines changed: 287 additions & 155 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)