Skip to content

[HLSL][SPIRV] Add vk::constant_id attribute. #143544

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

s-perron
Copy link
Contributor

@s-perron s-perron commented Jun 10, 2025

The vk::constant_id attribute is used to indicate that a global const variable
represents a specialization constant in SPIR-V. This PR adds this
attribute to clang.

The documentation for the attribute is here.

The strategy is to to modify the initializer to get the value of a
specialize constant for a builtin defined in the SPIR-V backend.

Implements llvm/wg-hlsl#287

Fixes #142448

The vk::constant_id attribute is used to indicate that a global const variable
represents a specialization constant in SPIR-V. This PR adds this
attribute to clang.

The documetation for the attribute is [here](https://github.com/microsoft/DirectXShaderCompiler/blob/main/docs/SPIR-V.rst#specialization-constants).

The strategy is to to modify the initializer to get the value of a
specialize constant for a builtin defined in the SPIR-V backend.
@s-perron s-perron requested a review from Keenuts June 10, 2025 14:45
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:codegen IR generation bugs: mangling, exceptions, etc. HLSL HLSL Language Support labels Jun 10, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 10, 2025

@llvm/pr-subscribers-clang

Author: Steven Perron (s-perron)

Changes

The vk::constant_id attribute is used to indicate that a global const variable
represents a specialization constant in SPIR-V. This PR adds this
attribute to clang.

The documetation for the attribute is here.

The strategy is to to modify the initializer to get the value of a
specialize constant for a builtin defined in the SPIR-V backend.


Patch is 39.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/143544.diff

16 Files Affected:

  • (modified) clang/include/clang/Basic/Attr.td (+8)
  • (modified) clang/include/clang/Basic/AttrDocs.td (+15)
  • (modified) clang/include/clang/Basic/Builtins.td (+13)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+12)
  • (modified) clang/include/clang/Sema/SemaHLSL.h (+4-1)
  • (modified) clang/lib/Basic/Attributes.cpp (+2-1)
  • (modified) clang/lib/CodeGen/CGHLSLBuiltins.cpp (+72)
  • (modified) clang/lib/CodeGen/CodeGenFunction.h (+11)
  • (modified) clang/lib/Sema/SemaDecl.cpp (+14)
  • (modified) clang/lib/Sema/SemaDeclAttr.cpp (+3)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+119-1)
  • (added) clang/test/AST/HLSL/vk.spec-constant.usage.hlsl (+130)
  • (renamed) clang/test/CodeGenHLSL/vk-features/SpirvType.alignment.hlsl ()
  • (renamed) clang/test/CodeGenHLSL/vk-features/SpirvType.hlsl ()
  • (added) clang/test/CodeGenHLSL/vk-features/vk.spec-constant.hlsl (+210)
  • (added) clang/test/SemaHLSL/vk.spec-constant.error.hlsl (+37)
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index f889e41c8699f..d3f39de6a3e85 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4993,6 +4993,14 @@ def HLSLVkExtBuiltinInput : InheritableAttr {
   let Documentation = [HLSLVkExtBuiltinInputDocs];
 }
 
+def HLSLVkConstantId : InheritableAttr {
+  let Spellings = [CXX11<"vk", "constant_id">];
+  let Args = [IntArgument<"Id">];
+  let Subjects = SubjectList<[ExternalGlobalVar]>;
+  let LangOpts = [HLSL];
+  let Documentation = [VkConstantIdDocs];
+}
+
 def RandomizeLayout : InheritableAttr {
   let Spellings = [GCC<"randomize_layout">];
   let Subjects = SubjectList<[Record]>;
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index ea3c43f38d9fe..b3eafb79c5d4a 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -8252,6 +8252,21 @@ and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html
   }];
 }
 
+def VkConstantIdDocs : Documentation {
+  let Category = DocCatFunction;
+  let Content = [{
+The ``vk::constant_id`` attribute specify the id for a SPIR-V specialization
+constant. The attribute applies to const global scalar variables. The variable must be initialized with a C++11 constexpr.
+In SPIR-V, the
+variable will be replaced with an `OpSpecConstant` with the given id.
+The syntax is:
+
+.. code-block:: text
+
+  ``[[vk::constant_id(<Id>)]] const T Name = <Init>``
+}];
+}
+
 def RootSignatureDocs : Documentation {
   let Category = DocCatFunction;
   let Content = [{
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 68cd3d790e78a..d65b3a5d2f447 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -5065,6 +5065,19 @@ def HLSLGroupMemoryBarrierWithGroupSync: LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void()";
 }
 
+class HLSLScalarTemplate
+    : Template<["bool", "char", "short", "int", "long long int",
+                "unsigned short", "unsigned int", "unsigned long long int",
+                "__fp16", "float", "double"],
+               ["_bool", "_char", "_short", "_int", "_longlong", "_ushort",
+                "_uint", "_ulonglong", "_half", "_float", "_double"]>;
+
+def HLSLGetSpirvSpecConstant : LangBuiltin<"HLSL_LANG">, HLSLScalarTemplate {
+  let Spellings = ["__builtin_get_spirv_spec_constant"];
+  let Attributes = [NoThrow, Const, Pure];
+  let Prototype = "T(unsigned int, T)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 1f283b776a02c..23a490225dd19 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12919,6 +12919,18 @@ def err_spirv_enum_not_int : Error<
 def err_spirv_enum_not_valid : Error<
    "invalid value for %select{storage class}0 argument">;
 
+def err_specialization_const_lit_init
+    : Error<"variable with 'vk::constant_id' attribute cannot have an "
+            "initializer that is not a constexpr">;
+def err_specialization_const_missing_initializer
+    : Error<
+          "variable with 'vk::constant_id' attribute must have an initializer">;
+def err_specialization_const_missing_const
+    : Error<"variable with 'vk::constant_id' attribute must be const">;
+def err_specialization_const_is_not_int_or_float
+    : Error<"variable with 'vk::constant_id' attribute must be an enum, bool, "
+            "integer, or floating point value">;
+
 // errors of expect.with.probability
 def err_probability_not_constant_float : Error<
    "probability argument to __builtin_expect_with_probability must be constant "
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index 66d09f49680be..099d9c35684e8 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -98,6 +98,8 @@ class SemaHLSL : public SemaBase {
   HLSLWaveSizeAttr *mergeWaveSizeAttr(Decl *D, const AttributeCommonInfo &AL,
                                       int Min, int Max, int Preferred,
                                       int SpelledArgsCount);
+  HLSLVkConstantIdAttr *
+  mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL, int Id);
   HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
                                   llvm::Triple::EnvironmentType ShaderType);
   HLSLParamModifierAttr *
@@ -122,6 +124,7 @@ class SemaHLSL : public SemaBase {
   void handleRootSignatureAttr(Decl *D, const ParsedAttr &AL);
   void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
   void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
+  void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);
@@ -156,7 +159,7 @@ class SemaHLSL : public SemaBase {
   QualType getInoutParameterType(QualType Ty);
 
   bool transformInitList(const InitializedEntity &Entity, InitListExpr *Init);
-
+  bool handleInitialization(VarDecl *VDecl, Expr *&Init);
   void deduceAddressSpace(VarDecl *Decl);
 
 private:
diff --git a/clang/lib/Basic/Attributes.cpp b/clang/lib/Basic/Attributes.cpp
index 905046685934b..5f74365f0ed00 100644
--- a/clang/lib/Basic/Attributes.cpp
+++ b/clang/lib/Basic/Attributes.cpp
@@ -213,7 +213,8 @@ getScopeFromNormalizedScopeName(StringRef ScopeName) {
       .Case("vk", AttributeCommonInfo::Scope::VK)
       .Case("msvc", AttributeCommonInfo::Scope::MSVC)
       .Case("omp", AttributeCommonInfo::Scope::OMP)
-      .Case("riscv", AttributeCommonInfo::Scope::RISCV);
+      .Case("riscv", AttributeCommonInfo::Scope::RISCV)
+      .Case("vk", AttributeCommonInfo::Scope::HLSL);
 }
 
 unsigned AttributeCommonInfo::calculateAttributeSpellingListIndex() const {
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index abebc201808b0..74b5f92638600 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -12,6 +12,7 @@
 
 #include "CGBuiltin.h"
 #include "CGHLSLRuntime.h"
+#include "CodeGenFunction.h"
 
 using namespace clang;
 using namespace CodeGen;
@@ -774,6 +775,77 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
     return EmitRuntimeCall(
         Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID));
   }
+  case Builtin::BI__builtin_get_spirv_spec_constant_bool:
+  case Builtin::BI__builtin_get_spirv_spec_constant_short:
+  case Builtin::BI__builtin_get_spirv_spec_constant_ushort:
+  case Builtin::BI__builtin_get_spirv_spec_constant_int:
+  case Builtin::BI__builtin_get_spirv_spec_constant_uint:
+  case Builtin::BI__builtin_get_spirv_spec_constant_longlong:
+  case Builtin::BI__builtin_get_spirv_spec_constant_ulonglong:
+  case Builtin::BI__builtin_get_spirv_spec_constant_half:
+  case Builtin::BI__builtin_get_spirv_spec_constant_float:
+  case Builtin::BI__builtin_get_spirv_spec_constant_double: {
+    llvm::Function *SpecConstantFn = getSpecConstantFunction(E->getType());
+    llvm::Value *SpecId = EmitScalarExpr(E->getArg(0));
+    llvm::Value *DefaultVal = EmitScalarExpr(E->getArg(1));
+    llvm::Value *Args[] = {SpecId, DefaultVal};
+    return Builder.CreateCall(SpecConstantFn, Args);
+  }
   }
   return nullptr;
 }
