Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir-tensorrt/tensorrt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,5 @@ include_directories(${MLIR_TENSORRT_DIALECT_BINARY_DIR}/include)

add_subdirectory(include/mlir-tensorrt-dialect)
add_subdirectory(lib)
add_subdirectory(tensorrt-opt)
add_subdirectory(test)
add_subdirectory(tensorrt-opt)
Original file line number Diff line number Diff line change
Expand Up @@ -3463,7 +3463,8 @@ def TensorRT_ReshapeOp : TensorRT_Op<"reshape",
//===----------------------------------------------------------------------===//

def TensorRT_ResizeNearestOp : TensorRT_Op<"resize_nearest", [Pure,
TensorRTPartiallyInferTensorResultTypes]>{
TensorRTPartiallyInferTensorResultTypes,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]>{
let summary = "TensorRT Resize(IResizeLayer with NEAREST mode) operation";
let description = [{

Expand Down Expand Up @@ -3541,7 +3542,8 @@ def TensorRT_ResizeNearestOp : TensorRT_Op<"resize_nearest", [Pure,
//===----------------------------------------------------------------------===//

def TensorRT_ResizeLinearOp : TensorRT_Op<"resize_linear", [Pure,
TensorRTPartiallyInferTensorResultTypes]>{
TensorRTPartiallyInferTensorResultTypes,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]>{
let summary = "TensorRT Resize(IResizeLayer with LINEAR mode) operation";
let description = [{

Expand Down Expand Up @@ -3608,7 +3610,8 @@ def TensorRT_ResizeLinearOp : TensorRT_Op<"resize_linear", [Pure,
//===----------------------------------------------------------------------===//

def TensorRT_ResizeCubicOp : TensorRT_Op<"resize_cubic", [Pure,
TensorRTPartiallyInferTensorResultTypes]>{
TensorRTPartiallyInferTensorResultTypes,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]>{
let summary = "TensorRT Resize(IResizeLayer with CUBIC mode) operation";
let description = [{

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "EinsumHelper.h"
#include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h"
#include "mlir-tensorrt-dialect/Utils/ShapeUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/Builders.h"
Expand Down Expand Up @@ -1211,7 +1212,7 @@ LogicalResult tensorrt::ResizeNearestOp::inferReturnTypeComponents(
inputType.getRank())
return emitOptionalError(loc, "scales parameter must have same number of "
"dimensions as input/output");
for (int i = 0; i < inputType.getRank() - resizeDims; i++)
for (int64_t i = 0; i < inputType.getRank() - resizeDims; i++)
if (adaptor.getScales().value()[i] != 1)
return emitOptionalError(
loc,
Expand All @@ -1236,6 +1237,58 @@ LogicalResult tensorrt::ResizeNearestOp::inferReturnTypeComponents(
return success();
}

LogicalResult tensorrt::ResizeNearestOp::reifyResultShapes(
OpBuilder &b, ReifiedRankedShapedTypeDims &result) {
Location loc = getLoc();
RankedTensorType resultType = getType();
int64_t rank = resultType.getRank();

// Case 1: if `output_shape` is specified, then we just extract the scalars
// from that shape.
if (TypedValue<TensorType> outputShape = getOutputShape()) {
// 'tensor.extract' %source [%index]
SmallVector<OpFoldResult> extents;
for (int64_t i = 0; i < rank; i++) {
Value index = b.create<arith::ConstantOp>(getLoc(), b.getIndexAttr(i));
Value extractedShape =
b.create<tensor::ExtractOp>(loc, outputShape, index).getResult();
extents.push_back(
b.create<arith::IndexCastOp>(loc, b.getIndexType(), extractedShape)
.getResult());
}
result.emplace_back(std::move(extents));
return success();
}

SmallVector<OpFoldResult> extents;
extents.reserve(rank);

// This number of trailing dimensions are the special dimensions.
const int64_t resizeDims =
std::min(static_cast<int64_t>(3), resultType.getRank());

for (auto [idx, extent] : llvm::enumerate(resultType.getShape())) {

// If dimension is known, just materialize the extent as constant.
if (!ShapedType::isDynamic(extent)) {
extents.push_back(b.getIndexAttr(extent));
continue;
}

// Otherwise, the extent is equal to sentinel value (ShapedType::kDynamic),
// then we use `tensor.dim` on the input operand.
// Batch dimensions can only be leading dim.
if (static_cast<int64_t>(idx) >= rank - resizeDims)
return failure();

Value index = b.create<arith::ConstantOp>(loc, b.getIndexAttr(idx));
extents.push_back(
b.create<tensor::DimOp>(loc, getInput(), index).getResult());
}
result.emplace_back(std::move(extents));
return success();
}

//===----------------------------------------------------------------------===//
// ResizeLinearOp
//===----------------------------------------------------------------------===//
Expand All @@ -1253,7 +1306,7 @@ LogicalResult tensorrt::ResizeLinearOp::inferReturnTypeComponents(
inputType.getRank())
return emitOptionalError(loc, "scales parameter must have same number of "
"dimensions as input/output");
for (int i = 0; i < inputType.getRank() - resizeDims; i++)
for (int64_t i = 0; i < inputType.getRank() - resizeDims; i++)
if (adaptor.getScales().value()[i] != 1)
return emitOptionalError(
loc,
Expand All @@ -1279,6 +1332,58 @@ LogicalResult tensorrt::ResizeLinearOp::inferReturnTypeComponents(
return success();
}

LogicalResult tensorrt::ResizeLinearOp::reifyResultShapes(
OpBuilder &b, ReifiedRankedShapedTypeDims &result) {
Location loc = getLoc();
RankedTensorType resultType = getType();
int64_t rank = resultType.getRank();

// Case 1: if `output_shape` is specified, then we just extract the scalars
// from that shape.
if (TypedValue<TensorType> outputShape = getOutputShape()) {
// 'tensor.extract' %source [%index]
SmallVector<OpFoldResult> extents;
for (int64_t i = 0; i < rank; i++) {
Value index = b.create<arith::ConstantOp>(getLoc(), b.getIndexAttr(i));
Value extractedShape =
b.create<tensor::ExtractOp>(loc, outputShape, index).getResult();
extents.push_back(
b.create<arith::IndexCastOp>(loc, b.getIndexType(), extractedShape)
.getResult());
}
result.emplace_back(std::move(extents));
return success();
}

SmallVector<OpFoldResult> extents;
extents.reserve(rank);

// This number of trailing dimensions are the special dimensions.
const int64_t resizeDims =
std::min(static_cast<int64_t>(3), resultType.getRank());

for (auto [idx, extent] : llvm::enumerate(resultType.getShape())) {

// If dimension is known, just materialize the extent as constant.
if (!ShapedType::isDynamic(extent)) {
extents.push_back(b.getIndexAttr(extent));
continue;
}

// Otherwise, the extent is equal to sentinel value (ShapedType::kDynamic),
// then we use `tensor.dim` on the input operand.
// Batch dimensions can only be leading dim.
if (static_cast<int64_t>(idx) >= rank - resizeDims)
return failure();

Value index = b.create<arith::ConstantOp>(loc, b.getIndexAttr(idx));
extents.push_back(
b.create<tensor::DimOp>(loc, getInput(), index).getResult());
}
result.emplace_back(std::move(extents));
return success();
}

//===----------------------------------------------------------------------===//
// ResizeCubicOp
//===----------------------------------------------------------------------===//
Expand All @@ -1298,7 +1403,7 @@ LogicalResult tensorrt::ResizeCubicOp::inferReturnTypeComponents(
inputType.getRank())
return emitOptionalError(loc, "scales parameter must have same number of "
"dimensions as input/output");
for (int i = 0; i < inputType.getRank() - 2; i++)
for (int64_t i = 0; i < inputType.getRank() - 2; i++)
if (adaptor.getScales().value()[i] != 1)
return emitOptionalError(
loc, "all scale values except 2 innermost must be 1");
Expand All @@ -1323,6 +1428,58 @@ LogicalResult tensorrt::ResizeCubicOp::inferReturnTypeComponents(
return success();
}

LogicalResult tensorrt::ResizeCubicOp::reifyResultShapes(
OpBuilder &b, ReifiedRankedShapedTypeDims &result) {
Location loc = getLoc();
RankedTensorType resultType = getType();
int64_t rank = resultType.getRank();

// Case 1: if `output_shape` is specified, then we just extract the scalars
// from that shape.
if (TypedValue<TensorType> outputShape = getOutputShape()) {
// 'tensor.extract' %source [%index]
SmallVector<OpFoldResult> extents;
for (int64_t i = 0; i < rank; i++) {
Value index = b.create<arith::ConstantOp>(getLoc(), b.getIndexAttr(i));
Value extractedShape =
b.create<tensor::ExtractOp>(loc, outputShape, index).getResult();
extents.push_back(
b.create<arith::IndexCastOp>(loc, b.getIndexType(), extractedShape)
.getResult());
}
result.emplace_back(std::move(extents));
return success();
}

SmallVector<OpFoldResult> extents;
extents.reserve(rank);

// This number of trailing dimensions are the special dimensions.
const int64_t resizeDims =
std::min(static_cast<int64_t>(3), resultType.getRank());

for (auto [idx, extent] : llvm::enumerate(resultType.getShape())) {

// If dimension is known, just materialize the extent as constant.
if (!ShapedType::isDynamic(extent)) {
extents.push_back(b.getIndexAttr(extent));
continue;
}

// Otherwise, the extent is equal to sentinel value (ShapedType::kDynamic),
// then we use `tensor.dim` on the input operand.
// Batch dimensions can only be leading dim.
if (static_cast<int64_t>(idx) >= rank - resizeDims)
return failure();

Value index = b.create<arith::ConstantOp>(loc, b.getIndexAttr(idx));
extents.push_back(
b.create<tensor::DimOp>(loc, getInput(), index).getResult());
}
result.emplace_back(std::move(extents));
return success();
}

//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 3 additions & 2 deletions mlir-tensorrt/tensorrt/tensorrt-opt/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
add_llvm_executable(tensorrt-opt tensorrt-opt.cpp)

get_property(MLIR_TENSORRT_DIALECT_TEST_LIBS GLOBAL PROPERTY MLIR_TENSORRT_DIALECT_TEST_LIBS)
llvm_update_compile_flags(tensorrt-opt)
target_link_libraries(tensorrt-opt PRIVATE
MLIRTensorRTDialect
Expand All @@ -10,8 +10,9 @@ target_link_libraries(tensorrt-opt PRIVATE
MLIROptLib
MLIRTensorDialect
MLIRTransforms
MLIRTensorRTTestTensorKindAnalysis
MLIRSCFDialect
${MLIR_TENSORRT_DIALECT_TEST_LIBS}
MLIRTensorRTTestTensorKindAnalysis
)

mlir_check_all_link_libraries(tensorrt-opt)
4 changes: 3 additions & 1 deletion mlir-tensorrt/tensorrt/tensorrt-opt/tensorrt-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@

namespace mlir {
void registerTestTensorKindAnalysisPass();
}
void registerTestTensorRTShapeInferencePass();
} // namespace mlir

int main(int argc, char **argv) {
mlir::DialectRegistry registry;
Expand All @@ -44,6 +45,7 @@ int main(int argc, char **argv) {
mlir::affine::AffineDialect, mlir::quant::QuantizationDialect,
mlir::scf::SCFDialect>();
mlir::registerTestTensorKindAnalysisPass();
mlir::registerTestTensorRTShapeInferencePass();
mlir::func::registerInlinerExtension(registry);
mlir::tensorrt::registerTensorRTTranslationCLOpts();
mlir::tensorrt::registerTensorRTPasses();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// RUN: tensorrt-opt %s -split-input-file -test-tensorrt-shape-inference | FileCheck %s

func.func @test_resize_linear(%arg0: tensor<10x10xf32>) -> (index, index) {
%result = tensorrt.resize_linear {
coordinateTransformation = #tensorrt.resize_coordinate_transformation<kALIGN_CORNERS>,
selectorForSinglePixel = #tensorrt.resize_selector<kUPPER>
} %arg0 : (tensor<10x10xf32>) -> tensor<20x20xf32>

%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%d0 = tensor.dim %result, %c0 : tensor<20x20xf32>
%d1 = tensor.dim %result, %c1 : tensor<20x20xf32>
return %d0, %d1 : index, index
}

// CHECK-LABEL: test_resize_linear
// CHECK-NEXT: %[[c20:.+]] = arith.constant 20 : index
// CHECK-NEXT: return %[[c20]], %[[c20]] : index, index

// -----

func.func @test_resize_dynamic_batch(%arg0: tensor<?x1x10x10xf32>) -> (index, index, index, index) {
%result = tensorrt.resize_linear {
coordinateTransformation = #tensorrt.resize_coordinate_transformation<kALIGN_CORNERS>,
selectorForSinglePixel = #tensorrt.resize_selector<kUPPER>
} %arg0 : (tensor<?x1x10x10xf32>) -> tensor<?x1x20x20xf32>

%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%d0 = tensor.dim %result, %c0 : tensor<?x1x20x20xf32>
%d1 = tensor.dim %result, %c1 : tensor<?x1x20x20xf32>
%d2 = tensor.dim %result, %c2 : tensor<?x1x20x20xf32>
%d3 = tensor.dim %result, %c3 : tensor<?x1x20x20xf32>
return %d0, %d1, %d2, %d3 : index, index, index, index
}

// CHECK-LABEL: func.func @test_resize_dynamic_batch
// CHECK-SAME: (%[[arg0:.+]]: tensor<?x1x10x10xf32>)
// CHECK-DAG: %[[c20:.+]] = arith.constant 20 : index
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
// CHECK: %[[dim:.+]] = tensor.dim %[[arg0]], %[[c0]] : tensor<?x1x10x10xf32>
// CHECK: return %[[dim]], %[[c1]], %[[c20]], %[[c20]]

// -----

func.func @test_resize_output_shape(%arg0: tensor<4x4xf32>, %arg1: tensor<2xi32>) -> (index, index) {
%result = tensorrt.resize_linear {
coordinateTransformation = #tensorrt.resize_coordinate_transformation<kALIGN_CORNERS>,
selectorForSinglePixel = #tensorrt.resize_selector<kUPPER>
} %arg0, %arg1 : (tensor<4x4xf32>, tensor<2xi32>) -> tensor<?x?xf32>

%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%d0 = tensor.dim %result, %c0 : tensor<?x?xf32>
%d1 = tensor.dim %result, %c1 : tensor<?x?xf32>
return %d0, %d1 : index, index
}

// CHECK-LABEL: func.func @test_resize_output_shape
// CHECK-SAME: (%[[arg0:.+]]: tensor<4x4xf32>, %[[arg1:.+]]: tensor<2xi32>)
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[extracted:.+]] = tensor.extract %arg1[%c0] : tensor<2xi32>
// CHECK-DAG: %[[v0:.+]] = arith.index_cast %extracted : i32 to index
// CHECK-DAG: %[[extracted_0:.+]] = tensor.extract %arg1[%c1] : tensor<2xi32>
// CHECK-DAG: %[[v1:.+]] = arith.index_cast %extracted_0 : i32 to index
// CHECK: return %[[v0]], %[[v1]]
7 changes: 7 additions & 0 deletions mlir-tensorrt/tensorrt/test/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,8 @@
function(add_mlir_tensorrt_dialect_test_library target)
add_mlir_library(${target} ${ARGN}
EXCLUDE_FROM_LIBMLIR)
set_property(GLOBAL APPEND PROPERTY MLIR_TENSORRT_DIALECT_TEST_LIBS ${target})
endfunction()

add_subdirectory(Target)
add_subdirectory(TensorRT)
10 changes: 10 additions & 0 deletions mlir-tensorrt/tensorrt/test/lib/TensorRT/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
add_mlir_tensorrt_dialect_test_library(MLIRTensorRTTestTypeInferencePass
TestTypeInferencePass.cpp

LINK_LIBS PUBLIC
MLIRTensorRTDialect
MLIRTensorDialect
MLIRArithDialect
MLIRPass
MLIRTransformUtils
)
Loading