Skip to content

Commit 23851b8

Browse files
committed
Merge remote-tracking branch 'origin/move_internal_changes' into sam2
2 parents 43bcad3 + b7002a6 commit 23851b8

File tree

81 files changed

+2930
-1234
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

81 files changed

+2930
-1234
lines changed

mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/Passes.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ include "mlir/Pass/PassBase.td"
2525
//===----------------------------------------------------------------------===//
2626
// StablehloToTensorRT
2727
//===----------------------------------------------------------------------===//
28+
2829
#ifdef MLIR_TENSORRT_ENABLE_HLO
2930
def ConvertStablehloToTensorRTPass : Pass<"convert-stablehlo-to-tensorrt"> {
3031
let summary = "Convert Stable HLO dialect to TensorRT dialect";
@@ -44,7 +45,30 @@ def ConvertStablehloToTensorRTPass : Pass<"convert-stablehlo-to-tensorrt"> {
4445
"target TensorRT version for conversion">
4546
];
4647
}
48+
#endif // MLIR_TENSORRT_ENABLE_HLO
49+
50+
//===----------------------------------------------------------------------===//
51+
// ChloToStableHloExt
52+
//===----------------------------------------------------------------------===//
53+
54+
#ifdef MLIR_TENSORRT_ENABLE_HLO
55+
def ConvertChloToStableHloExtPass : Pass<"convert-chlo-to-stablehlo-ext"> {
56+
let summary = "Convert specific CHLO operations to stablehlo";
57+
let description = [{
58+
This pass converts a CHLO operations to StableHlo while also allowing
59+
for some CHLO operations to be preserved (see options).
60+
}];
61+
let dependentDialects = [
62+
"::mlir::stablehlo::StablehloDialect"
63+
];
4764

65+
let options = [
66+
Option<"preserveErf", "preserve-erf", "bool", "true",
67+
"do not convert chlo.erf ops">,
68+
Option<"preserveTopK", "preserve-topk", "bool", "true",
69+
"do not convert chlo.topk ops">,
70+
];
71+
}
4872
#endif // MLIR_TENSORRT_ENABLE_HLO
4973

5074
//===----------------------------------------------------------------------===//

mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/StablehloToTensorRT/StablehloToTensorRT.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#define MLIR_TENSORRT_CONVERSION_HLOTOTENSORRT_HLOTOTENSORRT_H
2626

2727
#include "mlir-tensorrt/Conversion/TensorRTCommon/ConvertToTensorRTCommon.h"
28+
#include "mlir-tensorrt/Dialect/StableHloExt/Utils/GatherScatterUtils.h"
2829
#include "mlir/IR/PatternMatch.h"
2930

3031
namespace mlir {
@@ -33,7 +34,8 @@ class ConversionTarget;
3334
// Collection of rewrite patterns for lowering of Stable HLO to TensorRT
3435
// dialect.
3536
void populateStablehloToTensorRtConversionPattern(
36-
TensorRTTypeConverter &typeConverter, RewritePatternSet &patterns);
37+
TensorRTTypeConverter &typeConverter, RewritePatternSet &patterns,
38+
ShapeInfoCallbacks shapeInfoCallbacks = {});
3739

