Skip to content

Commit 18cda08

Browse files
committed
Changes to get video pipeline with multiple objects working
1 parent 1176dbd commit 18cda08

File tree

10 files changed

+463
-224
lines changed

10 files changed

+463
-224
lines changed

mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/MaterializeShapeCalculations.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -356,34 +356,37 @@ struct SimplifyExtractOfReshape : public OpRewritePattern<tensor::ExtractOp> {
356356

357357
LogicalResult matchAndRewrite(tensor::ExtractOp op,
358358
PatternRewriter &rewriter) const override {
359-
SmallVector<Value> operands;
359+
360360
auto reshapeOp = op.getTensor().getDefiningOp<stablehlo::ReshapeOp>();
361361
if (!reshapeOp)
362362
return failure();
363363

364+
// Skip if either shape has dynamic dimensions
365+
if (!reshapeOp.getOperand().getType().hasStaticShape())
366+
return failure();
367+
364368
std::optional<SmallVector<int64_t>> coords =
365369
getConstantIntValues(getAsOpFoldResult(op.getIndices()));
366370
if (!coords)
367371
return failure();
368372

369-
// Get lienar coords.
370373
SmallVector<int64_t> resultBasis =
371374
mlir::computeSuffixProduct(reshapeOp.getType().getShape());
372375
SmallVector<int64_t> operandBasis =
373376
mlir::computeSuffixProduct(reshapeOp.getOperand().getType().getShape());
374377

375-
int64_t lienarIndex = mlir::linearize(*coords, resultBasis);
378+
int64_t linearIndex = mlir::linearize(*coords, resultBasis);
376379
SmallVector<int64_t> operandCoords =
377-
mlir::delinearize(lienarIndex, operandBasis);
380+
mlir::delinearize(linearIndex, operandBasis);
378381

379-
// Find linear offset within in the operand shape.
380382
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
381383
op, reshapeOp.getOperand(),
382384
llvm::map_to_vector(operandCoords, [&](int64_t c) -> Value {
383385
return rewriter.create<arith::ConstantIndexOp>(op->getLoc(), c);
384386
}));
385387

386388
return success();
389+
387390
}
388391
};
389392

@@ -858,7 +861,6 @@ class MaterializeShapeCalculationsPass
858861
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns_);
859862
stablehlo_ext::populateStableHloAbsorbTensorCastPatterns(patterns_);
860863
stablehlo::populateStablehloCanonicalizeDynamismPatterns(&patterns_, ctx);
861-
862864
// clang-format off
863865
addCanonicalizationPatterns<
864866
arith::AndIOp,

mlir-tensorrt/compiler/lib/Dialect/StableHloExt/IR/StableHloReifyTypeInterfaceImpl.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,47 @@ class ConvolutionReifyRankedShapedTypeOpInterfaceImpl
280280
}
281281
};
282282

283+
class SelectReifyRankedShapedTypeOpInterfaceImpl
284+
: public ReifyRankedShapedTypeOpInterface::ExternalModel<
285+
SelectReifyRankedShapedTypeOpInterfaceImpl,
286+
stablehlo::SelectOp> {
287+
288+
public:
289+
LogicalResult
290+
reifyResultShapes(Operation *op_, OpBuilder &builder,
291+
ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
292+
293+
auto op = cast<stablehlo::SelectOp>(op_);
294+
Location loc = op.getLoc();
295+
296+
// Get result type
297+
auto resultType = cast<RankedTensorType>(op.getResult().getType());
298+
int64_t rank = resultType.getRank();
299+
300+
// Collect dimension values
301+
SmallVector<OpFoldResult> dims(rank);
302+
for (int64_t i = 0; i < rank; ++i) {
303+
// For each dimension, if it's static in the result type, use that
304+
if (!resultType.isDynamicDim(i)) {
305+
dims[i] = builder.getIndexAttr(resultType.getDimSize(i));
306+
continue;
307+
}
308+
309+
// For dynamic dimensions, we need to compute the broadcasted size
310+
// The operands are: pred, on_true, on_false
311+
Value trueVal = builder.createOrFold<tensor::DimOp>(loc, op.getOperand(1), i);
312+
Value falseVal = builder.createOrFold<tensor::DimOp>(loc, op.getOperand(2), i);
313+
314+
// The result dimension should be the max of the two values
315+
Value maxDim = builder.create<arith::MaxSIOp>(loc, trueVal, falseVal);
316+
dims[i] = maxDim;
317+
}
318+
reifiedReturnShapes.emplace_back(std::move(dims));
319+
return success();
320+
321+
}
322+
323+
};
283324
class ReduceWindowReifyRankedShapedTypeOpInterfaceImpl
284325
: public ReifyRankedShapedTypeOpInterface::ExternalModel<
285326
ReduceWindowReifyRankedShapedTypeOpInterfaceImpl,
@@ -353,4 +394,10 @@ void stablehlo::registerTypeInferenceExternalModels(DialectRegistry &registry) {
353394
stablehlo::ReduceWindowOp::attachInterface<
354395
ReduceWindowReifyRankedShapedTypeOpInterfaceImpl>(*ctx);
355396
});
397+
registry.addExtension(
398+
+[](MLIRContext *ctx, stablehlo::StablehloDialect *dialect) {
399+
stablehlo::SelectOp::attachInterface<
400+
SelectReifyRankedShapedTypeOpInterfaceImpl>(*ctx);
401+
});
402+
356403
}

mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/ConstantFolding.cpp

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "mlir/Dialect/CommonFolders.h"
3030
#include "mlir/Dialect/Tensor/IR/Tensor.h"
3131
#include "mlir/Dialect/UB/IR/UBOps.h"
32+
#include "mlir/IR/PatternMatch.h"
3233
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3334
#include "stablehlo/dialect/StablehloOps.h"
3435
#include "stablehlo/dialect/TypeInference.h"
@@ -1064,6 +1065,11 @@ struct AbsorbTensorCastProducer : public RewritePattern {
10641065
};
10651066
} // namespace
10661067

1068+
1069+
/// Populates patterns that are temporarily reproduced here from upstream
1070+
/// commits we have not yet integrated.
1071+
static void populateFutureUpstreamPatterns(RewritePatternSet &patterns);
1072+
10671073
void stablehlo_ext::populateStableHloAbsorbTensorCastPatterns(
10681074
RewritePatternSet &patterns) {
10691075
patterns.add<AbsorbTensorCastProducer>(patterns.getContext());
@@ -1108,6 +1114,7 @@ class ConstantFoldingPass
11081114
SqrtOpFolder
11091115
>(ctx);
11101116
// clang-format on
1117+
populateFutureUpstreamPatterns(patterns);
11111118
populateStableHloAbsorbTensorCastPatterns(patterns);
11121119
stablehlo::populateStablehloCanonicalizationPatterns(ctx, &patterns);
11131120
tensor::EmptyOp::getCanonicalizationPatterns(patterns, ctx);
@@ -1124,3 +1131,150 @@ class ConstantFoldingPass
11241131
}
11251132
};
11261133
} // namespace
1134+
1135+
//===----------------------------------------------------------------------===//
1136+
/// The patterns below this point are reproduced from
1137+
/// https://github.com/openxla/stablehlo/commit/5d15ab064f165cc6773ef4ba949ac083ae8e1fea,
1138+
/// which is in upstream, but our current pinned StableHlo commit is not there
1139+
/// yet. The patterns can be removed in the next StableHLO upgrade.
1140+
///
1141+
//===----------------------------------------------------------------------===//
1142+
1143+
///
1144+
/// In cases where a concat is fed into a slice, it
1145+
/// is possible the concat can be simplified or bypassed. This checks which
1146+
/// inputs to the concat are used by the slice, either reducing the number of
1147+
/// concatenated values or entirely removes the concat. Pattern:
1148+
/// slice(concat(X,Y,Z,...),...) -> concat(slice(X),slice(Y),slice(Z))
1149+
struct SimplifySliceOfConcat : public OpRewritePattern<SliceOp> {
1150+
using OpRewritePattern::OpRewritePattern;
1151+
1152+
LogicalResult matchAndRewrite(SliceOp slice,
1153+
PatternRewriter &rewriter) const override {
1154+
RankedTensorType resultTy = slice.getType();
1155+
if (!resultTy.hasStaticShape())
1156+
return rewriter.notifyMatchFailure(slice, "result shape not static");
1157+
1158+
auto concat = slice.getOperand().getDefiningOp<ConcatenateOp>();
1159+
if (!concat)
1160+
return rewriter.notifyMatchFailure(slice, "slice input not concat");
1161+
1162+
RankedTensorType concatType = concat.getType();
1163+
uint64_t dimension = concat.getDimension();
1164+
1165+
ArrayRef<int64_t> start = slice.getStartIndices();
1166+
ArrayRef<int64_t> limit = slice.getLimitIndices();
1167+
1168+
int64_t sliceStart = start[dimension];
1169+
int64_t sliceLimit = limit[dimension];
1170+
1171+
// We need to determine what inputs from the concat affect the slice, and
1172+
// how the bounds of the slice need to be updated for the minimally required
1173+
// inputs.
1174+
int64_t runningSize = 0;
1175+
int64_t frontOffset = concatType.getShape()[dimension];
1176+
1177+
auto subsetStart = concat.operand_end();
1178+
auto subsetEnd = concat.operand_end();
1179+
for (auto it = concat.operand_begin(); it < concat.operand_end(); ++it) {
1180+
Value input = *it;
1181+
auto inputTy = cast<RankedTensorType>(input.getType());
1182+
if (inputTy.isDynamicDim(dimension))
1183+
return rewriter.notifyMatchFailure(
1184+
slice, "concat input has dynamic dimension");
1185+
1186+
int64_t dimSize = inputTy.getShape()[dimension];
1187+
1188+
// If this position is in the slice its the start of the subset and we
1189+
// need to update the start and limit values.
1190+
if (runningSize + dimSize > sliceStart &&
1191+
subsetStart == concat.operand_end()) {
1192+
subsetStart = it;
1193+
frontOffset = runningSize;
1194+
}
1195+
1196+
// Determine the last required offset.
1197+
if (runningSize < sliceLimit) {
1198+
subsetEnd = it + 1;
1199+
}
1200+
1201+
runningSize += dimSize;
1202+
}
1203+
1204+
auto subsetSize = subsetEnd - subsetStart;
1205+
// We need all inputs so no optimization.
1206+
if (subsetSize == concat.getNumOperands())
1207+
return rewriter.notifyMatchFailure(slice,
1208+
"slice needs all concat inputs");
1209+
1210+
// If there's nothing to slice that means the output is an empty tensor and
1211+
// there is dead code. We do nothing here and rely on other passes to clean
1212+
// this up.
1213+
if (subsetSize == 0)
1214+
return rewriter.notifyMatchFailure(slice, "slice is empty");
1215+
1216+
if (subsetSize > 1 && !concat.getResult().hasOneUse())
1217+
return rewriter.notifyMatchFailure(slice,
1218+
"slice is not the only concat user");
1219+
1220+
auto concatRange = OperandRange(subsetStart, subsetEnd);
1221+
auto newConcat = rewriter.create<ConcatenateOp>(
1222+
concat.getLoc(), concatRange, concat.getDimension());
1223+
1224+
SmallVector<int64_t> newStart(start);
1225+
SmallVector<int64_t> newLimit(limit);
1226+
newStart[dimension] -= frontOffset;
1227+
newLimit[dimension] -= frontOffset;
1228+
1229+
rewriter.replaceOpWithNewOp<SliceOp>(
1230+
slice, newConcat, rewriter.getDenseI64ArrayAttr(newStart),
1231+
rewriter.getDenseI64ArrayAttr(newLimit), slice.getStrides());
1232+
return success();
1233+
}
1234+
};
1235+
1236+
/// Flatten sequential concatenations as long as the parent concatenation either
1237+
/// has a single use or is <= 32 elements.
1238+
class SimplifyConcatOfConcatPattern
1239+
: public OpRewritePattern<stablehlo::ConcatenateOp> {
1240+
using OpRewritePattern::OpRewritePattern;
1241+
LogicalResult matchAndRewrite(ConcatenateOp op,
1242+
PatternRewriter &rewriter) const override {
1243+
auto getFlattenedOperands = [&](const Value &val) -> ValueRange {
1244+
auto definingOp = dyn_cast_or_null<ConcatenateOp>(val.getDefiningOp());
1245+
if (!definingOp || definingOp.getDimension() != op.getDimension())
1246+
return val;
1247+
if (definingOp->hasOneUse())
1248+
return definingOp.getInputs();
1249+
if (!definingOp.getType().hasStaticShape())
1250+
return val;
1251+
if (definingOp.getType().getNumElements() <= 32)
1252+
return definingOp.getInputs();
1253+
return val;
1254+
};
1255+
1256+
bool needToFlatten = false;
1257+
int operandCount = 0;
1258+
for (Value val : op.getInputs()) {
1259+
ValueRange result = getFlattenedOperands(val);
1260+
if (result.size() != 1 || result[0] != val)
1261+
needToFlatten = true;
1262+
operandCount += result.size();
1263+
}
1264+
if (!needToFlatten)
1265+
return rewriter.notifyMatchFailure(op, "no need to flatten");
1266+
1267+
llvm::SmallVector<Value, 6> newOperands;
1268+
newOperands.reserve(operandCount);
1269+
for (Value operand : op.getInputs())
1270+
llvm::append_range(newOperands, getFlattenedOperands(operand));
1271+
1272+
rewriter.modifyOpInPlace(op, [&] { op->setOperands(newOperands); });
1273+
return success();
1274+
}
1275+
};
1276+
1277+
void populateFutureUpstreamPatterns(RewritePatternSet &patterns) {
1278+
patterns.add<SimplifySliceOfConcat, SimplifyConcatOfConcatPattern>(
1279+
patterns.getContext());
1280+
}

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/BroadcastElimination.cpp

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,65 @@ struct PushDownBroadcastReduceRankOp : public OpRewritePattern<CollapseRankOp> {
120120
};
121121
} // namespace
122122

