Skip to content

Commit e491f40

Browse files
author
Matthew Francis-Landau
committed
fix bug with pushing up reshapes with 1 axes
1 parent a53bc6f commit e491f40

File tree

2 files changed

+31
-9
lines changed

2 files changed

+31
-9
lines changed

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ using namespace mlir;
4646
using namespace mlir::tensorrt;
4747

4848
// Set the max size of tensors which can be constant-folded to 131072 (0.5 MB
49-
// for f32 constants).g
49+
// for f32 constants).
5050
constexpr int64_t kFoldOpEltLimit = 1 << 17;
5151

5252
static int64_t memoryCost(RankedTensorType type) {
@@ -667,7 +667,7 @@ class PushDownTransposeToEinsum : public OpRewritePattern<tensorrt::EinsumOp> {
667667
return failure();
668668

669669
std::string newEinsumEquation = einsumEquation.generateEquation();
670-
670+
assert(einsumEquation.rhs.size() == op.getType().getShape().size());
671671
rewriter.replaceOpWithNewOp<tensorrt::EinsumOp>(op, op.getType(), newInputs,
672672
newEinsumEquation);
673673
return success();
@@ -713,6 +713,7 @@ class PushUpTransposeToEinsum : public OpRewritePattern<tensorrt::TransposeOp> {
713713

714714
auto newEinsum = rewriter.create<tensorrt::EinsumOp>(
715715
op.getLoc(), op.getType(), einsum.getInputs(), newEinsumEquation);
716+
assert(einsumEquation.rhs.size() == newEinsum.getType().getShape().size());
716717
rewriter.replaceOp(op, newEinsum.getResult());
717718
return success();
718719
}
@@ -789,6 +790,7 @@ class EinsumPushDownTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
789790
auto newEinsum = rewriter.create<tensorrt::EinsumOp>(
790791
op.getLoc(), op.getType().clone(newEinsumShape), op.getInputs(),
791792
newEinsumEquation);
793+
assert(equation.rhs.size() == newEinsum.getType().getShape().size());
792794

793795
auto forwardMap =
794796
AffineMap::getPermutationMap(forwardPerm, op.getLoc().getContext());
@@ -909,6 +911,7 @@ class EinsumPushUpTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
909911
return failure();
910912

911913
std::string newEquation = equation.generateEquation();
914+
assert(equation.rhs.size() == op.getType().getShape().size());
912915
rewriter.replaceOpWithNewOp<tensorrt::EinsumOp>(op, op.getType(), newInputs,
913916
newEquation);
914917
return success();
@@ -991,14 +994,18 @@ class EinsumEliminate1Axis : public OpRewritePattern<tensorrt::EinsumOp> {
991994
auto newEinsum = rewriter.create<tensorrt::EinsumOp>(
992995
op.getLoc(), outputType.clone(newOutputShape), newInputs,
993996
newEquation);
997+
assert(newEinsumEquation.rhs.size() ==
998+
newEinsum.getType().getShape().size());
994999
auto outReshape =
9951000
rewriter
9961001
.create<tensorrt::ReshapeOp>(op.getLoc(), op.getType(),
9971002
newEinsum.getResult())
9981003
.getResult();
1004+
assert(op.getType() == outReshape.getType());
9991005
rewriter.replaceOp(op, outReshape);
10001006
return success();
10011007
} else {
1008+
assert(newEinsumEquation.rhs.size() == op.getType().getShape().size());
10021009
rewriter.replaceOpWithNewOp<tensorrt::EinsumOp>(op, op.getType(),
10031010
newInputs, newEquation);
10041011
return success();
@@ -1069,7 +1076,7 @@ class EinsumMergeDown1Axis : public OpRewritePattern<tensorrt::EinsumOp> {
10691076
return failure();
10701077

10711078
std::string newEquation = equation.generateEquation();
1072-
1079+
assert(equation.rhs.size() == op.getType().getShape().size());
10731080
rewriter.replaceOpWithNewOp<tensorrt::EinsumOp>(op, op.getType(), newInputs,
10741081
newEquation);
10751082
return success();
@@ -1144,6 +1151,7 @@ class EinsumMergeUp1Axis : public OpRewritePattern<tensorrt::ExpandRankOp> {
11441151
}
11451152

11461153
std::string newEquation = equation.lhs + "->" + newRhs;
1154+
assert(newRhs.size() == op.getType().getShape().size());
11471155
rewriter.replaceOpWithNewOp<tensorrt::EinsumOp>(
11481156
op, op.getType(), einsum.getInputs(), newEquation);
11491157
return success();
@@ -1165,6 +1173,8 @@ class EinsumPushUp1AxisReshape : public OpRewritePattern<tensorrt::EinsumOp> {
11651173
if (op->getNumOperands() != 2)
11661174
return failure();
11671175

1176+
assert(equation.rhs.size() == op.getType().getShape().size());
1177+
11681178
char matrixAxes[2] = {0, 0};
11691179
char multipliedAxis = 0;
11701180

@@ -1265,6 +1275,7 @@ class EinsumPushUp1AxisReshape : public OpRewritePattern<tensorrt::EinsumOp> {
12651275
rewriter.createOrFold<ReshapeOp>(op.getLoc(), newInputTypes[1],
12661276
op.getInputs()[1])};
12671277

1278+
assert(newEquation.rhs.size() == op.getType().getShape().size());
12681279
rewriter.replaceOpWithNewOp<EinsumOp>(op, op.getType(), reshapes,
12691280
newEquation.generateEquation());
12701281

@@ -1341,9 +1352,9 @@ class PushReshapeUpThroughEinsum
13411352
inAxes += equation.rhs[j++];
13421353
}
13431354
if (inputNumElems == outputNumElems) {
1344-
if (reshapeOutShape[i] == 1 && outAxes.size() == 1 &&
1345-
inAxes.size() == 0) {
1346-
if (!prevInAxes.empty()) {
1355+
if (inAxes.empty()) {
1356+
if (!prevInAxes.empty() && reshapeOutShape[i] == 1 &&
1357+
outAxes.size() == 1) {
13471358
auto &p = inputToReshapedMap[prevInAxes];
13481359
p.first.push_back(c);
13491360
p.second.push_back(1);
@@ -1363,7 +1374,7 @@ class PushReshapeUpThroughEinsum
13631374
outAxes = "";
13641375
}
13651376
}
1366-
if (inputNumElems != outputNumElems)
1377+
if (inputNumElems != outputNumElems || !inAxes.empty() || !outAxes.empty())
13671378
return failure(/* should not happen, unexpected reshape */);
13681379
if (!hasNonTrivalReshape)
13691380
return failure(/* reshape is only expanding rank */);
@@ -1443,6 +1454,8 @@ class PushReshapeUpThroughEinsum
14431454

14441455
auto newEinsum = rewriter.create<tensorrt::EinsumOp>(
14451456
op.getLoc(), op.getType(), newInputs, newEquationStr);
1457+
assert(newEquation.rhs.size() == newEinsum.getType().getShape().size());
1458+
assert(op.getType() == newEinsum.getType());
14461459
rewriter.replaceOp(op, newEinsum.getResult());
14471460

14481461
return success();
@@ -1528,6 +1541,7 @@ class EinsumPushUpMultipleMulitipliedAxes
15281541
newInputs.push_back(reshape);
15291542
}
15301543

1544+
assert(newEquation.rhs.size() == op.getType().getShape().size());
15311545
rewriter.replaceOpWithNewOp<tensorrt::EinsumOp>(
15321546
op, op.getType(), newInputs, newEquation.generateEquation());
15331547
return success();
@@ -1848,6 +1862,7 @@ class PushReshapeDownThroughEinsum
18481862
auto newEinsum = rewriter.create<tensorrt::EinsumOp>(
18491863
op.getLoc(), outputType.clone(einsumOutputShape), newInputs,
18501864
newEinsumEquation);
1865+
assert(newEquation.rhs.size() == newEinsum.getType().getShape().size());
18511866

18521867
auto newReshape = rewriter.createOrFold<tensorrt::ReshapeOp>(
18531868
op.getLoc(), outputType.clone(afterEinsumReshape),
@@ -1857,6 +1872,7 @@ class PushReshapeDownThroughEinsum
18571872
op.getLoc(), newReshape,
18581873
AffineMap::getPermutationMap(afterReshapeTranspose, op.getContext()));
18591874

1875+
assert(op.getType() == newOut.getType());
18601876
rewriter.replaceOp(op, newOut);
18611877

18621878
return success();
@@ -2173,7 +2189,6 @@ class MoveTransposeBeforeReshape
21732189
op.getLoc(), reshapeInputType.clone(newReshape), newTransposeOp);
21742190

21752191
assert(op.getType() == newReshapeOp.getType());
2176-
21772192
rewriter.replaceOp(op, newReshapeOp);
21782193
return success();
21792194
}
@@ -2308,6 +2323,7 @@ class PushUpOpQuantizeDequantize : public OpRewritePattern<OpType> {
23082323
auto newDequantizeOp = rewriter.create<tensorrt::DequantizeOp>(
23092324
dequantizeOp.getLoc(), op.getResult().getType(),
23102325
newQuantizeOp.getResult(), scale, dequantizeOp.getAxisAttr());
2326+
assert(op.getType() == newDequantizeOp.getType());
23112327
rewriter.replaceOp(op, newDequantizeOp.getResult());
23122328
return success();
23132329
}
@@ -2354,6 +2370,7 @@ class PushDownOpQuantizeDequantize
23542370
auto newOp =
23552371
rewriter.create<OpType>(op.getLoc(), dequantizeOp.getResult().getType(),
23562372
newDequantizeOp.getResult(), op->getAttrs());
2373+
assert(dequantizeOp.getType() == newOp.getType());
23572374
rewriter.replaceOp(dequantizeOp, newOp.getResult());
23582375
return success();
23592376
}
@@ -2600,6 +2617,7 @@ class PushUpReshapeElementwise : public OpRewritePattern<tensorrt::ReshapeOp> {
26002617
auto newElementwiseOp = rewriter.create<tensorrt::ElementWiseOp>(
26012618
elementwiseOp.getLoc(), op.getResult().getType(), newLhs, newRhs,
26022619
elementwiseOp.getElementwiseOperation());
2620+
assert(op.getType() == newElementwiseOp.getType());
26032621
rewriter.replaceOp(op, newElementwiseOp.getResult());
26042622

26052623
return success();
@@ -2696,6 +2714,7 @@ class PushUpTransposeSoftmax : public OpRewritePattern<tensorrt::TransposeOp> {
26962714
op.getLoc(), softmax.getInput(), op.getPermutation());
26972715
auto newSoftmax = rewriter.create<tensorrt::SoftMaxOp>(
26982716
softmax.getLoc(), newTranspose, newAxis);
2717+
assert(op.getType() == newSoftmax.getType());
26992718
rewriter.replaceOp(op, newSoftmax.getResult());
27002719
return success();
27012720
}
@@ -2719,6 +2738,7 @@ class PushDownTransposeSoftmax : public OpRewritePattern<tensorrt::SoftMaxOp> {
27192738
op.getLoc(), transpose.getInput(), newAxis);
27202739
auto newTranspose = rewriter.create<tensorrt::TransposeOp>(
27212740
transpose.getLoc(), newSoftmax, transpose.getPermutation());
2741+
assert(op.getType() == newTranspose.getType());
27222742
rewriter.replaceOp(op, newTranspose.getResult());
27232743
return success();
27242744
}
@@ -2766,6 +2786,7 @@ class PushUpReshapeSoftmax : public OpRewritePattern<tensorrt::ReshapeOp> {
27662786
op.getLoc(), outputType, softmax.getInput());
27672787
auto newSoftmax = rewriter.create<tensorrt::SoftMaxOp>(
27682788
softmax.getLoc(), newReshape.getResult(), newAxis);
2789+
assert(op.getType() == newSoftmax.getType());
27692790
rewriter.replaceOp(op, newSoftmax.getResult());
27702791
return success();
27712792
}
@@ -2810,6 +2831,7 @@ class PushDownReshapeSoftmax : public OpRewritePattern<tensorrt::SoftMaxOp> {
28102831
op.getLoc(), reshapeOp.getInput(), newAxis);
28112832
auto newReshape = rewriter.create<tensorrt::ReshapeOp>(
28122833
reshapeOp.getLoc(), outputType, newSoftmax.getResult());
2834+
assert(op.getType() == newReshape.getType());
28132835
rewriter.replaceOp(op, newReshape.getResult());
28142836
return success();
28152837
}

mlir-tensorrt/tensorrt/test/Dialect/TensorRT/transpose-reshape-elimination.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,4 +335,4 @@ func.func @transpose_on_scalar(%arg0: tensor<4488x4x48xf32>, %arg1: tensor<f32>)
335335
func.func @einsum_multiply_two_axis(%arg0: tensor<10x11x12xf32>, %arg1: tensor<13x11x12xf32>) -> tensor<10x13xf32> {
336336
%0 = tensorrt.einsum {equation = "acd,bcd->ab"} ins(%arg0, %arg1: tensor<10x11x12xf32>, tensor<13x11x12xf32>) -> tensor<10x13xf32>
337337
return %0 : tensor<10x13xf32>
338-
}
338+
}

0 commit comments

Comments
 (0)