Skip to content

Commit f4441cb

Browse files
committed
[SimplifyCFG] Set branch weights when merging conditional store to address
1 parent 7be4626 commit f4441cb

File tree

2 files changed

+48
-13
lines changed

2 files changed

+48
-13
lines changed

llvm/include/llvm/IR/ProfDataUtils.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#ifndef LLVM_IR_PROFDATAUTILS_H
1616
#define LLVM_IR_PROFDATAUTILS_H
1717

18+
#include "llvm/ADT/STLExtras.h"
1819
#include "llvm/ADT/SmallVector.h"
1920
#include "llvm/ADT/Twine.h"
2021
#include "llvm/IR/Metadata.h"
@@ -186,5 +187,26 @@ LLVM_ABI bool hasExplicitlyUnknownBranchWeights(const Instruction &I);
186187
/// Scaling the profile data attached to 'I' using the ratio of S/T.
187188
LLVM_ABI void scaleProfData(Instruction &I, uint64_t S, uint64_t T);
188189

190+
/// get the branch weights of a branch conditioned on b1 || b2, where b1 and b2
191+
/// are 2 booleans that are the condition of 2 branches for which we have the
192+
/// branch weights B1 and B2, respectivelly.
193+
inline SmallVector<uint64_t, 2>
194+
getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
195+
const SmallVector<uint32_t, 2> &B2) {
196+
// the probability of the new branch being taken is:
197+
// P = p(b1) + p(b2) - p (b1 and b2)
198+
// not P = p((not b1) and (not b2)) =
199+
// = B1[1] / (B1[0]+B1[1]) * B2[1] / (B2[0]+B2[1]) =
200+
// = B1[1] * B2[1] / (B1[0] + B1[1]) * (B2[0] + B2[1])
201+
// P = 1 - (not P)
202+
// The numerator of P will be (B1[0] + B1[1]) * (B2[0] + B2[1]) - B1[1]*B2[1]
203+
// ... which becomes what's shown below.
204+
// We don't need the denominators, they are the same
205+
assert(B1.size() == 2);
206+
assert(B2.size() == 2);
207+
auto FalseWeight = B1[1] * B2[1];
208+
auto TrueWeight = B1[0] * B2[0] + B1[0] * B2[1] + B1[1] * B2[0];
209+
return {TrueWeight, FalseWeight};
210+
}
189211
} // namespace llvm
190212
#endif

