@@ -47,6 +47,30 @@ Constraint TanhConstraintImpl(op: Op) [{
4747 ActivationType::kTANH
4848 );
4949}];
50+ Constraint MaxConstraintImpl(op: Op) [{
51+ return mlir::success(
52+ cast<tensorrt::ElementWiseOp>(op).getElementwiseOperation() ==
53+ ElementWiseOperation::kMAX
54+ );
55+ }];
56+ Constraint MinConstraintImpl(op: Op) [{
57+ return mlir::success(
58+ cast<tensorrt::ElementWiseOp>(op).getElementwiseOperation() ==
59+ ElementWiseOperation::kMIN
60+ );
61+ }];
62+ Constraint PowConstraintImpl(op: Op) [{
63+ return mlir::success(
64+ cast<tensorrt::ElementWiseOp>(op).getElementwiseOperation() ==
65+ ElementWiseOperation::kPOW
66+ );
67+ }];
68+ Constraint ErfConstraintImpl(op: Op) [{
69+ return mlir::success(
70+ cast<tensorrt::UnaryOp>(op).getUnaryOperation() ==
71+ UnaryOperation::kERF
72+ );
73+ }];
5074Constraint MulConstraint(op: Op<tensorrt.element_wise>) -> Op {
5175 MulConstraintImpl(op);
5276 return op;
@@ -59,6 +83,22 @@ Constraint TanhConstraint(op: Op<tensorrt.activation>) -> Op {
5983 TanhConstraintImpl(op);
6084 return op;
6185}
86+ Constraint MaxConstraint(op: Op<tensorrt.element_wise>) -> Op {
87+ MaxConstraintImpl(op);
88+ return op;
89+ }
90+ Constraint MinConstraint(op: Op<tensorrt.element_wise>) -> Op {
91+ MinConstraintImpl(op);
92+ return op;
93+ }
94+ Constraint PowConstraint(op: Op<tensorrt.element_wise>) -> Op {
95+ PowConstraintImpl(op);
96+ return op;
97+ }
98+ Constraint ErfConstraint(op: Op<tensorrt.unary>) -> Op {
99+ ErfConstraintImpl(op);
100+ return op;
101+ }
62102Constraint Mul(lhs: Value, rhs: Value) -> Op {
63103 return MulConstraint(op<tensorrt.element_wise>(lhs, rhs));
64104}
@@ -68,17 +108,107 @@ Constraint Add(lhs: Value, rhs: Value) -> Op {
68108Constraint Tanh(x: Value) -> Op {
69109 return TanhConstraint(op<tensorrt.activation>(x));
70110}
111+ Constraint Max(lhs: Value, rhs: Value) -> Op {
112+ return MaxConstraint(op<tensorrt.element_wise>(lhs, rhs));
113+ }
114+ Constraint Min(lhs: Value, rhs: Value) -> Op {
115+ return MinConstraint(op<tensorrt.element_wise>(lhs, rhs));
116+ }
117+ Constraint Pow(lhs: Value, rhs: Value) -> Op {
118+ return PowConstraint(op<tensorrt.element_wise>(lhs, rhs));
119+ }
120+ Constraint Erf(x: Value) -> Op {
121+ return ErfConstraint(op<tensorrt.unary>(x));
122+ }
123+
124+ Rewrite GetSplatElementAttr(x: Value) -> Attr [{
125+ while(true) {
126+ if(auto expandRank = x.getDefiningOp<tensorrt::ExpandRankOp>())
127+ x = expandRank.getInput();
128+ else if(auto reshape = x.getDefiningOp<tensorrt::ReshapeOp>())
129+ x = reshape.getInput();
130+ else if(auto broadcast = x.getDefiningOp<tensorrt::BroadcastOp>())
131+ x = broadcast.getInput();
132+ else if(auto cast = x.getDefiningOp<tensorrt::CastOp>())
133+ x = cast.getInput();
134+ else if(auto identity = x.getDefiningOp<tensorrt::IdentityOp>())
135+ x = identity.getInput();
136+ else if(auto slice = x.getDefiningOp<tensorrt::SliceOp>())
137+ x = slice.getInput();
138+ else if(auto constant = x.getDefiningOp<tensorrt::ConstantOp>()) {
139+ DenseElementsAttr els{};
140+ if(!matchPattern(x, m_Constant(&els)))
141+ return {};
142+ if(!els.isSplat())
143+ return {};
144+ Attribute value = els.getSplatValue<Attribute>();
145+ return value;
146+ } else
147+ return {};
148+ }
149+ return {};
150+ }];
151+
152+ Constraint HasSplatElements(x: Value) [{
153+ while(true) {
154+ if(auto expandRank = x.getDefiningOp<tensorrt::ExpandRankOp>())
155+ x = expandRank.getInput();
156+ else if(auto reshape = x.getDefiningOp<tensorrt::ReshapeOp>())
157+ x = reshape.getInput();
158+ else if(auto broadcast = x.getDefiningOp<tensorrt::BroadcastOp>())
159+ x = broadcast.getInput();
160+ else if(auto cast = x.getDefiningOp<tensorrt::CastOp>())
161+ x = cast.getInput();
162+ else if(auto identity = x.getDefiningOp<tensorrt::IdentityOp>())
163+ x = identity.getInput();
164+ else if(auto slice = x.getDefiningOp<tensorrt::SliceOp>())
165+ x = slice.getInput();
166+ else if(auto constant = x.getDefiningOp<tensorrt::ConstantOp>()) {
167+ DenseElementsAttr els{};
168+ if(!matchPattern(x, m_Constant(&els)))
169+ return failure();
170+ if(!els.isSplat())
171+ return failure();
172+ Attribute value = els.getSplatValue<Attribute>();
173+ return success(isa<FloatAttr>(value));
174+ } else
175+ return failure();
176+ }
177+ return failure();
178+ }];
71179
72180/// Is true if `x` is a constant op that has a splat constant
73181/// where splat element is equal to `attr`.
74- Constraint SplatElements(x: Op, attr: Attr) [{
75- DenseElementsAttr els{};
76- if(!matchPattern(x, m_Constant(&els)))
77- return failure();
78- if(!els.isSplat())
79- return failure();
80- Attribute value = els.getSplatValue<Attribute>();
81- return success(value == attr);
182+ Constraint SplatElements(x: Value, attr: Attr) [{
183+ while(true) {
184+ if(auto expandRank = x.getDefiningOp<tensorrt::ExpandRankOp>())
185+ x = expandRank.getInput();
186+ else if(auto reshape = x.getDefiningOp<tensorrt::ReshapeOp>())
187+ x = reshape.getInput();
188+ else if(auto broadcast = x.getDefiningOp<tensorrt::BroadcastOp>())
189+ x = broadcast.getInput();
190+ else if(auto cast = x.getDefiningOp<tensorrt::CastOp>())
191+ x = cast.getInput();
192+ else if(auto identity = x.getDefiningOp<tensorrt::IdentityOp>())
193+ x = identity.getInput();
194+ else if(auto slice = x.getDefiningOp<tensorrt::SliceOp>())
195+ x = slice.getInput();
196+ else if(auto constant = x.getDefiningOp<tensorrt::ConstantOp>()) {
197+ DenseElementsAttr els{};
198+ if(!matchPattern(x, m_Constant(&els)))
199+ return failure();
200+ if(!els.isSplat())
201+ return failure();
202+ Attribute value = els.getSplatValue<Attribute>();
203+ if(!value) return failure();
204+ if(value == attr) return success();
205+ FloatAttr fvalue = dyn_cast<FloatAttr>(value);
206+ FloatAttr fattr = dyn_cast<FloatAttr>(attr);
207+ return success(fvalue && fattr && std::abs(fvalue.getValueAsDouble() - fattr.getValueAsDouble()) < .001); // handle different floating point type
208+ } else
209+ return failure();
210+ }
211+ return failure();
82212}];
83213
84214/// We need a native C++ function since we can't create the right
@@ -88,6 +218,16 @@ Rewrite CreateGeluTanh(x: Value) -> Op [{
88218 x, ActivationType::kGELU_TANH, FloatAttr{}, FloatAttr{}
89219 );
90220}];
221+ Rewrite CreateGeluErf(x: Value) -> Op [{
222+ return rewriter.create<tensorrt::ActivationOp>(x.getLoc(),
223+ x, ActivationType::kGELU_ERF, FloatAttr{}, FloatAttr{}
224+ );
225+ }];
226+
227+ Rewrite CreateClipActivation(x: Value, min: Attr, max: Attr) -> Op [{
228+ return rewriter.create<tensorrt::ActivationOp>(x.getLoc(),
229+ x, ActivationType::kCLIP, cast<FloatAttr>(max), cast<FloatAttr>(min));
230+ }];
91231
92232Constraint TypesMatch(x: Value, y: Value) [{
93233 return success(x.getType() == y.getType());
@@ -98,13 +238,13 @@ Constraint TypesMatch(x: Value, y: Value) [{
98238/// `https://github.com/google/jax/blob/main/jax/_src/nn/functions.py#L424-L455`.
99239Pattern RaiseToGeluTanh {
100240 let x: Value;
101- let const0 = op<tensorrt.constant>() ;
241+ let const0: Value ;
102242 SplatElements(const0, attr<"4.471500e-02 : f32">);
103- let rootPiOverTwo = op<tensorrt.constant>() ;
243+ let rootPiOverTwo: Value ;
104244 SplatElements(rootPiOverTwo, attr<"0.797884583 : f32">);
105- let one = op<tensorrt.constant>() ;
245+ let one: Value ;
106246 SplatElements(one, attr<"1.0 : f32">);
107- let half = op<tensorrt.constant>() ;
247+ let half: Value ;
108248 SplatElements(half, attr<"0.5 : f32">);
109249 let scaledCube = Mul(Mul(Mul(x, x), x), const0);
110250 let tanArg = Mul(Add(x, scaledCube), rootPiOverTwo);
@@ -119,3 +259,75 @@ Pattern RaiseToGeluTanh {
119259 replace root with replacement;
120260 };
121261}
262+
263+ /// Raise a sequence of "approximate" GELU to `tensorrt.ext.gelu_tanh`.
264+ /// Matching pattern of Ops from PyTorch/torch-mlir
265+ Pattern RaiseToGeluTanh2 {
266+ let x: Value;
267+ let half2: Value;
268+ SplatElements(half2, attr<"0.5 : f32">);
269+ let one: Value;
270+ SplatElements(one, attr<"1.0 : f32">);
271+ let three: Value;
272+ SplatElements(three, attr<"3.0 : f32">);
273+ let const0: Value;
274+ SplatElements(const0, attr<"4.471500e-02 : f32">);
275+ let scaledCube = Mul(Pow(x, three), const0);
276+ let half1: Value;
277+ SplatElements(half1, attr<"0.5 : f32">);
278+ let twoOverPi: Value;
279+ SplatElements(twoOverPi, attr<"0.63661977236 : f32">);
280+ let sqrt2pi = Pow(twoOverPi, half1);
281+ let tanArg = Mul(sqrt2pi, Add(x, scaledCube));
282+ let inner = Add(Tanh(tanArg), one);
283+ let root = Mul(Mul(x, half2), inner);
284+
285+ // Sanity check for cases where we could have broadcasted x.
286+ TypesMatch(root, x);
287+
288+ rewrite root with {
289+ let replacement = CreateGeluTanh(x);
290+ replace root with replacement;
291+ };
292+ }
293+
294+ /// Raise a sequence of GELU to `tensorrt.ext.gelu_none`.
295+ /// Matching pattern of Ops from PyTorch/torch-mlir
296+ Pattern RaiseToGeluErf {
297+ let x: Value;
298+ let half: Value;
299+ SplatElements(half, attr<"0.5 : f32">);
300+ let one: Value;
301+ SplatElements(one, attr<"1.0 : f32">);
302+ let const0: Value;
303+ SplatElements(const0, attr<"7.070310e-01 : f32">);
304+ let erf = Erf(Mul(x, const0));
305+ let normalCdf = Mul(Add(erf, one), half);
306+ let root = Mul(x, normalCdf);
307+
308+ rewrite root with {
309+ let replacement = CreateGeluErf(x);
310+ replace root with replacement;
311+ };
312+ }
313+
314+ /// Raise a elementwise min/max to the CLIP activation
315+ /// Matching pattern of Ops from PyTorch/torch-mlir
316+ Pattern RaiseMaxMinToClip {
317+ let x: Value;
318+ let minValue: Value;
319+ let maxValue: Value;
320+ let root = Min(minValue, Max(x, maxValue));
321+
322+ TypesMatch(root, x);
323+
324+ HasSplatElements(minValue);
325+ HasSplatElements(maxValue);
326+
327+ rewrite root with {
328+ let min: Attr = GetSplatElementAttr(minValue);
329+ let max: Attr = GetSplatElementAttr(maxValue);
330+ let replacement = CreateClipActivation(x, min, max);
331+ replace root with replacement;
332+ };
333+ }
0 commit comments