Skip to content

Commit 7524556

Browse files
committed
[tensorrt] Refine types for tensorrt ops with plan.with_shape
Adds `TensorRTRefineTypeFromWithShapeGeneric` pattern to refines types of tensorrt ops in `PlanRefineTypesPass`.
1 parent e73ba59 commit 7524556

File tree

2 files changed

+73
-11
lines changed

2 files changed

+73
-11
lines changed

mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/RefineTypes.cpp

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,12 @@ static void updateTypeInPlaceAndMaybeInsertCast(RewriterBase &rewriter,
5252
rewriter.modifyOpInPlace(toUpdate.getDefiningOp(),
5353
[&]() { toUpdate.setType(newType); });
5454

55-
// If all the users are StableHLO ops or plugins, then they all allow in-place
56-
// update of operand types.
57-
auto isOpaquePlugin = [](Operation *op) {
58-
return llvm::isa<tensorrt::OpaquePluginOp>(op);
55+
// If all the users are StableHLO or TensorRT ops, then they all allow
56+
// in-place update of operand types.
57+
auto isTensorRTOp = [](Operation *op) {
58+
return llvm::isa<tensorrt::TensorRTDialect>(op->getDialect());
5959
};
60-
if (stablehlo::canUpdateTypeWithoutCast(toUpdate, isOpaquePlugin))
60+
if (stablehlo::canUpdateTypeWithoutCast(toUpdate, isTensorRTOp))
6161
return;
6262

6363
OpBuilder::InsertionGuard g(rewriter);
@@ -294,6 +294,38 @@ struct StableHloRefineTypeFromWithShapeGeneric
294294
}
295295
};
296296

