@@ -1635,7 +1635,8 @@ class MoveReshapeBeforeTranspose
16351635 SmallVector<int64_t > groupReshapeOut;
16361636 size_t inputNumElems = 1 ;
16371637 size_t outputNumElems = 1 ;
1638- for (int i = 0 , j = 0 ; i < reshapeInputType.getRank (); i++) {
1638+ int j = 0 ;
1639+ for (int i = 0 ; i < reshapeInputType.getRank (); i++) {
16391640 inputNumElems *= reshapeInputType.getDimSize (i);
16401641 if (!transposeInAxes.empty () &&
16411642 transposeInAxes.back () + 1 != transposePerm[i]) {
@@ -1740,7 +1741,8 @@ class MoveTransposeBeforeReshape
17401741 SmallVector<int64_t > groupReshapeOut;
17411742 size_t inputNumElems = 1 ;
17421743 size_t outputNumElems = 1 ;
1743- for (int i = 0 , j = 0 ; i < reshapeInputType.getRank (); i++) {
1744+ int j = 0 ;
1745+ for (int i = 0 ; i < reshapeInputType.getRank (); i++) {
17441746 inputNumElems *= reshapeInputType.getDimSize (i);
17451747 inputAxes.push_back (i);
17461748 while (j < reshapeOutputType.getRank () &&
@@ -1764,6 +1766,21 @@ class MoveTransposeBeforeReshape
17641766 groupReshapeOut.clear ();
17651767 }
17661768 }
1769+ while (j < reshapeOutputType.getRank ()) {
1770+ outputNumElems *= reshapeOutputType.getDimSize (j);
1771+ groupReshapeOut.push_back (reshapeOutputType.getDimSize (j));
1772+ outputAxes.push_back (transposePerm[j++]);
1773+ }
1774+
1775+ assert (inputNumElems == outputNumElems);
1776+ assert (inputAxes.empty ());
1777+ if (!outputAxes.empty () || !groupReshapeOut.empty ()) {
1778+ reshapeGroups.push_back (ReshapeGroup{
1779+ .inputAxes = inputAxes,
1780+ .outputAxes = outputAxes,
1781+ .reshapeOut = groupReshapeOut,
1782+ });
1783+ }
17671784
17681785 SmallVector<int64_t > newTranspose;
17691786 SmallVector<int64_t > newReshape;
@@ -1776,6 +1793,38 @@ class MoveTransposeBeforeReshape
17761793 return a.outputAxes [0 ] < b.outputAxes [0 ];
17771794 });
17781795
1796+ // Debug print of reshapeGroups
1797+ LLVM_DEBUG ({
1798+ std::stringstream out;
1799+ out << " reshapeGroups:\n " ;
1800+ for (size_t idx = 0 ; idx < reshapeGroups.size (); ++idx) {
1801+ const auto &group = reshapeGroups[idx];
1802+ out << " Group " << idx << " :\n " ;
1803+ out << " inputAxes: [" ;
1804+ for (size_t i = 0 ; i < group.inputAxes .size (); ++i) {
1805+ out << group.inputAxes [i];
1806+ if (i + 1 < group.inputAxes .size ())
1807+ out << " , " ;
1808+ }
1809+ out << " ]\n " ;
1810+ out << " outputAxes: [" ;
1811+ for (size_t i = 0 ; i < group.outputAxes .size (); ++i) {
1812+ out << group.outputAxes [i];
1813+ if (i + 1 < group.outputAxes .size ())
1814+ DBGS () << " , " ;
1815+ }
1816+ out << " ]\n " ;
1817+ out << " reshapeOut: [" ;
1818+ for (size_t i = 0 ; i < group.reshapeOut .size (); ++i) {
1819+ out << group.reshapeOut [i];
1820+ if (i + 1 < group.reshapeOut .size ())
1821+ DBGS () << " , " ;
1822+ }
1823+ out << " ]\n " ;
1824+ }
1825+ DBGS () << out.str ();
1826+ });
1827+
17791828 for (auto &group : reshapeGroups) {
17801829 for (int64_t i : group.inputAxes )
17811830 newTranspose.push_back (i);
@@ -1789,6 +1838,8 @@ class MoveTransposeBeforeReshape
17891838 Value newReshapeOp = rewriter.createOrFold <tensorrt::ReshapeOp>(
17901839 op.getLoc (), reshapeInputType.clone (newReshape), newTransposeOp);
17911840
1841+ assert (op.getType () == newReshapeOp.getType ());
1842+
17921843 rewriter.replaceOp (op, newReshapeOp);
17931844 return success ();
17941845 }
@@ -2244,15 +2295,15 @@ class MatrixMultiplyTransposedArguments
22442295 return std::make_tuple (arg, operation);
22452296 bool swapsLastTwo = true ;
22462297 for (int64_t i = 0 ; i < rank - 2 ; ++i) {
2247- auto expr = permVec[i]. dyn_cast <AffineDimExpr>();
2298+ auto expr = dyn_cast<AffineDimExpr>(permVec[i] );
22482299 if (!expr || expr.getPosition () != i) {
22492300 swapsLastTwo = false ;
22502301 break ;
22512302 }
22522303 }
22532304 if (swapsLastTwo) {
2254- auto expr1 = permVec[rank - 2 ]. dyn_cast <AffineDimExpr>( );
2255- auto expr2 = permVec[rank - 1 ]. dyn_cast <AffineDimExpr>( );
2305+ auto expr1 = dyn_cast<AffineDimExpr>( permVec[rank - 2 ]);
2306+ auto expr2 = dyn_cast<AffineDimExpr>( permVec[rank - 1 ]);
22562307 if (!(expr1 && expr2 && expr1.getPosition () == rank - 1 &&
22572308 expr2.getPosition () == rank - 2 )) {
22582309 swapsLastTwo = false ;
0 commit comments