Skip to content

[SimplifyCFG] Set branch weights when merging conditional store to address #154841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: users/mtrofin/08-20-_local_preserve_md_prof_in_hoistallinstructionsinto_
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions llvm/include/llvm/IR/ProfDataUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#ifndef LLVM_IR_PROFDATAUTILS_H
#define LLVM_IR_PROFDATAUTILS_H

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/IR/Metadata.h"
Expand Down Expand Up @@ -186,5 +187,31 @@ LLVM_ABI bool hasExplicitlyUnknownBranchWeights(const Instruction &I);
/// Scaling the profile data attached to 'I' using the ratio of S/T.
LLVM_ABI void scaleProfData(Instruction &I, uint64_t S, uint64_t T);

/// get the branch weights of a branch conditioned on b1 || b2, where b1 and b2
/// are 2 booleans that are the condition of 2 branches for which we have the
/// branch weights B1 and B2, respectivelly.
inline SmallVector<uint64_t, 2>
getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
const SmallVector<uint32_t, 2> &B2) {
// for the first conditional branch, the probability the "true" case is taken
// is p(b1) = B1[0] / (B1[0] + B2[0]). The "false" case's probability is
// p(not b1) = B1[1] / (B1[0] + B1[1]).
// Similarly for the second conditional branch and B2.
//
// the probability of the new branch NOT being taken is:
// not P = p((not b1) and (not b2)) =
// = B1[1] / (B1[0]+B1[1]) * B2[1] / (B2[0]+B2[1]) =
// = B1[1] * B2[1] / (B1[0] + B1[1]) * (B2[0] + B2[1])
// then the probability of it being taken is: P = 1 - (not P).
// The denominator will be the same as above, and the numerator of P will be
// (B1[0] + B1[1]) * (B2[0] + B2[1]) - B1[1]*B2[1]
// Which then reduces to what's shown below (out of the 4 terms coming out of
// the product of sums, the subtracted one cancels out)
assert(B1.size() == 2);
assert(B2.size() == 2);
auto FalseWeight = B1[1] * B2[1];
auto TrueWeight = B1[0] * B2[0] + B1[0] * B2[1] + B1[1] * B2[0];
return {TrueWeight, FalseWeight};
}
} // namespace llvm
#endif
40 changes: 27 additions & 13 deletions llvm/lib/Transforms/Utils/SimplifyCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1182,7 +1182,7 @@ static void cloneInstructionsIntoPredecessorBlockAndUpdateSSAUses(
// only given the branch precondition.
// Similarly strip attributes on call parameters that may cause UB in
// location the call is moved to.
NewBonusInst->dropUBImplyingAttrsAndMetadata();
NewBonusInst->dropUBImplyingAttrsAndMetadata({LLVMContext::MD_prof});

NewBonusInst->insertInto(PredBlock, PTI->getIterator());
auto Range = NewBonusInst->cloneDebugInfoFrom(&BonusInst);
Expand Down Expand Up @@ -1808,7 +1808,8 @@ static void hoistConditionalLoadsStores(
// !annotation: Not impact semantics. Keep it.
if (const MDNode *Ranges = I->getMetadata(LLVMContext::MD_range))
MaskedLoadStore->addRangeRetAttr(getConstantRangeFromMetadata(*Ranges));
I->dropUBImplyingAttrsAndUnknownMetadata({LLVMContext::MD_annotation});
I->dropUBImplyingAttrsAndUnknownMetadata(
{LLVMContext::MD_annotation, LLVMContext::MD_prof});
// FIXME: DIAssignID is not supported for masked store yet.
// (Verifier::visitDIAssignIDMetadata)
at::deleteAssignmentMarkers(I);
Expand Down Expand Up @@ -3366,7 +3367,7 @@ bool SimplifyCFGOpt::speculativelyExecuteBB(BranchInst *BI,
if (!SpeculatedStoreValue || &I != SpeculatedStore) {
I.setDebugLoc(DebugLoc::getDropped());
}
I.dropUBImplyingAttrsAndMetadata();
I.dropUBImplyingAttrsAndMetadata({LLVMContext::MD_prof});

// Drop ephemeral values.
if (EphTracker.contains(&I)) {
Expand Down Expand Up @@ -4404,10 +4405,12 @@ static bool mergeConditionalStoreToAddress(

// OK, we're going to sink the stores to PostBB. The store has to be
// conditional though, so first create the predicate.
Value *PCond = cast<BranchInst>(PFB->getSinglePredecessor()->getTerminator())
->getCondition();
Value *QCond = cast<BranchInst>(QFB->getSinglePredecessor()->getTerminator())
->getCondition();
BranchInst *const PBranch =
cast<BranchInst>(PFB->getSinglePredecessor()->getTerminator());
BranchInst *const QBranch =
cast<BranchInst>(QFB->getSinglePredecessor()->getTerminator());
Value *const PCond = PBranch->getCondition();
Value *const QCond = QBranch->getCondition();

Value *PPHI = ensureValueAvailableInSuccessor(PStore->getValueOperand(),
PStore->getParent());
Expand All @@ -4418,19 +4421,30 @@ static bool mergeConditionalStoreToAddress(
IRBuilder<> QB(PostBB, PostBBFirst);
QB.SetCurrentDebugLocation(PostBBFirst->getStableDebugLoc());

Value *PPred = PStore->getParent() == PTB ? PCond : QB.CreateNot(PCond);
Value *QPred = QStore->getParent() == QTB ? QCond : QB.CreateNot(QCond);
InvertPCond = (PStore->getParent() == PTB) ^ InvertPCond;
InvertQCond = (QStore->getParent() == QTB) ^ InvertQCond;
Value *const PPred = InvertPCond ? PCond : QB.CreateNot(PCond);
Value *const QPred = InvertQCond ? QCond : QB.CreateNot(QCond);

if (InvertPCond)
PPred = QB.CreateNot(PPred);
if (InvertQCond)
QPred = QB.CreateNot(QPred);
Value *CombinedPred = QB.CreateOr(PPred, QPred);

BasicBlock::iterator InsertPt = QB.GetInsertPoint();
auto *T = SplitBlockAndInsertIfThen(CombinedPred, InsertPt,
/*Unreachable=*/false,
/*BranchWeights=*/nullptr, DTU);
if (hasBranchWeightMD(*PBranch) && hasBranchWeightMD(*QBranch)) {
SmallVector<uint32_t, 2> PWeights, QWeights;
extractBranchWeights(*PBranch, PWeights);
extractBranchWeights(*QBranch, QWeights);
if (InvertPCond)
std::swap(PWeights[0], PWeights[1]);
if (InvertQCond)
std::swap(QWeights[0], QWeights[1]);
auto CombinedWeights = getDisjunctionWeights(PWeights, QWeights);
setBranchWeights(PostBB->getTerminator(), CombinedWeights[0],
CombinedWeights[1],
/*IsExpected=*/false);
}

QB.SetInsertPoint(T);
StoreInst *SI = cast<StoreInst>(QB.CreateStore(QPHI, Address));
Expand Down
Loading