@@ -47,6 +47,30 @@ Constraint TanhConstraintImpl(op: Op) [{
4747 ActivationType::kTANH
4848 );
4949}];
50+ Constraint MaxConstraintImpl(op: Op) [{
51+ return mlir::success(
52+ cast<tensorrt::ElementWiseOp>(op).getElementwiseOperation() ==
53+ ElementWiseOperation::kMAX
54+ );
55+ }];
56+ Constraint MinConstraintImpl(op: Op) [{
57+ return mlir::success(
58+ cast<tensorrt::ElementWiseOp>(op).getElementwiseOperation() ==
59+ ElementWiseOperation::kMIN
60+ );
61+ }];
62+ Constraint PowConstraintImpl(op: Op) [{
63+ return mlir::success(
64+ cast<tensorrt::ElementWiseOp>(op).getElementwiseOperation() ==
65+ ElementWiseOperation::kPOW
66+ );
67+ }];
68+ Constraint ErfConstraintImpl(op: Op) [{
69+ return mlir::success(
70+ cast<tensorrt::UnaryOp>(op).getUnaryOperation() ==
71+ UnaryOperation::kERF
72+ );
73+ }];
5074Constraint MulConstraint(op: Op<tensorrt.element_wise>) -> Op {
5175 MulConstraintImpl(op);
5276 return op;
@@ -59,6 +83,22 @@ Constraint TanhConstraint(op: Op<tensorrt.activation>) -> Op {
5983 TanhConstraintImpl(op);
6084 return op;
6185}
86+ Constraint MaxConstraint(op: Op<tensorrt.element_wise>) -> Op {
87+ MaxConstraintImpl(op);
88+ return op;
89+ }
90+ Constraint MinConstraint(op: Op<tensorrt.element_wise>) -> Op {
91+ MinConstraintImpl(op);
92+ return op;
93+ }
94+ Constraint PowConstraint(op: Op<tensorrt.element_wise>) -> Op {
95+ PowConstraintImpl(op);
96+ return op;
97+ }
98+ Constraint ErfConstraint(op: Op<tensorrt.unary>) -> Op {
99+ ErfConstraintImpl(op);
100+ return op;
101+ }
62102Constraint Mul(lhs: Value, rhs: Value) -> Op {
63103 return MulConstraint(op<tensorrt.element_wise>(lhs, rhs));
64104}
@@ -68,17 +108,36 @@ Constraint Add(lhs: Value, rhs: Value) -> Op {
68108Constraint Tanh(x: Value) -> Op {
69109 return TanhConstraint(op<tensorrt.activation>(x));
70110}
111+ Constraint Max(lhs: Value, rhs: Value) -> Op {
112+ return MaxConstraint(op<tensorrt.element_wise>(lhs, rhs));
113+ }
114+ Constraint Min(lhs: Value, rhs: Value) -> Op {
115+ return MinConstraint(op<tensorrt.element_wise>(lhs, rhs));
116+ }
117+ Constraint Pow(lhs: Value, rhs: Value) -> Op {
118+ return PowConstraint(op<tensorrt.element_wise>(lhs, rhs));
119+ }
120+ Constraint Erf(x: Value) -> Op {
121+ return ErfConstraint(op<tensorrt.unary>(x));
122+ }
123+
124+ Rewrite GetSplatElementAttr(x: Value) -> Attr [{
125+ return *getSplatConstantElementAttribute(x);
126+ }];
127+
128+ Constraint HasSplatElements(x: Value) [{
129+ return LogicalResult(getSplatConstantElementAttribute(x));
130+ }];
71131
72132/// Is true if `x` is a constant op that has a splat constant
73133/// where splat element is equal to `attr`.
74- Constraint SplatElements(x: Op, attr: Attr) [{
75- DenseElementsAttr els{};
76- if(!matchPattern(x, m_Constant(&els)))
77- return failure();
78- if(!els.isSplat())
79- return failure();
80- Attribute value = els.getSplatValue<Attribute>();
81- return success(value == attr);
134+ Constraint SplatElements(x: Value, attr: Attr) [{
135+ FailureOr<Attribute> value = getSplatConstantElementAttribute(x);
136+ if(LogicalResult(value).failed()) return failure();
137+ if(*value == attr) return success();
138+ FloatAttr fvalue = dyn_cast<FloatAttr>(*value);
139+ FloatAttr fattr = dyn_cast<FloatAttr>(attr);
140+ return success(fvalue && fattr && std::abs(fvalue.getValueAsDouble() - fattr.getValueAsDouble()) < .001); // handle different floating point type
82141}];
83142
84143/// We need a native C++ function since we can't create the right
@@ -88,6 +147,16 @@ Rewrite CreateGeluTanh(x: Value) -> Op [{
88147 x, ActivationType::kGELU_TANH, FloatAttr{}, FloatAttr{}
89148 );
90149}];
150+ Rewrite CreateGeluErf(x: Value) -> Op [{
151+ return rewriter.create<tensorrt::ActivationOp>(x.getLoc(),
152+ x, ActivationType::kGELU_ERF, FloatAttr{}, FloatAttr{}
153+ );
154+ }];
155+
156+ Rewrite CreateClipActivation(x: Value, min: Attr, max: Attr) -> Op [{
157+ return rewriter.create<tensorrt::ActivationOp>(x.getLoc(),
158+ x, ActivationType::kCLIP, cast<FloatAttr>(max), cast<FloatAttr>(min));
159+ }];
91160
92161Constraint TypesMatch(x: Value, y: Value) [{
93162 return success(x.getType() == y.getType());
@@ -98,13 +167,13 @@ Constraint TypesMatch(x: Value, y: Value) [{
98167/// `https://github.com/google/jax/blob/main/jax/_src/nn/functions.py#L424-L455`.
99168Pattern RaiseToGeluTanh {
100169 let x: Value;
101- let const0 = op<tensorrt.constant>() ;
170+ let const0: Value ;
102171 SplatElements(const0, attr<"4.471500e-02 : f32">);
103- let rootPiOverTwo = op<tensorrt.constant>() ;
172+ let rootPiOverTwo: Value ;
104173 SplatElements(rootPiOverTwo, attr<"0.797884583 : f32">);
105- let one = op<tensorrt.constant>() ;
174+ let one: Value ;
106175 SplatElements(one, attr<"1.0 : f32">);
107- let half = op<tensorrt.constant>() ;
176+ let half: Value ;
108177 SplatElements(half, attr<"0.5 : f32">);
109178 let scaledCube = Mul(Mul(Mul(x, x), x), const0);
110179 let tanArg = Mul(Add(x, scaledCube), rootPiOverTwo);
@@ -119,3 +188,75 @@ Pattern RaiseToGeluTanh {
119188 replace root with replacement;
120189 };
121190}
191+
192+ /// Raise a sequence of "approximate" GELU to `tensorrt.ext.gelu_tanh`.
193+ /// Matching pattern of Ops from PyTorch/torch-mlir
194+ Pattern RaiseToGeluTanh2 {
195+ let x: Value;
196+ let half2: Value;
197+ SplatElements(half2, attr<"0.5 : f32">);
198+ let one: Value;
199+ SplatElements(one, attr<"1.0 : f32">);
200+ let three: Value;
201+ SplatElements(three, attr<"3.0 : f32">);
202+ let const0: Value;
203+ SplatElements(const0, attr<"4.471500e-02 : f32">);
204+ let scaledCube = Mul(Pow(x, three), const0);
205+ let half1: Value;
206+ SplatElements(half1, attr<"0.5 : f32">);
207+ let twoOverPi: Value;
208+ SplatElements(twoOverPi, attr<"0.63661977236 : f32">);
209+ let sqrt2pi = Pow(twoOverPi, half1);
210+ let tanArg = Mul(sqrt2pi, Add(x, scaledCube));
211+ let inner = Add(Tanh(tanArg), one);
212+ let root = Mul(Mul(x, half2), inner);
213+
214+ // Sanity check for cases where we could have broadcasted x.
215+ TypesMatch(root, x);
216+
217+ rewrite root with {
218+ let replacement = CreateGeluTanh(x);
219+ replace root with replacement;
220+ };
221+ }
222+
223+ /// Raise a sequence of GELU to `tensorrt.ext.gelu_none`.
224+ /// Matching pattern of Ops from PyTorch/torch-mlir
225+ Pattern RaiseToGeluErf {
226+ let x: Value;
227+ let half: Value;
228+ SplatElements(half, attr<"0.5 : f32">);
229+ let one: Value;
230+ SplatElements(one, attr<"1.0 : f32">);
231+ let const0: Value;
232+ SplatElements(const0, attr<"7.070310e-01 : f32">);
233+ let erf = Erf(Mul(x, const0));
234+ let normalCdf = Mul(Add(erf, one), half);
235+ let root = Mul(x, normalCdf);
236+
237+ rewrite root with {
238+ let replacement = CreateGeluErf(x);
239+ replace root with replacement;
240+ };
241+ }
242+
243+ /// Raise a elementwise min/max to the CLIP activation
244+ /// Matching pattern of Ops from PyTorch/torch-mlir
245+ Pattern RaiseMaxMinToClip {
246+ let x: Value;
247+ let minValue: Value;
248+ let maxValue: Value;
249+ let root = Min(minValue, Max(x, maxValue));
250+
251+ TypesMatch(root, x);
252+
253+ HasSplatElements(minValue);
254+ HasSplatElements(maxValue);
255+
256+ rewrite root with {
257+ let min: Attr = GetSplatElementAttr(minValue);
258+ let max: Attr = GetSplatElementAttr(maxValue);
259+ let replacement = CreateClipActivation(x, min, max);
260+ replace root with replacement;
261+ };
262+ }
0 commit comments