+
+llvm::Function *clang::CodeGen::CodeGenFunction::getSpecConstantFunction(
+    const clang::QualType &SpecConstantType) {
+
+  // Find or create the declaration for the function.
+  llvm::Module *M = &CGM.getModule();
+  std::string MangledName = getSpecConstantFunctionName(SpecConstantType);
+  llvm::Function *SpecConstantFn = M->getFunction(MangledName);
+
+  if (!SpecConstantFn) {
+    llvm::Type *IntType = ConvertType(getContext().IntTy);
+    llvm::Type *RetTy = ConvertType(SpecConstantType);
+    llvm::Type *ArgTypes[] = {IntType, RetTy};
+    llvm::FunctionType *FnTy = llvm::FunctionType::get(RetTy, ArgTypes, false);
+    SpecConstantFn = llvm::Function::Create(
+        FnTy, llvm::GlobalValue::ExternalLinkage, MangledName, M);
+  }
+  return SpecConstantFn;
+}
+
+std::string clang::CodeGen::CodeGenFunction::getSpecConstantFunctionName(
+    const clang::QualType &SpecConstantType) {
+  // The parameter types for our conceptual intrinsic function.
+  ASTContext &Context = getContext();
+  QualType ClangParamTypes[] = {Context.IntTy, SpecConstantType};
+
+  // Create a temporary FunctionDecl for the builtin fuction. It won't be
+  // added to the AST.
+  FunctionProtoType::ExtProtoInfo EPI;
+  QualType FnType =
+      Context.getFunctionType(SpecConstantType, ClangParamTypes, EPI);
+  DeclarationName FuncName = &Context.Idents.get("__spirv_SpecConstant");
+  FunctionDecl *FnDeclForMangling = FunctionDecl::Create(
+      Context, Context.getTranslationUnitDecl(), SourceLocation(),
+      SourceLocation(), FuncName, FnType, /*TSI=*/nullptr, SC_Extern);
+
+  // Attach the created parameter declarations to the function declaration.
+  SmallVector<ParmVarDecl *, 2> ParamDecls;
+  for (QualType ParamType : ClangParamTypes) {
+    ParmVarDecl *PD = ParmVarDecl::Create(
+        Context, FnDeclForMangling, SourceLocation(), SourceLocation(),
+        /*IdentifierInfo*/ nullptr, ParamType, /*TSI*/ nullptr, SC_None,
+        /*DefaultArg*/ nullptr);
+    ParamDecls.push_back(PD);
+  }
+  FnDeclForMangling->setParams(ParamDecls);
+
+  // Get the mangled name.
+  std::string Name;
+  llvm::raw_string_ostream MangledNameStream(Name);
+  MangleContext *Mangler = Context.createMangleContext();
+  Mangler->mangleName(FnDeclForMangling, MangledNameStream);
+  MangledNameStream.flush();
+  return Name;
+}
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index a5ab9df01dba9..badc9feeb8c2b 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4850,6 +4850,17 @@ class CodeGenFunction : public CodeGenTypeCache {
   llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E,
                                    ReturnValueSlot ReturnValue);
+
+  // Returns a builtin function that the SPIR-V backend will expand into a spec
+  // constant.
+  llvm::Function *
+  getSpecConstantFunction(const clang::QualType &SpecConstantType);
+
+  // Returns the mangled name for a builtin function that the SPIR-V backend
+  // will expand into a spec Constant.
+  std::string
+  getSpecConstantFunctionName(const clang::QualType &SpecConstantType);
+
   llvm::Value *EmitDirectXBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitSPIRVBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index bbd63372c168b..9d09eec26f7a2 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -2889,6 +2889,8 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
     NewAttr = S.HLSL().mergeWaveSizeAttr(D, *WS, WS->getMin(), WS->getMax(),
                                          WS->getPreferred(),
                                          WS->getSpelledArgsCount());
+  else if (const auto *CI = dyn_cast<HLSLVkConstantIdAttr>(Attr))
+    NewAttr = S.HLSL().mergeVkConstantIdAttr(D, *CI, CI->getId());
   else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
     NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
   else if (isa<SuppressAttr>(Attr))
@@ -13755,6 +13757,10 @@ void Sema::AddInitializerToDecl(Decl *RealDecl, Expr *Init, bool DirectInit) {
     return;
   }
 
+  if (getLangOpts().HLSL)
+    if (!HLSL().handleInitialization(VDecl, Init))
+      return;
+
   // Get the decls type and save a reference for later, since
   // CheckInitializerTypes may change it.
   QualType DclT = VDecl->getType(), SavT = DclT;
@@ -14215,6 +14221,14 @@ void Sema::ActOnUninitializedDecl(Decl *RealDecl) {
       }
     }
 
+    // HLSL variable with the `vk::constant_id` attribute must be initialized.
+    if (!Var->isInvalidDecl() && Var->hasAttr<HLSLVkConstantIdAttr>()) {
+      Diag(Var->getLocation(),
+           diag::err_specialization_const_missing_initializer);
+      Var->setInvalidDecl();
+      return;
+    }
+
     if (!Var->isInvalidDecl() && RealDecl->hasAttr<LoaderUninitializedAttr>()) {
       if (Var->getStorageClass() == SC_Extern) {
         Diag(Var->getLocation(), diag::err_loader_uninitialized_extern_decl)
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index da0e3265767d8..e49bdda1a402b 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -7560,6 +7560,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
   case ParsedAttr::AT_HLSLVkExtBuiltinInput:
     S.HLSL().handleVkExtBuiltinInputAttr(D, AL);
     break;
+  case ParsedAttr::AT_HLSLVkConstantId:
+    S.HLSL().handleVkConstantIdAttr(D, AL);
+    break;
   case ParsedAttr::AT_HLSLSV_GroupThreadID:
     S.HLSL().handleSV_GroupThreadIDAttr(D, AL);
     break;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 9065cc5a1d4a5..5764507f35882 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -119,6 +119,40 @@ static ResourceClass getResourceClass(RegisterType RT) {
   llvm_unreachable("unexpected RegisterType value");
 }
 
+static Builtin::ID getSpecConstBuiltinId(QualType Type) {
+  const auto *BT = dyn_cast<BuiltinType>(Type);
+  if (!BT) {
+    if (!Type->isEnumeralType())
+      return Builtin::NotBuiltin;
+    return Builtin::BI__builtin_get_spirv_spec_constant_int;
+  }
+
+  switch (BT->getKind()) {
+  case BuiltinType::Bool:
+    return Builtin::BI__builtin_get_spirv_spec_constant_bool;
+  case BuiltinType::Short:
+    return Builtin::BI__builtin_get_spirv_spec_constant_short;
+  case BuiltinType::Int:
+    return Builtin::BI__builtin_get_spirv_spec_constant_int;
+  case BuiltinType::LongLong:
+    return Builtin::BI__builtin_get_spirv_spec_constant_longlong;
+  case BuiltinType::UShort:
+    return Builtin::BI__builtin_get_spirv_spec_constant_ushort;
+  case BuiltinType::UInt:
+    return Builtin::BI__builtin_get_spirv_spec_constant_uint;
+  case BuiltinType::ULongLong:
+    return Builtin::BI__builtin_get_spirv_spec_constant_ulonglong;
+  case BuiltinType::Half:
+    return Builtin::BI__builtin_get_spirv_spec_constant_half;
+  case BuiltinType::Float:
+    return Builtin::BI__builtin_get_spirv_spec_constant_float;
+  case BuiltinType::Double:
+    return Builtin::BI__builtin_get_spirv_spec_constant_double;
+  default:
+    return Builtin::NotBuiltin;
+  }
+}
+
 DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD,
                                                       ResourceClass ResClass) {
   assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
@@ -607,6 +641,41 @@ HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D,
   return Result;
 }
 
+HLSLVkConstantIdAttr *
+SemaHLSL::mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL,
+                                int Id) {
+
+  auto &TargetInfo = getASTContext().getTargetInfo();
+  if (TargetInfo.getTriple().getArch() != llvm::Triple::spirv) {
+    Diag(AL.getLoc(), diag::warn_attribute_ignored) << AL;
+    return nullptr;
+  }
+
+  auto *VD = cast<VarDecl>(D);
+
+  if (getSpecConstBuiltinId(VD->getType()) == Builtin::NotBuiltin) {
+    Diag(VD->getLocation(), diag::err_specialization_const_is_not_int_or_float);
+    return nullptr;
+  }
+
+  if (!VD->getType().isConstQualified()) {
+    Diag(VD->getLocation(), diag::err_specialization_const_missing_const);
+    return nullptr;
+  }
+
+  if (HLSLVkConstantIdAttr *CI = D->getAttr<HLSLVkConstantIdAttr>()) {
+    if (CI->getId() != Id) {
+      Diag(CI->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
+      Diag(AL.getLoc(), diag::note_conflicting_attribute);
+    }
+    return nullptr;
+  }
+
+  HLSLVkConstantIdAttr *Result =
+      ::new (getASTContext()) HLSLVkConstantIdAttr(getASTContext(), AL, Id);
+  return Result;
+}
+
 HLSLShaderAttr *
 SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
                           llvm::Triple::EnvironmentType ShaderType) {
@@ -1125,6 +1194,15 @@ void SemaHLSL::handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL) {
                  HLSLVkExtBuiltinInputAttr(getASTContext(), AL, ID));
 }
 
