2222#include " mlir-tensorrt-dialect/TensorRT/Utils/Utils.h"
2323
2424#include " mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h"
25+ #include " mlir/IR/Matchers.h"
2526#include " mlir/Interfaces/FunctionInterfaces.h"
2627
2728using namespace mlir ;
@@ -158,3 +159,34 @@ tensorrt::scatterShapeTensor(RewriterBase &b, Location loc,
158159
159160 return b.create <tensorrt::ConcatenationOp>(loc, parts, 0 );
160161}
162+
163+ FailureOr<Attribute> tensorrt::getSplatConstantElementAttribute (Value x) {
164+ while (true ) {
165+ if (auto expandRank = x.getDefiningOp <tensorrt::ExpandRankOp>())
166+ x = expandRank.getInput ();
167+ else if (auto collapseRank = x.getDefiningOp <tensorrt::CollapseRankOp>())
168+ x = collapseRank.getInput ();
169+ else if (auto reshape = x.getDefiningOp <tensorrt::ReshapeOp>())
170+ x = reshape.getInput ();
171+ else if (auto broadcast = x.getDefiningOp <tensorrt::BroadcastOp>())
172+ x = broadcast.getInput ();
173+ else if (auto cast = x.getDefiningOp <tensorrt::CastOp>())
174+ x = cast.getInput ();
175+ else if (auto identity = x.getDefiningOp <tensorrt::IdentityOp>())
176+ x = identity.getInput ();
177+ else if (auto slice = x.getDefiningOp <tensorrt::SliceOp>())
178+ x = slice.getInput ();
179+ else if (auto constant = x.getDefiningOp <tensorrt::ConstantOp>()) {
180+ SplatElementsAttr els{};
181+ if (!matchPattern (x, m_Constant (&els)))
182+ return failure ();
183+ Attribute value = els.getSplatValue <Attribute>();
184+ if (!isa<FloatAttr, IntegerAttr>(value))
185+ return failure ();
186+ return value;
187+ } else {
188+ return failure ();
189+ }
190+ }
191+ return failure ();
192+ }
0 commit comments