@@ -1327,16 +1327,21 @@ class PushReshapeUpThroughEinsum
13271327 auto reshapeInShape = op.getInput ().getType ().getShape ();
13281328 auto reshapeOutShape = op.getResult ().getType ().getShape ();
13291329
1330+ struct ReshapeInfo {
1331+ std::string newAxes;
1332+ SmallVector<int64_t > newShape;
1333+ SmallVector<int64_t > oldShape;
1334+ };
1335+
13301336 bool hasNonTrivalReshape = false ;
1331- std::unordered_map<std::string,
1332- std::pair<std::string, SmallVector<int64_t >>>
1333- inputToReshapedMap;
1337+ std::unordered_map<std::string, ReshapeInfo> inputToReshapedMap;
13341338 size_t inputNumElems = 1 ;
13351339 size_t outputNumElems = 1 ;
13361340 std::string inAxes = " " ;
13371341 std::string outAxes = " " ;
13381342 std::string prevInAxes = " " ;
13391343 SmallVector<int64_t > outShape;
1344+ SmallVector<int64_t > inShape;
13401345 for (size_t i = 0 , j = 0 ; i < reshapeOutShape.size (); i++) {
13411346 if (reshapeOutShape[i] == 0 ) {
13421347 return failure (/* 0-shape not supported */ );
@@ -1349,26 +1354,30 @@ class PushReshapeUpThroughEinsum
13491354 outAxes += c;
13501355 while (j < reshapeInShape.size () && inputNumElems < outputNumElems) {
13511356 inputNumElems *= reshapeInShape[j];
1357+ inShape.push_back (reshapeInShape[j]);
13521358 inAxes += equation.rhs [j++];
13531359 }
13541360 if (inputNumElems == outputNumElems) {
13551361 if (inAxes.empty ()) {
13561362 if (!prevInAxes.empty () && reshapeOutShape[i] == 1 &&
13571363 outAxes.size () == 1 ) {
13581364 auto &p = inputToReshapedMap[prevInAxes];
1359- p.first .push_back (c);
1360- p.second .push_back (1 );
1361- if (prevInAxes.size () != p.first .size ())
1365+ p.newAxes .push_back (c);
1366+ p.newShape .push_back (1 );
1367+ if (prevInAxes.size () != p.newAxes .size ())
13621368 hasNonTrivalReshape = true ;
13631369 outAxes = " " ;
13641370 outShape.clear ();
1371+ inShape.clear ();
13651372 }
13661373 continue ;
13671374 }
13681375 if (inAxes.size () != outAxes.size ())
13691376 hasNonTrivalReshape = true ;
1370- inputToReshapedMap[inAxes] = {outAxes, outShape};
1377+ inputToReshapedMap[inAxes] = ReshapeInfo{
1378+ .newAxes = outAxes, .newShape = outShape, .oldShape = inShape};
13711379 outShape.clear ();
1380+ inShape.clear ();
13721381 prevInAxes = inAxes;
13731382 inAxes = " " ;
13741383 outAxes = " " ;
@@ -1405,12 +1414,60 @@ class PushReshapeUpThroughEinsum
14051414 for (char c : equation.rhs ) {
14061415 assert (charToGroup.count (c));
14071416 if (charToGroup[c][0 ] == c)
1408- newEquation.rhs += inputToReshapedMap[charToGroup[c]].first ;
1417+ newEquation.rhs += inputToReshapedMap[charToGroup[c]].newAxes ;
14091418 }
14101419
14111420 // generate a `x` -> `reshape(transpose(x))` if necessary
14121421 SmallVector<Value> newInputs;
14131422 newEquation.lhsParts .clear ();
1423+
1424+ LLVM_DEBUG ({
1425+ std::stringstream out;
1426+ out << " ==== Einsum Reshape/Transpose Pushdown Debug ====\n " ;
1427+ for (const auto &entry : charToGroup) {
1428+ out << " charToGroup[" << entry.first << " ] = " << entry.second
1429+ << " \n " ;
1430+ }
1431+ for (const auto &entry : inputToReshapedMap) {
1432+ out << " inputToReshapedMap[" << entry.first
1433+ << " ]: axes = " << entry.second .newAxes << " , shape = [" ;
1434+ for (size_t si = 0 ; si < entry.second .newShape .size (); ++si) {
1435+ out << entry.second .newShape [si];
1436+ if (si + 1 < entry.second .newShape .size ())
1437+ out << " , " ;
1438+ }
1439+ out << " ], old shape = [" ;
1440+ for (size_t si = 0 ; si < entry.second .oldShape .size (); ++si) {
1441+ out << entry.second .oldShape [si];
1442+ if (si + 1 < entry.second .oldShape .size ())
1443+ out << " , " ;
1444+ }
1445+ out << " ]" ;
1446+ out << " \n " ;
1447+ }
1448+ DBGS () << out.str ();
1449+ });
1450+
1451+ // check that the input shape for all of the inputs match (that there are no
1452+ // broadcasts happening on some inputs)
1453+ for (auto &[inputAxes, reshapeInfo] : inputToReshapedMap) {
1454+ // this is a single axis, so broadcasting is allowed in this case, hence
1455+ // do not check
1456+ if (inputAxes.size () == 1 && reshapeInfo.newAxes .size () == 1 )
1457+ continue ;
1458+
1459+ for (size_t i = 0 ; i < einsum.getInputs ().size (); i++) {
1460+ auto inputShape =
1461+ cast<RankedTensorType>(einsum.getInputs ()[i].getType ()).getShape ();
1462+ for (size_t j = 0 ; j < inputAxes.size (); j++) {
1463+ size_t pos = equation.lhsParts [i].find (inputAxes[j]);
1464+ if (pos != std::string::npos &&
1465+ inputShape[pos] != reshapeInfo.oldShape [j])
1466+ return failure (/* input shape does not match output shape*/ );
1467+ }
1468+ }
1469+ }
1470+
14141471 for (size_t i = 0 ; i < einsum.getInputs ().size (); i++) {
14151472 Value input = einsum.getInputs ()[i];
14161473 auto inputType = cast<RankedTensorType>(input.getType ());
@@ -1434,11 +1491,34 @@ class PushReshapeUpThroughEinsum
14341491 // group
14351492 for (char c : group->second )
14361493 newInputTranspose.push_back (equation.lhsParts [i].find (c));
1437- newInputEquation += inputToReshapedMap[group->second ].first ;
1438- for (int64_t v : inputToReshapedMap[group->second ].second )
1494+ newInputEquation += inputToReshapedMap[group->second ].newAxes ;
1495+ for (int64_t v : inputToReshapedMap[group->second ].newShape )
14391496 newInputShape.push_back (v);
14401497 }
14411498 }
1499+
1500+ // Debug print for this input's result
1501+ LLVM_DEBUG ({
1502+ std::stringstream out;
1503+ out << " Input #" << i << " orig eq: " << equation.lhsParts [i]
1504+ << " new eq: " << newInputEquation << " \n " ;
1505+ out << " newInputTranspose: [" ;
1506+ for (size_t ti = 0 ; ti < newInputTranspose.size (); ++ti) {
1507+ out << newInputTranspose[ti];
1508+ if (ti + 1 < newInputTranspose.size ())
1509+ out << " , " ;
1510+ }
1511+ out << " ]\n " ;
1512+ out << " newInputShape: [" ;
1513+ for (size_t si = 0 ; si < newInputShape.size (); ++si) {
1514+ out << newInputShape[si];
1515+ if (si + 1 < newInputShape.size ())
1516+ out << " , " ;
1517+ }
1518+ out << " ]\n " ;
1519+ DBGS () << out.str () << " \n " ;
1520+ });
1521+
14421522 auto newTranspose = rewriter.createOrFold <tensorrt::TransposeOp>(
14431523 op.getLoc (), input,
14441524 AffineMap::getPermutationMap (newInputTranspose,
@@ -1449,6 +1529,8 @@ class PushReshapeUpThroughEinsum
14491529 newInputs.push_back (newReshape);
14501530 newEquation.lhsParts .push_back (newInputEquation);
14511531 }
1532+ LLVM_DEBUG (
1533+ { DBGS () << " ===============================================\n " ; });
14521534
14531535 std::string newEquationStr = newEquation.generateEquation ();
14541536
0 commit comments