Skip to content

Commit 07a5667

Browse files
committed
[SLP]Fix PR87477: fix alternate node cast cost/codegen.
Have to compare actual type size to pick up proper cast operation opcode.
1 parent 33992ea commit 07a5667

File tree

2 files changed

+74
-25
lines changed

2 files changed

+74
-25
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9063,25 +9063,35 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
90639063
cast<CmpInst>(E->getAltOp())->getPredicate(), CostKind,
90649064
E->getAltOp());
90659065
} else {
9066-
Type *Src0SclTy = E->getMainOp()->getOperand(0)->getType();
9067-
Type *Src1SclTy = E->getAltOp()->getOperand(0)->getType();
9068-
auto *Src0Ty = FixedVectorType::get(Src0SclTy, VL.size());
9069-
auto *Src1Ty = FixedVectorType::get(Src1SclTy, VL.size());
9070-
if (It != MinBWs.end()) {
9071-
if (!MinBWs.contains(getOperandEntry(E, 0)))
9072-
VecCost =
9073-
TTIRef.getCastInstrCost(Instruction::Trunc, VecTy, Src0Ty,
9074-
TTI::CastContextHint::None, CostKind);
9075-
LLVM_DEBUG({
9076-
dbgs() << "SLP: alternate extension, which should be truncated.\n";
9077-
E->dump();
9078-
});
9079-
return VecCost;
9066+
Type *SrcSclTy = E->getMainOp()->getOperand(0)->getType();
9067+
auto *SrcTy = FixedVectorType::get(SrcSclTy, VL.size());
9068+
if (SrcSclTy->isIntegerTy() && ScalarTy->isIntegerTy()) {
9069+
auto SrcIt = MinBWs.find(getOperandEntry(E, 0));
9070+
unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
9071+
unsigned SrcBWSz =
9072+
DL->getTypeSizeInBits(E->getMainOp()->getOperand(0)->getType());
9073+
if (SrcIt != MinBWs.end()) {
9074+
SrcBWSz = SrcIt->second.first;
9075+
SrcSclTy = IntegerType::get(SrcSclTy->getContext(), SrcBWSz);
9076+
SrcTy = FixedVectorType::get(SrcSclTy, VL.size());
9077+
}
9078+
if (BWSz <= SrcBWSz) {
9079+
if (BWSz < SrcBWSz)
9080+
VecCost =
9081+
TTIRef.getCastInstrCost(Instruction::Trunc, VecTy, SrcTy,
9082+
TTI::CastContextHint::None, CostKind);
9083+
LLVM_DEBUG({
9084+
dbgs()
9085+
<< "SLP: alternate extension, which should be truncated.\n";
9086+
E->dump();
9087+
});
9088+
return VecCost;
9089+
}
90809090
}
9081-
VecCost = TTIRef.getCastInstrCost(E->getOpcode(), VecTy, Src0Ty,
9091+
VecCost = TTIRef.getCastInstrCost(E->getOpcode(), VecTy, SrcTy,
90829092
TTI::CastContextHint::None, CostKind);
90839093
VecCost +=
9084-
TTIRef.getCastInstrCost(E->getAltOpcode(), VecTy, Src1Ty,
9094+
TTIRef.getCastInstrCost(E->getAltOpcode(), VecTy, SrcTy,
90859095
TTI::CastContextHint::None, CostKind);
90869096
}
90879097
SmallVector<int> Mask;
@@ -12591,15 +12601,20 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
1259112601
CmpInst::Predicate AltPred = AltCI->getPredicate();
1259212602
V1 = Builder.CreateCmp(AltPred, LHS, RHS);
1259312603
} else {
12594-
if (It != MinBWs.end()) {
12595-
if (!MinBWs.contains(getOperandEntry(E, 0)))
12596-
LHS = Builder.CreateIntCast(LHS, VecTy, It->second.first);
12597-
assert(LHS->getType() == VecTy && "Expected same type as operand.");
12598-
if (auto *I = dyn_cast<Instruction>(LHS))
12599-
LHS = propagateMetadata(I, E->Scalars);
12600-
E->VectorizedValue = LHS;
12601-
++NumVectorInstructions;
12602-
return LHS;
12604+
if (LHS->getType()->isIntOrIntVectorTy() && ScalarTy->isIntegerTy()) {
12605+
unsigned SrcBWSz = DL->getTypeSizeInBits(
12606+
cast<VectorType>(LHS->getType())->getElementType());
12607+
unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
12608+
if (BWSz <= SrcBWSz) {
12609+
if (BWSz < SrcBWSz)
12610+
LHS = Builder.CreateIntCast(LHS, VecTy, It->second.first);
12611+
assert(LHS->getType() == VecTy && "Expected same type as operand.");
12612+
if (auto *I = dyn_cast<Instruction>(LHS))
12613+
LHS = propagateMetadata(I, E->Scalars);
12614+
E->VectorizedValue = LHS;
12615+
++NumVectorInstructions;
12616+
return LHS;
12617+
}
1260312618
}
1260412619
V0 = Builder.CreateCast(
1260512620
static_cast<Instruction::CastOps>(E->getOpcode()), LHS, VecTy);
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
2+
; RUN: opt -S --passes=slp-vectorizer -mtriple=systemz-unknown -mcpu=z15 < %s -slp-threshold=-10 | FileCheck %s
3+
4+
define i32 @test(ptr %0, ptr %1) {
5+
; CHECK-LABEL: define i32 @test(
6+
; CHECK-SAME: ptr [[TMP0:%.*]], ptr [[TMP1:%.*]]) #[[ATTR0:[0-9]+]] {
7+
; CHECK-NEXT: [[TMP3:%.*]] = load i64, ptr inttoptr (i64 32 to ptr), align 32
8+
; CHECK-NEXT: [[TMP4:%.*]] = load ptr, ptr [[TMP1]], align 8
9+
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds i8, ptr [[TMP4]], i64 32
10+
; CHECK-NEXT: [[TMP6:%.*]] = load i64, ptr [[TMP5]], align 8
11+
; CHECK-NEXT: [[TMP7:%.*]] = insertelement <2 x i64> poison, i64 [[TMP6]], i32 0
12+
; CHECK-NEXT: [[TMP14:%.*]] = insertelement <2 x i64> [[TMP7]], i64 [[TMP3]], i32 1
13+
; CHECK-NEXT: [[TMP9:%.*]] = icmp ne <2 x i64> [[TMP14]], zeroinitializer
14+
; CHECK-NEXT: [[TMP16:%.*]] = sext <2 x i1> [[TMP9]] to <2 x i8>
15+
; CHECK-NEXT: [[TMP11:%.*]] = zext <2 x i1> [[TMP9]] to <2 x i8>
16+
; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <2 x i8> [[TMP16]], <2 x i8> [[TMP11]], <2 x i32> <i32 0, i32 3>
17+
; CHECK-NEXT: [[TMP13:%.*]] = extractelement <2 x i8> [[TMP12]], i32 0
18+
; CHECK-NEXT: [[DOTNEG:%.*]] = sext i8 [[TMP13]] to i32
19+
; CHECK-NEXT: [[TMP15:%.*]] = extractelement <2 x i8> [[TMP12]], i32 1
20+
; CHECK-NEXT: [[TMP8:%.*]] = sext i8 [[TMP15]] to i32
21+
; CHECK-NEXT: [[TMP10:%.*]] = add nsw i32 [[DOTNEG]], [[TMP8]]
22+
; CHECK-NEXT: ret i32 [[TMP10]]
23+
;
24+
%3 = load i64, ptr inttoptr (i64 32 to ptr), align 32
25+
%4 = load ptr, ptr %1, align 8
26+
%5 = getelementptr inbounds i8, ptr %4, i64 32
27+
%6 = load i64, ptr %5, align 8
28+
%7 = icmp ne i64 %3, 0
29+
%8 = zext i1 %7 to i32
30+
%9 = icmp ne i64 %6, 0
31+
%.neg = sext i1 %9 to i32
32+
%10 = add nsw i32 %.neg, %8
33+
ret i32 %10
34+
}

0 commit comments

Comments
 (0)