Skip to content

[HLSL][SPIRV] Add vk::constant_id attribute. #143180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -4993,6 +4993,14 @@ def HLSLVkExtBuiltinInput : InheritableAttr {
let Documentation = [HLSLVkExtBuiltinInputDocs];
}

def HLSLVkConstantId : InheritableAttr {
let Spellings = [CXX11<"vk", "constant_id">];
let Args = [IntArgument<"Id">];
let Subjects = SubjectList<[ExternalGlobalVar]>;
let LangOpts = [HLSL];
let Documentation = [VkConstantIdDocs];
}

def RandomizeLayout : InheritableAttr {
let Spellings = [GCC<"randomize_layout">];
let Subjects = SubjectList<[Record]>;
Expand Down
15 changes: 15 additions & 0 deletions clang/include/clang/Basic/AttrDocs.td
Original file line number Diff line number Diff line change
Expand Up @@ -8247,6 +8247,21 @@ and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html
}];
}

def VkConstantIdDocs : Documentation {
let Category = DocCatFunction;
let Content = [{
The ``vk::constant_id`` attribute specify the id for a SPIR-V specialization
constant. The attribute applies to const global scalar variables. The variable must be initialized with a C++11 constexpr.
In SPIR-V, the
variable will be replaced with an `OpSpecConstant` with the given id.
The syntax is:

.. code-block:: text

``[[vk::constant_id(<Id>)]] const T Name = <Init>``
}];
}

def RootSignatureDocs : Documentation {
let Category = DocCatFunction;
let Content = [{
Expand Down
13 changes: 13 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -5059,6 +5059,19 @@ def HLSLGroupMemoryBarrierWithGroupSync: LangBuiltin<"HLSL_LANG"> {
let Prototype = "void()";
}

class HLSLScalarTemplate
: Template<["bool", "char", "short", "int", "long long int",
"unsigned short", "unsigned int", "unsigned long long int",
"__fp16", "float", "double"],
["_bool", "_char", "_short", "_int", "_longlong", "_ushort",
"_uint", "_ulonglong", "_half", "_float", "_double"]>;

def HLSLGetSpirvSpecConstant : LangBuiltin<"HLSL_LANG">, HLSLScalarTemplate {
let Spellings = ["__builtin_get_spirv_spec_constant"];
let Attributes = [NoThrow, Const, Pure];
let Prototype = "T(unsigned int, T)";
}

// Builtins for XRay.
def XRayCustomEvent : Builtin {
let Spellings = ["__xray_customevent"];
Expand Down
12 changes: 12 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -12904,6 +12904,18 @@ def err_spirv_enum_not_int : Error<
def err_spirv_enum_not_valid : Error<
"invalid value for %select{storage class}0 argument">;

def err_specialization_const_lit_init
: Error<"variable with 'vk::constant_id' attribute cannot have an "
"initializer that is not a constexpr">;
def err_specialization_const_missing_initializer
: Error<
"variable with 'vk::constant_id' attribute must have an initializer">;
def err_specialization_const_missing_const
: Error<"variable with 'vk::constant_id' attribute must be const">;
def err_specialization_const_is_not_int_or_float
: Error<"variable with 'vk::constant_id' attribute must be an enum, bool, "
"integer, or floating point value">;

// errors of expect.with.probability
def err_probability_not_constant_float : Error<
"probability argument to __builtin_expect_with_probability must be constant "
Expand Down
5 changes: 4 additions & 1 deletion clang/include/clang/Sema/SemaHLSL.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class SemaHLSL : public SemaBase {
HLSLWaveSizeAttr *mergeWaveSizeAttr(Decl *D, const AttributeCommonInfo &AL,
int Min, int Max, int Preferred,
int SpelledArgsCount);
HLSLVkConstantIdAttr *
mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL, int Id);
HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
llvm::Triple::EnvironmentType ShaderType);
HLSLParamModifierAttr *
Expand All @@ -122,6 +124,7 @@ class SemaHLSL : public SemaBase {
void handleRootSignatureAttr(Decl *D, const ParsedAttr &AL);
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL);
void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL);
void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);
Expand Down Expand Up @@ -156,7 +159,7 @@ class SemaHLSL : public SemaBase {
QualType getInoutParameterType(QualType Ty);

bool transformInitList(const InitializedEntity &Entity, InitListExpr *Init);

bool handleInitialization(VarDecl *VDecl, Expr *&Init);
void deduceAddressSpace(VarDecl *Decl);

private:
Expand Down
3 changes: 2 additions & 1 deletion clang/lib/Basic/Attributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ getScopeFromNormalizedScopeName(StringRef ScopeName) {
.Case("vk", AttributeCommonInfo::Scope::VK)
.Case("msvc", AttributeCommonInfo::Scope::MSVC)
.Case("omp", AttributeCommonInfo::Scope::OMP)
.Case("riscv", AttributeCommonInfo::Scope::RISCV);
.Case("riscv", AttributeCommonInfo::Scope::RISCV)
.Case("vk", AttributeCommonInfo::Scope::HLSL);
}

unsigned AttributeCommonInfo::calculateAttributeSpellingListIndex() const {
Expand Down
17 changes: 17 additions & 0 deletions clang/lib/CodeGen/CGHLSLBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,23 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
return EmitRuntimeCall(
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID));
}
case Builtin::BI__builtin_get_spirv_spec_constant_bool:
case Builtin::BI__builtin_get_spirv_spec_constant_short:
case Builtin::BI__builtin_get_spirv_spec_constant_ushort:
case Builtin::BI__builtin_get_spirv_spec_constant_int:
case Builtin::BI__builtin_get_spirv_spec_constant_uint:
case Builtin::BI__builtin_get_spirv_spec_constant_longlong:
case Builtin::BI__builtin_get_spirv_spec_constant_ulonglong:
case Builtin::BI__builtin_get_spirv_spec_constant_half:
case Builtin::BI__builtin_get_spirv_spec_constant_float:
case Builtin::BI__builtin_get_spirv_spec_constant_double: {
assert(CGM.getTarget().getTriple().isSPIRV() && "SPIR-V only");
Intrinsic::ID ID = Intrinsic::spv_get_specialization_constant;
llvm::Type *T = CGM.getTypes().ConvertType(E->getType());
auto F = Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID, {T});
return EmitRuntimeCall(
F, {EmitScalarExpr(E->getArg(0)), EmitScalarExpr(E->getArg(1))});
}
}
return nullptr;
}
14 changes: 14 additions & 0 deletions clang/lib/Sema/SemaDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2889,6 +2889,8 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
NewAttr = S.HLSL().mergeWaveSizeAttr(D, *WS, WS->getMin(), WS->getMax(),
WS->getPreferred(),
WS->getSpelledArgsCount());
else if (const auto *CI = dyn_cast<HLSLVkConstantIdAttr>(Attr))
NewAttr = S.HLSL().mergeVkConstantIdAttr(D, *CI, CI->getId());
else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
else if (isa<SuppressAttr>(Attr))
Expand Down Expand Up @@ -13755,6 +13757,10 @@ void Sema::AddInitializerToDecl(Decl *RealDecl, Expr *Init, bool DirectInit) {
return;
}

