Skip to content

Commit 6e5fa09

Browse files
authored
[CIR] Skip generation of a continue block when flattening TernaryOp (#1651)
We used to insert a continue Block at the end of a flattened ternary op that only contained a branch to the remaing operation of the remaining Block. This patch removes that continue block and changes the true/false blocks to directly jump to the remaining ops. With this patch the CIR now generates exactly the same LLVM IR as the original codegen.
1 parent 6e116b5 commit 6e5fa09

File tree

3 files changed

+13
-21
lines changed

3 files changed

+13
-21
lines changed

clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "PassDetail.h"
1414
#include "mlir/Dialect/Func/IR/FuncOps.h"
1515
#include "mlir/IR/PatternMatch.h"
16+
#include "mlir/IR/ValueRange.h"
1617
#include "mlir/Support/LogicalResult.h"
1718
#include "mlir/Transforms/DialectConversion.h"
1819
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -868,14 +869,6 @@ class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
868869
auto *condBlock = rewriter.getInsertionBlock();
869870
auto opPosition = rewriter.getInsertionPoint();
870871
auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
871-
llvm::SmallVector<mlir::Location, 2> locs;
872-
// Ternary result is optional, make sure to populate the location only
873-
// when relevant.
874-
if (op->getResultTypes().size())
875-
locs.push_back(loc);
876-
auto *continueBlock =
877-
rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs);
878-
rewriter.create<cir::BrOp>(loc, remainingOpsBlock);
879872

880873
auto &trueRegion = op.getTrueRegion();
881874
auto *trueBlock = &trueRegion.front();
@@ -884,24 +877,29 @@ class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
884877
auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator);
885878

886879
rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(),
887-
continueBlock);
888-
rewriter.inlineRegionBefore(trueRegion, continueBlock);
880+
remainingOpsBlock);
881+
rewriter.inlineRegionBefore(trueRegion, remainingOpsBlock);
889882

890-
auto *falseBlock = continueBlock;
891883
auto &falseRegion = op.getFalseRegion();
884+
auto *falseBlock = &falseRegion.front();
892885

893-
falseBlock = &falseRegion.front();
894886
mlir::Operation *falseTerminator = falseRegion.back().getTerminator();
895887
rewriter.setInsertionPointToEnd(&falseRegion.back());
896888
auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator);
897889
rewriter.replaceOpWithNewOp<cir::BrOp>(falseYieldOp, falseYieldOp.getArgs(),
898-
continueBlock);
899-
rewriter.inlineRegionBefore(falseRegion, continueBlock);
890+
remainingOpsBlock);
891+
rewriter.inlineRegionBefore(falseRegion, remainingOpsBlock);
900892

901893
rewriter.setInsertionPointToEnd(condBlock);
902894
rewriter.create<cir::BrCondOp>(loc, op.getCond(), trueBlock, falseBlock);
903895

904-
rewriter.replaceOp(op, continueBlock->getArguments());
896+
if (auto rt = op.getResultTypes(); rt.size()) {
897+
auto args = remainingOpsBlock->addArguments(rt, op.getLoc());
898+
SmallVector<mlir::Value, 2> values;
899+
llvm::copy(args, std::back_inserter(values));
900+
rewriter.replaceOpUsesWithinBlock(op, values, remainingOpsBlock);
901+
}
902+
rewriter.eraseOp(op);
905903

906904
// Ok, we're done!
907905
return mlir::success();

clang/test/CIR/Lowering/ternary.cir

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ cir.func @_Z1xi(%arg0: !s32i) -> !s32i {
4141
// MLIR-NEXT: %8 = llvm.mlir.constant(5 : i32) : i32
4242
// MLIR-NEXT: llvm.br ^bb3(%8 : i32)
4343
// MLIR-NEXT: ^bb3(%9: i32): // 2 preds: ^bb1, ^bb2
44-
// MLIR-NEXT: llvm.br ^bb4
45-
// MLIR-NEXT: ^bb4: // pred: ^bb3
4644
// MLIR-NEXT: llvm.store %9, %3 {{.*}}: i32, !llvm.ptr
4745
// MLIR-NEXT: %10 = llvm.load %3 {alignment = 4 : i64} : !llvm.ptr -> i32
4846
// MLIR-NEXT: llvm.return %10 : i32

clang/test/CIR/Transforms/ternary.cir

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ module {
3737
// CHECK: %6 = cir.const #cir.int<5> : !s32i
3838
// CHECK: cir.br ^bb3(%6 : !s32i)
3939
// CHECK: ^bb3(%7: !s32i): // 2 preds: ^bb1, ^bb2
40-
// CHECK: cir.br ^bb4
41-
// CHECK: ^bb4: // pred: ^bb3
4240
// CHECK: cir.store %7, %1 : !s32i, !cir.ptr<!s32i>
4341
// CHECK: %8 = cir.load %1 : !cir.ptr<!s32i>, !s32i
4442
// CHECK: cir.return %8 : !s32i
@@ -60,8 +58,6 @@ module {
6058
// CHECK: ^bb2: // pred: ^bb0
6159
// CHECK: cir.br ^bb3
6260
// CHECK: ^bb3: // 2 preds: ^bb1, ^bb2
63-
// CHECK: cir.br ^bb4
64-
// CHECK: ^bb4: // pred: ^bb3
6561
// CHECK: cir.return
6662
// CHECK: }
6763

0 commit comments

Comments
 (0)