2424#include " EinsumHelper.h"
2525#include " mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h"
2626#include " mlir-tensorrt-dialect/Utils/ShapeUtils.h"
27+ #include " mlir/Dialect/Arith/IR/Arith.h"
2728#include " mlir/Dialect/Tensor/IR/Tensor.h"
2829#include " mlir/Dialect/Utils/ReshapeOpsUtils.h"
2930#include " mlir/IR/Builders.h"
@@ -1211,7 +1212,7 @@ LogicalResult tensorrt::ResizeNearestOp::inferReturnTypeComponents(
12111212 inputType.getRank ())
12121213 return emitOptionalError (loc, " scales parameter must have same number of "
12131214 " dimensions as input/output" );
1214- for (int i = 0 ; i < inputType.getRank () - resizeDims; i++)
1215+ for (int64_t i = 0 ; i < inputType.getRank () - resizeDims; i++)
12151216 if (adaptor.getScales ().value ()[i] != 1 )
12161217 return emitOptionalError (
12171218 loc,
@@ -1236,6 +1237,56 @@ LogicalResult tensorrt::ResizeNearestOp::inferReturnTypeComponents(
12361237 return success ();
12371238}
12381239
1240+ LogicalResult tensorrt::ResizeNearestOp::reifyResultShapes (
1241+ OpBuilder &b, ReifiedRankedShapedTypeDims &result) {
1242+ Location loc = getLoc ();
1243+ RankedTensorType resultType = getType ();
1244+ int64_t rank = resultType.getRank ();
1245+
1246+ // Case 1: if `output_shape` is specified, then we just extract the scalars
1247+ // from that shape.
1248+ if (TypedValue<TensorType> outputShape = getOutputShape ()) {
1249+ // 'tensor.extract' %source [%index]
1250+ SmallVector<OpFoldResult> extents;
1251+ for (int64_t i = 0 ; i < rank; i++) {
1252+ Value index = b.create <arith::ConstantOp>(getLoc (), b.getIndexAttr (i));
1253+ Value extractedShape = b.create <tensor::ExtractOp>(loc, outputShape, index).getResult ();
1254+ extents.push_back (
1255+ b.create <arith::IndexCastOp>(loc, b.getIndexType (), extractedShape).getResult ());
1256+ }
1257+ result.emplace_back (std::move (extents));
1258+ return success ();
1259+ }
1260+
1261+ SmallVector<OpFoldResult> extents;
1262+ extents.reserve (rank);
1263+
1264+ // This number of trailing dimensions are the special dimensions.
1265+ const int64_t resizeDims =
1266+ std::min (static_cast <int64_t >(3 ), resultType.getRank ());
1267+
1268+ for (auto [idx, extent] : llvm::enumerate (resultType.getShape ())) {
1269+
1270+ // If dimension is known, just materialize the extent as constant.
1271+ if (!ShapedType::isDynamic (extent)) {
1272+ extents.push_back (b.getIndexAttr (extent));
1273+ continue ;
1274+ }
1275+
1276+ // Otherwise, the extent is equal to sentinel value (ShapedType::kDynamic),
1277+ // then we use `tensor.dim` on the input operand.
1278+ // Batch dimensions can only be leading dim.
1279+ if (static_cast <int64_t >(idx) >= rank - resizeDims)
1280+ return failure ();
1281+
1282+ Value index = b.create <arith::ConstantOp>(loc, b.getIndexAttr (idx));
1283+ extents.push_back (
1284+ b.create <tensor::DimOp>(loc, getInput (), index).getResult ());
1285+ }
1286+ result.emplace_back (std::move (extents));
1287+ return success ();
1288+ }
1289+
12391290// ===----------------------------------------------------------------------===//
12401291// ResizeLinearOp
12411292// ===----------------------------------------------------------------------===//
@@ -1253,7 +1304,7 @@ LogicalResult tensorrt::ResizeLinearOp::inferReturnTypeComponents(
12531304 inputType.getRank ())
12541305 return emitOptionalError (loc, " scales parameter must have same number of "
12551306 " dimensions as input/output" );
1256- for (int i = 0 ; i < inputType.getRank () - resizeDims; i++)
1307+ for (int64_t i = 0 ; i < inputType.getRank () - resizeDims; i++)
12571308 if (adaptor.getScales ().value ()[i] != 1 )
12581309 return emitOptionalError (
12591310 loc,
@@ -1279,6 +1330,56 @@ LogicalResult tensorrt::ResizeLinearOp::inferReturnTypeComponents(
12791330 return success ();
12801331}
12811332
1333+ LogicalResult tensorrt::ResizeLinearOp::reifyResultShapes (
1334+ OpBuilder &b, ReifiedRankedShapedTypeDims &result) {
1335+ Location loc = getLoc ();
1336+ RankedTensorType resultType = getType ();
1337+ int64_t rank = resultType.getRank ();
1338+
1339+ // Case 1: if `output_shape` is specified, then we just extract the scalars
1340+ // from that shape.
1341+ if (TypedValue<TensorType> outputShape = getOutputShape ()) {
1342+ // 'tensor.extract' %source [%index]
1343+ SmallVector<OpFoldResult> extents;
1344+ for (int64_t i = 0 ; i < rank; i++) {
1345+ Value index = b.create <arith::ConstantOp>(getLoc (), b.getIndexAttr (i));
1346+ Value extractedShape = b.create <tensor::ExtractOp>(loc, outputShape, index).getResult ();
1347+ extents.push_back (
1348+ b.create <arith::IndexCastOp>(loc, b.getIndexType (), extractedShape).getResult ());
1349+ }
1350+ result.emplace_back (std::move (extents));
1351+ return success ();
1352+ }
1353+
1354+ SmallVector<OpFoldResult> extents;
1355+ extents.reserve (rank);
1356+
1357+ // This number of trailing dimensions are the special dimensions.
1358+ const int64_t resizeDims =
1359+ std::min (static_cast <int64_t >(3 ), resultType.getRank ());
1360+
1361+ for (auto [idx, extent] : llvm::enumerate (resultType.getShape ())) {
1362+
1363+ // If dimension is known, just materialize the extent as constant.
1364+ if (!ShapedType::isDynamic (extent)) {
1365+ extents.push_back (b.getIndexAttr (extent));
1366+ continue ;
1367+ }
1368+
1369+ // Otherwise, the extent is equal to sentinel value (ShapedType::kDynamic),
1370+ // then we use `tensor.dim` on the input operand.
1371+ // Batch dimensions can only be leading dim.
1372+ if (static_cast <int64_t >(idx) >= rank - resizeDims)
1373+ return failure ();
1374+
1375+ Value index = b.create <arith::ConstantOp>(loc, b.getIndexAttr (idx));
1376+ extents.push_back (
1377+ b.create <tensor::DimOp>(loc, getInput (), index).getResult ());
1378+ }
1379+ result.emplace_back (std::move (extents));
1380+ return success ();
1381+ }
1382+
12821383// ===----------------------------------------------------------------------===//
12831384// ResizeCubicOp
12841385// ===----------------------------------------------------------------------===//
@@ -1298,7 +1399,7 @@ LogicalResult tensorrt::ResizeCubicOp::inferReturnTypeComponents(
12981399 inputType.getRank ())
12991400 return emitOptionalError (loc, " scales parameter must have same number of "
13001401 " dimensions as input/output" );
1301- for (int i = 0 ; i < inputType.getRank () - 2 ; i++)
1402+ for (int64_t i = 0 ; i < inputType.getRank () - 2 ; i++)
13021403 if (adaptor.getScales ().value ()[i] != 1 )
13031404 return emitOptionalError (
13041405 loc, " all scale values except 2 innermost must be 1" );
@@ -1323,6 +1424,56 @@ LogicalResult tensorrt::ResizeCubicOp::inferReturnTypeComponents(
13231424 return success ();
13241425}
13251426
1427+ LogicalResult tensorrt::ResizeCubicOp::reifyResultShapes (
1428+ OpBuilder &b, ReifiedRankedShapedTypeDims &result) {
1429+ Location loc = getLoc ();
1430+ RankedTensorType resultType = getType ();
1431+ int64_t rank = resultType.getRank ();
1432+
1433+ // Case 1: if `output_shape` is specified, then we just extract the scalars
1434+ // from that shape.
1435+ if (TypedValue<TensorType> outputShape = getOutputShape ()) {
1436+ // 'tensor.extract' %source [%index]
1437+ SmallVector<OpFoldResult> extents;
1438+ for (int64_t i = 0 ; i < rank; i++) {
1439+ Value index = b.create <arith::ConstantOp>(getLoc (), b.getIndexAttr (i));
1440+ Value extractedShape = b.create <tensor::ExtractOp>(loc, outputShape, index).getResult ();
1441+ extents.push_back (
1442+ b.create <arith::IndexCastOp>(loc, b.getIndexType (), extractedShape).getResult ());
1443+ }
1444+ result.emplace_back (std::move (extents));
1445+ return success ();
1446+ }
1447+
1448+ SmallVector<OpFoldResult> extents;
1449+ extents.reserve (rank);
1450+
1451+ // This number of trailing dimensions are the special dimensions.
1452+ const int64_t resizeDims =
1453+ std::min (static_cast <int64_t >(3 ), resultType.getRank ());
1454+
1455+ for (auto [idx, extent] : llvm::enumerate (resultType.getShape ())) {
1456+
1457+ // If dimension is known, just materialize the extent as constant.
1458+ if (!ShapedType::isDynamic (extent)) {
1459+ extents.push_back (b.getIndexAttr (extent));
1460+ continue ;
1461+ }
1462+
1463+ // Otherwise, the extent is equal to sentinel value (ShapedType::kDynamic),
1464+ // then we use `tensor.dim` on the input operand.
1465+ // Batch dimensions can only be leading dim.
1466+ if (static_cast <int64_t >(idx) >= rank - resizeDims)
1467+ return failure ();
1468+
1469+ Value index = b.create <arith::ConstantOp>(loc, b.getIndexAttr (idx));
1470+ extents.push_back (
1471+ b.create <tensor::DimOp>(loc, getInput (), index).getResult ());
1472+ }
1473+ result.emplace_back (std::move (extents));
1474+ return success ();
1475+ }
1476+
13261477// ===----------------------------------------------------------------------===//
13271478// ScatterOp
13281479// ===----------------------------------------------------------------------===//
0 commit comments