Skip to content

[mlir][spirv] Add support for SPV_ARM_tensors #144667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

davidegrohmann
Copy link
Contributor

@davidegrohmann davidegrohmann commented Jun 18, 2025

This patch introduces a new custom type !spirv.arm.tensor<> to the MLIR SPIR-V dialect to represent
OpTypeTensorARM as defined in the SPV_ARM_tensors extension.

The type models a shaped tensor with element type and optional shape, and implements the
ShapedType interface to enable reuse of MLIR's existing shape-aware infrastructure.

The type supports serialization to and from SPIR-V binary as OpTypeTensorARM, and emits the
required capability (TensorsARM) and extension (SPV_ARM_tensors) declarations automatically.

This addition lays the foundation for supporting structured tensor values natively in SPIR-V and
will enable future support for operations defined in the SPV_ARM_tensors extension, such as
OpTensorReadARM, OpTensorWriteARM, and OpTensorQuerySizeARM.

Reference: KhronosGroup/SPIRV-Registry#342

Signed-off-by: Davide Grohmann <[email protected]>
Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
Change-Id: If78909a47417ef3dda710847cfe90c34b984ff09
@llvmbot
Copy link
Member

llvmbot commented Jun 18, 2025

@llvm/pr-subscribers-mlir-spirv

Author: Davide Grohmann (davidegrohmann)

Changes

Patch is 32.41 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144667.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+31-3)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h (+34-1)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp (+76-1)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (+6)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp (+120-4)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp (+1)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+52)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.h (+2)
  • (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+48)
  • (modified) mlir/test/Dialect/SPIRV/IR/types.mlir (+51)
  • (added) mlir/test/Target/SPIRV/tensorARM.mlir (+66)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index d2ba76cdad904..d874817e6888d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -422,6 +422,8 @@ def SPV_NV_ray_tracing_motion_blur       : I32EnumAttrCase<"SPV_NV_ray_tracing_m
 
 def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>;
 
