2424#include " EinsumHelper.h"
2525#include " mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h"
2626#include " mlir-tensorrt-dialect/Utils/ShapeUtils.h"
27+ #include " mlir/Dialect/Arith/IR/Arith.h"
2728#include " mlir/Dialect/Tensor/IR/Tensor.h"
2829#include " mlir/Dialect/Utils/ReshapeOpsUtils.h"
2930#include " mlir/IR/Builders.h"
@@ -1211,7 +1212,7 @@ LogicalResult tensorrt::ResizeNearestOp::inferReturnTypeComponents(
12111212 inputType.getRank ())
12121213 return emitOptionalError (loc, " scales parameter must have same number of "
12131214 " dimensions as input/output" );
1214- for (int i = 0 ; i < inputType.getRank () - resizeDims; i++)
1215+ for (int64_t i = 0 ; i < inputType.getRank () - resizeDims; i++)
12151216 if (adaptor.getScales ().value ()[i] != 1 )
12161217 return emitOptionalError (
12171218 loc,
@@ -1236,6 +1237,55 @@ LogicalResult tensorrt::ResizeNearestOp::inferReturnTypeComponents(
12361237 return success ();
12371238}
12381239
1240+ LogicalResult tensorrt::ResizeNearestOp::reifyResultShapes (
1241+ OpBuilder &b, ReifiedRankedShapedTypeDims &result) {
1242+ Location loc = getLoc ();
1243+ RankedTensorType resultType = getType ();
1244+ int64_t rank = resultType.getRank ();
1245+
1246+ // Case 1: if `output_shape` is specified, then we just extract the scalars
1247+ // from that shape.
1248+ if (TypedValue<TensorType> outputShape = getOutputShape ()) {
1249+ // 'tensor.extract' %source [%index]
1250+ SmallVector<OpFoldResult> extents;
1251+ for (int64_t i = 0 ; i < rank; i++) {
1252+ Value index = b.create <arith::ConstantOp>(getLoc (), b.getIndexAttr (i));
1253+ extents.push_back (
1254+ b.create <tensor::ExtractOp>(loc, outputShape, index).getResult ());
1255+ }
1256+ result.emplace_back (std::move (extents));
1257+ return success ();
1258+ }
1259+
1260+ SmallVector<OpFoldResult> extents;
1261+ extents.reserve (rank);
1262+
1263+ // This number of trailing dimensions are the special dimensions.
1264+ const int64_t resizeDims =
1265+ std::min (static_cast <int64_t >(3 ), resultType.getRank ());
1266+
1267+ for (auto [idx, extent] : llvm::enumerate (resultType.getShape ())) {
1268+
1269+ // If dimension is known, just materialize the extent as constant.
1270+ if (!ShapedType::isDynamic (extent)) {
1271+ extents.push_back (b.getIndexAttr (extent));
1272+ continue ;
1273+ }
1274+
1275+ // Otherwise, the extent is equal to sentinel value (ShapedType::kDynamic),
1276+ // then we use `tensor.dim` on the input operand.
1277+ // Batch dimensions can only be leading dim.
1278+ if (static_cast <int64_t >(idx) >= rank - resizeDims)
1279+ return failure ();
1280+
1281+ Value index = b.create <arith::ConstantOp>(loc, b.getIndexAttr (idx));
1282+ extents.push_back (
1283+ b.create <tensor::DimOp>(loc, getInput (), index).getResult ());
1284+ }
1285+ result.emplace_back (std::move (extents));
1286+ return success ();
1287+ }
1288+
12391289// ===----------------------------------------------------------------------===//
12401290// ResizeLinearOp
12411291// ===----------------------------------------------------------------------===//
@@ -1253,7 +1303,7 @@ LogicalResult tensorrt::ResizeLinearOp::inferReturnTypeComponents(
12531303 inputType.getRank ())
12541304 return emitOptionalError (loc, " scales parameter must have same number of "
12551305 " dimensions as input/output" );
1256- for (int i = 0 ; i < inputType.getRank () - resizeDims; i++)
1306+ for (int64_t i = 0 ; i < inputType.getRank () - resizeDims; i++)
12571307 if (adaptor.getScales ().value ()[i] != 1 )
12581308 return emitOptionalError (
12591309 loc,
@@ -1279,6 +1329,55 @@ LogicalResult tensorrt::ResizeLinearOp::inferReturnTypeComponents(
12791329 return success ();
12801330}
12811331
1332+ LogicalResult tensorrt::ResizeLinearOp::reifyResultShapes (
1333+ OpBuilder &b, ReifiedRankedShapedTypeDims &result) {
1334+ Location loc = getLoc ();
1335+ RankedTensorType resultType = getType ();
1336+ int64_t rank = resultType.getRank ();
1337+
1338+ // Case 1: if `output_shape` is specified, then we just extract the scalars
1339+ // from that shape.
1340+ if (TypedValue<TensorType> outputShape = getOutputShape ()) {
1341+ // 'tensor.extract' %source [%index]
1342+ SmallVector<OpFoldResult> extents;
1343+ for (int64_t i = 0 ; i < rank; i++) {
1344+ Value index = b.create <arith::ConstantOp>(getLoc (), b.getIndexAttr (i));
1345+ extents.push_back (
1346+ b.create <tensor::ExtractOp>(loc, outputShape, index).getResult ());
1347+ }
1348+ result.emplace_back (std::move (extents));
1349+ return success ();
1350+ }
1351+
1352+ SmallVector<OpFoldResult> extents;
1353+ extents.reserve (rank);
1354+
1355+ // This number of trailing dimensions are the special dimensions.
1356+ const int64_t resizeDims =
1357+ std::min (static_cast <int64_t >(3 ), resultType.getRank ());
1358+
1359+ for (auto [idx, extent] : llvm::enumerate (resultType.getShape ())) {
1360+
1361+ // If dimension is known, just materialize the extent as constant.
1362+ if (!ShapedType::isDynamic (extent)) {
1363+ extents.push_back (b.getIndexAttr (extent));
1364+ continue ;
1365+ }
1366+
1367+ // Otherwise, the extent is equal to sentinel value (ShapedType::kDynamic),
1368+ // then we use `tensor.dim` on the input operand.
1369+ // Batch dimensions can only be leading dim.
1370+ if (static_cast <int64_t >(idx) >= rank - resizeDims)
1371+ return failure ();
1372+
1373+ Value index = b.create <arith::ConstantOp>(loc, b.getIndexAttr (idx));
1374+ extents.push_back (
1375+ b.create <tensor::DimOp>(loc, getInput (), index).getResult ());
1376+ }
1377+ result.emplace_back (std::move (extents));
1378+ return success ();
1379+ }
1380+
12821381// ===----------------------------------------------------------------------===//
12831382// ResizeCubicOp
12841383// ===----------------------------------------------------------------------===//
@@ -1298,7 +1397,7 @@ LogicalResult tensorrt::ResizeCubicOp::inferReturnTypeComponents(
12981397 inputType.getRank ())
12991398 return emitOptionalError (loc, " scales parameter must have same number of "
13001399 " dimensions as input/output" );
1301- for (int i = 0 ; i < inputType.getRank () - 2 ; i++)
1400+ for (int64_t i = 0 ; i < inputType.getRank () - 2 ; i++)
13021401 if (adaptor.getScales ().value ()[i] != 1 )
13031402 return emitOptionalError (
13041403 loc, " all scale values except 2 innermost must be 1" );
@@ -1323,6 +1422,55 @@ LogicalResult tensorrt::ResizeCubicOp::inferReturnTypeComponents(
13231422 return success ();
13241423}
13251424
1425+ LogicalResult tensorrt::ResizeCubicOp::reifyResultShapes (
1426+ OpBuilder &b, ReifiedRankedShapedTypeDims &result) {
1427+ Location loc = getLoc ();
1428+ RankedTensorType resultType = getType ();
1429+ int64_t rank = resultType.getRank ();
1430+
1431+ // Case 1: if `output_shape` is specified, then we just extract the scalars
1432+ // from that shape.
1433+ if (TypedValue<TensorType> outputShape = getOutputShape ()) {
1434+ // 'tensor.extract' %source [%index]
1435+ SmallVector<OpFoldResult> extents;
1436+ for (int64_t i = 0 ; i < rank; i++) {
1437+ Value index = b.create <arith::ConstantOp>(getLoc (), b.getIndexAttr (i));
1438+ extents.push_back (
1439+ b.create <tensor::ExtractOp>(loc, outputShape, index).getResult ());
1440+ }
1441+ result.emplace_back (std::move (extents));
1442+ return success ();
1443+ }
1444+
1445+ SmallVector<OpFoldResult> extents;
1446+ extents.reserve (rank);
1447+
1448+ // This number of trailing dimensions are the special dimensions.
1449+ const int64_t resizeDims =
1450+ std::min (static_cast <int64_t >(3 ), resultType.getRank ());
1451+
1452+ for (auto [idx, extent] : llvm::enumerate (resultType.getShape ())) {
1453+
1454+ // If dimension is known, just materialize the extent as constant.
1455+ if (!ShapedType::isDynamic (extent)) {
1456+ extents.push_back (b.getIndexAttr (extent));
1457+ continue ;
1458+ }
1459+
1460+ // Otherwise, the extent is equal to sentinel value (ShapedType::kDynamic),
1461+ // then we use `tensor.dim` on the input operand.
1462+ // Batch dimensions can only be leading dim.
1463+ if (static_cast <int64_t >(idx) >= rank - resizeDims)
1464+ return failure ();
1465+
1466+ Value index = b.create <arith::ConstantOp>(loc, b.getIndexAttr (idx));
1467+ extents.push_back (
1468+ b.create <tensor::DimOp>(loc, getInput (), index).getResult ());
1469+ }
1470+ result.emplace_back (std::move (extents));
1471+ return success ();
1472+ }
1473+
13261474// ===----------------------------------------------------------------------===//
13271475// ScatterOp
13281476// ===----------------------------------------------------------------------===//
0 commit comments