Skip to content

Commit f6380e8

Browse files
author
Matthew Francis-Landau
committed
Raise Normalization for Pytorch (torch-mlir) Layer norm.
This Adds a Raise normalization pattern that matches the tensorrt mlir that is generated by torch-mlir when lowering the layer norm layer. Signed-off-by: Matthew Francis-Landau <[email protected]>
1 parent 54efd4d commit f6380e8

File tree

5 files changed

+188
-4
lines changed

5 files changed

+188
-4
lines changed

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ def TensorRT_CastOp : TensorRT_Op<"cast", [
401401
let arguments = (ins TensorRT_RankedTensorOf<[I1, UI8, TensorRT_I8, I32, I64, F16, BF16, F32]>:$input);
402402
let results = (outs TensorRT_RankedTensorOf<[I1, UI8, TensorRT_I8, I32, I64, F16, BF16, F32]>:$result);
403403
let assemblyFormat = "attr-dict $input `:` type($input) `to` type($result)";
404+
let hasFolder = 1;
404405

405406
let extraClassDeclaration = [{
406407
/// Returns true if created op is valid for TensorRT major version.

mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRT.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2380,6 +2380,17 @@ OpFoldResult IdentityOp::fold(FoldAdaptor adaptor) {
23802380
return foldIdentity(getType(), getInput(), adaptor);
23812381
}
23822382

2383+
//===----------------------------------------------------------------------===//
2384+
// CastOp
2385+
//===----------------------------------------------------------------------===//
2386+
2387+
OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
2388+
if (getInput().getType() == getType()) {
2389+
return getInput();
2390+
}
2391+
return nullptr;
2392+
}
2393+
23832394
//===----------------------------------------------------------------------===//
23842395
// ReduceOp
23852396
//===----------------------------------------------------------------------===//

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/RaiseNormalizations.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ class RaiseNormalizations
5151
MLIRContext *ctx = &getContext();
5252
RewritePatternSet patterns(ctx);
5353
patterns.add<RaiseInstanceNormalization_NCHW>(ctx);
54+
patterns.add<RaisePytorchLayerNorm>(ctx);
55+
patterns.add<RemoveLayerNormCast>(ctx);
5456

5557
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
5658
emitError(getOperation()->getLoc())

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/RaiseNormalizations.pdll

Lines changed: 139 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,10 @@ Constraint ReduceSumImpl(val: Value)[{
146146
(reduceOp.getInput().getType().getRank() - 1)));
147147
}];
148148

