2828#include " mlir/IR/Matchers.h"
2929#include " mlir/Pass/Pass.h"
3030#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
31- #include " llvm/ADT/StringExtras.h"
3231#include " llvm/Support/Debug.h"
3332#include < numeric>
3433
@@ -568,31 +567,37 @@ class RankChangeToReshape : public OpRewritePattern<OpType> {
568567
569568namespace {
570569struct EinsumEquation {
571- llvm::StringRef equation;
572- SmallVector<llvm::SmallString< 128 > > lhsParts;
573- llvm::SmallString< 128 > lhs;
574- llvm::SmallString< 128 > rhs;
570+ std::string equation;
571+ SmallVector<std::string > lhsParts;
572+ std::string lhs;
573+ std::string rhs;
575574
576575 LogicalResult parse (llvm::StringRef einsumEquation) {
576+ std::string e{einsumEquation};
577+ return parse (e);
578+ }
579+
580+ LogicalResult parse (const std::string &einsumEquation) {
577581 size_t pos = einsumEquation.find (" ->" );
578582 if (pos == std::string::npos)
579583 return failure ();
580584 equation = einsumEquation;
581- lhs = equation.substr (0 , pos);
582- rhs = equation.substr (pos + 2 );
583- SmallVector<llvm::StringRef> parts;
584- llvm::SplitString (lhs, parts, " ," );
585- for (llvm::StringRef part : parts)
586- lhsParts.push_back (part); // cast from StringRef to SmallString
585+ lhs = einsumEquation.substr (0 , pos);
586+ rhs = einsumEquation.substr (pos + 2 );
587+ std::istringstream lhsStream (lhs);
588+ std::string currentPart;
589+ while (std::getline (lhsStream, currentPart, ' ,' )) {
590+ lhsParts.push_back (currentPart);
591+ }
587592 return success ();
588593 }
589594
590- StringRef generateEquation () const {
591- llvm::SmallString< 128 > ret = lhsParts[0 ];
595+ std::string generateEquation () const {
596+ std::string ret = lhsParts[0 ];
592597 for (size_t i = 1 ; i < lhsParts.size (); i++) {
593- ret. append ({ " ," , lhsParts[i]}) ;
598+ ret += " ," + lhsParts[i];
594599 }
595- ret. append ({ " ->" , rhs}) ;
600+ ret += " ->" + rhs;
596601 return ret;
597602 }
598603};
@@ -648,7 +653,7 @@ class PushDownTransposeToEinsum : public OpRewritePattern<tensorrt::EinsumOp> {
648653 if (!hasTransposeInput)
649654 return failure ();
650655
651- StringRef newEinsumEquation = einsumEquation.generateEquation ();
656+ std::string newEinsumEquation = einsumEquation.generateEquation ();
652657
653658 rewriter.replaceOpWithNewOp <tensorrt::EinsumOp>(op, op.getType (), newInputs,
654659 newEinsumEquation);
@@ -691,7 +696,7 @@ class PushUpTransposeToEinsum : public OpRewritePattern<tensorrt::TransposeOp> {
691696 for (size_t i = 0 ; i < einsumRhs.size (); i++)
692697 einsumEquation.rhs += (char )einsumRhs[i];
693698
694- StringRef newEinsumEquation = einsumEquation.generateEquation ();
699+ std::string newEinsumEquation = einsumEquation.generateEquation ();
695700
696701 auto newEinsum = rewriter.create <tensorrt::EinsumOp>(
697702 op.getLoc (), op.getType (), einsum.getInputs (), newEinsumEquation);
@@ -730,7 +735,7 @@ class EinsumPushDownTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
730735 std::sort (outputAxes.begin (), outputAxes.end (),
731736 [&](const std::pair<char , int64_t > &a,
732737 const std::pair<char , int64_t > &b) {
733- for (auto &eqLhs : equation.lhsParts ) {
738+ for (std::string &eqLhs : equation.lhsParts ) {
734739 if (eqLhs.find (a.first ) != std::string::npos) {
735740 if (eqLhs.find (b.first ) != std::string::npos) {
736741 return eqLhs.find (a.first ) < eqLhs.find (b.first );
@@ -746,7 +751,7 @@ class EinsumPushDownTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
746751
747752 SmallVector<int64_t > newEinsumShape;
748753 SmallVector<int64_t > outputPerm;
749- SmallString< 128 > newEinsumRhs{ " " } ;
754+ std::string newEinsumRhs = " " ;
750755 for (auto &[c, i] : outputAxes) {
751756 newEinsumRhs += c;
752757 newEinsumShape.push_back (op.getType ().getDimSize (i));
@@ -755,7 +760,7 @@ class EinsumPushDownTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
755760 if (newEinsumRhs == equation.rhs )
756761 return failure (); // no change
757762
758- StringRef newEinsumEquation = equation.generateEquation ();
763+ std::string newEinsumEquation = equation.generateEquation ();
759764
760765 auto newEinsum = rewriter.create <tensorrt::EinsumOp>(
761766 op.getLoc (), op.getType ().clone (newEinsumShape), op.getInputs (),
@@ -847,7 +852,7 @@ class EinsumPushUpTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
847852 if (!didChange)
848853 return failure ();
849854
850- StringRef newEquation = equation.generateEquation ();
855+ std::string newEquation = equation.generateEquation ();
851856 rewriter.replaceOpWithNewOp <tensorrt::EinsumOp>(op, op.getType (), newInputs,
852857 newEquation);
853858 return success ();
@@ -923,7 +928,7 @@ class EinsumEliminate1Axis : public OpRewritePattern<tensorrt::EinsumOp> {
923928 newOutputShape.push_back (outputType.getDimSize (i));
924929 }
925930 }
926- StringRef newEquation = newEinsumEquation.generateEquation ();
931+ std::string newEquation = newEinsumEquation.generateEquation ();
927932
928933 if (changeOutput) {
929934 auto newEinsum = rewriter.create <tensorrt::EinsumOp>(
@@ -999,7 +1004,7 @@ class EinsumMergeDown1Axis : public OpRewritePattern<tensorrt::EinsumOp> {
9991004 if (!madeChange)
10001005 return failure ();
10011006
1002- StringRef newEquation = equation.generateEquation ();
1007+ std::string newEquation = equation.generateEquation ();
10031008
10041009 rewriter.replaceOpWithNewOp <tensorrt::EinsumOp>(op, op.getType (), newInputs,
10051010 newEquation);
@@ -1059,7 +1064,7 @@ class EinsumMergeUp1Axis : public OpRewritePattern<tensorrt::ExpandRankOp> {
10591064 }
10601065 }
10611066
1062- SmallString< 128 > newEquation{ equation.lhs , " ->" , newRhs} ;
1067+ std::string newEquation = equation.lhs + " ->" + newRhs;
10631068 rewriter.replaceOpWithNewOp <tensorrt::EinsumOp>(
10641069 op, op.getType (), einsum.getInputs (), newEquation);
10651070 return success ();
@@ -1161,7 +1166,7 @@ class PushReshapeUpThroughEinsum
11611166 // check that all of the inputs are have the right groupping. If this
11621167 // doesn't happen then that means that the reshape can not get pushed
11631168 // through
1164- for (auto &eqLhs : equation.lhsParts ) {
1169+ for (std::string &eqLhs : equation.lhsParts ) {
11651170 for (char c : eqLhs) {
11661171 auto it = charToGroup.find (c);
11671172 if (it == charToGroup.end ())
@@ -1188,7 +1193,7 @@ class PushReshapeUpThroughEinsum
11881193 for (size_t i = 0 ; i < einsum.getInputs ().size (); i++) {
11891194 Value input = einsum.getInputs ()[i];
11901195 auto inputType = cast<RankedTensorType>(input.getType ());
1191- SmallString< 128 > newInputEquation{ " " } ;
1196+ std::string newInputEquation = " " ;
11921197 SmallVector<int64_t > newInputShape;
11931198 SmallVector<int64_t > newInputTranspose;
11941199 for (int j = 0 ; j < inputType.getRank (); j++) {
@@ -1224,7 +1229,7 @@ class PushReshapeUpThroughEinsum
12241229 newEquation.lhsParts .push_back (newInputEquation);
12251230 }
12261231
1227- StringRef newEquationStr = newEquation.generateEquation ();
1232+ std::string newEquationStr = newEquation.generateEquation ();
12281233
12291234 if (has1OutputShape) {
12301235 SmallVector<int64_t > newShape;
@@ -1414,13 +1419,13 @@ class PushReshapeDownThroughEinsum
14141419 }
14151420 }
14161421
1417- for (auto &part : equation.lhsParts ) {
1422+ for (std::string &part : equation.lhsParts ) {
14181423 for (char c : part) {
14191424 auto group = charToGroup.find (c);
14201425 if (group == charToGroup.end ())
14211426 continue ;
14221427 for (char c2 : group->second ) {
1423- if (part.find (c2) == StringRef ::npos)
1428+ if (part.find (c2) == std::string ::npos)
14241429 return failure (
14251430 /* Missing dimensions that need to be reshaped together */ );
14261431 }
@@ -1432,7 +1437,7 @@ class PushReshapeDownThroughEinsum
14321437 if (group == charToGroup.end ())
14331438 continue ;
14341439 for (char c2 : group->second ) {
1435- if (equation.rhs .find (c2) == StringRef ::npos)
1440+ if (equation.rhs .find (c2) == std::string ::npos)
14361441 return failure (
14371442 /* Missing dimensions that need to be reshaped together */ );
14381443 }
@@ -1481,7 +1486,7 @@ class PushReshapeDownThroughEinsum
14811486 RankedTensorType inputType = cast<RankedTensorType>(input.getType ());
14821487 SmallVector<int64_t > newInputShape;
14831488 SmallVector<int64_t > newInputTranspose;
1484- SmallString< 128 > newEinsumStr{ " " } ;
1489+ std::string newEinsumStr = " " ;
14851490 for (int j = 0 ; j < inputType.getRank (); j++) {
14861491 char c = equation.lhsParts [i][j];
14871492 auto it = charToGroup.find (c);
@@ -1499,7 +1504,7 @@ class PushReshapeDownThroughEinsum
14991504 }
15001505 for (char c2 : group->first ) {
15011506 size_t pos = equation.lhsParts [i].find (c2);
1502- assert (pos != StringRef ::npos);
1507+ assert (pos != std::string ::npos);
15031508 newInputTranspose.push_back (pos);
15041509 }
15051510 }
@@ -1550,13 +1555,13 @@ class PushReshapeDownThroughEinsum
15501555 }
15511556 for (char c2 : it->second ) {
15521557 size_t pos = equation.rhs .find (c2);
1553- assert (pos != StringRef ::npos);
1558+ assert (pos != std::string ::npos);
15541559 afterReshapeTranspose.push_back (pos);
15551560 }
15561561 }
15571562 }
15581563
1559- StringRef newEinsumEquation = newEquation.generateEquation ();
1564+ std::string newEinsumEquation = newEquation.generateEquation ();
15601565
15611566 auto newEinsum = rewriter.create <tensorrt::EinsumOp>(
15621567 op.getLoc (), outputType.clone (einsumOutputShape), newInputs,
0 commit comments