if (getLangOpts().HLSL)
if (!HLSL().handleInitialization(VDecl, Init))
return;

// Get the decls type and save a reference for later, since
// CheckInitializerTypes may change it.
QualType DclT = VDecl->getType(), SavT = DclT;
Expand Down Expand Up @@ -14215,6 +14221,14 @@ void Sema::ActOnUninitializedDecl(Decl *RealDecl) {
}
}

// HLSL variable with the `vk::constant_id` attribute must be initialized.
if (!Var->isInvalidDecl() && Var->hasAttr<HLSLVkConstantIdAttr>()) {
Diag(Var->getLocation(),
diag::err_specialization_const_missing_initializer);
Var->setInvalidDecl();
return;
}

if (!Var->isInvalidDecl() && RealDecl->hasAttr<LoaderUninitializedAttr>()) {
if (Var->getStorageClass() == SC_Extern) {
Diag(Var->getLocation(), diag::err_loader_uninitialized_extern_decl)
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/Sema/SemaDeclAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7560,6 +7560,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
case ParsedAttr::AT_HLSLVkExtBuiltinInput:
S.HLSL().handleVkExtBuiltinInputAttr(D, AL);
break;
case ParsedAttr::AT_HLSLVkConstantId:
S.HLSL().handleVkConstantIdAttr(D, AL);
break;
case ParsedAttr::AT_HLSLSV_GroupThreadID:
S.HLSL().handleSV_GroupThreadIDAttr(D, AL);
break;
Expand Down
120 changes: 119 additions & 1 deletion clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,40 @@ static ResourceClass getResourceClass(RegisterType RT) {
llvm_unreachable("unexpected RegisterType value");
}

static Builtin::ID getSpecConstBuiltinId(QualType Type) {
const auto *BT = dyn_cast<BuiltinType>(Type);
if (!BT) {
if (!Type->isEnumeralType())
return Builtin::NotBuiltin;
return Builtin::BI__builtin_get_spirv_spec_constant_int;
}

switch (BT->getKind()) {
case BuiltinType::Bool:
return Builtin::BI__builtin_get_spirv_spec_constant_bool;
case BuiltinType::Short:
return Builtin::BI__builtin_get_spirv_spec_constant_short;
case BuiltinType::Int:
return Builtin::BI__builtin_get_spirv_spec_constant_int;
case BuiltinType::LongLong:
return Builtin::BI__builtin_get_spirv_spec_constant_longlong;
case BuiltinType::UShort:
return Builtin::BI__builtin_get_spirv_spec_constant_ushort;
case BuiltinType::UInt:
return Builtin::BI__builtin_get_spirv_spec_constant_uint;
case BuiltinType::ULongLong:
return Builtin::BI__builtin_get_spirv_spec_constant_ulonglong;
case BuiltinType::Half:
return Builtin::BI__builtin_get_spirv_spec_constant_half;
case BuiltinType::Float:
return Builtin::BI__builtin_get_spirv_spec_constant_float;
case BuiltinType::Double:
return Builtin::BI__builtin_get_spirv_spec_constant_double;
default:
return Builtin::NotBuiltin;
}
}

DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD,
ResourceClass ResClass) {
assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
Expand Down Expand Up @@ -607,6 +641,41 @@ HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D,
return Result;
}

HLSLVkConstantIdAttr *
SemaHLSL::mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL,
int Id) {

auto &TargetInfo = getASTContext().getTargetInfo();
if (TargetInfo.getTriple().getArch() != llvm::Triple::spirv) {
Diag(AL.getLoc(), diag::warn_attribute_ignored) << AL;
return nullptr;
}

auto *VD = cast<VarDecl>(D);

if (getSpecConstBuiltinId(VD->getType()) == Builtin::NotBuiltin) {
Diag(VD->getLocation(), diag::err_specialization_const_is_not_int_or_float);
return nullptr;
}

if (!VD->getType().isConstQualified()) {
Diag(VD->getLocation(), diag::err_specialization_const_missing_const);
return nullptr;
}

if (HLSLVkConstantIdAttr *CI = D->getAttr<HLSLVkConstantIdAttr>()) {
if (CI->getId() != Id) {
Diag(CI->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
Diag(AL.getLoc(), diag::note_conflicting_attribute);
}
return nullptr;
}

HLSLVkConstantIdAttr *Result =
::new (getASTContext()) HLSLVkConstantIdAttr(getASTContext(), AL, Id);
return Result;
}

HLSLShaderAttr *
SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
llvm::Triple::EnvironmentType ShaderType) {
Expand Down Expand Up @@ -1125,6 +1194,15 @@ void SemaHLSL::handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL) {
HLSLVkExtBuiltinInputAttr(getASTContext(), AL, ID));
}

