Skip to content

Commit efe8de8

Browse files
author
Matthew Francis-Landau
committed
Raise Normalization for Pytorch (torch-mlir) Layer norm.
This Adds a Raise normalization pattern that matches the tensorrt mlir that is generated by torch-mlir when lowering the layer norm layer. Signed-off-by: Matthew Francis-Landau <[email protected]> add utils function of getSplatConstantElementAttribute Signed-off-by: Matthew Francis-Landau <[email protected]>
1 parent f821499 commit efe8de8

File tree

7 files changed

+178
-4
lines changed

7 files changed

+178
-4
lines changed

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ def TensorRT_CastOp : TensorRT_Op<"cast", [
401401
let arguments = (ins TensorRT_RankedTensorOf<[I1, UI8, TensorRT_I8, I32, I64, F16, BF16, F32]>:$input);
402402
let results = (outs TensorRT_RankedTensorOf<[I1, UI8, TensorRT_I8, I32, I64, F16, BF16, F32]>:$result);
403403
let assemblyFormat = "attr-dict $input `:` type($input) `to` type($result)";
404+
let hasFolder = 1;
404405

405406
let extraClassDeclaration = [{
406407
/// Returns true if created op is valid for TensorRT major version.

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/Utils/Utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ TypedValue<RankedTensorType>
6868
scatterShapeTensor(RewriterBase &b, Location loc, ArrayRef<int64_t> baseShape,
6969
int32_t scatterDim, TypedValue<RankedTensorType> update);
7070

71+
/// Get a splatted constant's attribute by going up a chain of reshape and cast
72+
/// operations to find the original constant. The constant can be a different
73+
/// data type if there is a cast operation in the chain.
74+
FailureOr<Attribute> getSplatConstantElementAttribute(Value x);
75+
7176
} // namespace tensorrt
7277
} // namespace mlir
7378

mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRT.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2380,6 +2380,17 @@ OpFoldResult IdentityOp::fold(FoldAdaptor adaptor) {
23802380
return foldIdentity(getType(), getInput(), adaptor);
23812381
}
23822382

2383+
//===----------------------------------------------------------------------===//
2384+
// CastOp
2385+
//===----------------------------------------------------------------------===//
2386+
2387+
OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
2388+
if (getInput().getType() == getType()) {
2389+
return getInput();
2390+
}
2391+
return nullptr;
2392+
}
2393+
23832394
//===----------------------------------------------------------------------===//
23842395
// ReduceOp
23852396
//===----------------------------------------------------------------------===//

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/RaiseNormalizations.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
//===----------------------------------------------------------------------===//
2424
#include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h"
2525
#include "mlir-tensorrt-dialect/TensorRT/Transforms/Passes.h"
26+
#include "mlir-tensorrt-dialect/TensorRT/Utils/Utils.h"
2627
#include "mlir/Dialect/PDL/IR/PDL.h"
2728
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
2829
#include "mlir/IR/Matchers.h"
@@ -51,6 +52,8 @@ class RaiseNormalizations
5152
MLIRContext *ctx = &getContext();
5253
RewritePatternSet patterns(ctx);
5354
patterns.add<RaiseInstanceNormalization_NCHW>(ctx);
55+
patterns.add<RaisePytorchLayerNorm>(ctx);
56+
patterns.add<RemoveLayerNormCast>(ctx);
5457

5558
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
5659
emitError(getOperation()->getLoc())

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/RaiseNormalizations.pdll

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,10 @@ Constraint ReduceSumImpl(val: Value)[{
146146
(reduceOp.getInput().getType().getRank() - 1)));
147147
}];
148148

