Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
//===----------------------------------------------------------------------===//
#include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h"
#include "mlir-tensorrt-dialect/TensorRT/Transforms/Passes.h"
#include "mlir-tensorrt-dialect/TensorRT/Utils/Utils.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/IR/Matchers.h"
Expand Down Expand Up @@ -50,8 +51,12 @@ class RaiseActivations
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);

if (targetTensorRTVersion >= TensorRTVersion(10, 0))
if (targetTensorRTVersion >= TensorRTVersion(10, 0)) {
patterns.add<RaiseToGeluTanh>(ctx);
patterns.add<RaiseToGeluTanh2>(ctx);
patterns.add<RaiseToGeluErf>(ctx);
}
patterns.add<RaiseMaxMinToClip>(ctx);

if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
emitError(getOperation()->getLoc())
Expand Down
165 changes: 153 additions & 12 deletions mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/RaiseActivations.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,30 @@ Constraint TanhConstraintImpl(op: Op) [{
ActivationType::kTANH
);
}];
Constraint MaxConstraintImpl(op: Op) [{
return mlir::success(
cast<tensorrt::ElementWiseOp>(op).getElementwiseOperation() ==
ElementWiseOperation::kMAX
);
}];
Constraint MinConstraintImpl(op: Op) [{
return mlir::success(
cast<tensorrt::ElementWiseOp>(op).getElementwiseOperation() ==
ElementWiseOperation::kMIN
);
}];
Constraint PowConstraintImpl(op: Op) [{
return mlir::success(
cast<tensorrt::ElementWiseOp>(op).getElementwiseOperation() ==
ElementWiseOperation::kPOW
);
}];
Constraint ErfConstraintImpl(op: Op) [{
return mlir::success(
cast<tensorrt::UnaryOp>(op).getUnaryOperation() ==
UnaryOperation::kERF
);
}];
Constraint MulConstraint(op: Op<tensorrt.element_wise>) -> Op {
MulConstraintImpl(op);
return op;
Expand All @@ -59,6 +83,22 @@ Constraint TanhConstraint(op: Op<tensorrt.activation>) -> Op {
TanhConstraintImpl(op);
return op;
}
Constraint MaxConstraint(op: Op<tensorrt.element_wise>) -> Op {
MaxConstraintImpl(op);
return op;
}
Constraint MinConstraint(op: Op<tensorrt.element_wise>) -> Op {
MinConstraintImpl(op);
return op;
}
Constraint PowConstraint(op: Op<tensorrt.element_wise>) -> Op {
PowConstraintImpl(op);
return op;
}
Constraint ErfConstraint(op: Op<tensorrt.unary>) -> Op {
ErfConstraintImpl(op);
return op;
}
Constraint Mul(lhs: Value, rhs: Value) -> Op {
return MulConstraint(op<tensorrt.element_wise>(lhs, rhs));
}
Expand All @@ -68,17 +108,36 @@ Constraint Add(lhs: Value, rhs: Value) -> Op {
Constraint Tanh(x: Value) -> Op {
return TanhConstraint(op<tensorrt.activation>(x));
}
Constraint Max(lhs: Value, rhs: Value) -> Op {
return MaxConstraint(op<tensorrt.element_wise>(lhs, rhs));
}
Constraint Min(lhs: Value, rhs: Value) -> Op {
return MinConstraint(op<tensorrt.element_wise>(lhs, rhs));
}
Constraint Pow(lhs: Value, rhs: Value) -> Op {
return PowConstraint(op<tensorrt.element_wise>(lhs, rhs));
}
Constraint Erf(x: Value) -> Op {
return ErfConstraint(op<tensorrt.unary>(x));
}

Rewrite GetSplatElementAttr(x: Value) -> Attr [{
return *getSplatConstantElementAttribute(x);
}];

Constraint HasSplatElements(x: Value) [{
return LogicalResult(getSplatConstantElementAttribute(x));
}];

/// Is true if `x` is a constant op that has a splat constant
/// where splat element is equal to `attr`.
Constraint SplatElements(x: Op, attr: Attr) [{
DenseElementsAttr els{};
if(!matchPattern(x, m_Constant(&els)))
return failure();
if(!els.isSplat())
return failure();
Attribute value = els.getSplatValue<Attribute>();
return success(value == attr);
Constraint SplatElements(x: Value, attr: Attr) [{
FailureOr<Attribute> value = getSplatConstantElementAttribute(x);
if(LogicalResult(value).failed()) return failure();
if(*value == attr) return success();
FloatAttr fvalue = dyn_cast<FloatAttr>(*value);
FloatAttr fattr = dyn_cast<FloatAttr>(attr);
return success(fvalue && fattr && std::abs(fvalue.getValueAsDouble() - fattr.getValueAsDouble()) < .001); // handle different floating point type
}];

/// We need a native C++ function since we can't create the right
Expand All @@ -88,6 +147,16 @@ Rewrite CreateGeluTanh(x: Value) -> Op [{
x, ActivationType::kGELU_TANH, FloatAttr{}, FloatAttr{}
);
}];
Rewrite CreateGeluErf(x: Value) -> Op [{
return rewriter.create<tensorrt::ActivationOp>(x.getLoc(),
x, ActivationType::kGELU_ERF, FloatAttr{}, FloatAttr{}
);
}];

Rewrite CreateClipActivation(x: Value, min: Attr, max: Attr) -> Op [{
return rewriter.create<tensorrt::ActivationOp>(x.getLoc(),
x, ActivationType::kCLIP, cast<FloatAttr>(max), cast<FloatAttr>(min));
}];

Constraint TypesMatch(x: Value, y: Value) [{
return success(x.getType() == y.getType());
Expand All @@ -98,13 +167,13 @@ Constraint TypesMatch(x: Value, y: Value) [{
/// `https://github.com/google/jax/blob/main/jax/_src/nn/functions.py#L424-L455`.
Pattern RaiseToGeluTanh {
let x: Value;
let const0 = op<tensorrt.constant>();
let const0: Value;
SplatElements(const0, attr<"4.471500e-02 : f32">);
let rootPiOverTwo = op<tensorrt.constant>();
let rootPiOverTwo: Value;
SplatElements(rootPiOverTwo, attr<"0.797884583 : f32">);
let one = op<tensorrt.constant>();
let one: Value;
SplatElements(one, attr<"1.0 : f32">);
let half = op<tensorrt.constant>();
let half: Value;
SplatElements(half, attr<"0.5 : f32">);
let scaledCube = Mul(Mul(Mul(x, x), x), const0);
let tanArg = Mul(Add(x, scaledCube), rootPiOverTwo);
Expand All @@ -119,3 +188,75 @@ Pattern RaiseToGeluTanh {
replace root with replacement;
};
}

/// Raise a sequence of "approximate" GELU to `tensorrt.ext.gelu_tanh`.
/// Matching pattern of Ops from PyTorch/torch-mlir
Pattern RaiseToGeluTanh2 {
let x: Value;
let half2: Value;
SplatElements(half2, attr<"0.5 : f32">);
let one: Value;
SplatElements(one, attr<"1.0 : f32">);
let three: Value;
SplatElements(three, attr<"3.0 : f32">);
let const0: Value;
SplatElements(const0, attr<"4.471500e-02 : f32">);
let scaledCube = Mul(Pow(x, three), const0);
let half1: Value;
SplatElements(half1, attr<"0.5 : f32">);
let twoOverPi: Value;
SplatElements(twoOverPi, attr<"0.63661977236 : f32">);
let sqrt2pi = Pow(twoOverPi, half1);
let tanArg = Mul(sqrt2pi, Add(x, scaledCube));
let inner = Add(Tanh(tanArg), one);
let root = Mul(Mul(x, half2), inner);

// Sanity check for cases where we could have broadcasted x.
TypesMatch(root, x);

rewrite root with {
let replacement = CreateGeluTanh(x);
replace root with replacement;
};
}

/// Raise a sequence of GELU to `tensorrt.ext.gelu_none`.
/// Matching pattern of Ops from PyTorch/torch-mlir
Pattern RaiseToGeluErf {
let x: Value;
let half: Value;
SplatElements(half, attr<"0.5 : f32">);
let one: Value;
SplatElements(one, attr<"1.0 : f32">);
let const0: Value;
SplatElements(const0, attr<"7.070310e-01 : f32">);
let erf = Erf(Mul(x, const0));
let normalCdf = Mul(Add(erf, one), half);
let root = Mul(x, normalCdf);

rewrite root with {
let replacement = CreateGeluErf(x);
replace root with replacement;
};
}

/// Raise a elementwise min/max to the CLIP activation
/// Matching pattern of Ops from PyTorch/torch-mlir
Pattern RaiseMaxMinToClip {
let x: Value;
let minValue: Value;
let maxValue: Value;
let root = Min(minValue, Max(x, maxValue));

TypesMatch(root, x);

HasSplatElements(minValue);
HasSplatElements(maxValue);

rewrite root with {
let min: Attr = GetSplatElementAttr(minValue);
let max: Attr = GetSplatElementAttr(maxValue);
let replacement = CreateClipActivation(x, min, max);
replace root with replacement;
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,70 @@ func.func @raise_gelu(%arg0: tensor<12x128x4x12x1xf32>) -> (tensor<12x128x4x12x1

// TRT8-LABEL: func.func @raise_gelu
// TRT8-NOT: kGELU_TANH

// -----

// CHECK-LABEL: func.func @raise_gelu2
// CHECK-SAME: (%[[arg0:.+]]: tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> {
// CHECK: %[[v0:.+]] = tensorrt.activation {activationType = #tensorrt.activation_type<kGELU_TANH>} %[[arg0]] : tensor<16x1024x1024xbf16>
// CHECK: return %[[v0]] : tensor<16x1024x1024xbf16>

// TRT8-LABEL: func.func @raise_gelu2
// TRT8-NOT: kGELU_TANH

func.func @raise_gelu2(%arg0: tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> {
%cst_bf16 = tensorrt.constant dense<5.000000e-01> : tensor<1x1x1xbf16>
%cst_bf16_0 = tensorrt.constant dense<3.000000e+00> : tensor<1x1x1xbf16>
%cst_bf16_2 = tensorrt.constant dense<6.367190e-01> : tensor<1x1x1xbf16>
%cst_bf16_3 = tensorrt.constant dense<4.467770e-02> : tensor<1x1x1xbf16>
%cst_bf16_5 = tensorrt.constant dense<1.000000e+00> : tensor<1x1x1xbf16>
%0 = tensorrt.slice %cst_bf16[0, 0, 0][16, 1024, 1024][1, 1, 1] {mode = #tensorrt.slice_mode<kWRAP>} : tensor<1x1x1xbf16> to tensor<16x1024x1024xbf16>
%5 = tensorrt.element_wise <kPROD>(%arg0, %cst_bf16 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16>
%6 = tensorrt.element_wise <kPOW>(%cst_bf16_2, %0 : tensor<1x1x1xbf16>, tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16>
%7 = tensorrt.element_wise <kPOW>(%arg0, %cst_bf16_0 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16>
%8 = tensorrt.element_wise <kPROD>(%7, %cst_bf16_3 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16>
%9 = tensorrt.element_wise <kSUM>(%arg0, %8 : tensor<16x1024x1024xbf16>, tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16>
%10 = tensorrt.element_wise <kPROD>(%6, %9 : tensor<16x1024x1024xbf16>, tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16>
%11 = tensorrt.activation {activationType = #tensorrt.activation_type<kTANH>} %10 : tensor<16x1024x1024xbf16>
%12 = tensorrt.element_wise <kSUM>(%11, %cst_bf16_5 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16>
%13 = tensorrt.element_wise <kPROD>(%5, %12 : tensor<16x1024x1024xbf16>, tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16>
return %13 : tensor<16x1024x1024xbf16>
}

// -----

// CHECK-LABEL: func.func @raise_gelu_erf
// CHECK-SAME: (%[[arg0:.+]]: tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> {
// CHECK: %[[v0:.+]] = tensorrt.activation {activationType = #tensorrt.activation_type<kGELU_ERF>} %[[arg0]] : tensor<16x1024x1024xbf16>
// CHECK: return %[[v0]] : tensor<16x1024x1024xbf16>

// TRT8-LABEL: func.func @raise_gelu_erf
// TRT8-NOT: kGELU_ERF

func.func @raise_gelu_erf(%arg0: tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> {
%cst_bf16_1 = tensorrt.constant dense<1.000000e+00> : tensor<1x1x1xbf16>
%cst_bf16_2 = tensorrt.constant dense<5.000000e-01> : tensor<1x1x1xbf16>
%cst_bf16_3 = tensorrt.constant dense<7.070310e-01> : tensor<1x1x1xbf16>
%5 = tensorrt.element_wise <kPROD>(%arg0, %cst_bf16_3 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16>
%6 = tensorrt.unary {unaryOperation = #tensorrt.unary_operation<kERF>} %5 : tensor<16x1024x1024xbf16>
%7 = tensorrt.element_wise <kSUM>(%6, %cst_bf16_1 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16>
%8 = tensorrt.element_wise <kPROD>(%7, %cst_bf16_2 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16>
%9 = tensorrt.element_wise <kPROD>(%arg0, %8 : tensor<16x1024x1024xbf16>, tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16>
return %9 : tensor<16x1024x1024xbf16>
}

// -----

// CHECK: @raise_min_max(%[[arg0:.+]]: tensor<16x1024x1024xbf16>)
// CHECK: #tensorrt.activation_type<kCLIP>
func.func @raise_min_max(%arg0: tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> {
%cst_f32 = tensorrt.constant dense<6.000000e+00> : tensor<f32>
%cst_f32_1 = tensorrt.constant dense<0.000000e+00> : tensor<f32>
%5 = tensorrt.cast %cst_f32_1 : tensor<f32> to tensor<bf16>
%6 = tensorrt.expand_rank %5 : tensor<bf16> to tensor<1x1x1xbf16>
%8 = tensorrt.cast %cst_f32 : tensor<f32> to tensor<bf16>
%9 = tensorrt.expand_rank %8 : tensor<bf16> to tensor<1x1x1xbf16>
%15 = tensorrt.element_wise <kMAX>(%arg0, %6 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16>
%16 = tensorrt.element_wise <kMIN>(%9, %15 : tensor<1x1x1xbf16>, tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16>
return %16 : tensor<16x1024x1024xbf16>
}