+def SPV_ARM_tensors                      : I32EnumAttrCase<"SPV_ARM_tensors", 6000>;
+
 def SPIRV_ExtensionAttr :
     SPIRV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [
       SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_device_group,
@@ -445,6 +447,7 @@ def SPIRV_ExtensionAttr :
       SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
       SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
       SPV_EXT_mesh_shader,
+      SPV_ARM_tensors,
       SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
       SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
       SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod,
@@ -1311,6 +1314,24 @@ def SPIRV_C_GeometryStreams                             : I32EnumAttrCase<"Geome
 def SPIRV_C_MultiViewport                               : I32EnumAttrCase<"MultiViewport", 57> {
   list<I32EnumAttrCase> implies = [SPIRV_C_Geometry];
 }
+def SPIRV_C_TensorsARM                                  : I32EnumAttrCase<"TensorsARM", 4174> {
+  list<I32EnumAttrCase> implies = [SPIRV_C_Int8];
+  list<Availability> availability = [
+    Extension<[SPV_ARM_tensors]>
+  ];
+}
+def SPIRV_C_StorageTensorArrayDynamicIndexingEXT        : I32EnumAttrCase<"StorageTensorArrayDynamicIndexingEXT", 4175> {
+  list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_Shader];
+  list<Availability> availability = [
+    Extension<[SPV_ARM_tensors]>
+  ];
+}
+def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT     : I32EnumAttrCase<"StorageTensorArrayNonUniformIndexingEXT", 4176> {
+  list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_ShaderNonUniform];
+  list<Availability> availability = [
+    Extension<[SPV_ARM_tensors]>
+  ];
+}
 def SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR  : I32EnumAttrCase<"WorkgroupMemoryExplicitLayout8BitAccessKHR", 4429> {
   list<I32EnumAttrCase> implies = [SPIRV_C_WorkgroupMemoryExplicitLayoutKHR];
   list<Availability> availability = [
@@ -1523,6 +1544,8 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_IntegerFunctions2INTEL, SPIRV_C_TessellationPointSize,
       SPIRV_C_GeometryPointSize, SPIRV_C_ImageCubeArray, SPIRV_C_ImageRect,
       SPIRV_C_GeometryStreams, SPIRV_C_MultiViewport,
+      SPIRV_C_TensorsARM, SPIRV_C_StorageTensorArrayDynamicIndexingEXT,
+      SPIRV_C_StorageTensorArrayNonUniformIndexingEXT,
       SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR, SPIRV_C_VariablePointers,
       SPIRV_C_RayTraversalPrimitiveCullingKHR, SPIRV_C_SampleMaskOverrideCoverageNV,
       SPIRV_C_GeometryShaderPassthroughNV, SPIRV_C_PerViewAttributesNV,
@@ -4179,7 +4202,7 @@ def SPIRV_IsPtrType : CPred<"::llvm::isa<::mlir::spirv::PointerType>($_self)">;
 def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">;
 def SPIRV_IsSampledImageType : CPred<"::llvm::isa<::mlir::spirv::SampledImageType>($_self)">;
 def SPIRV_IsStructType : CPred<"::llvm::isa<::mlir::spirv::StructType>($_self)">;
-
+def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_self)">;
 
 // See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
 // for the definition of the following types and type categories.
@@ -4217,6 +4240,8 @@ def SPIRV_AnyStruct : DialectType<SPIRV_Dialect, SPIRV_IsStructType,
                                 "any SPIR-V struct type">;
 def SPIRV_AnySampledImage : DialectType<SPIRV_Dialect, SPIRV_IsSampledImageType,
                                 "any SPIR-V sampled image type">;
+def SPIRV_AnyTensorArm : DialectType<SPIRV_Dialect, SPIRV_IsTensorArmType,
+                                 "any SPIR-V tensorArm type">;
 
 def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>;
 def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
@@ -4228,7 +4253,7 @@ def SPIRV_Type : AnyTypeOf<[
     SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector,
     SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
     SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage,
-    SPIRV_AnyImage
+    SPIRV_AnyImage, SPIRV_AnyTensorArm
   ]>;
 
 def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
@@ -4525,6 +4550,7 @@ def SPIRV_OC_OpGroupNonUniformBitwiseXor      : I32EnumAttrCase<"OpGroupNonUnifo
 def SPIRV_OC_OpGroupNonUniformLogicalAnd      : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>;
 def SPIRV_OC_OpGroupNonUniformLogicalOr       : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
 def SPIRV_OC_OpGroupNonUniformLogicalXor      : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
+def SPIRV_OC_OpTypeTensorARM                  : I32EnumAttrCase<"OpTypeTensorARM", 4163>;
 def SPIRV_OC_OpSubgroupBallotKHR              : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
 def SPIRV_OC_OpGroupNonUniformRotateKHR       : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
 def SPIRV_OC_OpSDot                           : I32EnumAttrCase<"OpSDot", 4450>;
@@ -4638,7 +4664,9 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpGroupNonUniformBitwiseAnd,
       SPIRV_OC_OpGroupNonUniformBitwiseOr, SPIRV_OC_OpGroupNonUniformBitwiseXor,
       SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
-      SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpSubgroupBallotKHR,
+      SPIRV_OC_OpGroupNonUniformLogicalXor,
+      SPIRV_OC_OpTypeTensorARM,
+      SPIRV_OC_OpSubgroupBallotKHR,
       SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
       SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
       SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 787535d0a6bd2..7ffea6e7dba81 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -29,6 +29,7 @@ namespace spirv {
 namespace detail {
 struct ArrayTypeStorage;
 struct CooperativeMatrixTypeStorage;
+struct TensorArmTypeStorage;
 struct ImageTypeStorage;
 struct MatrixTypeStorage;
 struct PointerTypeStorage;
@@ -96,7 +97,8 @@ class ScalarType : public SPIRVType {
   std::optional<int64_t> getSizeInBytes();
 };
 
-// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType.
+// SPIR-V composite type: TensorArmType, VectorType, SPIR-V ArrayType, or SPIR-V
+// StructType.
 class CompositeType : public SPIRVType {
 public:
   using SPIRVType::SPIRVType;
@@ -477,6 +479,37 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
                        std::optional<StorageClass> storage = std::nullopt);
 };
 
+// SPIR-V TensorARM Type
+class TensorArmType
+    : public Type::TypeBase<TensorArmType, CompositeType,
+                            detail::TensorArmTypeStorage, ShapedType::Trait> {
+public:
+  using Base::Base;
+
+  static constexpr StringLiteral name = "spirv.arm.tensor";
+
+  // TensorArm supports minimum rank of 1, hence an empty shape here means
+  // unranked.
+  static TensorArmType get(ArrayRef<int64_t> shape, Type elementType);
+  TensorArmType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                          Type elementType) const;
+
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   ArrayRef<int64_t> shape, Type elementType);
+
+  Type getElementType() const;
+  ArrayRef<int64_t> getShape() const;
+  unsigned getNumElements() const;
+  bool hasRank() const { return !getShape().empty(); }
+  operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
+
+  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                     std::optional<StorageClass> storage = std::nullopt);
+  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+                       std::optional<StorageClass> storage = std::nullopt);
+};
+
 } // namespace spirv
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index a21acef1c4b43..15002f1d5d16e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -194,6 +194,13 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
           << t.getNumElements();
       return Type();
     }
+  } else if (auto t = llvm::dyn_cast<TensorArmType>(type)) {
+    if (!llvm::isa<ScalarType>(t.getElementType())) {
+      parser.emitError(
+          typeLoc, "only scalar element type allowed in tensor type but found ")
+          << t.getElementType();
+      return Type();
+    }
   } else {
     parser.emitError(typeLoc, "cannot use ")
         << type << " to compose SPIR-V types";
@@ -363,6 +370,54 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
   return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
 }
 
+// tensor-arm-type ::=
+//   `!spirv.arm.tensor` `<` dim0 `x` dim1 `x` ... `x` dimN `x` element-type`>`
+static Type parseTensorArmType(SPIRVDialect const &dialect,
+                               DialectAsmParser &parser) {
+  if (parser.parseLess())
+    return {};
+
+  bool unranked = false;
+  SmallVector<int64_t, 4> dims;
+  SMLoc countLoc = parser.getCurrentLocation();
+
+  if (parser.parseOptionalStar().succeeded()) {
+    unranked = true;
+    if (parser.parseXInDimensionList())
+      return {};
+  } else if (parser.parseDimensionList(dims, /*allowDynamic=*/true))
+    return {};
+
+  if (!unranked && dims.empty()) {
+    parser.emitError(countLoc, "arm.tensors do not support rank zero");
+    return {};
+  }
+
+  if (std::any_of(dims.begin(), dims.end(),
+                  [](int64_t dim) { return dim == 0; })) {
+    parser.emitError(countLoc, "arm.tensors do not support zero dimensions");
+    return {};
+  }
+
+  if (std::any_of(dims.begin(), dims.end(),
+                  [](int64_t dim) { return dim < 0; }) &&
+      std::any_of(dims.begin(), dims.end(),
+                  [](int64_t dim) { return dim > 0; })) {
+    parser.emitError(countLoc, "arm.tensor shape dimensions must be either "
+                               "fully dynamic or completed shaped");
+    return {};
+  }
+
+  auto elementTy = parseAndVerifyType(dialect, parser);
+  if (!elementTy)
+    return {};
+
+  if (parser.parseGreater())
+    return {};
+
+  return TensorArmType::get(dims, elementTy);
+}
+
 // TODO: Reorder methods to be utilities first and parse*Type
 // methods in alphabetical order
 //
