Skip to content

Commit 3353e18

Browse files
shelkesagar29Copybara Bot
andauthored
[compiler] Fix stablehlo-ext-constant-folding bug in "absorb tensor.cast" pattern
This PR move the following internal commits to OSS [compiler] Fix `stablehlo-ext-constant-folding` bug in "absorb tensor.cast" pattern Fixes an issue where we incorrectly assume all StableHLO operations have tensor operands. There are other types which can be used by various ops --- `stablehlo.token` and `tuple` at least (see https://openxla.org/stablehlo/spec#types). GitOrigin-RevId: 5415c7a0db725232fa30086c27ca38e70d28d0eb Co-authored-by: Copybara Bot <[email protected]>
1 parent ff1b5e3 commit 3353e18

File tree

3 files changed

+50
-3
lines changed

3 files changed

+50
-3
lines changed

mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/ConstantFolding.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,7 +1038,11 @@ struct AbsorbTensorCastProducer : public RewritePattern {
10381038
if (!canUpdateTypeWithoutCast(operand))
10391039
return nullptr;
10401040
Value value = operand.get();
1041-
auto rtt = cast<RankedTensorType>(value.getType());
1041+
// Not all stablehlo operands are tensors -- some can have types like
1042+
// 'tuple' or special quantized types.
1043+
auto rtt = dyn_cast<RankedTensorType>(value.getType());
1044+
if (!rtt)
1045+
return nullptr;
10421046
auto castOp = value.getDefiningOp<tensor::CastOp>();
10431047
if (!castOp)
10441048
return nullptr;
@@ -1273,7 +1277,27 @@ class SimplifyConcatOfConcatPattern
12731277
}
12741278
};
12751279

1280+
// Pattern: broadcast_in_dim(splat, _) -> constant(splat)
1281+
struct FoldBroadcastInDimSplatPattern final
1282+
: OpRewritePattern<mlir::stablehlo::BroadcastInDimOp> {
1283+
using OpRewritePattern::OpRewritePattern;
1284+
1285+
LogicalResult matchAndRewrite(mlir::stablehlo::BroadcastInDimOp op,
1286+
PatternRewriter &rewriter) const override {
1287+
TypedValue<RankedTensorType> operand = op.getOperand();
1288+
1289+
if (SplatElementsAttr cstAttr;
1290+
matchPattern(operand, m_Constant(&cstAttr))) {
1291+
rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(
1292+
op, SplatElementsAttr::get(op.getType(),
1293+
cstAttr.getSplatValue<Attribute>()));
1294+
return success();
1295+
}
1296+
return failure();
1297+
}
1298+
};
1299+
12761300
void populateFutureUpstreamPatterns(RewritePatternSet &patterns) {
1277-
patterns.add<SimplifySliceOfConcat, SimplifyConcatOfConcatPattern>(
1278-
patterns.getContext());
1301+
patterns.add<SimplifySliceOfConcat, SimplifyConcatOfConcatPattern,
1302+
FoldBroadcastInDimSplatPattern>(patterns.getContext());
12791303
}

mlir-tensorrt/tensorrt/lib/Target/TranslateToTensorRT.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,9 @@ class TranslateToTensorRTEnginePass
803803
continue;
804804
}
805805

806+
LLVM_DEBUG(DBGS() << "starting to build TensorRT engine for function "
807+
<< func.getName() << "\n");
808+
806809
FailureOr<TensorRTEngineResult> engineResult =
807810
buildFunction(func, *builderContext, *timingCache, translationOptions,
808811
layerMetadataCallback);
@@ -811,6 +814,10 @@ class TranslateToTensorRTEnginePass
811814
<< "' to a TensorRT engine";
812815
return signalPassFailure();
813816
}
817+
818+
LLVM_DEBUG(DBGS() << "done building TensorRT engine for function "
819+
<< func.getName() << "\n");
820+
814821
const std::unique_ptr<nvinfer1::IHostMemory> &serializedEngine =
815822
engineResult->serializedEngine;
816823

mlir-tensorrt/test/Dialect/StableHloExt/constant-folding.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,3 +1141,19 @@ func.func private @add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf
11411141
// CHECK-DAG: %[[cast_0:.+]] = tensor.cast %[[arg1]] : tensor<4xf32> to tensor<?xf32>
11421142
// CHECK-DAG: %[[v0:.+]] = stablehlo.composite "foo.bar" %[[cast]], %[[cast_0]] {decomposition = @add} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
11431143
// CHECK-DAG: return %[[v0]] : tensor<?xf32>
1144+
1145+
1146+
// -----
1147+
1148+
// This is a regression check for where we previously had a crash/failure. Not change should be
1149+
// made.
1150+
1151+
func.func @tuple_regression_check(%arg0: tuple<tensor<1xf32>, tensor<1xf32>>) -> tensor<1xf32> {
1152+
%0 = stablehlo.get_tuple_element %arg0[0] : (tuple<tensor<1xf32>, tensor<1xf32>>) -> tensor<1xf32>
1153+
return %0 : tensor<1xf32>
1154+
}
1155+
1156+
// CHECK-LABEL: func.func @tuple_regression_check
1157+
// CHECK-SAME: (%[[arg0:.+]]: tuple<tensor<1xf32>, tensor<1xf32>>)
1158+
// CHECK: %[[v0:.+]] = stablehlo.get_tuple_element %[[arg0]][0] : (tuple<tensor<1xf32>, tensor<1xf32>>) -> tensor<1xf32>
1159+
// CHECK: return %[[v0]] : tensor<1xf32>

0 commit comments

Comments
 (0)