+void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) {
+  uint32_t Id;
+  if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Id))
+    return;
+  HLSLVkConstantIdAttr *NewAttr = mergeVkConstantIdAttr(D, AL, Id);
+  if (NewAttr)
+    D->addAttr(NewAttr);
+}
+
 bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) {
   const auto *VT = T->getAs<VectorType>();
 
@@ -3154,6 +3232,7 @@ static bool IsDefaultBufferConstantDecl(VarDecl *VD) {
   return VD->getDeclContext()->isTranslationUnit() &&
          QT.getAddressSpace() == LangAS::Default &&
          VD->getStorageClass() != SC_Static &&
+         !VD->hasAttr<HLSLVkConstantIdAttr>() &&
          !isInvalidConstantBufferLeafElementType(QT.getTypePtr());
 }
 
@@ -3221,7 +3300,8 @@ void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
     const Type *VarType = VD->getType().getTypePtr();
     while (VarType->isArrayType())
       VarType = VarType->getArrayElementTypeNoTypeQual();
-    if (VarType->isHLSLResourceRecord()) {
+    if (VarType->isHLSLResourceRecord() ||
+        VD->hasAttr<HLSLVkConstantIdAttr>()) {
       // Make the variable for resources static. The global externally visible
       // storage is accessed through the handle, which is a member. The variable
       // itself is not externally visible.
@@ -3644,3 +3724,41 @@ bool SemaHLSL::transformInitList(const InitializedEntity &Entity,
     Init->updateInit(Ctx, I, NewInit->getInit(I));
   return true;
 }
+
+bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) {
+  const HLSLVkConstantIdAttr *ConstIdAttr =
+      VDecl->getAttr<HLSLVkConstantIdAttr>();
+  if (!ConstIdAttr)
+    return true;
+
+  ASTContext &Context = SemaRef.getASTContext();
+
+  APValue InitValue;
+  if (!Init->isCXX11ConstantExpr(Context, &InitValue)) {
+    Diag(VDecl->getLocation(), diag::err_specialization_const_lit_init);
+    VDecl->setInvalidDecl();
+    return false;
+  }
+
+  Builtin::ID BID = getSpecConstBuiltinId(VDecl->getType());
+
+  // Argument 1: The ID from the attribute
+  int ConstantID = ConstIdAttr->getId();
+  llvm::APInt IDVal(Context.getIntWidth(Context.IntTy), ConstantID);
+  Expr *IdExpr = IntegerLiteral::Create(Context, IDVal, Context.IntTy,
+                                        ConstIdAttr->getLocation());
+
+  SmallVector<Expr *, 2> Args = {IdExpr, Init};
+  Expr *C = SemaRef.BuildBuiltinCallExpr(Init->getExprLoc(), BID, Args);
+  if (C->getType()->getCanonicalTypeUnqualified() !=
+      VDecl->getType()->getCanonicalTypeUnqualified()) {
+    C = SemaRef
+            .BuildCStyleCastExpr(SourceLocation(),
+                                 Context.getTrivialTypeSourceInfo(
+                                     Init->getType(), Init->getExprLoc()),
+                                 SourceLocation(), C)
+            .get();
+  }
+  Init = C;
+  return true;
+}
diff --git a/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl b/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl
new file mode 100644
index 0000000000000..c0955c1ea7b43
--- /dev/null
+++ b/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl
@@ -0,0 +1,130 @@
+// RUN: %clang_cc1 -finclude-default-header -triple spirv-unknown-vulkan-compute -x hlsl -ast-dump -o - %s | FileCheck %s
+
+// CHECK: VarDecl {{.*}} bool_const 'const hlsl_private bool' static cinit
+// CHECK-NEXT: CallExpr {{.*}} 'bool'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'bool (*)(unsigned int, bool) noexcept' <FunctionToPointerDecay>
+// CHECK-NEXT: DeclRefExpr {{.*}} 'bool (unsigned int, bool) noexcept' lvalue Function {{.*}} '__builtin_get_spirv_spec_constant_bool' 'bool (unsigned int, bool) noexcept'
+// CHECK-NEXT: ImplicitCastExpr {{.*}...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jun 10, 2025

@llvm/pr-subscribers-clang-codegen

Author: Steven Perron (s-perron)

Changes

The vk::constant_id attribute is used to indicate that a global const variable
represents a specialization constant in SPIR-V. This PR adds this
attribute to clang.

The documetation for the attribute is here.

The strategy is to to modify the initializer to get the value of a
specialize constant for a builtin defined in the SPIR-V backend.


Patch is 39.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/143544.diff

16 Files Affected:

  • (modified) clang/include/clang/Basic/Attr.td (+8)
  • (modified) clang/include/clang/Basic/AttrDocs.td (+15)
  • (modified) clang/include/clang/Basic/Builtins.td (+13)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+12)
  • (modified) clang/include/clang/Sema/SemaHLSL.h (+4-1)
  • (modified) clang/lib/Basic/Attributes.cpp (+2-1)
  • (modified) clang/lib/CodeGen/CGHLSLBuiltins.cpp (+72)
  • (modified) clang/lib/CodeGen/CodeGenFunction.h (+11)
  • (modified) clang/lib/Sema/SemaDecl.cpp (+14)
  • (modified) clang/lib/Sema/SemaDeclAttr.cpp (+3)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+119-1)
  • (added) clang/test/AST/HLSL/vk.spec-constant.usage.hlsl (+130)
  • (renamed) clang/test/CodeGenHLSL/vk-features/SpirvType.alignment.hlsl ()
  • (renamed) clang/test/CodeGenHLSL/vk-features/SpirvType.hlsl ()
  • (added) clang/test/CodeGenHLSL/vk-features/vk.spec-constant.hlsl (+210)
  • (added) clang/test/SemaHLSL/vk.spec-constant.error.hlsl (+37)
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index f889e41c8699f..d3f39de6a3e85 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4993,6 +4993,14 @@ def HLSLVkExtBuiltinInput : InheritableAttr {
   let Documentation = [HLSLVkExtBuiltinInputDocs];
 }
 
+def HLSLVkConstantId : InheritableAttr {
+  let Spellings = [CXX11<"vk", "constant_id">];
+  let Args = [IntArgument<"Id">];
+  let Subjects = SubjectList<[ExternalGlobalVar]>;
+  let LangOpts = [HLSL];
+  let Documentation = [VkConstantIdDocs];
+}
+
 def RandomizeLayout : InheritableAttr {
   let Spellings = [GCC<"randomize_layout">];
   let Subjects = SubjectList<[Record]>;
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index ea3c43f38d9fe..b3eafb79c5d4a 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -8252,6 +8252,21 @@ and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html
   }];
 }
 
