Skip to content

Commit a5439dc

Browse files
christopherbateyizhuoz004
authored andcommitted
[tensorrt] Add shape inference for tensorrt resize ops
Implements `ReifyRankedShapedTypeOpInterface` for tensorrt resize ops to enable shape inference.
1 parent 2503d17 commit a5439dc

File tree

9 files changed

+337
-10
lines changed

9 files changed

+337
-10
lines changed

mlir-tensorrt/tensorrt/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,5 +82,5 @@ include_directories(${MLIR_TENSORRT_DIALECT_BINARY_DIR}/include)
8282

8383
add_subdirectory(include/mlir-tensorrt-dialect)
8484
add_subdirectory(lib)
85-
add_subdirectory(tensorrt-opt)
8685
add_subdirectory(test)
86+
add_subdirectory(tensorrt-opt)

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3463,7 +3463,8 @@ def TensorRT_ReshapeOp : TensorRT_Op<"reshape",
34633463
//===----------------------------------------------------------------------===//
34643464

34653465
def TensorRT_ResizeNearestOp : TensorRT_Op<"resize_nearest", [Pure,
3466-
TensorRTPartiallyInferTensorResultTypes]>{
3466+
TensorRTPartiallyInferTensorResultTypes,
3467+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]>{
34673468
let summary = "TensorRT Resize(IResizeLayer with NEAREST mode) operation";
34683469
let description = [{
34693470

@@ -3541,7 +3542,8 @@ def TensorRT_ResizeNearestOp : TensorRT_Op<"resize_nearest", [Pure,
35413542
//===----------------------------------------------------------------------===//
35423543

35433544
def TensorRT_ResizeLinearOp : TensorRT_Op<"resize_linear", [Pure,
3544-
TensorRTPartiallyInferTensorResultTypes]>{
3545+
TensorRTPartiallyInferTensorResultTypes,
3546+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]>{
35453547
let summary = "TensorRT Resize(IResizeLayer with LINEAR mode) operation";
35463548
let description = [{
35473549

@@ -3608,7 +3610,8 @@ def TensorRT_ResizeLinearOp : TensorRT_Op<"resize_linear", [Pure,
36083610
//===----------------------------------------------------------------------===//
36093611

36103612
def TensorRT_ResizeCubicOp : TensorRT_Op<"resize_cubic", [Pure,
3611-
TensorRTPartiallyInferTensorResultTypes]>{
3613+
TensorRTPartiallyInferTensorResultTypes,
3614+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]>{
36123615
let summary = "TensorRT Resize(IResizeLayer with CUBIC mode) operation";
36133616
let description = [{
36143617

mlir-tensorrt/tensorrt/lib/TensorRT/IR/TypeInferenceInterfaceImpls.cpp

Lines changed: 154 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "EinsumHelper.h"
2525
#include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h"
2626
#include "mlir-tensorrt-dialect/Utils/ShapeUtils.h"
27+
#include "mlir/Dialect/Arith/IR/Arith.h"
2728
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2829
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
2930
#include "mlir/IR/Builders.h"
@@ -1211,7 +1212,7 @@ LogicalResult tensorrt::ResizeNearestOp::inferReturnTypeComponents(
12111212
inputType.getRank())
12121213
return emitOptionalError(loc, "scales parameter must have same number of "
12131214
"dimensions as input/output");
1214-
for (int i = 0; i < inputType.getRank() - resizeDims; i++)
1215+
for (int64_t i = 0; i < inputType.getRank() - resizeDims; i++)
12151216
if (adaptor.getScales().value()[i] != 1)
12161217
return emitOptionalError(
12171218
loc,
@@ -1236,6 +1237,56 @@ LogicalResult tensorrt::ResizeNearestOp::inferReturnTypeComponents(
12361237
return success();
12371238
}
12381239

1240+
LogicalResult tensorrt::ResizeNearestOp::reifyResultShapes(
1241+
OpBuilder &b, ReifiedRankedShapedTypeDims &result) {
1242+
Location loc = getLoc();
1243+
RankedTensorType resultType = getType();
1244+
int64_t rank = resultType.getRank();
1245+
1246+
// Case 1: if `output_shape` is specified, then we just extract the scalars
1247+
// from that shape.
1248+
if (TypedValue<TensorType> outputShape = getOutputShape()) {
1249+
// 'tensor.extract' %source [%index]
1250+
SmallVector<OpFoldResult> extents;
1251+
for (int64_t i = 0; i < rank; i++) {
1252+
Value index = b.create<arith::ConstantOp>(getLoc(), b.getIndexAttr(i));
1253+
Value extractedShape = b.create<tensor::ExtractOp>(loc, outputShape, index).getResult();
1254+
extents.push_back(
1255+
b.create<arith::IndexCastOp>(loc, b.getIndexType(), extractedShape).getResult());
1256+
}
1257+
result.emplace_back(std::move(extents));
1258+
return success();
1259+
}
1260+
1261+
SmallVector<OpFoldResult> extents;
1262+
extents.reserve(rank);
1263+
1264+
// This number of trailing dimensions are the special dimensions.
1265+
const int64_t resizeDims =
1266+
std::min(static_cast<int64_t>(3), resultType.getRank());
1267+
1268+
for (auto [idx, extent] : llvm::enumerate(resultType.getShape())) {
1269+
1270+
// If dimension is known, just materialize the extent as constant.
1271+
if (!ShapedType::isDynamic(extent)) {
1272+
extents.push_back(b.getIndexAttr(extent));
1273+
continue;
1274+
}
1275+
1276+
// Otherwise, the extent is equal to sentinel value (ShapedType::kDynamic),
1277+
// then we use `tensor.dim` on the input operand.
1278+
// Batch dimensions can only be leading dim.
1279+
if (static_cast<int64_t>(idx) >= rank - resizeDims)
1280+
return failure();
1281+
1282+
Value index = b.create<arith::ConstantOp>(loc, b.getIndexAttr(idx));
1283+
extents.push_back(
1284+
b.create<tensor::DimOp>(loc, getInput(), index).getResult());
1285+
}
1286+
result.emplace_back(std::move(extents));
1287+
return success();
1288+
}
1289+
12391290
//===----------------------------------------------------------------------===//
12401291
// ResizeLinearOp
12411292
//===----------------------------------------------------------------------===//
@@ -1253,7 +1304,7 @@ LogicalResult tensorrt::ResizeLinearOp::inferReturnTypeComponents(
12531304
inputType.getRank())
12541305
return emitOptionalError(loc, "scales parameter must have same number of "
12551306
"dimensions as input/output");
1256-
for (int i = 0; i < inputType.getRank() - resizeDims; i++)
1307+
for (int64_t i = 0; i < inputType.getRank() - resizeDims; i++)
12571308
if (adaptor.getScales().value()[i] != 1)
12581309
return emitOptionalError(
12591310
loc,
@@ -1279,6 +1330,56 @@ LogicalResult tensorrt::ResizeLinearOp::inferReturnTypeComponents(
12791330
return success();
12801331
}
12811332

1333+
LogicalResult tensorrt::ResizeLinearOp::reifyResultShapes(
1334+
OpBuilder &b, ReifiedRankedShapedTypeDims &result) {
1335+
Location loc = getLoc();
1336+
RankedTensorType resultType = getType();
1337+
int64_t rank = resultType.getRank();
1338+
1339+
// Case 1: if `output_shape` is specified, then we just extract the scalars
1340+
// from that shape.
1341+
if (TypedValue<TensorType> outputShape = getOutputShape()) {
1342+
// 'tensor.extract' %source [%index]
1343+
SmallVector<OpFoldResult> extents;
1344+
for (int64_t i = 0; i < rank; i++) {
1345+
Value index = b.create<arith::ConstantOp>(getLoc(), b.getIndexAttr(i));
1346+
Value extractedShape = b.create<tensor::ExtractOp>(loc, outputShape, index).getResult();
1347+
extents.push_back(
1348+
b.create<arith::IndexCastOp>(loc, b.getIndexType(), extractedShape).getResult());
1349+
}
1350+
result.emplace_back(std::move(extents));
1351+
return success();
1352+
}
1353+
1354+
SmallVector<OpFoldResult> extents;
1355+
extents.reserve(rank);
1356+
1357+
// This number of trailing dimensions are the special dimensions.
1358+
const int64_t resizeDims =
1359+
std::min(static_cast<int64_t>(3), resultType.getRank());
1360+
1361+
for (auto [idx, extent] : llvm::enumerate(resultType.getShape())) {
1362+
1363+
// If dimension is known, just materialize the extent as constant.
1364+
if (!ShapedType::isDynamic(extent)) {
1365+
extents.push_back(b.getIndexAttr(extent));
1366+
continue;
1367+
}
1368+
1369+
// Otherwise, the extent is equal to sentinel value (ShapedType::kDynamic),
1370+
// then we use `tensor.dim` on the input operand.
1371+
// Batch dimensions can only be leading dim.
1372+
if (static_cast<int64_t>(idx) >= rank - resizeDims)
1373+
return failure();
1374+
1375+
Value index = b.create<arith::ConstantOp>(loc, b.getIndexAttr(idx));
1376+
extents.push_back(
1377+
b.create<tensor::DimOp>(loc, getInput(), index).getResult());
1378+
}
1379+
result.emplace_back(std::move(extents));
1380+
return success();
1381+
}
1382+
12821383
//===----------------------------------------------------------------------===//
12831384
// ResizeCubicOp
12841385
//===----------------------------------------------------------------------===//
@@ -1298,7 +1399,7 @@ LogicalResult tensorrt::ResizeCubicOp::inferReturnTypeComponents(
12981399
inputType.getRank())
12991400
return emitOptionalError(loc, "scales parameter must have same number of "
13001401
"dimensions as input/output");
1301-
for (int i = 0; i < inputType.getRank() - 2; i++)
1402+
for (int64_t i = 0; i < inputType.getRank() - 2; i++)
13021403
if (adaptor.getScales().value()[i] != 1)
13031404
return emitOptionalError(
13041405
loc, "all scale values except 2 innermost must be 1");
@@ -1323,6 +1424,56 @@ LogicalResult tensorrt::ResizeCubicOp::inferReturnTypeComponents(
13231424
return success();
13241425
}
13251426

1427+
LogicalResult tensorrt::ResizeCubicOp::reifyResultShapes(
1428+
OpBuilder &b, ReifiedRankedShapedTypeDims &result) {
1429+
Location loc = getLoc();
1430+
RankedTensorType resultType = getType();
1431+
int64_t rank = resultType.getRank();
1432+
1433+
// Case 1: if `output_shape` is specified, then we just extract the scalars
1434+
// from that shape.
1435+
if (TypedValue<TensorType> outputShape = getOutputShape()) {
1436+
// 'tensor.extract' %source [%index]
1437+
SmallVector<OpFoldResult> extents;
1438+
for (int64_t i = 0; i < rank; i++) {
1439+
Value index = b.create<arith::ConstantOp>(getLoc(), b.getIndexAttr(i));
1440+
Value extractedShape = b.create<tensor::ExtractOp>(loc, outputShape, index).getResult();
1441+
extents.push_back(
1442+
b.create<arith::IndexCastOp>(loc, b.getIndexType(), extractedShape).getResult());
1443+
}
1444+
result.emplace_back(std::move(extents));
1445+
return success();
1446+
}
1447+
1448+
SmallVector<OpFoldResult> extents;
1449+
extents.reserve(rank);
1450+
1451+
// This number of trailing dimensions are the special dimensions.
1452+
const int64_t resizeDims =
1453+
std::min(static_cast<int64_t>(3), resultType.getRank());
1454+
1455+
for (auto [idx, extent] : llvm::enumerate(resultType.getShape())) {
1456+
1457+
// If dimension is known, just materialize the extent as constant.
1458+
if (!ShapedType::isDynamic(extent)) {
1459+
extents.push_back(b.getIndexAttr(extent));
1460+
continue;
1461+
}
1462+
1463+
// Otherwise, the extent is equal to sentinel value (ShapedType::kDynamic),
1464+
// then we use `tensor.dim` on the input operand.
1465+
// Batch dimensions can only be leading dim.
1466+
if (static_cast<int64_t>(idx) >= rank - resizeDims)
1467+
return failure();
1468+
1469+
Value index = b.create<arith::ConstantOp>(loc, b.getIndexAttr(idx));
1470+
extents.push_back(
1471+
b.create<tensor::DimOp>(loc, getInput(), index).getResult());
1472+
}
1473+
result.emplace_back(std::move(extents));
1474+
return success();
1475+
}
1476+
13261477
//===----------------------------------------------------------------------===//
13271478
// ScatterOp
13281479
//===----------------------------------------------------------------------===//

mlir-tensorrt/tensorrt/tensorrt-opt/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
add_llvm_executable(tensorrt-opt tensorrt-opt.cpp)
2-
2+
get_property(MLIR_TENSORRT_DIALECT_TEST_LIBS GLOBAL PROPERTY MLIR_TENSORRT_DIALECT_TEST_LIBS)
33
llvm_update_compile_flags(tensorrt-opt)
44
target_link_libraries(tensorrt-opt PRIVATE
55
MLIRTensorRTDialect
@@ -10,8 +10,9 @@ target_link_libraries(tensorrt-opt PRIVATE
1010
MLIROptLib
1111
MLIRTensorDialect
1212
MLIRTransforms
13-
MLIRTensorRTTestTensorKindAnalysis
1413
MLIRSCFDialect
14+
${MLIR_TENSORRT_DIALECT_TEST_LIBS}
15+
MLIRTensorRTTestTensorKindAnalysis
1516
)
1617

1718
mlir_check_all_link_libraries(tensorrt-opt)

mlir-tensorrt/tensorrt/tensorrt-opt/tensorrt-opt.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535

3636
namespace mlir {
3737
void registerTestTensorKindAnalysisPass();
38-
}
38+
void registerTestTensorRTShapeInferencePass();
39+
} // namespace mlir
3940

4041
int main(int argc, char **argv) {
4142
mlir::DialectRegistry registry;
@@ -44,6 +45,7 @@ int main(int argc, char **argv) {
4445
mlir::affine::AffineDialect, mlir::quant::QuantizationDialect,
4546
mlir::scf::SCFDialect>();
4647
mlir::registerTestTensorKindAnalysisPass();
48+
mlir::registerTestTensorRTShapeInferencePass();
4749
mlir::func::registerInlinerExtension(registry);
4850
mlir::tensorrt::registerTensorRTTranslationCLOpts();
4951
mlir::tensorrt::registerTensorRTPasses();
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// RUN: tensorrt-opt %s -split-input-file -test-tensorrt-shape-inference | FileCheck %s
2+
3+
func.func @test_resize_linear(%arg0: tensor<10x10xf32>) -> (index, index) {
4+
%result = tensorrt.resize_linear {
5+
coordinateTransformation = #tensorrt.resize_coordinate_transformation<kALIGN_CORNERS>,
6+
selectorForSinglePixel = #tensorrt.resize_selector<kUPPER>
7+
} %arg0 : (tensor<10x10xf32>) -> tensor<20x20xf32>
8+
9+
%c0 = arith.constant 0 : index
10+
%c1 = arith.constant 1 : index
11+
%d0 = tensor.dim %result, %c0 : tensor<20x20xf32>
12+
%d1 = tensor.dim %result, %c1 : tensor<20x20xf32>
13+
return %d0, %d1 : index, index
14+
}
15+
16+
// CHECK-LABEL: test_resize_linear
17+
// CHECK-NEXT: %[[c20:.+]] = arith.constant 20 : index
18+
// CHECK-NEXT: return %[[c20]], %[[c20]] : index, index
19+
20+
// -----
21+
22+
func.func @test_resize_dynamic_batch(%arg0: tensor<?x1x10x10xf32>) -> (index, index, index, index) {
23+
%result = tensorrt.resize_linear {
24+
coordinateTransformation = #tensorrt.resize_coordinate_transformation<kALIGN_CORNERS>,
25+
selectorForSinglePixel = #tensorrt.resize_selector<kUPPER>
26+
} %arg0 : (tensor<?x1x10x10xf32>) -> tensor<?x1x20x20xf32>
27+
28+
%c0 = arith.constant 0 : index
29+
%c1 = arith.constant 1 : index
30+
%c2 = arith.constant 2 : index
31+
%c3 = arith.constant 3 : index
32+
%d0 = tensor.dim %result, %c0 : tensor<?x1x20x20xf32>
33+
%d1 = tensor.dim %result, %c1 : tensor<?x1x20x20xf32>
34+
%d2 = tensor.dim %result, %c2 : tensor<?x1x20x20xf32>
35+
%d3 = tensor.dim %result, %c3 : tensor<?x1x20x20xf32>
36+
return %d0, %d1, %d2, %d3 : index, index, index, index
37+
}
38+
39+
// CHECK-LABEL: func.func @test_resize_dynamic_batch
40+
// CHECK-SAME: (%[[arg0:.+]]: tensor<?x1x10x10xf32>)
41+
// CHECK-DAG: %[[c20:.+]] = arith.constant 20 : index
42+
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
43+
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
44+
// CHECK: %[[dim:.+]] = tensor.dim %[[arg0]], %[[c0]] : tensor<?x1x10x10xf32>
45+
// CHECK: return %[[dim]], %[[c1]], %[[c20]], %[[c20]]
46+
47+
// -----
48+
49+
func.func @test_resize_output_shape(%arg0: tensor<4x4xf32>, %arg1: tensor<2xi32>) -> (index, index) {
50+
%result = tensorrt.resize_linear {
51+
coordinateTransformation = #tensorrt.resize_coordinate_transformation<kALIGN_CORNERS>,
52+
selectorForSinglePixel = #tensorrt.resize_selector<kUPPER>
53+
} %arg0, %arg1 : (tensor<4x4xf32>, tensor<2xi32>) -> tensor<?x?xf32>
54+
55+
%c0 = arith.constant 0 : index
56+
%c1 = arith.constant 1 : index
57+
%d0 = tensor.dim %result, %c0 : tensor<?x?xf32>
58+
%d1 = tensor.dim %result, %c1 : tensor<?x?xf32>
59+
return %d0, %d1 : index, index
60+
}
61+
62+
// CHECK-LABEL: func.func @test_resize_output_shape
63+
// CHECK-SAME: (%[[arg0:.+]]: tensor<4x4xf32>, %[[arg1:.+]]: tensor<2xi32>)
64+
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
65+
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
66+
// CHECK-DAG: %[[extracted:.+]] = tensor.extract %arg1[%c0] : tensor<2xi32>
67+
// CHECK-DAG: %[[v0:.+]] = arith.index_cast %extracted : i32 to index
68+
// CHECK-DAG: %[[extracted_0:.+]] = tensor.extract %arg1[%c1] : tensor<2xi32>
69+
// CHECK-DAG: %[[v1:.+]] = arith.index_cast %extracted_0 : i32 to index
70+
// CHECK: return %[[v0]], %[[v1]]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,8 @@
1+
function(add_mlir_tensorrt_dialect_test_library target)
2+
add_mlir_library(${target} ${ARGN}
3+
EXCLUDE_FROM_LIBMLIR)
4+
set_property(GLOBAL APPEND PROPERTY MLIR_TENSORRT_DIALECT_TEST_LIBS ${target})
5+
endfunction()
6+
17
add_subdirectory(Target)
8+
add_subdirectory(TensorRT)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
add_mlir_tensorrt_dialect_test_library(MLIRTensorRTTestTypeInferencePass
2+
TestTypeInferencePass.cpp
3+
4+
LINK_LIBS PUBLIC
5+
MLIRTensorRTDialect
6+
MLIRTensorDialect
7+
MLIRArithDialect
8+
MLIRPass
9+
MLIRTransformUtils
10+
)

0 commit comments

Comments
 (0)