297+
/// Given a pattern `plan.with_shape(tensorrt_op, dims...)`, if inspection of
298+
/// `dims` yields an opportunity to refine the type of `with_shape`, then
299+
/// `tensorrt_op` can also be refined. The refinements are made (and casts are
300+
/// inserted if required).
301+
struct TensorRTRefineTypeFromWithShapeGeneric
302+
: public OpRewritePattern<WithShapeOp> {
303+
using OpRewritePattern<WithShapeOp>::OpRewritePattern;
304+
LogicalResult matchAndRewrite(WithShapeOp withOp,
305+
PatternRewriter &rewriter) const override {
306+
auto producer = withOp.getOperand().getDefiningOp();
307+
if (!producer || !producer->hasOneUse() ||
308+
!isa<tensorrt::TensorRTDialect>(producer->getDialect()))
309+
return failure();
310+
311+
// Create a new shape and try to refine it.
312+
std::optional<SmallVector<int64_t>> newShape =
313+
getRefinedShape(withOp.getShape(), withOp.getOperand().getType());
314+
if (!newShape)
315+
return failure();
316+
317+
// Update type of the producer.
318+
updateTypeInPlaceAndMaybeInsertCast(
319+
rewriter, withOp.getOperand(),
320+
withOp.getOperand().getType().clone(*newShape));
321+
322+
// Update type of the WithShapeOp.
323+
updateTypeInPlaceAndMaybeInsertCast(rewriter, withOp.getResult(),
324+
withOp.getType().clone(*newShape));
325+
return success();
326+
}
327+
};
328+
297329
class PlanRefineTypesPass
298330
: public plan::impl::PlanRefineTypesPassBase<PlanRefineTypesPass> {
299331
using Base::Base;
@@ -315,10 +347,12 @@ class PlanRefineTypesPass
315347
RefineDynamicIota,
316348
SimplifyIdentityDynamicBroadcast,
317349
StableHloRefineTypeFromWithShapeGeneric,
318-
WithShapeAbsorbCastPattern
350+
WithShapeAbsorbCastPattern,
351+
TensorRTRefineTypeFromWithShapeGeneric
319352
>(ctx);
320353
// clang-format on
321354
stablehlo::populateStablehloRefineShapesPatterns(&patterns, ctx);
355+
stablehlo::populateStablehloCanonicalizationPatterns(ctx, &patterns);
322356
if (failed(applyPatternsAndFoldGreedily(funcTarget, std::move(patterns),
323357
config))) {
324358
emitError(funcTarget.getLoc())

mlir-tensorrt/test/Dialect/Plan/refine-types.mlir

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,36 @@ func.func @tensorrt_opaque_plugin_no_cast() -> tensor<?xf32> {
126126
return %3 : tensor<?xf32>
127127
}
128128
// CHECK-LABEL: func.func @tensorrt_opaque_plugin_no_cast
129-
// CHECK-SAME: () -> tensor<?xf32>
130-
// CHECK: %[[v1:.*]] = stablehlo.dynamic_reshape %{{.*}}, %{{.*}} : (tensor<64xf32>, tensor<1xi32>) -> tensor<64xf32>
131-
// CHECK: %[[v2:.*]] = tensorrt.opaque_plugin
132-
// CHECK-SAME: (%[[v1]]) : (tensor<64xf32>) -> tensor<?xf32>
133-
// CHECK: return %{{.*}} : tensor<?xf32>
129+
// CHECK-SAME: () -> tensor<1xf32>
130+
// CHECK: %[[cst:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<64xf32>
131+
// CHECK: %[[v0:.*]] = tensorrt.opaque_plugin
132+
// CHECK: return %{{.*}} : tensor<1xf32>
133+
134+
// -----
135+
136+
func.func @refine_tensorrt_resize_with_shape() -> tensor<?x?x?x?xf32> {
137+
%c3 = arith.constant 3 : index
138+
%c2 = arith.constant 2 : index
139+
%c1 = arith.constant 1 : index
140+
%c0 = arith.constant 0 : index
141+
%cst = stablehlo.constant dense<1.000000e+00> : tensor<1x1x4x4xf32>
142+
%c = stablehlo.constant dense<[1, 1, 8, 8]> : tensor<4xi32>
143+
%result = tensorrt.resize_linear {
144+
coordinateTransformation = #tensorrt.resize_coordinate_transformation<kALIGN_CORNERS>,
145+
selectorForSinglePixel = #tensorrt.resize_selector<kUPPER>
146+
} %cst, %c : (tensor<1x1x4x4xf32>, tensor<4xi32>) -> tensor<?x?x?x?xf32>
147+
%dim_i32_0 = tensor.extract %c[%c0] : tensor<4xi32>
148+
%dim_i32_1 = tensor.extract %c[%c1] : tensor<4xi32>
149+
%dim_i32_2 = tensor.extract %c[%c2] : tensor<4xi32>
150+
%dim_i32_3 = tensor.extract %c[%c3] : tensor<4xi32>
151+
%dim_0 = arith.index_cast %dim_i32_0 : i32 to index
152+
%dim_1 = arith.index_cast %dim_i32_1 : i32 to index
153+
%dim_2 = arith.index_cast %dim_i32_2 : i32 to index
154+
%dim_3 = arith.index_cast %dim_i32_3 : i32 to index
155+
%1 = plan.with_shape %result(%dim_0, %dim_1, %dim_2, %dim_3) : (tensor<?x?x?x?xf32>, index, index, index, index) -> tensor<?x?x?x?xf32>
156+
return %1 : tensor<?x?x?x?xf32>
157+
}
158+
// CHECK-LABEL: func.func @refine_tensorrt_resize_with_shape
159+
// CHECK-SAME: -> tensor<1x1x8x8xf32>
160+
// CHECK: %[[v0:.*]] = tensorrt.resize_linear {coordinateTransformation = #tensorrt.resize_coordinate_transformation<kALIGN_CORNERS>, selectorForSinglePixel = #tensorrt.resize_selector<kUPPER>} %cst, %c : (tensor<1x1x4x4xf32>, tensor<4xi32>) -> tensor<1x1x8x8xf32>
161+
// CHECK: return %[[v0]] : tensor<1x1x8x8xf32>

0 commit comments

Comments
 (0)