13
13
#include " PassDetail.h"
14
14
#include " mlir/Dialect/Func/IR/FuncOps.h"
15
15
#include " mlir/IR/PatternMatch.h"
16
+ #include " mlir/IR/ValueRange.h"
16
17
#include " mlir/Support/LogicalResult.h"
17
18
#include " mlir/Transforms/DialectConversion.h"
18
19
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -868,14 +869,6 @@ class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
868
869
auto *condBlock = rewriter.getInsertionBlock ();
869
870
auto opPosition = rewriter.getInsertionPoint ();
870
871
auto *remainingOpsBlock = rewriter.splitBlock (condBlock, opPosition);
871
- llvm::SmallVector<mlir::Location, 2 > locs;
872
- // Ternary result is optional, make sure to populate the location only
873
- // when relevant.
874
- if (op->getResultTypes ().size ())
875
- locs.push_back (loc);
876
- auto *continueBlock =
877
- rewriter.createBlock (remainingOpsBlock, op->getResultTypes (), locs);
878
- rewriter.create <cir::BrOp>(loc, remainingOpsBlock);
879
872
880
873
auto &trueRegion = op.getTrueRegion ();
881
874
auto *trueBlock = &trueRegion.front ();
@@ -884,24 +877,29 @@ class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
884
877
auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator);
885
878
886
879
rewriter.replaceOpWithNewOp <cir::BrOp>(trueYieldOp, trueYieldOp.getArgs (),
887
- continueBlock );
888
- rewriter.inlineRegionBefore (trueRegion, continueBlock );
880
+ remainingOpsBlock );
881
+ rewriter.inlineRegionBefore (trueRegion, remainingOpsBlock );
889
882
890
- auto *falseBlock = continueBlock;
891
883
auto &falseRegion = op.getFalseRegion ();
884
+ auto *falseBlock = &falseRegion.front ();
892
885
893
- falseBlock = &falseRegion.front ();
894
886
mlir::Operation *falseTerminator = falseRegion.back ().getTerminator ();
895
887
rewriter.setInsertionPointToEnd (&falseRegion.back ());
896
888
auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator);
897
889
rewriter.replaceOpWithNewOp <cir::BrOp>(falseYieldOp, falseYieldOp.getArgs (),
898
- continueBlock );
899
- rewriter.inlineRegionBefore (falseRegion, continueBlock );
890
+ remainingOpsBlock );
891
+ rewriter.inlineRegionBefore (falseRegion, remainingOpsBlock );
900
892
901
893
rewriter.setInsertionPointToEnd (condBlock);
902
894
rewriter.create <cir::BrCondOp>(loc, op.getCond (), trueBlock, falseBlock);
903
895
904
- rewriter.replaceOp (op, continueBlock->getArguments ());
896
+ if (auto rt = op.getResultTypes (); rt.size ()) {
897
+ auto args = remainingOpsBlock->addArguments (rt, op.getLoc ());
898
+ SmallVector<mlir::Value, 2 > values;
899
+ llvm::copy (args, std::back_inserter (values));
900
+ rewriter.replaceOpUsesWithinBlock (op, values, remainingOpsBlock);
901
+ }
902
+ rewriter.eraseOp (op);
905
903
906
904
// Ok, we're done!
907
905
return mlir::success ();
0 commit comments