@@ -759,6 +814,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
     return parseStructType(*this, parser);
   if (keyword == "matrix")
     return parseMatrixType(*this, parser);
+  if (keyword == "arm.tensor")
+    return parseTensorArmType(*this, parser);
   parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
   return Type();
 }
@@ -855,10 +912,28 @@ static void print(MatrixType type, DialectAsmPrinter &os) {
   os << ">";
 }
 
+static void print(TensorArmType type, DialectAsmPrinter &os) {
+  os << "arm.tensor<";
+
+  llvm::interleave(
+      type.getShape(), os,
+      [&](int64_t dim) {
+        if (ShapedType::isDynamic(dim))
+          os << '?';
+        else
+          os << dim;
+      },
+      "x");
+  if (!type.hasRank()) {
+    os << "*";
+  }
+  os << "x" << type.getElementType() << ">";
+}
+
 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
   TypeSwitch<Type>(type)
       .Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType,
-            ImageType, SampledImageType, StructType, MatrixType>(
+            ImageType, SampledImageType, StructType, MatrixType, TensorArmType>(
           [&](auto type) { print(type, os); })
       .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
 }
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 7148027dae78d..eb2974d62fdd1 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -547,6 +547,12 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
       return failure();
   }
 
+  if (llvm::isa<TensorArmType>(type)) {
+    if (parser.parseOptionalColon().succeeded())
+      if (parser.parseType(type))
+        return failure();
+  }
+
   return parser.addTypeToList(type, result.types);
 }
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 93e0c9b33c546..e4eeb0a7f37d5 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -18,8 +18,10 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 
+#include <algorithm>
 #include <cstdint>
 #include <iterator>
+#include <numeric>
 
 using namespace mlir;
 using namespace mlir::spirv;
@@ -96,7 +98,7 @@ bool CompositeType::classof(Type type) {
     return isValid(vectorType);
   return llvm::isa<spirv::ArrayType, spirv::CooperativeMatrixType,
                    spirv::MatrixType, spirv::RuntimeArrayType,
-                   spirv::StructType>(type);
+                   spirv::StructType, spirv::TensorArmType>(type);
 }
 
 bool CompositeType::isValid(VectorType type) {
@@ -107,8 +109,8 @@ bool CompositeType::isValid(VectorType type) {
 
 Type CompositeType::getElementType(unsigned index) const {
   return TypeSwitch<Type, Type>(*this)
-      .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType>(
-          [](auto type) { return type.getElementType(); })
+      .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType,
+            TensorArmType>([](auto type) { return type.getElementType(); })
       .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
       .Case<StructType>(
           [index](StructType type) { return type.getElementType(index); })
@@ -125,6 +127,8 @@ unsigned CompositeType::getNumElements() const {
     return structType.getNumElements();
   if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
     return vectorType.getNumElements();
+  if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this))
+    return tensorArmType.getNumElements();
   if (llvm::isa<CooperativeMatrixType>(*this)) {
     llvm_unreachable(
         "invalid to query number of elements of spirv Cooperative Matrix type");
@@ -151,6 +155,14 @@ void CompositeType::getExtensions(
         return llvm::cast<ScalarType>(type.getElementType())
             .getExtensions(extensions, storage);
       })
+      .Case<TensorArmType>([&](TensorArmType type) {
+        static const Extension exts[] = {Extension::SPV_ARM_tensors};
+        ArrayRef<Extension> ref(exts, std::size(exts));
+        extensions.push_back(ref);
+        return llvm::cast<ScalarType>(type.getElementType())
+            .getExtensions(extensions, storage);
+      })
+
       .Default([](Type) { llvm_unreachable("invalid composite type"); });
 }
 
@@ -171,6 +183,13 @@ void CompositeType::getCapabilities(
         return llvm::cast<ScalarType>(type.getElementType())
             .getCapabilities(capabilities, storage);
       })
+      .Case<TensorArmType>([&](TensorArmType type) {
+        static const Capability caps[] = {Capability::TensorsARM};
+        ArrayRef<Capability> ref(caps, std::size(caps));
+        capabilities.push_back(ref);
+        return llvm::cast<ScalarType>(type.getElementType())
+            .getCapabilities(capabilities, storage);
+      })
       .Default([](Type) { llvm_unreachable("invalid composite type"); });
 }
 
@@ -186,6 +205,13 @@ std::optional<int64_t> CompositeType::getSizeInBytes() {
       return std::nullopt;
     return *elementSize * vectorType.getNumElements();
   }
+  if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
+    std::optional<int64_t> elementSize =
+        llvm::cast<ScalarType>(tensorArmType.getElementType()).getSizeInBytes();
+    if (!elementSize)
+      return std::nullopt;
+    return *elementSize * tensorArmType.getNumElements();
+  }
   return std::nullopt;
 }
 