149+
Constraint AvgImpl(op: Op) [{
150+
return success(cast<tensorrt::ReduceOp>(op).getReduceOperation() == ReduceOperation::kAVG);
151+
}];
152+
149153
Constraint CheckRank4(val: Value)[{
150154
RankedTensorType rtt = cast<RankedTensorType>(val.getType());
151155
return success(rtt.getRank() == 4);
@@ -196,10 +200,10 @@ Constraint ReverseSqrt(val : Value) -> Value{
196200
}
197201

198202
Constraint FlattenTailDims(val: Value) -> Value {
199-
CheckRank4(val);
200-
let reshapeRes = op<tensorrt.reshape>(val);
201-
FlattenConstraintImpl(reshapeRes);
202-
return reshapeRes;
203+
CheckRank4(val);
204+
let reshapeRes = op<tensorrt.reshape>(val);
205+
FlattenConstraintImpl(reshapeRes);
206+
return reshapeRes;
203207
}
204208

205209
Constraint ReduceSum(val: Value) -> Value{
@@ -219,6 +223,83 @@ Constraint Mean(input: Value, numHW: Value){
219223
return Div(ExpandTailDims(ReduceSum(FlattenTailDims(input))), numHW);
220224
}
221225

226+
Constraint ReduceAvg(input: Value, reduceAxes: Attr) {
227+
let avgOp = op<tensorrt.reduce>(input) {keepDimensions = attr<"true">, reduceAxes = reduceAxes};
228+
AvgImpl(avgOp);
229+
return avgOp;
230+
}
231+
232+
233+
Rewrite GetSplatElementAttr(x: Value) -> Attr [{
234+
while(true) {
235+
if(auto expandRank = x.getDefiningOp<tensorrt::ExpandRankOp>())
236+
x = expandRank.getInput();
237+
else if(auto reshape = x.getDefiningOp<tensorrt::ReshapeOp>())
238+
x = reshape.getInput();
239+
else if(auto broadcast = x.getDefiningOp<tensorrt::BroadcastOp>())
240+
x = broadcast.getInput();
241+
else if(auto cast = x.getDefiningOp<tensorrt::CastOp>())
242+
x = cast.getInput();
243+
else if(auto identity = x.getDefiningOp<tensorrt::IdentityOp>())
244+
x = identity.getInput();
245+
else if(auto slice = x.getDefiningOp<tensorrt::SliceOp>())
246+
x = slice.getInput();
247+
else if(auto constant = x.getDefiningOp<tensorrt::ConstantOp>()) {
248+
DenseElementsAttr els{};
249+
if(!matchPattern(x, m_Constant(&els)))
250+
return {};
251+
if(!els.isSplat())
252+
return {};
253+
Attribute value = els.getSplatValue<Attribute>();
254+
return value;
255+
} else
256+
return {};
257+
}
258+
return {};
259+
}];
260+
261+
Constraint HasSplatElements(x: Value) [{
262+
while(true) {
263+
if(auto expandRank = x.getDefiningOp<tensorrt::ExpandRankOp>())
264+
x = expandRank.getInput();
265+
else if(auto reshape = x.getDefiningOp<tensorrt::ReshapeOp>())
266+
x = reshape.getInput();
267+
else if(auto broadcast = x.getDefiningOp<tensorrt::BroadcastOp>())
268+
x = broadcast.getInput();
269+
else if(auto cast = x.getDefiningOp<tensorrt::CastOp>())
270+
x = cast.getInput();
271+
else if(auto identity = x.getDefiningOp<tensorrt::IdentityOp>())
272+
x = identity.getInput();
273+
else if(auto slice = x.getDefiningOp<tensorrt::SliceOp>())
274+
x = slice.getInput();
275+
else if(auto constant = x.getDefiningOp<tensorrt::ConstantOp>()) {
276+
DenseElementsAttr els{};
277+
if(!matchPattern(x, m_Constant(&els)))
278+
return failure();
279+
if(!els.isSplat())
280+
return failure();
281+
Attribute value = els.getSplatValue<Attribute>();
282+
return success(isa<FloatAttr, IntegerAttr>(value));
283+
} else
284+
return failure();
285+
}
286+
return failure();
287+
}];
288+
289+
Constraint SameElementType(a: Value, b: Value) [{
290+
return success(cast<RankedTensorType>(a.getType()).getElementType() == cast<RankedTensorType>(b.getType()).getElementType());
291+
}];
292+
293+
Rewrite CreateCast(x: Value, refValue: Value) -> Value [{
294+
Type retType = RankedTensorType::Builder(cast<RankedTensorType>(x.getType())).setElementType(cast<RankedTensorType>(refValue.getType()).getElementType());
295+
return rewriter.createOrFold<tensorrt::CastOp>(
296+
x.getLoc(),
297+
retType,
298+
x
299+
);
300+
}];
301+
302+
222303
Pattern RaiseInstanceNormalization_NCHW {
223304
let inputType : Type;
224305
let input : Value<inputType>;
@@ -240,3 +321,57 @@ Pattern RaiseInstanceNormalization_NCHW {
240321
CheckRank4(addOffset);
241322
replace addOffset with op<tensorrt.normalization>(input, scale, offset){axis = attr<"array<i64: 2,3>">};
242323
}
324+
325+
Pattern RaisePytorchLayerNorm {
326+
let x: Value;
327+
let beta: Value;
328+
let gamma: Value;
329+
let axis: Attr;
330+
let epsilon: Value;
331+
332+
let mean = ReduceAvg(x, axis);
333+
let diffMean = Sub(x, mean);
334+
335+
let varianceDenominator: Value;
336+
let varianceMean = Div(ReduceSum(x), varianceDenominator); // for some reason Pytorch's lowering computes the mean in 2 different ways....
337+
let varianceDiff = Sub(x, varianceMean);
338+
let varianceDiffSquared = Mul(varianceDiff, varianceDiff);
339+
let varianceNumerator = ReduceSum(varianceDiffSquared);
340+
let variance = Div(varianceNumerator, varianceDenominator);
341+
let varianceEps = Add(variance, epsilon);
342+
343+
let inverseSqrt = ReverseSqrt(varianceEps);
344+
let normed = Mul(diffMean, inverseSqrt);
345+
let prod = Mul(normed, gamma);
346+
let root = Add(prod, beta);
347+
348+
HasSplatElements(epsilon);
349+
HasSplatElements(varianceDenominator);
350+
351+
rewrite root with {
352+
let epsilonAttr = GetSplatElementAttr(epsilon);
353+
let replacement = op<tensorrt.normalization>(x, gamma, beta) {axis = axis, eps = epsilonAttr};
354+
replace root with replacement;
355+
};
356+
}
357+
358+
Pattern RemoveLayerNormCast {
359+
let x: Value;
360+
let gamma: Value;
361+
let beta: Value;
362+
let axis: Attr;
363+
let epsilonAttr: Attr;
364+
365+
let castInput = op<tensorrt.cast>(x);
366+
let norm = op<tensorrt.normalization>(castInput, gamma, beta) {axis = axis, eps = epsilonAttr};
367+
let root = op<tensorrt.cast>(norm);
368+
369+
SameElementType(x, root);
370+
371+
rewrite root with {
372+
let newGamma = CreateCast(gamma, x);
373+
let newBeta = CreateCast(beta, x);
374+
let replacement = op<tensorrt.normalization>(x, newGamma, newBeta) {axis = axis, eps = epsilonAttr};
375+
replace root with replacement;
376+
};
377+
}

mlir-tensorrt/tensorrt/test/Dialect/TensorRT/raise-normalizations.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,38 @@ func.func @raise_inst_norm_nchw(%arg0: tensor<1x3x1x1xf32>, %arg1: tensor<1x3x1x
5353

5454
// CHECK-LABEL: @neg_raise_nhwc
5555
// CHECK-NOT: tensorrt.normalization
56+
57+
// -----
58+
59+
// CHECK: @raise_layer_norm_pytorch(%[[arg0:.+]]: tensor<16x1024x1024xf32>)
60+
// CHECK: %[[ret:.+]] = tensorrt.normalization [[attr:.+]](%[[arg0]] : tensor<16x1024x1024xf32>, %[[gamma:.+]] : tensor<1x1x1024xf32>, %[[beta:.+]] : tensor<1x1x1024xf32>)
61+
// CHECK: return %[[ret]]
62+
func.func @raise_layer_norm_pytorch(%arg0: tensor<16x1024x1024xf32>) -> tensor<16x1024x1024xf32> {
63+
%cst_i64 = tensorrt.constant dense<1024> : tensor<i64>
64+
%cst_f32 = tensorrt.constant dense<9.99999974E-6> : tensor<1x1x1xf32>
65+
66+
%cst_bf16_1 = tensorrt.constant dense_resource<__elided__> : tensor<1024xbf16> // beta (added)
67+
%cst_bf16_2 = tensorrt.constant dense_resource<__elided__> : tensor<1024xbf16> // gamma (multiplied)
68+
69+
%6 = tensorrt.reduce <kSUM> %arg0 {keepDimensions = true, reduceAxes = array<i64: 2>} : tensor<16x1024x1024xf32> -> tensor<16x1024x1xf32>
70+
%7 = tensorrt.cast %cst_i64 : tensor<i64> to tensor<f32>
71+
%8 = tensorrt.expand_rank %7 : tensor<f32> to tensor<1x1x1xf32>
72+
%9 = tensorrt.element_wise <kDIV>(%6, %8 : tensor<16x1024x1xf32>, tensor<1x1x1xf32>) -> tensor<16x1024x1xf32>
73+
%10 = tensorrt.element_wise <kSUB>(%arg0, %9 : tensor<16x1024x1024xf32>, tensor<16x1024x1xf32>) -> tensor<16x1024x1024xf32>
74+
%11 = tensorrt.element_wise <kPROD>(%10, %10 : tensor<16x1024x1024xf32>, tensor<16x1024x1024xf32>) -> tensor<16x1024x1024xf32>
75+
%12 = tensorrt.reduce <kSUM> %11 {keepDimensions = true, reduceAxes = array<i64: 2>} : tensor<16x1024x1024xf32> -> tensor<16x1024x1xf32>
76+
%13 = tensorrt.element_wise <kDIV>(%12, %8 : tensor<16x1024x1xf32>, tensor<1x1x1xf32>) -> tensor<16x1024x1xf32> // Var[x]
77+
%15 = tensorrt.reduce <kAVG> %arg0 {keepDimensions = true, reduceAxes = array<i64: 2>} : tensor<16x1024x1024xf32> -> tensor<16x1024x1xf32> // E[x]
78+
%16 = tensorrt.element_wise <kSUM>(%13, %cst_f32 : tensor<16x1024x1xf32>, tensor<1x1x1xf32>) -> tensor<16x1024x1xf32> // Var[x] + epsilon
79+
%17 = tensorrt.unary {unaryOperation = #tensorrt.unary_operation<kRECIP>} %16 : tensor<16x1024x1xf32>
80+
%18 = tensorrt.unary {unaryOperation = #tensorrt.unary_operation<kSQRT>} %17 : tensor<16x1024x1xf32> // compute 1/sqrt(...)
81+
%19 = tensorrt.element_wise <kSUB>(%arg0, %15 : tensor<16x1024x1024xf32>, tensor<16x1024x1xf32>) -> tensor<16x1024x1024xf32>
82+
%20 = tensorrt.element_wise <kPROD>(%19, %18 : tensor<16x1024x1024xf32>, tensor<16x1024x1xf32>) -> tensor<16x1024x1024xf32> // multiply for division
83+
%21 = tensorrt.cast %cst_bf16_2 : tensor<1024xbf16> to tensor<1024xf32>
84+
%22 = tensorrt.expand_rank %21 : tensor<1024xf32> to tensor<1x1x1024xf32>
85+
%23 = tensorrt.element_wise <kPROD>(%20, %22 : tensor<16x1024x1024xf32>, tensor<1x1x1024xf32>) -> tensor<16x1024x1024xf32> // multiply gamma
86+
%24 = tensorrt.cast %cst_bf16_1 : tensor<1024xbf16> to tensor<1024xf32>
87+
%25 = tensorrt.expand_rank %24 : tensor<1024xf32> to tensor<1x1x1024xf32>
88+
%26 = tensorrt.element_wise <kSUM>(%23, %25 : tensor<16x1024x1024xf32>, tensor<1x1x1024xf32>) -> tensor<16x1024x1024xf32> // add beta
89+
return %26 : tensor<16x1024x1024xf32>
90+
}

0 commit comments

Comments
 (0)