Skip to content

Commit 537cc2c

Browse files
christopherbateyizhuoz004
authored andcommitted
[tensorrt] Add shape inference for tensorrt resize ops
Implements `ReifyRankedShapedTypeOpInterface` for tensorrt resize ops to enable shape inference.
1 parent 2503d17 commit 537cc2c

File tree

9 files changed

+309
-10
lines changed

9 files changed

+309
-10
lines changed

mlir-tensorrt/tensorrt/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,5 +82,5 @@ include_directories(${MLIR_TENSORRT_DIALECT_BINARY_DIR}/include)
8282

8383
add_subdirectory(include/mlir-tensorrt-dialect)
8484
add_subdirectory(lib)
85-
add_subdirectory(tensorrt-opt)
8685
add_subdirectory(test)
86+
add_subdirectory(tensorrt-opt)

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3463,7 +3463,8 @@ def TensorRT_ReshapeOp : TensorRT_Op<"reshape",
34633463
//===----------------------------------------------------------------------===//
34643464

34653465
def TensorRT_ResizeNearestOp : TensorRT_Op<"resize_nearest", [Pure,
3466-
TensorRTPartiallyInferTensorResultTypes]>{
3466+
TensorRTPartiallyInferTensorResultTypes,
3467+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]>{
34673468
let summary = "TensorRT Resize(IResizeLayer with NEAREST mode) operation";
34683469
let description = [{
34693470

@@ -3541,7 +3542,8 @@ def TensorRT_ResizeNearestOp : TensorRT_Op<"resize_nearest", [Pure,
35413542
//===----------------------------------------------------------------------===//
35423543

35433544
def TensorRT_ResizeLinearOp : TensorRT_Op<"resize_linear", [Pure,
3544-
TensorRTPartiallyInferTensorResultTypes]>{
3545+
TensorRTPartiallyInferTensorResultTypes,
3546+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]>{
35453547
let summary = "TensorRT Resize(IResizeLayer with LINEAR mode) operation";
35463548
let description = [{
35473549

@@ -3608,7 +3610,8 @@ def TensorRT_ResizeLinearOp : TensorRT_Op<"resize_linear", [Pure,
36083610
//===----------------------------------------------------------------------===//
36093611

36103612
def TensorRT_ResizeCubicOp : TensorRT_Op<"resize_cubic", [Pure,
3611-
TensorRTPartiallyInferTensorResultTypes]>{
3613+
TensorRTPartiallyInferTensorResultTypes,
3614+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]>{
36123615
let summary = "TensorRT Resize(IResizeLayer with CUBIC mode) operation";
36133616
let description = [{
36143617

mlir-tensorrt/tensorrt/lib/TensorRT/IR/TypeInferenceInterfaceImpls.cpp

Lines changed: 151 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "EinsumHelper.h"
2525
#include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h"
2626
#include "mlir-tensorrt-dialect/Utils/ShapeUtils.h"
27+
#include "mlir/Dialect/Arith/IR/Arith.h"
2728
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2829
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
2930
#include "mlir/IR/Builders.h"
@@ -1211,7 +1212,7 @@ LogicalResult tensorrt::ResizeNearestOp::inferReturnTypeComponents(
12111212
inputType.getRank())
12121213
return emitOptionalError(loc, "scales parameter must have same number of "
12131214
"dimensions as input/output");
1214-
for (int i = 0; i < inputType.getRank() - resizeDims; i++)
1215+
for (int64_t i = 0; i < inputType.getRank() - resizeDims; i++)
12151216
if (adaptor.getScales().value()[i] != 1)
12161217
return emitOptionalError(
12171218
loc,
@@ -1236,6 +1237,55 @@ LogicalResult tensorrt::ResizeNearestOp::inferReturnTypeComponents(
12361237
return success();
12371238
}
12381239

1240+
LogicalResult tensorrt::ResizeNearestOp::reifyResultShapes(
1241+
OpBuilder &b, ReifiedRankedShapedTypeDims &result) {
1242+
Location loc = getLoc();
1243+
RankedTensorType resultType = getType();
1244+
int64_t rank = resultType.getRank();
1245+
1246+
// Case 1: if `output_shape` is specified, then we just extract the scalars
1247+
// from that shape.
1248+
if (TypedValue<TensorType> outputShape = getOutputShape()) {
1249+
// 'tensor.extract' %source [%index]
1250+
SmallVector<OpFoldResult> extents;
1251+
for (int64_t i = 0; i < rank; i++) {
1252+
Value index = b.create<arith::ConstantOp>(getLoc(), b.getIndexAttr(i));
1253+
extents.push_back(
1254+
b.create<tensor::ExtractOp>(loc, outputShape, index).getResult());
1255+
}
1256+
result.emplace_back(std::move(extents));
1257+
return success();
1258+
}
1259+
1260+
SmallVector<OpFoldResult> extents;
1261+
extents.reserve(rank);
1262+
1263+
// This number of trailing dimensions are the special dimensions.
1264+
const int64_t resizeDims =
1265+
std::min(static_cast<int64_t>(3), resultType.getRank());
1266+
1267+
for (auto [idx, extent] : llvm::enumerate(resultType.getShape())) {
1268+
1269+
// If dimension is known, just materialize the extent as constant.
1270+
if (!ShapedType::isDynamic(extent)) {
1271+
extents.push_back(b.getIndexAttr(extent));
1272+
continue;
1273+
}
1274+
1275+
// Otherwise, the extent is equal to sentinel value (ShapedType::kDynamic),
1276+
// then we use `tensor.dim` on the input operand.
1277+
// Batch dimensions can only be leading dim.
1278+
if (static_cast<int64_t>(idx) >= rank - resizeDims)
1279+
return failure();
1280+
1281+
Value index = b.create<arith::ConstantOp>(loc, b.getIndexAttr(idx));
1282+
extents.push_back(
1283+
b.create<tensor::DimOp>(loc, getInput(), index).getResult());
1284+
}
1285+
result.emplace_back(std::move(extents));
1286+
return success();
1287+
}
1288+
12391289
//===----------------------------------------------------------------------===//
12401290
// ResizeLinearOp
12411291
//===----------------------------------------------------------------------===//
@@ -1253,7 +1303,7 @@ LogicalResult tensorrt::ResizeLinearOp::inferReturnTypeComponents(
12531303
inputType.getRank())
12541304
return emitOptionalError(loc, "scales parameter must have same number of "
12551305
"dimensions as input/output");
1256-
for (int i = 0; i < inputType.getRank() - resizeDims; i++)
1306+
for (int64_t i = 0; i < inputType.getRank() - resizeDims; i++)
12571307
if (adaptor.getScales().value()[i] != 1)
12581308
return emitOptionalError(
12591309
loc,
@@ -1279,6 +1329,55 @@ LogicalResult tensorrt::ResizeLinearOp::inferReturnTypeComponents(
12791329
return success();
12801330
}
12811331

1332+
LogicalResult tensorrt::ResizeLinearOp::reifyResultShapes(
1333+
OpBuilder &b, ReifiedRankedShapedTypeDims &result) {
1334+
Location loc = getLoc();
1335+
RankedTensorType resultType = getType();
1336+
int64_t rank = resultType.getRank();
1337+
1338+
// Case 1: if `output_shape` is specified, then we just extract the scalars
1339+
// from that shape.
1340+
if (TypedValue<TensorType> outputShape = getOutputShape()) {
1341+
// 'tensor.extract' %source [%index]
1342+
SmallVector<OpFoldResult> extents;
1343+
for (int64_t i = 0; i < rank; i++) {
1344+
Value index = b.create<arith::ConstantOp>(getLoc(), b.getIndexAttr(i));
1345+
extents.push_back(
1346+
b.create<tensor::ExtractOp>(loc, outputShape, index).getResult());
1347+
}
1348+
result.emplace_back(std::move(extents));
1349+
return success();
1350+
}
1351+
1352+
SmallVector<OpFoldResult> extents;
1353+
extents.reserve(rank);
1354+
1355+
// This number of trailing dimensions are the special dimensions.
1356+
const int64_t resizeDims =
1357+
std::min(static_cast<int64_t>(3), resultType.getRank());
1358+
1359+
for (auto [idx, extent] : llvm::enumerate(resultType.getShape())) {
1360+
1361+
// If dimension is known, just materialize the extent as constant.
1362+
if (!ShapedType::isDynamic(extent)) {
1363+
extents.push_back(b.getIndexAttr(extent));
1364+
continue;
1365+
}
1366+
1367+
// Otherwise, the extent is equal to sentinel value (ShapedType::kDynamic),
1368+
// then we use `tensor.dim` on the input operand.
1369+
// Batch dimensions can only be leading dim.
1370+
if (static_cast<int64_t>(idx) >= rank - resizeDims)
1371+
return failure();
1372+
1373+
Value index = b.create<arith::ConstantOp>(loc, b.getIndexAttr(idx));
1374+
extents.push_back(
1375+
b.create<tensor::DimOp>(loc, getInput(), index).getResult());
1376+
}
1377+
result.emplace_back(std::move(extents));
1378+
return success();
1379+
}
1380+
12821381
//===----------------------------------------------------------------------===//
12831382
// ResizeCubicOp
12841383
//===----------------------------------------------------------------------===//
@@ -1298,7 +1397,7 @@ LogicalResult tensorrt::ResizeCubicOp::inferReturnTypeComponents(
12981397
inputType.getRank())
12991398
return emitOptionalError(loc, "scales parameter must have same number of "
13001399
"dimensions as input/output");
1301-
for (int i = 0; i < inputType.getRank() - 2; i++)
1400+
for (int64_t i = 0; i < inputType.getRank() - 2; i++)
13021401
if (adaptor.getScales().value()[i] != 1)
13031402
return emitOptionalError(
13041403
loc, "all scale values except 2 innermost must be 1");
@@ -1323,6 +1422,55 @@ LogicalResult tensorrt::ResizeCubicOp::inferReturnTypeComponents(
13231422
return success();
13241423
}
13251424

1425+
LogicalResult tensorrt::ResizeCubicOp::reifyResultShapes(
1426+
OpBuilder &b, ReifiedRankedShapedTypeDims &result) {
1427+
Location loc = getLoc();
1428+
RankedTensorType resultType = getType();
1429+
int64_t rank = resultType.getRank();
1430+
1431+
// Case 1: if `output_shape` is specified, then we just extract the scalars
1432+
// from that shape.
1433+
if (TypedValue<TensorType> outputShape = getOutputShape()) {
1434+
// 'tensor.extract' %source [%index]
1435+
SmallVector<OpFoldResult> extents;
1436+
for (int64_t i = 0; i < rank; i++) {
1437+
Value index = b.create<arith::ConstantOp>(getLoc(), b.getIndexAttr(i));
1438+
extents.push_back(
1439+
b.create<tensor::ExtractOp>(loc, outputShape, index).getResult());
1440+
}
1441+
result.emplace_back(std::move(extents));
1442+
return success();
1443+
}
1444+
1445+
SmallVector<OpFoldResult> extents;
1446+
extents.reserve(rank);
1447+
1448+
// This number of trailing dimensions are the special dimensions.
1449+
const int64_t resizeDims =
1450+
std::min(static_cast<int64_t>(3), resultType.getRank());
1451+
1452+
for (auto [idx, extent] : llvm::enumerate(resultType.getShape())) {
1453+
1454+
// If dimension is known, just materialize the extent as constant.
1455+
if (!ShapedType::isDynamic(extent)) {
1456+
extents.push_back(b.getIndexAttr(extent));
1457+
continue;
1458+
}
1459+
1460+
// Otherwise, the extent is equal to sentinel value (ShapedType::kDynamic),
1461+
// then we use `tensor.dim` on the input operand.
1462+
// Batch dimensions can only be leading dim.
1463+
if (static_cast<int64_t>(idx) >= rank - resizeDims)
1464+
return failure();
1465+
1466+
Value index = b.create<arith::ConstantOp>(loc, b.getIndexAttr(idx));
1467+
extents.push_back(
1468+
b.create<tensor::DimOp>(loc, getInput(), index).getResult());
1469+
}
1470+
result.emplace_back(std::move(extents));
1471+
return success();
1472+
}
1473+
13261474
//===----------------------------------------------------------------------===//
13271475
// ScatterOp
13281476
//===----------------------------------------------------------------------===//

mlir-tensorrt/tensorrt/tensorrt-opt/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
add_llvm_executable(tensorrt-opt tensorrt-opt.cpp)
2-
2+
get_property(MLIR_TENSORRT_DIALECT_TEST_LIBS GLOBAL PROPERTY MLIR_TENSORRT_DIALECT_TEST_LIBS)
33
llvm_update_compile_flags(tensorrt-opt)
44
target_link_libraries(tensorrt-opt PRIVATE
55
MLIRTensorRTDialect
@@ -10,8 +10,9 @@ target_link_libraries(tensorrt-opt PRIVATE
1010
MLIROptLib
1111
MLIRTensorDialect
1212
MLIRTransforms
13-
MLIRTensorRTTestTensorKindAnalysis
1413
MLIRSCFDialect
14+
${MLIR_TENSORRT_DIALECT_TEST_LIBS}
15+
MLIRTensorRTTestTensorKindAnalysis
1516
)
1617

1718
mlir_check_all_link_libraries(tensorrt-opt)

mlir-tensorrt/tensorrt/tensorrt-opt/tensorrt-opt.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535

3636
namespace mlir {
3737
void registerTestTensorKindAnalysisPass();
38-
}
38+
void registerTestTensorRTShapeInferencePass();
39+
} // namespace mlir
3940

4041
int main(int argc, char **argv) {
4142
mlir::DialectRegistry registry;
@@ -44,6 +45,7 @@ int main(int argc, char **argv) {
4445
mlir::affine::AffineDialect, mlir::quant::QuantizationDialect,
4546
mlir::scf::SCFDialect>();
4647
mlir::registerTestTensorKindAnalysisPass();
48+
mlir::registerTestTensorRTShapeInferencePass();
4749
mlir::func::registerInlinerExtension(registry);
4850
mlir::tensorrt::registerTensorRTTranslationCLOpts();
4951
mlir::tensorrt::registerTensorRTPasses();
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// tensorrt-opt %s -split-input-file -test-tensorrt-shape-inference | FileCheck %s
2+
3+
func.func @test_resize_linear(%arg0: tensor<10x10xf32>) -> (index, index) {
4+
%result = tensorrt.resize_linear {
5+
coordinateTransformation = #tensorrt.resize_coordinate_transformation<kALIGN_CORNERS>,
6+
selectorForSinglePixel = #tensorrt.resize_selector<kUPPER>
7+
} %arg0 : (tensor<10x10xf32>) -> tensor<20x20xf32>
8+
9+
%c0 = arith.constant 0 : index
10+
%c1 = arith.constant 1 : index
11+
%d0 = tensor.dim %result, %c0 : tensor<20x20xf32>
12+
%d1 = tensor.dim %result, %c1 : tensor<20x20xf32>
13+
return %d0, %d1 : index, index
14+
}
15+
16+
// CHECK-LABEL: test_resize_linear
17+
// CHECK-NEXT: %[[c20:.+]] = arith.constant 20 : index
18+
// CHECK-NEXT: return %[[c20]], %[[[c20]]] : index, index
19+
20+
// -----
21+
22+
func.func @test_resize_dynamic_batch(%arg0: tensor<?x1x10x10xf32>) -> (index, index, index, index) {
23+
%result = tensorrt.resize_linear {
24+
coordinateTransformation = #tensorrt.resize_coordinate_transformation<kALIGN_CORNERS>,
25+
selectorForSinglePixel = #tensorrt.resize_selector<kUPPER>
26+
} %arg0 : (tensor<?x1x10x10xf32>) -> tensor<?x1x20x20xf32>
27+
28+
%c0 = arith.constant 0 : index
29+
%c1 = arith.constant 1 : index
30+
%c2 = arith.constant 2 : index
31+
%c3 = arith.constant 3 : index
32+
%d0 = tensor.dim %result, %c0 : tensor<?x1x20x20xf32>
33+
%d1 = tensor.dim %result, %c1 : tensor<?x1x20x20xf32>
34+
%d2 = tensor.dim %result, %c2 : tensor<?x1x20x20xf32>
35+
%d3 = tensor.dim %result, %c3 : tensor<?x1x20x20xf32>
36+
return %d0, %d1, %d2, %d3 : index, index, index, index
37+
}
38+
39+
// CHECK-LABEL: func.func @test_resize_dynamic_batch
40+
// CHECK-SAME: (%[[arg0]]: tensor<?x1x10x10xf32>)
41+
// CHECK-DAG: %[[c20:.+]] = arith.constant 20 : index
42+
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
43+
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
44+
// CHECK: %[[dim:.+]] = tensor.dim %[[arg0]], %[[c0] : tensor<?x1x10x10xf32>
45+
// CHECK: return %[[dim]], %[[c1]], %[[c20]], %[[c20]]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,8 @@
1+
function(add_mlir_tensorrt_dialect_test_library target)
2+
add_mlir_library(${target} ${ARGN}
3+
EXCLUDE_FROM_LIBMLIR)
4+
set_property(GLOBAL APPEND PROPERTY MLIR_TENSORRT_DIALECT_TEST_LIBS ${target})
5+
endfunction()
6+
17
add_subdirectory(Target)
8+
add_subdirectory(TensorRT)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
add_mlir_tensorrt_dialect_test_library(MLIRTensorRTTestTypeInferencePass
2+
TestTypeInferencePass.cpp
3+
4+
LINK_LIBS PUBLIC
5+
MLIRTensorRTDialect
6+
MLIRTensorDialect
7+
MLIRArithDialect
8+
MLIRPass
9+
MLIRTransformUtils
10+
)

0 commit comments

Comments
 (0)