@@ -691,6 +717,9 @@ bool SPIRVType::classof(Type type) {
     return true;
   if (auto vectorType = llvm::dyn_cast<VectorType>(type))
     return CompositeType::isValid(vectorType);
+  if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(type)) {
+    return llvm::isa<ScalarType>(tensorArmType.getElementType());
+  }
   return false;
 }
 
@@ -712,6 +741,8 @@ void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
     matrixType.getExtensions(extensions, storage);
   } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
     ptrType.getExtensions(extensions, storage);
+  } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
+    tensorArmType.getExtensions(extensions, storage);
   } else {
     llvm_unreachable("invalid SPIR-V Type to getExtensions");
   }
@@ -732,6 +763,8 @@ void SPIRVType::getCapabilities(
     matrixType.getCapabilities(capabilities, storage);
   } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
     ptrType.getCapabilities(capabilities, storage);
+  } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
+    tensorArmType.getCapabilities(capabilities, storage);
   } else {
     llvm_unreachable("invalid SPIR-V Type to getCapabilities");
   }
@@ -1203,11 +1236,94 @@ void MatrixType::getCapabilities(
   llvm::cast<SPIRVType>(getColumnType()).getCapabilities(capabilities, storage);
 }
 
+//===----------------------------------------------------------------------===//
+// TensorArmType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::TensorArmTypeStorage final : TypeStorage {
+  using KeyTy = std::tuple<ArrayRef<int64_t>, Type>;
+
+  static TensorArmTypeStorage *construct(TypeStorageAllocator &allocator,
+                                         const KeyTy &key) {
+    auto shape = std::get<0>(key);
+    auto elementType = std::get<1>(key);
+    shape = allocator.copyInto(shape);
+    return new (allocator.allocate<TensorArmTypeStorage>())
+        TensorArmTypeStorage(std::move(shape), std::move(elementType));
+  }
+
+  static llvm::hash_code hashKey(const KeyTy &key) {
+    return llvm::hash_combine(std::get<0>(key), std::get<1>(key));
+  }
+
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(shape, elementType);
+  }
+
+  TensorArmTypeStorage(ArrayRef<int64_t> shape, Type elementType)
+      : shape(std::move(shape)), elementType(std::move(elementType)) {}
+
+  ArrayRef<int64_t> shape;
+  Type elementType;
+};
+
+TensorArmType TensorArmType::get(ArrayRef<int64_t> shape, Type elementType) {
+  return Base::get(elementType.getContext(), shape, elementType);
+}
+
+TensorArmType TensorArmType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                                       Type elementType) const {
+  return TensorArmType::get(shape.value_or(getShape()), elementType);
+}
+
+Type TensorArmType::getElementType() const { return getImpl()->elementType; }
+ArrayRef<int64_t> TensorArmType::getShape() const { return getImpl()->shape; }
+
+unsigned TensorArmType::getNumElements() const {
+  auto shape = getShape();
+  return std::accumulate(shape.begin(), shape.end(), unsigned(1),
+                         std::multiplies<unsigned>());
+}
+
+void TensorArmType::getExtensions(
+    SPIRVType::ExtensionArrayRefVector &extensions,
+    std::optional<StorageClass> storage) {
+
+  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
+  static constexpr Extension exts[] = {Extension::SPV_ARM_tensors};
+  extensions.push_back(exts);
+}
+
+void TensorArmType::getCapabilities(
+    SPIRVType::CapabilityArrayRefVector &capabilities,
+    std::optional<StorageClass> storage) {
+  llvm::cast<SPIRVType>(getElementType())
+      .getCapabilities(capabilities, storage);
+  static constexpr Capab...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jun 18, 2025

@llvm/pr-subscribers-mlir

Author: Davide Grohmann (davidegrohmann)

Changes

Patch is 32.41 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144667.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+31-3)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h (+34-1)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp (+76-1)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (+6)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp (+120-4)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp (+1)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+52)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.h (+2)
  • (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+48)
  • (modified) mlir/test/Dialect/SPIRV/IR/types.mlir (+51)
  • (added) mlir/test/Target/SPIRV/tensorARM.mlir (+66)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index d2ba76cdad904..d874817e6888d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -422,6 +422,8 @@ def SPV_NV_ray_tracing_motion_blur       : I32EnumAttrCase<"SPV_NV_ray_tracing_m
 
 def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>;
 
