Skip to content

Commit 35bc5fd

Browse files
author
Matthew Francis-Landau
committed
fix issue with PushReshapeUpThroughEinsum when there is a broadcast of one of the axis
1 parent d335aad commit 35bc5fd

File tree

2 files changed

+104
-10
lines changed

2 files changed

+104
-10
lines changed

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp

Lines changed: 92 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,16 +1327,21 @@ class PushReshapeUpThroughEinsum
13271327
auto reshapeInShape = op.getInput().getType().getShape();
13281328
auto reshapeOutShape = op.getResult().getType().getShape();
13291329

1330+
struct ReshapeInfo {
1331+
std::string newAxes;
1332+
SmallVector<int64_t> newShape;
1333+
SmallVector<int64_t> oldShape;
1334+
};
1335+
13301336
bool hasNonTrivalReshape = false;
1331-
std::unordered_map<std::string,
1332-
std::pair<std::string, SmallVector<int64_t>>>
1333-
inputToReshapedMap;
1337+
std::unordered_map<std::string, ReshapeInfo> inputToReshapedMap;
13341338
size_t inputNumElems = 1;
13351339
size_t outputNumElems = 1;
13361340
std::string inAxes = "";
13371341
std::string outAxes = "";
13381342
std::string prevInAxes = "";
13391343
SmallVector<int64_t> outShape;
1344+
SmallVector<int64_t> inShape;
13401345
for (size_t i = 0, j = 0; i < reshapeOutShape.size(); i++) {
13411346
if (reshapeOutShape[i] == 0) {
13421347
return failure(/* 0-shape not supported */);
@@ -1349,26 +1354,30 @@ class PushReshapeUpThroughEinsum
13491354
outAxes += c;
13501355
while (j < reshapeInShape.size() && inputNumElems < outputNumElems) {
13511356
inputNumElems *= reshapeInShape[j];
1357+
inShape.push_back(reshapeInShape[j]);
13521358
inAxes += equation.rhs[j++];
13531359
}
13541360
if (inputNumElems == outputNumElems) {
13551361
if (inAxes.empty()) {
13561362
if (!prevInAxes.empty() && reshapeOutShape[i] == 1 &&
13571363
outAxes.size() == 1) {
13581364
auto &p = inputToReshapedMap[prevInAxes];
1359-
p.first.push_back(c);
1360-
p.second.push_back(1);
1361-
if (prevInAxes.size() != p.first.size())
1365+
p.newAxes.push_back(c);
1366+
p.newShape.push_back(1);
1367+
if (prevInAxes.size() != p.newAxes.size())
13621368
hasNonTrivalReshape = true;
13631369
outAxes = "";
13641370
outShape.clear();
1371+
inShape.clear();
13651372
}
13661373
continue;
13671374
}
13681375
if (inAxes.size() != outAxes.size())
13691376
hasNonTrivalReshape = true;
1370-
inputToReshapedMap[inAxes] = {outAxes, outShape};
1377+
inputToReshapedMap[inAxes] = ReshapeInfo{
1378+
.newAxes = outAxes, .newShape = outShape, .oldShape = inShape};
13711379
outShape.clear();
1380+
inShape.clear();
13721381
prevInAxes = inAxes;
13731382
inAxes = "";
13741383
outAxes = "";
@@ -1405,12 +1414,60 @@ class PushReshapeUpThroughEinsum
14051414
for (char c : equation.rhs) {
14061415
assert(charToGroup.count(c));
14071416
if (charToGroup[c][0] == c)
1408-
newEquation.rhs += inputToReshapedMap[charToGroup[c]].first;
1417+
newEquation.rhs += inputToReshapedMap[charToGroup[c]].newAxes;
14091418
}
14101419

14111420
// generate a `x` -> `reshape(transpose(x))` if necessary
14121421
SmallVector<Value> newInputs;
14131422
newEquation.lhsParts.clear();
1423+
1424+
LLVM_DEBUG({
1425+
std::stringstream out;
1426+
out << "==== Einsum Reshape/Transpose Pushdown Debug ====\n";
1427+
for (const auto &entry : charToGroup) {
1428+
out << " charToGroup[" << entry.first << "] = " << entry.second
1429+
<< "\n";
1430+
}
1431+
for (const auto &entry : inputToReshapedMap) {
1432+
out << " inputToReshapedMap[" << entry.first
1433+
<< "]: axes = " << entry.second.newAxes << ", shape = [";
1434+
for (size_t si = 0; si < entry.second.newShape.size(); ++si) {
1435+
out << entry.second.newShape[si];
1436+
if (si + 1 < entry.second.newShape.size())
1437+
out << ", ";
1438+
}
1439+
out << "], old shape = [";
1440+
for (size_t si = 0; si < entry.second.oldShape.size(); ++si) {
1441+
out << entry.second.oldShape[si];
1442+
if (si + 1 < entry.second.oldShape.size())
1443+
out << ", ";
1444+
}
1445+
out << "]";
1446+
out << "\n";
1447+
}
1448+
DBGS() << out.str();
1449+
});
1450+
1451+
// check that the input shape for all of the inputs match (that there are no
1452+
// broadcasts happening on some inputs)
1453+
for (auto &[inputAxes, reshapeInfo] : inputToReshapedMap) {
1454+
// this is a single axis, so broadcasting is allowed in this case, hence
1455+
// do not check
1456+
if (inputAxes.size() == 1 && reshapeInfo.newAxes.size() == 1)
1457+
continue;
1458+
1459+
for (size_t i = 0; i < einsum.getInputs().size(); i++) {
1460+
auto inputShape =
1461+
cast<RankedTensorType>(einsum.getInputs()[i].getType()).getShape();
1462+
for (size_t j = 0; j < inputAxes.size(); j++) {
1463+
size_t pos = equation.lhsParts[i].find(inputAxes[j]);
1464+
if (pos != std::string::npos &&
1465+
inputShape[pos] != reshapeInfo.oldShape[j])
1466+
return failure(/* input shape does not match output shape*/);
1467+
}
1468+
}
1469+
}
1470+
14141471
for (size_t i = 0; i < einsum.getInputs().size(); i++) {
14151472
Value input = einsum.getInputs()[i];
14161473
auto inputType = cast<RankedTensorType>(input.getType());
@@ -1434,11 +1491,34 @@ class PushReshapeUpThroughEinsum
14341491
// group
14351492
for (char c : group->second)
14361493
newInputTranspose.push_back(equation.lhsParts[i].find(c));
1437-
newInputEquation += inputToReshapedMap[group->second].first;
1438-
for (int64_t v : inputToReshapedMap[group->second].second)
1494+
newInputEquation += inputToReshapedMap[group->second].newAxes;
1495+
for (int64_t v : inputToReshapedMap[group->second].newShape)
14391496
newInputShape.push_back(v);
14401497
}
14411498
}
1499+
1500+
// Debug print for this input's result
1501+
LLVM_DEBUG({
1502+
std::stringstream out;
1503+
out << "Input #" << i << " orig eq: " << equation.lhsParts[i]
1504+
<< " new eq: " << newInputEquation << "\n";
1505+
out << " newInputTranspose: [";
1506+
for (size_t ti = 0; ti < newInputTranspose.size(); ++ti) {
1507+
out << newInputTranspose[ti];
1508+
if (ti + 1 < newInputTranspose.size())
1509+
out << ", ";
1510+
}
1511+
out << "]\n";
1512+
out << " newInputShape: [";
1513+
for (size_t si = 0; si < newInputShape.size(); ++si) {
1514+
out << newInputShape[si];
1515+
if (si + 1 < newInputShape.size())
1516+
out << ", ";
1517+
}
1518+
out << "]\n";
1519+
DBGS() << out.str() << "\n";
1520+
});
1521+
14421522
auto newTranspose = rewriter.createOrFold<tensorrt::TransposeOp>(
14431523
op.getLoc(), input,
14441524
AffineMap::getPermutationMap(newInputTranspose,
@@ -1449,6 +1529,8 @@ class PushReshapeUpThroughEinsum
14491529
newInputs.push_back(newReshape);
14501530
newEquation.lhsParts.push_back(newInputEquation);
14511531
}
1532+
LLVM_DEBUG(
1533+
{ DBGS() << "===============================================\n"; });
14521534

14531535
std::string newEquationStr = newEquation.generateEquation();
14541536

mlir-tensorrt/tensorrt/test/Dialect/TensorRT/transpose-reshape-elimination.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,3 +336,15 @@ func.func @einsum_multiply_two_axis(%arg0: tensor<10x11x12xf32>, %arg1: tensor<1
336336
%0 = tensorrt.einsum {equation = "acd,bcd->ab"} ins(%arg0, %arg1: tensor<10x11x12xf32>, tensor<13x11x12xf32>) -> tensor<10x13xf32>
337337
return %0 : tensor<10x13xf32>
338338
}
339+
340+
// -----
341+
342+
// CHECK: @can_not_push_reshape_through_einsum(%[[arg0:.+]]: tensor<2x20x12x64xf32>, %[[arg1:.+]]: tensor<2x12x20x1xf32>)
343+
// CHECK: %[[v0:.+]] = tensorrt.einsum {{{.*}}} ins(%[[arg0]], %[[arg1]] : {{.*}})
344+
// CHECK: %[[v1:.+]] = tensorrt.reshape %[[v0]] : tensor<2x12x64xf32> to tensor<2x1x768xf32>
345+
// CHECK: return %[[v1]]
346+
func.func @can_not_push_reshape_through_einsum(%arg0: tensor<2x20x12x64xf32>, %arg1: tensor<2x12x20x1xf32>) -> tensor<2x1x768xf32>{
347+
%0 = tensorrt.einsum {equation = "acbd,abcd->abd"} ins(%arg0, %arg1 : tensor<2x20x12x64xf32>, tensor<2x12x20x1xf32>) -> tensor<2x12x64xf32>
348+
%1 = tensorrt.reshape %0 : tensor<2x12x64xf32> to tensor<2x1x768xf32>
349+
return %1 : tensor<2x1x768xf32>
350+
}

0 commit comments

Comments
 (0)