Skip to content

Commit 6e51621

Browse files
authored
[tensorrt] Add shape inference for resize ops in TRT dialect (#232)
Implements `ReifyRankedShapedTypeOpInterface` for tensorrt resize ops to enable shape inference.
1 parent af99134 commit 6e51621

File tree

9 files changed

+343
-10
lines changed

9 files changed

+343
-10
lines changed

mlir-tensorrt/tensorrt/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,5 +82,5 @@ include_directories(${MLIR_TENSORRT_DIALECT_BINARY_DIR}/include)
8282

8383
add_subdirectory(include/mlir-tensorrt-dialect)
8484
add_subdirectory(lib)
85-
add_subdirectory(tensorrt-opt)
8685
add_subdirectory(test)
86+
add_subdirectory(tensorrt-opt)

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3463,7 +3463,8 @@ def TensorRT_ReshapeOp : TensorRT_Op<"reshape",
34633463
//===----------------------------------------------------------------------===//
34643464

34653465
def TensorRT_ResizeNearestOp : TensorRT_Op<"resize_nearest", [Pure,
3466-
TensorRTPartiallyInferTensorResultTypes]>{
3466+
TensorRTPartiallyInferTensorResultTypes,
3467+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]>{
34673468
let summary = "TensorRT Resize(IResizeLayer with NEAREST mode) operation";
34683469
let description = [{
34693470

@@ -3541,7 +3542,8 @@ def TensorRT_ResizeNearestOp : TensorRT_Op<"resize_nearest", [Pure,
35413542
//===----------------------------------------------------------------------===//
35423543

35433544
def TensorRT_ResizeLinearOp : TensorRT_Op<"resize_linear", [Pure,
3544-
TensorRTPartiallyInferTensorResultTypes]>{
3545+
TensorRTPartiallyInferTensorResultTypes,
3546+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]>{
35453547
let summary = "TensorRT Resize(IResizeLayer with LINEAR mode) operation";
35463548
let description = [{
35473549

@@ -3608,7 +3610,8 @@ def TensorRT_ResizeLinearOp : TensorRT_Op<"resize_linear", [Pure,
36083610
//===----------------------------------------------------------------------===//
36093611

36103612
def TensorRT_ResizeCubicOp : TensorRT_Op<"resize_cubic", [Pure,
3611-
TensorRTPartiallyInferTensorResultTypes]>{
3613+
TensorRTPartiallyInferTensorResultTypes,
3614+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]>{
36123615
let summary = "TensorRT Resize(IResizeLayer with CUBIC mode) operation";
36133616
let description = [{
36143617

mlir-tensorrt/tensorrt/lib/TensorRT/IR/TypeInferenceInterfaceImpls.cpp

Lines changed: 160 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "EinsumHelper.h"
2525
#include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h"
2626
#include "mlir-tensorrt-dialect/Utils/ShapeUtils.h"
27+
#include "mlir/Dialect/Arith/IR/Arith.h"
2728
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2829
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
2930
#include "mlir/IR/Builders.h"
@@ -1211,7 +1212,7 @@ LogicalResult tensorrt::ResizeNearestOp::inferReturnTypeComponents(
12111212
inputType.getRank())
12121213
return emitOptionalError(loc, "scales parameter must have same number of "
12131214
"dimensions as input/output");
1214-
for (int i = 0; i < inputType.getRank() - resizeDims; i++)
1215+
for (int64_t i = 0; i < inputType.getRank() - resizeDims; i++)
12151216
if (adaptor.getScales().value()[i] != 1)
12161217
return emitOptionalError(
12171218
loc,
@@ -1236,6 +1237,58 @@ LogicalResult tensorrt::ResizeNearestOp::inferReturnTypeComponents(
12361237
return success();
12371238
}
12381239

1240+
LogicalResult tensorrt::ResizeNearestOp::reifyResultShapes(
1241+
OpBuilder &b, ReifiedRankedShapedTypeDims &result) {
1242+
Location loc = getLoc();
1243+
RankedTensorType resultType = getType();
1244+
int64_t rank = resultType.getRank();
1245+
1246+
// Case 1: if `output_shape` is specified, then we just extract the scalars
1247+
// from that shape.
1248+
if (TypedValue<TensorType> outputShape = getOutputShape()) {
1249+
// 'tensor.extract' %source [%index]
1250+
SmallVector<OpFoldResult> extents;
1251+
for (int64_t i = 0; i < rank; i++) {
1252+
Value index = b.create<arith::ConstantOp>(getLoc(), b.getIndexAttr(i));
1253+
Value extractedShape =
1254+
b.create<tensor::ExtractOp>(loc, outputShape, index).getResult();
1255+
extents.push_back(
1256+
b.create<arith::IndexCastOp>(loc, b.getIndexType(), extractedShape)
1257+
.getResult());
1258+
}
1259+
result.emplace_back(std::move(extents));
1260+
return success();
1261+
}
1262+
1263+
SmallVector<OpFoldResult> extents;
1264+
extents.reserve(rank);
1265+
1266+
// This number of trailing dimensions are the special dimensions.
1267+
const int64_t resizeDims =
1268+
std::min(static_cast<int64_t>(3), resultType.getRank());
1269+
1270+
for (auto [idx, extent] : llvm::enumerate(resultType.getShape())) {
1271+
1272+
// If dimension is known, just materialize the extent as constant.
1273+
if (!ShapedType::isDynamic(extent)) {
1274+
extents.push_back(b.getIndexAttr(extent));
1275+
continue;
1276+
}
1277+
1278+
// Otherwise, the extent is equal to sentinel value (ShapedType::kDynamic),
1279+
// then we use `tensor.dim` on the input operand.
1280+
// Batch dimensions can only be leading dim.
1281+
if (static_cast<int64_t>(idx) >= rank - resizeDims)
1282+
return failure();
1283+
1284+
Value index = b.create<arith::ConstantOp>(loc, b.getIndexAttr(idx));
1285+
extents.push_back(
1286+
b.create<tensor::DimOp>(loc, getInput(), index).getResult());
1287+
}
1288+
result.emplace_back(std::move(extents));
1289+
return success();
1290+
}
1291+
12391292
//===----------------------------------------------------------------------===//
12401293
// ResizeLinearOp
12411294
//===----------------------------------------------------------------------===//
@@ -1253,7 +1306,7 @@ LogicalResult tensorrt::ResizeLinearOp::inferReturnTypeComponents(
12531306
inputType.getRank())
12541307
return emitOptionalError(loc, "scales parameter must have same number of "
12551308
"dimensions as input/output");
1256-
for (int i = 0; i < inputType.getRank() - resizeDims; i++)
1309+
for (int64_t i = 0; i < inputType.getRank() - resizeDims; i++)
12571310
if (adaptor.getScales().value()[i] != 1)
12581311
return emitOptionalError(
12591312
loc,
@@ -1279,6 +1332,58 @@ LogicalResult tensorrt::ResizeLinearOp::inferReturnTypeComponents(
12791332
return success();
12801333
}
12811334

1335+
LogicalResult tensorrt::ResizeLinearOp::reifyResultShapes(
1336+
OpBuilder &b, ReifiedRankedShapedTypeDims &result) {
1337+
Location loc = getLoc();
1338+
RankedTensorType resultType = getType();
1339+
int64_t rank = resultType.getRank();
1340+
1341+
// Case 1: if `output_shape` is specified, then we just extract the scalars
1342+
// from that shape.
1343+
if (TypedValue<TensorType> outputShape = getOutputShape()) {
1344+
// 'tensor.extract' %source [%index]
1345+
SmallVector<OpFoldResult> extents;
1346+
for (int64_t i = 0; i < rank; i++) {
1347+
Value index = b.create<arith::ConstantOp>(getLoc(), b.getIndexAttr(i));
1348+
Value extractedShape =
1349+
b.create<tensor::ExtractOp>(loc, outputShape, index).getResult();
1350+
extents.push_back(
1351+
b.create<arith::IndexCastOp>(loc, b.getIndexType(), extractedShape)
1352+
.getResult());
1353+
}
1354+
result.emplace_back(std::move(extents));
1355+
return success();
1356+
}
1357+
1358+
SmallVector<OpFoldResult> extents;
1359+
extents.reserve(rank);
1360+
1361+
// This number of trailing dimensions are the special dimensions.
1362+
const int64_t resizeDims =
1363+
std::min(static_cast<int64_t>(3), resultType.getRank());
1364+
1365+
for (auto [idx, extent] : llvm::enumerate(resultType.getShape())) {
1366+
1367+
// If dimension is known, just materialize the extent as constant.
1368+
if (!ShapedType::isDynamic(extent)) {
1369+
extents.push_back(b.getIndexAttr(extent));
1370+
continue;
1371+
}
1372+
1373+
// Otherwise, the extent is equal to sentinel value (ShapedType::kDynamic),
1374+
// then we use `tensor.dim` on the input operand.
1375+
// Batch dimensions can only be leading dim.
1376+
if (static_cast<int64_t>(idx) >= rank - resizeDims)
1377+
return failure();
1378+
1379+
Value index = b.create<arith::ConstantOp>(loc, b.getIndexAttr(idx));
1380+
extents.push_back(
1381+
b.create<tensor::DimOp>(loc, getInput(), index).getResult());
1382+
}
1383+
result.emplace_back(std::move(extents));
1384+
return success();
1385+
}
1386+
12821387
//===----------------------------------------------------------------------===//
12831388
// ResizeCubicOp
12841389
//===----------------------------------------------------------------------===//
@@ -1298,7 +1403,7 @@ LogicalResult tensorrt::ResizeCubicOp::inferReturnTypeComponents(
12981403
inputType.getRank())
12991404
return emitOptionalError(loc, "scales parameter must have same number of "
13001405
"dimensions as input/output");
1301-
for (int i = 0; i < inputType.getRank() - 2; i++)
1406+
for (int64_t i = 0; i < inputType.getRank() - 2; i++)
13021407
if (adaptor.getScales().value()[i] != 1)
13031408
return emitOptionalError(
13041409
loc, "all scale values except 2 innermost must be 1");
@@ -1323,6 +1428,58 @@ LogicalResult tensorrt::ResizeCubicOp::inferReturnTypeComponents(
13231428
return success();
13241429
}
13251430

1431+
LogicalResult tensorrt::ResizeCubicOp::reifyResultShapes(
1432+
OpBuilder &b, ReifiedRankedShapedTypeDims &result) {
1433+
Location loc = getLoc();
1434+
RankedTensorType resultType = getType();
1435+
int64_t rank = resultType.getRank();
1436+
1437+
// Case 1: if `output_shape` is specified, then we just extract the scalars
1438+
// from that shape.
1439+
if (TypedValue<TensorType> outputShape = getOutputShape()) {
1440+
// 'tensor.extract' %source [%index]
1441+
SmallVector<OpFoldResult> extents;
1442+
for (int64_t i = 0; i < rank; i++) {
1443+
Value index = b.create<arith::ConstantOp>(getLoc(), b.getIndexAttr(i));
1444+
Value extractedShape =
1445+
b.create<tensor::ExtractOp>(loc, outputShape, index).getResult();
1446+
extents.push_back(
1447+
b.create<arith::IndexCastOp>(loc, b.getIndexType(), extractedShape)
1448+
.getResult());
1449+
}
1450+
result.emplace_back(std::move(extents));
1451+
return success();
1452+
}
1453+
1454+
SmallVector<OpFoldResult> extents;
1455+
extents.reserve(rank);
1456+
1457+
// This number of trailing dimensions are the special dimensions.
1458+
const int64_t resizeDims =
1459+
std::min(static_cast<int64_t>(3), resultType.getRank());
1460+
1461+
for (auto [idx, extent] : llvm::enumerate(resultType.getShape())) {
1462+
1463+
// If dimension is known, just materialize the extent as constant.
1464+
if (!ShapedType::isDynamic(extent)) {
1465+
extents.push_back(b.getIndexAttr(extent));
1466+
continue;
1467+
}
1468+
1469+
// Otherwise, the extent is equal to sentinel value (ShapedType::kDynamic),
1470+
// then we use `tensor.dim` on the input operand.
1471+
// Batch dimensions can only be leading dim.
1472+
if (static_cast<int64_t>(idx) >= rank - resizeDims)
1473+
return failure();
1474+
1475+
Value index = b.create<arith::ConstantOp>(loc, b.getIndexAttr(idx));
1476+
extents.push_back(
1477+
b.create<tensor::DimOp>(loc, getInput(), index).getResult());
1478+
}
1479+
result.emplace_back(std::move(extents));
1480+
return success();
1481+
}
1482+
13261483
//===----------------------------------------------------------------------===//
13271484
// ScatterOp
13281485
//===----------------------------------------------------------------------===//

mlir-tensorrt/tensorrt/tensorrt-opt/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
add_llvm_executable(tensorrt-opt tensorrt-opt.cpp)
2-
2+
get_property(MLIR_TENSORRT_DIALECT_TEST_LIBS GLOBAL PROPERTY MLIR_TENSORRT_DIALECT_TEST_LIBS)
33
llvm_update_compile_flags(tensorrt-opt)
44
target_link_libraries(tensorrt-opt PRIVATE
55
MLIRTensorRTDialect
@@ -10,8 +10,9 @@ target_link_libraries(tensorrt-opt PRIVATE
1010
MLIROptLib
1111
MLIRTensorDialect
1212
MLIRTransforms
13-
MLIRTensorRTTestTensorKindAnalysis
1413
MLIRSCFDialect
14+
${MLIR_TENSORRT_DIALECT_TEST_LIBS}
15+
MLIRTensorRTTestTensorKindAnalysis
1516
)
1617

1718
mlir_check_all_link_libraries(tensorrt-opt)

mlir-tensorrt/tensorrt/tensorrt-opt/tensorrt-opt.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535

3636
namespace mlir {
3737
void registerTestTensorKindAnalysisPass();
38-
}
38+
void registerTestTensorRTShapeInferencePass();
39+
} // namespace mlir
3940

4041
int main(int argc, char **argv) {
4142
mlir::DialectRegistry registry;
@@ -44,6 +45,7 @@ int main(int argc, char **argv) {
4445
mlir::affine::AffineDialect, mlir::quant::QuantizationDialect,
4546
mlir::scf::SCFDialect>();
4647
mlir::registerTestTensorKindAnalysisPass();
48+
mlir::registerTestTensorRTShapeInferencePass();
4749
mlir::func::registerInlinerExtension(registry);
4850
mlir::tensorrt::registerTensorRTTranslationCLOpts();
4951
mlir::tensorrt::registerTensorRTPasses();
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// RUN: tensorrt-opt %s -split-input-file -test-tensorrt-shape-inference | FileCheck %s
2+
3+
func.func @test_resize_linear(%arg0: tensor<10x10xf32>) -> (index, index) {
4+
%result = tensorrt.resize_linear {
5+
coordinateTransformation = #tensorrt.resize_coordinate_transformation<kALIGN_CORNERS>,
6+
selectorForSinglePixel = #tensorrt.resize_selector<kUPPER>
7+
} %arg0 : (tensor<10x10xf32>) -> tensor<20x20xf32>
8+
9+
%c0 = arith.constant 0 : index
10+
%c1 = arith.constant 1 : index
11+
%d0 = tensor.dim %result, %c0 : tensor<20x20xf32>
12+
%d1 = tensor.dim %result, %c1 : tensor<20x20xf32>
13+
return %d0, %d1 : index, index
14+
}
15+
16+
// CHECK-LABEL: test_resize_linear
17+
// CHECK-NEXT: %[[c20:.+]] = arith.constant 20 : index
18+
// CHECK-NEXT: return %[[c20]], %[[c20]] : index, index
19+
20+
// -----
21+
22+
func.func @test_resize_dynamic_batch(%arg0: tensor<?x1x10x10xf32>) -> (index, index, index, index) {
23+
%result = tensorrt.resize_linear {
24+
coordinateTransformation = #tensorrt.resize_coordinate_transformation<kALIGN_CORNERS>,
25+
selectorForSinglePixel = #tensorrt.resize_selector<kUPPER>
26+
} %arg0 : (tensor<?x1x10x10xf32>) -> tensor<?x1x20x20xf32>
27+
28+
%c0 = arith.constant 0 : index
29+
%c1 = arith.constant 1 : index
30+
%c2 = arith.constant 2 : index
31+
%c3 = arith.constant 3 : index
32+
%d0 = tensor.dim %result, %c0 : tensor<?x1x20x20xf32>
33+
%d1 = tensor.dim %result, %c1 : tensor<?x1x20x20xf32>
34+
%d2 = tensor.dim %result, %c2 : tensor<?x1x20x20xf32>
35+
%d3 = tensor.dim %result, %c3 : tensor<?x1x20x20xf32>
36+
return %d0, %d1, %d2, %d3 : index, index, index, index
37+
}
38+
39+
// CHECK-LABEL: func.func @test_resize_dynamic_batch
40+
// CHECK-SAME: (%[[arg0:.+]]: tensor<?x1x10x10xf32>)
41+
// CHECK-DAG: %[[c20:.+]] = arith.constant 20 : index
42+
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
43+
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
44+
// CHECK: %[[dim:.+]] = tensor.dim %[[arg0]], %[[c0]] : tensor<?x1x10x10xf32>
45+
// CHECK: return %[[dim]], %[[c1]], %[[c20]], %[[c20]]
46+
47+
// -----
48+
49+
func.func @test_resize_output_shape(%arg0: tensor<4x4xf32>, %arg1: tensor<2xi32>) -> (index, index) {
50+
%result = tensorrt.resize_linear {
51+
coordinateTransformation = #tensorrt.resize_coordinate_transformation<kALIGN_CORNERS>,
52+
selectorForSinglePixel = #tensorrt.resize_selector<kUPPER>
53+
} %arg0, %arg1 : (tensor<4x4xf32>, tensor<2xi32>) -> tensor<?x?xf32>
54+
55+
%c0 = arith.constant 0 : index
56+
%c1 = arith.constant 1 : index
57+
%d0 = tensor.dim %result, %c0 : tensor<?x?xf32>
58+
%d1 = tensor.dim %result, %c1 : tensor<?x?xf32>
59+
return %d0, %d1 : index, index
60+
}
61+
62+
// CHECK-LABEL: func.func @test_resize_output_shape
63+
// CHECK-SAME: (%[[arg0:.+]]: tensor<4x4xf32>, %[[arg1:.+]]: tensor<2xi32>)
64+
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
65+
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
66+
// CHECK-DAG: %[[extracted:.+]] = tensor.extract %arg1[%c0] : tensor<2xi32>
67+
// CHECK-DAG: %[[v0:.+]] = arith.index_cast %extracted : i32 to index
68+
// CHECK-DAG: %[[extracted_0:.+]] = tensor.extract %arg1[%c1] : tensor<2xi32>
69+
// CHECK-DAG: %[[v1:.+]] = arith.index_cast %extracted_0 : i32 to index
70+
// CHECK: return %[[v0]], %[[v1]]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,8 @@
1+
function(add_mlir_tensorrt_dialect_test_library target)
2+
add_mlir_library(${target} ${ARGN}
3+
EXCLUDE_FROM_LIBMLIR)
4+
set_property(GLOBAL APPEND PROPERTY MLIR_TENSORRT_DIALECT_TEST_LIBS ${target})
5+
endfunction()
6+
17
add_subdirectory(Target)
8+
add_subdirectory(TensorRT)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
add_mlir_tensorrt_dialect_test_library(MLIRTensorRTTestTypeInferencePass
2+
TestTypeInferencePass.cpp
3+
4+
LINK_LIBS PUBLIC
5+
MLIRTensorRTDialect
6+
MLIRTensorDialect
7+
MLIRArithDialect
8+
MLIRPass
9+
MLIRTransformUtils
10+
)

0 commit comments

Comments
 (0)