2424#include " EinsumHelper.h"
2525#include " mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h"
2626#include " mlir-tensorrt-dialect/Utils/ShapeUtils.h"
27+ #include " mlir/Dialect/Arith/IR/Arith.h"
2728#include " mlir/Dialect/Tensor/IR/Tensor.h"
2829#include " mlir/Dialect/Utils/ReshapeOpsUtils.h"
2930#include " mlir/IR/Builders.h"
@@ -1211,7 +1212,7 @@ LogicalResult tensorrt::ResizeNearestOp::inferReturnTypeComponents(
12111212 inputType.getRank ())
12121213 return emitOptionalError (loc, " scales parameter must have same number of "
12131214 " dimensions as input/output" );
1214- for (int i = 0 ; i < inputType.getRank () - resizeDims; i++)
1215+ for (int64_t i = 0 ; i < inputType.getRank () - resizeDims; i++)
12151216 if (adaptor.getScales ().value ()[i] != 1 )
12161217 return emitOptionalError (
12171218 loc,
@@ -1236,6 +1237,58 @@ LogicalResult tensorrt::ResizeNearestOp::inferReturnTypeComponents(
12361237 return success ();
12371238}
12381239
1240+ LogicalResult tensorrt::ResizeNearestOp::reifyResultShapes (
1241+ OpBuilder &b, ReifiedRankedShapedTypeDims &result) {
1242+ Location loc = getLoc ();
1243+ RankedTensorType resultType = getType ();
1244+ int64_t rank = resultType.getRank ();
1245+
1246+ // Case 1: if `output_shape` is specified, then we just extract the scalars
1247+ // from that shape.
1248+ if (TypedValue<TensorType> outputShape = getOutputShape ()) {
1249+ // 'tensor.extract' %source [%index]
1250+ SmallVector<OpFoldResult> extents;
1251+ for (int64_t i = 0 ; i < rank; i++) {
1252+ Value index = b.create <arith::ConstantOp>(getLoc (), b.getIndexAttr (i));
1253+ Value extractedShape =
1254+ b.create <tensor::ExtractOp>(loc, outputShape, index).getResult ();
1255+ extents.push_back (
1256+ b.create <arith::IndexCastOp>(loc, b.getIndexType (), extractedShape)
1257+ .getResult ());
1258+ }
1259+ result.emplace_back (std::move (extents));
1260+ return success ();
1261+ }
1262+
1263+ SmallVector<OpFoldResult> extents;
1264+ extents.reserve (rank);
1265+
1266+ // This number of trailing dimensions are the special dimensions.
1267+ const int64_t resizeDims =
1268+ std::min (static_cast <int64_t >(3 ), resultType.getRank ());
1269+
1270+ for (auto [idx, extent] : llvm::enumerate (resultType.getShape ())) {
1271+
1272+ // If dimension is known, just materialize the extent as constant.
1273+ if (!ShapedType::isDynamic (extent)) {
1274+ extents.push_back (b.getIndexAttr (extent));
1275+ continue ;
1276+ }
1277+
1278+ // Otherwise, the extent is equal to sentinel value (ShapedType::kDynamic),
1279+ // then we use `tensor.dim` on the input operand.
1280+ // Batch dimensions can only be leading dim.
1281+ if (static_cast <int64_t >(idx) >= rank - resizeDims)
1282+ return failure ();
1283+
1284+ Value index = b.create <arith::ConstantOp>(loc, b.getIndexAttr (idx));
1285+ extents.push_back (
1286+ b.create <tensor::DimOp>(loc, getInput (), index).getResult ());
1287+ }
1288+ result.emplace_back (std::move (extents));
1289+ return success ();
1290+ }
1291+
12391292// ===----------------------------------------------------------------------===//
12401293// ResizeLinearOp
12411294// ===----------------------------------------------------------------------===//
@@ -1253,7 +1306,7 @@ LogicalResult tensorrt::ResizeLinearOp::inferReturnTypeComponents(
12531306 inputType.getRank ())
12541307 return emitOptionalError (loc, " scales parameter must have same number of "
12551308 " dimensions as input/output" );
1256- for (int i = 0 ; i < inputType.getRank () - resizeDims; i++)
1309+ for (int64_t i = 0 ; i < inputType.getRank () - resizeDims; i++)
12571310 if (adaptor.getScales ().value ()[i] != 1 )
12581311 return emitOptionalError (
12591312 loc,
@@ -1279,6 +1332,58 @@ LogicalResult tensorrt::ResizeLinearOp::inferReturnTypeComponents(
12791332 return success ();
12801333}
12811334
1335+ LogicalResult tensorrt::ResizeLinearOp::reifyResultShapes (
1336+ OpBuilder &b, ReifiedRankedShapedTypeDims &result) {
1337+ Location loc = getLoc ();
1338+ RankedTensorType resultType = getType ();
1339+ int64_t rank = resultType.getRank ();
1340+
1341+ // Case 1: if `output_shape` is specified, then we just extract the scalars
1342+ // from that shape.
1343+ if (TypedValue<TensorType> outputShape = getOutputShape ()) {
1344+ // 'tensor.extract' %source [%index]
1345+ SmallVector<OpFoldResult> extents;
1346+ for (int64_t i = 0 ; i < rank; i++) {
1347+ Value index = b.create <arith::ConstantOp>(getLoc (), b.getIndexAttr (i));
1348+ Value extractedShape =
1349+ b.create <tensor::ExtractOp>(loc, outputShape, index).getResult ();
1350+ extents.push_back (
1351+ b.create <arith::IndexCastOp>(loc, b.getIndexType (), extractedShape)
1352+ .getResult ());
1353+ }
1354+ result.emplace_back (std::move (extents));
1355+ return success ();
1356+ }
1357+
1358+ SmallVector<OpFoldResult> extents;
1359+ extents.reserve (rank);
1360+
1361+ // This number of trailing dimensions are the special dimensions.
1362+ const int64_t resizeDims =
1363+ std::min (static_cast <int64_t >(3 ), resultType.getRank ());
1364+
1365+ for (auto [idx, extent] : llvm::enumerate (resultType.getShape ())) {
1366+
1367+ // If dimension is known, just materialize the extent as constant.
1368+ if (!ShapedType::isDynamic (extent)) {
1369+ extents.push_back (b.getIndexAttr (extent));
1370+ continue ;
1371+ }
1372+
1373+ // Otherwise, the extent is equal to sentinel value (ShapedType::kDynamic),
1374+ // then we use `tensor.dim` on the input operand.
1375+ // Batch dimensions can only be leading dim.
1376+ if (static_cast <int64_t >(idx) >= rank - resizeDims)
1377+ return failure ();
1378+
1379+ Value index = b.create <arith::ConstantOp>(loc, b.getIndexAttr (idx));
1380+ extents.push_back (
1381+ b.create <tensor::DimOp>(loc, getInput (), index).getResult ());
1382+ }
1383+ result.emplace_back (std::move (extents));
1384+ return success ();
1385+ }
1386+
12821387// ===----------------------------------------------------------------------===//
12831388// ResizeCubicOp
12841389// ===----------------------------------------------------------------------===//
@@ -1298,7 +1403,7 @@ LogicalResult tensorrt::ResizeCubicOp::inferReturnTypeComponents(
12981403 inputType.getRank ())
12991404 return emitOptionalError (loc, " scales parameter must have same number of "
13001405 " dimensions as input/output" );
1301- for (int i = 0 ; i < inputType.getRank () - 2 ; i++)
1406+ for (int64_t i = 0 ; i < inputType.getRank () - 2 ; i++)
13021407 if (adaptor.getScales ().value ()[i] != 1 )
13031408 return emitOptionalError (
13041409 loc, " all scale values except 2 innermost must be 1" );
@@ -1323,6 +1428,58 @@ LogicalResult tensorrt::ResizeCubicOp::inferReturnTypeComponents(
13231428 return success ();
13241429}
13251430
1431+ LogicalResult tensorrt::ResizeCubicOp::reifyResultShapes (
1432+ OpBuilder &b, ReifiedRankedShapedTypeDims &result) {
1433+ Location loc = getLoc ();
1434+ RankedTensorType resultType = getType ();
1435+ int64_t rank = resultType.getRank ();
1436+
1437+ // Case 1: if `output_shape` is specified, then we just extract the scalars
1438+ // from that shape.
1439+ if (TypedValue<TensorType> outputShape = getOutputShape ()) {
1440+ // 'tensor.extract' %source [%index]
1441+ SmallVector<OpFoldResult> extents;
1442+ for (int64_t i = 0 ; i < rank; i++) {
1443+ Value index = b.create <arith::ConstantOp>(getLoc (), b.getIndexAttr (i));
1444+ Value extractedShape =
1445+ b.create <tensor::ExtractOp>(loc, outputShape, index).getResult ();
1446+ extents.push_back (
1447+ b.create <arith::IndexCastOp>(loc, b.getIndexType (), extractedShape)
1448+ .getResult ());
1449+ }
1450+ result.emplace_back (std::move (extents));
1451+ return success ();
1452+ }
1453+
1454+ SmallVector<OpFoldResult> extents;
1455+ extents.reserve (rank);
1456+
1457+ // This number of trailing dimensions are the special dimensions.
1458+ const int64_t resizeDims =
1459+ std::min (static_cast <int64_t >(3 ), resultType.getRank ());
1460+
1461+ for (auto [idx, extent] : llvm::enumerate (resultType.getShape ())) {
1462+
1463+ // If dimension is known, just materialize the extent as constant.
1464+ if (!ShapedType::isDynamic (extent)) {
1465+ extents.push_back (b.getIndexAttr (extent));
1466+ continue ;
1467+ }
1468+
1469+ // Otherwise, the extent is equal to sentinel value (ShapedType::kDynamic),
1470+ // then we use `tensor.dim` on the input operand.
1471+ // Batch dimensions can only be leading dim.
1472+ if (static_cast <int64_t >(idx) >= rank - resizeDims)
1473+ return failure ();
1474+
1475+ Value index = b.create <arith::ConstantOp>(loc, b.getIndexAttr (idx));
1476+ extents.push_back (
1477+ b.create <tensor::DimOp>(loc, getInput (), index).getResult ());
1478+ }
1479+ result.emplace_back (std::move (extents));
1480+ return success ();
1481+ }
1482+
13261483// ===----------------------------------------------------------------------===//
13271484// ScatterOp
13281485// ===----------------------------------------------------------------------===//
0 commit comments