@@ -164,7 +164,8 @@ struct PushDownTransposeActivationRewriter
164164 auto activationOp = rewriter.create <ActivationOp>(
165165 op.getLoc (), producer.getInput (), op.getActivationType (),
166166 op.getAlphaAttr (), op.getBetaAttr ());
167- auto newTranspose = rewriter.create <TransposeOp>(producer.getLoc (), activationOp.getResult (), permutation);
167+ auto newTranspose = rewriter.create <TransposeOp>(
168+ producer.getLoc (), activationOp.getResult (), permutation);
168169 rewriter.replaceOp (op, newTranspose.getResult ());
169170 return success ();
170171 }
@@ -181,7 +182,8 @@ struct PushDownTransposeUnary : OpRewritePattern<UnaryOp> {
181182 AffineMap permutation = producer.getPermutation ();
182183 auto unary = rewriter.create <UnaryOp>(op.getLoc (), producer.getInput (),
183184 op.getUnaryOperationAttr ());
184- auto newTranspose = rewriter.create <TransposeOp>(producer.getLoc (), unary.getResult (), permutation);
185+ auto newTranspose = rewriter.create <TransposeOp>(
186+ producer.getLoc (), unary.getResult (), permutation);
185187 rewriter.replaceOp (op, newTranspose.getResult ());
186188 return success ();
187189 }
@@ -755,13 +757,23 @@ class EinsumPushDownTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
755757 return a.first < b.first ;
756758 });
757759
760+ LLVM_DEBUG ({
761+ std::stringstream out;
762+ out << " outputAxes: [" ;
763+ for (auto x : outputAxes) {
764+ out << x.first << " (" << x.second << " ) " ;
765+ }
766+ out << " ]\n " ;
767+ DBGS () << out.str ();
768+ });
769+
758770 SmallVector<int64_t > newEinsumShape;
759- SmallVector<int64_t > outputPerm ;
771+ SmallVector<int64_t > forwardPerm ;
760772 std::string newEinsumRhs = " " ;
761773 for (auto &[c, i] : outputAxes) {
762774 newEinsumRhs += c;
763775 newEinsumShape.push_back (op.getType ().getDimSize (i));
764- outputPerm .push_back (i);
776+ forwardPerm .push_back (i);
765777 }
766778 if (newEinsumRhs == equation.rhs )
767779 return failure (); // no change
@@ -773,10 +785,13 @@ class EinsumPushDownTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
773785 op.getLoc (), op.getType ().clone (newEinsumShape), op.getInputs (),
774786 newEinsumEquation);
775787
788+ auto forwardMap =
789+ AffineMap::getPermutationMap (forwardPerm, op.getLoc ().getContext ());
790+
776791 auto newTranspose = rewriter.create <tensorrt::TransposeOp>(
777- op.getLoc (), newEinsum.getResult (),
778- AffineMap::getPermutationMap (outputPerm, op.getLoc ().getContext ()));
792+ op.getLoc (), newEinsum.getResult (), inversePermutation (forwardMap));
779793
794+ assert (op.getType () == newTranspose.getType ());
780795 rewriter.replaceOp (op, newTranspose.getResult ());
781796
782797 return success ();
@@ -1662,29 +1677,29 @@ class MoveReshapeBeforeTranspose
16621677 }
16631678 }
16641679 assert (inputNumElems == outputNumElems);
1665- while (j < reshapeOutputType.getRank ()) {
1680+ while (j < reshapeOutputType.getRank ()) {
16661681 outputNumElems *= reshapeOutputType.getDimSize (j);
16671682 groupReshapeOut.push_back (reshapeOutputType.getDimSize (j));
16681683 transposeOutAxes.push_back (j++);
16691684 }
16701685 assert (inputNumElems == outputNumElems);
16711686 assert (transposeInAxes.empty ());
1672- if (!transposeOutAxes.empty () || !groupReshapeOut.empty ()) {
1687+ if (!transposeOutAxes.empty () || !groupReshapeOut.empty ()) {
16731688 reshapeGroups.push_back (ReshapeGroup{
16741689 .transposeInAxes = transposeInAxes,
16751690 .transposeOutAxes = transposeOutAxes,
16761691 .reshapeOut = groupReshapeOut,
16771692 .startOutputIdx = -1 , // set later
1678- });
1693+ });
16791694 }
16801695
16811696 SmallVector<int64_t > newTranspose;
16821697 SmallVector<int64_t > newReshape;
16831698
16841699 std::sort (reshapeGroups.begin (), reshapeGroups.end (), [](auto &a, auto &b) {
1685- if (a.transposeInAxes .empty ())
1700+ if (a.transposeInAxes .empty ())
16861701 return false ;
1687- if (b.transposeInAxes .empty ())
1702+ if (b.transposeInAxes .empty ())
16881703 return true ;
16891704 return a.transposeInAxes [0 ] < b.transposeInAxes [0 ];
16901705 });
@@ -1713,28 +1728,29 @@ class MoveReshapeBeforeTranspose
17131728 out << " transposeInAxes: [" ;
17141729 for (size_t i = 0 ; i < group.transposeInAxes .size (); ++i) {
17151730 out << group.transposeInAxes [i];
1716- if (i + 1 < group.transposeInAxes .size ()) out << " , " ;
1731+ if (i + 1 < group.transposeInAxes .size ())
1732+ out << " , " ;
17171733 }
17181734 out << " ]\n " ;
17191735 out << " transposeOutAxes: [" ;
17201736 for (size_t i = 0 ; i < group.transposeOutAxes .size (); ++i) {
17211737 out << group.transposeOutAxes [i];
1722- if (i + 1 < group.transposeOutAxes .size ()) out << " , " ;
1738+ if (i + 1 < group.transposeOutAxes .size ())
1739+ out << " , " ;
17231740 }
17241741 out << " ]\n " ;
17251742 out << " reshapeOut: [" ;
17261743 for (size_t i = 0 ; i < group.reshapeOut .size (); ++i) {
17271744 out << group.reshapeOut [i];
1728- if (i + 1 < group.reshapeOut .size ()) out << " , " ;
1745+ if (i + 1 < group.reshapeOut .size ())
1746+ out << " , " ;
17291747 }
17301748 out << " ]\n " ;
17311749 out << " startOutputIdx: " << group.startOutputIdx << " \n " ;
17321750 }
17331751 DBGS () << out.str ();
17341752 });
17351753
1736-
1737-
17381754 for (auto &group : reshapeGroups) {
17391755 for (size_t i = 0 ; i < group.reshapeOut .size (); i++)
17401756 newTranspose.push_back (group.startOutputIdx + i);
0 commit comments