Skip to content

Commit 1d549c7

Browse files
author
Matthew Francis-Landau
committed
[mlir-tensorrt] Raise activations from their elementwise representation to tensorrt.activation Op
Signed-off-by: Matthew Francis-Landau <[email protected]>
1 parent 54efd4d commit 1d549c7

File tree

3 files changed

+296
-13
lines changed

3 files changed

+296
-13
lines changed

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/RaiseActivations.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,12 @@ class RaiseActivations
5050
MLIRContext *ctx = &getContext();
5151
RewritePatternSet patterns(ctx);
5252

53-
if (targetTensorRTVersion >= TensorRTVersion(10, 0))
53+
if (targetTensorRTVersion >= TensorRTVersion(10, 0)) {
5454
patterns.add<RaiseToGeluTanh>(ctx);
55+
patterns.add<RaiseToGeluTanh2>(ctx);
56+
patterns.add<RaiseToGeluErf>(ctx);
57+
}
58+
patterns.add<RaiseMaxMinToClip>(ctx);
5559

5660
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
5761
emitError(getOperation()->getLoc())

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/RaiseActivations.pdll

Lines changed: 224 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,30 @@ Constraint TanhConstraintImpl(op: Op) [{
4747
ActivationType::kTANH
4848
);
4949
}];
50+
Constraint MaxConstraintImpl(op: Op) [{
51+
return mlir::success(
52+
cast<tensorrt::ElementWiseOp>(op).getElementwiseOperation() ==
53+
ElementWiseOperation::kMAX
54+
);
55+
}];
56+
Constraint MinConstraintImpl(op: Op) [{
57+
return mlir::success(
58+
cast<tensorrt::ElementWiseOp>(op).getElementwiseOperation() ==
59+
ElementWiseOperation::kMIN
60+
);
61+
}];
62+
Constraint PowConstraintImpl(op: Op) [{
63+
return mlir::success(
64+
cast<tensorrt::ElementWiseOp>(op).getElementwiseOperation() ==
65+
ElementWiseOperation::kPOW
66+
);
67+
}];
68+
Constraint ErfConstraintImpl(op: Op) [{
69+
return mlir::success(
70+
cast<tensorrt::UnaryOp>(op).getUnaryOperation() ==
71+
UnaryOperation::kERF
72+
);
73+
}];
5074
Constraint MulConstraint(op: Op<tensorrt.element_wise>) -> Op {
5175
MulConstraintImpl(op);
5276
return op;
@@ -59,6 +83,22 @@ Constraint TanhConstraint(op: Op<tensorrt.activation>) -> Op {
5983
TanhConstraintImpl(op);
6084
return op;
6185
}
86+
Constraint MaxConstraint(op: Op<tensorrt.element_wise>) -> Op {
87+
MaxConstraintImpl(op);
88+
return op;
89+
}
90+
Constraint MinConstraint(op: Op<tensorrt.element_wise>) -> Op {
91+
MinConstraintImpl(op);
92+
return op;
93+
}
94+
Constraint PowConstraint(op: Op<tensorrt.element_wise>) -> Op {
95+
PowConstraintImpl(op);
96+
return op;
97+
}
98+
Constraint ErfConstraint(op: Op<tensorrt.unary>) -> Op {
99+
ErfConstraintImpl(op);
100+
return op;
101+
}
62102
Constraint Mul(lhs: Value, rhs: Value) -> Op {
63103
return MulConstraint(op<tensorrt.element_wise>(lhs, rhs));
64104
}
@@ -68,17 +108,107 @@ Constraint Add(lhs: Value, rhs: Value) -> Op {
68108
Constraint Tanh(x: Value) -> Op {
69109
return TanhConstraint(op<tensorrt.activation>(x));
70110
}
111+
Constraint Max(lhs: Value, rhs: Value) -> Op {
112+
return MaxConstraint(op<tensorrt.element_wise>(lhs, rhs));
113+
}
114+
Constraint Min(lhs: Value, rhs: Value) -> Op {
115+
return MinConstraint(op<tensorrt.element_wise>(lhs, rhs));
116+
}
117+
Constraint Pow(lhs: Value, rhs: Value) -> Op {
118+
return PowConstraint(op<tensorrt.element_wise>(lhs, rhs));
119+
}
120+
Constraint Erf(x: Value) -> Op {
121+
return ErfConstraint(op<tensorrt.unary>(x));
122+
}
123+
124+
Rewrite GetSplatElementAttr(x: Value) -> Attr [{
125+
while(true) {
126+
if(auto expandRank = x.getDefiningOp<tensorrt::ExpandRankOp>())
127+
x = expandRank.getInput();
128+
else if(auto reshape = x.getDefiningOp<tensorrt::ReshapeOp>())
129+
x = reshape.getInput();
130+
else if(auto broadcast = x.getDefiningOp<tensorrt::BroadcastOp>())
131+
x = broadcast.getInput();
132+
else if(auto cast = x.getDefiningOp<tensorrt::CastOp>())
133+
x = cast.getInput();
134+
else if(auto identity = x.getDefiningOp<tensorrt::IdentityOp>())
135+
x = identity.getInput();
136+
else if(auto slice = x.getDefiningOp<tensorrt::SliceOp>())
137+
x = slice.getInput();
138+
else if(auto constant = x.getDefiningOp<tensorrt::ConstantOp>()) {
139+
DenseElementsAttr els{};
140+
if(!matchPattern(x, m_Constant(&els)))
141+
return {};
142+
if(!els.isSplat())
143+
return {};
144+
Attribute value = els.getSplatValue<Attribute>();
145+
return value;
146+
} else
147+
return {};
148+
}
149+
return {};
150+
}];
151+
152+
Constraint HasSplatElements(x: Value) [{
153+
while(true) {
154+
if(auto expandRank = x.getDefiningOp<tensorrt::ExpandRankOp>())
155+
x = expandRank.getInput();
156+
else if(auto reshape = x.getDefiningOp<tensorrt::ReshapeOp>())
157+
x = reshape.getInput();
158+
else if(auto broadcast = x.getDefiningOp<tensorrt::BroadcastOp>())
159+
x = broadcast.getInput();
160+
else if(auto cast = x.getDefiningOp<tensorrt::CastOp>())
161+
x = cast.getInput();
162+
else if(auto identity = x.getDefiningOp<tensorrt::IdentityOp>())
163+
x = identity.getInput();
164+
else if(auto slice = x.getDefiningOp<tensorrt::SliceOp>())
165+
x = slice.getInput();
166+
else if(auto constant = x.getDefiningOp<tensorrt::ConstantOp>()) {
167+
DenseElementsAttr els{};
168+
if(!matchPattern(x, m_Constant(&els)))
169+
return failure();
170+
if(!els.isSplat())
171+
return failure();
172+
Attribute value = els.getSplatValue<Attribute>();
173+
return success(isa<FloatAttr>(value));
174+
} else
175+
return failure();
176+
}
177+
return failure();
178+
}];
71179

72180
/// Is true if `x` is a constant op that has a splat constant
73181
/// where splat element is equal to `attr`.
74-
Constraint SplatElements(x: Op, attr: Attr) [{
75-
DenseElementsAttr els{};
76-
if(!matchPattern(x, m_Constant(&els)))
77-
return failure();
78-
if(!els.isSplat())
79-
return failure();
80-
Attribute value = els.getSplatValue<Attribute>();
81-
return success(value == attr);
182+
Constraint SplatElements(x: Value, attr: Attr) [{
183+
while(true) {
184+
if(auto expandRank = x.getDefiningOp<tensorrt::ExpandRankOp>())
185+
x = expandRank.getInput();
186+
else if(auto reshape = x.getDefiningOp<tensorrt::ReshapeOp>())
187+
x = reshape.getInput();
188+
else if(auto broadcast = x.getDefiningOp<tensorrt::BroadcastOp>())
189+
x = broadcast.getInput();
190+
else if(auto cast = x.getDefiningOp<tensorrt::CastOp>())
191+
x = cast.getInput();
192+
else if(auto identity = x.getDefiningOp<tensorrt::IdentityOp>())
193+
x = identity.getInput();
194+
else if(auto slice = x.getDefiningOp<tensorrt::SliceOp>())
195+
x = slice.getInput();
196+
else if(auto constant = x.getDefiningOp<tensorrt::ConstantOp>()) {
197+
DenseElementsAttr els{};
198+
if(!matchPattern(x, m_Constant(&els)))
199+
return failure();
200+
if(!els.isSplat())
201+
return failure();
202+
Attribute value = els.getSplatValue<Attribute>();
203+
if(!value) return failure();
204+
if(value == attr) return success();
205+
FloatAttr fvalue = dyn_cast<FloatAttr>(value);
206+
FloatAttr fattr = dyn_cast<FloatAttr>(attr);
207+
return success(fvalue && fattr && std::abs(fvalue.getValueAsDouble() - fattr.getValueAsDouble()) < .001); // handle different floating point type
208+
} else
209+
return failure();
210+
}
211+
return failure();
82212
}];
83213

84214
/// We need a native C++ function since we can't create the right
@@ -88,6 +218,16 @@ Rewrite CreateGeluTanh(x: Value) -> Op [{
88218
x, ActivationType::kGELU_TANH, FloatAttr{}, FloatAttr{}
89219
);
90220
}];
221+
Rewrite CreateGeluErf(x: Value) -> Op [{
222+
return rewriter.create<tensorrt::ActivationOp>(x.getLoc(),
223+
x, ActivationType::kGELU_ERF, FloatAttr{}, FloatAttr{}
224+
);
225+
}];
226+
227+
Rewrite CreateClipActivation(x: Value, min: Attr, max: Attr) -> Op [{
228+
return rewriter.create<tensorrt::ActivationOp>(x.getLoc(),
229+
x, ActivationType::kCLIP, cast<FloatAttr>(max), cast<FloatAttr>(min));
230+
}];
91231

92232
Constraint TypesMatch(x: Value, y: Value) [{
93233
return success(x.getType() == y.getType());
@@ -98,13 +238,13 @@ Constraint TypesMatch(x: Value, y: Value) [{
98238
/// `https://github.com/google/jax/blob/main/jax/_src/nn/functions.py#L424-L455`.
99239
Pattern RaiseToGeluTanh {
100240
let x: Value;
101-
let const0 = op<tensorrt.constant>();
241+
let const0: Value;
102242
SplatElements(const0, attr<"4.471500e-02 : f32">);
103-
let rootPiOverTwo = op<tensorrt.constant>();
243+
let rootPiOverTwo: Value;
104244
SplatElements(rootPiOverTwo, attr<"0.797884583 : f32">);
105-
let one = op<tensorrt.constant>();
245+
let one: Value;
106246
SplatElements(one, attr<"1.0 : f32">);
107-
let half = op<tensorrt.constant>();
247+
let half: Value;
108248
SplatElements(half, attr<"0.5 : f32">);
109249
let scaledCube = Mul(Mul(Mul(x, x), x), const0);
110250
let tanArg = Mul(Add(x, scaledCube), rootPiOverTwo);
@@ -119,3 +259,75 @@ Pattern RaiseToGeluTanh {
119259
replace root with replacement;
120260
};
121261
}
262+
263+
/// Raise a sequence of "approximate" GELU to `tensorrt.ext.gelu_tanh`.
264+
/// Matching pattern of Ops from PyTorch/torch-mlir
265+
Pattern RaiseToGeluTanh2 {
266+
let x: Value;
267+
let half2: Value;
268+
SplatElements(half2, attr<"0.5 : f32">);
269+
let one: Value;
270+
SplatElements(one, attr<"1.0 : f32">);
271+
let three: Value;
272+
SplatElements(three, attr<"3.0 : f32">);
273+
let const0: Value;
274+
SplatElements(const0, attr<"4.471500e-02 : f32">);
275+
let scaledCube = Mul(Pow(x, three), const0);
276+
let half1: Value;
277+
SplatElements(half1, attr<"0.5 : f32">);
278+
let twoOverPi: Value;
279+
SplatElements(twoOverPi, attr<"0.63661977236 : f32">);
280+
let sqrt2pi = Pow(twoOverPi, half1);
281+
let tanArg = Mul(sqrt2pi, Add(x, scaledCube));
282+
let inner = Add(Tanh(tanArg), one);
283+
let root = Mul(Mul(x, half2), inner);
284+
285+
// Sanity check for cases where we could have broadcasted x.
286+
TypesMatch(root, x);
287+
288+
rewrite root with {
289+
let replacement = CreateGeluTanh(x);
290+
replace root with replacement;
291+
};
292+
}
293+
294+
/// Raise a sequence of GELU to `tensorrt.ext.gelu_none`.
295+
/// Matching pattern of Ops from PyTorch/torch-mlir
296+
Pattern RaiseToGeluErf {
297+
let x: Value;
298+
let half: Value;
299+
SplatElements(half, attr<"0.5 : f32">);
300+
let one: Value;
301+
SplatElements(one, attr<"1.0 : f32">);
302+
let const0: Value;
303+
SplatElements(const0, attr<"7.070310e-01 : f32">);
304+
let erf = Erf(Mul(x, const0));
305+
let normalCdf = Mul(Add(erf, one), half);
306+
let root = Mul(x, normalCdf);
307+
308+
rewrite root with {
309+
let replacement = CreateGeluErf(x);
310+
replace root with replacement;
311+
};
312+
}
313+
314+
/// Raise a elementwise min/max to the CLIP activation
315+
/// Matching pattern of Ops from PyTorch/torch-mlir
316+
Pattern RaiseMaxMinToClip {
317+
let x: Value;
318+
let minValue: Value;
319+
let maxValue: Value;
320+
let root = Min(minValue, Max(x, maxValue));
321+
322+
TypesMatch(root, x);
323+
324+
HasSplatElements(minValue);
325+
HasSplatElements(maxValue);
326+
327+
rewrite root with {
328+
let min: Attr = GetSplatElementAttr(minValue);
329+
let max: Attr = GetSplatElementAttr(maxValue);
330+
let replacement = CreateClipActivation(x, min, max);
331+
replace root with replacement;
332+
};
333+
}

mlir-tensorrt/tensorrt/test/Dialect/TensorRT/raise-activations.mlir

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,70 @@ func.func @raise_gelu(%arg0: tensor<12x128x4x12x1xf32>) -> (tensor<12x128x4x12x1
2525

2626
// TRT8-LABEL: func.func @raise_gelu
2727
// TRT8-NOT: kGELU_TANH
28+
29+
// -----
30+
31+
// CHECK-LABEL: func.func @raise_gelu2
32+
// CHECK-SAME: (%[[arg0:.+]]: tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> {
33+
// CHECK: %[[v0:.+]] = tensorrt.activation {activationType = #tensorrt.activation_type<kGELU_TANH>} %[[arg0]] : tensor<16x1024x1024xbf16>
34+
// CHECK: return %[[v0]] : tensor<16x1024x1024xbf16>
35+
36+
// TRT8-LABEL: func.func @raise_gelu2
37+
// TRT8-NOT: kGELU_TANH
38+
39+
func.func @raise_gelu2(%arg0: tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> {
40+
%cst_bf16 = tensorrt.constant dense<5.000000e-01> : tensor<1x1x1xbf16>
41+
%cst_bf16_0 = tensorrt.constant dense<3.000000e+00> : tensor<1x1x1xbf16>
42+
%cst_bf16_2 = tensorrt.constant dense<6.367190e-01> : tensor<1x1x1xbf16>
43+
%cst_bf16_3 = tensorrt.constant dense<4.467770e-02> : tensor<1x1x1xbf16>
44+
%cst_bf16_5 = tensorrt.constant dense<1.000000e+00> : tensor<1x1x1xbf16>
45+
%0 = tensorrt.slice %cst_bf16[0, 0, 0][16, 1024, 1024][1, 1, 1] {mode = #tensorrt.slice_mode<kWRAP>} : tensor<1x1x1xbf16> to tensor<16x1024x1024xbf16>
46+
%5 = tensorrt.element_wise <kPROD>(%arg0, %cst_bf16 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16>
47+
%6 = tensorrt.element_wise <kPOW>(%cst_bf16_2, %0 : tensor<1x1x1xbf16>, tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16>
48+
%7 = tensorrt.element_wise <kPOW>(%arg0, %cst_bf16_0 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16>
49+
%8 = tensorrt.element_wise <kPROD>(%7, %cst_bf16_3 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16>
50+
%9 = tensorrt.element_wise <kSUM>(%arg0, %8 : tensor<16x1024x1024xbf16>, tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16>
51+
%10 = tensorrt.element_wise <kPROD>(%6, %9 : tensor<16x1024x1024xbf16>, tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16>
52+
%11 = tensorrt.activation {activationType = #tensorrt.activation_type<kTANH>} %10 : tensor<16x1024x1024xbf16>
53+
%12 = tensorrt.element_wise <kSUM>(%11, %cst_bf16_5 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16>
54+
%13 = tensorrt.element_wise <kPROD>(%5, %12 : tensor<16x1024x1024xbf16>, tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16>
55+
return %13 : tensor<16x1024x1024xbf16>
56+
}
57+
58+
// -----
59+
60+
// CHECK-LABEL: func.func @raise_gelu_erf
61+
// CHECK-SAME: (%[[arg0:.+]]: tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> {
62+
// CHECK: %[[v0:.+]] = tensorrt.activation {activationType = #tensorrt.activation_type<kGELU_ERF>} %[[arg0]] : tensor<16x1024x1024xbf16>
63+
// CHECK: return %[[v0]] : tensor<16x1024x1024xbf16>
64+
65+
// TRT8-LABEL: func.func @raise_gelu_erf
66+
// TRT8-NOT: kGELU_ERF
67+
68+
func.func @raise_gelu_erf(%arg0: tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> {
69+
%cst_bf16_1 = tensorrt.constant dense<1.000000e+00> : tensor<1x1x1xbf16>
70+
%cst_bf16_2 = tensorrt.constant dense<5.000000e-01> : tensor<1x1x1xbf16>
71+
%cst_bf16_3 = tensorrt.constant dense<7.070310e-01> : tensor<1x1x1xbf16>
72+
%5 = tensorrt.element_wise <kPROD>(%arg0, %cst_bf16_3 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16>
73+
%6 = tensorrt.unary {unaryOperation = #tensorrt.unary_operation<kERF>} %5 : tensor<16x1024x1024xbf16>
74+
%7 = tensorrt.element_wise <kSUM>(%6, %cst_bf16_1 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16>
75+
%8 = tensorrt.element_wise <kPROD>(%7, %cst_bf16_2 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16>
76+
%9 = tensorrt.element_wise <kPROD>(%arg0, %8 : tensor<16x1024x1024xbf16>, tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16>
77+
return %9 : tensor<16x1024x1024xbf16>
78+
}
79+
80+
// -----
81+
82+
// CHECK: @raise_min_max(%[[arg0:.+]]: tensor<16x1024x1024xbf16>)
83+
// CHECK: #tensorrt.activation_type<kCLIP>
84+
func.func @raise_min_max(%arg0: tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> {
85+
%cst_f32 = tensorrt.constant dense<6.000000e+00> : tensor<f32>
86+
%cst_f32_1 = tensorrt.constant dense<0.000000e+00> : tensor<f32>
87+
%5 = tensorrt.cast %cst_f32_1 : tensor<f32> to tensor<bf16>
88+
%6 = tensorrt.expand_rank %5 : tensor<bf16> to tensor<1x1x1xbf16>
89+
%8 = tensorrt.cast %cst_f32 : tensor<f32> to tensor<bf16>
90+
%9 = tensorrt.expand_rank %8 : tensor<bf16> to tensor<1x1x1xbf16>
91+
%15 = tensorrt.element_wise <kMAX>(%arg0, %6 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16>
92+
%16 = tensorrt.element_wise <kMIN>(%9, %15 : tensor<1x1x1xbf16>, tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16>
93+
return %16 : tensor<16x1024x1024xbf16>
94+
}

0 commit comments

Comments
 (0)