3840
/// Populate patterns for convert Chlo ops to TensorRT ops.
3941
void populateChloToTensorRtLegalityAndPatterns(

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/Analysis/BoundsAnalysis.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ class BoundsArray {
7474
static BoundsArray fromIntegerValueBounds(unsigned bitwidth,
7575
ArrayRef<int64_t> min,
7676
ArrayRef<int64_t> max);
77+
static BoundsArray fromIntegerValueBounds(ArrayRef<llvm::APInt> min,
78+
ArrayRef<llvm::APInt> max);
7779

7880
/// For the given tensor-typed value, return the most conservative bounds for
7981
/// the shape of `v`. For each unknown dimension of the shape of `v` the

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanAttributes.td

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ def Plan_HostClusterKindAttr : Plan_Attr<"HostClusterKind", "host_cluster",
2424
}
2525

2626

27-
def Plan_BoundsAttr : Plan_Attr<"Bounds", "bounds">{
27+
def Plan_BoundsAttr : Plan_Attr<"Bounds", "bounds", [
28+
DeclareAttrInterfaceMethods<TensorBoundsAttrInterface>]>{
2829
let parameters = (ins
2930
EnumParameter<Plan_BoundsKind>:$kind,
3031
OptionalParameter<"DenseI64ArrayAttr">:$min_shape,
@@ -46,17 +47,17 @@ def Plan_BoundsAttr : Plan_Attr<"Bounds", "bounds">{
4647

4748
let extraClassDeclaration = [{
4849
/// Returns true if this bounds is for shape dimension extents.
49-
bool isShapeBound() {
50+
bool isShapeBound() const {
5051
return getKind() == BoundsKind::Shape;
5152
}
5253

5354
/// Returns true if this bounds is a 'none' bounds kind.
54-
bool isNone() {
55+
bool isNone() const {
5556
return getKind() == BoundsKind::None;
5657
}
5758

5859
/// Returns true if this bounds is for values of a tensor.
59-
bool isValueBound() {
60+
bool isValueBound() const {
6061
return getKind() == BoundsKind::Value;
6162
}
6263

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanDialect.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ def Plan_Dialect : Dialect {
1515

1616
}];
1717
let cppNamespace = "::mlir::plan";
18+
19+
let hasRegionArgAttrVerify = 1;
1820

1921
let extraClassDeclaration = [{
2022

@@ -72,6 +74,18 @@ def Plan_Dialect : Dialect {
7274
(addExtensionOperation<Ops>(), ...);
7375
}
7476

77+
/// Return the name of the function arg/result attributes that encode
78+
/// host tensor value bounds. It should have a type `plan::BoundsAttr`.
79+
static StringRef getValueBoundsAttrName() {
80+
return "plan.value_bounds";
81+
}
82+
83+
/// Return the name of the function arg/result attributes that encode
84+
/// the shape bounds. It should have a type `plan::BoundsAttr`.
85+
static StringRef getShapeBoundsAttrName() {
86+
return "plan.shape_profile";
87+
}
88+
7589
private:
7690
::llvm::StringMap<AttrParsingHook> attrParsingHooks;
7791
::llvm::DenseMap<::mlir::TypeID, AttrPrintingHook> attrPrintingHooks;

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/Plan/IR/PlanInterfaces.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,34 @@
33

44
include "mlir/IR/OpBase.td"
55

6+
//===----------------------------------------------------------------------===//
7+
// TensorBoundsAttrInterface
8+
//===----------------------------------------------------------------------===//
9+
10+
def TensorBoundsAttrInterface : AttrInterface<"TensorBoundsAttrInterface"> {
11+
let cppNamespace = "::mlir::plan";
12+
let methods = [
13+
InterfaceMethod<
14+
/*desc=*/"Return the shape bounds associated with the attribute",
15+
/*retTy=*/"LogicalResult",
16+
"getShapeBounds",
17+
(ins "llvm::SmallVectorImpl<int64_t> &":$min,
18+
"llvm::SmallVectorImpl<int64_t> &":$max),
19+
/*body=*/"",
20+
""
21+
>,
22+
InterfaceMethod<
23+
/*desc=*/"Return the integer value bounds associated with the attribute",
24+
/*retTy=*/"LogicalResult",
25+
"getIntegerValueBounds",
26+
(ins "llvm::SmallVectorImpl<llvm::APInt> &":$min,
27+
"llvm::SmallVectorImpl<llvm::APInt> &":$max),
28+
/*body=*/"",
29+
""
30+
>
31+
];
32+
}
33+
634
//===----------------------------------------------------------------------===//
735
// ClusterKindInterface
836
//===----------------------------------------------------------------------===//

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/StableHloExt/Transforms/Passes.td

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -202,32 +202,11 @@ def LowerSpecialCustomCalls : Pass<"stablehlo-ext-lower-special-custom-calls"> {
202202
}
203203

204204
//===----------------------------------------------------------------------===//
205-
// StablehloInputPreprocessingPass
205+
// CanonicalizeConvolutionPass
206206
//===----------------------------------------------------------------------===//
207207

208-
def StablehloInputPreprocessingPass : Pass<"tensorrt-stablehlo-input-preprocessing"> {
209-
let summary = "Prepares Stable HLO dialect operations for conversion to TensorRT";
210-
211-
let description = [{
212-
This pass contains a set of patterns for simplifying or transforming Stable HLO
213-
input IR so that conversion to the TensorRT dialect is more straightforward.
214-
215-
In particular:
216-
217-
- Simplify certain patterns commonly found in IR emitted for JAX programs
218-
but not covered by existing Stable HLO canonicalizations/transforms.
219-
220-
- Prepare convolutions to be NCHW/FCRS format and have at least two
221-
"spatial" dimensions.
222-
223-
- Canonicalize `stablehlo.scatter` operations so that they can be converted to
224-
`tensorrt.scatter` in a straightforward manner.
225-
}];
226-
227-
let dependentDialects = [
228-
"::mlir::tensor::TensorDialect",
229-
"::mlir::stablehlo::StablehloDialect"
230-
];
208+
def CanonicalizeConvolutionPass : Pass<"stablehlo-ext-canonicalize-convolution"> {
209+
let summary = "Canonicalizes stablehlo convolution operations";
231210
}
232211

233212

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/StableHloExt/Utils/GatherScatterUtils.h

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
#ifndef MLIR_TENSORRT_DIALECT_STABLEHLOEXT_UTILS_GATHERSCATTERUTILS_H
2727
#define MLIR_TENSORRT_DIALECT_STABLEHLOEXT_UTILS_GATHERSCATTERUTILS_H
2828

29+
#include "mlir-tensorrt/Utils/ShapeInfo.h"
30+
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
2931
#include "mlir/IR/Value.h"
3032
#include <optional>
3133

@@ -34,6 +36,7 @@ namespace mlir {
3436
class OpBuilder;
3537

3638
namespace stablehlo {
39+
class DynamicGatherOp;
3740
class GatherOp;
3841
class ScatterOp;
3942

@@ -124,6 +127,15 @@ namespace stablehlo_ext {
124127
std::optional<int64_t>
125128
isSingleDimSimpleGatherWithImplicitIndexDim(stablehlo::GatherOp op);
126129

130+
/// Returns the "gather dimension" if `op` is a 'simple, single dimension'
131+
/// gather op with implicit index vector dimension (see above for definition).
132+
/// This version works for `stablehlo.dynamic_gather` using pattern matching
133+
/// against the expected canonical form when the operand shape along some
134+
/// "offset dimensions" is dynamic.
135+
std::optional<int64_t> isSingleDimSimpleGatherWithImplicitIndexDim(
136+
stablehlo::DynamicGatherOp op,
137+
const ShapeInfoCallbacks &shapeInfoCallbacks);
138+
127139
/// Returns the "gather dimension" if `op` is a 'simple, single dimension'
128140
/// gather op with explicit size-1 index vector dimension (see above for
129141
/// definition).
@@ -138,6 +150,21 @@ bool isSimpleLeadingMultiDimGather(stablehlo::GatherOp op);
138150
/// gather' (see definition above).
139151
bool isSimpleLeadingMultiDimGatherWithDegenerateDims(stablehlo::GatherOp op);
140152

153+
/// Attempts to construct a `stablehlo.reshape` if result type is statically
154+
/// shaped, otherwise creates `stablehlo.dynamic_reshape`.
155+
Value createCollapsingReshape(OpBuilder &b, Location loc, Value input,
156+
ArrayRef<ReassociationIndices> reassociation);
157+
158+
/// Attempts to construct a `stablehlo.reshape` if `resultType` is statically
159+
/// shaped, otherwise creates a `stablehlo.dynamic_reshape`.
160+
Value createExpandingReshape(OpBuilder &b, Location loc,
161+
RankedTensorType resultType, Value input,
162+
ArrayRef<ReassociationIndices> reassociation);
163+
164+
/// Returns true if the `scatterOp` has a configuration that corresponds to the
165+
/// ONNX ScatterNd operation semantic.
166+
bool isCanonicalScatterNd(stablehlo::ScatterOp scatterOp);
167+
141168
//===----------------------------------------------------------------------===//
142169
// Code below this point was adapted from the MLIR-HLO project (part of OpenXLA
143170
// project) `xla/mlir_hlo/mhlo/utils/mhlo_scatter_gather_utils.h` and has the
@@ -155,10 +182,6 @@ bool isSimpleLeadingMultiDimGatherWithDegenerateDims(stablehlo::GatherOp op);
155182
// - scatter_dims_to_operand_dims is [0, 1, ...]
156183
bool isCanonicalScatter(stablehlo::ScatterOp scatterOp);
157184

158-
/// Returns true if the `scatterOp` has a configuration that corresponds to the
159-
/// ONNX ScatterNd operation semantic.
160-
bool isCanonicalScatterNd(stablehlo::ScatterOp scatterOp);
161-
162185
// Checks if the gather has the following characteristics:
163186
// - start_indices is a two-dimensional tensor
164187
// - index_vector_dim is 1

mlir-tensorrt/compiler/include/mlir-tensorrt/Pipelines/StableHloInputPipelines.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,19 @@ class OpPassManager;
3232
struct StableHloInputOptions {
3333
/// Whether to lower Stablehlo control flow ops to SCF dialect ops.
3434
bool legalizeControlFlowToSCF = false;
35-
/// Whether to lower chlo.erf into primitive stablehlo operations.
36-
bool legalizeChloErfToStablehlo = false;
35+
36+
/// Whether to preserve 'chlo.erf' ops or lower them to 'stablehlo' ops.
37+
/// By default, we preserve since it has a 1-1 correspondence with a TensorRT
38+
/// op.
39+
bool preserveChloErf = true;
40+
41+
/// Whether to preserve 'chlo.top_k' ops or lower them to 'stablehlo' ops.
42+
/// By default, we preserve since it has a 1-1 correspondence with a TensorRT
43+
/// op.
44+
bool preserveChloTopK = true;
45+
3746
/// Whether to disable running the inliner.
3847
bool disableInliner = false;
39-
/// Whether to lower chlo to stablehlo.
40-
bool convertChloToStablehlo = false;
4148
};
4249

4350
/// Construct a pipeline for preprocessing StableHLO IR to convert it into the
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
//===- ShapeInfo.h ---------------------------------------------*- C++ -*-===//
2+
//
3+
// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES.
4+
// All rights reserved.
5+
// SPDX-License-Identifier: Apache-2.0
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===----------------------------------------------------------------------===//
20+
///
21+
/// Declarations for callback types are used to abstract away how to infer
22+
/// shape knowledge from a pass or transformation. For example, a pass operating
23+
/// on StableHlo IR may need to check whether the *values* of tensor A represent
24+
/// the actual *shape* of tensor B, whose shape may not be known statically at
25+
/// compile time.
26+
///
27+
/// The specific mechanism that one may use to determine the validity of a
28+
/// specific proposition like the example above (which must be reported as
29+
/// "unknown", "true", or "false") may depend on the context. In the case
30+
/// of the StableHlo example above, we could try to naively pattern match
31+
/// whether tensor A is the result of `stablehlo.concat` of appropriate
32+
/// `stablehlo.get_dimensions_size %A, dim = ...` results. In other cases,
33+
/// we may have access to an analysis that assists with more robustly
34+
/// checking the proposition.
35+
///
36+
/// This file just contains callback types that a Pass or rewrite/transform can
37+
/// accept as a parameter, allowing the creator or caller to hand in a
38+
/// particular implementation.
39+
///
40+
//===----------------------------------------------------------------------===//
41+
#ifndef MLIR_TENSORRT_UTILS_SHAPEINFO
42+
#define MLIR_TENSORRT_UTILS_SHAPEINFO
43+
44+
#include "mlir/IR/BuiltinAttributes.h"
45+
#include "mlir/IR/Value.h"
46+
47+
namespace mlir {
48+
49+
/// TensorElementValue identifies a particular scalar element value of a
50+
/// statically-shaped tensor.
51+
struct TensorElementValue {
52+
TensorElementValue(Value value, ArrayRef<int64_t> coord);
53+
54+
TypedValue<RankedTensorType> getTensor() const { return tensor; }
55+
int64_t getLinearIndex() const { return linearIndex; }
56+
57+
/// A value of type (must be statically-shaped) RankedTensorType.
58+
TypedValue<RankedTensorType> tensor;
59+
60+
/// The linear coordinate of the value.
61+
int64_t linearIndex;
62+
};
63+
64+
/// TensorShapeDimExtent identifies a (potentially dynamically shaped) size
65+
/// of a particular dimension of a tensor's shape.
66+
struct TensorShapeDimExtent {
67+
TensorShapeDimExtent(Value value, int64_t dim);
68+
69+
std::optional<int64_t> getConstantSize() const;
70+
71+
/// A value of type RankedTensorType.
72+
TypedValue<RankedTensorType> tensor;
73+
74+
/// The dimension.
75+
int64_t dim;
76+
};
77+
78+
struct ShapeInfoCallbacks {
79+
// Check whether 'tensorElementValue' is provably equivalent to
80+
// `tensorShapeDimExtent`. Returning 'nullopt' means "unknown", true means
81+
// "equal", false means "not equal".
82+
std::function<std::optional<bool>(TensorElementValue tensorElementValue,
83+
TensorShapeDimExtent tensorShapeDimExtent)>
84+
isElementValueEqualToShapeDimExtent;
85+
86+
// Check whether 'tensorElementValue' is provably equivalent to the given
87+
// static value. Returning 'nullopt' means "unknown", true means "equal",
88+
// false means "not equal".
89+
std::function<std::optional<bool>(TensorElementValue tensorElementValue,
90+
Attribute constantValue)>
91+
isElementValueEqualToConstant;
92+
};
93+
94+
} // namespace mlir
95+
96+
#endif // MLIR_TENSORRT_UTILS_SHAPEINFO

0 commit comments

Comments
 (0)