+def SPV_ARM_tensors                      : I32EnumAttrCase<"SPV_ARM_tensors", 6000>;
+
 def SPIRV_ExtensionAttr :
     SPIRV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [
       SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_device_group,
@@ -445,6 +447,7 @@ def SPIRV_ExtensionAttr :
       SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
       SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
       SPV_EXT_mesh_shader,
+      SPV_ARM_tensors,
       SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
       SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
       SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod,
@@ -1311,6 +1314,24 @@ def SPIRV_C_GeometryStreams                             : I32EnumAttrCase<"Geome
 def SPIRV_C_MultiViewport                               : I32EnumAttrCase<"MultiViewport", 57> {
   list<I32EnumAttrCase> implies = [SPIRV_C_Geometry];
 }
+def SPIRV_C_TensorsARM                                  : I32EnumAttrCase<"TensorsARM", 4174> {
+  list<I32EnumAttrCase> implies = [SPIRV_C_Int8];
+  list<Availability> availability = [
+    Extension<[SPV_ARM_tensors]>
+  ];
+}
+def SPIRV_C_StorageTensorArrayDynamicIndexingEXT        : I32EnumAttrCase<"StorageTensorArrayDynamicIndexingEXT", 4175> {
+  list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_Shader];
+  list<Availability> availability = [
+    Extension<[SPV_ARM_tensors]>
+  ];
+}
+def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT     : I32EnumAttrCase<"StorageTensorArrayNonUniformIndexingEXT", 4176> {
+  list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_ShaderNonUniform];
+  list<Availability> availability = [
+    Extension<[SPV_ARM_tensors]>
+  ];
+}
 def SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR  : I32EnumAttrCase<"WorkgroupMemoryExplicitLayout8BitAccessKHR", 4429> {
   list<I32EnumAttrCase> implies = [SPIRV_C_WorkgroupMemoryExplicitLayoutKHR];
   list<Availability> availability = [
@@ -1523,6 +1544,8 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_IntegerFunctions2INTEL, SPIRV_C_TessellationPointSize,
       SPIRV_C_GeometryPointSize, SPIRV_C_ImageCubeArray, SPIRV_C_ImageRect,
       SPIRV_C_GeometryStreams, SPIRV_C_MultiViewport,
+      SPIRV_C_TensorsARM, SPIRV_C_StorageTensorArrayDynamicIndexingEXT,
+      SPIRV_C_StorageTensorArrayNonUniformIndexingEXT,
       SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR, SPIRV_C_VariablePointers,
       SPIRV_C_RayTraversalPrimitiveCullingKHR, SPIRV_C_SampleMaskOverrideCoverageNV,
       SPIRV_C_GeometryShaderPassthroughNV, SPIRV_C_PerViewAttributesNV,
@@ -4179,7 +4202,7 @@ def SPIRV_IsPtrType : CPred<"::llvm::isa<::mlir::spirv::PointerType>($_self)">;
 def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">;
 def SPIRV_IsSampledImageType : CPred<"::llvm::isa<::mlir::spirv::SampledImageType>($_self)">;
 def SPIRV_IsStructType : CPred<"::llvm::isa<::mlir::spirv::StructType>($_self)">;
-
+def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_self)">;
 
 // See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
 // for the definition of the following types and type categories.
@@ -4217,6 +4240,8 @@ def SPIRV_AnyStruct : DialectType<SPIRV_Dialect, SPIRV_IsStructType,
                                 "any SPIR-V struct type">;
 def SPIRV_AnySampledImage : DialectType<SPIRV_Dialect, SPIRV_IsSampledImageType,
                                 "any SPIR-V sampled image type">;
+def SPIRV_AnyTensorArm : DialectType<SPIRV_Dialect, SPIRV_IsTensorArmType,
+                                 "any SPIR-V tensorArm type">;
 
 def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>;
 def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
@@ -4228,7 +4253,7 @@ def SPIRV_Type : AnyTypeOf<[
     SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector,
     SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
     SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage,
-    SPIRV_AnyImage
+    SPIRV_AnyImage, SPIRV_AnyTensorArm
   ]>;
 
 def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
@@ -4525,6 +4550,7 @@ def SPIRV_OC_OpGroupNonUniformBitwiseXor      : I32EnumAttrCase<"OpGroupNonUnifo
 def SPIRV_OC_OpGroupNonUniformLogicalAnd      : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>;
 def SPIRV_OC_OpGroupNonUniformLogicalOr       : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
 def SPIRV_OC_OpGroupNonUniformLogicalXor      : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
+def SPIRV_OC_OpTypeTensorARM                  : I32EnumAttrCase<"OpTypeTensorARM", 4163>;
 def SPIRV_OC_OpSubgroupBallotKHR              : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
 def SPIRV_OC_OpGroupNonUniformRotateKHR       : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
 def SPIRV_OC_OpSDot                           : I32EnumAttrCase<"OpSDot", 4450>;
@@ -4638,7 +4664,9 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpGroupNonUniformBitwiseAnd,
       SPIRV_OC_OpGroupNonUniformBitwiseOr, SPIRV_OC_OpGroupNonUniformBitwiseXor,
       SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
-      SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpSubgroupBallotKHR,
+      SPIRV_OC_OpGroupNonUniformLogicalXor,
+      SPIRV_OC_OpTypeTensorARM,
+      SPIRV_OC_OpSubgroupBallotKHR,
       SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
       SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
       SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 787535d0a6bd2..7ffea6e7dba81 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -29,6 +29,7 @@ namespace spirv {
 namespace detail {
 struct ArrayTypeStorage;
 struct CooperativeMatrixTypeStorage;
+struct TensorArmTypeStorage;
 struct ImageTypeStorage;
 struct MatrixTypeStorage;
 struct PointerTypeStorage;
@@ -96,7 +97,8 @@ class ScalarType : public SPIRVType {
   std::optional<int64_t> getSizeInBytes();
 };
 
-// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType.
+// SPIR-V composite type: TensorArmType, VectorType, SPIR-V ArrayType, or SPIR-V
+// StructType.
 class CompositeType : public SPIRVType {
 public:
   using SPIRVType::SPIRVType;
@@ -477,6 +479,37 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
                        std::optional<StorageClass> storage = std::nullopt);
 };
 
+// SPIR-V TensorARM Type
+class TensorArmType
+    : public Type::TypeBase<TensorArmType, CompositeType,
+                            detail::TensorArmTypeStorage, ShapedType::Trait> {
+public:
+  using Base::Base;
+
+  static constexpr StringLiteral name = "spirv.arm.tensor";
+
+  // TensorArm supports minimum rank of 1, hence an empty shape here means
+  // unranked.
+  static TensorArmType get(ArrayRef<int64_t> shape, Type elementType);
+  TensorArmType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                          Type elementType) const;
+
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   ArrayRef<int64_t> shape, Type elementType);
+
+  Type getElementType() const;
+  ArrayRef<int64_t> getShape() const;
+  unsigned getNumElements() const;
+  bool hasRank() const { return !getShape().empty(); }
+  operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
+
+  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                     std::optional<StorageClass> storage = std::nullopt);
+  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+                       std::optional<StorageClass> storage = std::nullopt);
+};
+
 } // namespace spirv
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index a21acef1c4b43..15002f1d5d16e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -194,6 +194,13 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
           << t.getNumElements();
       return Type();
     }
