Skip to content

[HLSL][SPIRV] Add vk::constant_id attribute. #143544

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -4993,6 +4993,14 @@ def HLSLVkExtBuiltinInput : InheritableAttr {
let Documentation = [HLSLVkExtBuiltinInputDocs];
}

def HLSLVkConstantId : InheritableAttr {
let Spellings = [CXX11<"vk", "constant_id">];
let Args = [IntArgument<"Id">];
let Subjects = SubjectList<[ExternalGlobalVar]>;
let LangOpts = [HLSL];
let Documentation = [VkConstantIdDocs];
}

def RandomizeLayout : InheritableAttr {
let Spellings = [GCC<"randomize_layout">];
let Subjects = SubjectList<[Record]>;
Expand Down
15 changes: 15 additions & 0 deletions clang/include/clang/Basic/AttrDocs.td
Original file line number Diff line number Diff line change
Expand Up @@ -8252,6 +8252,21 @@ and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html
}];
}

def VkConstantIdDocs : Documentation {
let Category = DocCatFunction;
let Content = [{
The ``vk::constant_id`` attribute specifies the id for a SPIR-V specialization
constant. The attribute applies to const global scalar variables. The variable must be initialized with a C++11 constexpr.
In SPIR-V, the
variable will be replaced with an `OpSpecConstant` with the given id.
The syntax is:

.. code-block:: text

``[[vk::constant_id(<Id>)]] const T Name = <Init>``
}];
}

