Skip to content
8 changes: 8 additions & 0 deletions clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -5023,6 +5023,14 @@ def HLSLVkExtBuiltinInput : InheritableAttr {
let Documentation = [HLSLVkExtBuiltinInputDocs];
}

def HLSLVkConstantId : InheritableAttr {
let Spellings = [CXX11<"vk", "constant_id">];
let Args = [IntArgument<"Id">];
let Subjects = SubjectList<[ExternalGlobalVar]>;
let LangOpts = [HLSL];
let Documentation = [VkConstantIdDocs];
}

def RandomizeLayout : InheritableAttr {
let Spellings = [GCC<"randomize_layout">];
let Subjects = SubjectList<[Record]>;
Expand Down
15 changes: 15 additions & 0 deletions clang/include/clang/Basic/AttrDocs.td
Original file line number Diff line number Diff line change
Expand Up @@ -8252,6 +8252,21 @@ and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html
}];
}

def VkConstantIdDocs : Documentation {
let Category = DocCatFunction;
let Content = [{
The ``vk::constant_id`` attribute specifies the id for a SPIR-V specialization
constant. The attribute applies to const global scalar variables. The variable must be initialized with a C++11 constexpr.
In SPIR-V, the
variable will be replaced with an `OpSpecConstant` with the given id.
The syntax is:

.. code-block:: text

``[[vk::constant_id(<Id>)]] const T Name = <Init>``
}];
}

def RootSignatureDocs : Documentation {
let Category = DocCatFunction;
let Content = [{
Expand Down
13 changes: 13 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -5065,6 +5065,19 @@ def HLSLGroupMemoryBarrierWithGroupSync: LangBuiltin<"HLSL_LANG"> {
let Prototype = "void()";
}

class HLSLScalarTemplate
: Template<["bool", "char", "short", "int", "long long int",
"unsigned short", "unsigned int", "unsigned long long int",
"__fp16", "float", "double"],
["_bool", "_char", "_short", "_int", "_longlong", "_ushort",
"_uint", "_ulonglong", "_half", "_float", "_double"]>;

def HLSLGetSpirvSpecConstant : LangBuiltin<"HLSL_LANG">, HLSLScalarTemplate {
let Spellings = ["__builtin_get_spirv_spec_constant"];
let Attributes = [NoThrow, Const, Pure];
let Prototype = "T(unsigned int, T)";
}

// Builtins for XRay.
def XRayCustomEvent : Builtin {
let Spellings = ["__xray_customevent"];
Expand Down
4 changes: 4 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -12927,6 +12927,10 @@ def err_spirv_enum_not_int : Error<
def err_spirv_enum_not_valid : Error<
"invalid value for %select{storage class}0 argument">;

def err_specialization_const
: Error<"variable with 'vk::constant_id' attribute must be a const "
"int/float/enum/bool and be initialized with a literal">;

// errors of expect.with.probability
def err_probability_not_constant_float : Error<
"probability argument to __builtin_expect_with_probability must be constant "
Expand Down
5 changes: 4 additions & 1 deletion clang/include/clang/Sema/SemaHLSL.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class SemaHLSL : public SemaBase {
HLSLWaveSizeAttr *mergeWaveSizeAttr(Decl *D, const AttributeCommonInfo &AL,
int Min, int Max, int Preferred,
int SpelledArgsCount);
HLSLVkConstantIdAttr *
mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL, int Id);
HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
llvm::Triple::EnvironmentType ShaderType);
HLSLParamModifierAttr *
Expand All @@ -122,6 +124,7 @@ class SemaHLSL : public SemaBase {
void handleRootSignatureAttr(Decl *D, const ParsedAttr &AL);
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL);
void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL);
void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);
Expand Down Expand Up @@ -158,7 +161,7 @@ class SemaHLSL : public SemaBase {
QualType getInoutParameterType(QualType Ty);

bool transformInitList(const InitializedEntity &Entity, InitListExpr *Init);

bool handleInitialization(VarDecl *VDecl, Expr *&Init);
void deduceAddressSpace(VarDecl *Decl);

private:
Expand Down
74 changes: 74 additions & 0 deletions clang/lib/CodeGen/CGHLSLBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "CGBuiltin.h"
#include "CGHLSLRuntime.h"
#include "CodeGenFunction.h"

using namespace clang;
using namespace CodeGen;
Expand Down Expand Up @@ -214,6 +215,43 @@ static Intrinsic::ID getWaveActiveMaxIntrinsic(llvm::Triple::ArchType Arch,
}
}

// Returns the mangled name for a builtin function that the SPIR-V backend
// will expand into a spec Constant.
static std::string getSpecConstantFunctionName(clang::QualType SpecConstantType,
ASTContext &Context) {
// The parameter types for our conceptual intrinsic function.
QualType ClangParamTypes[] = {Context.IntTy, SpecConstantType};

// Create a temporary FunctionDecl for the builtin fuction. It won't be
// added to the AST.
FunctionProtoType::ExtProtoInfo EPI;
QualType FnType =
Context.getFunctionType(SpecConstantType, ClangParamTypes, EPI);
DeclarationName FuncName = &Context.Idents.get("__spirv_SpecConstant");
FunctionDecl *FnDeclForMangling = FunctionDecl::Create(
Context, Context.getTranslationUnitDecl(), SourceLocation(),
SourceLocation(), FuncName, FnType, /*TSI=*/nullptr, SC_Extern);

// Attach the created parameter declarations to the function declaration.
SmallVector<ParmVarDecl *, 2> ParamDecls;
for (QualType ParamType : ClangParamTypes) {
ParmVarDecl *PD = ParmVarDecl::Create(
Context, FnDeclForMangling, SourceLocation(), SourceLocation(),
/*IdentifierInfo*/ nullptr, ParamType, /*TSI*/ nullptr, SC_None,
/*DefaultArg*/ nullptr);
ParamDecls.push_back(PD);
}
FnDeclForMangling->setParams(ParamDecls);

// Get the mangled name.
std::string Name;
llvm::raw_string_ostream MangledNameStream(Name);
MangleContext *Mangler = Context.createMangleContext();
Mangler->mangleName(FnDeclForMangling, MangledNameStream);
MangledNameStream.flush();
return Name;
}

Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
const CallExpr *E,
ReturnValueSlot ReturnValue) {
Expand Down Expand Up @@ -773,6 +811,42 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
return EmitRuntimeCall(
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID));
}
case Builtin::BI__builtin_get_spirv_spec_constant_bool:
case Builtin::BI__builtin_get_spirv_spec_constant_short:
case Builtin::BI__builtin_get_spirv_spec_constant_ushort:
case Builtin::BI__builtin_get_spirv_spec_constant_int:
case Builtin::BI__builtin_get_spirv_spec_constant_uint:
case Builtin::BI__builtin_get_spirv_spec_constant_longlong:
case Builtin::BI__builtin_get_spirv_spec_constant_ulonglong:
case Builtin::BI__builtin_get_spirv_spec_constant_half:
case Builtin::BI__builtin_get_spirv_spec_constant_float:
case Builtin::BI__builtin_get_spirv_spec_constant_double: {
llvm::Function *SpecConstantFn = getSpecConstantFunction(E->getType());
llvm::Value *SpecId = EmitScalarExpr(E->getArg(0));
llvm::Value *DefaultVal = EmitScalarExpr(E->getArg(1));
llvm::Value *Args[] = {SpecId, DefaultVal};
return Builder.CreateCall(SpecConstantFn, Args);
}
}
return nullptr;
}

