Skip to content

Commit deb8e73

Browse files
author
Matthew Francis-Landau
committed
fix issues with push down transpose from einsum
1 parent 9eb2d2d commit deb8e73

File tree

2 files changed

+56
-16
lines changed

2 files changed

+56
-16
lines changed

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,8 @@ struct PushDownTransposeActivationRewriter
164164
auto activationOp = rewriter.create<ActivationOp>(
165165
op.getLoc(), producer.getInput(), op.getActivationType(),
166166
op.getAlphaAttr(), op.getBetaAttr());
167-
auto newTranspose = rewriter.create<TransposeOp>(producer.getLoc(), activationOp.getResult(), permutation);
167+
auto newTranspose = rewriter.create<TransposeOp>(
168+
producer.getLoc(), activationOp.getResult(), permutation);
168169
rewriter.replaceOp(op, newTranspose.getResult());
169170
return success();
170171
}
@@ -181,7 +182,8 @@ struct PushDownTransposeUnary : OpRewritePattern<UnaryOp> {
181182
AffineMap permutation = producer.getPermutation();
182183
auto unary = rewriter.create<UnaryOp>(op.getLoc(), producer.getInput(),
183184
op.getUnaryOperationAttr());
184-
auto newTranspose = rewriter.create<TransposeOp>(producer.getLoc(), unary.getResult(), permutation);
185+
auto newTranspose = rewriter.create<TransposeOp>(
186+
producer.getLoc(), unary.getResult(), permutation);
185187
rewriter.replaceOp(op, newTranspose.getResult());
186188
return success();
187189
}
@@ -755,13 +757,23 @@ class EinsumPushDownTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
755757
return a.first < b.first;
756758
});
757759

760+
LLVM_DEBUG({
761+
std::stringstream out;
762+
out << "outputAxes: [";
763+
for (auto x : outputAxes) {
764+
out << x.first << "(" << x.second << ") ";
765+
}
766+
out << "]\n";
767+
DBGS() << out.str();
768+
});
769+
758770
SmallVector<int64_t> newEinsumShape;
759-
SmallVector<int64_t> outputPerm;
771+
SmallVector<int64_t> forwardPerm;
760772
std::string newEinsumRhs = "";
761773
for (auto &[c, i] : outputAxes) {
762774
newEinsumRhs += c;
763775
newEinsumShape.push_back(op.getType().getDimSize(i));
764-
outputPerm.push_back(i);
776+
forwardPerm.push_back(i);
765777
}
766778
if (newEinsumRhs == equation.rhs)
767779
return failure(); // no change
@@ -773,10 +785,13 @@ class EinsumPushDownTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
773785
op.getLoc(), op.getType().clone(newEinsumShape), op.getInputs(),
774786
newEinsumEquation);
775787

788+
auto forwardMap =
789+
AffineMap::getPermutationMap(forwardPerm, op.getLoc().getContext());
790+
776791
auto newTranspose = rewriter.create<tensorrt::TransposeOp>(
777-
op.getLoc(), newEinsum.getResult(),
778-
AffineMap::getPermutationMap(outputPerm, op.getLoc().getContext()));
792+
op.getLoc(), newEinsum.getResult(), inversePermutation(forwardMap));
779793

794+
assert(op.getType() == newTranspose.getType());
780795
rewriter.replaceOp(op, newTranspose.getResult());
781796

782797
return success();
@@ -1662,29 +1677,29 @@ class MoveReshapeBeforeTranspose
16621677
}
16631678
}
16641679
assert(inputNumElems == outputNumElems);
1665-
while(j < reshapeOutputType.getRank()) {
1680+
while (j < reshapeOutputType.getRank()) {
16661681
outputNumElems *= reshapeOutputType.getDimSize(j);
16671682
groupReshapeOut.push_back(reshapeOutputType.getDimSize(j));
16681683
transposeOutAxes.push_back(j++);
16691684
}
16701685
assert(inputNumElems == outputNumElems);
16711686
assert(transposeInAxes.empty());
1672-
if(!transposeOutAxes.empty() || !groupReshapeOut.empty()) {
1687+
if (!transposeOutAxes.empty() || !groupReshapeOut.empty()) {
16731688
reshapeGroups.push_back(ReshapeGroup{
16741689
.transposeInAxes = transposeInAxes,
16751690
.transposeOutAxes = transposeOutAxes,
16761691
.reshapeOut = groupReshapeOut,
16771692
.startOutputIdx = -1, // set later
1678-
});
1693+
});
16791694
}
16801695

16811696
SmallVector<int64_t> newTranspose;
16821697
SmallVector<int64_t> newReshape;
16831698