+def VkConstantIdDocs : Documentation {
+  let Category = DocCatFunction;
+  let Content = [{
+The ``vk::constant_id`` attribute specify the id for a SPIR-V specialization
+constant. The attribute applies to const global scalar variables. The variable must be initialized with a C++11 constexpr.
+In SPIR-V, the
+variable will be replaced with an `OpSpecConstant` with the given id.
+The syntax is:
+
+.. code-block:: text
+
+  ``[[vk::constant_id(<Id>)]] const T Name = <Init>``
+}];
+}
+
 def RootSignatureDocs : Documentation {
   let Category = DocCatFunction;
   let Content = [{
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 68cd3d790e78a..d65b3a5d2f447 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -5065,6 +5065,19 @@ def HLSLGroupMemoryBarrierWithGroupSync: LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void()";
 }
 
+class HLSLScalarTemplate
+    : Template<["bool", "char", "short", "int", "long long int",
+                "unsigned short", "unsigned int", "unsigned long long int",
+                "__fp16", "float", "double"],
+               ["_bool", "_char", "_short", "_int", "_longlong", "_ushort",
+                "_uint", "_ulonglong", "_half", "_float", "_double"]>;
+
+def HLSLGetSpirvSpecConstant : LangBuiltin<"HLSL_LANG">, HLSLScalarTemplate {
+  let Spellings = ["__builtin_get_spirv_spec_constant"];
+  let Attributes = [NoThrow, Const, Pure];
+  let Prototype = "T(unsigned int, T)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 1f283b776a02c..23a490225dd19 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12919,6 +12919,18 @@ def err_spirv_enum_not_int : Error<
 def err_spirv_enum_not_valid : Error<
    "invalid value for %select{storage class}0 argument">;
 
+def err_specialization_const_lit_init
+    : Error<"variable with 'vk::constant_id' attribute cannot have an "
+            "initializer that is not a constexpr">;
+def err_specialization_const_missing_initializer
+    : Error<
+          "variable with 'vk::constant_id' attribute must have an initializer">;
+def err_specialization_const_missing_const
+    : Error<"variable with 'vk::constant_id' attribute must be const">;
+def err_specialization_const_is_not_int_or_float
+    : Error<"variable with 'vk::constant_id' attribute must be an enum, bool, "
+            "integer, or floating point value">;
+
 // errors of expect.with.probability
 def err_probability_not_constant_float : Error<
    "probability argument to __builtin_expect_with_probability must be constant "
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index 66d09f49680be..099d9c35684e8 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -98,6 +98,8 @@ class SemaHLSL : public SemaBase {
   HLSLWaveSizeAttr *mergeWaveSizeAttr(Decl *D, const AttributeCommonInfo &AL,
                                       int Min, int Max, int Preferred,
                                       int SpelledArgsCount);
+  HLSLVkConstantIdAttr *
+  mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL, int Id);
   HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
                                   llvm::Triple::EnvironmentType ShaderType);
   HLSLParamModifierAttr *
@@ -122,6 +124,7 @@ class SemaHLSL : public SemaBase {
   void handleRootSignatureAttr(Decl *D, const ParsedAttr &AL);
   void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
   void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
+  void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);
@@ -156,7 +159,7 @@ class SemaHLSL : public SemaBase {
   QualType getInoutParameterType(QualType Ty);
 
   bool transformInitList(const InitializedEntity &Entity, InitListExpr *Init);
-
+  bool handleInitialization(VarDecl *VDecl, Expr *&Init);
   void deduceAddressSpace(VarDecl *Decl);
 
 private:
diff --git a/clang/lib/Basic/Attributes.cpp b/clang/lib/Basic/Attributes.cpp
index 905046685934b..5f74365f0ed00 100644
--- a/clang/lib/Basic/Attributes.cpp
+++ b/clang/lib/Basic/Attributes.cpp
@@ -213,7 +213,8 @@ getScopeFromNormalizedScopeName(StringRef ScopeName) {
       .Case("vk", AttributeCommonInfo::Scope::VK)
       .Case("msvc", AttributeCommonInfo::Scope::MSVC)
       .Case("omp", AttributeCommonInfo::Scope::OMP)
-      .Case("riscv", AttributeCommonInfo::Scope::RISCV);
+      .Case("riscv", AttributeCommonInfo::Scope::RISCV)
+      .Case("vk", AttributeCommonInfo::Scope::HLSL);
 }
 
 unsigned AttributeCommonInfo::calculateAttributeSpellingListIndex() const {
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index abebc201808b0..74b5f92638600 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -12,6 +12,7 @@
 
 #include "CGBuiltin.h"
 #include "CGHLSLRuntime.h"
+#include "CodeGenFunction.h"
 
 using namespace clang;
 using namespace CodeGen;
@@ -774,6 +775,77 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
     return EmitRuntimeCall(
         Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID));
   }
+  case Builtin::BI__builtin_get_spirv_spec_constant_bool:
+  case Builtin::BI__builtin_get_spirv_spec_constant_short:
+  case Builtin::BI__builtin_get_spirv_spec_constant_ushort:
+  case Builtin::BI__builtin_get_spirv_spec_constant_int:
+  case Builtin::BI__builtin_get_spirv_spec_constant_uint:
+  case Builtin::BI__builtin_get_spirv_spec_constant_longlong:
+  case Builtin::BI__builtin_get_spirv_spec_constant_ulonglong:
+  case Builtin::BI__builtin_get_spirv_spec_constant_half:
+  case Builtin::BI__builtin_get_spirv_spec_constant_float:
+  case Builtin::BI__builtin_get_spirv_spec_constant_double: {
+    llvm::Function *SpecConstantFn = getSpecConstantFunction(E->getType());
+    llvm::Value *SpecId = EmitScalarExpr(E->getArg(0));
+    llvm::Value *DefaultVal = EmitScalarExpr(E->getArg(1));
+    llvm::Value *Args[] = {SpecId, DefaultVal};
+    return Builder.CreateCall(SpecConstantFn, Args);
+  }
   }
   return nullptr;
 }
+
+llvm::Function *clang::CodeGen::CodeGenFunction::getSpecConstantFunction(
+    const clang::QualType &SpecConstantType) {
+
+  // Find or create the declaration for the function.
+  llvm::Module *M = &CGM.getModule();
+  std::string MangledName = getSpecConstantFunctionName(SpecConstantType);
+  llvm::Function *SpecConstantFn = M->getFunction(MangledName);
+
+  if (!SpecConstantFn) {
+    llvm::Type *IntType = ConvertType(getContext().IntTy);
+    llvm::Type *RetTy = ConvertType(SpecConstantType);
+    llvm::Type *ArgTypes[] = {IntType, RetTy};
+    llvm::FunctionType *FnTy = llvm::FunctionType::get(RetTy, ArgTypes, false);
+    SpecConstantFn = llvm::Function::Create(
+        FnTy, llvm::GlobalValue::ExternalLinkage, MangledName, M);
+  }
+  return SpecConstantFn;
+}
+
+std::string clang::CodeGen::CodeGenFunction::getSpecConstantFunctionName(
+    const clang::QualType &SpecConstantType) {
+  // The parameter types for our conceptual intrinsic function.
+  ASTContext &Context = getContext();
+  QualType ClangParamTypes[] = {Context.IntTy, SpecConstantType};
+
+  // Create a temporary FunctionDecl for the builtin fuction. It won't be
+  // added to the AST.
+  FunctionProtoType::ExtProtoInfo EPI;
+  QualType FnType =
+      Context.getFunctionType(SpecConstantType, ClangParamTypes, EPI);
+  DeclarationName FuncName = &Context.Idents.get("__spirv_SpecConstant");
+  FunctionDecl *FnDeclForMangling = FunctionDecl::Create(
+      Context, Context.getTranslationUnitDecl(), SourceLocation(),
+      SourceLocation(), FuncName, FnType, /*TSI=*/nullptr, SC_Extern);
+
+  // Attach the created parameter declarations to the function declaration.
+  SmallVector<ParmVarDecl *, 2> ParamDecls;
+  for (QualType ParamType : ClangParamTypes) {
+    ParmVarDecl *PD = ParmVarDecl::Create(
+        Context, FnDeclForMangling, SourceLocation(), SourceLocation(),
+        /*IdentifierInfo*/ nullptr, ParamType, /*TSI*/ nullptr, SC_None,
+        /*DefaultArg*/ nullptr);
+    ParamDecls.push_back(PD);
+  }
+  FnDeclForMangling->setParams(ParamDecls);
+
+  // Get the mangled name.
+  std::string Name;
+  llvm::raw_string_ostream MangledNameStream(Name);
+  MangleContext *Mangler = Context.createMangleContext();
+  Mangler->mangleName(FnDeclForMangling, MangledNameStream);
+  MangledNameStream.flush();
+  return Name;
+}
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index a5ab9df01dba9..badc9feeb8c2b 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4850,6 +4850,17 @@ class CodeGenFunction : public CodeGenTypeCache {
   llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E,
                                    ReturnValueSlot ReturnValue);
