diff --git a/.github/workflows/mlir-tensorrt-ci.yml b/.github/workflows/mlir-tensorrt-ci.yml index f6854f952..54735a5f7 100644 --- a/.github/workflows/mlir-tensorrt-ci.yml +++ b/.github/workflows/mlir-tensorrt-ci.yml @@ -100,7 +100,7 @@ jobs: #!/bin/bash set -e python3 -m black --check --extend-exclude='.*\.pyi' mlir-tensorrt/compiler/ - python3 -m black --check --extend-exclude='.*\.pyi' mlir-tensorrt/python/ + python3 -m black --check --extend-exclude='.*\.pyi' mlir-tensorrt/integrations/python/ git clang-format HEAD~1 --diff EOF diff --git a/mlir-tensorrt/CMakeLists.txt b/mlir-tensorrt/CMakeLists.txt index b00019933..2cb09c49c 100644 --- a/mlir-tensorrt/CMakeLists.txt +++ b/mlir-tensorrt/CMakeLists.txt @@ -249,4 +249,4 @@ include_directories(${CMAKE_CURRENT_LIST_DIR}/tensorrt/include) include_directories(${CMAKE_CURRENT_BINARY_DIR}/tensorrt/include) add_subdirectory(compiler) -add_subdirectory(python) +add_subdirectory(integrations) diff --git a/mlir-tensorrt/build_tools/docker/Dockerfile b/mlir-tensorrt/build_tools/docker/Dockerfile index 04789ba6b..bb2996369 100644 --- a/mlir-tensorrt/build_tools/docker/Dockerfile +++ b/mlir-tensorrt/build_tools/docker/Dockerfile @@ -74,8 +74,8 @@ EOF ARG PYTHON_VERSION=3.10 ENV PYENV_ROOT="/pyenv" ENV PATH="/pyenv/bin:/pyenv/shims:$PATH" -COPY python/requirements-dev.txt /tmp/requirements-dev.txt -COPY python/requirements.txt /tmp/requirements.txt +COPY integrations/python/requirements-dev.txt /tmp/requirements-dev.txt +COPY integrations/python/requirements.txt /tmp/requirements.txt RUN <> value = std::nullopt) - : value(std::move(value)) {} - - bool isUninitialized() const { return !value.has_value(); } - - bool operator==(const BoundsArray &rhs) const { return value == rhs.value; } - - ArrayRef getValue() const { - assert(!isUninitialized()); - return *value; - } - - /// Return the most conservative integer scalar bounds for an dynamic/unknown - /// dimension extent. - static ConstantIntRanges getMaxDimRange(); - - /// Create a BoundsValue from the min/max bounds of shape. Using this method - /// ensures that the `value` are created with the correct storage bitwidth - /// (an implementation detail of the analysis). - static BoundsArray fromShapeBounds(ArrayRef min, - ArrayRef max); - - /// Create a `BoundsValue` using the given scalar values encoded as int64_t - /// values. However, when storing the bounds, use the given bitwidth. - /// TODO: remove this when we migrate away from using - /// `#tensorrt.shape_profile` for value bounds. - static BoundsArray fromIntegerValueBounds(unsigned bitwidth, - ArrayRef min, - ArrayRef max); - static BoundsArray fromIntegerValueBounds(ArrayRef min, - ArrayRef max); - - /// For the given tensor-typed value, return the most conservative bounds for - /// the shape of `v`. For each unknown dimension of the shape of `v` the - /// `getMaxDimRange()` bound is used. - static BoundsArray getMaxRangeForShapeBounds(Value v); - - /// For the given statically shaped integer tensor-typed value, return the - /// most conservative bounds for the value of `v`. - static BoundsArray getMaxRangeForValueBounds(Value v); - - /// For the given DenseIntElementsAttr, return a corresponding BoudnsValue - /// representing constant bounds as indicated by the attribute. - static BoundsArray getFromConstantValue(DenseIntElementsAttr attr); - - /// Join two BoundsValues by performing a pointwise union of the integer - /// scalar a ranges. - static BoundsArray join(const BoundsArray &lhs, const BoundsArray &rhs); - - /// Meet two BoundsValues by performing a pointwise intersection of the - /// integer scalar a ranges. - static BoundsArray meet(const BoundsArray &lhs, const BoundsArray &rhs); - - /// Print a human-readable representation of the bounds. - void print(raw_ostream &os) const; - - /// Return the min/max bounds representation as two DenseElementsAttrs. - std::pair - getAsElementsAttr(RankedTensorType type) const; - - /// Returns DenseElementsAttr representation if the element ranges are all - /// constant (single-value) ranges, otherwise nullopt. - std::optional - getConstantValues(RankedTensorType type) const; - -private: - std::optional> value; -}; - -llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const BoundsArray &v); +using BoundsArray = mlirtrt::compiler::BoundsArray; //===----------------------------------------------------------------------===// // Shape Bounds Analyses diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/CMakeLists.txt b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/CMakeLists.txt index d9fe0c8fb..14f6f9a3d 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/CMakeLists.txt +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/CMakeLists.txt @@ -17,4 +17,6 @@ add_public_tablegen_target(MLIRTensorRTPlanDialectAttributesIncGen) set(LLVM_TARGET_DEFINITIONS PlanInterfaces.td) mlir_tablegen(PlanAttrInterfaces.h.inc -gen-attr-interface-decls) mlir_tablegen(PlanAttrInterfaces.cpp.inc -gen-attr-interface-defs) +mlir_tablegen(PlanOpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(PlanOpInterfaces.cpp.inc -gen-op-interface-defs) add_public_tablegen_target(MLIRTensorRTPlanDialectAttrInterfacesIncGen) diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/Plan.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/Plan.h index 5f95faf3f..ff11571a2 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/Plan.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/Plan.h @@ -27,6 +27,7 @@ #include "mlir-tensorrt-dialect/Interface/TensorKindOpInterface.h" #include "mlir-tensorrt/Compiler/Extension.h" #include "mlir-tensorrt/Dialect/Plan/IR/PlanInterfaces.h" +#include "mlir-tensorrt/Interfaces/InferTensorValueRangeInterface.h" #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanOps.td b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanOps.td index 2932ab6d0..2fab72fd3 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanOps.td +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanOps.td @@ -8,6 +8,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir-tensorrt/Interfaces/InferTensorValueRangeInterface.td" include "mlir/IR/OpAsmInterface.td" class Plan_NativeOpTrait, + DeclareOpInterfaceMethods, AllTypesMatch<["operand", "result"]>]> { let summary = "Ties a tensor value with index SSA values representing its element values"; diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/Transforms/Passes.td b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/Transforms/Passes.td index cf22de6c3..e7538edea 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/Transforms/Passes.td +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/Transforms/Passes.td @@ -258,9 +258,9 @@ def ClusteringPass : Pass<"plan-clustering", "::mlir::ModuleOp"> { operations will be compiled. The kinds of clusters that can be formed and the specific rules for - clustering are defined by the clustering configuration specified + clustering are defined by the clustering configuration specified by the module's `plan.cluster_kinds` attribute. This is an array of - attributes which all implement the + attributes which all implement the [ClusterKindAttrInterface](../IR/PlanInterfaces.td). }]; @@ -585,5 +585,35 @@ def PlanOwnershipBasedBufferDeallocationPass : Pass< ]; } +//===----------------------------------------------------------------------===// +// PlanOutlineConstantFoldableSubgraphs +//===----------------------------------------------------------------------===// + +def PlanOutlineConstantFoldableSubgraphsPass : Pass< + "plan-outline-constant-foldable-subgraphs", + "::mlir::ModuleOp"> { + let summary = "Analyze and outline constant foldable subgraphs"; + + let description = [{ + This pass implements forward dataflow analysis (named `SparseConstantFoldabilityAnalysis`) + to find out constant foldable ops. This analysis, unlike upstream + `ConstantPropagationAnalysis` is very simple and works only for pure ops. + If all operands of an operation are constant foldable, all results are marked + as constant foldable. + Constant foldability analysis is then used along with clustering to + find constant foldable subgraphs. These constant foldable subgraphs are + finally outlined to a private function with `plan.constant_foldable` attribute. + }]; + + let options = [ + Option<"skipClustering", "skip-clustering", + "std::function", /*default=*/"nullptr", + "This option enables user to extend default pass behavior and skip " + "more ops from clustering. If this method returns true, `op` is not " + "clustered. When op is not clustered, it is not outlined for constant " + "folding. This is helpful in avoiding clustering of ops that can't be " + "run e2e at compile time, in the workflow of user's choice.">, + ]; +} #endif // MLIR_TENSORRT_DIALECT_PLAN_TRANSFORMS_PASSES_TD diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/StablehloExt/IR/StableHloExt.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/StablehloExt/IR/StableHloExt.h index fe61e8c56..3637d027c 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/StablehloExt/IR/StableHloExt.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/StablehloExt/IR/StableHloExt.h @@ -34,6 +34,10 @@ void registerTensorKindOpInterfaceExternalModels(DialectRegistry ®istry); /// Register StableHlo op implementations for ReifyRankedShapedTypeOpInterface. void registerTypeInferenceExternalModels(DialectRegistry ®istry); +/// Register StableHlo op implementations for InferTensorValueRangeInterface. +void registerInferTensorValueRangeInterfaceExternalModels( + DialectRegistry ®istry); + } // namespace mlir::stablehlo #endif // MLIR_TENSORRT_DIALECT_STABLEHLOEXT_IR_STABLEHLOEXT_H diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/InitAllDialects.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/InitAllDialects.h index 51850b4ba..17cc3125a 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/InitAllDialects.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/InitAllDialects.h @@ -189,6 +189,8 @@ inline void registerAllDialects(mlir::DialectRegistry ®istry) { mlir::vector::registerValueBoundsOpInterfaceExternalModels(registry); IF_MLIR_TRT_ENABLE_HLO({ + mlir::stablehlo::registerInferTensorValueRangeInterfaceExternalModels( + registry); mlir::stablehlo::registerTensorKindOpInterfaceExternalModels(registry); mlir::stablehlo::registerTypeInferenceExternalModels(registry); }); diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Interfaces/InferTensorValueRangeInterface.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Interfaces/InferTensorValueRangeInterface.h new file mode 100644 index 000000000..8904296f3 --- /dev/null +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Interfaces/InferTensorValueRangeInterface.h @@ -0,0 +1,154 @@ +//===- InferTensorValueRangeInterface.h --------------------------*- C++ +//-*-===// +// +// SPDX-FileCopyrightText: Copyright 2025 NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// Declarations for InferTensorValueRangeInterface. +/// +//===----------------------------------------------------------------------===// +#ifndef MLIR_TENSORRT_INTERFACES_INFERTENSORVALUERANGEINTERFACE +#define MLIR_TENSORRT_INTERFACES_INFERTENSORVALUERANGEINTERFACE + +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include + +namespace mlirtrt::compiler { + +//===----------------------------------------------------------------------===// +// BoundsArray +//===----------------------------------------------------------------------===// + +/// A BoundsArray is simply an array of mlir::ConstantIntRanges used to +/// represent either the bounds on a shape of a tensor-typed SSA value or the +/// bounds of the element values of a statically shaped integer tensor-typed SSA +/// value. When it is used to represent the bounds for the value of a tensor, we +/// use a canonical packed generalized row-major layout mapping from tensor +/// coordinates to storage index. +class BoundsArray { +public: + BoundsArray() : value(std::nullopt) {} + + BoundsArray(llvm::ArrayRef value) + : value(std::make_optional(llvm::to_vector(value))) {} + + bool isUninitialized() const { return !value.has_value(); } + + bool operator==(const BoundsArray &rhs) const { return value == rhs.value; } + + llvm::ArrayRef getValue() const { + assert(!isUninitialized()); + return *value; + } + + /// Return the most conservative integer scalar bounds for an dynamic/unknown + /// dimension extent. + static mlir::ConstantIntRanges getMaxDimRange(); + + /// Create a BoundsValue from the min/max bounds of shape. Using this method + /// ensures that the `value` are created with the correct storage bitwidth + /// (an implementation detail of the analysis). + static BoundsArray fromShapeBounds(llvm::ArrayRef min, + llvm::ArrayRef max); + + /// Create a `BoundsValue` using the given scalar values encoded as int64_t + /// values. However, when storing the bounds, use the given bitwidth. + /// TODO: remove this when we migrate away from using + /// `#tensorrt.shape_profile` for value bounds. + static BoundsArray fromIntegerValueBounds(unsigned bitwidth, + llvm::ArrayRef min, + llvm::ArrayRef max); + static BoundsArray fromIntegerValueBounds(llvm::ArrayRef min, + llvm::ArrayRef max); + + /// For the given tensor-typed value, return the most conservative bounds for + /// the shape of `v`. For each unknown dimension of the shape of `v` the + /// `getMaxDimRange()` bound is used. + static BoundsArray getMaxRangeForShapeBounds(mlir::Value v); + + /// For the given statically shaped integer tensor-typed value, return the + /// most conservative bounds for the value of `v`. + static BoundsArray getMaxRangeForValueBounds(mlir::Value v); + + /// For the given DenseIntElementsAttr, return a corresponding BoudnsValue + /// representing constant bounds as indicated by the attribute. + static BoundsArray getFromConstantValue(mlir::DenseIntElementsAttr attr); + + /// Join two BoundsValues by performing a pointwise union of the integer + /// scalar a ranges. + static BoundsArray join(const BoundsArray &lhs, const BoundsArray &rhs); + + /// Meet two BoundsValues by performing a pointwise intersection of the + /// integer scalar a ranges. + static BoundsArray meet(const BoundsArray &lhs, const BoundsArray &rhs); + + /// Print a human-readable representation of the bounds. + void print(llvm::raw_ostream &os) const; + + /// Return the min/max bounds representation as two DenseElementsAttrs. + std::pair + getAsElementsAttr(mlir::RankedTensorType type) const; + + /// Returns DenseElementsAttr representation if the element ranges are all + /// constant (single-value) ranges, otherwise nullopt. + std::optional + getConstantValues(mlir::RankedTensorType type) const; + + /// The maximum allowed volume of a tensor that we allow tracking the value + /// of. This is used to avoid edge cases where tracking the bounds would + /// require an excess amount of memory. + static constexpr int64_t kMaxVolumeThreshold = 32; + + /// Whether the analysis should consider a value. To consider + /// a value, it must be a ranked tensor of static shape and signless-or-index + /// integer element type and have a total volume <= kMaxVolumeThreshold. + static bool shouldAnalyzeValueBounds(mlir::Type type); + + /// Whether the analysis should consider a value. To consider + /// a value, it must be a ranked tensor of static shape and signless-or-index + /// integer element type and have a total volume <= kMaxVolumeThreshold. + static bool shouldAnalyzeValueBounds(mlir::Value value); + +private: + std::optional> value; +}; + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const BoundsArray &v); + +/// Represents either a BoundsArray lattice or a InterValueRange lattice. +struct IntOrTensorValueRange + : public llvm::PointerUnion { + using PointerUnion::PointerUnion; +}; + +/// Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values. +/// This is the `setResultRanges` callback for the BoundsArray based +/// interface method. +using SetTensorValueLatticeFn = + llvm::function_ref; + +class InferTensorValueRangeInterface; + +namespace detail {} // namespace detail + +} // namespace mlirtrt::compiler + +#include "mlir-tensorrt/Interfaces/InferTensorValueRangeInterface.h.inc" + +#endif // MLIR_TENSORRT_INTERFACES_INFERTENSORVALUERANGEINTERFACE diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Interfaces/InferTensorValueRangeInterface.td b/mlir-tensorrt/compiler/include/mlir-tensorrt/Interfaces/InferTensorValueRangeInterface.td new file mode 100644 index 000000000..b728cecf6 --- /dev/null +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Interfaces/InferTensorValueRangeInterface.td @@ -0,0 +1,40 @@ +#ifndef MLIR_TENSORRT_INTERFACES_INFERTENSORVALUERANGEINTERFACE +#define MLIR_TENSORRT_INTERFACES_INFERTENSORVALUERANGEINTERFACE + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// InferTensorValueRangeInterface +//===----------------------------------------------------------------------===// + +def InferTensorValueRangeInterface : OpInterface<"InferTensorValueRangeInterface"> { + let description = [{ + Allows operations to participate in range analysis for tensor values by + providing a methods that allows them to specify lower and upper bounds on their + result(s) given lower and upper bounds on their input(s) if known. + }]; + let cppNamespace = "::mlirtrt::compiler"; + + let methods = [ + InterfaceMethod<[{ + Infer the bounds on the results of this op given the lattice representation + of the bounds for its arguments. For each result value or block argument + (that isn't a branch argument, since the dataflow analysis handles + those case), the method should call `setValueRange` with that `Value` + as an argument. When implemented, `setValueRange` should be called on + all result values for the operation. + + This method allows for more precise implementations when operations + want to reason about inputs which may be undefined during the analysis. + }], + /*retTy=*/"void", + /*methodName=*/"inferResultRangesFromOptional", + /*args=*/(ins "::llvm::ArrayRef<::mlirtrt::compiler::IntOrTensorValueRange>":$argRanges, + "::mlirtrt::compiler::SetTensorValueLatticeFn":$setResultRanges), + /*methodBody=*/"", + /*defaultImplementation=*/[{}]>, + ]; +} + + +#endif // MLIR_TENSORRT_INTERFACES_INFERTENSORVALUERANGEINTERFACE diff --git a/mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/Passes.cpp b/mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/Passes.cpp index 45353ef7d..2d1462bf7 100644 --- a/mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/Passes.cpp +++ b/mlir-tensorrt/compiler/lib/Compiler/TensorRTToExecutable/Passes.cpp @@ -216,6 +216,44 @@ static SmallVector makeRegionIsolatedFromAboveImpl( static FailureOr outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule, const Cluster &cluster) { + auto parentFunc = cluster.getRoot()->getParentOfType(); + + auto reorderYieldValues = [&](SetVector &yieldValues, + SmallVectorImpl &yieldTypes) { + auto term = dyn_cast( + parentFunc.getFunctionBody().front().getTerminator()); + if (!term) + return; + if (term->getNumOperands() != yieldValues.size()) + return; + + DenseMap termValueOrder; + DenseMap termTypeOrder; + for (const auto &it : llvm::enumerate(term.getOperands())) { + termValueOrder[it.value()] = it.index(); + termTypeOrder[it.value().getType()] = it.index(); + } + + // Make sure each yielded value is terminator operand. + if (llvm::any_of(yieldValues, + [&](Value v) { return !termValueOrder.contains(v); })) + return; + + SmallVector valuesVector(yieldValues.begin(), yieldValues.end()); + + // Sort both values and type. + llvm::stable_sort(valuesVector, [&](Value lhs, Value rhs) { + return termValueOrder[lhs] < termValueOrder[rhs]; + }); + llvm::stable_sort(yieldTypes, [&](Type lhs, Type rhs) { + return termTypeOrder[lhs] < termTypeOrder[rhs]; + }); + + // Write sorted values back to set vector + yieldValues.clear(); + yieldValues.insert(valuesVector.begin(), valuesVector.end()); + }; + auto inlineGroupOp = cast(mlir::createRegionOpFromCluster( cluster, rewriter, @@ -224,10 +262,10 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule, b.setInsertionPointToStart(®ionOp.getRegion().emplaceBlock()); b.create(loc); return regionOp; - })); + }, + reorderYieldValues)); // Make the region isolated from above. This captures the input operands. - auto parentFunc = cluster.getRoot()->getParentOfType(); SmallVector inputs = makeRegionIsolatedFromAboveImpl( rewriter, inlineGroupOp.getRegion(), parentFunc, {}); @@ -240,8 +278,6 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule, StringRef tensorrtShapeBoundsAttrName = mlir::tensorrt::TensorRTDialect::getShapeProfileArgAttrName(); - func::FuncOp funcContainingCluster = - cluster.back()->getParentOfType(); SmallVector profileAttrsPerInput; for (Value v : inputs) { auto rtt = dyn_cast(v.getType()); @@ -251,8 +287,7 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule, } auto blockArg = dyn_cast(v); - if (!blockArg || - blockArg.getOwner()->getParentOp() != funcContainingCluster) { + if (!blockArg || blockArg.getOwner()->getParentOp() != parentFunc) { return emitError(blockArg.getLoc()) << "Block argument is not part of the signature of the function " "containing this TRT cluster"; @@ -260,7 +295,7 @@ outlineOp(RewriterBase &rewriter, tensorrt::TensorRTModuleOp trtModule, int64_t argIndex = blockArg.getArgNumber(); profileAttrsPerInput.push_back( - funcContainingCluster.getArgAttrOfType( + parentFunc.getArgAttrOfType( argIndex, tensorrtShapeBoundsAttrName)); if (!profileAttrsPerInput.back()) { diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Analysis/BoundsAnalysis.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Analysis/BoundsAnalysis.cpp index 3b348dc4e..98280450f 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Analysis/BoundsAnalysis.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Analysis/BoundsAnalysis.cpp @@ -25,11 +25,13 @@ #include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h" #include "mlir-tensorrt/Dialect/Plan/IR/Plan.h" #include "mlir-tensorrt/Dialect/Plan/IR/PlanInterfaces.h" +#include "mlir-tensorrt/Interfaces/InferTensorValueRangeInterface.h" #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferIntRangeInterface.h" @@ -39,6 +41,7 @@ using namespace mlir; using namespace mlir::dataflow; using namespace mlir::plan; +using namespace mlirtrt::compiler; #define DEBUG_TYPE "plan-bounds-analysis" #define DBGS(x) llvm::dbgs() << " [" DEBUG_TYPE "][" x "] " @@ -70,154 +73,6 @@ static bool hasShapeFuncMarker(Value value, StringRef attrName) { return false; } -//===----------------------------------------------------------------------===// -// BoundsValue -//===----------------------------------------------------------------------===// - -ConstantIntRanges BoundsArray::getMaxDimRange() { - APInt smin = APInt(IndexType::kInternalStorageBitWidth, 0); - APInt smax = APInt(IndexType::kInternalStorageBitWidth, - std::numeric_limits::max()); - return ConstantIntRanges::fromSigned(smin, smax); -} - -BoundsArray BoundsArray::getMaxRangeForShapeBounds(Value v) { - auto type = cast(v.getType()); - SmallVector ranges; - ranges.reserve(type.getRank()); - for (int64_t dim : type.getShape()) { - if (ShapedType::isDynamic(dim)) { - ranges.push_back(getMaxDimRange()); - continue; - } - ranges.push_back(ConstantIntRanges::constant(APInt(64, dim))); - } - return BoundsArray(std::move(ranges)); -} - -BoundsArray BoundsArray::getMaxRangeForValueBounds(Value v) { - assert(TensorValueBoundsAnalysis::shouldAnalyzeValueBounds(v) && - "value is unsuitable for analysis"); - Type elementType = mlir::getElementTypeOrSelf(v); - unsigned numBits = ConstantIntRanges::getStorageBitwidth(elementType); - APInt smin = APInt::getSignedMinValue(numBits); - APInt smax = APInt::getSignedMaxValue(numBits); - SmallVector ranges( - cast(v.getType()).getNumElements(), - ConstantIntRanges::fromSigned(smin, smax)); - return BoundsArray(std::move(ranges)); -} - -BoundsArray BoundsArray::getFromConstantValue(DenseIntElementsAttr v) { - assert(TensorValueBoundsAnalysis::shouldAnalyzeValueBounds(v.getType()) && - "attribute type is unsuitable for creating value bound state"); - SmallVector ranges; - ranges.reserve(cast(v.getType()).getNumElements()); - for (const APInt &element : v.getValues()) - ranges.push_back(ConstantIntRanges::constant(element)); - return BoundsArray(std::move(ranges)); -} - -BoundsArray BoundsArray::fromShapeBounds(ArrayRef min, - ArrayRef max) { - SmallVector res; - for (auto [l, r] : llvm::zip_equal(min, max)) - res.push_back(ConstantIntRanges::fromSigned(APInt(64, l), APInt(64, r))); - return BoundsArray(std::move(res)); -} - -BoundsArray BoundsArray::fromIntegerValueBounds(unsigned bitWidth, - ArrayRef min, - ArrayRef max) { - SmallVector res; - for (auto [l, r] : llvm::zip_equal(min, max)) - res.push_back( - ConstantIntRanges::fromSigned(APInt(64, l).sextOrTrunc(bitWidth), - APInt(64, r).sextOrTrunc(bitWidth))); - return BoundsArray(std::move(res)); -} - -BoundsArray BoundsArray::fromIntegerValueBounds(ArrayRef min, - ArrayRef max) { - SmallVector res; - for (auto [l, r] : llvm::zip_equal(min, max)) - res.push_back(ConstantIntRanges::fromSigned(l, r)); - return BoundsArray(std::move(res)); -} - -BoundsArray BoundsArray::join(const BoundsArray &lhs, const BoundsArray &rhs) { - if (lhs.isUninitialized()) - return rhs; - if (rhs.isUninitialized()) - return lhs; - SmallVector res; - for (auto [l, r] : llvm::zip_equal(lhs.getValue(), rhs.getValue())) - res.push_back(l.rangeUnion(r)); - return BoundsArray(std::move(res)); -} - -BoundsArray BoundsArray::meet(const BoundsArray &lhs, const BoundsArray &rhs) { - if (lhs.isUninitialized()) - return rhs; - if (rhs.isUninitialized()) - return lhs; - SmallVector res; - for (auto [l, r] : llvm::zip_equal(lhs.getValue(), rhs.getValue())) - res.push_back(l.intersection(r)); - return BoundsArray(std::move(res)); -} - -void BoundsArray::print(raw_ostream &os) const { - if (!value) { - os << "<>"; - return; - } - os << "<"; - llvm::interleaveComma(*value, os, [&](const ConstantIntRanges &r) { - os << "[" << r.smin() << ", " << r.smax() << "]"; - }); - os << ">"; -} - -llvm::raw_ostream &plan::operator<<(llvm::raw_ostream &os, - const BoundsArray &v) { - v.print(os); - return os; -} - -std::pair -BoundsArray::getAsElementsAttr(RankedTensorType type) const { - assert(!isUninitialized() && "expected initialized value"); - assert(type.getNumElements() == static_cast(value->size()) && - "specified tensor type's volume does not match lattice value volume"); - SmallVector lbs; - lbs.reserve(type.getNumElements()); - SmallVector ubs; - ubs.reserve(type.getNumElements()); - for (const ConstantIntRanges &r : *value) { - lbs.push_back(r.smin()); - ubs.push_back(r.smax()); - } - return std::make_pair(DenseElementsAttr::get(type, lbs), - DenseElementsAttr::get(type, ubs)); -} - -/// Returns true if the element ranges are constant (single-value) ranges. -std::optional -BoundsArray::getConstantValues(RankedTensorType type) const { - assert(!isUninitialized() && "expected initialized value"); - assert(type.getNumElements() == static_cast(value->size()) && - "specified tensor type's volume does not match lattice value volume"); - SmallVector lbs; - lbs.reserve(type.getNumElements()); - for (const ConstantIntRanges &r : *value) { - if (r.smin() != r.smax()) - return {}; - lbs.push_back(r.smin()); - } - return DenseElementsAttr::get(type, lbs); -} - //===----------------------------------------------------------------------===// // Utilities //===----------------------------------------------------------------------===// @@ -709,13 +564,11 @@ LogicalResult ShapeIntegerRangeAnalysis::visitOperation( //===----------------------------------------------------------------------===// bool TensorValueBoundsAnalysis::shouldAnalyzeValueBounds(Type type) { - if (auto rtt = dyn_cast(type)) - return rtt.getElementType().isSignlessIntOrIndex() && - rtt.hasStaticShape() && rtt.getNumElements() <= kMaxVolumeThreshold; - return false; + return BoundsArray::shouldAnalyzeValueBounds(type); } + bool TensorValueBoundsAnalysis::shouldAnalyzeValueBounds(Value value) { - return shouldAnalyzeValueBounds(value.getType()); + return BoundsArray::shouldAnalyzeValueBounds(value); } void TensorValueBoundsAnalysis::setToEntryState( @@ -739,55 +592,46 @@ void TensorValueBoundsAnalysis::setToEntryState( shapeProfile->getMaxValues().getValues())))); } -static void maybePopulateConstantValueBounds( - Value point, - llvm::function_ref)> joinCallback) { - if (!TensorValueBoundsAnalysis::shouldAnalyzeValueBounds(point)) - return; - DenseIntElementsAttr attr; - if (!matchPattern(point, m_Constant(&attr))) - return; - BoundsArray val = BoundsArray::getFromConstantValue(attr); - joinCallback(point, val.getValue()); -} +LogicalResult TensorValueBoundsAnalysis::visitOperation( + Operation *op, ArrayRef operands, + ArrayRef results) { -static FailureOr> -intersectTensorValueBoundsAndScalarBounds( - ArrayRef scalarBounds, - const TensorValueBoundsLattice *tensorBounds) { - SmallVector ranges; - ranges.reserve(scalarBounds.size()); - for (unsigned i = 0, e = scalarBounds.size(); i < e; i++) { - const IntegerValueRangeLattice *scalar = scalarBounds[i]; - bool scalarIsInvalid = !scalar || scalar->getValue().isUninitialized(); - if (tensorBounds && !tensorBounds->getValue().isUninitialized()) { - if (!scalarIsInvalid) { - ranges.push_back(scalar->getValue().getValue().intersection( - tensorBounds->getValue().getValue()[i])); - continue; - } - ranges.push_back(tensorBounds->getValue().getValue()[i]); + if (!isa(op) && + !op->hasTrait()) { + setAllToEntryStates(results); + return success(); + } + + LLVM_DEBUG(DBGS("TensorValueBoundsAnalysis") << "visiting " << *op << "\n"); + + SmallVector argRanges; + for (auto [idx, operand] : llvm::enumerate(op->getOperands())) { + if (isa(operand.getType()) && + shouldAnalyzeValueBounds(operand)) { + if (operands[idx]) + argRanges.emplace_back(&operands[idx]->getValue()); + else + argRanges.emplace_back(nullptr); continue; } - if (!scalarIsInvalid) { - ranges.push_back(scalarBounds[i]->getValue().getValue()); + if (isa(operand.getType())) { + const auto *scalarLattice = + this->getOrCreateFor( + getProgramPointAfter(op), operand); + if (scalarLattice) { + argRanges.emplace_back(&scalarLattice->getValue()); + } else { + argRanges.emplace_back(nullptr); + } continue; } - return failure(); + setAllToEntryStates(results); + return success(); } - return ranges; -} - -LogicalResult TensorValueBoundsAnalysis::visitOperation( - Operation *op, ArrayRef operands, - ArrayRef results) { - LLVM_DEBUG(DBGS("TensorValueBoundsAnalysis") << "visiting " << *op << "\n"); - - auto joinCallback = [&](Value v, ArrayRef attrs) { - assert(shouldAnalyzeValueBounds(v) && "value is unsuitable for analysis"); + auto joinCallback = [&](Value v, BoundsArray newRange) { auto result = dyn_cast(v); if (!result) return; @@ -795,7 +639,6 @@ LogicalResult TensorValueBoundsAnalysis::visitOperation( TensorValueBoundsLattice *lattice = results[result.getResultNumber()]; const BoundsArray &oldRanges = lattice->getValue(); - BoundsArray newRange{llvm::to_vector(attrs)}; LLVM_DEBUG(DBGS("TensorValueBoundsAnalysis") << "inferred " << newRange << " for\n\t" << v << "\n"); @@ -818,33 +661,30 @@ LogicalResult TensorValueBoundsAnalysis::visitOperation( propagateIfChanged(lattice, changed); }; - // If the value is produced by constant op, populate ranges appropriately. - // NOTE: we should instead use the mechanism from ConstantIntRanges lattice? - for (TensorValueBoundsLattice *lattice : results) { - Value point = lattice->getAnchor(); - if (!shouldAnalyzeValueBounds(point)) - continue; - maybePopulateConstantValueBounds(point, joinCallback); + if (op->hasTrait() && op->getNumResults() == 1) { + // If the value is produced by constant op, populate ranges appropriately. + // NOTE: we should instead use the mechanism from ConstantIntRanges lattice? + Value point = results[0]->getAnchor(); + if (!shouldAnalyzeValueBounds(point)) { + setAllToEntryStates(results); + return success(); + } + DenseIntElementsAttr attr; + if (!matchPattern(point, m_Constant(&attr))) { + setAllToEntryStates(results); + return success(); + } + joinCallback(point, BoundsArray::getFromConstantValue(attr)); + return success(); } - auto getScalarLatticeValues = [&](ValueRange scalars) { - SmallVector result; - result.reserve(scalars.size()); - for (Value v : scalars) - result.emplace_back(this->getOrCreateFor( - getProgramPointAfter(op), v)); - return result; - }; - - if (auto withOp = dyn_cast(op)) { - FailureOr> ranges = - intersectTensorValueBoundsAndScalarBounds( - getScalarLatticeValues(withOp.getElements()), operands[0]); - if (failed(ranges)) - return success(); - joinCallback(withOp.getResult(), *ranges); + auto inferrable = dyn_cast(op); + if (!inferrable) { + setAllToEntryStates(results); return success(); } + + inferrable.inferResultRangesFromOptional(argRanges, joinCallback); return success(); } diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Analysis/CMakeLists.txt b/mlir-tensorrt/compiler/lib/Dialect/Plan/Analysis/CMakeLists.txt index 991414861..4d318bbf7 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Analysis/CMakeLists.txt +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Analysis/CMakeLists.txt @@ -2,8 +2,9 @@ add_mlir_tensorrt_library(MLIRTensorRTPlanAnalysis BoundsAnalysis.cpp LINK_LIBS PUBLIC - MLIRTensorRTPlanDialect - MLIRTensorRTDialect MLIRAnalysis + MLIRTensorRTDialect + MLIRTensorRTInferTensorValueRangeInterface + MLIRTensorRTPlanDialect MLIRValueBoundsOpInterface ) diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/IR/CMakeLists.txt b/mlir-tensorrt/compiler/lib/Dialect/Plan/IR/CMakeLists.txt index eeef55830..7a53fe5c4 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/IR/CMakeLists.txt +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/IR/CMakeLists.txt @@ -16,6 +16,7 @@ add_mlir_tensorrt_dialect_library(MLIRTensorRTPlanDialect MLIRInferTypeOpInterface MLIRIR MLIRSupport + MLIRTensorRTInferTensorValueRangeInterface MLIRTensorRTInterfaces MLIRTensorRTSupportStatus MLIRTensorRTUtilsShapeInfo diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/IR/PlanOps.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/IR/PlanOps.cpp index e051c202b..bf8a130c8 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/IR/PlanOps.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/IR/PlanOps.cpp @@ -23,8 +23,10 @@ //===----------------------------------------------------------------------===// #include "mlir-tensorrt-dialect/Interface/TensorKindOpInterface.h" #include "mlir-tensorrt/Dialect/Plan/IR/Plan.h" +#include "mlir-tensorrt/Interfaces/InferTensorValueRangeInterface.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" @@ -42,6 +44,7 @@ using namespace mlir; using namespace mlir::plan; +using namespace mlirtrt::compiler; //===----------------------------------------------------------------------===// // MemorySpaceAttr @@ -663,6 +666,48 @@ void WithShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, // WithValuesOp //===----------------------------------------------------------------------===// +void WithValuesOp::inferResultRangesFromOptional( + ArrayRef argBounds, + SetTensorValueLatticeFn setResultRanges) { + if (!BoundsArray::shouldAnalyzeValueBounds(getResult())) { + setResultRanges(getResult(), BoundsArray()); + return; + } + + const auto *tensorBounds = argBounds.front().dyn_cast(); + + SmallVector ranges; + ArrayRef scalarBounds = argBounds.drop_front(); + ranges.reserve(scalarBounds.size()); + for (unsigned i = 0, e = scalarBounds.size(); i < e; i++) { + const auto *scalarBound = + scalarBounds[i].dyn_cast(); + bool scalarIsInvalid = !scalarBound || scalarBound->isUninitialized(); + if (tensorBounds && !tensorBounds->isUninitialized()) { + assert( + tensorBounds->getValue().size() == scalarBounds.size() && + "expected number of tensor bounds to equal number of scalar bounds"); + if (!scalarIsInvalid) { + ranges.push_back( + scalarBound->getValue().rangeUnion(tensorBounds->getValue()[i])); + continue; + } + ranges.push_back(tensorBounds->getValue()[i]); + continue; + } + + if (!scalarIsInvalid) { + ranges.push_back(scalarBound->getValue()); + continue; + } + + setResultRanges(getResult(), BoundsArray()); + return; + } + + setResultRanges(getResult(), BoundsArray(ranges)); +} + static ParseResult parseWithValuesTypes(OpAsmParser &parser, ArrayRef elements, diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/CMakeLists.txt b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/CMakeLists.txt index be4d4ffc3..b9705d975 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/CMakeLists.txt +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/CMakeLists.txt @@ -13,6 +13,7 @@ add_mlir_tensorrt_library(MLIRTensorRTPlanTransforms ModuleBufferization/ModuleBufferizationUtils.cpp ModuleBufferization/RemoveEquivalentBufferResults.cpp OutlineClusters.cpp + OutlineConstantFoldableSubgraphs.cpp Passes.cpp PopulateFunctionBoundsAttributes.cpp PostClusteringValidation.cpp diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OutlineConstantFoldableSubgraphs.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OutlineConstantFoldableSubgraphs.cpp new file mode 100644 index 000000000..453814c80 --- /dev/null +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/OutlineConstantFoldableSubgraphs.cpp @@ -0,0 +1,516 @@ +//===- OutlineConstantFoldableSubgraphs.cpp -------------------------------===// +// +// SPDX-FileCopyrightText: Copyright 2025 NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// Implementation of the `outline-constant-foldable-subgraphs` pass. +/// +//===----------------------------------------------------------------------===// + +#include "mlir-executor/Transforms/Clustering/Clustering.h" +#include "mlir-tensorrt/Dialect/Plan/Transforms/Passes.h" +#include "mlir-tensorrt/Utils/DataFlowUtils.h" +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" + +namespace mlir::plan { +#define GEN_PASS_DEF_PLANOUTLINECONSTANTFOLDABLESUBGRAPHSPASS +#include "mlir-tensorrt/Dialect/Plan/Transforms/Passes.h.inc" +} // namespace mlir::plan + +using namespace mlir; + +/// State for constant foldability analysis. +/// true +/// | +/// uninitialized(⊥) +/// We don't have top (⊤) element (generally, unknown information) +/// because constant foldability analysis is definitive on pure ops. +/// Note that we don't have false state. +namespace { +class ConstantFoldabilityState { +public: + ConstantFoldabilityState(std::optional value = std::nullopt) + : value(std::move(value)) {} + + bool isInitialized() const { return value.has_value(); } + bool isUninitialized() const { return !value.has_value(); } + static ConstantFoldabilityState getUninitialized() { + return ConstantFoldabilityState{}; + } + bool getKnownState() const { + assert(isInitialized()); + return *value; + } + + bool operator==(const ConstantFoldabilityState &rhs) const { + return value == rhs.value; + } + + static ConstantFoldabilityState join(const ConstantFoldabilityState &lhs, + const ConstantFoldabilityState &rhs) { + if (lhs.isUninitialized()) + return rhs; + if (rhs.isUninitialized()) + return lhs; + return lhs; + } + + void print(llvm::raw_ostream &os) const { + if (isUninitialized()) { + os << "uninitialized"; + return; + } + os << *value; + } + +private: + std::optional value; +}; + +class ConstantFoldabilityLattice + : public dataflow::Lattice { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConstantFoldabilityLattice) + using Lattice::Lattice; +}; + +/// Implements forward dataflow analysis that find constant foldable +/// values. This is simple analysis that works only for pure ops. Operation +/// results are considered constant foldable, if all of its operands are +/// constant foldable, and it has no memory effects. +class SparseConstantFoldabilityAnalysis + : public dataflow::SparseForwardDataFlowAnalysis< + ConstantFoldabilityLattice> { +public: + using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; + + LogicalResult + visitOperation(Operation *op, + ArrayRef operands, + ArrayRef results) override { + + if (!isPure(op)) { + setAllToEntryStates(results); + return success(); + } + + // If op is constant, it is constant foldable. + if (op->hasTrait()) { + ConstantFoldabilityLattice *lattice = results.front(); + propagateIfChanged(lattice, + lattice->join(ConstantFoldabilityState(true))); + return success(); + } + + // For other operations, check if all operands are constant. + bool areAllOperandsConstantFoldable = true; + for (auto *operandLattice : operands) { + if (operandLattice->getValue().isUninitialized()) + return success(); + areAllOperandsConstantFoldable &= + operandLattice->getValue().getKnownState(); + } + + // If all operands are constant foldable, results are constant foldable. + for (auto *resultLattice : results) + propagateIfChanged(resultLattice, + resultLattice->join(ConstantFoldabilityState( + areAllOperandsConstantFoldable))); + + return success(); + } + + // Set up entry state for lattices to be uninitialized. + void setToEntryState(ConstantFoldabilityLattice *lattice) override { + propagateIfChanged( + lattice, lattice->join(ConstantFoldabilityState::getUninitialized())); + } +}; +} // namespace + +/// Given `cluster`, this function creates a new private `FuncOp` containing all +/// ops from `cluster`and returns it after adding to the `moduleSymbolTable`. +/// Function op returned has a single block with no arguments and return types +/// same as types of values in `clusterValuesUsedOutsideCluster`. During +/// outlining, first, constant ops consumed by ops inside the cluster +/// (represented by `constantsUsedInsideCluster`) are cloned into the newly +/// created function body. Later, cluster ops are moved inside the function +/// body. Finally, uses of original constants (from outside) by operations +/// inside the cluster are replaced with newly cloned constants. +static func::FuncOp +outlineClusterToFunction(IRRewriter &rewriter, Location loc, + SymbolTable &moduleSymbolTable, const Cluster &cluster, + ArrayRef clusterValuesUsedOutsideCluster, + ArrayRef constantsUsedInsideCluster) { + + // Create `func::FuncOp` op. + FunctionType funcType = FunctionType::get( + rewriter.getContext(), {}, + llvm::to_vector(llvm::map_range(clusterValuesUsedOutsideCluster, + [](Value v) { return v.getType(); }))); + func::FuncOp funcOp = + rewriter.create(loc, "constant_subgraph", funcType); + funcOp->setAttr("plan.constant_foldable", rewriter.getUnitAttr()); + funcOp.setPrivate(); + + // Create function body. + Block *entryBlock = funcOp.addEntryBlock(); + + { + // Create `func::ReturnOp`. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToEnd(entryBlock); + rewriter.create(loc, clusterValuesUsedOutsideCluster); + } + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(entryBlock); + Operation *term = funcOp.getBody().getBlocks().front().getTerminator(); + + // First clone the constant ops used by cluster. + IRMapping constantMapping; + for (Operation *op : constantsUsedInsideCluster) + rewriter.clone(*op, constantMapping); + + // Move non-constant ops. + for (Operation *op : cluster) { + if (!op->hasTrait()) + rewriter.moveOpBefore(op, term); + } + + // Update use of outside constants to cloned constants. + for (Operation *op : constantsUsedInsideCluster) + rewriter.replaceUsesWithIf( + op->getResults().front(), + constantMapping.lookup(op->getResults().front()), + [&](OpOperand &user) { + return funcOp->isProperAncestor(user.getOwner()); + }); + } + + if (funcOp->getParentOp()) + funcOp->remove(); + moduleSymbolTable.insert(funcOp); + return funcOp; +} + +/// Given `op`, for each of its operand, if producer has `ConstantLike` trait, +/// push producer to `constantsUsedByCluster`. Region/s of `op` are traveled +/// recursively, doing the same. +static void collectConstantParentsOfOperands( + Operation *op, SmallVectorImpl &constantsUsedByCluster) { + for (Value operand : op->getOperands()) { + Operation *defOp = operand.getDefiningOp(); + if (!defOp) + continue; + if (defOp->hasTrait()) + constantsUsedByCluster.push_back(defOp); + continue; + } + + // Visit regions of op, if applicable. + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (Operation &nestedOp : block) + collectConstantParentsOfOperands(&nestedOp, constantsUsedByCluster); + } + } +} + +/// Visits every op in the cluster and if parent of its operand has +/// `ConstantLike` trait, adds it to `constantsUsedByCluster`. Cluster outlining +/// function uses `constantsUsedByCluster` to first copy constant ops into the +/// cluster before moving other ops. +static void collectConstantOpsUsedInsideCluster( + const Cluster &cluster, + SmallVectorImpl &constantsUsedByCluster) { + for (Operation *op : cluster) + collectConstantParentsOfOperands(op, constantsUsedByCluster); +} + +/// Returns true if `op` is pure region op and doesn't have +/// `FunctionOpInterface`. +static bool isPureAndNonFuncRegionOp(Operation *op) { + return (op->getNumRegions() > 0 && + !isa(op)); +} + +/// Traverse `op` recursively and return `true` if op is standalone. Op is +/// standalone if every operand is either result of a constant op OR result +/// of another op which is inside the cluster. +/// There is a special case when op is inside the body of single region +/// carrying ops (for example, `stablehlo.reduce` and `linalg.generic`). In +/// this case, if parent region op is standalone, ops using entry block +/// arguments of this region are also standalone. However, one exception to +/// this is `func::Func` op. +static bool isOpStandalone(Operation *op, DenseSet &clusterOps) { + + auto isBlockArgOfPureAndNonFuncRegionOp = [&](Value v) { + // First check if it's a block argument + if (!isa(v)) + return false; + // Get the parent operation of the block + Operation *parentOp = cast(v).getOwner()->getParentOp(); + if (!parentOp) + return false; + return isPureAndNonFuncRegionOp(parentOp); + }; + + auto checkIfOpInCluster = [&](Operation *op) { + if (isPureAndNonFuncRegionOp(op->getParentOp())) + return clusterOps.contains(op->getParentOp()); + return clusterOps.contains(op); + }; + + for (Value operand : op->getOperands()) { + Operation *defOp = operand.getDefiningOp(); + // `operand` is block argument. + if (!defOp) { + if (isBlockArgOfPureAndNonFuncRegionOp(operand)) + continue; + return false; + } + // `defOp` should be either constant OR member of this cluster. + if (defOp->hasTrait() || checkIfOpInCluster(defOp)) + continue; + return false; + } + + // Visit regions of op, if applicable. + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (Operation &nestedOp : block) { + if (!isOpStandalone(&nestedOp, clusterOps)) + return false; + } + } + } + return true; +} + +/// Returns true if cluster is standalone, given ops in the cluster. Cluster +/// is standalone if, for each op, every operand is either result of +/// constant op OR result of another op which is inside cluster. To keep +/// things simple, we only outline standalone clusters. We ignore clusters +/// where one of the ops need result of other cluster as operand. +/// TODO: @Sagar Remove this limitation. +static bool isClusterStandalone(DenseSet &clusterOps) { + for (Operation *op : clusterOps) { + if (!isOpStandalone(op, clusterOps)) + return false; + } + return true; +} + +/// Find constant foldable clusters by running clustering on `func` with +/// given clustering `opts` and outline each cluster to a function inside +/// module with `symbolTable`. +static LogicalResult findClustersAndOutlineToFuncs(func::FuncOp func, + ModuleOp moduleOp, + IRRewriter &rewriter, + SymbolTable &symbolTable, + const ClusteringOpts &opts, + DataFlowSolver &solver) { + FailureOr> clusters = + analyzeAndClusterOperations(func, opts); + if (failed(clusters)) + return failure(); + + for (const Cluster &cluster : *clusters) { + DenseSet clusterOpSet; + for (Operation *op : cluster) + clusterOpSet.insert(op); + + // Check if cluster is standalone. + if (!isClusterStandalone(clusterOpSet)) + continue; + + // It is still possible for clusters to have only non-compute ops. + // For example, `tensor.empty()` followed by `linalg.generic` where + // later one is not constant foldable. Outlining of such clusters is + // skipped. + if (llvm::all_of(clusterOpSet, + [](Operation *op) { return op->getNumOperands() == 0; })) + continue; + + Block *clusterRootBlock = cluster.getRoot()->getBlock(); + + // Find cluster values used outside the cluster. These values + // should be returned from the outlined function. + SetVector valuesUsedOutsideCluster; + for (Operation *op : cluster) { + for (Value result : op->getResults()) { + for (Operation *user : result.getUsers()) { + if (clusterOpSet.contains(user)) + continue; + if (valuesUsedOutsideCluster.contains(result)) + continue; + valuesUsedOutsideCluster.insert(result); + } + } + } + + // Find constant ops used inside the cluster. Remember, constant ops + // are not part of any cluster. + SmallVector constantOpsUsedInsideCluster; + collectConstantOpsUsedInsideCluster(cluster, constantOpsUsedInsideCluster); + + func::FuncOp outlinedFunc = outlineClusterToFunction( + rewriter, moduleOp->getLoc(), symbolTable, cluster, + valuesUsedOutsideCluster.getArrayRef(), constantOpsUsedInsideCluster); + + // Insert call to the outline function. + rewriter.setInsertionPointToStart(clusterRootBlock); + auto callOp = + rewriter.create(moduleOp->getLoc(), outlinedFunc); + // Set the call result values to 'uninitialized'. + for (Value v : callOp->getResults()) + solver.getOrCreateState(v); + + // Replace uses of cluster values used outside with the result of call op. + for (auto [originalValue, callResult] : + llvm::zip(valuesUsedOutsideCluster, callOp.getResults())) + rewriter.replaceUsesWithIf( + originalValue, callResult, [&](OpOperand &operand) { + return !outlinedFunc->isProperAncestor(operand.getOwner()); + }); + } + return success(); +} + +/// Returns true if `op` is should be clustered. +static bool shouldClusterOp(Operation *op, const DataFlowSolver &solver) { + + // Don't cluster terminator otherwise constant foldable terminator will be + // outlined. + // Don't cluster constants since they might be shared across clusters and + // will be cloned later. + // Don't cluster control-flow op itself. If control-flow op is clusterable + // (i.e. added to `ClusteringState`), clustering algorithm doesn't visit ops + // in its region/s. This causes issue when not all regions of control-flow + // op are standalone to outline. Consider the example, + // + // %0 = scf.if %true->(tensor<4xf32>) { + // %cst_1 = stablehlo.constant dense<3.000000e+00> : tensor<4xf32> + // %1 = stablehlo.add %cst, %cst_0 : tensor<4xf32> + // %2 = stablehlo.subtract %cst_1, %1 : tensor<4xf32> + // scf.yield %2 : tensor<4xf32> + // } else { + // %1 = stablehlo.add %arg0, %cst : tensor<4xf32> + // %2 = stablehlo.subtract %1, %cst_0 : tensor<4xf32> + // scf.yield %2 : tensor<4xf32> + // } + // Here, DFA decides correctly that `%0` is constant foldable. However, + // our logic to check whether an op is standalone (within cluster) doesn't + // check which region might be taken for control flow ops. It says op is + // standalone only if all regions (and thus ops inside those regions) have + // dependencies inside the cluster. + // In above example, this causes an issue because even though only `then` + // region is going to be executed, since `else` region has dependency on + // %arg0 (which is argument of top level `func.func`), we say this is `scf.if` + // is not standalone. + // Skipping control-flow ops all together solves this issue as follows, + // 1. Clustering analyzes ops within all regions and outline constant foldable + // clusters. + // 2. If control-flow op is constant foldable, like above, its own + // canonicalizer kicks in to keep clusters only in executable regions, as + // shown below. + // %0 = call @constant_subgraph() : () -> tensor<4xf32> + // func.func private @constant_subgraph() -> tensor<4xf32> attributes + // {plan.constant_foldable} { + // %cst = stablehlo.constant dense<2.000000e+00> : tensor<4xf32> + // %cst_0 = stablehlo.constant dense<3.000000e+00> : tensor<4xf32> + // %0 = stablehlo.add %cst, %cst: tensor<4xf32> + // %1 = stablehlo.subtract %cst_0, %0 : tensor<4xf32> + // return %1 : tensor<4xf32> + // } + if (op->hasTrait() || + op->hasTrait() || isa(op)) + return false; + + // Don't cluster ops inside pure region ops. + if (isPureAndNonFuncRegionOp(op->getParentOp())) + return false; + + bool areAllResultsConstantFoldable = true; + for (Value result : op->getResults()) { + const ConstantFoldabilityLattice *lattice = + solver.lookupState(result); + if (!lattice || lattice->getValue().isUninitialized()) + return false; + areAllResultsConstantFoldable &= lattice->getValue().getKnownState(); + } + return areAllResultsConstantFoldable; +} + +/// Returns clustering options for constant foldable clusters generation. +static ClusteringOpts +getClusteringOpts(const DataFlowSolver &solver, + const std::function &skipClustering) { + ClusteringOpts opts; + opts.clusterTarget = Attribute{}; + opts.isClusterableOp = [&solver, &skipClustering](Operation *op) { + if (skipClustering && skipClustering(op)) + return false; + return shouldClusterOp(op, solver); + }; + opts.mergeIndependentClusters = [](Operation *, ClusterRange, Operation *, + ClusterRange) { return true; }; + return opts; +} + +namespace { +class PlanOutlineConstantFoldableSubgraphsPass + : public mlir::plan::impl::PlanOutlineConstantFoldableSubgraphsPassBase< + PlanOutlineConstantFoldableSubgraphsPass> { +public: + using Base::Base; + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + + DataFlowSolver solver(DataFlowConfig().setInterprocedural(false)); + IRRewriter rewriter(&getContext()); + SymbolTable symbolTable(moduleOp); + + // Initialize and run data flow analysis to determine + // constant foldable ops. + solver.load(); + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(moduleOp))) + return signalPassFailure(); + + ClusteringOpts opts = getClusteringOpts(solver, skipClustering); + + SmallVector originalFuncs; + for (func::FuncOp func : moduleOp.getOps()) + originalFuncs.push_back(func); + + for (func::FuncOp func : originalFuncs) { + if (failed(findClustersAndOutlineToFuncs(func, moduleOp, rewriter, + symbolTable, opts, solver))) { + emitError(moduleOp->getLoc()) << " failed to process clusters\n"; + return signalPassFailure(); + } + } + } +}; +} // namespace \ No newline at end of file diff --git a/mlir-tensorrt/compiler/lib/Dialect/StablehloExt/IR/CMakeLists.txt b/mlir-tensorrt/compiler/lib/Dialect/StablehloExt/IR/CMakeLists.txt index 5ce042f1b..70827efdb 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/StablehloExt/IR/CMakeLists.txt +++ b/mlir-tensorrt/compiler/lib/Dialect/StablehloExt/IR/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_tensorrt_library(MLIRTensorRTStableHloExtIR + StablehloInferTensorValueRangeImpl.cpp + StablehloReifyTypeInterfaceImpl.cpp StablehloTensorKindOpInterfaceImpl.cpp - StableHloReifyTypeInterfaceImpl.cpp LINK_LIBS PUBLIC MLIRAffineDialect @@ -9,6 +10,7 @@ add_mlir_tensorrt_library(MLIRTensorRTStableHloExtIR MLIRDialectUtils MLIRIR MLIRTensorDialect + MLIRTensorRTInferTensorValueRangeInterface MLIRTensorRTInterfaces StablehloOps ) diff --git a/mlir-tensorrt/compiler/lib/Dialect/StablehloExt/IR/StablehloInferTensorValueRangeImpl.cpp b/mlir-tensorrt/compiler/lib/Dialect/StablehloExt/IR/StablehloInferTensorValueRangeImpl.cpp new file mode 100644 index 000000000..9d97cf87f --- /dev/null +++ b/mlir-tensorrt/compiler/lib/Dialect/StablehloExt/IR/StablehloInferTensorValueRangeImpl.cpp @@ -0,0 +1,128 @@ +//===- StablehloInferTensorValueRangeImpl.cpp -----------------------------===// +// +// SPDX-FileCopyrightText: Copyright 2025 NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// Implementation of InferTensorValueRangeInterface for specific StableHlo +/// ops. +//===----------------------------------------------------------------------===// +#include "mlir-tensorrt/Dialect/StablehloExt/IR/StableHloExt.h" +#include "mlir-tensorrt/Interfaces/InferTensorValueRangeInterface.h" +#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" +#include "stablehlo/dialect/StablehloOps.h" + +using namespace mlir; +using namespace mlir::stablehlo; +using namespace mlirtrt::compiler; + +static BoundsArray extUIRanges(ArrayRef ranges, + unsigned destWidth) { + SmallVector result; + for (const auto &l : ranges) + result.push_back(intrange::extUIRange(l, destWidth)); + return BoundsArray(result); +} + +static BoundsArray truncRanges(ArrayRef ranges, + unsigned destWidth) { + SmallVector result; + for (const auto &l : ranges) + result.push_back(intrange::truncRange(l, destWidth)); + return BoundsArray(result); +} + +static BoundsArray extRanges(ArrayRef ranges, + unsigned destWidth) { + SmallVector result; + for (const auto &l : ranges) + result.push_back(intrange::extRange(l, destWidth)); + return BoundsArray(result); +} + +namespace { +class ConvertOpImpl : public InferTensorValueRangeInterface::ExternalModel< + ConvertOpImpl, stablehlo::ConvertOp> { +public: + void + inferResultRangesFromOptional(Operation *op_, + ArrayRef argRanges, + SetTensorValueLatticeFn setResultRanges) const { + auto op = cast(op_); + Type sourceElementType = op.getOperand().getType().getElementType(); + Type resultElementType = op.getType().getElementType(); + + if (!isa(sourceElementType) || + !isa(resultElementType) || + !BoundsArray::shouldAnalyzeValueBounds(op.getResult())) { + setResultRanges(op.getResult(), BoundsArray()); + return; + } + + unsigned sourceWidth = + ConstantIntRanges::getStorageBitwidth(sourceElementType); + unsigned destWidth = + ConstantIntRanges::getStorageBitwidth(resultElementType); + + const auto *argRange0 = argRanges[0].dyn_cast(); + if (!argRange0 || argRange0->isUninitialized()) { + setResultRanges(op.getResult(), BoundsArray()); + return; + } + + // Per Stablehlo spec: + // "For boolean-to-any-supported-type conversions, the value false is + // converted to zero, and the value true is converted to one. For + // any-supported-type-to-boolean conversions, a zero value is converted to + // false, and non-zero values are converted to true." + // See https://openxla.org/stablehlo/spec#convert. + if (sourceWidth == 1 && destWidth > 1) { + setResultRanges(op.getResult(), + extUIRanges(argRange0->getValue(), destWidth)); + return; + } + if (destWidth == 1 && sourceWidth > 1) { + setResultRanges(op.getResult(), + truncRanges(argRange0->getValue(), destWidth)); + return; + } + + if (sourceWidth < destWidth) { + setResultRanges(op.getResult(), + extRanges(argRange0->getValue(), destWidth)); + return; + } + + if (sourceWidth > destWidth) { + setResultRanges(op.getResult(), + truncRanges(argRange0->getValue(), destWidth)); + return; + } + + setResultRanges(op.getResult(), BoundsArray()); + } +}; + +} // namespace + +void mlir::stablehlo::registerInferTensorValueRangeInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension( + +[](MLIRContext *ctx, stablehlo::StablehloDialect *dialect) { + stablehlo::ConvertOp::attachInterface(*ctx); + }); +} diff --git a/mlir-tensorrt/compiler/lib/Dialect/StablehloExt/IR/StableHloReifyTypeInterfaceImpl.cpp b/mlir-tensorrt/compiler/lib/Dialect/StablehloExt/IR/StablehloReifyTypeInterfaceImpl.cpp similarity index 98% rename from mlir-tensorrt/compiler/lib/Dialect/StablehloExt/IR/StableHloReifyTypeInterfaceImpl.cpp rename to mlir-tensorrt/compiler/lib/Dialect/StablehloExt/IR/StablehloReifyTypeInterfaceImpl.cpp index ccd48571b..8fe6ec0cd 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/StablehloExt/IR/StableHloReifyTypeInterfaceImpl.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/StablehloExt/IR/StablehloReifyTypeInterfaceImpl.cpp @@ -1,6 +1,6 @@ -//===- StablehloTensorKindOpInterfaceImpl.cpp -----------------------------===// +//===- StablehloReifyTypeInterfaceImpl.cpp --------------------------------===// // -// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES. +// SPDX-FileCopyrightText: Copyright 2025 NVIDIA CORPORATION & AFFILIATES. // All rights reserved. // SPDX-License-Identifier: Apache-2.0 // @@ -347,9 +347,6 @@ void stablehlo::registerTypeInferenceExternalModels(DialectRegistry ®istry) { +[](MLIRContext *ctx, stablehlo::StablehloDialect *dialect) { stablehlo::ConvolutionOp::attachInterface< ConvolutionReifyRankedShapedTypeOpInterfaceImpl>(*ctx); - }); - registry.addExtension( - +[](MLIRContext *ctx, stablehlo::StablehloDialect *dialect) { stablehlo::ReduceWindowOp::attachInterface< ReduceWindowReifyRankedShapedTypeOpInterfaceImpl>(*ctx); }); diff --git a/mlir-tensorrt/compiler/lib/Interfaces/CMakeLists.txt b/mlir-tensorrt/compiler/lib/Interfaces/CMakeLists.txt index 3f4d7c2da..cf9f70213 100644 --- a/mlir-tensorrt/compiler/lib/Interfaces/CMakeLists.txt +++ b/mlir-tensorrt/compiler/lib/Interfaces/CMakeLists.txt @@ -48,3 +48,15 @@ add_mlir_tensorrt_interface_library(MLIRTensorRTBufferizationScopeInterface MLIRBufferizationDialect MLIRBufferizationTransforms ) + +add_mlir_tensorrt_interface_library(MLIRTensorRTInferTensorValueRangeInterface + InferTensorValueRangeInterface.cpp + + TD + "mlir-tensorrt/Interfaces/InferTensorValueRangeInterface.td" + OP + "mlir-tensorrt/Interfaces/InferTensorValueRangeInterface" + + LINK_LIBS PUBLIC + MLIRInferIntRangeInterface + ) diff --git a/mlir-tensorrt/compiler/lib/Interfaces/InferTensorValueRangeInterface.cpp b/mlir-tensorrt/compiler/lib/Interfaces/InferTensorValueRangeInterface.cpp new file mode 100644 index 000000000..2c69780bd --- /dev/null +++ b/mlir-tensorrt/compiler/lib/Interfaces/InferTensorValueRangeInterface.cpp @@ -0,0 +1,194 @@ +//===- InferTensorValueRangeInterface.cpp -------------------------------===// +// +// SPDX-FileCopyrightText: Copyright 2025 NVIDIA CORPORATION & AFFILIATES. +// All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// InferTensorValueRangeInterface definitions. +/// +//===----------------------------------------------------------------------===// +#include "mlir-tensorrt/Interfaces/InferTensorValueRangeInterface.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Support/LLVM.h" + +using namespace mlirtrt::compiler; +using namespace mlir; + +//===----------------------------------------------------------------------===// +// BoundsValue +//===----------------------------------------------------------------------===// + +bool BoundsArray::shouldAnalyzeValueBounds(Type type) { + if (auto rtt = dyn_cast(type)) + return rtt.getElementType().isSignlessIntOrIndex() && + rtt.hasStaticShape() && rtt.getNumElements() <= kMaxVolumeThreshold; + return false; +} +bool BoundsArray::shouldAnalyzeValueBounds(Value value) { + return shouldAnalyzeValueBounds(value.getType()); +} + +ConstantIntRanges BoundsArray::getMaxDimRange() { + APInt smin = APInt(IndexType::kInternalStorageBitWidth, 0); + APInt smax = APInt(IndexType::kInternalStorageBitWidth, + std::numeric_limits::max()); + return ConstantIntRanges::fromSigned(smin, smax); +} + +BoundsArray BoundsArray::getMaxRangeForShapeBounds(Value v) { + auto type = cast(v.getType()); + SmallVector ranges; + ranges.reserve(type.getRank()); + for (int64_t dim : type.getShape()) { + if (ShapedType::isDynamic(dim)) { + ranges.push_back(getMaxDimRange()); + continue; + } + ranges.push_back(ConstantIntRanges::constant(APInt(64, dim))); + } + return BoundsArray(std::move(ranges)); +} + +BoundsArray BoundsArray::getMaxRangeForValueBounds(Value v) { + assert(shouldAnalyzeValueBounds(v) && "value is unsuitable for analysis"); + Type elementType = mlir::getElementTypeOrSelf(v); + unsigned numBits = ConstantIntRanges::getStorageBitwidth(elementType); + APInt smin = APInt::getSignedMinValue(numBits); + APInt smax = APInt::getSignedMaxValue(numBits); + SmallVector ranges( + cast(v.getType()).getNumElements(), + ConstantIntRanges::fromSigned(smin, smax)); + return BoundsArray(std::move(ranges)); +} + +BoundsArray BoundsArray::getFromConstantValue(DenseIntElementsAttr v) { + assert(shouldAnalyzeValueBounds(v.getType()) && + "attribute type is unsuitable for creating value bound state"); + SmallVector ranges; + ranges.reserve(cast(v.getType()).getNumElements()); + for (const APInt &element : v.getValues()) + ranges.push_back(ConstantIntRanges::constant(element)); + return BoundsArray(std::move(ranges)); +} + +BoundsArray BoundsArray::fromShapeBounds(ArrayRef min, + ArrayRef max) { + SmallVector res; + for (auto [l, r] : llvm::zip_equal(min, max)) + res.push_back(ConstantIntRanges::fromSigned(APInt(64, l), APInt(64, r))); + return BoundsArray(std::move(res)); +} + +BoundsArray BoundsArray::fromIntegerValueBounds(unsigned bitWidth, + ArrayRef min, + ArrayRef max) { + SmallVector res; + for (auto [l, r] : llvm::zip_equal(min, max)) + res.push_back( + ConstantIntRanges::fromSigned(APInt(64, l).sextOrTrunc(bitWidth), + APInt(64, r).sextOrTrunc(bitWidth))); + return BoundsArray(std::move(res)); +} + +BoundsArray BoundsArray::fromIntegerValueBounds(ArrayRef min, + ArrayRef max) { + SmallVector res; + for (auto [l, r] : llvm::zip_equal(min, max)) + res.push_back(ConstantIntRanges::fromSigned(l, r)); + return BoundsArray(std::move(res)); +} + +BoundsArray BoundsArray::join(const BoundsArray &lhs, const BoundsArray &rhs) { + if (lhs.isUninitialized()) + return rhs; + if (rhs.isUninitialized()) + return lhs; + SmallVector res; + for (auto [l, r] : llvm::zip_equal(lhs.getValue(), rhs.getValue())) + res.push_back(l.rangeUnion(r)); + return BoundsArray(std::move(res)); +} + +BoundsArray BoundsArray::meet(const BoundsArray &lhs, const BoundsArray &rhs) { + if (lhs.isUninitialized()) + return rhs; + if (rhs.isUninitialized()) + return lhs; + SmallVector res; + for (auto [l, r] : llvm::zip_equal(lhs.getValue(), rhs.getValue())) + res.push_back(l.intersection(r)); + return BoundsArray(std::move(res)); +} + +void BoundsArray::print(raw_ostream &os) const { + if (!value) { + os << "<>"; + return; + } + os << "<"; + llvm::interleaveComma(*value, os, [&](const ConstantIntRanges &r) { + os << "[" << r.smin() << ", " << r.smax() << "]"; + }); + os << ">"; +} + +llvm::raw_ostream &mlirtrt::compiler::operator<<(llvm::raw_ostream &os, + const BoundsArray &v) { + v.print(os); + return os; +} + +std::pair +BoundsArray::getAsElementsAttr(RankedTensorType type) const { + assert(!isUninitialized() && "expected initialized value"); + assert(type.getNumElements() == static_cast(value->size()) && + "specified tensor type's volume does not match lattice value volume"); + SmallVector lbs; + lbs.reserve(type.getNumElements()); + SmallVector ubs; + ubs.reserve(type.getNumElements()); + for (const ConstantIntRanges &r : *value) { + lbs.push_back(r.smin()); + ubs.push_back(r.smax()); + } + return std::make_pair(DenseElementsAttr::get(type, lbs), + DenseElementsAttr::get(type, ubs)); +} + +/// Returns true if the element ranges are constant (single-value) ranges. +std::optional +BoundsArray::getConstantValues(RankedTensorType type) const { + assert(!isUninitialized() && "expected initialized value"); + assert(type.getNumElements() == static_cast(value->size()) && + "specified tensor type's volume does not match lattice value volume"); + SmallVector lbs; + lbs.reserve(type.getNumElements()); + for (const ConstantIntRanges &r : *value) { + if (r.smin() != r.smax()) + return {}; + lbs.push_back(r.smin()); + } + return DenseElementsAttr::get(type, lbs); +} + +//===----------------------------------------------------------------------===// +// Generated interface class implmenetations. +//===----------------------------------------------------------------------===// + +#include "mlir-tensorrt/Interfaces/InferTensorValueRangeInterface.cpp.inc" diff --git a/mlir-tensorrt/compiler/test/Dialect/Plan/outline-constant-foldable-subgraphs.mlir b/mlir-tensorrt/compiler/test/Dialect/Plan/outline-constant-foldable-subgraphs.mlir new file mode 100644 index 000000000..36f8031f8 --- /dev/null +++ b/mlir-tensorrt/compiler/test/Dialect/Plan/outline-constant-foldable-subgraphs.mlir @@ -0,0 +1,397 @@ +// RUN: mlir-tensorrt-opt %s --plan-outline-constant-foldable-subgraphs --split-input-file --canonicalize -allow-unregistered-dialect | FileCheck %s + +func.func @simple_case1(%arg0: tensor<4xf32>) -> tensor<4xf32>{ + %c0 = stablehlo.constant dense<2.0> : tensor<4xf32> + %c1 = stablehlo.constant dense<3.0> : tensor<4xf32> + %add0 = stablehlo.add %c0, %c1 : tensor<4xf32> + %add1 = stablehlo.add %arg0, %add0 : tensor<4xf32> + return %add1 : tensor<4xf32> +} + +// CHECK-LABEL: @simple_case1 +// CHECK-SAME: (%[[arg0:.+]]: tensor<4xf32>) -> tensor<4xf32> +// CHECK-NEXT: %[[v0:.+]] = call @constant_subgraph() : () -> tensor<4xf32> +// CHECK-NEXT: %[[v1:.+]] = stablehlo.add %[[arg0]], %[[v0]] : tensor<4xf32> +// CHECK-NEXT: return %[[v1]] : tensor<4xf32> +// CHECK: func.func private @constant_subgraph() -> tensor<4xf32> attributes {plan.constant_foldable} + +// ----- + +func.func @simple_case2(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>{ + %c0 = stablehlo.constant dense<2.0> : tensor<4xf32> + %c1 = stablehlo.constant dense<3.0> : tensor<4xf32> + %add0 = stablehlo.add %c0, %arg0 : tensor<4xf32> + %add1 = stablehlo.add %c0, %c1 : tensor<4xf32> + %add2 = stablehlo.add %arg1, %add1 : tensor<4xf32> + %sub0 = stablehlo.subtract %c0, %c1 : tensor<4xf32> + %add3 = stablehlo.add %add2, %add0 : tensor<4xf32> + %sub1 = stablehlo.subtract %add3, %sub0 : tensor<4xf32> + return %sub1 : tensor<4xf32> +} + +// CHECK-LABEL: @simple_case2 +// CHECK-SAME: (%[[arg0:.+]]: tensor<4xf32>, %[[arg1:.+]]: tensor<4xf32>) -> tensor<4xf32> +// CHECK-NEXT: %[[v0:.+]] = stablehlo.constant dense<2.000000e+00> : tensor<4xf32> +// CHECK-NEXT: %[[v1:.+]] = call @constant_subgraph_0() : () -> tensor<4xf32> +// CHECK-NEXT: %[[v2:.+]] = call @constant_subgraph() : () -> tensor<4xf32> +// CHECK-NEXT: %[[v3:.+]] = stablehlo.add %[[v0]], %[[arg0]] : tensor<4xf32> +// CHECK-NEXT: %[[v4:.+]] = stablehlo.add %[[arg1]], %[[v2]] : tensor<4xf32> +// CHECK-NEXT: %[[v5:.+]] = stablehlo.add %[[v4]], %[[v3]] : tensor<4xf32> +// CHECK-NEXT: %[[v6:.+]] = stablehlo.subtract %[[v5]], %[[v1]] : tensor<4xf32> +// CHECK-NEXT: return %[[v6]] : tensor<4xf32> +// CHECK: func.func private @constant_subgraph() -> tensor<4xf32> attributes {plan.constant_foldable} +// CHECK: func.func private @constant_subgraph_0() -> tensor<4xf32> attributes {plan.constant_foldable} + +// ----- + +func.func @simple_case3(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>{ + %c0 = stablehlo.constant dense<2.0> : tensor<4xf32> + %add0 = stablehlo.add %arg0, %c0 : tensor<4xf32> + %c1 = stablehlo.constant dense<3.0> : tensor<4xf32> + %add1 = stablehlo.add %c0, %arg1 : tensor<4xf32> + %add2 = stablehlo.add %add0, %c1 : tensor<4xf32> + %add3 = stablehlo.add %arg0, %add2 : tensor<4xf32> + return %add3 : tensor<4xf32> +} + +// CHECK-LABEL: @simple_case3 +// CHECK-SAME: (%[[arg0:.+]]: tensor<4xf32>, %[[arg1:.+]]: tensor<4xf32>) -> tensor<4xf32> +// CHECK-NEXT: %[[v0:.+]] = stablehlo.constant dense<3.000000e+00> : tensor<4xf32> +// CHECK-NEXT: %[[v1:.+]] = stablehlo.constant dense<2.000000e+00> : tensor<4xf32> +// CHECK-NEXT: %[[v2:.+]] = stablehlo.add %[[arg0]], %[[v1]] : tensor<4xf32> +// CHECK-NEXT: %[[v3:.+]] = stablehlo.add %[[v2]], %[[v0]] : tensor<4xf32> +// CHECK-NEXT: %[[v4:.+]] = stablehlo.add %[[arg0]], %[[v3]] : tensor<4xf32> +// CHECK-NEXT: return %[[v4]] : tensor<4xf32> + +// ----- + +func.func @scf_for_case1(%arg0: tensor<4xf32>) -> tensor<4xf32>{ + %c0 = stablehlo.constant dense<2.0> : tensor<4xf32> + %c1 = stablehlo.constant dense<3.0> : tensor<4xf32> + %add0 = stablehlo.add %c0, %arg0 : tensor<4xf32> + %for_start = arith.constant 0 : index + %for_end = arith.constant 5 : index + %for_step = arith.constant 1 : index + %sum = scf.for %iv = %for_start to %for_end step %for_step + iter_args(%sum_iter = %add0) -> (tensor<4xf32>) { + %add1 = stablehlo.add %c0, %c1 : tensor<4xf32> + %sub1 = stablehlo.subtract %add1, %c1 : tensor<4xf32> + %sum_next = stablehlo.add %add0, %sub1 : tensor<4xf32> + scf.yield %sum_next : tensor<4xf32> + } + return %sum : tensor<4xf32> +} + +// CHECK-LABEL: @scf_for_case1 +// CHECK-SAME: (%[[arg0:.+]]: tensor<4xf32>) -> tensor<4xf32> +// CHECK: %[[cst:.+]] = stablehlo.constant dense<2.000000e+00> : tensor<4xf32> +// CHECK: %[[v0:.+]] = stablehlo.add %[[cst]], %[[arg0]] : tensor<4xf32> +// CHECK: %[[v1:.+]] = scf.for +// CHECK-NEXT: func.call @constant_subgraph() : () -> tensor<4xf32> +// CHECK-NEXT: stablehlo.add +// CHECK-NEXT: scf.yield +// CHECK: return %[[v1]] : tensor<4xf32> +// CHECK: func.func private @constant_subgraph() -> tensor<4xf32> attributes {plan.constant_foldable} + +// ----- + +func.func @scf_for_case2(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>{ + %c0 = stablehlo.constant dense<2.0> : tensor<4xf32> + %c1 = stablehlo.constant dense<3.0> : tensor<4xf32> + %add0 = stablehlo.add %c0, %arg0 : tensor<4xf32> + %for_start = arith.constant 0 : index + %for_end = arith.constant 5 : index + %for_step = arith.constant 1 : index + %sum = scf.for %iv = %for_start to %for_end step %for_step + iter_args(%sum_iter = %add0) -> (tensor<4xf32>) { + %add1 = stablehlo.add %c0, %c1 : tensor<4xf32> + %add2 = stablehlo.add %add1, %arg1 : tensor<4xf32> + %sub1 = stablehlo.subtract %add2, %c1 : tensor<4xf32> + %sum_next = stablehlo.add %sum_iter, %sub1 : tensor<4xf32> + scf.yield %sum_next : tensor<4xf32> + } + return %sum : tensor<4xf32> +} + +// CHECK-LABEL: @scf_for_case2 +// CHECK-SAME: (%[[arg0:.+]]: tensor<4xf32>, %[[arg1:.+]]: tensor<4xf32>) -> tensor<4xf32> +// CHECK: %[[cst:.+]] = stablehlo.constant dense<2.000000e+00> : tensor<4xf32> +// CHECK: %[[v0:.+]] = stablehlo.add %[[cst]], %[[arg0]] : tensor<4xf32> +// CHECK: %[[v1:.+]] = scf.for +// CHECK-NEXT: func.call @constant_subgraph() : () -> tensor<4xf32> +// CHECK-NEXT: stablehlo.add +// CHECK-NEXT: stablehlo.subtract +// CHECK-NEXT: stablehlo.add +// CHECK-NEXT: scf.yield +// CHECK: return %[[v1]] : tensor<4xf32> +// CHECK: func.func private @constant_subgraph() -> tensor<4xf32> attributes {plan.constant_foldable} + +// ----- + +func.func @scf_for_case3(%arg0: tensor<4xf32>) -> tensor<4xf32>{ + %c0 = stablehlo.constant dense<2.0> : tensor<4xf32> + %add0 = stablehlo.add %c0, %arg0 : tensor<4xf32> + %for_start = arith.constant 0 : index + %for_end = arith.constant 5 : index + %for_step = arith.constant 1 : index + %sum = scf.for %iv = %for_start to %for_end step %for_step + iter_args(%sum_iter = %add0) -> (tensor<4xf32>) { + %sum_next = stablehlo.add %sum_iter, %c0 : tensor<4xf32> + scf.yield %sum_next : tensor<4xf32> + } + return %sum : tensor<4xf32> +} + +// CHECK-LABEL: @scf_for_case3 +// CHECK-SAME: (%[[arg0:.+]]: tensor<4xf32>) -> tensor<4xf32> +// CHECK: %[[cst:.+]] = stablehlo.constant dense<2.000000e+00> : tensor<4xf32> +// CHECK: %[[v0:.+]] = stablehlo.add %[[cst]], %[[arg0]] : tensor<4xf32> +// CHECK: %[[v1:.+]] = scf.for +// CHECK-NEXT: stablehlo.add +// CHECK-NEXT: scf.yield +// CHECK: return %[[v1]] : tensor<4xf32> + +// ----- + +func.func @scf_while_case1(%init: tensor<4xf32>) -> tensor<4xf32>{ + %cst0 = stablehlo.constant dense<2.0> : tensor<4xf32> + %cst1 = stablehlo.constant dense<3.0> : tensor<4xf32> + %cst2 = stablehlo.constant dense<4.0> : tensor<4xf32> + %count = arith.constant 5: i32 + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %r, %c = scf.while(%arg0 = %init, %arg1 = %count) : (tensor<4xf32>, i32) -> (tensor<4xf32>, i32){ + %0 = arith.subi %arg1, %c1 : i32 + %1 = arith.cmpi eq, %0, %c0 : i32 + scf.condition(%1) %arg0, %0 : tensor<4xf32>, i32 + } do { + ^bb0(%base: tensor<4xf32>, %new_count: i32): + %add0 = stablehlo.add %cst0, %cst1 : tensor<4xf32> + %sub0 = stablehlo.subtract %cst2, %add0 : tensor<4xf32> + %new = stablehlo.add %base, %sub0 : tensor<4xf32> + scf.yield %new, %new_count : tensor<4xf32>, i32 + } + return %r : tensor<4xf32> +} + +// CHECK-LABEL: @scf_while_case1 +// CHECK: %[[v0:.+]]:2 = scf.while +// CHECK: scf.condition +// CHECK: ^bb0(%[[arg1:.+]]: tensor<4xf32>, %[[arg2:.+]]: i32) +// CHECK-NEXT: stablehlo.add +// CHECK-NEXT: stablehlo.subtract +// CHECK-NEXT: stablehlo.add +// CHECK-NEXT: scf.yield +// CHECK: return %[[v0]]#0 + +// ----- + +func.func @scf_while_case2(%init: tensor<4xf32>) -> tensor<4xf32>{ + %cst0 = stablehlo.constant dense<2.0> : tensor<4xf32> + %cst1 = stablehlo.constant dense<3.0> : tensor<4xf32> + %cst2 = stablehlo.constant dense<4.0> : tensor<4xf32> + %count = arith.constant 2: i32 + %c1 = arith.constant 1 : i32 + %r, %c = scf.while(%arg0 = %init, %arg1 = %count) : (tensor<4xf32>, i32) -> (tensor<4xf32>, i32){ + %0 = arith.subi %arg1, %c1 : i32 + %1 = arith.cmpi eq, %0, %c1 : i32 + scf.condition(%1) %arg0, %0 : tensor<4xf32>, i32 + } do { + ^bb0(%base: tensor<4xf32>, %new_count: i32): + %add0 = stablehlo.add %cst0, %cst1 : tensor<4xf32> + %sub0 = stablehlo.subtract %cst2, %add0 : tensor<4xf32> + %new = stablehlo.add %base, %sub0 : tensor<4xf32> + scf.yield %new, %new_count : tensor<4xf32>, i32 + } + return %r : tensor<4xf32> +} + +// CHECK-LABEL: @scf_while_case2 +// CHECK: %[[v0:.+]]:2 = scf.while +// CHECK: scf.condition +// CHECK: ^bb0(%[[arg1:.+]]: tensor<4xf32>, %[[arg2:.+]]: i32) +// CHECK-NEXT: %[[v1:.+]] = func.call @constant_subgraph() : () -> tensor<4xf32> +// CHECK-NEXT: %[[v2:.+]] = stablehlo.add %[[arg1]], %[[v1]] : tensor<4xf32> +// CHECK-NEXT: scf.yield %[[v2]], %[[arg2]] : tensor<4xf32>, i32 +// CHECK: return %[[v0]]#0 +// CHECK: func.func private @constant_subgraph() -> tensor<4xf32> attributes {plan.constant_foldable} + +// ----- + +func.func @scf_if_case1(%arg0: tensor<4xf32>, %arg1: i1) -> tensor<4xf32>{ + %c0 = stablehlo.constant dense<2.0> : tensor<4xf32> + %c1 = stablehlo.constant dense<2.0> : tensor<4xf32> + %r = scf.if %arg1 -> (tensor<4xf32>){ + %c2 = stablehlo.constant dense<3.0> : tensor<4xf32> + %add0 = stablehlo.add %c0, %c1 : tensor<4xf32> + %sub0 = stablehlo.subtract %c2, %add0 : tensor<4xf32> + scf.yield %sub0 : tensor<4xf32> + } else { + %add0 = stablehlo.add %arg0, %c0 : tensor<4xf32> + %sub0 = stablehlo.subtract %add0, %c1 : tensor<4xf32> + scf.yield %sub0 : tensor<4xf32> + } + return %r : tensor<4xf32> +} + +// CHECK-LABEL: @scf_if_case1 +// CHECK: %[[v0:.+]] = scf.if +// CHECK-NEXT: func.call @constant_subgraph() : () -> tensor<4xf32> +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } else { +// CHECK-NEXT: stablehlo.add +// CHECK-NEXT: stablehlo.subtract +// CHECK-NEXT: scf.yield +// CHECK: return %[[v0]] : tensor<4xf32> +// CHECK: func.func private @constant_subgraph() -> tensor<4xf32> attributes {plan.constant_foldable} + +// ----- + +func.func @scf_if_case2(%arg0: tensor<4xf32>) -> tensor<4xf32>{ + %cond = arith.constant 1 : i1 + %c0 = stablehlo.constant dense<2.0> : tensor<4xf32> + %c1 = stablehlo.constant dense<2.0> : tensor<4xf32> + %r = scf.if %cond -> (tensor<4xf32>){ + %c2 = stablehlo.constant dense<3.0> : tensor<4xf32> + %add0 = stablehlo.add %c0, %c1 : tensor<4xf32> + %sub0 = stablehlo.subtract %c2, %add0 : tensor<4xf32> + scf.yield %sub0 : tensor<4xf32> + } else { + %add0 = stablehlo.add %arg0, %c0 : tensor<4xf32> + %sub0 = stablehlo.subtract %add0, %c1 : tensor<4xf32> + scf.yield %sub0 : tensor<4xf32> + } + return %r : tensor<4xf32> +} + +// CHECK-LABEL: @scf_if_case2 +// CHECK-SAME: (%[[arg0:.+]]: tensor<4xf32>) -> tensor<4xf32> +// CHECK-NEXT: %[[v0:.+]] = call @constant_subgraph() : () -> tensor<4xf32> +// CHECK-NEXT: return %[[v0]] : tensor<4xf32> +// CHECK: func.func private @constant_subgraph() -> tensor<4xf32> attributes {plan.constant_foldable} + +// ----- + +func.func @foldable_terminator() -> (tensor<4xf32>, tensor<4xf32>){ + %c0 = stablehlo.constant dense_resource<__elided__> : tensor<4xf32> + %c1 = stablehlo.constant dense_resource<__elided__> : tensor<4xf32> + %0 = stablehlo.add %c0, %c1 : tensor<4xf32> + %2 = stablehlo.subtract %0, %c1 : tensor<4xf32> + return %2, %0 : tensor<4xf32>, tensor<4xf32> +} + +// CHECK-LABEL: @foldable_terminator +// CHECK-NEXT: %[[v0:.+]]:2 = call @constant_subgraph() : () -> (tensor<4xf32>, tensor<4xf32>) +// CHECK-NEXT: return %[[v0]]#1, %[[v0]]#0 : tensor<4xf32>, tensor<4xf32> +// CHECK: func.func private @constant_subgraph() -> (tensor<4xf32>, tensor<4xf32>) attributes {plan.constant_foldable} + +// ----- + +func.func @skip_outlining() -> (tensor<4xf32>, tensor<4xf32>){ + %c0 = stablehlo.constant dense_resource<__elided__> : tensor<4xf32> + %c1 = stablehlo.constant dense_resource<__elided__> : tensor<4xf32> + %0 = stablehlo.add %c0, %c1 : tensor<4xf32> + %1 = "some.op"(%0) : (tensor<4xf32>) -> tensor<4xf32> + %2 = stablehlo.subtract %0, %c1 : tensor<4xf32> + return %1, %2 : tensor<4xf32>, tensor<4xf32> +} + +// CHECK-LABEL: @skip_outlining +// CHECK-NEXT: %[[cst:.+]] = stablehlo.constant +// CHECK-NEXT: %[[v0:.+]] = call @constant_subgraph() : () -> tensor<4xf32> +// CHECK-NEXT: %[[v1:.+]] = "some.op"(%[[v0]]) : (tensor<4xf32>) -> tensor<4xf32> +// CHECK-NEXT: %[[v2:.+]] = stablehlo.subtract %[[v0]], %[[cst]] : tensor<4xf32> +// CHECK-NEXT: return %[[v1]], %[[v2]] : tensor<4xf32>, tensor<4xf32> +// CHECK: func.func private @constant_subgraph() -> tensor<4xf32> attributes {plan.constant_foldable} + +// ----- + +func.func @reduce_const_foldable_negative(%arg0: tensor) -> tensor<1x10xf32> { + %cst_0 = stablehlo.constant dense<1.0> : tensor<1x10x20xf32> + %cst = stablehlo.constant dense<0.0> : tensor + %cst_1 = stablehlo.constant dense<2.0> : tensor + %0 = stablehlo.reduce(%cst_0 init: %cst) + across dimensions = [2] : (tensor<1x10x20xf32>, tensor) -> tensor<1x10xf32> + reducer(%accum: tensor, %curr: tensor) { + %first = stablehlo.add %accum, %curr : tensor + %second = stablehlo.multiply %first, %cst_1 : tensor + %third = stablehlo.subtract %second, %arg0 : tensor + stablehlo.return %third : tensor + } + return %0 : tensor<1x10xf32> +} + +// CHECK-LABEL: @reduce_const_foldable_negative +// CHECK-NEXT: stablehlo.constant +// CHECK-NEXT: stablehlo.constant +// CHECK-NEXT: stablehlo.constant +// CHECK-NEXT: %[[v1:.+]] = stablehlo.reduce +// CHECK: return %[[v1]] : tensor<1x10xf32> + +// ----- + +func.func @reduce_const_foldable(%arg0: tensor) -> tensor<1x10xf32> { + %cst_0 = stablehlo.constant dense<1.0> : tensor<1x10x20xf32> + %add_0 = stablehlo.add %cst_0, %cst_0 : tensor<1x10x20xf32> + %cst = stablehlo.constant dense<0.0> : tensor + %cst_1 = stablehlo.constant dense<2.0> : tensor + %0 = stablehlo.reduce(%cst_0 init: %cst) + across dimensions = [2] : (tensor<1x10x20xf32>, tensor) -> tensor<1x10xf32> + reducer(%accum: tensor, %curr: tensor) { + %first = stablehlo.add %accum, %curr : tensor + %second = stablehlo.multiply %first, %cst_1 : tensor + stablehlo.return %second : tensor + } + return %0 : tensor<1x10xf32> +} + +// CHECK-LABEL: @reduce_const_foldable +// CHECK-NEXT: %[[v0:.+]] = call @constant_subgraph() : () -> tensor<1x10xf32> +// CHECK-NEXT: return %[[v0]] : tensor<1x10xf32> +// CHECK: func.func private @constant_subgraph() -> tensor<1x10xf32> attributes {plan.constant_foldable} + +// ----- + +#map = affine_map<(d0)->(d0)> +func.func @linalg_generic_neg(%arg0: tensor<10xf32>) -> tensor<10xf32> { + %empty = tensor.empty () : tensor<10xf32> + %0 = linalg.generic { + iterator_types = ["parallel"], + indexing_maps = [#map, #map] + } ins(%arg0: tensor<10xf32>) outs(%empty: tensor<10xf32>) { + ^bb0(%a: f32, %b: f32): + %r = arith.negf %a : f32 + linalg.yield %r : f32 + } -> tensor<10xf32> + return %0 : tensor<10xf32> +} + +// CHECK-LABEL: @linalg_generic_neg +// CHECK-SAME: (%[[arg0:.+]]: tensor<10xf32>) -> tensor<10xf32> +// CHECK-NEXT: %[[v0:.+]] = tensor.empty() : tensor<10xf32> +// CHECK-NEXT: %[[v1:.+]] = linalg.generic {{.*}} ins(%[[arg0]] : tensor<10xf32>) outs(%[[v0]] : tensor<10xf32>) +// CHECK: return %[[v1]] : tensor<10xf32> + +// ----- + +#map = affine_map<(d0)->(d0)> +func.func @linalg_generic() -> tensor<10xf32> { + %empty = tensor.empty () : tensor<10xf32> + %cst = stablehlo.constant dense<4.0> : tensor<10xf32> + %cst_1 = stablehlo.constant dense<4.0> : tensor<10xf32> + %add = stablehlo.add %cst, %cst_1 : tensor<10xf32> + %0 = linalg.generic { + iterator_types = ["parallel"], + indexing_maps = [#map, #map] + } ins(%add: tensor<10xf32>) outs(%empty: tensor<10xf32>) { + ^bb0(%a: f32, %b: f32): + %r = arith.negf %a : f32 + linalg.yield %r : f32 + } -> tensor<10xf32> + return %0 : tensor<10xf32> +} + +// CHECK-LABEL: @linalg_generic +// CHECK-NEXT: %[[v0:.+]] = call @constant_subgraph() : () -> tensor<10xf32> +// CHECK-NEXT: return %[[v0]] : tensor<10xf32> +// CHECK: func.func private @constant_subgraph() -> tensor<10xf32> attributes {plan.constant_foldable} \ No newline at end of file diff --git a/mlir-tensorrt/compiler/test/Dialect/StableHloExt/tensor-value-bounds-analysis.mlir b/mlir-tensorrt/compiler/test/Dialect/StableHloExt/tensor-value-bounds-analysis.mlir new file mode 100644 index 000000000..d5d800656 --- /dev/null +++ b/mlir-tensorrt/compiler/test/Dialect/StableHloExt/tensor-value-bounds-analysis.mlir @@ -0,0 +1,53 @@ +// RUN: mlir-tensorrt-opt %s -test-tensor-value-bounds-analysis -split-input-file 2>&1 | FileCheck %s + +func.func @stablehlo_convert_bool_to_int() -> tensor<2xi32> { + %cst = arith.constant dense<[true, false]> : tensor<2xi1> + %0 = stablehlo.convert %cst {tag = "convert"} : (tensor<2xi1>) -> tensor<2xi32> + return %0 : tensor<2xi32> +} + +// CHECK-LABEL: func stablehlo_convert_bool_to_int: +// CHECK: test_tag: convert: +// CHECK-NEXT: operand #0: <[-1, -1], [0, 0]> +// CHECK-NEXT: result #0: <[1, 1], [0, 0]> + +// ----- + +func.func @stablehlo_convert_int_to_bool() -> tensor<2xi1> { + %cst = arith.constant dense<[0, -53]> : tensor<2xi32> + %0 = stablehlo.convert %cst {tag = "convert"} : (tensor<2xi32>) -> tensor<2xi1> + return %0 : tensor<2xi1> +} + +// CHECK-LABEL: func stablehlo_convert_int_to_bool: +// CHECK: test_tag: convert: +// CHECK-NEXT: operand #0: <[0, 0], [-53, -53]> +// CHECK-NEXT: result #0: <[0, 0], [-1, -1]> + +// ----- + +func.func @stablehlo_convert_float_to_int() -> tensor<2xi32> { + %cst = arith.constant dense<[1.0, 2.0]> : tensor<2xf32> + %0 = stablehlo.convert %cst {tag = "convert"} : (tensor<2xf32>) -> tensor<2xi32> + return %0 : tensor<2xi32> +} + +// CHECK-LABEL: func stablehlo_convert_float_to_int: +// CHECK: test_tag: convert: +// CHECK-NEXT: operand #0: <> +// CHECK-NEXT: result #0: <> + +// ----- + +// Semantic for signed -> unsigned is currently undefined in stablehlo spec. + +func.func @stablehlo_convert_int_to_uint() -> tensor<2xui32> { + %cst = arith.constant dense<[2, -1]> : tensor<2xi32> + %0 = stablehlo.convert %cst {tag = "convert"} : (tensor<2xi32>) -> tensor<2xui32> + return %0 : tensor<2xui32> +} + +// CHECK-LABEL: func stablehlo_convert_int_to_uint: +// CHECK: test_tag: convert: +// CHECK-NEXT: operand #0: <[2, 2], [-1, -1]> +// CHECK-NEXT: result #0: <> diff --git a/mlir-tensorrt/compiler/test/Pipelines/TensorRTClustering/tensorrt-clustering.mlir b/mlir-tensorrt/compiler/test/Pipelines/TensorRTClustering/tensorrt-clustering.mlir index d73326b07..56a5e7d14 100644 --- a/mlir-tensorrt/compiler/test/Pipelines/TensorRTClustering/tensorrt-clustering.mlir +++ b/mlir-tensorrt/compiler/test/Pipelines/TensorRTClustering/tensorrt-clustering.mlir @@ -102,4 +102,42 @@ func.func @reorder_engine_arguments(%arg0: tensor<2x3x4xf32>, %arg1: tensor<4x2x // CHECK-NEXT: return %[[v0:.+]] : tensor<2x3x?xf32> // CHECK-LABEL: tensorrt.module @trt_engines // CHECK-LABEL: func.func @tensorrt_cluster -// CHECK-SAME: (%[[arg2:.+]]: tensor<2x3x4xf32>, %[[arg3:.+]]: tensor<4x2xf32>) -> tensor<2x3x?xf32> \ No newline at end of file +// CHECK-SAME: (%[[arg2:.+]]: tensor<2x3x4xf32>, %[[arg3:.+]]: tensor<4x2xf32>) -> tensor<2x3x?xf32> + +// ----- + +func.func @maintain_output_order() -> (tensor, tensor, tensor, tensor) { + %cst_f32 = tensorrt.constant dense<0.000000e+00> : tensor + %cst_i32 = tensorrt.constant dense<[1, 2]> : tensor<2xi32> + %0 = tensorrt.broadcast %cst_f32 broadcast_dims<> shape(%cst_i32 : tensor<2xi32>) : tensor to tensor + %cst_f32_0 = tensorrt.constant dense<1.000000e+00> : tensor + %cst_i32_1 = tensorrt.constant dense<[3, 4]> : tensor<2xi32> + %1 = tensorrt.broadcast %cst_f32_0 broadcast_dims<> shape(%cst_i32_1 : tensor<2xi32>) : tensor to tensor + %cst_f32_2 = tensorrt.constant dense<0.000000e+00> : tensor + %cst_i32_3 = tensorrt.constant dense<[5, 6]> : tensor<2xi32> + %2 = tensorrt.broadcast %cst_f32_2 broadcast_dims<> shape(%cst_i32_3 : tensor<2xi32>) : tensor to tensor + %cst_f32_4 = tensorrt.constant dense<1.000000e+00> : tensor + %cst_i32_5 = tensorrt.constant dense<[7, 8]> : tensor<2xi32> + %3 = tensorrt.broadcast %cst_f32_4 broadcast_dims<> shape(%cst_i32_5 : tensor<2xi32>) : tensor to tensor + return %0, %1, %2, %3 : tensor, tensor, tensor, tensor +} + +// CHECK-LABEL: @maintain_output_order +// CHECK-SAME: () -> (tensor, tensor, tensor, tensor) +// CHECK-NEXT: %[[v0:.+]]:4 = tensorrt.call_alloc @trt_engines::@tensorrt_cluster() +// CHECK-NEXT: return %[[v0]]#0, %[[v0]]#1, %[[v0]]#2, %[[v0]]#3 +// CHECK-LABEL: tensorrt.module @trt_engines +// CHECK-LABEL: func.func @tensorrt_cluster +// CHECK-NEXT: %[[v0:.+]] = tensorrt.constant dense<0.000000e+00> : tensor +// CHECK-NEXT: %[[v1:.+]] = tensorrt.constant dense<[1, 2]> : tensor<2xi32> +// CHECK-NEXT: %[[v2:.+]] = tensorrt.broadcast %[[v0]] broadcast_dims<> shape(%[[v1]] : tensor<2xi32>) +// CHECK-NEXT: %[[v3:.+]] = tensorrt.constant dense<1.000000e+00> : tensor +// CHECK-NEXT: %[[v4:.+]] = tensorrt.constant dense<[3, 4]> : tensor<2xi32> +// CHECK-NEXT: %[[v5:.+]] = tensorrt.broadcast %[[v3]] broadcast_dims<> shape(%[[v4]] : tensor<2xi32>) +// CHECK-NEXT: %[[v6:.+]] = tensorrt.constant dense<0.000000e+00> : tensor +// CHECK-NEXT: %[[v7:.+]] = tensorrt.constant dense<[5, 6]> : tensor<2xi32> +// CHECK-NEXT: %[[v8:.+]] = tensorrt.broadcast %[[v6]] broadcast_dims<> shape(%[[v7]] : tensor<2xi32>) +// CHECK-NEXT: %[[v9:.+]] = tensorrt.constant dense<1.000000e+00> : tensor +// CHECK-NEXT: %[[v10:.+]] = tensorrt.constant dense<[7, 8]> : tensor<2xi32> +// CHECK-NEXT: %[[v11:.+]] = tensorrt.broadcast %[[v9]] broadcast_dims<> shape(%[[v10]] : tensor<2xi32>) +// return %[[v2]], %[[v5]], %[[v8]], %[[v11]] \ No newline at end of file diff --git a/mlir-tensorrt/compiler/test/lit.cfg.py b/mlir-tensorrt/compiler/test/lit.cfg.py index 56504d3c9..240fdc4f3 100644 --- a/mlir-tensorrt/compiler/test/lit.cfg.py +++ b/mlir-tensorrt/compiler/test/lit.cfg.py @@ -31,7 +31,7 @@ config.test_source_root = os.path.dirname(__file__) config.gpu_tools_script = os.path.join( config.test_source_root, - "../../python/mlir_tensorrt_tools/mlir_tensorrt/tools/gpu_tools.py", + "../../integrations/python/mlir_tensorrt_tools/mlir_tensorrt/tools/gpu_tools.py", ) diff --git a/mlir-tensorrt/executor/include/mlir-executor/Transforms/Clustering/Clustering.h b/mlir-tensorrt/executor/include/mlir-executor/Transforms/Clustering/Clustering.h index ae9a3fbc0..d71f490e4 100644 --- a/mlir-tensorrt/executor/include/mlir-executor/Transforms/Clustering/Clustering.h +++ b/mlir-tensorrt/executor/include/mlir-executor/Transforms/Clustering/Clustering.h @@ -261,14 +261,24 @@ analyzeAndClusterOperations(Operation *op, const ClusteringOpts &opts); using ClusterRegionOpBuilderFunc = std::function; +/// When region op is created from cluster, values yielded from region +/// op are values from cluster that are used outside the cluster. Outside values +/// are collected with `op->getUses()` where returned use order is not +/// deterministic. However, sometimes it is desired to get certain output order. +/// `ReorderRegionOpYieldValues` is a handle to a function that reorders yielded +/// values (and their types) in desired way. +using ReorderRegionOpYieldValues = std::function &yieldValues, SmallVectorImpl &yieldTypes)>; + /// Creates a "region op" from the given cluster. See above for the /// definition of "region op". When an operation located outside of the cluster /// uses an SSA value produced by an operation in the cluster, the use is /// replaced by the result of the region op. It is assumed that the root is /// located at the back of the cluster. -Operation *createRegionOpFromCluster(const Cluster &cluster, - RewriterBase &rewriter, - ClusterRegionOpBuilderFunc createRegionOp); +Operation *createRegionOpFromCluster( + const Cluster &cluster, RewriterBase &rewriter, + ClusterRegionOpBuilderFunc createRegionOp, + ReorderRegionOpYieldValues reorderYieldValues = nullptr); template OpType createRegionOpFromCluster(const Cluster &cluster, diff --git a/mlir-tensorrt/executor/lib/Transforms/Clustering/Clustering.cpp b/mlir-tensorrt/executor/lib/Transforms/Clustering/Clustering.cpp index 9b4344e54..2e0ed9d01 100644 --- a/mlir-tensorrt/executor/lib/Transforms/Clustering/Clustering.cpp +++ b/mlir-tensorrt/executor/lib/Transforms/Clustering/Clustering.cpp @@ -393,7 +393,8 @@ mlir::analyzeAndClusterOperations(Operation *op, const ClusteringOpts &opts) { Operation * mlir::createRegionOpFromCluster(const Cluster &cluster, RewriterBase &rewriter, - ClusterRegionOpBuilderFunc createRegionOp) { + ClusterRegionOpBuilderFunc createRegionOp, + ReorderRegionOpYieldValues reorderYieldValues) { // insert the region to the last Op to because of dominance property Operation *insertionOp = cluster.getRoot(); @@ -419,6 +420,9 @@ mlir::createRegionOpFromCluster(const Cluster &cluster, RewriterBase &rewriter, } } + if (reorderYieldValues) + reorderYieldValues(yieldValues, yieldTypes); + rewriter.setInsertionPoint(insertionOp); Operation *regionOp = createRegionOp(rewriter, insertionOp->getLoc(), yieldTypes, cluster.getTarget()); diff --git a/mlir-tensorrt/executor/test/lit.site.cfg.py.in b/mlir-tensorrt/executor/test/lit.site.cfg.py.in index 422022a22..9f27e81c5 100644 --- a/mlir-tensorrt/executor/test/lit.site.cfg.py.in +++ b/mlir-tensorrt/executor/test/lit.site.cfg.py.in @@ -19,7 +19,7 @@ config.enable_assertions = @LLVM_ENABLE_ASSERTIONS@ config.gpu_tools_script = os.path.join( "@MLIR_EXECUTOR_SOURCE_DIR@", - "../python/mlir_tensorrt_tools/mlir_tensorrt/tools/gpu_tools.py", + "../integrations/python/mlir_tensorrt_tools/mlir_tensorrt/tools/gpu_tools.py", ) import lit.llvm diff --git a/mlir-tensorrt/integrations/CMakeLists.txt b/mlir-tensorrt/integrations/CMakeLists.txt new file mode 100644 index 000000000..8e5f91a37 --- /dev/null +++ b/mlir-tensorrt/integrations/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(python) diff --git a/mlir-tensorrt/python/CMakeLists.txt b/mlir-tensorrt/integrations/python/CMakeLists.txt similarity index 100% rename from mlir-tensorrt/python/CMakeLists.txt rename to mlir-tensorrt/integrations/python/CMakeLists.txt diff --git a/mlir-tensorrt/python/CompilerPackage.cmake b/mlir-tensorrt/integrations/python/CompilerPackage.cmake similarity index 100% rename from mlir-tensorrt/python/CompilerPackage.cmake rename to mlir-tensorrt/integrations/python/CompilerPackage.cmake diff --git a/mlir-tensorrt/python/CompilerPackageUtils.cmake b/mlir-tensorrt/integrations/python/CompilerPackageUtils.cmake similarity index 100% rename from mlir-tensorrt/python/CompilerPackageUtils.cmake rename to mlir-tensorrt/integrations/python/CompilerPackageUtils.cmake diff --git a/mlir-tensorrt/python/bindings/CPyBindInterop.h b/mlir-tensorrt/integrations/python/bindings/CPyBindInterop.h similarity index 100% rename from mlir-tensorrt/python/bindings/CPyBindInterop.h rename to mlir-tensorrt/integrations/python/bindings/CPyBindInterop.h diff --git a/mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp b/mlir-tensorrt/integrations/python/bindings/Compiler/CompilerPyBind.cpp similarity index 100% rename from mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp rename to mlir-tensorrt/integrations/python/bindings/Compiler/CompilerPyBind.cpp diff --git a/mlir-tensorrt/python/bindings/Compiler/SiteInitializer.cpp b/mlir-tensorrt/integrations/python/bindings/Compiler/SiteInitializer.cpp similarity index 100% rename from mlir-tensorrt/python/bindings/Compiler/SiteInitializer.cpp rename to mlir-tensorrt/integrations/python/bindings/Compiler/SiteInitializer.cpp diff --git a/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp b/mlir-tensorrt/integrations/python/bindings/Runtime/RuntimePyBind.cpp similarity index 100% rename from mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp rename to mlir-tensorrt/integrations/python/bindings/Runtime/RuntimePyBind.cpp diff --git a/mlir-tensorrt/python/bindings/Utils.h b/mlir-tensorrt/integrations/python/bindings/Utils.h similarity index 100% rename from mlir-tensorrt/python/bindings/Utils.h rename to mlir-tensorrt/integrations/python/bindings/Utils.h diff --git a/mlir-tensorrt/python/mlir_tensorrt_compiler/README.md b/mlir-tensorrt/integrations/python/mlir_tensorrt_compiler/README.md similarity index 100% rename from mlir-tensorrt/python/mlir_tensorrt_compiler/README.md rename to mlir-tensorrt/integrations/python/mlir_tensorrt_compiler/README.md diff --git a/mlir-tensorrt/python/mlir_tensorrt_compiler/mlir_tensorrt/compiler/_mlir_libs/_api.pyi b/mlir-tensorrt/integrations/python/mlir_tensorrt_compiler/mlir_tensorrt/compiler/_mlir_libs/_api.pyi similarity index 100% rename from mlir-tensorrt/python/mlir_tensorrt_compiler/mlir_tensorrt/compiler/_mlir_libs/_api.pyi rename to mlir-tensorrt/integrations/python/mlir_tensorrt_compiler/mlir_tensorrt/compiler/_mlir_libs/_api.pyi diff --git a/mlir-tensorrt/python/mlir_tensorrt_compiler/mlir_tensorrt/compiler/api.py b/mlir-tensorrt/integrations/python/mlir_tensorrt_compiler/mlir_tensorrt/compiler/api.py similarity index 100% rename from mlir-tensorrt/python/mlir_tensorrt_compiler/mlir_tensorrt/compiler/api.py rename to mlir-tensorrt/integrations/python/mlir_tensorrt_compiler/mlir_tensorrt/compiler/api.py diff --git a/mlir-tensorrt/python/mlir_tensorrt_compiler/mlir_tensorrt/compiler/torch_bridge.py b/mlir-tensorrt/integrations/python/mlir_tensorrt_compiler/mlir_tensorrt/compiler/torch_bridge.py similarity index 100% rename from mlir-tensorrt/python/mlir_tensorrt_compiler/mlir_tensorrt/compiler/torch_bridge.py rename to mlir-tensorrt/integrations/python/mlir_tensorrt_compiler/mlir_tensorrt/compiler/torch_bridge.py diff --git a/mlir-tensorrt/python/mlir_tensorrt_compiler/pyproject.toml b/mlir-tensorrt/integrations/python/mlir_tensorrt_compiler/pyproject.toml similarity index 100% rename from mlir-tensorrt/python/mlir_tensorrt_compiler/pyproject.toml rename to mlir-tensorrt/integrations/python/mlir_tensorrt_compiler/pyproject.toml diff --git a/mlir-tensorrt/python/mlir_tensorrt_compiler/setup.py b/mlir-tensorrt/integrations/python/mlir_tensorrt_compiler/setup.py similarity index 100% rename from mlir-tensorrt/python/mlir_tensorrt_compiler/setup.py rename to mlir-tensorrt/integrations/python/mlir_tensorrt_compiler/setup.py diff --git a/mlir-tensorrt/python/mlir_tensorrt_runtime/README.md b/mlir-tensorrt/integrations/python/mlir_tensorrt_runtime/README.md similarity index 100% rename from mlir-tensorrt/python/mlir_tensorrt_runtime/README.md rename to mlir-tensorrt/integrations/python/mlir_tensorrt_runtime/README.md diff --git a/mlir-tensorrt/python/mlir_tensorrt_runtime/mlir_tensorrt/runtime/_mlir_libs/_api.pyi b/mlir-tensorrt/integrations/python/mlir_tensorrt_runtime/mlir_tensorrt/runtime/_mlir_libs/_api.pyi similarity index 100% rename from mlir-tensorrt/python/mlir_tensorrt_runtime/mlir_tensorrt/runtime/_mlir_libs/_api.pyi rename to mlir-tensorrt/integrations/python/mlir_tensorrt_runtime/mlir_tensorrt/runtime/_mlir_libs/_api.pyi diff --git a/mlir-tensorrt/python/mlir_tensorrt_runtime/mlir_tensorrt/runtime/api.py b/mlir-tensorrt/integrations/python/mlir_tensorrt_runtime/mlir_tensorrt/runtime/api.py similarity index 100% rename from mlir-tensorrt/python/mlir_tensorrt_runtime/mlir_tensorrt/runtime/api.py rename to mlir-tensorrt/integrations/python/mlir_tensorrt_runtime/mlir_tensorrt/runtime/api.py diff --git a/mlir-tensorrt/python/mlir_tensorrt_runtime/pyproject.toml b/mlir-tensorrt/integrations/python/mlir_tensorrt_runtime/pyproject.toml similarity index 100% rename from mlir-tensorrt/python/mlir_tensorrt_runtime/pyproject.toml rename to mlir-tensorrt/integrations/python/mlir_tensorrt_runtime/pyproject.toml diff --git a/mlir-tensorrt/python/mlir_tensorrt_runtime/setup.py b/mlir-tensorrt/integrations/python/mlir_tensorrt_runtime/setup.py similarity index 100% rename from mlir-tensorrt/python/mlir_tensorrt_runtime/setup.py rename to mlir-tensorrt/integrations/python/mlir_tensorrt_runtime/setup.py diff --git a/mlir-tensorrt/python/mlir_tensorrt_tools/README.md b/mlir-tensorrt/integrations/python/mlir_tensorrt_tools/README.md similarity index 100% rename from mlir-tensorrt/python/mlir_tensorrt_tools/README.md rename to mlir-tensorrt/integrations/python/mlir_tensorrt_tools/README.md diff --git a/mlir-tensorrt/python/mlir_tensorrt_tools/mlir_tensorrt/tools/__init__.py b/mlir-tensorrt/integrations/python/mlir_tensorrt_tools/mlir_tensorrt/tools/__init__.py similarity index 100% rename from mlir-tensorrt/python/mlir_tensorrt_tools/mlir_tensorrt/tools/__init__.py rename to mlir-tensorrt/integrations/python/mlir_tensorrt_tools/mlir_tensorrt/tools/__init__.py diff --git a/mlir-tensorrt/python/mlir_tensorrt_tools/mlir_tensorrt/tools/gpu_tools.py b/mlir-tensorrt/integrations/python/mlir_tensorrt_tools/mlir_tensorrt/tools/gpu_tools.py similarity index 100% rename from mlir-tensorrt/python/mlir_tensorrt_tools/mlir_tensorrt/tools/gpu_tools.py rename to mlir-tensorrt/integrations/python/mlir_tensorrt_tools/mlir_tensorrt/tools/gpu_tools.py diff --git a/mlir-tensorrt/python/mlir_tensorrt_tools/pyproject.toml b/mlir-tensorrt/integrations/python/mlir_tensorrt_tools/pyproject.toml similarity index 100% rename from mlir-tensorrt/python/mlir_tensorrt_tools/pyproject.toml rename to mlir-tensorrt/integrations/python/mlir_tensorrt_tools/pyproject.toml diff --git a/mlir-tensorrt/python/mlir_tensorrt_tools/setup.py b/mlir-tensorrt/integrations/python/mlir_tensorrt_tools/setup.py similarity index 100% rename from mlir-tensorrt/python/mlir_tensorrt_tools/setup.py rename to mlir-tensorrt/integrations/python/mlir_tensorrt_tools/setup.py diff --git a/mlir-tensorrt/python/requirements-dev.txt b/mlir-tensorrt/integrations/python/requirements-dev.txt similarity index 100% rename from mlir-tensorrt/python/requirements-dev.txt rename to mlir-tensorrt/integrations/python/requirements-dev.txt diff --git a/mlir-tensorrt/python/requirements.txt b/mlir-tensorrt/integrations/python/requirements.txt similarity index 100% rename from mlir-tensorrt/python/requirements.txt rename to mlir-tensorrt/integrations/python/requirements.txt diff --git a/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRT.cpp b/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRT.cpp index 7fa53a845..1f934a7e0 100644 --- a/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRT.cpp +++ b/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRT.cpp @@ -38,6 +38,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Parser/Parser.h" +#include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/TypeSwitch.h" @@ -2749,6 +2750,23 @@ class TensorRTDialectOpAsmInterface : public OpAsmDialectInterface { return AliasResult::NoAlias; } }; + +//===----------------------------------------------------------------------===// +// TensorRTDialectInlinerInterface +//===----------------------------------------------------------------------===// + +struct TensorRTDialectInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, + IRMapping &valueMapping) const final { + return true; + } + // Pure operations in TensorRT dialect are always legal to inline. + bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, + IRMapping &valueMapping) const final { + return isPure(op); + } +}; } // namespace //===----------------------------------------------------------------------===// @@ -2778,6 +2796,7 @@ void TensorRTDialect::initialize() { >(); addInterface(); + addInterface(); } //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/inline.mlir b/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/inline.mlir new file mode 100644 index 000000000..47c62a79e --- /dev/null +++ b/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/inline.mlir @@ -0,0 +1,56 @@ +// RUN: tensorrt-opt -inline -split-input-file %s | FileCheck %s + +func.func @outlined()->tensor<2x2xf32> { + %0 = tensorrt.constant dense<2.0> : tensor<2x2xf32> + %1 = tensorrt.element_wise (%0, %0 : tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %1 : tensor<2x2xf32> +} + +func.func @valid_inline() -> tensor<2x2xf32> { + %0 = call @outlined() : () -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// CHECK-LABEL: @valid_inline +// CHECK-NEXT: %[[v0:.+]] = tensorrt.constant +// CHECK-NEXT: %[[v1:.+]] = tensorrt.constant +// CHECK-NEXT: %[[v2:.+]] = tensorrt.element_wise (%[[v0]], %[[v1]] : {{.*}}) +// CHECK-NEXT: return %[[v2]] : tensor<2x2xf32> + +// ----- + +tensorrt.module @engines { + func.func @trt_callee(%arg0: tensor) -> tensor { + return %arg0: tensor + } +} + +func.func @invalid_inline(%arg0: tensor<10xf32>) -> tensor<10xf32> { + %0 = tensor.empty() : tensor<10xf32> + %1 = tensorrt.call @engines::@trt_callee(%arg0 : tensor<10xf32>) outs(%0: tensor<10xf32>) + -> tensor<10xf32> + return %1 : tensor<10xf32> +} + +// CHECK-LABEL: @invalid_inline +// CHECK-NEXT: tensor.empty() +// CHECK-NEXT: %[[v1:.+]] = tensorrt.call +// CHECK-NEXT: return %[[v1]] : tensor<10xf32> + +// ----- + +func.func @trt_callee(%arg0: tensor) -> tensor { + return %arg0: tensor +} + +func.func @invalid_inline_same_sym_table(%arg0: tensor<10xf32>) -> tensor<10xf32> { + %0 = tensor.empty() : tensor<10xf32> + %1 = tensorrt.call @trt_callee(%arg0 : tensor<10xf32>) outs(%0: tensor<10xf32>) + -> tensor<10xf32> + return %1 : tensor<10xf32> +} + +// CHECK-LABEL: @invalid_inline_same_sym_table +// CHECK-NEXT: tensor.empty() +// CHECK-NEXT: %[[v1:.+]] = tensorrt.call +// CHECK-NEXT: return %[[v1]] : tensor<10xf32> \ No newline at end of file diff --git a/mlir-tensorrt/tensorrt/test/lit.site.cfg.py.in b/mlir-tensorrt/tensorrt/test/lit.site.cfg.py.in index 4a857f75a..28be315d0 100644 --- a/mlir-tensorrt/tensorrt/test/lit.site.cfg.py.in +++ b/mlir-tensorrt/tensorrt/test/lit.site.cfg.py.in @@ -26,11 +26,11 @@ config.mlir_tensorrt_compile_time_version = "@MLIR_TRT_TENSORRT_VERSION@" config.gpu_tools_package_path = os.path.join( "@MLIR_TENSORRT_DIALECT_SOURCE_DIR@", - "../python/mlir_tensorrt_tools" + "../integrations/python/mlir_tensorrt_tools" ) config.gpu_tools_script = os.path.join( "@MLIR_TENSORRT_DIALECT_SOURCE_DIR@", - "../python/mlir_tensorrt_tools/mlir_tensorrt/tools/gpu_tools.py", + "../integrations/python/mlir_tensorrt_tools/mlir_tensorrt/tools/gpu_tools.py", ) def load_gpu_tools_module():