Skip to content

Commit 0685a68

Browse files
author
Matthew Francis-Landau
committed
make raise activation use same getSplatConstantElementAttribute utils function
Signed-off-by: Matthew Francis-Landau <[email protected]>
1 parent 56afad3 commit 0685a68

File tree

2 files changed

+9
-79
lines changed

2 files changed

+9
-79
lines changed

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/RaiseActivations.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
//===----------------------------------------------------------------------===//
2424
#include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h"
2525
#include "mlir-tensorrt-dialect/TensorRT/Transforms/Passes.h"
26+
#include "mlir-tensorrt-dialect/TensorRT/Utils/Utils.h"
2627
#include "mlir/Dialect/PDL/IR/PDL.h"
2728
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
2829
#include "mlir/IR/Matchers.h"

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/RaiseActivations.pdll

Lines changed: 8 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -122,93 +122,22 @@ Constraint Erf(x: Value) -> Op {
122122
}
123123

124124
Rewrite GetSplatElementAttr(x: Value) -> Attr [{
125-
while(true) {
126-
if(auto expandRank = x.getDefiningOp<tensorrt::ExpandRankOp>())
127-
x = expandRank.getInput();
128-
else if(auto reshape = x.getDefiningOp<tensorrt::ReshapeOp>())
129-
x = reshape.getInput();
130-
else if(auto broadcast = x.getDefiningOp<tensorrt::BroadcastOp>())
131-
x = broadcast.getInput();
132-
else if(auto cast = x.getDefiningOp<tensorrt::CastOp>())
133-
x = cast.getInput();
134-
else if(auto identity = x.getDefiningOp<tensorrt::IdentityOp>())
135-
x = identity.getInput();
136-
else if(auto slice = x.getDefiningOp<tensorrt::SliceOp>())
137-
x = slice.getInput();
138-
else if(auto constant = x.getDefiningOp<tensorrt::ConstantOp>()) {
139-
DenseElementsAttr els{};
140-
if(!matchPattern(x, m_Constant(&els)))
141-
return {};
142-
if(!els.isSplat())
143-
return {};
144-
Attribute value = els.getSplatValue<Attribute>();
145-
return value;
146-
} else
147-
return {};
148-
}
149-
return {};
125+
return *getSplatConstantElementAttribute(x);
150126
}];
151127

152128
Constraint HasSplatElements(x: Value) [{
153-
while(true) {
154-
if(auto expandRank = x.getDefiningOp<tensorrt::ExpandRankOp>())
155-
x = expandRank.getInput();
156-
else if(auto reshape = x.getDefiningOp<tensorrt::ReshapeOp>())
157-
x = reshape.getInput();
158-
else if(auto broadcast = x.getDefiningOp<tensorrt::BroadcastOp>())
159-
x = broadcast.getInput();
160-
else if(auto cast = x.getDefiningOp<tensorrt::CastOp>())
161-
x = cast.getInput();
162-
else if(auto identity = x.getDefiningOp<tensorrt::IdentityOp>())
163-
x = identity.getInput();
164-
else if(auto slice = x.getDefiningOp<tensorrt::SliceOp>())
165-
x = slice.getInput();
166-
else if(auto constant = x.getDefiningOp<tensorrt::ConstantOp>()) {
167-
DenseElementsAttr els{};
168-
if(!matchPattern(x, m_Constant(&els)))
169-
return failure();
170-
if(!els.isSplat())
171-
return failure();
172-
Attribute value = els.getSplatValue<Attribute>();
173-
return success(isa<FloatAttr>(value));
174-
} else
175-
return failure();
176-
}
177-
return failure();
129+
return LogicalResult(getSplatConstantElementAttribute(x));
178130
}];
179131

180132
/// Is true if `x` is a constant op that has a splat constant
181133
/// where splat element is equal to `attr`.
182134
Constraint SplatElements(x: Value, attr: Attr) [{
183-
while(true) {
184-
if(auto expandRank = x.getDefiningOp<tensorrt::ExpandRankOp>())
185-
x = expandRank.getInput();
186-
else if(auto reshape = x.getDefiningOp<tensorrt::ReshapeOp>())
187-
x = reshape.getInput();
188-
else if(auto broadcast = x.getDefiningOp<tensorrt::BroadcastOp>())
189-
x = broadcast.getInput();
190-
else if(auto cast = x.getDefiningOp<tensorrt::CastOp>())
191-
x = cast.getInput();
192-
else if(auto identity = x.getDefiningOp<tensorrt::IdentityOp>())
193-
x = identity.getInput();
194-
else if(auto slice = x.getDefiningOp<tensorrt::SliceOp>())
195-
x = slice.getInput();
196-
else if(auto constant = x.getDefiningOp<tensorrt::ConstantOp>()) {
197-
DenseElementsAttr els{};
198-
if(!matchPattern(x, m_Constant(&els)))
199-
return failure();
200-
if(!els.isSplat())
201-
return failure();
202-
Attribute value = els.getSplatValue<Attribute>();
203-
if(!value) return failure();
204-
if(value == attr) return success();
205-
FloatAttr fvalue = dyn_cast<FloatAttr>(value);
206-
FloatAttr fattr = dyn_cast<FloatAttr>(attr);
207-
return success(fvalue && fattr && std::abs(fvalue.getValueAsDouble() - fattr.getValueAsDouble()) < .001); // handle different floating point type
208-
} else
209-
return failure();
210-
}
211-
return failure();
135+
FailureOr<Attribute> value = getSplatConstantElementAttribute(x);
136+
if(LogicalResult(value).failed()) return failure();
137+
if(*value == attr) return success();
138+
FloatAttr fvalue = dyn_cast<FloatAttr>(*value);
139+
FloatAttr fattr = dyn_cast<FloatAttr>(attr);
140+
return success(fvalue && fattr && std::abs(fvalue.getValueAsDouble() - fattr.getValueAsDouble()) < .001); // handle different floating point type
212141
}];
213142

214143
/// We need a native C++ function since we can't create the right

0 commit comments

Comments
 (0)