@@ -54,6 +54,7 @@ static void updateTypeInPlaceAndMaybeInsertCast(RewriterBase &rewriter,
5454
5555 // If all the users are StableHLO ops or plugins, then they all allow in-place
5656 // update of operand types.
57+ // Need to make exception for tensorrt dialect?
5758 auto isOpaquePlugin = [](Operation *op) {
5859 return llvm::isa<tensorrt::OpaquePluginOp>(op);
5960 };
@@ -294,6 +295,30 @@ struct StableHloRefineTypeFromWithShapeGeneric
294295 }
295296};
296297
298+ struct TensorRTResizeAbsorbCastPattern : public OpRewritePattern <tensor::CastOp> {
299+ using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
300+ LogicalResult matchAndRewrite (tensor::CastOp castOp,
301+ PatternRewriter &rewriter) const override {
302+ auto resizeOp = castOp.getOperand ().getDefiningOp ();
303+
304+ // tensor.cast is the only user of the tensorrt op
305+ // tensor.cast must be a refining op
306+ if (!resizeOp || !resizeOp->hasOneUse () ||
307+ !isa<tensorrt::TensorRTDialect>(resizeOp->getDialect ()) ||
308+ !tensorrt::isTargetRefinementOfSource (
309+ castOp.getSource ().getType ().getShape (),
310+ castOp.getType ().getShape ()))
311+ return failure ();
312+
313+ // how to rewrite?
314+ rewriter.modifyOpInPlace (resizeOp, [&]() {
315+ resizeOp->getResult (0 ).setType (
316+ cast<RankedTensorType>(castOp.getType ()));
317+ });
318+ return success ();
319+ }
320+ };
321+
297322class PlanRefineTypesPass
298323 : public plan::impl::PlanRefineTypesPassBase<PlanRefineTypesPass> {
299324 using Base::Base;
@@ -315,7 +340,8 @@ class PlanRefineTypesPass
315340 RefineDynamicIota,
316341 SimplifyIdentityDynamicBroadcast,
317342 StableHloRefineTypeFromWithShapeGeneric,
318- WithShapeAbsorbCastPattern
343+ WithShapeAbsorbCastPattern,
344+ TensorRTResizeAbsorbCastPattern
319345 >(ctx);
320346 // clang-format on
321347 stablehlo::populateStablehloRefineShapesPatterns (&patterns, ctx);
0 commit comments