def RootSignatureDocs : Documentation {
let Category = DocCatFunction;
let Content = [{
Expand Down
13 changes: 13 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -5065,6 +5065,19 @@ def HLSLGroupMemoryBarrierWithGroupSync: LangBuiltin<"HLSL_LANG"> {
let Prototype = "void()";
}

class HLSLScalarTemplate
: Template<["bool", "char", "short", "int", "long long int",
"unsigned short", "unsigned int", "unsigned long long int",
"__fp16", "float", "double"],
["_bool", "_char", "_short", "_int", "_longlong", "_ushort",
"_uint", "_ulonglong", "_half", "_float", "_double"]>;

def HLSLGetSpirvSpecConstant : LangBuiltin<"HLSL_LANG">, HLSLScalarTemplate {
let Spellings = ["__builtin_get_spirv_spec_constant"];
let Attributes = [NoThrow, Const, Pure];
let Prototype = "T(unsigned int, T)";
}

// Builtins for XRay.
def XRayCustomEvent : Builtin {
let Spellings = ["__xray_customevent"];
Expand Down
12 changes: 12 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -12919,6 +12919,18 @@ def err_spirv_enum_not_int : Error<
def err_spirv_enum_not_valid : Error<
"invalid value for %select{storage class}0 argument">;

def err_specialization_const_lit_init
: Error<"variable with 'vk::constant_id' attribute cannot have an "
"initializer that is not a literal">;
def err_specialization_const_missing_initializer
: Error<
"variable with 'vk::constant_id' attribute must have an initializer">;
def err_specialization_const_missing_const
: Error<"variable with 'vk::constant_id' attribute must be const">;
def err_specialization_const_is_not_int_or_float
: Error<"variable with 'vk::constant_id' attribute must be an enum, bool, "
"integer, or floating point value">;

// errors of expect.with.probability
def err_probability_not_constant_float : Error<
"probability argument to __builtin_expect_with_probability must be constant "
Expand Down
5 changes: 4 additions & 1 deletion clang/include/clang/Sema/SemaHLSL.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class SemaHLSL : public SemaBase {
HLSLWaveSizeAttr *mergeWaveSizeAttr(Decl *D, const AttributeCommonInfo &AL,
int Min, int Max, int Preferred,
int SpelledArgsCount);
HLSLVkConstantIdAttr *
mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL, int Id);
HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
llvm::Triple::EnvironmentType ShaderType);
HLSLParamModifierAttr *
Expand All @@ -122,6 +124,7 @@ class SemaHLSL : public SemaBase {
void handleRootSignatureAttr(Decl *D, const ParsedAttr &AL);
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL);
void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL);
void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);
Expand Down Expand Up @@ -156,7 +159,7 @@ class SemaHLSL : public SemaBase {
QualType getInoutParameterType(QualType Ty);

bool transformInitList(const InitializedEntity &Entity, InitListExpr *Init);

bool handleInitialization(VarDecl *VDecl, Expr *&Init);
void deduceAddressSpace(VarDecl *Decl);

private:
Expand Down
74 changes: 74 additions & 0 deletions clang/lib/CodeGen/CGHLSLBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "CGBuiltin.h"
#include "CGHLSLRuntime.h"
#include "CodeGenFunction.h"

using namespace clang;
using namespace CodeGen;
Expand Down Expand Up @@ -214,6 +215,43 @@ static Intrinsic::ID getWaveActiveMaxIntrinsic(llvm::Triple::ArchType Arch,
}
}

// Returns the mangled name for a builtin function that the SPIR-V backend
// will expand into a spec Constant.
static std::string getSpecConstantFunctionName(clang::QualType SpecConstantType,
ASTContext &Context) {
// The parameter types for our conceptual intrinsic function.
QualType ClangParamTypes[] = {Context.IntTy, SpecConstantType};

// Create a temporary FunctionDecl for the builtin fuction. It won't be
// added to the AST.
FunctionProtoType::ExtProtoInfo EPI;
QualType FnType =
Context.getFunctionType(SpecConstantType, ClangParamTypes, EPI);
DeclarationName FuncName = &Context.Idents.get("__spirv_SpecConstant");
FunctionDecl *FnDeclForMangling = FunctionDecl::Create(
Context, Context.getTranslationUnitDecl(), SourceLocation(),
SourceLocation(), FuncName, FnType, /*TSI=*/nullptr, SC_Extern);

// Attach the created parameter declarations to the function declaration.
SmallVector<ParmVarDecl *, 2> ParamDecls;
for (QualType ParamType : ClangParamTypes) {
ParmVarDecl *PD = ParmVarDecl::Create(
Context, FnDeclForMangling, SourceLocation(), SourceLocation(),
/*IdentifierInfo*/ nullptr, ParamType, /*TSI*/ nullptr, SC_None,
/*DefaultArg*/ nullptr);
ParamDecls.push_back(PD);
}
FnDeclForMangling->setParams(ParamDecls);

// Get the mangled name.
std::string Name;
llvm::raw_string_ostream MangledNameStream(Name);
MangleContext *Mangler = Context.createMangleContext();
Mangler->mangleName(FnDeclForMangling, MangledNameStream);
MangledNameStream.flush();
return Name;
}

Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
const CallExpr *E,
ReturnValueSlot ReturnValue) {
Expand Down Expand Up @@ -774,6 +812,42 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
return EmitRuntimeCall(
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID));
}
case Builtin::BI__builtin_get_spirv_spec_constant_bool:
case Builtin::BI__builtin_get_spirv_spec_constant_short:
case Builtin::BI__builtin_get_spirv_spec_constant_ushort:
case Builtin::BI__builtin_get_spirv_spec_constant_int:
case Builtin::BI__builtin_get_spirv_spec_constant_uint:
case Builtin::BI__builtin_get_spirv_spec_constant_longlong:
case Builtin::BI__builtin_get_spirv_spec_constant_ulonglong:
case Builtin::BI__builtin_get_spirv_spec_constant_half:
case Builtin::BI__builtin_get_spirv_spec_constant_float:
case Builtin::BI__builtin_get_spirv_spec_constant_double: {
Comment on lines +815 to +824
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to handle overloading the same way wave_read_lane_at does it?
This way there is a single Builtin::BI to handle here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, but I like this method because it becomes impossible to create calls to the builtin with invalid types. wave_read_lane_at is defined with a prototype: void(...). I believe we can add type checking in other places, but I feel this is more robust.

llvm::Function *SpecConstantFn = getSpecConstantFunction(E->getType());
llvm::Value *SpecId = EmitScalarExpr(E->getArg(0));
llvm::Value *DefaultVal = EmitScalarExpr(E->getArg(1));
llvm::Value *Args[] = {SpecId, DefaultVal};
return Builder.CreateCall(SpecConstantFn, Args);
}
}
return nullptr;
}