+
+  // Returns a builtin function that the SPIR-V backend will expand into a spec
+  // constant.
+  llvm::Function *
+  getSpecConstantFunction(const clang::QualType &SpecConstantType);
+
+  // Returns the mangled name for a builtin function that the SPIR-V backend
+  // will expand into a spec Constant.
+  std::string
+  getSpecConstantFunctionName(const clang::QualType &SpecConstantType);
+
   llvm::Value *EmitDirectXBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitSPIRVBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index bbd63372c168b..9d09eec26f7a2 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -2889,6 +2889,8 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
     NewAttr = S.HLSL().mergeWaveSizeAttr(D, *WS, WS->getMin(), WS->getMax(),
                                          WS->getPreferred(),
                                          WS->getSpelledArgsCount());
+  else if (const auto *CI = dyn_cast<HLSLVkConstantIdAttr>(Attr))
+    NewAttr = S.HLSL().mergeVkConstantIdAttr(D, *CI, CI->getId());
   else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
     NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
   else if (isa<SuppressAttr>(Attr))
@@ -13755,6 +13757,10 @@ void Sema::AddInitializerToDecl(Decl *RealDecl, Expr *Init, bool DirectInit) {
     return;
   }
 
+  if (getLangOpts().HLSL)
+    if (!HLSL().handleInitialization(VDecl, Init))
+      return;
+
   // Get the decls type and save a reference for later, since
   // CheckInitializerTypes may change it.
   QualType DclT = VDecl->getType(), SavT = DclT;
@@ -14215,6 +14221,14 @@ void Sema::ActOnUninitializedDecl(Decl *RealDecl) {
       }
     }
 
+    // HLSL variable with the `vk::constant_id` attribute must be initialized.
+    if (!Var->isInvalidDecl() && Var->hasAttr<HLSLVkConstantIdAttr>()) {
+      Diag(Var->getLocation(),
+           diag::err_specialization_const_missing_initializer);
+      Var->setInvalidDecl();
+      return;
+    }
+
     if (!Var->isInvalidDecl() && RealDecl->hasAttr<LoaderUninitializedAttr>()) {
       if (Var->getStorageClass() == SC_Extern) {
         Diag(Var->getLocation(), diag::err_loader_uninitialized_extern_decl)
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index da0e3265767d8..e49bdda1a402b 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -7560,6 +7560,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
   case ParsedAttr::AT_HLSLVkExtBuiltinInput:
     S.HLSL().handleVkExtBuiltinInputAttr(D, AL);
     break;
+  case ParsedAttr::AT_HLSLVkConstantId:
+    S.HLSL().handleVkConstantIdAttr(D, AL);
+    break;
   case ParsedAttr::AT_HLSLSV_GroupThreadID:
     S.HLSL().handleSV_GroupThreadIDAttr(D, AL);
     break;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 9065cc5a1d4a5..5764507f35882 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -119,6 +119,40 @@ static ResourceClass getResourceClass(RegisterType RT) {
   llvm_unreachable("unexpected RegisterType value");
 }
 
+static Builtin::ID getSpecConstBuiltinId(QualType Type) {
+  const auto *BT = dyn_cast<BuiltinType>(Type);
+  if (!BT) {
+    if (!Type->isEnumeralType())
+      return Builtin::NotBuiltin;
+    return Builtin::BI__builtin_get_spirv_spec_constant_int;
+  }
+
+  switch (BT->getKind()) {
+  case BuiltinType::Bool:
+    return Builtin::BI__builtin_get_spirv_spec_constant_bool;
+  case BuiltinType::Short:
+    return Builtin::BI__builtin_get_spirv_spec_constant_short;
+  case BuiltinType::Int:
+    return Builtin::BI__builtin_get_spirv_spec_constant_int;
+  case BuiltinType::LongLong:
+    return Builtin::BI__builtin_get_spirv_spec_constant_longlong;
+  case BuiltinType::UShort:
+    return Builtin::BI__builtin_get_spirv_spec_constant_ushort;
+  case BuiltinType::UInt:
+    return Builtin::BI__builtin_get_spirv_spec_constant_uint;
+  case BuiltinType::ULongLong:
+    return Builtin::BI__builtin_get_spirv_spec_constant_ulonglong;
+  case BuiltinType::Half:
+    return Builtin::BI__builtin_get_spirv_spec_constant_half;
+  case BuiltinType::Float:
+    return Builtin::BI__builtin_get_spirv_spec_constant_float;
+  case BuiltinType::Double:
+    return Builtin::BI__builtin_get_spirv_spec_constant_double;
+  default:
+    return Builtin::NotBuiltin;
+  }
+}
+
 DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD,
                                                       ResourceClass ResClass) {
   assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
@@ -607,6 +641,41 @@ HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D,
   return Result;
 }
 
+HLSLVkConstantIdAttr *
+SemaHLSL::mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL,
+                                int Id) {
+
+  auto &TargetInfo = getASTContext().getTargetInfo();
+  if (TargetInfo.getTriple().getArch() != llvm::Triple::spirv) {
+    Diag(AL.getLoc(), diag::warn_attribute_ignored) << AL;
+    return nullptr;
+  }
+
+  auto *VD = cast<VarDecl>(D);
+
+  if (getSpecConstBuiltinId(VD->getType()) == Builtin::NotBuiltin) {
+    Diag(VD->getLocation(), diag::err_specialization_const_is_not_int_or_float);
+    return nullptr;
+  }
+
+  if (!VD->getType().isConstQualified()) {
+    Diag(VD->getLocation(), diag::err_specialization_const_missing_const);
+    return nullptr;
+  }
+
+  if (HLSLVkConstantIdAttr *CI = D->getAttr<HLSLVkConstantIdAttr>()) {
+    if (CI->getId() != Id) {
+      Diag(CI->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
+      Diag(AL.getLoc(), diag::note_conflicting_attribute);
+    }
+    return nullptr;
+  }
+
+  HLSLVkConstantIdAttr *Result =
+      ::new (getASTContext()) HLSLVkConstantIdAttr(getASTContext(), AL, Id);
+  return Result;
+}
+
 HLSLShaderAttr *
 SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
                           llvm::Triple::EnvironmentType ShaderType) {
@@ -1125,6 +1194,15 @@ void SemaHLSL::handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL) {
                  HLSLVkExtBuiltinInputAttr(getASTContext(), AL, ID));
 }
 