149+
Constraint AvgImpl(op: Op) [{
150+
return success(cast<tensorrt::ReduceOp>(op).getReduceOperation() == ReduceOperation::kAVG);
151+
}];
152+
149153
Constraint CheckRank4(val: Value)[{
150154
RankedTensorType rtt = cast<RankedTensorType>(val.getType());
151155
return success(rtt.getRank() == 4);
@@ -196,10 +200,10 @@ Constraint ReverseSqrt(val : Value) -> Value{
196200
}
197201

198202
Constraint FlattenTailDims(val: Value) -> Value {
199-
CheckRank4(val);
200-
let reshapeRes = op<tensorrt.reshape>(val);
201-
FlattenConstraintImpl(reshapeRes);
202-
return reshapeRes;
203+
CheckRank4(val);
204+
let reshapeRes = op<tensorrt.reshape>(val);
205+
FlattenConstraintImpl(reshapeRes);
206+
return reshapeRes;
203207
}
204208

205209
Constraint ReduceSum(val: Value) -> Value{
@@ -219,6 +223,35 @@ Constraint Mean(input: Value, numHW: Value){
219223
return Div(ExpandTailDims(ReduceSum(FlattenTailDims(input))), numHW);
220224
}
221225

226+
Constraint ReduceAvg(input: Value, reduceAxes: Attr) {
227+
let avgOp = op<tensorrt.reduce>(input) {keepDimensions = attr<"true">, reduceAxes = reduceAxes};
228+
AvgImpl(avgOp);
229+
return avgOp;
230+
}
231+
232+
233+
Rewrite GetSplatElementAttr(x: Value) -> Attr [{
234+
return *getSplatConstantElementAttribute(x);
235+
}];
236+
237+
Constraint HasSplatElements(x: Value) [{
238+
return LogicalResult(getSplatConstantElementAttribute(x));
239+
}];
240+
241+
Constraint SameElementType(a: Value, b: Value) [{
242+
return success(cast<RankedTensorType>(a.getType()).getElementType() == cast<RankedTensorType>(b.getType()).getElementType());
243+
}];
244+
245+
Rewrite CreateCast(x: Value, refValue: Value) -> Value [{
246+
Type retType = RankedTensorType::Builder(cast<RankedTensorType>(x.getType())).setElementType(cast<RankedTensorType>(refValue.getType()).getElementType());
247+
return rewriter.createOrFold<tensorrt::CastOp>(
248+
x.getLoc(),
249+
retType,
250+
x
251+
);
252+
}];
253+
254+
222255
Pattern RaiseInstanceNormalization_NCHW {
223256
let inputType : Type;
224257
let input : Value<inputType>;
@@ -240,3 +273,57 @@ Pattern RaiseInstanceNormalization_NCHW {
240273
CheckRank4(addOffset);
241274
replace addOffset with op<tensorrt.normalization>(input, scale, offset){axis = attr<"array<i64: 2,3>">};
242275
}
276+
277+
Pattern RaisePytorchLayerNorm {
278+
let x: Value;
279+
let beta: Value;
280+
let gamma: Value;
281+
let axis: Attr;
282+
let epsilon: Value;
283+
284+
let mean = ReduceAvg(x, axis);
285+
let diffMean = Sub(x, mean);
286+
287+
let varianceDenominator: Value;
288+
let varianceMean = Div(ReduceSum(x), varianceDenominator); // for some reason Pytorch's lowering computes the mean in 2 different ways....
289+
let varianceDiff = Sub(x, varianceMean);
290+
let varianceDiffSquared = Mul(varianceDiff, varianceDiff);
291+
let varianceNumerator = ReduceSum(varianceDiffSquared);
292+
let variance = Div(varianceNumerator, varianceDenominator);
293+
let varianceEps = Add(variance, epsilon);
294+
295+
let inverseSqrt = ReverseSqrt(varianceEps);
296+
let normed = Mul(diffMean, inverseSqrt);
297+
let prod = Mul(normed, gamma);
298+
let root = Add(prod, beta);
299+
300+
HasSplatElements(epsilon);
301+
HasSplatElements(varianceDenominator);
302+
303+
rewrite root with {
304+
let epsilonAttr = GetSplatElementAttr(epsilon);
305+
let replacement = op<tensorrt.normalization>(x, gamma, beta) {axis = axis, eps = epsilonAttr};
306+
replace root with replacement;
307+
};
308+
}
309+
310+
Pattern RemoveLayerNormCast {
311+
let x: Value;
312+
let gamma: Value;
313+
let beta: Value;
314+
let axis: Attr;
315+
let epsilonAttr: Attr;
316+
317+
let castInput = op<tensorrt.cast>(x);
318+
let norm = op<tensorrt.normalization>(castInput, gamma, beta) {axis = axis, eps = epsilonAttr};
319+
let root = op<tensorrt.cast>(norm);
320+
321+
SameElementType(x, root);
322+
323+
rewrite root with {
324+
let newGamma = CreateCast(gamma, x);
325+
let newBeta = CreateCast(beta, x);
326+
let replacement = op<tensorrt.normalization>(x, newGamma, newBeta) {axis = axis, eps = epsilonAttr};
327+
replace root with replacement;
328+
};
329+
}

mlir-tensorrt/tensorrt/lib/TensorRT/Utils/Utils.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir-tensorrt-dialect/TensorRT/Utils/Utils.h"
2323

2424
#include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h"
25+
#include "mlir/IR/Matchers.h"
2526
#include "mlir/Interfaces/FunctionInterfaces.h"
2627

2728
using namespace mlir;
@@ -158,3 +159,34 @@ tensorrt::scatterShapeTensor(RewriterBase &b, Location loc,
158159

159160
return b.create<tensorrt::ConcatenationOp>(loc, parts, 0);
160161
}
162+
163+
FailureOr<Attribute> tensorrt::getSplatConstantElementAttribute(Value x) {
164+
while (true) {
165+
if (auto expandRank = x.getDefiningOp<tensorrt::ExpandRankOp>())
166+
x = expandRank.getInput();
167+
else if (auto collapseRank = x.getDefiningOp<tensorrt::CollapseRankOp>())
168+
x = collapseRank.getInput();
169+
else if (auto reshape = x.getDefiningOp<tensorrt::ReshapeOp>())
170+
x = reshape.getInput();
171+
else if (auto broadcast = x.getDefiningOp<tensorrt::BroadcastOp>())
172+
x = broadcast.getInput();
173+
else if (auto cast = x.getDefiningOp<tensorrt::CastOp>())
174+
x = cast.getInput();
175+
else if (auto identity = x.getDefiningOp<tensorrt::IdentityOp>())
176+
x = identity.getInput();
177+
else if (auto slice = x.getDefiningOp<tensorrt::SliceOp>())
178+
x = slice.getInput();
179+
else if (auto constant = x.getDefiningOp<tensorrt::ConstantOp>()) {
180+
SplatElementsAttr els{};
181+
if (!matchPattern(x, m_Constant(&els)))
182+
return failure();
183+
Attribute value = els.getSplatValue<Attribute>();
184+
if (!isa<FloatAttr, IntegerAttr>(value))
185+
return failure();
186+
return value;
187+
} else {
188+
return failure();
189+
}
190+
}
191+
return failure();
192+
}

mlir-tensorrt/tensorrt/test/Dialect/TensorRT/raise-normalizations.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,38 @@ func.func @raise_inst_norm_nchw(%arg0: tensor<1x3x1x1xf32>, %arg1: tensor<1x3x1x
5353

5454
// CHECK-LABEL: @neg_raise_nhwc
5555
// CHECK-NOT: tensorrt.normalization
56+
57+
// -----
58+
59+
// CHECK: @raise_layer_norm_pytorch(%[[arg0:.+]]: tensor<16x1024x1024xf32>)
60+
// CHECK: %[[ret:.+]] = tensorrt.normalization [[attr:.+]](%[[arg0]] : tensor<16x1024x1024xf32>, %[[gamma:.+]] : tensor<1x1x1024xf32>, %[[beta:.+]] : tensor<1x1x1024xf32>)
61+
// CHECK: return %[[ret]]
62+
func.func @raise_layer_norm_pytorch(%arg0: tensor<16x1024x1024xf32>) -> tensor<16x1024x1024xf32> {
63+
%cst_i64 = tensorrt.constant dense<1024> : tensor<i64>
64+
%cst_f32 = tensorrt.constant dense<9.99999974E-6> : tensor<1x1x1xf32>
65+
66+
%cst_bf16_1 = tensorrt.constant dense_resource<__elided__> : tensor<1024xbf16> // beta (added)
67+
%cst_bf16_2 = tensorrt.constant dense_resource<__elided__> : tensor<1024xbf16> // gamma (multiplied)
68+
69+
%6 = tensorrt.reduce <kSUM> %arg0 {keepDimensions = true, reduceAxes = array<i64: 2>} : tensor<16x1024x1024xf32> -> tensor<16x1024x1xf32>
70+
%7 = tensorrt.cast %cst_i64 : tensor<i64> to tensor<f32>
71+
%8 = tensorrt.expand_rank %7 : tensor<f32> to tensor<1x1x1xf32>
72+
%9 = tensorrt.element_wise <kDIV>(%6, %8 : tensor<16x1024x1xf32>, tensor<1x1x1xf32>) -> tensor<16x1024x1xf32>
73+
%10 = tensorrt.element_wise <kSUB>(%arg0, %9 : tensor<16x1024x1024xf32>, tensor<16x1024x1xf32>) -> tensor<16x1024x1024xf32>
74+
%11 = tensorrt.element_wise <kPROD>(%10, %10 : tensor<16x1024x1024xf32>, tensor<16x1024x1024xf32>) -> tensor<16x1024x1024xf32>
75+
%12 = tensorrt.reduce <kSUM> %11 {keepDimensions = true, reduceAxes = array<i64: 2>} : tensor<16x1024x1024xf32> -> tensor<16x1024x1xf32>
76+
%13 = tensorrt.element_wise <kDIV>(%12, %8 : tensor<16x1024x1xf32>, tensor<1x1x1xf32>) -> tensor<16x1024x1xf32> // Var[x]
77+
%15 = tensorrt.reduce <kAVG> %arg0 {keepDimensions = true, reduceAxes = array<i64: 2>} : tensor<16x1024x1024xf32> -> tensor<16x1024x1xf32> // E[x]
78+
%16 = tensorrt.element_wise <kSUM>(%13, %cst_f32 : tensor<16x1024x1xf32>, tensor<1x1x1xf32>) -> tensor<16x1024x1xf32> // Var[x] + epsilon
79+
%17 = tensorrt.unary {unaryOperation = #tensorrt.unary_operation<kRECIP>} %16 : tensor<16x1024x1xf32>
80+
%18 = tensorrt.unary {unaryOperation = #tensorrt.unary_operation<kSQRT>} %17 : tensor<16x1024x1xf32> // compute 1/sqrt(...)
81+
%19 = tensorrt.element_wise <kSUB>(%arg0, %15 : tensor<16x1024x1024xf32>, tensor<16x1024x1xf32>) -> tensor<16x1024x1024xf32>
82+
%20 = tensorrt.element_wise <kPROD>(%19, %18 : tensor<16x1024x1024xf32>, tensor<16x1024x1xf32>) -> tensor<16x1024x1024xf32> // multiply for division
83+
%21 = tensorrt.cast %cst_bf16_2 : tensor<1024xbf16> to tensor<1024xf32>
84+
%22 = tensorrt.expand_rank %21 : tensor<1024xf32> to tensor<1x1x1024xf32>
85+
%23 = tensorrt.element_wise <kPROD>(%20, %22 : tensor<16x1024x1024xf32>, tensor<1x1x1024xf32>) -> tensor<16x1024x1024xf32> // multiply gamma
86+
%24 = tensorrt.cast %cst_bf16_1 : tensor<1024xbf16> to tensor<1024xf32>
87+
%25 = tensorrt.expand_rank %24 : tensor<1024xf32> to tensor<1x1x1024xf32>
88+
%26 = tensorrt.element_wise <kSUM>(%23, %25 : tensor<16x1024x1024xf32>, tensor<1x1x1024xf32>) -> tensor<16x1024x1024xf32> // add beta
89+
return %26 : tensor<16x1024x1024xf32>
90+
}

0 commit comments

Comments
 (0)