Skip to content

Commit d335aad

Browse files
author
Matthew Francis-Landau
committed
slightly better tracking of Location as moving ops around
1 parent 2d49cae commit d335aad

File tree

1 file changed

+32
-16
lines changed

1 file changed

+32
-16
lines changed

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1453,7 +1453,7 @@ class PushReshapeUpThroughEinsum
14531453
std::string newEquationStr = newEquation.generateEquation();
14541454

14551455
auto newEinsum = rewriter.create<tensorrt::EinsumOp>(
1456-
op.getLoc(), op.getType(), newInputs, newEquationStr);
1456+
einsum.getLoc(), op.getType(), newInputs, newEquationStr);
14571457
assert(newEquation.rhs.size() == newEinsum.getType().getShape().size());
14581458
assert(op.getType() == newEinsum.getType());
14591459
rewriter.replaceOp(op, newEinsum.getResult());
@@ -1601,6 +1601,7 @@ class PushReshapeDownThroughEinsum
16011601
// get pushed down as reshapes might need to get added to other inputs to
16021602
// make the shapes work
16031603
bool hasReshapeInput = false;
1604+
Location reshapeLoc = op.getLoc();
16041605
for (auto input : op.getInputs()) {
16051606
if (!cast<RankedTensorType>(input.getType()).hasStaticShape()) {
16061607
return failure(/* dynamic input not supported */);
@@ -1609,6 +1610,7 @@ class PushReshapeDownThroughEinsum
16091610
if (!reshape.getInput().getType().hasStaticShape())
16101611
return failure(/* dynamic reshape input not supported */);
16111612
hasReshapeInput = true;
1613+
reshapeLoc = reshape.getLoc();
16121614
}
16131615
}
16141616
if (!hasReshapeInput)
@@ -1865,11 +1867,11 @@ class PushReshapeDownThroughEinsum
18651867
assert(newEquation.rhs.size() == newEinsum.getType().getShape().size());
18661868

18671869
auto newReshape = rewriter.createOrFold<tensorrt::ReshapeOp>(
1868-
op.getLoc(), outputType.clone(afterEinsumReshape),
1870+
reshapeLoc, outputType.clone(afterEinsumReshape),
18691871
newEinsum.getResult());
18701872

18711873
Value newOut = rewriter.createOrFold<tensorrt::TransposeOp>(
1872-
op.getLoc(), newReshape,
1874+
reshapeLoc, newReshape,
18731875
AffineMap::getPermutationMap(afterReshapeTranspose, op.getContext()));
18741876

18751877
assert(op.getType() == newOut.getType());
@@ -2186,7 +2188,7 @@ class MoveTransposeBeforeReshape
21862188
AffineMap::getPermutationMap(newTranspose, op.getContext()));
21872189
}
21882190
Value newReshapeOp = rewriter.createOrFold<tensorrt::ReshapeOp>(
2189-
op.getLoc(), reshapeInputType.clone(newReshape), newTransposeOp);
2191+
reshape.getLoc(), reshapeInputType.clone(newReshape), newTransposeOp);
21902192

21912193
assert(op.getType() == newReshapeOp.getType());
21922194
rewriter.replaceOp(op, newReshapeOp);
@@ -2210,8 +2212,11 @@ class PushDownReshapeActivationRewriter
22102212
auto activationOp = rewriter.create<tensorrt::ActivationOp>(
22112213
op.getLoc(), producer.getInput(), op.getActivationType(),
22122214
op.getAlphaAttr(), op.getBetaAttr());
2213-
rewriter.replaceOpWithNewOp<tensorrt::ReshapeOp>(
2214-
op, op.getType(), activationOp.getResult(), producer.getShape());
2215+
auto reshapeOp = rewriter.createOrFold<tensorrt::ReshapeOp>(
2216+
producer.getLoc(), op.getType(), activationOp.getResult(),
2217+
producer.getShape());
2218+
assert(op.getType() == reshapeOp.getType());
2219+
rewriter.replaceOp(op, reshapeOp);
22152220
return success();
22162221
}
22172222
};
@@ -2231,8 +2236,11 @@ class PushDownReshapeUnaryRewriter
22312236