+void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) {
+  uint32_t Id;
+  if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Id))
+    return;
+  HLSLVkConstantIdAttr *NewAttr = mergeVkConstantIdAttr(D, AL, Id);
+  if (NewAttr)
+    D->addAttr(NewAttr);
+}
+
 bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) {
   const auto *VT = T->getAs<VectorType>();
 
@@ -3154,6 +3232,7 @@ static bool IsDefaultBufferConstantDecl(VarDecl *VD) {
   return VD->getDeclContext()->isTranslationUnit() &&
          QT.getAddressSpace() == LangAS::Default &&
          VD->getStorageClass() != SC_Static &&
+         !VD->hasAttr<HLSLVkConstantIdAttr>() &&
          !isInvalidConstantBufferLeafElementType(QT.getTypePtr());
 }
 
@@ -3221,7 +3300,8 @@ void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
     const Type *VarType = VD->getType().getTypePtr();
     while (VarType->isArrayType())
       VarType = VarType->getArrayElementTypeNoTypeQual();
-    if (VarType->isHLSLResourceRecord()) {
+    if (VarType->isHLSLResourceRecord() ||
+        VD->hasAttr<HLSLVkConstantIdAttr>()) {
       // Make the variable for resources static. The global externally visible
       // storage is accessed through the handle, which is a member. The variable
       // itself is not externally visible.
@@ -3644,3 +3724,41 @@ bool SemaHLSL::transformInitList(const InitializedEntity &Entity,
     Init->updateInit(Ctx, I, NewInit->getInit(I));
   return true;
 }
+
+bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) {
+  const HLSLVkConstantIdAttr *ConstIdAttr =
+      VDecl->getAttr<HLSLVkConstantIdAttr>();
+  if (!ConstIdAttr)
+    return true;
+
+  ASTContext &Context = SemaRef.getASTContext();
+
+  APValue InitValue;
+  if (!Init->isCXX11ConstantExpr(Context, &InitValue)) {
+    Diag(VDecl->getLocation(), diag::err_specialization_const_lit_init);
+    VDecl->setInvalidDecl();
+    return false;
+  }
+
+  Builtin::ID BID = getSpecConstBuiltinId(VDecl->getType());
+
+  // Argument 1: The ID from the attribute
+  int ConstantID = ConstIdAttr->getId();
+  llvm::APInt IDVal(Context.getIntWidth(Context.IntTy), ConstantID);
+  Expr *IdExpr = IntegerLiteral::Create(Context, IDVal, Context.IntTy,
+                                        ConstIdAttr->getLocation());
+
+  SmallVector<Expr *, 2> Args = {IdExpr, Init};
+  Expr *C = SemaRef.BuildBuiltinCallExpr(Init->getExprLoc(), BID, Args);
+  if (C->getType()->getCanonicalTypeUnqualified() !=
+      VDecl->getType()->getCanonicalTypeUnqualified()) {
+    C = SemaRef
+            .BuildCStyleCastExpr(SourceLocation(),
+                                 Context.getTrivialTypeSourceInfo(
+                                     Init->getType(), Init->getExprLoc()),
+                                 SourceLocation(), C)
+            .get();
+  }
+  Init = C;
+  return true;
+}
diff --git a/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl b/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl
new file mode 100644
index 0000000000000..c0955c1ea7b43
--- /dev/null
+++ b/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl
@@ -0,0 +1,130 @@
+// RUN: %clang_cc1 -finclude-default-header -triple spirv-unknown-vulkan-compute -x hlsl -ast-dump -o - %s | FileCheck %s
+
+// CHECK: VarDecl {{.*}} bool_const 'const hlsl_private bool' static cinit
+// CHECK-NEXT: CallExpr {{.*}} 'bool'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'bool (*)(unsigned int, bool) noexcept' <FunctionToPointerDecay>
+// CHECK-NEXT: DeclRefExpr {{.*}} 'bool (unsigned int, bool) noexcept' lvalue Function {{.*}} '__builtin_get_spirv_spec_constant_bool' 'bool (unsigned int, bool) noexcept'
+// CHECK-NEXT: ImplicitCastExpr {{.*}...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jun 10, 2025

@llvm/pr-subscribers-hlsl

Author: Steven Perron (s-perron)

Changes

The vk::constant_id attribute is used to indicate that a global const variable
represents a specialization constant in SPIR-V. This PR adds this
attribute to clang.

The documetation for the attribute is here.

The strategy is to to modify the initializer to get the value of a
specialize constant for a builtin defined in the SPIR-V backend.


Patch is 39.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/143544.diff

16 Files Affected:

  • (modified) clang/include/clang/Basic/Attr.td (+8)
  • (modified) clang/include/clang/Basic/AttrDocs.td (+15)
  • (modified) clang/include/clang/Basic/Builtins.td (+13)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+12)
  • (modified) clang/include/clang/Sema/SemaHLSL.h (+4-1)
  • (modified) clang/lib/Basic/Attributes.cpp (+2-1)
  • (modified) clang/lib/CodeGen/CGHLSLBuiltins.cpp (+72)
  • (modified) clang/lib/CodeGen/CodeGenFunction.h (+11)
  • (modified) clang/lib/Sema/SemaDecl.cpp (+14)
  • (modified) clang/lib/Sema/SemaDeclAttr.cpp (+3)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+119-1)
  • (added) clang/test/AST/HLSL/vk.spec-constant.usage.hlsl (+130)
  • (renamed) clang/test/CodeGenHLSL/vk-features/SpirvType.alignment.hlsl ()
  • (renamed) clang/test/CodeGenHLSL/vk-features/SpirvType.hlsl ()
  • (added) clang/test/CodeGenHLSL/vk-features/vk.spec-constant.hlsl (+210)
  • (added) clang/test/SemaHLSL/vk.spec-constant.error.hlsl (+37)
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index f889e41c8699f..d3f39de6a3e85 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4993,6 +4993,14 @@ def HLSLVkExtBuiltinInput : InheritableAttr {
   let Documentation = [HLSLVkExtBuiltinInputDocs];
 }
 
+def HLSLVkConstantId : InheritableAttr {
+  let Spellings = [CXX11<"vk", "constant_id">];
+  let Args = [IntArgument<"Id">];
+  let Subjects = SubjectList<[ExternalGlobalVar]>;
+  let LangOpts = [HLSL];
+  let Documentation = [VkConstantIdDocs];
+}
+
 def RandomizeLayout : InheritableAttr {
   let Spellings = [GCC<"randomize_layout">];
   let Subjects = SubjectList<[Record]>;
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index ea3c43f38d9fe..b3eafb79c5d4a 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -8252,6 +8252,21 @@ and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html
   }];
 }
 
+def VkConstantIdDocs : Documentation {
+  let Category = DocCatFunction;
+  let Content = [{
+The ``vk::constant_id`` attribute specify the id for a SPIR-V specialization
+constant. The attribute applies to const global scalar variables. The variable must be initialized with a C++11 constexpr.
+In SPIR-V, the
+variable will be replaced with an `OpSpecConstant` with the given id.
+The syntax is:
+
+.. code-block:: text
+
+  ``[[vk::constant_id(<Id>)]] const T Name = <Init>``
+}];
+}
+
 def RootSignatureDocs : Documentation {
   let Category = DocCatFunction;
   let Content = [{
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 68cd3d790e78a..d65b3a5d2f447 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -5065,6 +5065,19 @@ def HLSLGroupMemoryBarrierWithGroupSync: LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void()";
 }
 
+class HLSLScalarTemplate
+    : Template<["bool", "char", "short", "int", "long long int",
+                "unsigned short", "unsigned int", "unsigned long long int",
+                "__fp16", "float", "double"],
+               ["_bool", "_char", "_short", "_int", "_longlong", "_ushort",
+                "_uint", "_ulonglong", "_half", "_float", "_double"]>;
+
+def HLSLGetSpirvSpecConstant : LangBuiltin<"HLSL_LANG">, HLSLScalarTemplate {
+  let Spellings = ["__builtin_get_spirv_spec_constant"];
+  let Attributes = [NoThrow, Const, Pure];
+  let Prototype = "T(unsigned int, T)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 1f283b776a02c..23a490225dd19 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12919,6 +12919,18 @@ def err_spirv_enum_not_int : Error<
 def err_spirv_enum_not_valid : Error<
    "invalid value for %select{storage class}0 argument">;
 
+def err_specialization_const_lit_init
+    : Error<"variable with 'vk::constant_id' attribute cannot have an "
+            "initializer that is not a constexpr">;
+def err_specialization_const_missing_initializer
+    : Error<
+          "variable with 'vk::constant_id' attribute must have an initializer">;
+def err_specialization_const_missing_const
+    : Error<"variable with 'vk::constant_id' attribute must be const">;
+def err_specialization_const_is_not_int_or_float
+    : Error<"variable with 'vk::constant_id' attribute must be an enum, bool, "
+            "integer, or floating point value">;
+
 // errors of expect.with.probability
 def err_probability_not_constant_float : Error<
    "probability argument to __builtin_expect_with_probability must be constant "
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index 66d09f49680be..099d9c35684e8 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -98,6 +98,8 @@ class SemaHLSL : public SemaBase {
   HLSLWaveSizeAttr *mergeWaveSizeAttr(Decl *D, const AttributeCommonInfo &AL,
                                       int Min, int Max, int Preferred,
                                       int SpelledArgsCount);
+  HLSLVkConstantIdAttr *
+  mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL, int Id);
   HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
                                   llvm::Triple::EnvironmentType ShaderType);
   HLSLParamModifierAttr *
@@ -122,6 +124,7 @@ class SemaHLSL : public SemaBase {
   void handleRootSignatureAttr(Decl *D, const ParsedAttr &AL);
   void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
   void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
+  void handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL);
   void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);
@@ -156,7 +159,7 @@ class SemaHLSL : public SemaBase {
   QualType getInoutParameterType(QualType Ty);
 
   bool transformInitList(const InitializedEntity &Entity, InitListExpr *Init);
-
+  bool handleInitialization(VarDecl *VDecl, Expr *&Init);
   void deduceAddressSpace(VarDecl *Decl);
 
 private:
diff --git a/clang/lib/Basic/Attributes.cpp b/clang/lib/Basic/Attributes.cpp
index 905046685934b..5f74365f0ed00 100644
--- a/clang/lib/Basic/Attributes.cpp
+++ b/clang/lib/Basic/Attributes.cpp
@@ -213,7 +213,8 @@ getScopeFromNormalizedScopeName(StringRef ScopeName) {
       .Case("vk", AttributeCommonInfo::Scope::VK)
       .Case("msvc", AttributeCommonInfo::Scope::MSVC)
       .Case("omp", AttributeCommonInfo::Scope::OMP)
-      .Case("riscv", AttributeCommonInfo::Scope::RISCV);
+      .Case("riscv", AttributeCommonInfo::Scope::RISCV)
+      .Case("vk", AttributeCommonInfo::Scope::HLSL);
 }
 
 unsigned AttributeCommonInfo::calculateAttributeSpellingListIndex() const {
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index abebc201808b0..74b5f92638600 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -12,6 +12,7 @@
 
 #include "CGBuiltin.h"
 #include "CGHLSLRuntime.h"
+#include "CodeGenFunction.h"
 
 using namespace clang;
 using namespace CodeGen;
@@ -774,6 +775,77 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
     return EmitRuntimeCall(
         Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID));
   }
+  case Builtin::BI__builtin_get_spirv_spec_constant_bool:
+  case Builtin::BI__builtin_get_spirv_spec_constant_short:
+  case Builtin::BI__builtin_get_spirv_spec_constant_ushort:
+  case Builtin::BI__builtin_get_spirv_spec_constant_int:
+  case Builtin::BI__builtin_get_spirv_spec_constant_uint:
+  case Builtin::BI__builtin_get_spirv_spec_constant_longlong:
+  case Builtin::BI__builtin_get_spirv_spec_constant_ulonglong:
+  case Builtin::BI__builtin_get_spirv_spec_constant_half:
+  case Builtin::BI__builtin_get_spirv_spec_constant_float:
+  case Builtin::BI__builtin_get_spirv_spec_constant_double: {
+    llvm::Function *SpecConstantFn = getSpecConstantFunction(E->getType());
+    llvm::Value *SpecId = EmitScalarExpr(E->getArg(0));
+    llvm::Value *DefaultVal = EmitScalarExpr(E->getArg(1));
+    llvm::Value *Args[] = {SpecId, DefaultVal};
+    return Builder.CreateCall(SpecConstantFn, Args);
+  }
   }
   return nullptr;
 }