+  } else if (auto t = llvm::dyn_cast<TensorArmType>(type)) {
+    if (!llvm::isa<ScalarType>(t.getElementType())) {
+      parser.emitError(
+          typeLoc, "only scalar element type allowed in tensor type but found ")
+          << t.getElementType();
+      return Type();
+    }
   } else {
     parser.emitError(typeLoc, "cannot use ")
         << type << " to compose SPIR-V types";
@@ -363,6 +370,54 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
   return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
 }
 
+// tensor-arm-type ::=
+//   `!spirv.arm.tensor` `<` dim0 `x` dim1 `x` ... `x` dimN `x` element-type`>`
+static Type parseTensorArmType(SPIRVDialect const &dialect,
+                               DialectAsmParser &parser) {
+  if (parser.parseLess())
+    return {};
+
+  bool unranked = false;
+  SmallVector<int64_t, 4> dims;
+  SMLoc countLoc = parser.getCurrentLocation();
+
+  if (parser.parseOptionalStar().succeeded()) {
+    unranked = true;
+    if (parser.parseXInDimensionList())
+      return {};
+  } else if (parser.parseDimensionList(dims, /*allowDynamic=*/true))
+    return {};
+
+  if (!unranked && dims.empty()) {
+    parser.emitError(countLoc, "arm.tensors do not support rank zero");
+    return {};
+  }
+
+  if (std::any_of(dims.begin(), dims.end(),
+                  [](int64_t dim) { return dim == 0; })) {
+    parser.emitError(countLoc, "arm.tensors do not support zero dimensions");
+    return {};
+  }
+
+  if (std::any_of(dims.begin(), dims.end(),
+                  [](int64_t dim) { return dim < 0; }) &&
+      std::any_of(dims.begin(), dims.end(),
+                  [](int64_t dim) { return dim > 0; })) {
+    parser.emitError(countLoc, "arm.tensor shape dimensions must be either "
+                               "fully dynamic or completed shaped");
+    return {};
+  }
+
+  auto elementTy = parseAndVerifyType(dialect, parser);
+  if (!elementTy)
+    return {};
+
+  if (parser.parseGreater())
+    return {};
+
+  return TensorArmType::get(dims, elementTy);
+}
+
 // TODO: Reorder methods to be utilities first and parse*Type
 // methods in alphabetical order
 //
@@ -759,6 +814,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
     return parseStructType(*this, parser);
   if (keyword == "matrix")
     return parseMatrixType(*this, parser);
+  if (keyword == "arm.tensor")
+    return parseTensorArmType(*this, parser);
   parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
   return Type();
 }
@@ -855,10 +912,28 @@ static void print(MatrixType type, DialectAsmPrinter &os) {
   os << ">";
 }
 
+static void print(TensorArmType type, DialectAsmPrinter &os) {
+  os << "arm.tensor<";
+
+  llvm::interleave(
+      type.getShape(), os,
+      [&](int64_t dim) {
+        if (ShapedType::isDynamic(dim))
+          os << '?';
+        else
+          os << dim;
+      },
+      "x");
+  if (!type.hasRank()) {
+    os << "*";
+  }
+  os << "x" << type.getElementType() << ">";
+}
+
 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
   TypeSwitch<Type>(type)
       .Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType,
-            ImageType, SampledImageType, StructType, MatrixType>(
+            ImageType, SampledImageType, StructType, MatrixType, TensorArmType>(
           [&](auto type) { print(type, os); })
       .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
 }
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 7148027dae78d..eb2974d62fdd1 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -547,6 +547,12 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
       return failure();
   }
 
+  if (llvm::isa<TensorArmType>(type)) {
+    if (parser.parseOptionalColon().succeeded())
+      if (parser.parseType(type))
+        return failure();
+  }
+
   return parser.addTypeToList(type, result.types);
 }
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 93e0c9b33c546..e4eeb0a7f37d5 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -18,8 +18,10 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 
+#include <algorithm>
 #include <cstdint>
 #include <iterator>
+#include <numeric>
 
 using namespace mlir;
 using namespace mlir::spirv;
