Skip to content

Commit 56afad3

Browse files
author
Matthew Francis-Landau
committed
add utils function of getSplatConstantElementAttribute
Signed-off-by: Matthew Francis-Landau <[email protected]>
1 parent 83014d4 commit 56afad3

File tree

2 files changed

+37
-0
lines changed
  • mlir-tensorrt/tensorrt

2 files changed

+37
-0
lines changed

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/Utils/Utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ TypedValue<RankedTensorType>
6868
scatterShapeTensor(RewriterBase &b, Location loc, ArrayRef<int64_t> baseShape,
6969
int32_t scatterDim, TypedValue<RankedTensorType> update);
7070

71+
/// Get a splatted constant's attribute by going up a chain of reshape and cast
72+
/// operations to find the original constant. The constant can be a different
73+
/// data type if there is a cast operation in the chain.
74+
FailureOr<Attribute> getSplatConstantElementAttribute(Value x);
75+
7176
} // namespace tensorrt
7277
} // namespace mlir
7378

mlir-tensorrt/tensorrt/lib/TensorRT/Utils/Utils.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir-tensorrt-dialect/TensorRT/Utils/Utils.h"
2323

2424
#include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h"
25+
#include "mlir/IR/Matchers.h"
2526
#include "mlir/Interfaces/FunctionInterfaces.h"
2627

2728
using namespace mlir;
@@ -158,3 +159,34 @@ tensorrt::scatterShapeTensor(RewriterBase &b, Location loc,
158159

159160
return b.create<tensorrt::ConcatenationOp>(loc, parts, 0);
160161
}
162+
163+
FailureOr<Attribute> tensorrt::getSplatConstantElementAttribute(Value x) {
164+
while (true) {
165+
if (auto expandRank = x.getDefiningOp<tensorrt::ExpandRankOp>())
166+
x = expandRank.getInput();
167+
else if (auto collapseRank = x.getDefiningOp<tensorrt::CollapseRankOp>())
168+
x = collapseRank.getInput();
169+
else if (auto reshape = x.getDefiningOp<tensorrt::ReshapeOp>())
170+
x = reshape.getInput();
171+
else if (auto broadcast = x.getDefiningOp<tensorrt::BroadcastOp>())
172+
x = broadcast.getInput();
173+
else if (auto cast = x.getDefiningOp<tensorrt::CastOp>())
174+
x = cast.getInput();
175+
else if (auto identity = x.getDefiningOp<tensorrt::IdentityOp>())
176+
x = identity.getInput();
177+
else if (auto slice = x.getDefiningOp<tensorrt::SliceOp>())
178+
x = slice.getInput();
179+
else if (auto constant = x.getDefiningOp<tensorrt::ConstantOp>()) {
180+
SplatElementsAttr els{};
181+
if (!matchPattern(x, m_Constant(&els)))
182+
return failure();
183+
Attribute value = els.getSplatValue<Attribute>();
184+
if (!isa<FloatAttr, IntegerAttr>(value))
185+
return failure();
186+
return value;
187+
} else {
188+
return failure();
189+
}
190+
}
191+
return failure();
192+
}

0 commit comments

Comments
 (0)