+
+llvm::Function *clang::CodeGen::CodeGenFunction::getSpecConstantFunction(
+    const clang::QualType &SpecConstantType) {
+
+  // Find or create the declaration for the function.
+  llvm::Module *M = &CGM.getModule();
+  std::string MangledName = getSpecConstantFunctionName(SpecConstantType);
+  llvm::Function *SpecConstantFn = M->getFunction(MangledName);
+
+  if (!SpecConstantFn) {
+    llvm::Type *IntType = ConvertType(getContext().IntTy);
+    llvm::Type *RetTy = ConvertType(SpecConstantType);
+    llvm::Type *ArgTypes[] = {IntType, RetTy};
+    llvm::FunctionType *FnTy = llvm::FunctionType::get(RetTy, ArgTypes, false);
+    SpecConstantFn = llvm::Function::Create(
+        FnTy, llvm::GlobalValue::ExternalLinkage, MangledName, M);
+  }
+  return SpecConstantFn;
+}
+
+std::string clang::CodeGen::CodeGenFunction::getSpecConstantFunctionName(
+    const clang::QualType &SpecConstantType) {
+  // The parameter types for our conceptual intrinsic function.
+  ASTContext &Context = getContext();
+  QualType ClangParamTypes[] = {Context.IntTy, SpecConstantType};
+
+  // Create a temporary FunctionDecl for the builtin fuction. It won't be
+  // added to the AST.
+  FunctionProtoType::ExtProtoInfo EPI;
+  QualType FnType =
+      Context.getFunctionType(SpecConstantType, ClangParamTypes, EPI);
+  DeclarationName FuncName = &Context.Idents.get("__spirv_SpecConstant");
+  FunctionDecl *FnDeclForMangling = FunctionDecl::Create(
+      Context, Context.getTranslationUnitDecl(), SourceLocation(),
+      SourceLocation(), FuncName, FnType, /*TSI=*/nullptr, SC_Extern);
+
+  // Attach the created parameter declarations to the function declaration.
+  SmallVector<ParmVarDecl *, 2> ParamDecls;
+  for (QualType ParamType : ClangParamTypes) {
+    ParmVarDecl *PD = ParmVarDecl::Create(
+        Context, FnDeclForMangling, SourceLocation(), SourceLocation(),
+        /*IdentifierInfo*/ nullptr, ParamType, /*TSI*/ nullptr, SC_None,
+        /*DefaultArg*/ nullptr);
+    ParamDecls.push_back(PD);
+  }
+  FnDeclForMangling->setParams(ParamDecls);
+
+  // Get the mangled name.
+  std::string Name;
+  llvm::raw_string_ostream MangledNameStream(Name);
+  MangleContext *Mangler = Context.createMangleContext();
+  Mangler->mangleName(FnDeclForMangling, MangledNameStream);
+  MangledNameStream.flush();
+  return Name;
+}
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index a5ab9df01dba9..badc9feeb8c2b 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4850,6 +4850,17 @@ class CodeGenFunction : public CodeGenTypeCache {
   llvm::Value *EmitAMDGPUBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitHLSLBuiltinExpr(unsigned BuiltinID, const CallExpr *E,
                                    ReturnValueSlot ReturnValue);
+
+  // Returns a builtin function that the SPIR-V backend will expand into a spec
+  // constant.
+  llvm::Function *
+  getSpecConstantFunction(const clang::QualType &SpecConstantType);
+
+  // Returns the mangled name for a builtin function that the SPIR-V backend
+  // will expand into a spec Constant.
+  std::string
+  getSpecConstantFunctionName(const clang::QualType &SpecConstantType);
+
   llvm::Value *EmitDirectXBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitSPIRVBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   llvm::Value *EmitScalarOrConstFoldImmArg(unsigned ICEArguments, unsigned Idx,
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index bbd63372c168b..9d09eec26f7a2 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -2889,6 +2889,8 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D,
     NewAttr = S.HLSL().mergeWaveSizeAttr(D, *WS, WS->getMin(), WS->getMax(),
                                          WS->getPreferred(),
                                          WS->getSpelledArgsCount());
+  else if (const auto *CI = dyn_cast<HLSLVkConstantIdAttr>(Attr))
+    NewAttr = S.HLSL().mergeVkConstantIdAttr(D, *CI, CI->getId());
   else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr))
     NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType());
   else if (isa<SuppressAttr>(Attr))
@@ -13755,6 +13757,10 @@ void Sema::AddInitializerToDecl(Decl *RealDecl, Expr *Init, bool DirectInit) {
     return;
   }
 
+  if (getLangOpts().HLSL)
+    if (!HLSL().handleInitialization(VDecl, Init))
+      return;
+
   // Get the decls type and save a reference for later, since
   // CheckInitializerTypes may change it.
   QualType DclT = VDecl->getType(), SavT = DclT;
@@ -14215,6 +14221,14 @@ void Sema::ActOnUninitializedDecl(Decl *RealDecl) {
       }
     }
 
