Skip to content

Commit 95e1c8a

Browse files
committed
Add type canonicalizer for tensorrt ops
1 parent 6e51621 commit 95e1c8a

File tree

2 files changed

+45
-2
lines changed

2 files changed

+45
-2
lines changed

mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/RefineTypes.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ static void updateTypeInPlaceAndMaybeInsertCast(RewriterBase &rewriter,
5454

5555
// If all the users are StableHLO ops or plugins, then they all allow in-place
5656
// update of operand types.
57+
// Need to make exception for tensorrt dialect?
5758
auto isOpaquePlugin = [](Operation *op) {
5859
return llvm::isa<tensorrt::OpaquePluginOp>(op);
5960
};
@@ -294,6 +295,30 @@ struct StableHloRefineTypeFromWithShapeGeneric
294295
}
295296
};
296297

298+
struct TensorRTResizeAbsorbCastPattern: public OpRewritePattern<tensor::CastOp> {
299+
using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
300+
LogicalResult matchAndRewrite(tensor::CastOp castOp,
301+
PatternRewriter &rewriter) const override {
302+
auto resizeOp = castOp.getOperand().getDefiningOp();
303+
304+
// tensor.cast is the only user of the tensorrt op
305+
// tensor.cast must be a refining op
306+
if (!resizeOp || !resizeOp->hasOneUse() ||
307+
!isa<tensorrt::TensorRTDialect>(resizeOp->getDialect()) ||
308+
!tensorrt::isTargetRefinementOfSource(
309+
castOp.getSource().getType().getShape(),
310+
castOp.getType().getShape()))
311+
return failure();
312+
313+
// how to rewrite?
314+
rewriter.modifyOpInPlace(resizeOp, [&]() {
315+
resizeOp->getResult(0).setType(
316+
cast<RankedTensorType>(castOp.getType()));
317+
});
318+
return success();
319+
}
320+
};
321+
297322
class PlanRefineTypesPass
298323
: public plan::impl::PlanRefineTypesPassBase<PlanRefineTypesPass> {
299324
using Base::Base;
@@ -315,7 +340,8 @@ class PlanRefineTypesPass
315340
RefineDynamicIota,
316341
SimplifyIdentityDynamicBroadcast,
317342
StableHloRefineTypeFromWithShapeGeneric,
318-
WithShapeAbsorbCastPattern
343+
WithShapeAbsorbCastPattern,
344+
TensorRTResizeAbsorbCastPattern
319345
>(ctx);
320346
// clang-format on
321347
stablehlo::populateStablehloRefineShapesPatterns(&patterns, ctx);

mlir-tensorrt/test/Dialect/Plan/refine-types.mlir

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,4 +130,21 @@ func.func @tensorrt_opaque_plugin_no_cast() -> tensor<?xf32> {
130130
// CHECK: %[[v1:.*]] = stablehlo.dynamic_reshape %{{.*}}, %{{.*}} : (tensor<64xf32>, tensor<1xi32>) -> tensor<64xf32>
131131
// CHECK: %[[v2:.*]] = tensorrt.opaque_plugin
132132
// CHECK-SAME: (%[[v1]]) : (tensor<64xf32>) -> tensor<?xf32>
133-
// CHECK: return %{{.*}} : tensor<?xf32>
133+
// CHECK: return %{{.*}} : tensor<?xf32>
134+
135+
// -----
136+
137+
func.func @tensorrt_resize_no_cast() -> tensor<1x1x8x8xf32> {
138+
%cst = stablehlo.constant dense<1.000000e+00> : tensor<1x1x4x4xf32>
139+
%c = stablehlo.constant dense<[1, 1, 8, 8]> : tensor<4xi32>
140+
%result = tensorrt.resize_linear {
141+
coordinateTransformation = #tensorrt.resize_coordinate_transformation<kALIGN_CORNERS>,
142+
selectorForSinglePixel = #tensorrt.resize_selector<kUPPER>
143+
} %cst, %c : (tensor<1x1x4x4xf32>, tensor<4xi32>) -> tensor<?x?x?x?xf32>
144+
%cast = tensor.cast %result : tensor<?x?x?x?xf32> to tensor<1x1x8x8xf32>
145+
return %cast : tensor<1x1x8x8xf32>
146+
}
147+
// CHECK-LABEL: func.func @tensorrt_resize_no_cast
148+
// CHECK-SAME: -> tensor<1x1x8x8xf32>
149+
// CHECK: %[[v1:.*]] = tensorrt.resize_linear {coordinateTransformation = #tensorrt.resize_coordinate_transformation<kALIGN_CORNERS>, selectorForSinglePixel = #tensorrt.resize_selector<kUPPER>} %cst, %c : (tensor<1x1x4x4xf32>, tensor<4xi32>) -> tensor<1x1x8x8xf32>
150+
// CHECK: return %[[v1]] : tensor<1x1x8x8xf32>

0 commit comments

Comments
 (0)