llvm::Function *clang::CodeGen::CodeGenFunction::getSpecConstantFunction(
const clang::QualType &SpecConstantType) {

// Find or create the declaration for the function.
llvm::Module *M = &CGM.getModule();
std::string MangledName =
getSpecConstantFunctionName(SpecConstantType, getContext());
llvm::Function *SpecConstantFn = M->getFunction(MangledName);

if (!SpecConstantFn) {
llvm::Type *IntType = ConvertType(getContext().IntTy);
llvm::Type *RetTy = ConvertType(SpecConstantType);
llvm::Type *ArgTypes[] = {IntType, RetTy};
llvm::FunctionType *FnTy = llvm::FunctionType::get(RetTy, ArgTypes, false);
SpecConstantFn = llvm::Function::Create(
FnTy, llvm::GlobalValue::ExternalLinkage, MangledName, M);
}
return SpecConstantFn;
}
6 changes: 6 additions & 0 deletions clang/lib/CodeGen/CodeGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -4850,6 +4850,12 @@ class CodeGenFunction : public CodeGenTypeCache {
llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E,
ReturnValueSlot ReturnValue);

// Returns a builtin function that the SPIR-V backend will expand into a spec
// constant.
llvm::Function *
getSpecConstantFunction(const clang::QualType &SpecConstantType);

llvm::Value *EmitDirectXBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitSPIRVBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
Expand Down
14 changes: 14 additions & 0 deletions clang/lib/Sema/SemaDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2889,6 +2889,8 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
NewAttr = S.HLSL().mergeWaveSizeAttr(D, *WS, WS->getMin(), WS->getMax(),
WS->getPreferred(),
WS->getSpelledArgsCount());
else if (const auto *CI = dyn_cast<HLSLVkConstantIdAttr>(Attr))
NewAttr = S.HLSL().mergeVkConstantIdAttr(D, *CI, CI->getId());
else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
else if (isa<SuppressAttr>(Attr))
Expand Down Expand Up @@ -13755,6 +13757,10 @@ void Sema::AddInitializerToDecl(Decl *RealDecl, Expr *Init, bool DirectInit) {
return;
}

if (getLangOpts().HLSL)
if (!HLSL().handleInitialization(VDecl, Init))
return;

// Get the decls type and save a reference for later, since
// CheckInitializerTypes may change it.
QualType DclT = VDecl->getType(), SavT = DclT;
Expand Down Expand Up @@ -14215,6 +14221,14 @@ void Sema::ActOnUninitializedDecl(Decl *RealDecl) {
}
}

// HLSL variable with the `vk::constant_id` attribute must be initialized.
if (!Var->isInvalidDecl() && Var->hasAttr<HLSLVkConstantIdAttr>()) {
Diag(Var->getLocation(),
diag::err_specialization_const_missing_initializer);
Var->setInvalidDecl();
return;
}

if (!Var->isInvalidDecl() && RealDecl->hasAttr<LoaderUninitializedAttr>()) {
if (Var->getStorageClass() == SC_Extern) {
Diag(Var->getLocation(), diag::err_loader_uninitialized_extern_decl)
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/Sema/SemaDeclAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7560,6 +7560,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
case ParsedAttr::AT_HLSLVkExtBuiltinInput:
S.HLSL().handleVkExtBuiltinInputAttr(D, AL);
break;
case ParsedAttr::AT_HLSLVkConstantId:
S.HLSL().handleVkConstantIdAttr(D, AL);
break;
case ParsedAttr::AT_HLSLSV_GroupThreadID:
S.HLSL().handleSV_GroupThreadIDAttr(D, AL);
break;
Expand Down
Loading
Loading