16841699
std::sort(reshapeGroups.begin(), reshapeGroups.end(), [](auto &a, auto &b) {
1685-
if(a.transposeInAxes.empty())
1700+
if (a.transposeInAxes.empty())
16861701
return false;
1687-
if(b.transposeInAxes.empty())
1702+
if (b.transposeInAxes.empty())
16881703
return true;
16891704
return a.transposeInAxes[0] < b.transposeInAxes[0];
16901705
});
@@ -1713,28 +1728,29 @@ class MoveReshapeBeforeTranspose
17131728
out << " transposeInAxes: [";
17141729
for (size_t i = 0; i < group.transposeInAxes.size(); ++i) {
17151730
out << group.transposeInAxes[i];
1716-
if (i + 1 < group.transposeInAxes.size()) out << ", ";
1731+
if (i + 1 < group.transposeInAxes.size())
1732+
out << ", ";
17171733
}
17181734
out << "]\n";
17191735
out << " transposeOutAxes: [";
17201736
for (size_t i = 0; i < group.transposeOutAxes.size(); ++i) {
17211737
out << group.transposeOutAxes[i];
1722-
if (i + 1 < group.transposeOutAxes.size()) out << ", ";
1738+
if (i + 1 < group.transposeOutAxes.size())
1739+
out << ", ";
17231740
}
17241741
out << "]\n";
17251742
out << " reshapeOut: [";
17261743
for (size_t i = 0; i < group.reshapeOut.size(); ++i) {
17271744
out << group.reshapeOut[i];
1728-
if (i + 1 < group.reshapeOut.size()) out << ", ";
1745+
if (i + 1 < group.reshapeOut.size())
1746+
out << ", ";
17291747
}
17301748
out << "]\n";
17311749
out << " startOutputIdx: " << group.startOutputIdx << "\n";
17321750
}
17331751
DBGS() << out.str();
17341752
});
17351753

1736-
1737-
17381754
for (auto &group : reshapeGroups) {
17391755
for (size_t i = 0; i < group.reshapeOut.size(); i++)
17401756
newTranspose.push_back(group.startOutputIdx + i);

mlir-tensorrt/tensorrt/test/Dialect/TensorRT/transpose-reshape-elimination.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,4 +287,28 @@ func.func @reshape_transpose_reorder_ones_dim(%arg0: tensor<2x1x1x1x1xf32>, %arg
287287
%3 = tensorrt.reshape %2 : tensor<2x1x1x1x1xf32> to tensor<2x1x1x1xf32>
288288
%4 = tensorrt.deconvolution {dilation = array<i64: 1, 1>, num_groups = 2 : ui32, post_padding = array<i64: 0, 0>, pre_padding = array<i64: 0, 0>, stride = array<i64: 1, 2>} in(%arg1 : tensor<1x2x3x3xf32>) kernelWeights(%3 : tensor<2x1x1x1xf32>) -> tensor<1x2x3x5xf32>
289289
return %4 : tensor<1x2x3x5xf32>
290+
}
291+
292+
// -----
293+
294+
295+
// CHECK: @push_down_transpose_einsum(%[[arg0:.+]]: tensor<1x6x1500x64xf32>, %[[arg1:.+]]: tensor<1x6x1500x1500xf32>) -> tensor<1x1500x384xf32>
296+
// CHECK-DAG: %[[const0:.+]] = tensorrt.constant dense<1.000000e+00> : tensor<384x6x64xf32>
297+
// CHECK-DAG: %[[v0:.+]] = tensorrt.collapse_rank %[[arg0]] : tensor<1x6x1500x64xf32> to tensor<6x1500x64xf32>
298+
// CHECK-DAG: %[[v1:.+]] = tensorrt.collapse_rank %[[arg1]] : tensor<1x6x1500x1500xf32> to tensor<6x1500x1500xf32>
299+
// CHECK: %[[v2:.+]] = tensorrt.matrix_multiply [[params:.+]] ins(%[[v0]], %[[v1]] : tensor<6x1500x64xf32>, tensor<6x1500x1500xf32>) -> tensor<6x64x1500xf32>
300+
// CHECK: %[[v3:.+]] = tensorrt.einsum [[params2:.+]] ins(%[[v2]], %[[const0]] : tensor<6x64x1500xf32>, tensor<384x6x64xf32>) -> tensor<1500x384xf32>
301+
// CHECK: %[[v4:.+]] = tensorrt.expand_rank %[[v3:.+]] : tensor<1500x384xf32> to tensor<1x1500x384xf32>
302+
// CHECK: return %[[v4]]
303+
func.func @push_down_transpose_einsum(%arg0: tensor<1x6x1500x64xf32>, %arg1: tensor<1x6x1500x1500xf32>) -> tensor<1x1500x384xf32> {
304+
%cst_f32 = tensorrt.constant dense<1.000000e+00> : tensor<384x384xf32>
305+
%0 = tensorrt.reshape %arg0 : tensor<1x6x1500x64xf32> to tensor<6x1500x64xf32>
306+
%1 = tensorrt.reshape %arg1 : tensor<1x6x1500x1500xf32> to tensor<6x1500x1500xf32>
307+
%2 = tensorrt.einsum {equation = "bcd,bec->ebd"} ins(%0, %1 : tensor<6x1500x64xf32>, tensor<6x1500x1500xf32>) -> tensor<1500x6x64xf32>
308+
%3 = tensorrt.reshape %2 : tensor<1500x6x64xf32> to tensor<1x1500x6x64xf32>
309+
%4 = tensorrt.reshape %2 : tensor<1500x6x64xf32> to tensor<1500x384xf32>
310+
%cst_f32_0 = tensorrt.constant dense<1.000000e+00> : tensor<384x6x64xf32>
311+
%5 = tensorrt.einsum {equation = "bde,cde->bc"} ins(%2, %cst_f32_0 : tensor<1500x6x64xf32>, tensor<384x6x64xf32>) -> tensor<1500x384xf32>
312+
%6 = tensorrt.reshape %5 : tensor<1500x384xf32> to tensor<1x1500x384xf32>
313+
return %6 : tensor<1x1500x384xf32>
290314
}

0 commit comments

Comments
 (0)