@@ -96,7 +98,7 @@ bool CompositeType::classof(Type type) {
     return isValid(vectorType);
   return llvm::isa<spirv::ArrayType, spirv::CooperativeMatrixType,
                    spirv::MatrixType, spirv::RuntimeArrayType,
-                   spirv::StructType>(type);
+                   spirv::StructType, spirv::TensorArmType>(type);
 }
 
 bool CompositeType::isValid(VectorType type) {
@@ -107,8 +109,8 @@ bool CompositeType::isValid(VectorType type) {
 
 Type CompositeType::getElementType(unsigned index) const {
   return TypeSwitch<Type, Type>(*this)
-      .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType>(
-          [](auto type) { return type.getElementType(); })
+      .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType,
+            TensorArmType>([](auto type) { return type.getElementType(); })
       .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
       .Case<StructType>(
           [index](StructType type) { return type.getElementType(index); })
@@ -125,6 +127,8 @@ unsigned CompositeType::getNumElements() const {
     return structType.getNumElements();
   if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
     return vectorType.getNumElements();
+  if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this))
+    return tensorArmType.getNumElements();
   if (llvm::isa<CooperativeMatrixType>(*this)) {
     llvm_unreachable(
         "invalid to query number of elements of spirv Cooperative Matrix type");
@@ -151,6 +155,14 @@ void CompositeType::getExtensions(
         return llvm::cast<ScalarType>(type.getElementType())
             .getExtensions(extensions, storage);
       })
+      .Case<TensorArmType>([&](TensorArmType type) {
+        static const Extension exts[] = {Extension::SPV_ARM_tensors};
+        ArrayRef<Extension> ref(exts, std::size(exts));
+        extensions.push_back(ref);
+        return llvm::cast<ScalarType>(type.getElementType())
+            .getExtensions(extensions, storage);
+      })
+
       .Default([](Type) { llvm_unreachable("invalid composite type"); });
 }
 
@@ -171,6 +183,13 @@ void CompositeType::getCapabilities(
         return llvm::cast<ScalarType>(type.getElementType())
             .getCapabilities(capabilities, storage);
       })
+      .Case<TensorArmType>([&](TensorArmType type) {
+        static const Capability caps[] = {Capability::TensorsARM};
+        ArrayRef<Capability> ref(caps, std::size(caps));
+        capabilities.push_back(ref);
+        return llvm::cast<ScalarType>(type.getElementType())
+            .getCapabilities(capabilities, storage);
+      })
       .Default([](Type) { llvm_unreachable("invalid composite type"); });
 }
 
@@ -186,6 +205,13 @@ std::optional<int64_t> CompositeType::getSizeInBytes() {
       return std::nullopt;
     return *elementSize * vectorType.getNumElements();
   }
+  if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
+    std::optional<int64_t> elementSize =
+        llvm::cast<ScalarType>(tensorArmType.getElementType()).getSizeInBytes();
+    if (!elementSize)
+      return std::nullopt;
+    return *elementSize * tensorArmType.getNumElements();
+  }
   return std::nullopt;
 }
 
@@ -691,6 +717,9 @@ bool SPIRVType::classof(Type type) {
     return true;
   if (auto vectorType = llvm::dyn_cast<VectorType>(type))
     return CompositeType::isValid(vectorType);
+  if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(type)) {
+    return llvm::isa<ScalarType>(tensorArmType.getElementType());
+  }
   return false;
 }
 
@@ -712,6 +741,8 @@ void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
     matrixType.getExtensions(extensions, storage);
   } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
     ptrType.getExtensions(extensions, storage);
+  } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
+    tensorArmType.getExtensions(extensions, storage);
   } else {
     llvm_unreachable("invalid SPIR-V Type to getExtensions");
   }
@@ -732,6 +763,8 @@ void SPIRVType::getCapabilities(
     matrixType.getCapabilities(capabilities, storage);
   } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
     ptrType.getCapabilities(capabilities, storage);
+  } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
+    tensorArmType.getCapabilities(capabilities, storage);
   } else {
     llvm_unreachable("invalid SPIR-V Type to getCapabilities");
   }
@@ -1203,11 +1236,94 @@ void MatrixType::getCapabilities(
   llvm::cast<SPIRVType>(getColumnType()).getCapabilities(capabilities, storage);
 }
 
+//===----------------------------------------------------------------------===//
+// TensorArmType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::TensorArmTypeStorage final : TypeStorage {
+  using KeyTy = std::tuple<ArrayRef<int64_t>, Type>;
+
+  static TensorArmTypeStorage *construct(TypeStorageAllocator &allocator,
+                                         const KeyTy &key) {
+    auto shape = std::get<0>(key);
+    auto elementType = std::get<1>(key);
+    shape = allocator.copyInto(shape);
+    return new (allocator.allocate<TensorArmTypeStorage>())
+        TensorArmTypeStorage(std::move(shape), std::move(elementType));
+  }
+
+  static llvm::hash_code hashKey(const KeyTy &key) {
+    return llvm::hash_combine(std::get<0>(key), std::get<1>(key));
+  }
+
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(shape, elementType);
+  }
+
+  TensorArmTypeStorage(ArrayRef<int64_t> shape, Type elementType)
+      : shape(std::move(shape)), elementType(std::move(elementType)) {}
+
+  ArrayRef<int64_t> shape;
+  Type elementType;
+};
+
+TensorArmType TensorArmType::get(ArrayRef<int64_t> shape, Type elementType) {
+  return Base::get(elementType.getContext(), shape, elementType);
+}
+
+TensorArmType TensorArmType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                                       Type elementType) const {
+  return TensorArmType::get(shape.value_or(getShape()), elementType);
+}
+
+Type TensorArmType::getElementType() const { return getImpl()->elementType; }
+ArrayRef<int64_t> TensorArmType::getShape() const { return getImpl()->shape; }
+
+unsigned TensorArmType::getNumElements() const {
+  auto shape = getShape();
+  return std::accumulate(shape.begin(), shape.end(), unsigned(1),
+                         std::multiplies<unsigned>());
+}
+
+void TensorArmType::getExtensions(
+    SPIRVType::ExtensionArrayRefVector &extensions,
+    std::optional<StorageClass> storage) {
+
+  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
+  static constexpr Extension exts[] = {Extension::SPV_ARM_tensors};
+  extensions.push_back(exts);
+}
+
+void TensorArmType::getCapabilities(
+    SPIRVType::CapabilityArrayRefVector &capabilities,
+    std::optional<StorageClass> storage) {
+  llvm::cast<SPIRVType>(getElementType())
+      .getCapabilities(capabilities, storage);
+  static constexpr Capab...
[truncated]

@davidegrohmann
Copy link
Contributor Author

@davidegrohmann davidegrohmann changed the title Add support for SPV_ARM_tensors [mlir][spv] Add support for SPV_ARM_tensors Jun 18, 2025
@davidegrohmann davidegrohmann changed the title [mlir][spv] Add support for SPV_ARM_tensors [mlir][spirv] Add support for SPV_ARM_tensors Jun 18, 2025
Copy link
Contributor

@IgWod-IMG IgWod-IMG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contributing this. I added few comments.

I pointed it out few times, but I believe you don't need llvm:: with casts, etc. It's an artifact of refactoring that was done in the past: #141458 (comment)

Also, we probably should add a test in here: https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/SPIRV/IR/availability.mlir

@@ -4179,7 +4202,7 @@ def SPIRV_IsPtrType : CPred<"::llvm::isa<::mlir::spirv::PointerType>($_self)">;
def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">;
def SPIRV_IsSampledImageType : CPred<"::llvm::isa<::mlir::spirv::SampledImageType>($_self)">;
def SPIRV_IsStructType : CPred<"::llvm::isa<::mlir::spirv::StructType>($_self)">;

def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_self)">;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

