diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h index 404875285beae..ebf8559cd3d91 100644 --- a/llvm/include/llvm/IR/ProfDataUtils.h +++ b/llvm/include/llvm/IR/ProfDataUtils.h @@ -15,6 +15,7 @@ #ifndef LLVM_IR_PROFDATAUTILS_H #define LLVM_IR_PROFDATAUTILS_H +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" #include "llvm/IR/Metadata.h" @@ -186,5 +187,31 @@ LLVM_ABI bool hasExplicitlyUnknownBranchWeights(const Instruction &I); /// Scaling the profile data attached to 'I' using the ratio of S/T. LLVM_ABI void scaleProfData(Instruction &I, uint64_t S, uint64_t T); +/// get the branch weights of a branch conditioned on b1 || b2, where b1 and b2 +/// are 2 booleans that are the condition of 2 branches for which we have the +/// branch weights B1 and B2, respectivelly. +inline SmallVector +getDisjunctionWeights(const SmallVector &B1, + const SmallVector &B2) { + // for the first conditional branch, the probability the "true" case is taken + // is p(b1) = B1[0] / (B1[0] + B2[0]). The "false" case's probability is + // p(not b1) = B1[1] / (B1[0] + B1[1]). + // Similarly for the second conditional branch and B2. + // + // the probability of the new branch NOT being taken is: + // not P = p((not b1) and (not b2)) = + // = B1[1] / (B1[0]+B1[1]) * B2[1] / (B2[0]+B2[1]) = + // = B1[1] * B2[1] / (B1[0] + B1[1]) * (B2[0] + B2[1]) + // then the probability of it being taken is: P = 1 - (not P). + // The denominator will be the same as above, and the numerator of P will be + // (B1[0] + B1[1]) * (B2[0] + B2[1]) - B1[1]*B2[1] + // Which then reduces to what's shown below (out of the 4 terms coming out of + // the product of sums, the subtracted one cancels out) + assert(B1.size() == 2); + assert(B2.size() == 2); + auto FalseWeight = B1[1] * B2[1]; + auto TrueWeight = B1[0] * B2[0] + B1[0] * B2[1] + B1[1] * B2[0]; + return {TrueWeight, FalseWeight}; +} } // namespace llvm #endif diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index ef110a6922f05..6170e342c5ec2 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -203,6 +203,8 @@ static cl::opt MaxJumpThreadingLiveBlocks( cl::desc("Limit number of blocks a define in a threaded block is allowed " "to be live in")); +extern cl::opt ProfcheckDisableMetadataFixes; + STATISTIC(NumBitMaps, "Number of switch instructions turned into bitmaps"); STATISTIC(NumLinearMaps, "Number of switch instructions turned into linear mapping"); @@ -4431,6 +4433,20 @@ static bool mergeConditionalStoreToAddress( auto *T = SplitBlockAndInsertIfThen(CombinedPred, InsertPt, /*Unreachable=*/false, /*BranchWeights=*/nullptr, DTU); + if (hasBranchWeightMD(*PBranch) && hasBranchWeightMD(*QBranch) && + !ProfcheckDisableMetadataFixes) { + SmallVector PWeights, QWeights; + extractBranchWeights(*PBranch, PWeights); + extractBranchWeights(*QBranch, QWeights); + if (InvertPCond) + std::swap(PWeights[0], PWeights[1]); + if (InvertQCond) + std::swap(QWeights[0], QWeights[1]); + auto CombinedWeights = getDisjunctionWeights(PWeights, QWeights); + setBranchWeights(PostBB->getTerminator(), CombinedWeights[0], + CombinedWeights[1], + /*IsExpected=*/false); + } QB.SetInsertPoint(T); StoreInst *SI = cast(QB.CreateStore(QPHI, Address)); diff --git a/llvm/test/Transforms/SimplifyCFG/merge-cond-stores.ll b/llvm/test/Transforms/SimplifyCFG/merge-cond-stores.ll index e1bd7916b3be0..b1cce4484bbab 100644 --- a/llvm/test/Transforms/SimplifyCFG/merge-cond-stores.ll +++ b/llvm/test/Transforms/SimplifyCFG/merge-cond-stores.ll @@ -1,4 +1,4 @@ -; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals ; RUN: opt -passes=simplifycfg,instcombine -simplifycfg-require-and-preserve-domtree=1 < %s -simplifycfg-merge-cond-stores=true -simplifycfg-merge-cond-stores-aggressively=false -phi-node-folding-threshold=2 -S | FileCheck %s ; This test should succeed and end up if-converted. @@ -43,7 +43,7 @@ define void @test_simple_commuted(ptr %p, i32 %a, i32 %b) { ; CHECK-NEXT: [[X2:%.*]] = icmp eq i32 [[B:%.*]], 0 ; CHECK-NEXT: [[X3:%.*]] = icmp eq i32 [[B1:%.*]], 0 ; CHECK-NEXT: [[TMP0:%.*]] = or i1 [[X2]], [[X3]] -; CHECK-NEXT: br i1 [[TMP0]], label [[TMP1:%.*]], label [[TMP2:%.*]] +; CHECK-NEXT: br i1 [[TMP0]], label [[TMP1:%.*]], label [[TMP2:%.*]], !prof [[PROF0:![0-9]+]] ; CHECK: 1: ; CHECK-NEXT: [[SPEC_SELECT:%.*]] = zext i1 [[X3]] to i32 ; CHECK-NEXT: store i32 [[SPEC_SELECT]], ptr [[P:%.*]], align 4 @@ -53,7 +53,7 @@ define void @test_simple_commuted(ptr %p, i32 %a, i32 %b) { ; entry: %x1 = icmp eq i32 %a, 0 - br i1 %x1, label %yes1, label %fallthrough + br i1 %x1, label %yes1, label %fallthrough, !prof !0 yes1: store i32 0, ptr %p @@ -61,7 +61,7 @@ yes1: fallthrough: %x2 = icmp eq i32 %b, 0 - br i1 %x2, label %yes2, label %end + br i1 %x2, label %yes2, label %end, !prof !1 yes2: store i32 1, ptr %p @@ -406,3 +406,9 @@ yes2: end: ret void } + +!0 = !{!"branch_weights", i32 7, i32 13} +!1 = !{!"branch_weights", i32 3, i32 11} +;. +; CHECK: [[PROF0]] = !{!"branch_weights", i32 137, i32 143} +;.