@@ -52,12 +52,12 @@ static void updateTypeInPlaceAndMaybeInsertCast(RewriterBase &rewriter,
5252 rewriter.modifyOpInPlace (toUpdate.getDefiningOp (),
5353 [&]() { toUpdate.setType (newType); });
5454
55- // If all the users are StableHLO ops or plugins , then they all allow in-place
56- // update of operand types.
57- auto isOpaquePlugin = [](Operation *op) {
58- return llvm::isa<tensorrt::OpaquePluginOp >(op);
55+ // If all the users are StableHLO or TensorRT ops , then they all allow
56+ // in-place update of operand types.
57+ auto isTensorRTOp = [](Operation *op) {
58+ return llvm::isa<tensorrt::TensorRTDialect >(op-> getDialect () );
5959 };
60- if (stablehlo::canUpdateTypeWithoutCast (toUpdate, isOpaquePlugin ))
60+ if (stablehlo::canUpdateTypeWithoutCast (toUpdate, isTensorRTOp ))
6161 return ;
6262
6363 OpBuilder::InsertionGuard g (rewriter);
@@ -294,6 +294,38 @@ struct StableHloRefineTypeFromWithShapeGeneric
294294 }
295295};
296296
297+ // / Given a pattern `plan.with_shape(tensorrt_op, dims...)`, if inspection of
298+ // / `dims` yields an opportunity to refine the type of `with_shape`, then
299+ // / `tensorrt_op` can also be refined. The refinements are made (and casts are
300+ // / inserted if required).
301+ struct TensorRTRefineTypeFromWithShapeGeneric
302+ : public OpRewritePattern<WithShapeOp> {
303+ using OpRewritePattern<WithShapeOp>::OpRewritePattern;
304+ LogicalResult matchAndRewrite (WithShapeOp withOp,
305+ PatternRewriter &rewriter) const override {
306+ auto producer = withOp.getOperand ().getDefiningOp ();
307+ if (!producer || !producer->hasOneUse () ||
308+ !isa<tensorrt::TensorRTDialect>(producer->getDialect ()))
309+ return failure ();
310+
311+ // Create a new shape and try to refine it.
312+ std::optional<SmallVector<int64_t >> newShape =
313+ getRefinedShape (withOp.getShape (), withOp.getOperand ().getType ());
314+ if (!newShape)
315+ return failure ();
316+
317+ // Update type of the producer.
318+ updateTypeInPlaceAndMaybeInsertCast (
319+ rewriter, withOp.getOperand (),
320+ withOp.getOperand ().getType ().clone (*newShape));
321+
322+ // Update type of the WithShapeOp.
323+ updateTypeInPlaceAndMaybeInsertCast (rewriter, withOp.getResult (),
324+ withOp.getType ().clone (*newShape));
325+ return success ();
326+ }
327+ };
328+
297329class PlanRefineTypesPass
298330 : public plan::impl::PlanRefineTypesPassBase<PlanRefineTypesPass> {
299331 using Base::Base;
@@ -315,10 +347,12 @@ class PlanRefineTypesPass
315347 RefineDynamicIota,
316348 SimplifyIdentityDynamicBroadcast,
317349 StableHloRefineTypeFromWithShapeGeneric,
318- WithShapeAbsorbCastPattern
350+ WithShapeAbsorbCastPattern,
351+ TensorRTRefineTypeFromWithShapeGeneric
319352 >(ctx);
320353 // clang-format on
321354 stablehlo::populateStablehloRefineShapesPatterns (&patterns, ctx);
355+ stablehlo::populateStablehloCanonicalizationPatterns (ctx, &patterns);
322356 if (failed (applyPatternsAndFoldGreedily (funcTarget, std::move (patterns),
323357 config))) {
324358 emitError (funcTarget.getLoc ())
0 commit comments