@@ -46,7 +46,7 @@ using namespace mlir;
4646using namespace mlir ::tensorrt;
4747
4848// Set the max size of tensors which can be constant-folded to 131072 (0.5 MB
49- // for f32 constants).g
49+ // for f32 constants).
5050constexpr int64_t kFoldOpEltLimit = 1 << 17 ;
5151
5252static int64_t memoryCost (RankedTensorType type) {
@@ -667,7 +667,7 @@ class PushDownTransposeToEinsum : public OpRewritePattern<tensorrt::EinsumOp> {
667667 return failure ();
668668
669669 std::string newEinsumEquation = einsumEquation.generateEquation ();
670-
670+ assert (einsumEquation. rhs . size () == op. getType (). getShape (). size ());
671671 rewriter.replaceOpWithNewOp <tensorrt::EinsumOp>(op, op.getType (), newInputs,
672672 newEinsumEquation);
673673 return success ();
@@ -713,6 +713,7 @@ class PushUpTransposeToEinsum : public OpRewritePattern<tensorrt::TransposeOp> {
713713
714714 auto newEinsum = rewriter.create <tensorrt::EinsumOp>(
715715 op.getLoc (), op.getType (), einsum.getInputs (), newEinsumEquation);
716+ assert (einsumEquation.rhs .size () == newEinsum.getType ().getShape ().size ());
716717 rewriter.replaceOp (op, newEinsum.getResult ());
717718 return success ();
718719 }
@@ -789,6 +790,7 @@ class EinsumPushDownTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
789790 auto newEinsum = rewriter.create <tensorrt::EinsumOp>(
790791 op.getLoc (), op.getType ().clone (newEinsumShape), op.getInputs (),
791792 newEinsumEquation);
793+ assert (equation.rhs .size () == newEinsum.getType ().getShape ().size ());
792794
793795 auto forwardMap =
794796 AffineMap::getPermutationMap (forwardPerm, op.getLoc ().getContext ());
@@ -909,6 +911,7 @@ class EinsumPushUpTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
909911 return failure ();
910912
911913 std::string newEquation = equation.generateEquation ();
914+ assert (equation.rhs .size () == op.getType ().getShape ().size ());
912915 rewriter.replaceOpWithNewOp <tensorrt::EinsumOp>(op, op.getType (), newInputs,
913916 newEquation);
914917 return success ();
@@ -991,14 +994,18 @@ class EinsumEliminate1Axis : public OpRewritePattern<tensorrt::EinsumOp> {
991994 auto newEinsum = rewriter.create <tensorrt::EinsumOp>(
992995 op.getLoc (), outputType.clone (newOutputShape), newInputs,
993996 newEquation);
997+ assert (newEinsumEquation.rhs .size () ==
998+ newEinsum.getType ().getShape ().size ());
994999 auto outReshape =
9951000 rewriter
9961001 .create <tensorrt::ReshapeOp>(op.getLoc (), op.getType (),
9971002 newEinsum.getResult ())
9981003 .getResult ();
1004+ assert (op.getType () == outReshape.getType ());
9991005 rewriter.replaceOp (op, outReshape);
10001006 return success ();
10011007 } else {
1008+ assert (newEinsumEquation.rhs .size () == op.getType ().getShape ().size ());
10021009 rewriter.replaceOpWithNewOp <tensorrt::EinsumOp>(op, op.getType (),
10031010 newInputs, newEquation);
10041011 return success ();
@@ -1069,7 +1076,7 @@ class EinsumMergeDown1Axis : public OpRewritePattern<tensorrt::EinsumOp> {
10691076 return failure ();
10701077
10711078 std::string newEquation = equation.generateEquation ();
1072-
1079+ assert (equation. rhs . size () == op. getType (). getShape (). size ());
10731080 rewriter.replaceOpWithNewOp <tensorrt::EinsumOp>(op, op.getType (), newInputs,
10741081 newEquation);
10751082 return success ();
@@ -1144,6 +1151,7 @@ class EinsumMergeUp1Axis : public OpRewritePattern<tensorrt::ExpandRankOp> {
11441151 }
11451152
11461153 std::string newEquation = equation.lhs + " ->" + newRhs;
1154+ assert (newRhs.size () == op.getType ().getShape ().size ());
11471155 rewriter.replaceOpWithNewOp <tensorrt::EinsumOp>(
11481156 op, op.getType (), einsum.getInputs (), newEquation);
11491157 return success ();
@@ -1165,6 +1173,8 @@ class EinsumPushUp1AxisReshape : public OpRewritePattern<tensorrt::EinsumOp> {
11651173 if (op->getNumOperands () != 2 )
11661174 return failure ();
11671175
1176+ assert (equation.rhs .size () == op.getType ().getShape ().size ());
1177+
11681178 char matrixAxes[2 ] = {0 , 0 };
11691179 char multipliedAxis = 0 ;
11701180
@@ -1265,6 +1275,7 @@ class EinsumPushUp1AxisReshape : public OpRewritePattern<tensorrt::EinsumOp> {
12651275 rewriter.createOrFold <ReshapeOp>(op.getLoc (), newInputTypes[1 ],
12661276 op.getInputs ()[1 ])};
12671277
1278+ assert (newEquation.rhs .size () == op.getType ().getShape ().size ());
12681279 rewriter.replaceOpWithNewOp <EinsumOp>(op, op.getType (), reshapes,
12691280 newEquation.generateEquation ());
12701281
@@ -1341,9 +1352,9 @@ class PushReshapeUpThroughEinsum
13411352 inAxes += equation.rhs [j++];
13421353 }
13431354 if (inputNumElems == outputNumElems) {
1344- if (reshapeOutShape[i] == 1 && outAxes. size () == 1 &&
1345- inAxes. size () == 0 ) {
1346- if (!prevInAxes. empty () ) {
1355+ if (inAxes. empty ()) {
1356+ if (!prevInAxes. empty () && reshapeOutShape[i] == 1 &&
1357+ outAxes. size () == 1 ) {
13471358 auto &p = inputToReshapedMap[prevInAxes];
13481359 p.first .push_back (c);
13491360 p.second .push_back (1 );
@@ -1363,7 +1374,7 @@ class PushReshapeUpThroughEinsum
13631374 outAxes = " " ;
13641375 }
13651376 }
1366- if (inputNumElems != outputNumElems)
1377+ if (inputNumElems != outputNumElems || !inAxes. empty () || !outAxes. empty () )
13671378 return failure (/* should not happen, unexpected reshape */ );
13681379 if (!hasNonTrivalReshape)
13691380 return failure (/* reshape is only expanding rank */ );
@@ -1443,6 +1454,8 @@ class PushReshapeUpThroughEinsum
14431454
14441455 auto newEinsum = rewriter.create <tensorrt::EinsumOp>(
14451456 op.getLoc (), op.getType (), newInputs, newEquationStr);
1457+ assert (newEquation.rhs .size () == newEinsum.getType ().getShape ().size ());
1458+ assert (op.getType () == newEinsum.getType ());
14461459 rewriter.replaceOp (op, newEinsum.getResult ());
14471460
14481461 return success ();
@@ -1528,6 +1541,7 @@ class EinsumPushUpMultipleMulitipliedAxes
15281541 newInputs.push_back (reshape);
15291542 }
15301543
1544+ assert (newEquation.rhs .size () == op.getType ().getShape ().size ());
15311545 rewriter.replaceOpWithNewOp <tensorrt::EinsumOp>(
15321546 op, op.getType (), newInputs, newEquation.generateEquation ());
15331547 return success ();
@@ -1848,6 +1862,7 @@ class PushReshapeDownThroughEinsum
18481862 auto newEinsum = rewriter.create <tensorrt::EinsumOp>(
18491863 op.getLoc (), outputType.clone (einsumOutputShape), newInputs,
18501864 newEinsumEquation);
1865+ assert (newEquation.rhs .size () == newEinsum.getType ().getShape ().size ());
18511866
18521867 auto newReshape = rewriter.createOrFold <tensorrt::ReshapeOp>(
18531868 op.getLoc (), outputType.clone (afterEinsumReshape),
@@ -1857,6 +1872,7 @@ class PushReshapeDownThroughEinsum
18571872 op.getLoc (), newReshape,
18581873 AffineMap::getPermutationMap (afterReshapeTranspose, op.getContext ()));
18591874
1875+ assert (op.getType () == newOut.getType ());
18601876 rewriter.replaceOp (op, newOut);
18611877
18621878 return success ();
@@ -2173,7 +2189,6 @@ class MoveTransposeBeforeReshape
21732189 op.getLoc (), reshapeInputType.clone (newReshape), newTransposeOp);
21742190
21752191 assert (op.getType () == newReshapeOp.getType ());
2176-
21772192 rewriter.replaceOp (op, newReshapeOp);
21782193 return success ();
21792194 }
@@ -2308,6 +2323,7 @@ class PushUpOpQuantizeDequantize : public OpRewritePattern<OpType> {
23082323 auto newDequantizeOp = rewriter.create <tensorrt::DequantizeOp>(
23092324 dequantizeOp.getLoc (), op.getResult ().getType (),
23102325 newQuantizeOp.getResult (), scale, dequantizeOp.getAxisAttr ());
2326+ assert (op.getType () == newDequantizeOp.getType ());
23112327 rewriter.replaceOp (op, newDequantizeOp.getResult ());
23122328 return success ();
23132329 }
@@ -2354,6 +2370,7 @@ class PushDownOpQuantizeDequantize
23542370 auto newOp =
23552371 rewriter.create <OpType>(op.getLoc (), dequantizeOp.getResult ().getType (),
23562372 newDequantizeOp.getResult (), op->getAttrs ());
2373+ assert (dequantizeOp.getType () == newOp.getType ());
23572374 rewriter.replaceOp (dequantizeOp, newOp.getResult ());
23582375 return success ();
23592376 }
@@ -2600,6 +2617,7 @@ class PushUpReshapeElementwise : public OpRewritePattern<tensorrt::ReshapeOp> {
26002617 auto newElementwiseOp = rewriter.create <tensorrt::ElementWiseOp>(
26012618 elementwiseOp.getLoc (), op.getResult ().getType (), newLhs, newRhs,
26022619 elementwiseOp.getElementwiseOperation ());
2620+ assert (op.getType () == newElementwiseOp.getType ());
26032621 rewriter.replaceOp (op, newElementwiseOp.getResult ());
26042622
26052623 return success ();
@@ -2696,6 +2714,7 @@ class PushUpTransposeSoftmax : public OpRewritePattern<tensorrt::TransposeOp> {
26962714 op.getLoc (), softmax.getInput (), op.getPermutation ());
26972715 auto newSoftmax = rewriter.create <tensorrt::SoftMaxOp>(
26982716 softmax.getLoc (), newTranspose, newAxis);
2717+ assert (op.getType () == newSoftmax.getType ());
26992718 rewriter.replaceOp (op, newSoftmax.getResult ());
27002719 return success ();
27012720 }
@@ -2719,6 +2738,7 @@ class PushDownTransposeSoftmax : public OpRewritePattern<tensorrt::SoftMaxOp> {
27192738 op.getLoc (), transpose.getInput (), newAxis);
27202739 auto newTranspose = rewriter.create <tensorrt::TransposeOp>(
27212740 transpose.getLoc (), newSoftmax, transpose.getPermutation ());
2741+ assert (op.getType () == newTranspose.getType ());
27222742 rewriter.replaceOp (op, newTranspose.getResult ());
27232743 return success ();
27242744 }
@@ -2766,6 +2786,7 @@ class PushUpReshapeSoftmax : public OpRewritePattern<tensorrt::ReshapeOp> {
27662786 op.getLoc (), outputType, softmax.getInput ());
27672787 auto newSoftmax = rewriter.create <tensorrt::SoftMaxOp>(
27682788 softmax.getLoc (), newReshape.getResult (), newAxis);
2789+ assert (op.getType () == newSoftmax.getType ());
27692790 rewriter.replaceOp (op, newSoftmax.getResult ());
27702791 return success ();
27712792 }
@@ -2810,6 +2831,7 @@ class PushDownReshapeSoftmax : public OpRewritePattern<tensorrt::SoftMaxOp> {
28102831 op.getLoc (), reshapeOp.getInput (), newAxis);
28112832 auto newReshape = rewriter.create <tensorrt::ReshapeOp>(
28122833 reshapeOp.getLoc (), outputType, newSoftmax.getResult ());
2834+ assert (op.getType () == newReshape.getType ());
28132835 rewriter.replaceOp (op, newReshape.getResult ());
28142836 return success ();
28152837 }
0 commit comments