llvm/lib/Transforms/Utils/SimplifyCFG.cpp

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,7 +1182,7 @@ static void cloneInstructionsIntoPredecessorBlockAndUpdateSSAUses(
11821182
// only given the branch precondition.
11831183
// Similarly strip attributes on call parameters that may cause UB in
11841184
// location the call is moved to.
1185-
NewBonusInst->dropUBImplyingAttrsAndMetadata();
1185+
NewBonusInst->dropUBImplyingAttrsAndMetadata({LLVMContext::MD_prof});
11861186

11871187
NewBonusInst->insertInto(PredBlock, PTI->getIterator());
11881188
auto Range = NewBonusInst->cloneDebugInfoFrom(&BonusInst);
@@ -1808,7 +1808,8 @@ static void hoistConditionalLoadsStores(
18081808
// !annotation: Not impact semantics. Keep it.
18091809
if (const MDNode *Ranges = I->getMetadata(LLVMContext::MD_range))
18101810
MaskedLoadStore->addRangeRetAttr(getConstantRangeFromMetadata(*Ranges));
1811-
I->dropUBImplyingAttrsAndUnknownMetadata({LLVMContext::MD_annotation});
1811+
I->dropUBImplyingAttrsAndUnknownMetadata(
1812+
{LLVMContext::MD_annotation, LLVMContext::MD_prof});
18121813
// FIXME: DIAssignID is not supported for masked store yet.
18131814
// (Verifier::visitDIAssignIDMetadata)
18141815
at::deleteAssignmentMarkers(I);
@@ -3366,7 +3367,7 @@ bool SimplifyCFGOpt::speculativelyExecuteBB(BranchInst *BI,
33663367
if (!SpeculatedStoreValue || &I != SpeculatedStore) {
33673368
I.setDebugLoc(DebugLoc::getDropped());
33683369
}
3369-
I.dropUBImplyingAttrsAndMetadata();
3370+
I.dropUBImplyingAttrsAndMetadata({LLVMContext::MD_prof});
33703371

33713372
// Drop ephemeral values.
33723373
if (EphTracker.contains(&I)) {
@@ -4404,10 +4405,12 @@ static bool mergeConditionalStoreToAddress(
44044405

44054406
// OK, we're going to sink the stores to PostBB. The store has to be
44064407
// conditional though, so first create the predicate.
4407-
Value *PCond = cast<BranchInst>(PFB->getSinglePredecessor()->getTerminator())
4408-
->getCondition();
4409-
Value *QCond = cast<BranchInst>(QFB->getSinglePredecessor()->getTerminator())
4410-
->getCondition();
4408+
BranchInst *const PBranch =
4409+
cast<BranchInst>(PFB->getSinglePredecessor()->getTerminator());
4410+
BranchInst *const QBranch =
4411+
cast<BranchInst>(QFB->getSinglePredecessor()->getTerminator());
4412+
Value *const PCond = PBranch->getCondition();
4413+
Value *const QCond = QBranch->getCondition();
44114414

44124415
Value *PPHI = ensureValueAvailableInSuccessor(PStore->getValueOperand(),
44134416
PStore->getParent());
@@ -4418,19 +4421,29 @@ static bool mergeConditionalStoreToAddress(
44184421
IRBuilder<> QB(PostBB, PostBBFirst);
44194422
QB.SetCurrentDebugLocation(PostBBFirst->getStableDebugLoc());
44204423

4421-
Value *PPred = PStore->getParent() == PTB ? PCond : QB.CreateNot(PCond);
4422-
Value *QPred = QStore->getParent() == QTB ? QCond : QB.CreateNot(QCond);
4424+
InvertPCond = (PStore->getParent() == PTB) ^ InvertPCond;
4425+
InvertQCond = (QStore->getParent() == QTB) ^ InvertQCond;
4426+
Value *const PPred = InvertPCond ? PCond : QB.CreateNot(PCond);
4427+
Value *const QPred = InvertQCond ? QCond : QB.CreateNot(QCond);
44234428

4424-
if (InvertPCond)
4425-
PPred = QB.CreateNot(PPred);
4426-
if (InvertQCond)
4427-
QPred = QB.CreateNot(QPred);
44284429
Value *CombinedPred = QB.CreateOr(PPred, QPred);
44294430

44304431
BasicBlock::iterator InsertPt = QB.GetInsertPoint();
44314432
auto *T = SplitBlockAndInsertIfThen(CombinedPred, InsertPt,
44324433
/*Unreachable=*/false,
44334434
/*BranchWeights=*/nullptr, DTU);
4435+
if (hasBranchWeightMD(*PBranch) && hasBranchWeightMD(*QBranch)) {
4436+
SmallVector<uint32_t, 2> PWeights, QWeights;
4437+
extractBranchWeights(*PBranch, PWeights);
4438+
extractBranchWeights(*QBranch, QWeights);
4439+
if (InvertPCond)
4440+
std::swap(PWeights[0], PWeights[1]);
4441+
if (InvertQCond)
4442+
std::swap(QWeights[0], QWeights[1]);
4443+
auto CombinedWeights = getDisjunctionWeights(PWeights, QWeights);
4444+
setBranchWeights(T, CombinedWeights[0], CombinedWeights[1],
4445+
/*IsExpected=*/false);
4446+
}
44344447

44354448
QB.SetInsertPoint(T);
44364449
StoreInst *SI = cast<StoreInst>(QB.CreateStore(QPHI, Address));

0 commit comments

Comments
 (0)