Skip to content

Commit d8f3603

Browse files
author
Matthew Francis-Landau
authored
[mlir-tensorrt] Raise activations from their elementwise representation to tensorrt.activation Op (#679)
This PR adds new raising from `tensorrt.elemetnwise` Ops to `tensorrt.activation` for the GELU Tanh (when created by torch-mlir) and GELU Erf. This also includes raising from `min(a, max(x, y))` to the CLIP activation (which is used by clip and ReLU6 by torch-mlir). Merging these elementwise ops in to the `tensorrt.activation` type enables for TensorRT to fuse the activation into a proceeding linear layer's kernel (matrix multiply + elementwise sum). Signed-off-by: Matthew Francis-Landau <[email protected]>
1 parent deb6e79 commit d8f3603

File tree

3 files changed

+226
-13
lines changed

3 files changed

+226
-13
lines changed

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/RaiseActivations.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
//===----------------------------------------------------------------------===//
2424
#include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h"
2525
#include "mlir-tensorrt-dialect/TensorRT/Transforms/Passes.h"
26+
#include "mlir-tensorrt-dialect/TensorRT/Utils/Utils.h"
2627
#include "mlir/Dialect/PDL/IR/PDL.h"
2728
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
2829
#include "mlir/IR/Matchers.h"
@@ -50,8 +51,12 @@ class RaiseActivations
5051
MLIRContext *ctx = &getContext();
5152
RewritePatternSet patterns(ctx);
5253

53-
if (targetTensorRTVersion >= TensorRTVersion(10, 0))
54+
if (targetTensorRTVersion >= TensorRTVersion(10, 0)) {
5455
patterns.add<RaiseToGeluTanh>(ctx);
56+
patterns.add<RaiseToGeluTanh2>(ctx);
57+
patterns.add<RaiseToGeluErf>(ctx);
58+
}
59+
patterns.add<RaiseMaxMinToClip>(ctx);
5560

5661
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
5762
emitError(getOperation()->getLoc())

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/RaiseActivations.pdll

Lines changed: 153 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,30 @@ Constraint TanhConstraintImpl(op: Op) [{
4747
ActivationType::kTANH
4848
);
4949
}];
50+
Constraint MaxConstraintImpl(op: Op) [{
51+
return mlir::success(
52+
cast<tensorrt::ElementWiseOp>(op).getElementwiseOperation() ==
53+
ElementWiseOperation::kMAX
54+
);
55+
}];
56+
Constraint MinConstraintImpl(op: Op) [{
57+
return mlir::success(
58+
cast<tensorrt::ElementWiseOp>(op).getElementwiseOperation() ==
59+
ElementWiseOperation::kMIN
60+
);
61+
}];
62+
Constraint PowConstraintImpl(op: Op) [{
63+
return mlir::success(
64+
cast<tensorrt::ElementWiseOp>(op).getElementwiseOperation() ==
65+
ElementWiseOperation::kPOW
66+
);
67+
}];
68+
Constraint ErfConstraintImpl(op: Op) [{
69+
return mlir::success(
70+
cast<tensorrt::UnaryOp>(op).getUnaryOperation() ==
71+
UnaryOperation::kERF
72+
);
73+
}];
5074
Constraint MulConstraint(op: Op<tensorrt.element_wise>) -> Op {
5175
MulConstraintImpl(op);
5276
return op;
@@ -59,6 +83,22 @@ Constraint TanhConstraint(op: Op<tensorrt.activation>) -> Op {
5983
TanhConstraintImpl(op);
6084
return op;
6185
}
86+
Constraint MaxConstraint(op: Op<tensorrt.element_wise>) -> Op {
87+
MaxConstraintImpl(op);
88+
return op;
89+
}
90+
Constraint MinConstraint(op: Op<tensorrt.element_wise>) -> Op {
91+
MinConstraintImpl(op);
92+
return op;
93+
}
94+
Constraint PowConstraint(op: Op<tensorrt.element_wise>) -> Op {
95+
PowConstraintImpl(op);
96+
return op;
97+
}
98+
Constraint ErfConstraint(op: Op<tensorrt.unary>) -> Op {
99+
ErfConstraintImpl(op);
100+
return op;
101+
}
62102
Constraint Mul(lhs: Value, rhs: Value) -> Op {
63103
return MulConstraint(op<tensorrt.element_wise>(lhs, rhs));
64104
}
@@ -68,17 +108,36 @@ Constraint Add(lhs: Value, rhs: Value) -> Op {
68108
Constraint Tanh(x: Value) -> Op {
69109
return TanhConstraint(op<tensorrt.activation>(x));
70110
}
111+
Constraint Max(lhs: Value, rhs: Value) -> Op {
112+
return MaxConstraint(op<tensorrt.element_wise>(lhs, rhs));
113+
}
114+
Constraint Min(lhs: Value, rhs: Value) -> Op {
115+
return MinConstraint(op<tensorrt.element_wise>(lhs, rhs));
116+
}
117+
Constraint Pow(lhs: Value, rhs: Value) -> Op {
118+
return PowConstraint(op<tensorrt.element_wise>(lhs, rhs));
119+
}
120+
Constraint Erf(x: Value) -> Op {
121+
return ErfConstraint(op<tensorrt.unary>(x));
122+
}
123+
124+
Rewrite GetSplatElementAttr(x: Value) -> Attr [{
125+
return *getSplatConstantElementAttribute(x);
126+
}];
127+
128+
Constraint HasSplatElements(x: Value) [{
129+
return LogicalResult(getSplatConstantElementAttribute(x));
130+
}];
71131

72132
/// Is true if `x` is a constant op that has a splat constant
73133
/// where splat element is equal to `attr`.
74-
Constraint SplatElements(x: Op, attr: Attr) [{
75-
DenseElementsAttr els{};
76-
if(!matchPattern(x, m_Constant(&els)))
77-
return failure();
78-
if(!els.isSplat())
79-
return failure();
80-
Attribute value = els.getSplatValue<Attribute>();
81-
return success(value == attr);
134+
Constraint SplatElements(x: Value, attr: Attr) [{
135+
FailureOr<Attribute> value = getSplatConstantElementAttribute(x);
136+
if(LogicalResult(value).failed()) return failure();
137+
if(*value == attr) return success();
138+
FloatAttr fvalue = dyn_cast<FloatAttr>(*value);
139+
FloatAttr fattr = dyn_cast<FloatAttr>(attr);
140+
return success(fvalue && fattr && std::abs(fvalue.getValueAsDouble() - fattr.getValueAsDouble()) < .001); // handle different floating point type
82141
}];
83142

84143
/// We need a native C++ function since we can't create the right
@@ -88,6 +147,16 @@ Rewrite CreateGeluTanh(x: Value) -> Op [{
88147
x, ActivationType::kGELU_TANH, FloatAttr{}, FloatAttr{}
89148
);
90149
}];
150+
Rewrite CreateGeluErf(x: Value) -> Op [{
151+
return rewriter.create<tensorrt::ActivationOp>(x.getLoc(),
152+
x, ActivationType::kGELU_ERF, FloatAttr{}, FloatAttr{}
153+
);
154+
}];
155+
156+
Rewrite CreateClipActivation(x: Value, min: Attr, max: Attr) -> Op [{
157+
return rewriter.create<tensorrt::ActivationOp>(x.getLoc(),
158+
x, ActivationType::kCLIP, cast<FloatAttr>(max), cast<FloatAttr>(min));
159+
}];
91160

92161
Constraint TypesMatch(x: Value, y: Value) [{
93162
return success(x.getType() == y.getType());
@@ -98,13 +167,13 @@ Constraint TypesMatch(x: Value, y: Value) [{
98167
/// `https://github.com/google/jax/blob/main/jax/_src/nn/functions.py#L424-L455`.
99168
Pattern RaiseToGeluTanh {
100169
let x: Value;
101-
let const0 = op<tensorrt.constant>();
170+
let const0: Value;
102171
SplatElements(const0, attr<"4.471500e-02 : f32">);
103-
let rootPiOverTwo = op<tensorrt.constant>();
172+
let rootPiOverTwo: Value;
104173
SplatElements(rootPiOverTwo, attr<"0.797884583 : f32">);
105-
let one = op<tensorrt.constant>();
174+
let one: Value;
106175
SplatElements(one, attr<"1.0 : f32">);
107-
let half = op<tensorrt.constant>();
176+
let half: Value;
108177
SplatElements(half, attr<"0.5 : f32">);
109178
let scaledCube = Mul(Mul(Mul(x, x), x), const0);
110179
let tanArg = Mul(Add(x, scaledCube), rootPiOverTwo);
@@ -119,3 +188,75 @@ Pattern RaiseToGeluTanh {
119188
replace root with replacement;
120189
};
121190
}
191+
192+
/// Raise a sequence of "approximate" GELU to `tensorrt.ext.gelu_tanh`.
193+
/// Matching pattern of Ops from PyTorch/torch-mlir
194+
Pattern RaiseToGeluTanh2 {
195+
let x: Value;
196+
let half2: Value;
197+
SplatElements(half2, attr<"0.5 : f32">);
198+
let one: Value;
199+
SplatElements(one, attr<"1.0 : f32">);
200+
let three: Value;
201+
SplatElements(three, attr<"3.0 : f32">);
202+
let const0: Value;
203+
SplatElements(const0, attr<"4.471500e-02 : f32">);
204+
let scaledCube = Mul(Pow(x, three), const0);
205+
let half1: Value;
206+
SplatElements(half1, attr<"0.5 : f32">);
207+
let twoOverPi: Value;
208+
SplatElements(twoOverPi, attr<"0.63661977236 : f32">);
209+
let sqrt2pi = Pow(twoOverPi, half1);
210+
let tanArg = Mul(sqrt2pi, Add(x, scaledCube));
211+
let inner = Add(Tanh(tanArg), one);
212+
let root = Mul(Mul(x, half2), inner);
213+
214+
// Sanity check for cases where we could have broadcasted x.
215+
TypesMatch(root, x);
216+
217+
rewrite root with {
218+
let replacement = CreateGeluTanh(x);
219+
replace root with replacement;
220+
};
221+
}
222+
223+
/// Raise a sequence of GELU to `tensorrt.ext.gelu_none`.
224+
/// Matching pattern of Ops from PyTorch/torch-mlir
225+
Pattern RaiseToGeluErf {
226+
let x: Value;
227+
let half: Value;
228+
SplatElements(half, attr<"0.5 : f32">);
229+
let one: Value;
230+
SplatElements(one, attr<"1.0 : f32">);
231+
let const0: Value;
232+
SplatElements(const0, attr<"7.070310e-01 : f32">);
233+
let erf = Erf(Mul(x, const0));
234+
let normalCdf = Mul(Add(erf, one), half);
235+
let root = Mul(x, normalCdf);
236+
237+
rewrite root with {
238+
let replacement = CreateGeluErf(x);
239+
replace root with replacement;
240+
};
241+
}
242+
243+
/// Raise a elementwise min/max to the CLIP activation
244+
/// Matching pattern of Ops from PyTorch/torch-mlir
245+
Pattern RaiseMaxMinToClip {
246+
let x: Value;
247+
let minValue: Value;
248+
let maxValue: Value;
249+
let root = Min(minValue, Max(x, maxValue));
250+
251+
TypesMatch(root, x);
252+
253+
HasSplatElements(minValue);
254+
HasSplatElements(maxValue);
255+
256+
rewrite root with {
257+
let min: Attr = GetSplatElementAttr(minValue);
258+
let max: Attr = GetSplatElementAttr(maxValue);
259+
let replacement = CreateClipActivation(x, min, max);
260+
replace root with replacement;
261+
};
262+
}

mlir-tensorrt/tensorrt/test/Dialect/TensorRT/raise-activations.mlir

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,70 @@ func.func @raise_gelu(%arg0: tensor<12x128x4x12x1xf32>) -> (tensor<12x128x4x12x1
2525

2626
// TRT8-LABEL: func.func @raise_gelu
2727
// TRT8-NOT: kGELU_TANH
28+
29+
// -----
30+
31+
// CHECK-LABEL: func.func @raise_gelu2
32+
// CHECK-SAME: (%[[arg0:.+]]: tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> {
33+
// CHECK: %[[v0:.+]] = tensorrt.activation {activationType = #tensorrt.activation_type<kGELU_TANH>} %[[arg0]] : tensor<16x1024x1024xbf16>
34+
// CHECK: return %[[v0]] : tensor<16x1024x1024xbf16>
35+
36+
// TRT8-LABEL: func.func @raise_gelu2
37+
// TRT8-NOT: kGELU_TANH
38+
39+
func.func @raise_gelu2(%arg0: tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> {
40+
%cst_bf16 = tensorrt.constant dense<5.000000e-01> : tensor<1x1x1xbf16>
41+
%cst_bf16_0 = tensorrt.constant dense<3.000000e+00> : tensor<1x1x1xbf16>
42+
%cst_bf16_2 = tensorrt.constant dense<6.367190e-01> : tensor<1x1x1xbf16>
43+
%cst_bf16_3 = tensorrt.constant dense<4.467770e-02> : tensor<1x1x1xbf16>
44+
%cst_bf16_5 = tensorrt.constant dense<1.000000e+00> : tensor<1x1x1xbf16>
45+
%0 = tensorrt.slice %cst_bf16[0, 0, 0][16, 1024, 1024][1, 1, 1] {mode = #tensorrt.slice_mode<kWRAP>} : tensor<1x1x1xbf16> to tensor<16x1024x1024xbf16>
46+
%5 = tensorrt.element_wise <kPROD>(%arg0, %cst_bf16 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16>
47+
%6 = tensorrt.element_wise <kPOW>(%cst_bf16_2, %0 : tensor<1x1x1xbf16>, tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16>
48+
%7 = tensorrt.element_wise <kPOW>(%arg0, %cst_bf16_0 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16>
49+
%8 = tensorrt.element_wise <kPROD>(%7, %cst_bf16_3 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16>
50+
%9 = tensorrt.element_wise <kSUM>(%arg0, %8 : tensor<16x1024x1024xbf16>, tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16>
51+
%10 = tensorrt.element_wise <kPROD>(%6, %9 : tensor<16x1024x1024xbf16>, tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16>
52+
%11 = tensorrt.activation {activationType = #tensorrt.activation_type<kTANH>} %10 : tensor<16x1024x1024xbf16>
53+
%12 = tensorrt.element_wise <kSUM>(%11, %cst_bf16_5 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16>
54+
%13 = tensorrt.element_wise <kPROD>(%5, %12 : tensor<16x1024x1024xbf16>, tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16>
55+
return %13 : tensor<16x1024x1024xbf16>
56+
}
57+
58+
// -----
59+
60+
// CHECK-LABEL: func.func @raise_gelu_erf
61+
// CHECK-SAME: (%[[arg0:.+]]: tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> {
62+
// CHECK: %[[v0:.+]] = tensorrt.activation {activationType = #tensorrt.activation_type<kGELU_ERF>} %[[arg0]] : tensor<16x1024x1024xbf16>
63+
// CHECK: return %[[v0]] : tensor<16x1024x1024xbf16>
64+
65+
// TRT8-LABEL: func.func @raise_gelu_erf
66+
// TRT8-NOT: kGELU_ERF
67+
68+
func.func @raise_gelu_erf(%arg0: tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> {
69+
%cst_bf16_1 = tensorrt.constant dense<1.000000e+00> : tensor<1x1x1xbf16>
70+
%cst_bf16_2 = tensorrt.constant dense<5.000000e-01> : tensor<1x1x1xbf16>
71+
%cst_bf16_3 = tensorrt.constant dense<7.070310e-01> : tensor<1x1x1xbf16>
72+
%5 = tensorrt.element_wise <kPROD>(%arg0, %cst_bf16_3 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16>
73+
%6 = tensorrt.unary {unaryOperation = #tensorrt.unary_operation<kERF>} %5 : tensor<16x1024x1024xbf16>
74+
%7 = tensorrt.element_wise <kSUM>(%6, %cst_bf16_1 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16>
75+
%8 = tensorrt.element_wise <kPROD>(%7, %cst_bf16_2 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16>
76+
%9 = tensorrt.element_wise <kPROD>(%arg0, %8 : tensor<16x1024x1024xbf16>, tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16>
77+
return %9 : tensor<16x1024x1024xbf16>
78+
}
79+
80+
// -----
81+
82+
// CHECK: @raise_min_max(%[[arg0:.+]]: tensor<16x1024x1024xbf16>)
83+
// CHECK: #tensorrt.activation_type<kCLIP>
84+
func.func @raise_min_max(%arg0: tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> {
85+
%cst_f32 = tensorrt.constant dense<6.000000e+00> : tensor<f32>
86+
%cst_f32_1 = tensorrt.constant dense<0.000000e+00> : tensor<f32>
87+
%5 = tensorrt.cast %cst_f32_1 : tensor<f32> to tensor<bf16>
88+
%6 = tensorrt.expand_rank %5 : tensor<bf16> to tensor<1x1x1xbf16>
89+
%8 = tensorrt.cast %cst_f32 : tensor<f32> to tensor<bf16>
90+
%9 = tensorrt.expand_rank %8 : tensor<bf16> to tensor<1x1x1xbf16>
91+
%15 = tensorrt.element_wise <kMAX>(%arg0, %6 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16>
92+
%16 = tensorrt.element_wise <kMIN>(%9, %15 : tensor<1x1x1xbf16>, tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16>
93+
return %16 : tensor<16x1024x1024xbf16>
94+
}

0 commit comments

Comments
 (0)