Skip to content

Commit 5486d22

Browse files
author
Matthew Francis-Landau
committed
Revert "change to using llvm::SplitString for splitting the string"
This reverts commit 353322b.
1 parent c960e12 commit 5486d22

File tree

1 file changed

+39
-34
lines changed

1 file changed

+39
-34
lines changed

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
#include "mlir/IR/Matchers.h"
2929
#include "mlir/Pass/Pass.h"
3030
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31-
#include "llvm/ADT/StringExtras.h"
3231
#include "llvm/Support/Debug.h"
3332
#include <numeric>
3433

@@ -568,31 +567,37 @@ class RankChangeToReshape : public OpRewritePattern<OpType> {
568567

569568
namespace {
570569
struct EinsumEquation {
571-
llvm::StringRef equation;
572-
SmallVector<llvm::SmallString<128>> lhsParts;
573-
llvm::SmallString<128> lhs;
574-
llvm::SmallString<128> rhs;
570+
std::string equation;
571+
SmallVector<std::string> lhsParts;
572+
std::string lhs;
573+
std::string rhs;
575574

576575
LogicalResult parse(llvm::StringRef einsumEquation) {
576+
std::string e{einsumEquation};
577+
return parse(e);
578+
}
579+
580+
LogicalResult parse(const std::string &einsumEquation) {
577581
size_t pos = einsumEquation.find("->");
578582
if (pos == std::string::npos)
579583
return failure();
580584
equation = einsumEquation;
581-
lhs = equation.substr(0, pos);
582-
rhs = equation.substr(pos + 2);
583-
SmallVector<llvm::StringRef> parts;
584-
llvm::SplitString(lhs, parts, ",");
585-
for (llvm::StringRef part : parts)
586-
lhsParts.push_back(part); // cast from StringRef to SmallString
585+
lhs = einsumEquation.substr(0, pos);
586+
rhs = einsumEquation.substr(pos + 2);
587+
std::istringstream lhsStream(lhs);
588+
std::string currentPart;
589+
while (std::getline(lhsStream, currentPart, ',')) {
590+
lhsParts.push_back(currentPart);
591+
}
587592
return success();
588593
}
589594

590-
StringRef generateEquation() const {
591-
llvm::SmallString<128> ret = lhsParts[0];
595+
std::string generateEquation() const {
596+
std::string ret = lhsParts[0];
592597
for (size_t i = 1; i < lhsParts.size(); i++) {
593-
ret.append({",", lhsParts[i]});
598+
ret += "," + lhsParts[i];
594599
}
595-
ret.append({"->", rhs});
600+
ret += "->" + rhs;
596601
return ret;
597602
}
598603
};
@@ -648,7 +653,7 @@ class PushDownTransposeToEinsum : public OpRewritePattern<tensorrt::EinsumOp> {
648653
if (!hasTransposeInput)
649654
return failure();
650655

651-
StringRef newEinsumEquation = einsumEquation.generateEquation();
656+
std::string newEinsumEquation = einsumEquation.generateEquation();
652657

653658
rewriter.replaceOpWithNewOp<tensorrt::EinsumOp>(op, op.getType(), newInputs,
654659
newEinsumEquation);
@@ -691,7 +696,7 @@ class PushUpTransposeToEinsum : public OpRewritePattern<tensorrt::TransposeOp> {
691696
for (size_t i = 0; i < einsumRhs.size(); i++)
692697
einsumEquation.rhs += (char)einsumRhs[i];
693698

694-
StringRef newEinsumEquation = einsumEquation.generateEquation();
699+
std::string newEinsumEquation = einsumEquation.generateEquation();
695700

696701
auto newEinsum = rewriter.create<tensorrt::EinsumOp>(
697702
op.getLoc(), op.getType(), einsum.getInputs(), newEinsumEquation);
@@ -730,7 +735,7 @@ class EinsumPushDownTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
730735
std::sort(outputAxes.begin(), outputAxes.end(),
731736
[&](const std::pair<char, int64_t> &a,
732737
const std::pair<char, int64_t> &b) {
733-
for (auto &eqLhs : equation.lhsParts) {
738+
for (std::string &eqLhs : equation.lhsParts) {
734739
if (eqLhs.find(a.first) != std::string::npos) {
735740
if (eqLhs.find(b.first) != std::string::npos) {
736741
return eqLhs.find(a.first) < eqLhs.find(b.first);
@@ -746,7 +751,7 @@ class EinsumPushDownTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
746751

747752
SmallVector<int64_t> newEinsumShape;
748753
SmallVector<int64_t> outputPerm;
749-
SmallString<128> newEinsumRhs{""};
754+
std::string newEinsumRhs = "";
750755
for (auto &[c, i] : outputAxes) {
751756
newEinsumRhs += c;
752757
newEinsumShape.push_back(op.getType().getDimSize(i));
@@ -755,7 +760,7 @@ class EinsumPushDownTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
755760
if (newEinsumRhs == equation.rhs)
756761
return failure(); // no change
757762

758-
StringRef newEinsumEquation = equation.generateEquation();
763+
std::string newEinsumEquation = equation.generateEquation();
759764

760765
auto newEinsum = rewriter.create<tensorrt::EinsumOp>(
761766
op.getLoc(), op.getType().clone(newEinsumShape), op.getInputs(),
@@ -847,7 +852,7 @@ class EinsumPushUpTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
847852
if (!didChange)
848853
return failure();
849854

850-
StringRef newEquation = equation.generateEquation();
855+
std::string newEquation = equation.generateEquation();
851856
rewriter.replaceOpWithNewOp<tensorrt::EinsumOp>(op, op.getType(), newInputs,
852857
newEquation);
853858
return success();
@@ -923,7 +928,7 @@ class EinsumEliminate1Axis : public OpRewritePattern<tensorrt::EinsumOp> {
923928
newOutputShape.push_back(outputType.getDimSize(i));
924929
}
925930
}
926-
StringRef newEquation = newEinsumEquation.generateEquation();
931+
std::string newEquation = newEinsumEquation.generateEquation();
927932

928933
if (changeOutput) {
929934
auto newEinsum = rewriter.create<tensorrt::EinsumOp>(
@@ -999,7 +1004,7 @@ class EinsumMergeDown1Axis : public OpRewritePattern<tensorrt::EinsumOp> {
9991004
if (!madeChange)
10001005
return failure();
10011006

1002-
StringRef newEquation = equation.generateEquation();
1007+
std::string newEquation = equation.generateEquation();
10031008

10041009
rewriter.replaceOpWithNewOp<tensorrt::EinsumOp>(op, op.getType(), newInputs,
10051010
newEquation);
@@ -1059,7 +1064,7 @@ class EinsumMergeUp1Axis : public OpRewritePattern<tensorrt::ExpandRankOp> {
10591064
}
10601065
}
10611066

1062-
SmallString<128> newEquation{equation.lhs, "->", newRhs};
1067+
std::string newEquation = equation.lhs + "->" + newRhs;
10631068
rewriter.replaceOpWithNewOp<tensorrt::EinsumOp>(
10641069
op, op.getType(), einsum.getInputs(), newEquation);
10651070
return success();
@@ -1161,7 +1166,7 @@ class PushReshapeUpThroughEinsum
11611166
// check that all of the inputs are have the right groupping. If this
11621167
// doesn't happen then that means that the reshape can not get pushed
11631168
// through
1164-
for (auto &eqLhs : equation.lhsParts) {
1169+
for (std::string &eqLhs : equation.lhsParts) {
11651170
for (char c : eqLhs) {
11661171
auto it = charToGroup.find(c);
11671172
if (it == charToGroup.end())
@@ -1188,7 +1193,7 @@ class PushReshapeUpThroughEinsum
11881193
for (size_t i = 0; i < einsum.getInputs().size(); i++) {
11891194
Value input = einsum.getInputs()[i];
11901195
auto inputType = cast<RankedTensorType>(input.getType());
1191-
SmallString<128> newInputEquation{""};
1196+
std::string newInputEquation = "";
11921197
SmallVector<int64_t> newInputShape;
11931198
SmallVector<int64_t> newInputTranspose;
11941199
for (int j = 0; j < inputType.getRank(); j++) {
@@ -1224,7 +1229,7 @@ class PushReshapeUpThroughEinsum
12241229
newEquation.lhsParts.push_back(newInputEquation);
12251230
}
12261231

1227-
StringRef newEquationStr = newEquation.generateEquation();
1232+
std::string newEquationStr = newEquation.generateEquation();
12281233

12291234
if (has1OutputShape) {
12301235
SmallVector<int64_t> newShape;
@@ -1414,13 +1419,13 @@ class PushReshapeDownThroughEinsum
14141419
}
14151420
}
14161421

1417-
for (auto &part : equation.lhsParts) {
1422+
for (std::string &part : equation.lhsParts) {
14181423
for (char c : part) {
14191424
auto group = charToGroup.find(c);
14201425
if (group == charToGroup.end())
14211426
continue;
14221427
for (char c2 : group->second) {
1423-
if (part.find(c2) == StringRef::npos)
1428+
if (part.find(c2) == std::string::npos)
14241429
return failure(
14251430
/* Missing dimensions that need to be reshaped together */);
14261431
}
@@ -1432,7 +1437,7 @@ class PushReshapeDownThroughEinsum
14321437
if (group == charToGroup.end())
14331438
continue;
14341439
for (char c2 : group->second) {
1435-
if (equation.rhs.find(c2) == StringRef::npos)
1440+
if (equation.rhs.find(c2) == std::string::npos)
14361441
return failure(
14371442
/* Missing dimensions that need to be reshaped together */);
14381443
}
@@ -1481,7 +1486,7 @@ class PushReshapeDownThroughEinsum
14811486
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
14821487
SmallVector<int64_t> newInputShape;
14831488
SmallVector<int64_t> newInputTranspose;
1484-
SmallString<128> newEinsumStr{""};
1489+
std::string newEinsumStr = "";
14851490
for (int j = 0; j < inputType.getRank(); j++) {
14861491
char c = equation.lhsParts[i][j];
14871492
auto it = charToGroup.find(c);
@@ -1499,7 +1504,7 @@ class PushReshapeDownThroughEinsum
14991504
}
15001505
for (char c2 : group->first) {
15011506
size_t pos = equation.lhsParts[i].find(c2);
1502-
assert(pos != StringRef::npos);
1507+
assert(pos != std::string::npos);
15031508
newInputTranspose.push_back(pos);
15041509
}
15051510
}
@@ -1550,13 +1555,13 @@ class PushReshapeDownThroughEinsum
15501555
}
15511556
for (char c2 : it->second) {
15521557
size_t pos = equation.rhs.find(c2);
1553-
assert(pos != StringRef::npos);
1558+
assert(pos != std::string::npos);
15541559
afterReshapeTranspose.push_back(pos);
15551560
}
15561561
}
15571562
}
15581563

1559-
StringRef newEinsumEquation = newEquation.generateEquation();
1564+
std::string newEinsumEquation = newEquation.generateEquation();
15601565

15611566
auto newEinsum = rewriter.create<tensorrt::EinsumOp>(
15621567
op.getLoc(), outputType.clone(einsumOutputShape), newInputs,

0 commit comments

Comments
 (0)