Skip to content

Commit fa086ce

Browse files
committed
Add type canonicalizer for tensorrt ops
1 parent 6e51621 commit fa086ce

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/RefineTypes.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ static void updateTypeInPlaceAndMaybeInsertCast(RewriterBase &rewriter,
5454

5555
// If all the users are StableHLO ops or plugins, then they all allow in-place
5656
// update of operand types.
57+
// Need to make exception for tensorrt dialect?
5758
auto isOpaquePlugin = [](Operation *op) {
5859
return llvm::isa<tensorrt::OpaquePluginOp>(op);
5960
};
@@ -294,6 +295,26 @@ struct StableHloRefineTypeFromWithShapeGeneric
294295
}
295296
};
296297

298+
struct TensorRTResizeAbsorbCastPattern: public OpRewritePattern<tensor::CastOp> {
299+
using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
300+
LogicalResult matchAndRewrite(tensor::CastOp castOp,
301+
PatternRewriter &rewriter) const override {
302+
auto resizeOp = castOp.getOperand().getDefiningOp();
303+
304+
// tensor.cast is the only user of the tensorrt op
305+
// tensor.cast must be a refining op
306+
if (!resizeOp || !resizeOp->hasOneUse() ||
307+
!isa<tensorrt::TensorRTDialect>(resizeOp->getDialect()) ||
308+
!tensorrt::isTargetRefinementOfSource(
309+
castOp.getType().getShape(),
310+
castOp.getSource().getType().getShape()))
311+
return failure();
312+
313+
// how to rewrite?
314+
return success();
315+
}
316+
};
317+
297318
class PlanRefineTypesPass
298319
: public plan::impl::PlanRefineTypesPassBase<PlanRefineTypesPass> {
299320
using Base::Base;

mlir-tensorrt/test/Dialect/Plan/refine-types.mlir

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,4 +130,18 @@ func.func @tensorrt_opaque_plugin_no_cast() -> tensor<?xf32> {
130130
// CHECK: %[[v1:.*]] = stablehlo.dynamic_reshape %{{.*}}, %{{.*}} : (tensor<64xf32>, tensor<1xi32>) -> tensor<64xf32>
131131
// CHECK: %[[v2:.*]] = tensorrt.opaque_plugin
132132
// CHECK-SAME: (%[[v1]]) : (tensor<64xf32>) -> tensor<?xf32>
133-
// CHECK: return %{{.*}} : tensor<?xf32>
133+
// CHECK: return %{{.*}} : tensor<?xf32>
134+
135+
// -----
136+
137+
func.func @tensorrt_resize_no_cast() -> tensor<?x?x?x?xf32> {
138+
%cst = stablehlo.constant dense<1.000000e+00> : tensor<1x1x4x4f32>
139+
%c = stablehlo.constant dense<[1, 1, 8, 8]> : tensor<4xi32>
140+
%result = tensorrt.resize_linear {
141+
coordinateTransformation = #tensorrt.resize_coordinate_transformation<kALIGN_CORNERS>,
142+
selectorForSinglePixel = #tensorrt.resize_selector<kUPPER>
143+
} %cst, %c : (tensor<1x1x4x4xf32>, tensor<4xi32>) -> tensor<?x?x?x?xf32>
144+
145+
return %result : tensor<?x?x?x?xf32>
146+
}
147+
// CHECK-LABEL: func.func @tensorrt_resize_no_cast

0 commit comments

Comments
 (0)