+    // HLSL variable with the `vk::constant_id` attribute must be initialized.
+    if (!Var->isInvalidDecl() && Var->hasAttr<HLSLVkConstantIdAttr>()) {
+      Diag(Var->getLocation(),
+           diag::err_specialization_const_missing_initializer);
+      Var->setInvalidDecl();
+      return;
+    }
+
     if (!Var->isInvalidDecl() && RealDecl->hasAttr<LoaderUninitializedAttr>()) {
       if (Var->getStorageClass() == SC_Extern) {
         Diag(Var->getLocation(), diag::err_loader_uninitialized_extern_decl)
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index da0e3265767d8..e49bdda1a402b 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -7560,6 +7560,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
   case ParsedAttr::AT_HLSLVkExtBuiltinInput:
     S.HLSL().handleVkExtBuiltinInputAttr(D, AL);
     break;
+  case ParsedAttr::AT_HLSLVkConstantId:
+    S.HLSL().handleVkConstantIdAttr(D, AL);
+    break;
   case ParsedAttr::AT_HLSLSV_GroupThreadID:
     S.HLSL().handleSV_GroupThreadIDAttr(D, AL);
     break;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 9065cc5a1d4a5..5764507f35882 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -119,6 +119,40 @@ static ResourceClass getResourceClass(RegisterType RT) {
   llvm_unreachable("unexpected RegisterType value");
 }
 
+static Builtin::ID getSpecConstBuiltinId(QualType Type) {
+  const auto *BT = dyn_cast<BuiltinType>(Type);
+  if (!BT) {
+    if (!Type->isEnumeralType())
+      return Builtin::NotBuiltin;
+    return Builtin::BI__builtin_get_spirv_spec_constant_int;
+  }
+
+  switch (BT->getKind()) {
+  case BuiltinType::Bool:
+    return Builtin::BI__builtin_get_spirv_spec_constant_bool;
+  case BuiltinType::Short:
+    return Builtin::BI__builtin_get_spirv_spec_constant_short;
+  case BuiltinType::Int:
+    return Builtin::BI__builtin_get_spirv_spec_constant_int;
+  case BuiltinType::LongLong:
+    return Builtin::BI__builtin_get_spirv_spec_constant_longlong;
+  case BuiltinType::UShort:
+    return Builtin::BI__builtin_get_spirv_spec_constant_ushort;
+  case BuiltinType::UInt:
+    return Builtin::BI__builtin_get_spirv_spec_constant_uint;
+  case BuiltinType::ULongLong:
+    return Builtin::BI__builtin_get_spirv_spec_constant_ulonglong;
+  case BuiltinType::Half:
+    return Builtin::BI__builtin_get_spirv_spec_constant_half;
+  case BuiltinType::Float:
+    return Builtin::BI__builtin_get_spirv_spec_constant_float;
+  case BuiltinType::Double:
+    return Builtin::BI__builtin_get_spirv_spec_constant_double;
+  default:
+    return Builtin::NotBuiltin;
+  }
+}
+
 DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD,
                                                       ResourceClass ResClass) {
   assert(getDeclBindingInfo(VD, ResClass) == nullptr &&
@@ -607,6 +641,41 @@ HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D,
   return Result;
 }
 
+HLSLVkConstantIdAttr *
+SemaHLSL::mergeVkConstantIdAttr(Decl *D, const AttributeCommonInfo &AL,
+                                int Id) {
+
+  auto &TargetInfo = getASTContext().getTargetInfo();
+  if (TargetInfo.getTriple().getArch() != llvm::Triple::spirv) {
+    Diag(AL.getLoc(), diag::warn_attribute_ignored) << AL;
+    return nullptr;
+  }
+
+  auto *VD = cast<VarDecl>(D);
+
+  if (getSpecConstBuiltinId(VD->getType()) == Builtin::NotBuiltin) {
+    Diag(VD->getLocation(), diag::err_specialization_const_is_not_int_or_float);
+    return nullptr;
+  }
+
+  if (!VD->getType().isConstQualified()) {
+    Diag(VD->getLocation(), diag::err_specialization_const_missing_const);
+    return nullptr;
+  }
+
+  if (HLSLVkConstantIdAttr *CI = D->getAttr<HLSLVkConstantIdAttr>()) {
+    if (CI->getId() != Id) {
+      Diag(CI->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL;
+      Diag(AL.getLoc(), diag::note_conflicting_attribute);
+    }
+    return nullptr;
+  }
+
+  HLSLVkConstantIdAttr *Result =
+      ::new (getASTContext()) HLSLVkConstantIdAttr(getASTContext(), AL, Id);
+  return Result;
+}
+
 HLSLShaderAttr *
 SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL,
                           llvm::Triple::EnvironmentType ShaderType) {
@@ -1125,6 +1194,15 @@ void SemaHLSL::handleVkExtBuiltinInputAttr(Decl *D, const ParsedAttr &AL) {
                  HLSLVkExtBuiltinInputAttr(getASTContext(), AL, ID));
 }
 
+void SemaHLSL::handleVkConstantIdAttr(Decl *D, const ParsedAttr &AL) {
+  uint32_t Id;
+  if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Id))
+    return;
+  HLSLVkConstantIdAttr *NewAttr = mergeVkConstantIdAttr(D, AL, Id);
+  if (NewAttr)
+    D->addAttr(NewAttr);
+}
+
 bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) {
   const auto *VT = T->getAs<VectorType>();
 
@@ -3154,6 +3232,7 @@ static bool IsDefaultBufferConstantDecl(VarDecl *VD) {
   return VD->getDeclContext()->isTranslationUnit() &&
          QT.getAddressSpace() == LangAS::Default &&
          VD->getStorageClass() != SC_Static &&
+         !VD->hasAttr<HLSLVkConstantIdAttr>() &&
          !isInvalidConstantBufferLeafElementType(QT.getTypePtr());
 }
 
@@ -3221,7 +3300,8 @@ void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) {
     const Type *VarType = VD->getType().getTypePtr();
     while (VarType->isArrayType())
       VarType = VarType->getArrayElementTypeNoTypeQual();
-    if (VarType->isHLSLResourceRecord()) {
+    if (VarType->isHLSLResourceRecord() ||
+        VD->hasAttr<HLSLVkConstantIdAttr>()) {
       // Make the variable for resources static. The global externally visible
       // storage is accessed through the handle, which is a member. The variable
       // itself is not externally visible.
@@ -3644,3 +3724,41 @@ bool SemaHLSL::transformInitList(const InitializedEntity &Entity,
     Init->updateInit(Ctx, I, NewInit->getInit(I));
   return true;
 }
+
+bool SemaHLSL::handleInitialization(VarDecl *VDecl, Expr *&Init) {
+  const HLSLVkConstantIdAttr *ConstIdAttr =
+      VDecl->getAttr<HLSLVkConstantIdAttr>();
+  if (!ConstIdAttr)
+    return true;
+
+  ASTContext &Context = SemaRef.getASTContext();
+
+  APValue InitValue;
+  if (!Init->isCXX11ConstantExpr(Context, &InitValue)) {
+    Diag(VDecl->getLocation(), diag::err_specialization_const_lit_init);
+    VDecl->setInvalidDecl();
+    return false;
+  }
+
+  Builtin::ID BID = getSpecConstBuiltinId(VDecl->getType());
+
+  // Argument 1: The ID from the attribute
+  int ConstantID = ConstIdAttr->getId();
+  llvm::APInt IDVal(Context.getIntWidth(Context.IntTy), ConstantID);
+  Expr *IdExpr = IntegerLiteral::Create(Context, IDVal, Context.IntTy,
+                                        ConstIdAttr->getLocation());
+
+  SmallVector<Expr *, 2> Args = {IdExpr, Init};
+  Expr *C = SemaRef.BuildBuiltinCallExpr(Init->getExprLoc(), BID, Args);
+  if (C->getType()->getCanonicalTypeUnqualified() !=
+      VDecl->getType()->getCanonicalTypeUnqualified()) {
+    C = SemaRef
+            .BuildCStyleCastExpr(SourceLocation(),
+                                 Context.getTrivialTypeSourceInfo(
+                                     Init->getType(), Init->getExprLoc()),
+                                 SourceLocation(), C)
+            .get();
+  }
+  Init = C;
+  return true;
+}
diff --git a/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl b/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl
new file mode 100644
index 0000000000000..c0955c1ea7b43
--- /dev/null
+++ b/clang/test/AST/HLSL/vk.spec-constant.usage.hlsl
@@ -0,0 +1,130 @@
+// RUN: %clang_cc1 -finclude-default-header -triple spirv-unknown-vulkan-compute -x hlsl -ast-dump -o - %s | FileCheck %s
+
+// CHECK: VarDecl {{.*}} bool_const 'const hlsl_private bool' static cinit
+// CHECK-NEXT: CallExpr {{.*}} 'bool'
+// CHECK-NEXT: ImplicitCastExpr {{.*}} 'bool (*)(unsigned int, bool) noexcept' <FunctionToPointerDecay>
+// CHECK-NEXT: DeclRefExpr {{.*}} 'bool (unsigned int, bool) noexcept' lvalue Function {{.*}} '__builtin_get_spirv_spec_constant_bool' 'bool (unsigned int, bool) noexcept'
+// CHECK-NEXT: ImplicitCastExpr {{.*}...
[truncated]

Comment on lines +778 to +787
case Builtin::BI__builtin_get_spirv_spec_constant_bool:
case Builtin::BI__builtin_get_spirv_spec_constant_short:
case Builtin::BI__builtin_get_spirv_spec_constant_ushort:
case Builtin::BI__builtin_get_spirv_spec_constant_int:
case Builtin::BI__builtin_get_spirv_spec_constant_uint:
case Builtin::BI__builtin_get_spirv_spec_constant_longlong:
case Builtin::BI__builtin_get_spirv_spec_constant_ulonglong:
case Builtin::BI__builtin_get_spirv_spec_constant_half:
case Builtin::BI__builtin_get_spirv_spec_constant_float:
case Builtin::BI__builtin_get_spirv_spec_constant_double: {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to handle overloading the same way wave_read_lane_at does it?
This way there is a single Builtin::BI to handle here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, but I like this method because it becomes impossible to create calls to the builtin with invalid types. wave_read_lane_at is defined with a prototype: void(...). I believe we can add type checking in other places, but I feel this is more robust.

@s-perron s-perron requested a review from Keenuts June 11, 2025 15:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category HLSL HLSL Language Support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[HLSL][SPIRV] Add the vk::constant_id to clang
3 participants