meta: I know it comes from the spec so probably should stay this way, but ArmTensor sounds more natural than TensorArm.

@@ -477,6 +479,37 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
std::optional<StorageClass> storage = std::nullopt);
};

// SPIR-V TensorARM Type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and below it probably should be /// instead of //.

@@ -194,6 +194,13 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
<< t.getNumElements();
return Type();
}
} else if (auto t = llvm::dyn_cast<TensorArmType>(type)) {
if (!llvm::isa<ScalarType>(t.getElementType())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is a need for llvm::. Also line above.

@@ -125,6 +127,8 @@ unsigned CompositeType::getNumElements() const {
return structType.getNumElements();
if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
return vectorType.getNumElements();
if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for llvm::... I think :) Also below.

TensorArmTypeStorage(ArrayRef<int64_t> shape, Type elementType)
: shape(std::move(shape)), elementType(std::move(elementType)) {}

ArrayRef<int64_t> shape;
Copy link
Contributor

@IgWod-IMG IgWod-IMG Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks suspicious to me. I don't think an ArrayRef should be a storage type, since the ArrayRef itself doesn't keep any data - it's only "pointer" to data. I think this should be a SmallVector or something similar. But please wait for someone to confirm that I’m not wrong :)

@@ -138,6 +138,7 @@ LogicalResult spirv::Deserializer::processHeader() {
MIN_VERSION_CASE(3);
MIN_VERSION_CASE(4);
MIN_VERSION_CASE(5);
MIN_VERSION_CASE(6);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably should land as a separte PR.

@@ -1238,6 +1241,55 @@ spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
return success();
}

LogicalResult
spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return success();
}

auto rankAttr = getConstantInt(operands[2]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return success();
}

auto shapeInfo = getConstant(operands[3]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above: auto.

@@ -96,7 +97,8 @@ class ScalarType : public SPIRVType {
std::optional<int64_t> getSizeInBytes();
};

// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType.
// SPIR-V composite type: TensorArmType, VectorType, SPIR-V ArrayType, or SPIR-V
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put the tensor type at the end, since it's going to be the least commonly used one?

unranked = true;
if (parser.parseXInDimensionList())
return {};
} else if (parser.parseDimensionList(dims, /*allowDynamic=*/true))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the llvm coding style requires that when either of if/else uses braces, the other statement should too

Comment on lines +396 to +397
if (std::any_of(dims.begin(), dims.end(),
[](int64_t dim) { return dim == 0; })) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use llvm::is_contained

Comment on lines +402 to +405
if (std::any_of(dims.begin(), dims.end(),
[](int64_t dim) { return dim < 0; }) &&
std::any_of(dims.begin(), dims.end(),
[](int64_t dim) { return dim > 0; })) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usellvm::any_of


Type getElementType() const;
ArrayRef<int64_t> getShape() const;
unsigned getNumElements() const;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when this is called on an unranked tensor?

ArrayRef<int64_t> TensorArmType::getShape() const { return getImpl()->shape; }

unsigned TensorArmType::getNumElements() const {
auto shape = getShape();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't use auto here since the type is not obvious based on the RHS alone: https://llvm.org/docs/CodingStandards.html#use-auto-type-deduction-to-make-code-more-readable

std::optional<StorageClass> storage) {

llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
static constexpr Extension exts[] = {Extension::SPV_ARM_tensors};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't have to be an array

std::optional<StorageClass> storage) {
llvm::cast<SPIRVType>(getElementType())
.getCapabilities(capabilities, storage);
static constexpr Capability caps[] = {Capability::TensorsARM};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also here

Type elementTy = getType(operands[1]);
if (!elementTy) {
return emitError(unknownLoc,
"OpTypeTensorARM references undefined element type.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want to print the . before the operand?

if (size < 2 || size > 4) {
return emitError(unknownLoc, "OpTypeTensorARM must have 2-4 operands "
"(result_id, element_type, (rank), (shape))")
<< size;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add a space before printing the size?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants