@@ -146,6 +146,10 @@ Constraint ReduceSumImpl(val: Value)[{
146146 (reduceOp.getInput().getType().getRank() - 1)));
147147}];
148148
149+ Constraint AvgImpl(op: Op) [{
150+ return success(cast<tensorrt::ReduceOp>(op).getReduceOperation() == ReduceOperation::kAVG);
151+ }];
152+
149153Constraint CheckRank4(val: Value)[{
150154 RankedTensorType rtt = cast<RankedTensorType>(val.getType());
151155 return success(rtt.getRank() == 4);
@@ -196,10 +200,10 @@ Constraint ReverseSqrt(val : Value) -> Value{
196200}
197201
198202Constraint FlattenTailDims(val: Value) -> Value {
199- CheckRank4(val);
200- let reshapeRes = op<tensorrt.reshape>(val);
201- FlattenConstraintImpl(reshapeRes);
202- return reshapeRes;
203+ CheckRank4(val);
204+ let reshapeRes = op<tensorrt.reshape>(val);
205+ FlattenConstraintImpl(reshapeRes);
206+ return reshapeRes;
203207}
204208
205209Constraint ReduceSum(val: Value) -> Value{
@@ -219,6 +223,83 @@ Constraint Mean(input: Value, numHW: Value){
219223 return Div(ExpandTailDims(ReduceSum(FlattenTailDims(input))), numHW);
220224}
221225
226+ Constraint ReduceAvg(input: Value, reduceAxes: Attr) {
227+ let avgOp = op<tensorrt.reduce>(input) {keepDimensions = attr<"true">, reduceAxes = reduceAxes};
228+ AvgImpl(avgOp);
229+ return avgOp;
230+ }
231+
232+
233+ Rewrite GetSplatElementAttr(x: Value) -> Attr [{
234+ while(true) {
235+ if(auto expandRank = x.getDefiningOp<tensorrt::ExpandRankOp>())
236+ x = expandRank.getInput();
237+ else if(auto reshape = x.getDefiningOp<tensorrt::ReshapeOp>())
238+ x = reshape.getInput();
239+ else if(auto broadcast = x.getDefiningOp<tensorrt::BroadcastOp>())
240+ x = broadcast.getInput();
241+ else if(auto cast = x.getDefiningOp<tensorrt::CastOp>())
242+ x = cast.getInput();
243+ else if(auto identity = x.getDefiningOp<tensorrt::IdentityOp>())
244+ x = identity.getInput();
245+ else if(auto slice = x.getDefiningOp<tensorrt::SliceOp>())
246+ x = slice.getInput();
247+ else if(auto constant = x.getDefiningOp<tensorrt::ConstantOp>()) {
248+ DenseElementsAttr els{};
249+ if(!matchPattern(x, m_Constant(&els)))
250+ return {};
251+ if(!els.isSplat())
252+ return {};
253+ Attribute value = els.getSplatValue<Attribute>();
254+ return value;
255+ } else
256+ return {};
257+ }
258+ return {};
259+ }];
260+
261+ Constraint HasSplatElements(x: Value) [{
262+ while(true) {
263+ if(auto expandRank = x.getDefiningOp<tensorrt::ExpandRankOp>())
264+ x = expandRank.getInput();
265+ else if(auto reshape = x.getDefiningOp<tensorrt::ReshapeOp>())
266+ x = reshape.getInput();
267+ else if(auto broadcast = x.getDefiningOp<tensorrt::BroadcastOp>())
268+ x = broadcast.getInput();
269+ else if(auto cast = x.getDefiningOp<tensorrt::CastOp>())
270+ x = cast.getInput();
271+ else if(auto identity = x.getDefiningOp<tensorrt::IdentityOp>())
272+ x = identity.getInput();
273+ else if(auto slice = x.getDefiningOp<tensorrt::SliceOp>())
274+ x = slice.getInput();
275+ else if(auto constant = x.getDefiningOp<tensorrt::ConstantOp>()) {
276+ DenseElementsAttr els{};
277+ if(!matchPattern(x, m_Constant(&els)))
278+ return failure();
279+ if(!els.isSplat())
280+ return failure();
281+ Attribute value = els.getSplatValue<Attribute>();
282+ return success(isa<FloatAttr, IntegerAttr>(value));
283+ } else
284+ return failure();
285+ }
286+ return failure();
287+ }];
288+
289+ Constraint SameElementType(a: Value, b: Value) [{
290+ return success(cast<RankedTensorType>(a.getType()).getElementType() == cast<RankedTensorType>(b.getType()).getElementType());
291+ }];
292+
293+ Rewrite CreateCast(x: Value, refValue: Value) -> Value [{
294+ Type retType = RankedTensorType::Builder(cast<RankedTensorType>(x.getType())).setElementType(cast<RankedTensorType>(refValue.getType()).getElementType());
295+ return rewriter.createOrFold<tensorrt::CastOp>(
296+ x.getLoc(),
297+ retType,
298+ x
299+ );
300+ }];
301+
302+
222303Pattern RaiseInstanceNormalization_NCHW {
223304 let inputType : Type;
224305 let input : Value<inputType>;
@@ -240,3 +321,57 @@ Pattern RaiseInstanceNormalization_NCHW {
240321 CheckRank4(addOffset);
241322 replace addOffset with op<tensorrt.normalization>(input, scale, offset){axis = attr<"array<i64: 2,3>">};
242323}
324+
325+ Pattern RaisePytorchLayerNorm {
326+ let x: Value;
327+ let beta: Value;
328+ let gamma: Value;
329+ let axis: Attr;
330+ let epsilon: Value;
331+
332+ let mean = ReduceAvg(x, axis);
333+ let diffMean = Sub(x, mean);
334+
335+ let varianceDenominator: Value;
336+ let varianceMean = Div(ReduceSum(x), varianceDenominator); // for some reason Pytorch's lowering computes the mean in 2 different ways....
337+ let varianceDiff = Sub(x, varianceMean);
338+ let varianceDiffSquared = Mul(varianceDiff, varianceDiff);
339+ let varianceNumerator = ReduceSum(varianceDiffSquared);
340+ let variance = Div(varianceNumerator, varianceDenominator);
341+ let varianceEps = Add(variance, epsilon);
342+
343+ let inverseSqrt = ReverseSqrt(varianceEps);
344+ let normed = Mul(diffMean, inverseSqrt);
345+ let prod = Mul(normed, gamma);
346+ let root = Add(prod, beta);
347+
348+ HasSplatElements(epsilon);
349+ HasSplatElements(varianceDenominator);
350+
351+ rewrite root with {
352+ let epsilonAttr = GetSplatElementAttr(epsilon);
353+ let replacement = op<tensorrt.normalization>(x, gamma, beta) {axis = axis, eps = epsilonAttr};
354+ replace root with replacement;
355+ };
356+ }
357+
358+ Pattern RemoveLayerNormCast {
359+ let x: Value;
360+ let gamma: Value;
361+ let beta: Value;
362+ let axis: Attr;
363+ let epsilonAttr: Attr;
364+
365+ let castInput = op<tensorrt.cast>(x);
366+ let norm = op<tensorrt.normalization>(castInput, gamma, beta) {axis = axis, eps = epsilonAttr};
367+ let root = op<tensorrt.cast>(norm);
368+
369+ SameElementType(x, root);
370+
371+ rewrite root with {
372+ let newGamma = CreateCast(gamma, x);
373+ let newBeta = CreateCast(beta, x);
374+ let replacement = op<tensorrt.normalization>(x, newGamma, newBeta) {axis = axis, eps = epsilonAttr};
375+ replace root with replacement;
376+ };
377+ }
0 commit comments