@@ -122,93 +122,22 @@ Constraint Erf(x: Value) -> Op {
122122}
123123
124124Rewrite GetSplatElementAttr(x: Value) -> Attr [{
125- while(true) {
126- if(auto expandRank = x.getDefiningOp<tensorrt::ExpandRankOp>())
127- x = expandRank.getInput();
128- else if(auto reshape = x.getDefiningOp<tensorrt::ReshapeOp>())
129- x = reshape.getInput();
130- else if(auto broadcast = x.getDefiningOp<tensorrt::BroadcastOp>())
131- x = broadcast.getInput();
132- else if(auto cast = x.getDefiningOp<tensorrt::CastOp>())
133- x = cast.getInput();
134- else if(auto identity = x.getDefiningOp<tensorrt::IdentityOp>())
135- x = identity.getInput();
136- else if(auto slice = x.getDefiningOp<tensorrt::SliceOp>())
137- x = slice.getInput();
138- else if(auto constant = x.getDefiningOp<tensorrt::ConstantOp>()) {
139- DenseElementsAttr els{};
140- if(!matchPattern(x, m_Constant(&els)))
141- return {};
142- if(!els.isSplat())
143- return {};
144- Attribute value = els.getSplatValue<Attribute>();
145- return value;
146- } else
147- return {};
148- }
149- return {};
125+ return *getSplatConstantElementAttribute(x);
150126}];
151127
152128Constraint HasSplatElements(x: Value) [{
153- while(true) {
154- if(auto expandRank = x.getDefiningOp<tensorrt::ExpandRankOp>())
155- x = expandRank.getInput();
156- else if(auto reshape = x.getDefiningOp<tensorrt::ReshapeOp>())
157- x = reshape.getInput();
158- else if(auto broadcast = x.getDefiningOp<tensorrt::BroadcastOp>())
159- x = broadcast.getInput();
160- else if(auto cast = x.getDefiningOp<tensorrt::CastOp>())
161- x = cast.getInput();
162- else if(auto identity = x.getDefiningOp<tensorrt::IdentityOp>())
163- x = identity.getInput();
164- else if(auto slice = x.getDefiningOp<tensorrt::SliceOp>())
165- x = slice.getInput();
166- else if(auto constant = x.getDefiningOp<tensorrt::ConstantOp>()) {
167- DenseElementsAttr els{};
168- if(!matchPattern(x, m_Constant(&els)))
169- return failure();
170- if(!els.isSplat())
171- return failure();
172- Attribute value = els.getSplatValue<Attribute>();
173- return success(isa<FloatAttr>(value));
174- } else
175- return failure();
176- }
177- return failure();
129+ return LogicalResult(getSplatConstantElementAttribute(x));
178130}];
179131
180132/// Is true if `x` is a constant op that has a splat constant
181133/// where splat element is equal to `attr`.
182134Constraint SplatElements(x: Value, attr: Attr) [{
183- while(true) {
184- if(auto expandRank = x.getDefiningOp<tensorrt::ExpandRankOp>())
185- x = expandRank.getInput();
186- else if(auto reshape = x.getDefiningOp<tensorrt::ReshapeOp>())
187- x = reshape.getInput();
188- else if(auto broadcast = x.getDefiningOp<tensorrt::BroadcastOp>())
189- x = broadcast.getInput();
190- else if(auto cast = x.getDefiningOp<tensorrt::CastOp>())
191- x = cast.getInput();
192- else if(auto identity = x.getDefiningOp<tensorrt::IdentityOp>())
193- x = identity.getInput();
194- else if(auto slice = x.getDefiningOp<tensorrt::SliceOp>())
195- x = slice.getInput();
196- else if(auto constant = x.getDefiningOp<tensorrt::ConstantOp>()) {
197- DenseElementsAttr els{};
198- if(!matchPattern(x, m_Constant(&els)))
199- return failure();
200- if(!els.isSplat())
201- return failure();
202- Attribute value = els.getSplatValue<Attribute>();
203- if(!value) return failure();
204- if(value == attr) return success();
205- FloatAttr fvalue = dyn_cast<FloatAttr>(value);
206- FloatAttr fattr = dyn_cast<FloatAttr>(attr);
207- return success(fvalue && fattr && std::abs(fvalue.getValueAsDouble() - fattr.getValueAsDouble()) < .001); // handle different floating point type
208- } else
209- return failure();
210- }
211- return failure();
135+ FailureOr<Attribute> value = getSplatConstantElementAttribute(x);
136+ if(LogicalResult(value).failed()) return failure();
137+ if(*value == attr) return success();
138+ FloatAttr fvalue = dyn_cast<FloatAttr>(*value);
139+ FloatAttr fattr = dyn_cast<FloatAttr>(attr);
140+ return success(fvalue && fattr && std::abs(fvalue.getValueAsDouble() - fattr.getValueAsDouble()) < .001); // handle different floating point type
212141}];
213142
214143/// We need a native C++ function since we can't create the right
0 commit comments