Skip to content

Commit 1da7863

Browse files
author
Matthew Francis-Landau
committed
handle edge case when there are still axes that need to get processed on the output for the reshape groups
1 parent c3b4d85 commit 1da7863

File tree

2 files changed

+71
-5
lines changed

2 files changed

+71
-5
lines changed

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1635,7 +1635,8 @@ class MoveReshapeBeforeTranspose
16351635
SmallVector<int64_t> groupReshapeOut;
16361636
size_t inputNumElems = 1;
16371637
size_t outputNumElems = 1;
1638-
for (int i = 0, j = 0; i < reshapeInputType.getRank(); i++) {
1638+
int j = 0;
1639+
for (int i = 0; i < reshapeInputType.getRank(); i++) {
16391640
inputNumElems *= reshapeInputType.getDimSize(i);
16401641
if (!transposeInAxes.empty() &&
16411642
transposeInAxes.back() + 1 != transposePerm[i]) {
@@ -1740,7 +1741,8 @@ class MoveTransposeBeforeReshape
17401741
SmallVector<int64_t> groupReshapeOut;
17411742
size_t inputNumElems = 1;
17421743
size_t outputNumElems = 1;
1743-
for (int i = 0, j = 0; i < reshapeInputType.getRank(); i++) {
1744+
int j = 0;
1745+
for (int i = 0; i < reshapeInputType.getRank(); i++) {
17441746
inputNumElems *= reshapeInputType.getDimSize(i);
17451747
inputAxes.push_back(i);
17461748
while (j < reshapeOutputType.getRank() &&
@@ -1764,6 +1766,21 @@ class MoveTransposeBeforeReshape
17641766
groupReshapeOut.clear();
17651767
}
17661768
}
1769+
while (j < reshapeOutputType.getRank()) {
1770+
outputNumElems *= reshapeOutputType.getDimSize(j);
1771+
groupReshapeOut.push_back(reshapeOutputType.getDimSize(j));
1772+
outputAxes.push_back(transposePerm[j++]);
1773+
}
1774+
1775+
assert(inputNumElems == outputNumElems);
1776+
assert(inputAxes.empty());
1777+
if (!outputAxes.empty() || !groupReshapeOut.empty()) {
1778+
reshapeGroups.push_back(ReshapeGroup{
1779+
.inputAxes = inputAxes,
1780+
.outputAxes = outputAxes,
1781+
.reshapeOut = groupReshapeOut,
1782+
});
1783+
}
17671784

17681785
SmallVector<int64_t> newTranspose;
17691786
SmallVector<int64_t> newReshape;
@@ -1776,6 +1793,38 @@ class MoveTransposeBeforeReshape
17761793
return a.outputAxes[0] < b.outputAxes[0];
17771794
});
17781795

1796+
// Debug print of reshapeGroups
1797+
LLVM_DEBUG({
1798+
std::stringstream out;
1799+
out << "reshapeGroups:\n";
1800+
for (size_t idx = 0; idx < reshapeGroups.size(); ++idx) {
1801+
const auto &group = reshapeGroups[idx];
1802+
out << " Group " << idx << ":\n";
1803+
out << " inputAxes: [";
1804+
for (size_t i = 0; i < group.inputAxes.size(); ++i) {
1805+
out << group.inputAxes[i];
1806+
if (i + 1 < group.inputAxes.size())
1807+
out << ", ";
1808+
}
1809+
out << "]\n";
1810+
out << " outputAxes: [";
1811+
for (size_t i = 0; i < group.outputAxes.size(); ++i) {
1812+
out << group.outputAxes[i];
1813+
if (i + 1 < group.outputAxes.size())
1814+
DBGS() << ", ";
1815+
}
1816+
out << "]\n";
1817+
out << " reshapeOut: [";
1818+
for (size_t i = 0; i < group.reshapeOut.size(); ++i) {
1819+
out << group.reshapeOut[i];
1820+
if (i + 1 < group.reshapeOut.size())
1821+
DBGS() << ", ";
1822+
}
1823+
out << "]\n";
1824+
}
1825+
DBGS() << out.str();
1826+
});
1827+
17791828
for (auto &group : reshapeGroups) {
17801829
for (int64_t i : group.inputAxes)
17811830
newTranspose.push_back(i);
@@ -1789,6 +1838,8 @@ class MoveTransposeBeforeReshape
17891838
Value newReshapeOp = rewriter.createOrFold<tensorrt::ReshapeOp>(
17901839
op.getLoc(), reshapeInputType.clone(newReshape), newTransposeOp);
17911840

1841+
assert(op.getType() == newReshapeOp.getType());
1842+
17921843
rewriter.replaceOp(op, newReshapeOp);
17931844
return success();
17941845
}
@@ -2244,15 +2295,15 @@ class MatrixMultiplyTransposedArguments
22442295
return std::make_tuple(arg, operation);
22452296
bool swapsLastTwo = true;
22462297
for (int64_t i = 0; i < rank - 2; ++i) {
2247-
auto expr = permVec[i].dyn_cast<AffineDimExpr>();
2298+
auto expr = dyn_cast<AffineDimExpr>(permVec[i]);
22482299
if (!expr || expr.getPosition() != i) {
22492300
swapsLastTwo = false;
22502301
break;
22512302
}
22522303
}
22532304
if (swapsLastTwo) {
2254-
auto expr1 = permVec[rank - 2].dyn_cast<AffineDimExpr>();
2255-
auto expr2 = permVec[rank - 1].dyn_cast<AffineDimExpr>();
2305+
auto expr1 = dyn_cast<AffineDimExpr>(permVec[rank - 2]);
2306+
auto expr2 = dyn_cast<AffineDimExpr>(permVec[rank - 1]);
22562307
if (!(expr1 && expr2 && expr1.getPosition() == rank - 1 &&
22572308
expr2.getPosition() == rank - 2)) {
22582309
swapsLastTwo = false;

mlir-tensorrt/tensorrt/test/Dialect/TensorRT/transpose-reshape-elimination.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,18 @@ func.func @matmul_eliminate_reshape_lhs_2(%arg0: tensor<1x2x3x4x5x6xf16>, %arg1:
187187
%2 = tensorrt.reshape %1 : tensor<1x2x60x8xf16> to tensor<1x2x3x4x5x8xf16>
188188
return %2: tensor<1x2x3x4x5x8xf16>
189189
}
190+
191+
// -----
192+
193+
// CHECK: @elementwise_reshape(%[[arg0:.+]]: tensor<12x3x3xf32>, %[[arg1:.+]]: tensor<12xf32>)
194+
// CHECK: %[[v0:.+]] = tensorrt.expand_rank %[[arg1]] : tensor<12xf32> to tensor<12x1x1xf32>
195+
// CHECK: %[[v1:.+]] = tensorrt.element_wise <kDIV>(%[[arg0]], %[[v0]] : tensor<12x3x3xf32>, tensor<12x1x1xf32>) -> tensor<12x3x3xf32>
196+
// CHECK: %[[v2:.+]] = tensorrt.transpose {permutation = #map} %[[v1]] : tensor<12x3x3xf32> to tensor<12x3x3xf32>
197+
// CHECK: return %[[v2]]
198+
#map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
199+
func.func @elementwise_reshape(%arg0: tensor<12x3x3xf32>, %arg1: tensor<12xf32>) -> tensor<12x3x3xf32> {
200+
%0 = tensorrt.transpose {permutation = #map} %arg0 : tensor<12x3x3xf32> to tensor<12x3x3xf32>
201+
%1 = tensorrt.expand_rank %arg1 : tensor<12xf32> to tensor<12x1x1xf32>
202+
%2 = tensorrt.element_wise <kDIV>(%0, %1 : tensor<12x3x3xf32>, tensor<12x1x1xf32>) -> tensor<12x3x3xf32>
203+
return %2 : tensor<12x3x3xf32>
204+
}

0 commit comments

Comments
 (0)