Skip to content

Commit 9e7f578

Browse files
committed
Add in plan-refine-types pass, fix unit test
1 parent 41fb31c commit 9e7f578

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/RefineTypes.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ static void updateTypeInPlaceAndMaybeInsertCast(RewriterBase &rewriter,
5454

5555
// If all the users are StableHLO ops or plugins, then they all allow in-place
5656
// update of operand types.
57-
auto isOpaquePlugin = [](Operation *op) {
57+
auto isTensorRTOp = [](Operation *op) {
5858
return llvm::isa<tensorrt::TensorRTDialect>(op->getDialect());
5959
};
60-
if (stablehlo::canUpdateTypeWithoutCast(toUpdate, isOpaquePlugin))
60+
if (stablehlo::canUpdateTypeWithoutCast(toUpdate, isTensorRTOp))
6161
return;
6262

6363
OpBuilder::InsertionGuard g(rewriter);
@@ -377,6 +377,7 @@ class PlanRefineTypesPass
377377
>(ctx);
378378
// clang-format on
379379
stablehlo::populateStablehloRefineShapesPatterns(&patterns, ctx);
380+
stablehlo::populateStablehloCanonicalizationPatterns(ctx, &patterns);
380381
if (failed(applyPatternsAndFoldGreedily(funcTarget, std::move(patterns),
381382
config))) {
382383
emitError(funcTarget.getLoc())

mlir-tensorrt/test/Dialect/Plan/refine-types.mlir

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,10 @@ func.func @tensorrt_opaque_plugin_no_cast() -> tensor<?xf32> {
126126
return %3 : tensor<?xf32>
127127
}
128128
// CHECK-LABEL: func.func @tensorrt_opaque_plugin_no_cast
129-
// CHECK-SAME: () -> tensor<?xf32>
130-
// CHECK: %[[v1:.*]] = stablehlo.dynamic_reshape %{{.*}}, %{{.*}} : (tensor<64xf32>, tensor<1xi32>) -> tensor<64xf32>
131-
// CHECK: %[[v2:.*]] = tensorrt.opaque_plugin
132-
// CHECK-SAME: (%[[v1]]) : (tensor<64xf32>) -> tensor<?xf32>
133-
// CHECK: return %{{.*}} : tensor<?xf32>
129+
// CHECK-SAME: () -> tensor<1xf32>
130+
// CHECK: %[[cst:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<64xf32>
131+
// CHECK: %[[v0:.*]] = tensorrt.opaque_plugin
132+
// CHECK: return %{{.*}} : tensor<1xf32>
134133

135134
// -----
136135

0 commit comments

Comments
 (0)