22322237
auto unaryOp = rewriter.create<tensorrt::UnaryOp>(
22332238
op.getLoc(), producer.getInput(), op.getUnaryOperationAttr());
2234-
rewriter.replaceOpWithNewOp<tensorrt::ReshapeOp>(
2235-
op, op.getType(), unaryOp.getResult(), producer.getShape());
2239+
auto reshapeOp = rewriter.createOrFold<tensorrt::ReshapeOp>(
2240+
producer.getLoc(), op.getType(), unaryOp.getResult(),
2241+
producer.getShape());
2242+
assert(op.getType() == reshapeOp.getType());
2243+
rewriter.replaceOp(op, reshapeOp);
22362244
return success();
22372245
}
22382246
};
@@ -2252,12 +2260,13 @@ class PushDownReshapeIdentityRewriter
22522260

22532261
RankedTensorType newIdentityType =
22542262
producer.getInput().getType().clone(op.getType().getElementType());
2255-
22562263
Value newIdentityResult = rewriter.create<IdentityOp>(
22572264
op.getLoc(), newIdentityType, producer.getInput());
2258-
2259-
rewriter.replaceOpWithNewOp<tensorrt::ReshapeOp>(
2260-
op, op.getType(), newIdentityResult, producer.getShape());
2265+
auto reshapeOp = rewriter.createOrFold<tensorrt::ReshapeOp>(
2266+
producer.getLoc(), op.getType(), newIdentityResult,
2267+
producer.getShape());
2268+
assert(op.getType() == reshapeOp.getType());
2269+
rewriter.replaceOp(op, reshapeOp);
22612270
return success();
22622271
}
22632272
};
@@ -2280,8 +2289,11 @@ class PushUpReshapeUnary : public OpRewritePattern<tensorrt::ReshapeOp> {
22802289

22812290
Value newReshapeResult = rewriter.create<tensorrt::ReshapeOp>(
22822291
op.getLoc(), reshapeType, producer.getInput(), op.getShape());
2283-
rewriter.replaceOpWithNewOp<OpType>(op, op.getType(), newReshapeResult,
2284-
producer->getAttrs());
2292+
auto newOp =
2293+
rewriter.createOrFold<OpType>(producer.getLoc(), op.getType(),
2294+
newReshapeResult, producer->getAttrs());
2295+
assert(op.getType() == newOp.getType());
2296+
rewriter.replaceOp(op, newOp);
22852297
return success();
22862298
}
22872299
};
@@ -2495,6 +2507,7 @@ class PushDownReshapeElementwise
24952507
PatternRewriter &rewriter) const override {
24962508
bool hasReshapeInput = false;
24972509
uint64_t currentEstimatedCost = 0;
2510+
Location reshapeLoc = op.getLoc();
24982511
for (Value input : op.getOperands()) {
24992512
if (!cast<RankedTensorType>(input.getType()).hasStaticShape()) {
25002513
return failure();
@@ -2504,6 +2517,7 @@ class PushDownReshapeElementwise
25042517
return failure();
25052518
}
25062519
hasReshapeInput = true;
2520+
reshapeLoc = reshape.getLoc();
25072521
currentEstimatedCost += estimateShuffleCost(input);
25082522
}
25092523
}
@@ -2565,8 +2579,10 @@ class PushDownReshapeElementwise
25652579
auto newElementwiseOp = rewriter.create<tensorrt::ElementWiseOp>(
25662580
op.getLoc(), elementwiseType, newLhs, newRhs,
25672581
op.getElementwiseOperation());
2568-
rewriter.replaceOpWithNewOp<tensorrt::ReshapeOp>(
2569-
op, op.getType(), newElementwiseOp.getResult());
2582+
auto newReshapeOp = rewriter.createOrFold<tensorrt::ReshapeOp>(
2583+
reshapeLoc, op.getType(), newElementwiseOp.getResult());
2584+
assert(op.getType() == newReshapeOp.getType());
2585+
rewriter.replaceOp(op, newReshapeOp);
25702586
return success();
25712587
}
25722588
};

0 commit comments

Comments
 (0)