@@ -1453,7 +1453,7 @@ class PushReshapeUpThroughEinsum
14531453 std::string newEquationStr = newEquation.generateEquation ();
14541454
14551455 auto newEinsum = rewriter.create <tensorrt::EinsumOp>(
1456- op .getLoc (), op.getType (), newInputs, newEquationStr);
1456+ einsum .getLoc (), op.getType (), newInputs, newEquationStr);
14571457 assert (newEquation.rhs .size () == newEinsum.getType ().getShape ().size ());
14581458 assert (op.getType () == newEinsum.getType ());
14591459 rewriter.replaceOp (op, newEinsum.getResult ());
@@ -1601,6 +1601,7 @@ class PushReshapeDownThroughEinsum
16011601 // get pushed down as reshapes might need to get added to other inputs to
16021602 // make the shapes work
16031603 bool hasReshapeInput = false ;
1604+ Location reshapeLoc = op.getLoc ();
16041605 for (auto input : op.getInputs ()) {
16051606 if (!cast<RankedTensorType>(input.getType ()).hasStaticShape ()) {
16061607 return failure (/* dynamic input not supported */ );
@@ -1609,6 +1610,7 @@ class PushReshapeDownThroughEinsum
16091610 if (!reshape.getInput ().getType ().hasStaticShape ())
16101611 return failure (/* dynamic reshape input not supported */ );
16111612 hasReshapeInput = true ;
1613+ reshapeLoc = reshape.getLoc ();
16121614 }
16131615 }
16141616 if (!hasReshapeInput)
@@ -1865,11 +1867,11 @@ class PushReshapeDownThroughEinsum
18651867 assert (newEquation.rhs .size () == newEinsum.getType ().getShape ().size ());
18661868
18671869 auto newReshape = rewriter.createOrFold <tensorrt::ReshapeOp>(
1868- op. getLoc () , outputType.clone (afterEinsumReshape),
1870+ reshapeLoc , outputType.clone (afterEinsumReshape),
18691871 newEinsum.getResult ());
18701872
18711873 Value newOut = rewriter.createOrFold <tensorrt::TransposeOp>(
1872- op. getLoc () , newReshape,
1874+ reshapeLoc , newReshape,
18731875 AffineMap::getPermutationMap (afterReshapeTranspose, op.getContext ()));
18741876
18751877 assert (op.getType () == newOut.getType ());
@@ -2186,7 +2188,7 @@ class MoveTransposeBeforeReshape
21862188 AffineMap::getPermutationMap (newTranspose, op.getContext ()));
21872189 }
21882190 Value newReshapeOp = rewriter.createOrFold <tensorrt::ReshapeOp>(
2189- op .getLoc (), reshapeInputType.clone (newReshape), newTransposeOp);
2191+ reshape .getLoc (), reshapeInputType.clone (newReshape), newTransposeOp);
21902192
21912193 assert (op.getType () == newReshapeOp.getType ());
21922194 rewriter.replaceOp (op, newReshapeOp);
@@ -2210,8 +2212,11 @@ class PushDownReshapeActivationRewriter
22102212 auto activationOp = rewriter.create <tensorrt::ActivationOp>(
22112213 op.getLoc (), producer.getInput (), op.getActivationType (),
22122214 op.getAlphaAttr (), op.getBetaAttr ());
2213- rewriter.replaceOpWithNewOp <tensorrt::ReshapeOp>(
2214- op, op.getType (), activationOp.getResult (), producer.getShape ());
2215+ auto reshapeOp = rewriter.createOrFold <tensorrt::ReshapeOp>(
2216+ producer.getLoc (), op.getType (), activationOp.getResult (),
2217+ producer.getShape ());
2218+ assert (op.getType () == reshapeOp.getType ());
2219+ rewriter.replaceOp (op, reshapeOp);
22152220 return success ();
22162221 }
22172222};
@@ -2231,8 +2236,11 @@ class PushDownReshapeUnaryRewriter
22312236
22322237 auto unaryOp = rewriter.create <tensorrt::UnaryOp>(
22332238 op.getLoc (), producer.getInput (), op.getUnaryOperationAttr ());
2234- rewriter.replaceOpWithNewOp <tensorrt::ReshapeOp>(
2235- op, op.getType (), unaryOp.getResult (), producer.getShape ());
2239+ auto reshapeOp = rewriter.createOrFold <tensorrt::ReshapeOp>(
2240+ producer.getLoc (), op.getType (), unaryOp.getResult (),
2241+ producer.getShape ());
2242+ assert (op.getType () == reshapeOp.getType ());
2243+ rewriter.replaceOp (op, reshapeOp);
22362244 return success ();
22372245 }
22382246};
@@ -2252,12 +2260,13 @@ class PushDownReshapeIdentityRewriter
22522260
22532261 RankedTensorType newIdentityType =
22542262 producer.getInput ().getType ().clone (op.getType ().getElementType ());
2255-
22562263 Value newIdentityResult = rewriter.create <IdentityOp>(
22572264 op.getLoc (), newIdentityType, producer.getInput ());
2258-
2259- rewriter.replaceOpWithNewOp <tensorrt::ReshapeOp>(
2260- op, op.getType (), newIdentityResult, producer.getShape ());
2265+ auto reshapeOp = rewriter.createOrFold <tensorrt::ReshapeOp>(
2266+ producer.getLoc (), op.getType (), newIdentityResult,
2267+ producer.getShape ());
2268+ assert (op.getType () == reshapeOp.getType ());
2269+ rewriter.replaceOp (op, reshapeOp);
22612270 return success ();
22622271 }
22632272};
@@ -2280,8 +2289,11 @@ class PushUpReshapeUnary : public OpRewritePattern<tensorrt::ReshapeOp> {
22802289
22812290 Value newReshapeResult = rewriter.create <tensorrt::ReshapeOp>(
22822291 op.getLoc (), reshapeType, producer.getInput (), op.getShape ());
2283- rewriter.replaceOpWithNewOp <OpType>(op, op.getType (), newReshapeResult,
2284- producer->getAttrs ());
2292+ auto newOp =
2293+ rewriter.createOrFold <OpType>(producer.getLoc (), op.getType (),
2294+ newReshapeResult, producer->getAttrs ());
2295+ assert (op.getType () == newOp.getType ());
2296+ rewriter.replaceOp (op, newOp);
22852297 return success ();
22862298 }
22872299};
@@ -2495,6 +2507,7 @@ class PushDownReshapeElementwise
24952507 PatternRewriter &rewriter) const override {
24962508 bool hasReshapeInput = false ;
24972509 uint64_t currentEstimatedCost = 0 ;
2510+ Location reshapeLoc = op.getLoc ();
24982511 for (Value input : op.getOperands ()) {
24992512 if (!cast<RankedTensorType>(input.getType ()).hasStaticShape ()) {
25002513 return failure ();
@@ -2504,6 +2517,7 @@ class PushDownReshapeElementwise
25042517 return failure ();
25052518 }
25062519 hasReshapeInput = true ;
2520+ reshapeLoc = reshape.getLoc ();
25072521 currentEstimatedCost += estimateShuffleCost (input);
25082522 }
25092523 }
@@ -2565,8 +2579,10 @@ class PushDownReshapeElementwise
25652579 auto newElementwiseOp = rewriter.create <tensorrt::ElementWiseOp>(
25662580 op.getLoc (), elementwiseType, newLhs, newRhs,
25672581 op.getElementwiseOperation ());
2568- rewriter.replaceOpWithNewOp <tensorrt::ReshapeOp>(
2569- op, op.getType (), newElementwiseOp.getResult ());
2582+ auto newReshapeOp = rewriter.createOrFold <tensorrt::ReshapeOp>(
2583+ reshapeLoc, op.getType (), newElementwiseOp.getResult ());
2584+ assert (op.getType () == newReshapeOp.getType ());
2585+ rewriter.replaceOp (op, newReshapeOp);
25702586 return success ();
25712587 }
25722588};
0 commit comments