123+
static Value expandRank(RewriterBase &rewriter, Location loc,
124+
TypedValue<RankedTensorType> input,
125+
ArrayRef<int64_t> reorderedBroadcastDims,
126+
RankedTensorType resultType) {
127+
RankedTensorType inputType = input.getType();
128+
// For <= 1 dynamic dims, no need to do dynamic reshape.
129+
if (input.getType().getNumDynamicDims() <= 1) {
130+
SmallVector<int64_t> staticShape(resultType.getRank());
131+
132+
unsigned inputIdx = 0;
133+
for (unsigned i = 0, e = staticShape.size(); i < e; i++) {
134+
if (inputIdx < reorderedBroadcastDims.size() &&
135+
i == reorderedBroadcastDims[inputIdx]) {
136+
staticShape[i] = inputType.getDimSize(inputIdx++);
137+
continue;
138+
}
139+
staticShape[i] = 1;
140+
}
141+
return rewriter.create<ReshapeOp>(loc, resultType.clone(staticShape),
142+
input);
143+
}
144+
145+
// Otherwise, we need to do dynamic reshape.
146+
auto shape = rewriter.create<tensorrt::ShapeOp>(loc, input);
147+
SmallVector<Value> shapeComponents(resultType.getRank());
148+
SmallVector<int64_t> staticShape(resultType.getRank());
149+
unsigned inputIdx = 0;
150+
for (unsigned i = 0, e = shapeComponents.size(); i < e; i++) {
151+
if (inputIdx < reorderedBroadcastDims.size() &&
152+
i == reorderedBroadcastDims[inputIdx]) {
153+
if (!inputType.isDynamicDim(inputIdx)) {
154+
staticShape[i] = inputType.getDimSize(inputIdx);
155+
shapeComponents[i] = rewriter.create<tensorrt::ConstantOp>(
156+
loc, rewriter.getI32TensorAttr(
157+
{static_cast<int32_t>(inputType.getDimSize(inputIdx++))}));
158+
continue;
159+
}
160+
shapeComponents[i] = rewriter.create<tensorrt::SliceOp>(
161+
loc, shape,
162+
/*offset=*/ArrayRef<int32_t>{static_cast<int32_t>(inputIdx++)},
163+
ArrayRef<int32_t>{1}, ArrayRef<int32_t>{1});
164+
staticShape[i] = ShapedType::kDynamic;
165+
continue;
166+
}
167+
staticShape[i] = 1;
168+
shapeComponents[i] = rewriter.create<tensorrt::ConstantOp>(
169+
loc, rewriter.getI32TensorAttr(
170+
{static_cast<int32_t>(inputType.getDimSize(1))}));
171+
}
172+
auto newShape = rewriter.create<tensorrt::ConcatenationOp>(
173+
loc,
174+
RankedTensorType::get(static_cast<int64_t>(shapeComponents.size()),
175+
rewriter.getI32Type()),
176+
shapeComponents, /*axis=*/0);
177+
178+
return rewriter.create<ReshapeOp>(loc, resultType.clone(staticShape), input,
179+
newShape);
180+
}
181+
123182
namespace {
124183
/// Create transpose + expand_rank on the input of a `tensorrt.broadcast` so
125184
/// that the result has the same rank as the `tensorrt.broadcast` result and the
@@ -157,8 +216,9 @@ struct SimplifyBroadcast : public OpRewritePattern<BroadcastOp> {
157216
}
158217
expandedShape[i] = 1;
159218
}
160-
Value expanded = rewriter.create<ExpandRankOp>(
161-
loc, resultType.clone(expandedShape), transposeOp);
219+
220+
Value expanded = expandRank(rewriter, loc, transposeOp,
221+
reorderedBroadcastDims, resultType);
162222
rewriter.replaceOpWithNewOp<BroadcastOp>(
163223
op, op.getType(), expanded, op.getShape(),
164224
llvm::to_vector(llvm::seq<int64_t>(0, resultType.getRank())));
@@ -341,6 +401,8 @@ class BroadcastEliminationPass
341401
patterns.add<SimplifyBroadcast, ElementwiseAbsorbBroadcast,
342402
PushDownBroadcastReduceRankOp, SelectAbsorbBroadcast,
343403
MatMulAbsorbBroadcast>(&getContext());
404+
tensorrt::ReshapeOp::getCanonicalizationPatterns(patterns,
405+
patterns.getContext());
344406
if (failed(applyPatternsAndFoldGreedily(getOperation(),
345407
std::move(patterns)))) {
346408
emitError(getOperation()->getLoc())

0 commit comments

Comments
 (0)