diff --git a/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/RaiseActivations.cpp b/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/RaiseActivations.cpp index 3e9714288..73422c1ae 100644 --- a/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/RaiseActivations.cpp +++ b/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/RaiseActivations.cpp @@ -23,6 +23,7 @@ //===----------------------------------------------------------------------===// #include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h" #include "mlir-tensorrt-dialect/TensorRT/Transforms/Passes.h" +#include "mlir-tensorrt-dialect/TensorRT/Utils/Utils.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/IR/Matchers.h" @@ -50,8 +51,12 @@ class RaiseActivations MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); - if (targetTensorRTVersion >= TensorRTVersion(10, 0)) + if (targetTensorRTVersion >= TensorRTVersion(10, 0)) { patterns.add(ctx); + patterns.add(ctx); + patterns.add(ctx); + } + patterns.add(ctx); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { emitError(getOperation()->getLoc()) diff --git a/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/RaiseActivations.pdll b/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/RaiseActivations.pdll index dac4c042c..32892f40e 100644 --- a/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/RaiseActivations.pdll +++ b/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/RaiseActivations.pdll @@ -47,6 +47,30 @@ Constraint TanhConstraintImpl(op: Op) [{ ActivationType::kTANH ); }]; +Constraint MaxConstraintImpl(op: Op) [{ + return mlir::success( + cast(op).getElementwiseOperation() == + ElementWiseOperation::kMAX + ); +}]; +Constraint MinConstraintImpl(op: Op) [{ + return mlir::success( + cast(op).getElementwiseOperation() == + ElementWiseOperation::kMIN + ); +}]; +Constraint PowConstraintImpl(op: Op) [{ + return mlir::success( + cast(op).getElementwiseOperation() == + ElementWiseOperation::kPOW + ); +}]; +Constraint ErfConstraintImpl(op: Op) [{ + return mlir::success( + cast(op).getUnaryOperation() == + UnaryOperation::kERF + ); +}]; Constraint MulConstraint(op: Op) -> Op { MulConstraintImpl(op); return op; @@ -59,6 +83,22 @@ Constraint TanhConstraint(op: Op) -> Op { TanhConstraintImpl(op); return op; } +Constraint MaxConstraint(op: Op) -> Op { + MaxConstraintImpl(op); + return op; +} +Constraint MinConstraint(op: Op) -> Op { + MinConstraintImpl(op); + return op; +} +Constraint PowConstraint(op: Op) -> Op { + PowConstraintImpl(op); + return op; +} +Constraint ErfConstraint(op: Op) -> Op { + ErfConstraintImpl(op); + return op; +} Constraint Mul(lhs: Value, rhs: Value) -> Op { return MulConstraint(op(lhs, rhs)); } @@ -68,17 +108,36 @@ Constraint Add(lhs: Value, rhs: Value) -> Op { Constraint Tanh(x: Value) -> Op { return TanhConstraint(op(x)); } +Constraint Max(lhs: Value, rhs: Value) -> Op { + return MaxConstraint(op(lhs, rhs)); +} +Constraint Min(lhs: Value, rhs: Value) -> Op { + return MinConstraint(op(lhs, rhs)); +} +Constraint Pow(lhs: Value, rhs: Value) -> Op { + return PowConstraint(op(lhs, rhs)); +} +Constraint Erf(x: Value) -> Op { + return ErfConstraint(op(x)); +} + +Rewrite GetSplatElementAttr(x: Value) -> Attr [{ + return *getSplatConstantElementAttribute(x); +}]; + +Constraint HasSplatElements(x: Value) [{ + return LogicalResult(getSplatConstantElementAttribute(x)); +}]; /// Is true if `x` is a constant op that has a splat constant /// where splat element is equal to `attr`. -Constraint SplatElements(x: Op, attr: Attr) [{ - DenseElementsAttr els{}; - if(!matchPattern(x, m_Constant(&els))) - return failure(); - if(!els.isSplat()) - return failure(); - Attribute value = els.getSplatValue(); - return success(value == attr); +Constraint SplatElements(x: Value, attr: Attr) [{ + FailureOr value = getSplatConstantElementAttribute(x); + if(LogicalResult(value).failed()) return failure(); + if(*value == attr) return success(); + FloatAttr fvalue = dyn_cast(*value); + FloatAttr fattr = dyn_cast(attr); + return success(fvalue && fattr && std::abs(fvalue.getValueAsDouble() - fattr.getValueAsDouble()) < .001); // handle different floating point type }]; /// We need a native C++ function since we can't create the right @@ -88,6 +147,16 @@ Rewrite CreateGeluTanh(x: Value) -> Op [{ x, ActivationType::kGELU_TANH, FloatAttr{}, FloatAttr{} ); }]; +Rewrite CreateGeluErf(x: Value) -> Op [{ + return rewriter.create(x.getLoc(), + x, ActivationType::kGELU_ERF, FloatAttr{}, FloatAttr{} + ); +}]; + +Rewrite CreateClipActivation(x: Value, min: Attr, max: Attr) -> Op [{ + return rewriter.create(x.getLoc(), + x, ActivationType::kCLIP, cast(max), cast(min)); +}]; Constraint TypesMatch(x: Value, y: Value) [{ return success(x.getType() == y.getType()); @@ -98,13 +167,13 @@ Constraint TypesMatch(x: Value, y: Value) [{ /// `https://github.com/google/jax/blob/main/jax/_src/nn/functions.py#L424-L455`. Pattern RaiseToGeluTanh { let x: Value; - let const0 = op(); + let const0: Value; SplatElements(const0, attr<"4.471500e-02 : f32">); - let rootPiOverTwo = op(); + let rootPiOverTwo: Value; SplatElements(rootPiOverTwo, attr<"0.797884583 : f32">); - let one = op(); + let one: Value; SplatElements(one, attr<"1.0 : f32">); - let half = op(); + let half: Value; SplatElements(half, attr<"0.5 : f32">); let scaledCube = Mul(Mul(Mul(x, x), x), const0); let tanArg = Mul(Add(x, scaledCube), rootPiOverTwo); @@ -119,3 +188,75 @@ Pattern RaiseToGeluTanh { replace root with replacement; }; } + +/// Raise a sequence of "approximate" GELU to `tensorrt.ext.gelu_tanh`. +/// Matching pattern of Ops from PyTorch/torch-mlir +Pattern RaiseToGeluTanh2 { + let x: Value; + let half2: Value; + SplatElements(half2, attr<"0.5 : f32">); + let one: Value; + SplatElements(one, attr<"1.0 : f32">); + let three: Value; + SplatElements(three, attr<"3.0 : f32">); + let const0: Value; + SplatElements(const0, attr<"4.471500e-02 : f32">); + let scaledCube = Mul(Pow(x, three), const0); + let half1: Value; + SplatElements(half1, attr<"0.5 : f32">); + let twoOverPi: Value; + SplatElements(twoOverPi, attr<"0.63661977236 : f32">); + let sqrt2pi = Pow(twoOverPi, half1); + let tanArg = Mul(sqrt2pi, Add(x, scaledCube)); + let inner = Add(Tanh(tanArg), one); + let root = Mul(Mul(x, half2), inner); + + // Sanity check for cases where we could have broadcasted x. + TypesMatch(root, x); + + rewrite root with { + let replacement = CreateGeluTanh(x); + replace root with replacement; + }; +} + +/// Raise a sequence of GELU to `tensorrt.ext.gelu_none`. +/// Matching pattern of Ops from PyTorch/torch-mlir +Pattern RaiseToGeluErf { + let x: Value; + let half: Value; + SplatElements(half, attr<"0.5 : f32">); + let one: Value; + SplatElements(one, attr<"1.0 : f32">); + let const0: Value; + SplatElements(const0, attr<"7.070310e-01 : f32">); + let erf = Erf(Mul(x, const0)); + let normalCdf = Mul(Add(erf, one), half); + let root = Mul(x, normalCdf); + + rewrite root with { + let replacement = CreateGeluErf(x); + replace root with replacement; + }; +} + +/// Raise a elementwise min/max to the CLIP activation +/// Matching pattern of Ops from PyTorch/torch-mlir +Pattern RaiseMaxMinToClip { + let x: Value; + let minValue: Value; + let maxValue: Value; + let root = Min(minValue, Max(x, maxValue)); + + TypesMatch(root, x); + + HasSplatElements(minValue); + HasSplatElements(maxValue); + + rewrite root with { + let min: Attr = GetSplatElementAttr(minValue); + let max: Attr = GetSplatElementAttr(maxValue); + let replacement = CreateClipActivation(x, min, max); + replace root with replacement; + }; +} diff --git a/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/raise-activations.mlir b/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/raise-activations.mlir index b032b027e..7c088fa57 100644 --- a/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/raise-activations.mlir +++ b/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/raise-activations.mlir @@ -25,3 +25,70 @@ func.func @raise_gelu(%arg0: tensor<12x128x4x12x1xf32>) -> (tensor<12x128x4x12x1 // TRT8-LABEL: func.func @raise_gelu // TRT8-NOT: kGELU_TANH + +// ----- + +// CHECK-LABEL: func.func @raise_gelu2 +// CHECK-SAME: (%[[arg0:.+]]: tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> { +// CHECK: %[[v0:.+]] = tensorrt.activation {activationType = #tensorrt.activation_type} %[[arg0]] : tensor<16x1024x1024xbf16> +// CHECK: return %[[v0]] : tensor<16x1024x1024xbf16> + +// TRT8-LABEL: func.func @raise_gelu2 +// TRT8-NOT: kGELU_TANH + +func.func @raise_gelu2(%arg0: tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> { + %cst_bf16 = tensorrt.constant dense<5.000000e-01> : tensor<1x1x1xbf16> + %cst_bf16_0 = tensorrt.constant dense<3.000000e+00> : tensor<1x1x1xbf16> + %cst_bf16_2 = tensorrt.constant dense<6.367190e-01> : tensor<1x1x1xbf16> + %cst_bf16_3 = tensorrt.constant dense<4.467770e-02> : tensor<1x1x1xbf16> + %cst_bf16_5 = tensorrt.constant dense<1.000000e+00> : tensor<1x1x1xbf16> + %0 = tensorrt.slice %cst_bf16[0, 0, 0][16, 1024, 1024][1, 1, 1] {mode = #tensorrt.slice_mode} : tensor<1x1x1xbf16> to tensor<16x1024x1024xbf16> + %5 = tensorrt.element_wise (%arg0, %cst_bf16 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16> + %6 = tensorrt.element_wise (%cst_bf16_2, %0 : tensor<1x1x1xbf16>, tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> + %7 = tensorrt.element_wise (%arg0, %cst_bf16_0 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16> + %8 = tensorrt.element_wise (%7, %cst_bf16_3 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16> + %9 = tensorrt.element_wise (%arg0, %8 : tensor<16x1024x1024xbf16>, tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> + %10 = tensorrt.element_wise (%6, %9 : tensor<16x1024x1024xbf16>, tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> + %11 = tensorrt.activation {activationType = #tensorrt.activation_type} %10 : tensor<16x1024x1024xbf16> + %12 = tensorrt.element_wise (%11, %cst_bf16_5 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16> + %13 = tensorrt.element_wise (%5, %12 : tensor<16x1024x1024xbf16>, tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> + return %13 : tensor<16x1024x1024xbf16> +} + +// ----- + +// CHECK-LABEL: func.func @raise_gelu_erf +// CHECK-SAME: (%[[arg0:.+]]: tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> { +// CHECK: %[[v0:.+]] = tensorrt.activation {activationType = #tensorrt.activation_type} %[[arg0]] : tensor<16x1024x1024xbf16> +// CHECK: return %[[v0]] : tensor<16x1024x1024xbf16> + +// TRT8-LABEL: func.func @raise_gelu_erf +// TRT8-NOT: kGELU_ERF + +func.func @raise_gelu_erf(%arg0: tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> { + %cst_bf16_1 = tensorrt.constant dense<1.000000e+00> : tensor<1x1x1xbf16> + %cst_bf16_2 = tensorrt.constant dense<5.000000e-01> : tensor<1x1x1xbf16> + %cst_bf16_3 = tensorrt.constant dense<7.070310e-01> : tensor<1x1x1xbf16> + %5 = tensorrt.element_wise (%arg0, %cst_bf16_3 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16> + %6 = tensorrt.unary {unaryOperation = #tensorrt.unary_operation} %5 : tensor<16x1024x1024xbf16> + %7 = tensorrt.element_wise (%6, %cst_bf16_1 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16> + %8 = tensorrt.element_wise (%7, %cst_bf16_2 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16> + %9 = tensorrt.element_wise (%arg0, %8 : tensor<16x1024x1024xbf16>, tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> + return %9 : tensor<16x1024x1024xbf16> +} + +// ----- + +// CHECK: @raise_min_max(%[[arg0:.+]]: tensor<16x1024x1024xbf16>) +// CHECK: #tensorrt.activation_type +func.func @raise_min_max(%arg0: tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> { + %cst_f32 = tensorrt.constant dense<6.000000e+00> : tensor + %cst_f32_1 = tensorrt.constant dense<0.000000e+00> : tensor + %5 = tensorrt.cast %cst_f32_1 : tensor to tensor + %6 = tensorrt.expand_rank %5 : tensor to tensor<1x1x1xbf16> + %8 = tensorrt.cast %cst_f32 : tensor to tensor + %9 = tensorrt.expand_rank %8 : tensor to tensor<1x1x1xbf16> + %15 = tensorrt.element_wise (%arg0, %6 : tensor<16x1024x1024xbf16>, tensor<1x1x1xbf16>) -> tensor<16x1024x1024xbf16> + %16 = tensorrt.element_wise (%9, %15 : tensor<1x1x1xbf16>, tensor<16x1024x1024xbf16>) -> tensor<16x1024x1024xbf16> + return %16 : tensor<16x1024x1024xbf16> +}