void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) {
uint32_t Id;
if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Id))
return;
HLSLVkConstantIdAttr *NewAttr = mergeVkConstantIdAttr(D, AL, Id);
if (NewAttr)
D->addAttr(NewAttr);
}

bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) {
const auto *VT = T->getAs<VectorType>();

Expand Down Expand Up @@ -3154,6 +3232,7 @@ static bool IsDefaultBufferConstantDecl(VarDecl *VD) {
return VD->getDeclContext()->isTranslationUnit() &&
QT.getAddressSpace() == LangAS::Default &&
VD->getStorageClass() != SC_Static &&
!VD->hasAttr<HLSLVkConstantIdAttr>() &&
!isInvalidConstantBufferLeafElementType(QT.getTypePtr());
}

Expand Down Expand Up @@ -3221,7 +3300,8 @@ void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
const Type *VarType = VD->getType().getTypePtr();
while (VarType->isArrayType())
VarType = VarType->getArrayElementTypeNoTypeQual();
if (VarType->isHLSLResourceRecord()) {
if (VarType->isHLSLResourceRecord() ||
VD->hasAttr<HLSLVkConstantIdAttr>()) {
// Make the variable for resources static. The global externally visible
// storage is accessed through the handle, which is a member. The variable
// itself is not externally visible.
Expand Down Expand Up @@ -3644,3 +3724,41 @@ bool SemaHLSL::transformInitList(const InitializedEntity &Entity,
Init->updateInit(Ctx, I, NewInit->getInit(I));
return true;
}

bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) {
const HLSLVkConstantIdAttr *ConstIdAttr =
VDecl->getAttr<HLSLVkConstantIdAttr>();
if (!ConstIdAttr)
return true;

ASTContext &Context = SemaRef.getASTContext();

APValue InitValue;
if (!Init->isCXX11ConstantExpr(Context, &InitValue)) {
Diag(VDecl->getLocation(), diag::err_specialization_const_lit_init);
VDecl->setInvalidDecl();
return false;
}

Builtin::ID BID = getSpecConstBuiltinId(VDecl->getType());

// Argument 1: The ID from the attribute
int ConstantID = ConstIdAttr->getId();
llvm::APInt IDVal(Context.getIntWidth(Context.IntTy), ConstantID);
Expr *IdExpr = IntegerLiteral::Create(Context, IDVal, Context.IntTy,
ConstIdAttr->getLocation());

SmallVector<Expr *, 2> Args = {IdExpr, Init};
Expr *C = SemaRef.BuildBuiltinCallExpr(Init->getExprLoc(), BID, Args);
if (C->getType()->getCanonicalTypeUnqualified() !=
VDecl->getType()->getCanonicalTypeUnqualified()) {
C = SemaRef
.BuildCStyleCastExpr(SourceLocation(),
Context.getTrivialTypeSourceInfo(
Init->getType(), Init->getExprLoc()),
SourceLocation(), C)
.get();
}
Init = C;
return true;
}
Loading
Loading