llvm::Function *clang::CodeGen::CodeGenFunction::getSpecConstantFunction(
const clang::QualType &SpecConstantType) {

// Find or create the declaration for the function.
llvm::Module *M = &CGM.getModule();
std::string MangledName =
getSpecConstantFunctionName(SpecConstantType, getContext());
llvm::Function *SpecConstantFn = M->getFunction(MangledName);

if (!SpecConstantFn) {
llvm::Type *IntType = ConvertType(getContext().IntTy);
llvm::Type *RetTy = ConvertType(SpecConstantType);
llvm::Type *ArgTypes[] = {IntType, RetTy};
llvm::FunctionType *FnTy = llvm::FunctionType::get(RetTy, ArgTypes, false);
SpecConstantFn = llvm::Function::Create(
FnTy, llvm::GlobalValue::ExternalLinkage, MangledName, M);
}
return SpecConstantFn;
}
6 changes: 6 additions & 0 deletions clang/lib/CodeGen/CodeGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -4850,6 +4850,12 @@ class CodeGenFunction : public CodeGenTypeCache {
llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E,
ReturnValueSlot ReturnValue);

// Returns a builtin function that the SPIR-V backend will expand into a spec
// constant.
llvm::Function *
getSpecConstantFunction(const clang::QualType &SpecConstantType);

llvm::Value *EmitDirectXBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitSPIRVBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
Expand Down
13 changes: 13 additions & 0 deletions clang/lib/Sema/SemaDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2889,6 +2889,8 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
NewAttr = S.HLSL().mergeWaveSizeAttr(D, *WS, WS->getMin(), WS->getMax(),
WS->getPreferred(),
WS->getSpelledArgsCount());
else if (const auto *CI = dyn_cast<HLSLVkConstantIdAttr>(Attr))
NewAttr = S.HLSL().mergeVkConstantIdAttr(D, *CI, CI->getId());
else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
else if (isa<SuppressAttr>(Attr))
Expand Down Expand Up @@ -13756,6 +13758,10 @@ void Sema::AddInitializerToDecl(Decl *RealDecl, Expr *Init, bool DirectInit) {
return;
}

if (getLangOpts().HLSL)
if (!HLSL().handleInitialization(VDecl, Init))
return;

// Get the decls type and save a reference for later, since
// CheckInitializerTypes may change it.
QualType DclT = VDecl->getType(), SavT = DclT;
Expand Down Expand Up @@ -14199,6 +14205,13 @@ void Sema::ActOnUninitializedDecl(Decl *RealDecl) {
}
}

// HLSL variable with the `vk::constant_id` attribute must be initialized.
if (!Var->isInvalidDecl() && Var->hasAttr<HLSLVkConstantIdAttr>()) {
Diag(Var->getLocation(), diag::err_specialization_const);
Var->setInvalidDecl();
return;
}

if (!Var->isInvalidDecl() && RealDecl->hasAttr<LoaderUninitializedAttr>()) {
if (Var->getStorageClass() == SC_Extern) {
Diag(Var->getLocation(), diag::err_loader_uninitialized_extern_decl)
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/Sema/SemaDeclAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7590,6 +7590,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
case ParsedAttr::AT_HLSLVkExtBuiltinInput:
S.HLSL().handleVkExtBuiltinInputAttr(D, AL);
break;
case ParsedAttr::AT_HLSLVkConstantId:
S.HLSL().handleVkConstantIdAttr(D, AL);
break;
case ParsedAttr::AT_HLSLSV_GroupThreadID:
S.HLSL().handleSV_GroupThreadIDAttr(D, AL);
break;
Expand Down
Loading
Loading