2929#include " mlir/Dialect/CommonFolders.h"
3030#include " mlir/Dialect/Tensor/IR/Tensor.h"
3131#include " mlir/Dialect/UB/IR/UBOps.h"
32+ #include " mlir/IR/PatternMatch.h"
3233#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
3334#include " stablehlo/dialect/StablehloOps.h"
3435#include " stablehlo/dialect/TypeInference.h"
@@ -1064,6 +1065,11 @@ struct AbsorbTensorCastProducer : public RewritePattern {
10641065};
10651066} // namespace
10661067
1068+
1069+ // / Populates patterns that are temporarily reproduced here from upstream
1070+ // / commits we have not yet integrated.
1071+ static void populateFutureUpstreamPatterns (RewritePatternSet &patterns);
1072+
10671073void stablehlo_ext::populateStableHloAbsorbTensorCastPatterns (
10681074 RewritePatternSet &patterns) {
10691075 patterns.add <AbsorbTensorCastProducer>(patterns.getContext ());
@@ -1108,6 +1114,7 @@ class ConstantFoldingPass
11081114 SqrtOpFolder
11091115 >(ctx);
11101116 // clang-format on
1117+ populateFutureUpstreamPatterns (patterns);
11111118 populateStableHloAbsorbTensorCastPatterns (patterns);
11121119 stablehlo::populateStablehloCanonicalizationPatterns (ctx, &patterns);
11131120 tensor::EmptyOp::getCanonicalizationPatterns (patterns, ctx);
@@ -1124,3 +1131,150 @@ class ConstantFoldingPass
11241131 }
11251132};
11261133} // namespace
1134+
1135+ // ===----------------------------------------------------------------------===//
1136+ // / The patterns below this point are reproduced from
1137+ // / https://github.com/openxla/stablehlo/commit/5d15ab064f165cc6773ef4ba949ac083ae8e1fea,
1138+ // / which is in upstream, but our current pinned StableHlo commit is not there
1139+ // / yet. The patterns can be removed in the next StableHLO upgrade.
1140+ // /
1141+ // ===----------------------------------------------------------------------===//
1142+
1143+ // /
1144+ // / In cases where a concat is fed into a slice, it
1145+ // / is possible the concat can be simplified or bypassed. This checks which
1146+ // / inputs to the concat are used by the slice, either reducing the number of
1147+ // / concatenated values or entirely removes the concat. Pattern:
1148+ // / slice(concat(X,Y,Z,...),...) -> concat(slice(X),slice(Y),slice(Z))
1149+ struct SimplifySliceOfConcat : public OpRewritePattern <SliceOp> {
1150+ using OpRewritePattern::OpRewritePattern;
1151+
1152+ LogicalResult matchAndRewrite (SliceOp slice,
1153+ PatternRewriter &rewriter) const override {
1154+ RankedTensorType resultTy = slice.getType ();
1155+ if (!resultTy.hasStaticShape ())
1156+ return rewriter.notifyMatchFailure (slice, " result shape not static" );
1157+
1158+ auto concat = slice.getOperand ().getDefiningOp <ConcatenateOp>();
1159+ if (!concat)
1160+ return rewriter.notifyMatchFailure (slice, " slice input not concat" );
1161+
1162+ RankedTensorType concatType = concat.getType ();
1163+ uint64_t dimension = concat.getDimension ();
1164+
1165+ ArrayRef<int64_t > start = slice.getStartIndices ();
1166+ ArrayRef<int64_t > limit = slice.getLimitIndices ();
1167+
1168+ int64_t sliceStart = start[dimension];
1169+ int64_t sliceLimit = limit[dimension];
1170+
1171+ // We need to determine what inputs from the concat affect the slice, and
1172+ // how the bounds of the slice need to be updated for the minimally required
1173+ // inputs.
1174+ int64_t runningSize = 0 ;
1175+ int64_t frontOffset = concatType.getShape ()[dimension];
1176+
1177+ auto subsetStart = concat.operand_end ();
1178+ auto subsetEnd = concat.operand_end ();
1179+ for (auto it = concat.operand_begin (); it < concat.operand_end (); ++it) {
1180+ Value input = *it;
1181+ auto inputTy = cast<RankedTensorType>(input.getType ());
1182+ if (inputTy.isDynamicDim (dimension))
1183+ return rewriter.notifyMatchFailure (
1184+ slice, " concat input has dynamic dimension" );
1185+
1186+ int64_t dimSize = inputTy.getShape ()[dimension];
1187+
1188+ // If this position is in the slice its the start of the subset and we
1189+ // need to update the start and limit values.
1190+ if (runningSize + dimSize > sliceStart &&
1191+ subsetStart == concat.operand_end ()) {
1192+ subsetStart = it;
1193+ frontOffset = runningSize;
1194+ }
1195+
1196+ // Determine the last required offset.
1197+ if (runningSize < sliceLimit) {
1198+ subsetEnd = it + 1 ;
1199+ }
1200+
1201+ runningSize += dimSize;
1202+ }
1203+
1204+ auto subsetSize = subsetEnd - subsetStart;
1205+ // We need all inputs so no optimization.
1206+ if (subsetSize == concat.getNumOperands ())
1207+ return rewriter.notifyMatchFailure (slice,
1208+ " slice needs all concat inputs" );
1209+
1210+ // If there's nothing to slice that means the output is an empty tensor and
1211+ // there is dead code. We do nothing here and rely on other passes to clean
1212+ // this up.
1213+ if (subsetSize == 0 )
1214+ return rewriter.notifyMatchFailure (slice, " slice is empty" );
1215+
1216+ if (subsetSize > 1 && !concat.getResult ().hasOneUse ())
1217+ return rewriter.notifyMatchFailure (slice,
1218+ " slice is not the only concat user" );
1219+
1220+ auto concatRange = OperandRange (subsetStart, subsetEnd);
1221+ auto newConcat = rewriter.create <ConcatenateOp>(
1222+ concat.getLoc (), concatRange, concat.getDimension ());
1223+
1224+ SmallVector<int64_t > newStart (start);
1225+ SmallVector<int64_t > newLimit (limit);
1226+ newStart[dimension] -= frontOffset;
1227+ newLimit[dimension] -= frontOffset;
1228+
1229+ rewriter.replaceOpWithNewOp <SliceOp>(
1230+ slice, newConcat, rewriter.getDenseI64ArrayAttr (newStart),
1231+ rewriter.getDenseI64ArrayAttr (newLimit), slice.getStrides ());
1232+ return success ();
1233+ }
1234+ };
1235+
1236+ // / Flatten sequential concatenations as long as the parent concatenation either
1237+ // / has a single use or is <= 32 elements.
1238+ class SimplifyConcatOfConcatPattern
1239+ : public OpRewritePattern<stablehlo::ConcatenateOp> {
1240+ using OpRewritePattern::OpRewritePattern;
1241+ LogicalResult matchAndRewrite (ConcatenateOp op,
1242+ PatternRewriter &rewriter) const override {
1243+ auto getFlattenedOperands = [&](const Value &val) -> ValueRange {
1244+ auto definingOp = dyn_cast_or_null<ConcatenateOp>(val.getDefiningOp ());
1245+ if (!definingOp || definingOp.getDimension () != op.getDimension ())
1246+ return val;
1247+ if (definingOp->hasOneUse ())
1248+ return definingOp.getInputs ();
1249+ if (!definingOp.getType ().hasStaticShape ())
1250+ return val;
1251+ if (definingOp.getType ().getNumElements () <= 32 )
1252+ return definingOp.getInputs ();
1253+ return val;
1254+ };
1255+
1256+ bool needToFlatten = false ;
1257+ int operandCount = 0 ;
1258+ for (Value val : op.getInputs ()) {
1259+ ValueRange result = getFlattenedOperands (val);
1260+ if (result.size () != 1 || result[0 ] != val)
1261+ needToFlatten = true ;
1262+ operandCount += result.size ();
1263+ }
1264+ if (!needToFlatten)
1265+ return rewriter.notifyMatchFailure (op, " no need to flatten" );
1266+
1267+ llvm::SmallVector<Value, 6 > newOperands;
1268+ newOperands.reserve (operandCount);
1269+ for (Value operand : op.getInputs ())
1270+ llvm::append_range (newOperands, getFlattenedOperands (operand));
1271+
1272+ rewriter.modifyOpInPlace (op, [&] { op->setOperands (newOperands); });
1273+ return success ();
1274+ }
1275+ };
1276+
1277+ void populateFutureUpstreamPatterns (RewritePatternSet &patterns) {
1278+ patterns.add <SimplifySliceOfConcat, SimplifyConcatOfConcatPattern>(
1279+ patterns.getContext ());
1280+ }
0 commit comments