Skip to content

Commit 74a43dd

Browse files
committed
merge amd-develop into amd-staging
2 parents a377466 + 8cb74e2 commit 74a43dd

File tree

986 files changed

+4522
-3576
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

986 files changed

+4522
-3576
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,5 @@ EXT(SPV_INTEL_ternary_bitwise_function)
8282
EXT(SPV_INTEL_int4)
8383
EXT(SPV_INTEL_function_variants)
8484
EXT(SPV_INTEL_shader_atomic_bfloat16)
85+
EXT(SPV_EXT_float8)
8586
EXT(SPV_INTEL_predicated_io)

include/LLVMSPIRVLib.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,11 @@
4343

4444
#include "LLVMSPIRVOpts.h"
4545

46+
#include <cstdint>
4647
#include <iostream>
48+
#include <optional>
4749
#include <string>
50+
#include <vector>
4851

4952
namespace llvm {
5053
// Pass initialization functions need to be declared before inclusion of

include/LLVMSPIRVOpts.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
#include <map>
4848
#include <optional>
4949
#include <unordered_map>
50+
#include <vector>
5051

5152
namespace llvm {
5253
class IntrinsicInst;

lib/SPIRV/OCLToSPIRV.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ using namespace OCLUtil;
6464
namespace SPIRV {
6565
static size_t getOCLCpp11AtomicMaxNumOps(StringRef Name) {
6666
return StringSwitch<size_t>(Name)
67-
.Cases("load", "flag_test_and_set", "flag_clear", 3)
68-
.Cases("store", "exchange", 4)
67+
.Cases({"load", "flag_test_and_set", "flag_clear"}, 3)
68+
.Cases({"store", "exchange"}, 4)
6969
.StartsWith("compare_exchange", 6)
7070
.StartsWith("fetch", 4)
7171
.Default(0);

lib/SPIRV/SPIRVInternal.h

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ const static char ConvertHandleToImageINTEL[] = "ConvertHandleToImageINTEL";
374374
const static char ConvertHandleToSamplerINTEL[] = "ConvertHandleToSamplerINTEL";
375375
const static char ConvertHandleToSampledImageINTEL[] =
376376
"ConvertHandleToSampledImageINTEL";
377+
const static char InternalBuiltinPrefix[] = "__builtin_spirv_";
377378
} // namespace kSPIRVName
378379

379380
namespace kSPIRVPostfix {
@@ -666,7 +667,7 @@ Op getSPIRVFuncOC(StringRef Name, SmallVectorImpl<std::string> *Dec = nullptr);
666667
bool getSPIRVBuiltin(const std::string &Name, spv::BuiltIn &Builtin);
667668

668669
/// \param Name LLVM function name
669-
/// \param DemangledName demanged name of the OpenCL built-in function
670+
/// \param DemangledName demangled name of the OpenCL built-in function
670671
/// \returns true if Name is the name of the OpenCL built-in function,
671672
/// false for other functions
672673
bool oclIsBuiltin(StringRef Name, StringRef &DemangledName, bool IsCpp = false);
@@ -729,6 +730,9 @@ CallInst *addCallInst(Module *M, StringRef FuncName, Type *RetTy,
729730
StringRef InstName = SPIR_TEMP_NAME_PREFIX_CALL,
730731
bool TakeFuncName = true);
731732

733+
/// Check if an LLVM type is spirv.CooperativeMatrixKHR.
734+
bool isLLVMCooperativeMatrixType(llvm::Type *Ty);
735+
732736
/// Add a call instruction for SPIR-V builtin function.
733737
CallInst *addCallInstSPIRV(Module *M, StringRef FuncName, Type *RetTy,
734738
ArrayRef<Value *> Args, AttributeList *Attrs,
@@ -1030,6 +1034,84 @@ bool postProcessBuiltinsReturningStruct(Module *M, bool IsCpp = false);
10301034

10311035
bool postProcessBuiltinsWithArrayArguments(Module *M, bool IsCpp = false);
10321036

1037+
/// \param MangledName LLVM function name.
1038+
/// \param DemangledName demangled name of the input function if it is the
1039+
/// translator's internal built-in function.
1040+
/// \returns true if MangledName is the name of the translator's internal
1041+
/// built-in function, false for other functions.
1042+
/// Used for 'mini'-floats conversion functions
1043+
bool isInternalSPIRVBuiltin(StringRef MangledName, StringRef &DemangledName);
1044+
1045+
// Wrapper around SPIR-V 1.6.4 FP Encoding to be used in the conversion
1046+
// descriptor
1047+
enum FPEncodingWrap {
1048+
Integer = FPEncoding::FPEncodingMax - 1,
1049+
IEEE754 = FPEncoding::FPEncodingMax,
1050+
BF16 = FPEncoding::FPEncodingBFloat16KHR,
1051+
E4M3 = FPEncoding::FPEncodingFloat8E4M3EXT,
1052+
E5M2 = FPEncoding::FPEncodingFloat8E5M2EXT,
1053+
};
1054+
1055+
// Structure describing non-trivial conversions (FP8 and int4)
1056+
struct FPConversionDesc {
1057+
FPEncodingWrap SrcEncoding;
1058+
FPEncodingWrap DstEncoding;
1059+
SPIRVWord ConvOpCode;
1060+
1061+
// To use as a key in std::map
1062+
bool operator==(const FPConversionDesc &Other) const {
1063+
return SrcEncoding == Other.SrcEncoding &&
1064+
DstEncoding == Other.DstEncoding && ConvOpCode == Other.ConvOpCode;
1065+
}
1066+
1067+
bool operator<(const FPConversionDesc &Other) const {
1068+
if (ConvOpCode != Other.ConvOpCode)
1069+
return ConvOpCode < Other.ConvOpCode;
1070+
if (SrcEncoding != Other.SrcEncoding)
1071+
return SrcEncoding < Other.SrcEncoding;
1072+
return DstEncoding < Other.DstEncoding;
1073+
}
1074+
};
1075+
1076+
// Maps internal builtin name to conversion descriptor
1077+
typedef SPIRVMap<llvm::StringRef, FPConversionDesc> FPConvertToEncodingMap;
1078+
1079+
// clang-format off
1080+
template <> inline void FPConvertToEncodingMap::init() {
1081+
// 8-bit conversions
1082+
add("ConvertE4M3ToFP16EXT",
1083+
{FPEncodingWrap::E4M3, FPEncodingWrap::IEEE754, OpFConvert});
1084+
add("ConvertE5M2ToFP16EXT",
1085+
{FPEncodingWrap::E5M2, FPEncodingWrap::IEEE754, OpFConvert});
1086+
add("ConvertE4M3ToBF16EXT",
1087+
{FPEncodingWrap::E4M3, FPEncodingWrap::BF16, OpFConvert});
1088+
add("ConvertE5M2ToBF16EXT",
1089+
{FPEncodingWrap::E5M2, FPEncodingWrap::BF16, OpFConvert});
1090+
add("ConvertFP16ToE4M3EXT",
1091+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E4M3, OpFConvert});
1092+
add("ConvertFP16ToE5M2EXT",
1093+
{FPEncodingWrap::IEEE754, FPEncodingWrap::E5M2, OpFConvert});
1094+
add("ConvertBF16ToE4M3EXT",
1095+
{FPEncodingWrap::BF16, FPEncodingWrap::E4M3, OpFConvert});
1096+
add("ConvertBF16ToE5M2EXT",
1097+
{FPEncodingWrap::BF16, FPEncodingWrap::E5M2, OpFConvert});
1098+
1099+
add("ConvertInt4ToE4M3INTEL",
1100+
{FPEncodingWrap::Integer, FPEncodingWrap::E4M3, OpConvertSToF});
1101+
add("ConvertInt4ToE5M2INTEL",
1102+
{FPEncodingWrap::Integer, FPEncodingWrap::E5M2, OpConvertSToF});
1103+
add("ConvertInt4ToFP16INTEL",
1104+
{FPEncodingWrap::Integer, FPEncodingWrap::IEEE754, OpConvertSToF});
1105+
add("ConvertInt4ToBF16INTEL",
1106+
{FPEncodingWrap::Integer, FPEncodingWrap::BF16, OpConvertSToF});
1107+
add("ConvertFP16ToInt4INTEL",
1108+
{FPEncodingWrap::IEEE754, FPEncodingWrap::Integer, OpConvertFToS});
1109+
add("ConvertBF16ToInt4INTEL",
1110+
{FPEncodingWrap::BF16, FPEncodingWrap::Integer, OpConvertFToS});
1111+
}
1112+
1113+
// clang-format on
1114+
10331115
} // namespace SPIRV
10341116

10351117
#endif // SPIRV_SPIRVINTERNAL_H

lib/SPIRV/SPIRVReader.cpp

Lines changed: 91 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,9 @@ std::optional<uint64_t> SPIRVToLLVM::getAlignment(SPIRVValue *V) {
298298

299299
Type *SPIRVToLLVM::transFPType(SPIRVType *T) {
300300
switch (T->getFloatBitWidth()) {
301+
case 8:
302+
// No LLVM IR counter part for FP8 - map it on i8
303+
return Type::getIntNTy(*Context, 8);
301304
case 16:
302305
if (T->isTypeFloat(16, FPEncodingBFloat16KHR))
303306
return Type::getBFloatTy(*Context);
@@ -1060,6 +1063,22 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
10601063
CastInst::CastOps CO = Instruction::BitCast;
10611064
bool IsExt =
10621065
Dst->getScalarSizeInBits() > Src->getType()->getScalarSizeInBits();
1066+
1067+
auto GetFPEncoding = [](SPIRVType *Ty) -> FPEncodingWrap {
1068+
if (Ty->isTypeFloat()) {
1069+
unsigned Enc =
1070+
static_cast<SPIRVTypeFloat *>(Ty)->getFloatingPointEncoding();
1071+
return static_cast<FPEncodingWrap>(Enc);
1072+
}
1073+
if (Ty->isTypeInt())
1074+
return FPEncodingWrap::Integer;
1075+
return FPEncodingWrap::IEEE754;
1076+
};
1077+
1078+
auto IsFP8Encoding = [](FPEncodingWrap Encoding) -> bool {
1079+
return Encoding == FPEncodingWrap::E4M3 || Encoding == FPEncodingWrap::E5M2;
1080+
};
1081+
10631082
switch (BC->getOpCode()) {
10641083
case OpPtrCastToGeneric:
10651084
case OpGenericCastToPtr:
@@ -1081,9 +1100,56 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
10811100
case OpUConvert:
10821101
CO = IsExt ? Instruction::ZExt : Instruction::Trunc;
10831102
break;
1084-
case OpFConvert:
1085-
CO = IsExt ? Instruction::FPExt : Instruction::FPTrunc;
1103+
case OpConvertSToF:
1104+
case OpConvertFToS:
1105+
case OpConvertUToF:
1106+
case OpConvertFToU:
1107+
case OpFConvert: {
1108+
const auto OC = BC->getOpCode();
1109+
{
1110+
auto SPVOps = BC->getOperands();
1111+
auto *SPVSrcTy = SPVOps[0]->getType();
1112+
auto *SPVDstTy = BC->getType();
1113+
1114+
auto GetEncodingAndUpdateType =
1115+
[GetFPEncoding](SPIRVType *&SPVTy) -> FPEncodingWrap {
1116+
if (SPVTy->isTypeVector()) {
1117+
SPVTy = SPVTy->getVectorComponentType();
1118+
} else if (SPVTy->isTypeCooperativeMatrixKHR()) {
1119+
auto *MT = static_cast<SPIRVTypeCooperativeMatrixKHR *>(SPVTy);
1120+
SPVTy = MT->getCompType();
1121+
}
1122+
return GetFPEncoding(SPVTy);
1123+
};
1124+
1125+
FPEncodingWrap SrcEnc = GetEncodingAndUpdateType(SPVSrcTy);
1126+
FPEncodingWrap DstEnc = GetEncodingAndUpdateType(SPVDstTy);
1127+
if (IsFP8Encoding(SrcEnc) || IsFP8Encoding(DstEnc) ||
1128+
SPVSrcTy->isTypeInt(4) || SPVDstTy->isTypeInt(4)) {
1129+
FPConversionDesc FPDesc = {SrcEnc, DstEnc,
1130+
static_cast<SPIRVWord>(BC->getOpCode())};
1131+
auto Conv = SPIRV::FPConvertToEncodingMap::rmap(FPDesc);
1132+
std::vector<Value *> Ops = {Src};
1133+
std::vector<Type *> OpsTys = {Src->getType()};
1134+
1135+
std::string BuiltinName =
1136+
kSPIRVName::InternalBuiltinPrefix + std::string(Conv);
1137+
BuiltinFuncMangleInfo Info;
1138+
std::string MangledName = mangleBuiltin(BuiltinName, OpsTys, &Info);
1139+
1140+
FunctionType *FTy = FunctionType::get(Dst, OpsTys, false);
1141+
FunctionCallee Func = M->getOrInsertFunction(MangledName, FTy);
1142+
return CallInst::Create(Func, Ops, "", BB);
1143+
}
1144+
}
1145+
1146+
if (OC == OpFConvert) {
1147+
CO = IsExt ? Instruction::FPExt : Instruction::FPTrunc;
1148+
break;
1149+
}
1150+
CO = static_cast<CastInst::CastOps>(OpCodeMap::rmap(OC));
10861151
break;
1152+
}
10871153
case OpBitcast:
10881154
if (Src->getType()->isPointerTy() && Dst->isPointerTy()) {
10891155
if (M->getTargetTriple().getVendor() == Triple::VendorType::AMD) {
@@ -1093,7 +1159,8 @@ Value *SPIRVToLLVM::transConvertInst(SPIRVValue *BV, Function *F,
10931159
else
10941160
return Src; // Spuriously inserted pointer BC.
10951161
}
1096-
} else if (Src->getType() == Dst) { // Spuriously inserted BC
1162+
} else if ((!Dst->isPointerTy() && Dst == Src->getType()) ||
1163+
(Src->getType() == Dst)) { // Spuriously inserted BC
10971164
return Src;
10981165
} else {
10991166
// OpBitcast need to be handled as a special-case when the source is a
@@ -3037,11 +3104,29 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
30373104
if (isCvtOpCode(OC) && OC != OpGenericCastToPtrExplicit) {
30383105
auto *BI = static_cast<SPIRVInstruction *>(BV);
30393106
Value *Inst = nullptr;
3040-
if (BI->hasFPRoundingMode() || BI->isSaturatedConversion() ||
3041-
BI->getType()->isTypeCooperativeMatrixKHR())
3107+
if (BI->hasFPRoundingMode() || BI->isSaturatedConversion()) {
30423108
Inst = transSPIRVBuiltinFromInst(BI, BB);
3043-
else
3109+
} else if (BI->getType()->isTypeCooperativeMatrixKHR()) {
3110+
// For cooperative matrix conversions generate __builtin_spirv
3111+
// conversions instead of __spirv_FConvert in case of mini-float
3112+
// type element type.
3113+
auto *OutMatrixElementTy =
3114+
static_cast<SPIRVTypeCooperativeMatrixKHR *>(BI->getType())
3115+
->getCompType();
3116+
auto *InMatrixElementTy =
3117+
static_cast<SPIRVTypeCooperativeMatrixKHR *>(
3118+
static_cast<SPIRVUnary *>(BI)->getOperand(0)->getType())
3119+
->getCompType();
3120+
if (OutMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E4M3EXT) ||
3121+
OutMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT) ||
3122+
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E4M3EXT) ||
3123+
InMatrixElementTy->isTypeFloat(8, FPEncodingFloat8E5M2EXT))
3124+
Inst = transConvertInst(BV, F, BB);
3125+
else
3126+
Inst = transSPIRVBuiltinFromInst(BI, BB);
3127+
} else {
30443128
Inst = transConvertInst(BV, F, BB);
3129+
}
30453130
return mapValue(BV, Inst);
30463131
}
30473132
return mapValue(

lib/SPIRV/SPIRVToOCL.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,13 @@ void SPIRVToOCLBase::visitCallGenericCastToPtrExplicitBuiltIn(CallInst *CI,
679679

680680
void SPIRVToOCLBase::visitCallSPIRVCvtBuiltin(CallInst *CI, Op OC,
681681
StringRef DemangledName) {
682+
if (auto *TET =
683+
dyn_cast<TargetExtType>(CI->getFunctionType()->getReturnType())) {
684+
// Preserve any cooperative matrix type conversions as SPIR-V calls.
685+
if (TET->getName() == "spirv.CooperativeMatrixKHR") {
686+
return;
687+
}
688+
}
682689
std::string CastBuiltInName;
683690
if (isCvtFromUnsignedOpCode(OC))
684691
CastBuiltInName = "u";

lib/SPIRV/SPIRVTypeScavenger.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,7 @@ bool SPIRVTypeScavenger::typeIntrinsicCall(
453453
case OpAtomicLoad:
454454
case OpAtomicExchange:
455455
case OpAtomicCompareExchange:
456+
case OpAtomicCompareExchangeWeak:
456457
case OpAtomicIAdd:
457458
case OpAtomicISub:
458459
case OpAtomicFAddEXT:

lib/SPIRV/SPIRVUtil.cpp

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
// This file needs to be included before anything that declares
4242
// llvm::PointerType to avoid a compilation bug on MSVC.
43+
#include "llvm/Demangle/Demangle.h"
4344
#include "llvm/Demangle/ItaniumDemangle.h"
4445

4546
#include "FunctionDescriptor.h"
@@ -267,6 +268,12 @@ bool isSYCLBfloat16Type(llvm::Type *Ty) {
267268
return false;
268269
}
269270

271+
bool isLLVMCooperativeMatrixType(llvm::Type *Ty) {
272+
if (auto *TargetTy = dyn_cast<TargetExtType>(Ty))
273+
return TargetTy->getName() == "spirv.CooperativeMatrixKHR";
274+
return false;
275+
}
276+
270277
Function *getOrCreateFunction(Module *M, Type *RetTy, ArrayRef<Type *> ArgTypes,
271278
StringRef Name, BuiltinFuncMangleInfo *Mangle,
272279
AttributeList *Attrs, bool TakeName) {
@@ -442,7 +449,7 @@ bool getSPIRVBuiltin(const std::string &OrigName, spv::BuiltIn &B) {
442449
return getByName(R.str(), B);
443450
}
444451

445-
// Demangled name is a substring of the name. The DemangledName is updated only
452+
// DemangledName is a substring of Name. The DemangledName is updated only
446453
// if true is returned
447454
bool oclIsBuiltin(StringRef Name, StringRef &DemangledName, bool IsCpp) {
448455
if (Name == "printf") {
@@ -487,6 +494,21 @@ bool oclIsBuiltin(StringRef Name, StringRef &DemangledName, bool IsCpp) {
487494
return false;
488495
}
489496

497+
// DemangledName is a substring of Name. The DemangledName is updated only
498+
// if true is returned.
499+
bool isInternalSPIRVBuiltin(StringRef Name, StringRef &DemangledName) {
500+
if (!Name.starts_with("_Z"))
501+
return false;
502+
constexpr unsigned DemangledNameLenStart = 2;
503+
size_t Start = Name.find_first_not_of("0123456789", DemangledNameLenStart);
504+
if (!Name.substr(Start, Name.size() - 1)
505+
.starts_with(kSPIRVName::InternalBuiltinPrefix))
506+
return false;
507+
DemangledName = llvm::itaniumDemangle(Name.data(), false);
508+
DemangledName.consume_front(kSPIRVName::InternalBuiltinPrefix);
509+
return true;
510+
}
511+
490512
// Check if a mangled type Name is unsigned
491513
bool isMangledTypeUnsigned(char Mangled) {
492514
return Mangled == 'h' /* uchar */
@@ -601,11 +623,11 @@ static std::string demangleBuiltinOpenCLTypeName(StringRef MangledStructName) {
601623
/// floating point type.
602624
static Type *parsePrimitiveType(LLVMContext &Ctx, StringRef Name) {
603625
return StringSwitch<Type *>(Name)
604-
.Cases("char", "signed char", "unsigned char", Type::getInt8Ty(Ctx))
605-
.Cases("short", "unsigned short", Type::getInt16Ty(Ctx))
606-
.Cases("int", "unsigned int", Type::getInt32Ty(Ctx))
607-
.Cases("long", "unsigned long", Type::getInt64Ty(Ctx))
608-
.Cases("long long", "unsigned long long", Type::getInt64Ty(Ctx))
626+
.Cases({"char", "signed char", "unsigned char"}, Type::getInt8Ty(Ctx))
627+
.Cases({"short", "unsigned short"}, Type::getInt16Ty(Ctx))
628+
.Cases({"int", "unsigned int"}, Type::getInt32Ty(Ctx))
629+
.Cases({"long", "unsigned long"}, Type::getInt64Ty(Ctx))
630+
.Cases({"long long", "unsigned long long"}, Type::getInt64Ty(Ctx))
609631
.Case("half", Type::getHalfTy(Ctx))
610632
.Case("float", Type::getFloatTy(Ctx))
611633
.Case("double", Type::getDoubleTy(Ctx))

0 commit comments

Comments
 (0)