From c17d62d6be3d035cc82383b4633cb26227009274 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Mon, 28 Jul 2025 14:25:51 +0200 Subject: [PATCH 1/4] [mlir][amx] Vector to AMX conversion pass Adds a pass for Vector to AMX operation conversion. Initially, a direct rewrite for vector contraction in packed VNNI layout is supported. Operations are expected to already be in shapes which are AMX-compatible for the rewriting to occur. --- mlir/include/mlir/Conversion/Passes.h | 1 + mlir/include/mlir/Conversion/Passes.td | 13 + .../mlir/Conversion/VectorToAMX/VectorToAMX.h | 26 ++ mlir/lib/Conversion/CMakeLists.txt | 1 + .../lib/Conversion/VectorToAMX/CMakeLists.txt | 19 ++ .../Conversion/VectorToAMX/VectorToAMX.cpp | 287 +++++++++++++++++ .../VectorToAMX/contract-to-amx.mlir | 291 ++++++++++++++++++ 7 files changed, 638 insertions(+) create mode 100644 mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h create mode 100644 mlir/lib/Conversion/VectorToAMX/CMakeLists.txt create mode 100644 mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp create mode 100644 mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index 3dc48b2201cf2..91b2ecf8922a3 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -75,6 +75,7 @@ #include "mlir/Conversion/TosaToTensor/TosaToTensor.h" #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" #include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h" +#include "mlir/Conversion/VectorToAMX/VectorToAMX.h" #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h" #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index cf7596cc8a928..20ead98acc371 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1515,6 +1515,19 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> { ]; } +//===----------------------------------------------------------------------===// +// VectorToAMX +//===----------------------------------------------------------------------===// + +def ConvertVectorToAMX : Pass<"convert-vector-to-amx"> { + let summary = "Lower the operations from the vector dialect into the AMX " + "dialect"; + let dependentDialects = [ + "affine::AffineDialect", "amx::AMXDialect", "arith::ArithDialect", + "memref::MemRefDialect", "scf::SCFDialect", "vector::VectorDialect" + ]; +} + //===----------------------------------------------------------------------===// // XeVMToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h b/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h new file mode 100644 index 0000000000000..b075ac92990a2 --- /dev/null +++ b/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h @@ -0,0 +1,26 @@ +//===- VectorToAMX.h - Convert vector to AMX dialect ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_VECTORTOAMX_VECTORTOAMX_H +#define MLIR_CONVERSION_VECTORTOAMX_VECTORTOAMX_H + +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +class Pass; +class RewritePatternSet; + +#define GEN_PASS_DECL_CONVERTVECTORTOAMX +#include "mlir/Conversion/Passes.h.inc" + +/// Collect a set of patterns to convert from the vector to AMX ops. +void populateVectorToAMXConversionPatterns(RewritePatternSet &patterns); + +} // namespace mlir + +#endif // MLIR_CONVERSION_VECTORTOAMX_VECTORTOAMX_H diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 785cb8293810c..171f7169fd41d 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -68,6 +68,7 @@ add_subdirectory(TosaToSCF) add_subdirectory(TosaToTensor) add_subdirectory(UBToLLVM) add_subdirectory(UBToSPIRV) +add_subdirectory(VectorToAMX) add_subdirectory(VectorToArmSME) add_subdirectory(VectorToGPU) add_subdirectory(VectorToLLVM) diff --git a/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt b/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt new file mode 100644 index 0000000000000..2d4b2b6e9283c --- /dev/null +++ b/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(MLIRVectorToAMX + VectorToAMX.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToAMX + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRAMXDialect + MLIRAffineUtils + MLIRArithDialect + MLIRLinalgUtils + MLIRMemRefDialect + MLIRSCFDialect + MLIRTransforms + MLIRVectorDialect + ) diff --git a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp new file mode 100644 index 0000000000000..fc24275a1467c --- /dev/null +++ b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp @@ -0,0 +1,287 @@ +//===- VectorToXeGPU.cpp - Convert vector to AMX dialect --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements lowering of vector operations to XeGPU dialect ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/VectorToAMX/VectorToAMX.h" + +#include "mlir/Dialect/AMX/AMXDialect.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include + +namespace mlir { +#define GEN_PASS_DEF_CONVERTVECTORTOAMX +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { + +/// Return true if vector shape is compatible with AMX tiles. +/// The validation accounts for VNNI packing. +static bool verifyAmxShape(VectorType vec) { + // Check overall shape: + // - 2D for plain layout input or output + // - 3D for VNNI packed input + if (vec.getRank() != 2 && vec.getRank() != 3) + return false; + + ArrayRef shape = vec.getShape(); + int64_t rows = shape[0]; + int64_t cols = shape[1]; + unsigned elemBitWidth = vec.getElementType().getIntOrFloatBitWidth(); + + // 3D shape indicates VNNI packed layout. + if (vec.getRank() == 3) { + int64_t vnniFactor = 32 / elemBitWidth; + if (shape.back() != vnniFactor) + return false; + cols *= vnniFactor; + } + + // AMX tile supports up to 16 rows of 64 bytes each. + constexpr unsigned maxRows = 16; + constexpr unsigned maxBitsPerRow = 64 * 8; + return rows <= maxRows && (cols * elemBitWidth) <= maxBitsPerRow; +} + +/// Checks if contraction operands are in AMX-compatible packed VNNI layout. +static LogicalResult isAmxVnniLayout(PatternRewriter &rewriter, + vector::ContractionOp contractOp) { + VectorType accType = dyn_cast(contractOp.getAcc().getType()); + if (!accType || accType.getRank() != 2) + return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector"); + + // Expect 3D inputs for VNNI packed data. + VectorType lhsType = contractOp.getLhs().getType(); + VectorType rhsType = contractOp.getRhs().getType(); + if (lhsType.getRank() != 3 || rhsType.getRank() != 3) + return rewriter.notifyMatchFailure(contractOp, + "Expects lhs and rhs 3D vectors"); + + // Check if shapes are compatible with AMX tile. + if (!verifyAmxShape(lhsType) || !verifyAmxShape(rhsType) || + !verifyAmxShape(accType)) + return rewriter.notifyMatchFailure(contractOp, "Invalid operand shape"); + + // Validate affine maps. + // + // Iterators can be ordered arbitrarily. Indexing map positions are based on + // operands' target shapes. + // The matrix layouts must match the following: + // - matrix A - [M]x[K/vnniFactor]x[vnniFactor] + // - matrix B - [K/vnniFactor]x[N]x[vnniFactor] + // - matrix C - [M]x[N] + SmallVector indexingMaps = contractOp.getIndexingMapsArray(); + AffineMap mapA = indexingMaps[0]; + AffineMap mapB = indexingMaps[1]; + if (mapA.getNumInputs() != 4 || mapA.getNumResults() != 3 || + mapB.getNumResults() != 3) + return rewriter.notifyMatchFailure(contractOp, + "Invalid input indexing maps"); + FailureOr dims = + linalg::inferContractionDims(indexingMaps); + if (failed(dims)) + return rewriter.notifyMatchFailure(contractOp, + "Failed to infer contraction dims"); + // Two reduction dimensions are expected: + // - one for the K dimension + // - one for the VNNI factor + if (dims->k.size() != 2) + return rewriter.notifyMatchFailure(contractOp, + "Expected two reduction dims"); + assert(dims->m.size() == 1 && dims->n.size() == 1 && + "Invalid parallel contraction dims"); + + SmallVector iteratorTypes = + contractOp.getIteratorTypesArray(); + // Check VNNI dim maps - the innermost dim for A and B inputs. + auto vnniDimA = dyn_cast(mapA.getResult(2)); + auto vnniDimB = dyn_cast(mapB.getResult(2)); + if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB || + iteratorTypes[vnniDimA.getPosition()] != vector::IteratorType::reduction) + return rewriter.notifyMatchFailure(contractOp, "Invalid VNNI dim map"); + // Check K dim maps - non-transposed row-major layout. + auto redDimA = dyn_cast(mapA.getResult(1)); + auto redDimB = dyn_cast(mapB.getResult(0)); + if (!redDimA || !redDimB || redDimA != redDimB || + iteratorTypes[redDimA.getPosition()] != vector::IteratorType::reduction) + return rewriter.notifyMatchFailure(contractOp, "Invalid K dim map"); + // Check M and N dim maps - map to non-transposed output. + AffineMap mapC = indexingMaps[2]; + auto mDimC = dyn_cast(mapC.getResult(0)); + auto nDimC = dyn_cast(mapC.getResult(1)); + if (!mDimC || !nDimC) + return rewriter.notifyMatchFailure(contractOp, "Invalid acc maps"); + auto parallelDimA = dyn_cast(mapA.getResult(0)); + if (!parallelDimA || + iteratorTypes[parallelDimA.getPosition()] != + vector::IteratorType::parallel || + parallelDimA != mDimC) + return rewriter.notifyMatchFailure(contractOp, "Invalid M dim map"); + auto parallelDimB = dyn_cast(mapB.getResult(1)); + if (!parallelDimB || + iteratorTypes[parallelDimB.getPosition()] != + vector::IteratorType::parallel || + parallelDimB != nDimC) + return rewriter.notifyMatchFailure(contractOp, "Invalid N dim map"); + + return success(); +} + +/// Validate contraction operands for AMX lowering. +static LogicalResult validateOperands(PatternRewriter &rewriter, + vector::ContractionOp contractOp) { + VectorType accType = dyn_cast(contractOp.getAcc().getType()); + if (!accType) + return rewriter.notifyMatchFailure(contractOp, "Expects vector acc"); + + // Check if operand types are compatible with AMX compute ops. + bool validElemTypes = false; + Type lhsElemType = contractOp.getLhs().getType().getElementType(); + Type rhsElemType = contractOp.getRhs().getType().getElementType(); + Type accElemType = accType.getElementType(); + if (accElemType.isInteger(32)) { + validElemTypes = lhsElemType.isInteger(8) && rhsElemType.isInteger(8); + } else if (accElemType.isF32()) { + validElemTypes = (lhsElemType.isF16() && rhsElemType.isF16()) || + (lhsElemType.isBF16() && rhsElemType.isBF16()); + } + if (!validElemTypes) + return rewriter.notifyMatchFailure(contractOp, + "Invalid combination of operand types"); + + if (failed(isAmxVnniLayout(rewriter, contractOp))) + return failure(); + + return success(); +} + +/// Collapses the two innermost dimensions together. +static Value collapseLastDim(PatternRewriter &rewriter, + TypedValue memref) { + int64_t rank = memref.getType().getRank(); + SmallVector reassocIndices; + for (auto i : llvm::seq(0, rank - 2)) + reassocIndices.push_back({i}); + reassocIndices.push_back({rank - 2, rank - 1}); + return memref::CollapseShapeOp::create(rewriter, memref.getLoc(), memref, + reassocIndices); +} + +/// Loads vector values to an AMX tile. +static TypedValue loadTile(PatternRewriter &rewriter, + TypedValue vec) { + Location loc = vec.getLoc(); + Value zeroIndex = rewriter.createOrFold(loc, 0); + + // Transfer the vector to a tile through an intermediate buffer. + VectorType vecTy = vec.getType(); + Value buf = memref::AllocaOp::create( + rewriter, loc, MemRefType::get(vecTy.getShape(), vecTy.getElementType())); + SmallVector indices(vecTy.getRank(), zeroIndex); + vector::TransferWriteOp::create(rewriter, loc, vec, buf, indices); + + // Collapse the VNNI dimension in case of packing. + bool isPacked = vecTy.getRank() == 3; + if (isPacked) + buf = collapseLastDim(rewriter, cast>(buf)); + + ArrayRef shape = vecTy.getShape(); + int64_t rows = shape[0]; + int64_t cols = std::accumulate(shape.begin() + 1, shape.end(), 1, + std::multiplies()); + auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType()); + + return amx::TileLoadOp::create(rewriter, loc, tileType, buf, + {zeroIndex, zeroIndex}); +} + +/// Stores an AMX tile in a vector. +static TypedValue storeTile(PatternRewriter &rewriter, + TypedValue tile) { + Location loc = tile.getLoc(); + Value zeroIndex = rewriter.createOrFold(loc, 0); + + // Transfer the tile to a vector through an intermediate buffer. + amx::TileType tileTy = tile.getType(); + Value buf = memref::AllocaOp::create( + rewriter, loc, + MemRefType::get(tileTy.getShape(), tileTy.getElementType())); + SmallVector indices(2, zeroIndex); + amx::TileStoreOp::create(rewriter, loc, buf, indices, tile); + + auto vecTy = VectorType::get(tileTy.getShape(), tileTy.getElementType()); + return vector::TransferReadOp::create(rewriter, loc, vecTy, buf, indices, {}); +} + +struct ContractionToAMX : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + Location loc = contractOp.getLoc(); + + if (contractOp.getKind() != vector::CombiningKind::ADD) + return rewriter.notifyMatchFailure(contractOp, + "Expects add combining kind"); + if (failed(validateOperands(rewriter, contractOp))) + return failure(); + + TypedValue lhsTile = loadTile(rewriter, contractOp.getLhs()); + TypedValue rhsTile = loadTile(rewriter, contractOp.getRhs()); + auto acc = dyn_cast>(contractOp.getAcc()); + assert(acc && "Invalid accumulator type"); + TypedValue accTile = loadTile(rewriter, acc); + + TypedValue tileMul; + if (acc.getType().getElementType().isFloat()) { + tileMul = amx::TileMulFOp::create(rewriter, loc, accTile.getType(), + lhsTile, rhsTile, accTile); + } else { + tileMul = amx::TileMulIOp::create(rewriter, loc, accTile.getType(), + lhsTile, rhsTile, accTile); + } + + Value res = storeTile(rewriter, tileMul); + rewriter.replaceOp(contractOp, res); + + return success(); + } +}; + +struct ConvertVectorToAMXPass + : public impl::ConvertVectorToAMXBase { + void runOnOperation() override { + MLIRContext &ctx = getContext(); + RewritePatternSet patterns(&ctx); + populateVectorToAMXConversionPatterns(patterns); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +void mlir::populateVectorToAMXConversionPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir b/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir new file mode 100644 index 0000000000000..ad23964a15dd2 --- /dev/null +++ b/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir @@ -0,0 +1,291 @@ +// RUN: mlir-opt %s -convert-vector-to-amx -split-input-file | FileCheck %s + +/// VNNI format is Intel's packed data layout. +/// For matrix multiplication, elements from the reduction dimension `k` +/// are packed into 32-bit tuples. Then the appropriate AMX operations can +/// perform tile multiplication directly on the packed data. +/// +/// These packed elements are represented in the indexing maps by a separate +/// reduction dimension `vnni`. + +#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)> +#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)> +#map2 = affine_map<(m, n, k, vnni) -> (m, n)> +func.func @contract_vnni_f16(%A: vector<4x8x2xf16>, %B: vector<8x16x2xf16>, + %C: vector<4x16xf32>) -> vector<4x16xf32> { + %0 = vector.contract + {kind = #vector.kind, + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction", "reduction"]} + %A, %B, %C : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32> + return %0 : vector<4x16xf32> +} + +// CHECK-LABEL: @contract_vnni_f16( +// CHECK-SAME: %[[A:.+]]: vector<4x8x2xf16>, +// CHECK-SAME: %[[B:.+]]: vector<8x16x2xf16>, +// CHECK-SAME: %[[C:.+]]: vector<4x16xf32> + +/// AMX hardware has no direct access to the registers. Thus, data must +/// be transfered through intermediate buffers. +/// +/// Load A vector into an AMX tile +// CHECK: %[[A_BUF:.+]] = memref.alloca() : memref<4x8x2xf16> +// CHECK: vector.transfer_write %[[A]], %[[A_BUF]] +// CHECK: %[[A_BUF_2D:.+]] = memref.collapse_shape %[[A_BUF]] +// CHECK-SAME: {{\[}}[0], [1, 2]] : memref<4x8x2xf16> into memref<4x16xf16> +// CHECK: %[[A_TILE:.+]] = amx.tile_load %[[A_BUF_2D]] + +/// Load B vector into an AMX tile +// CHECK: %[[B_BUF:.+]] = memref.alloca() : memref<8x16x2xf16> +// CHECK: vector.transfer_write %[[B]], %[[B_BUF]] +// CHECK: %[[B_BUF_2D:.+]] = memref.collapse_shape %[[B_BUF]] +// CHECK-SAME: {{\[}}[0], [1, 2]] : memref<8x16x2xf16> into memref<8x32xf16> +// CHECK: %[[B_TILE:.+]] = amx.tile_load %[[B_BUF_2D]] + +/// Load C vector into an AMX tile +// CHECK: %[[C_BUF:.+]] = memref.alloca() : memref<4x16xf32> +// CHECK: vector.transfer_write %[[C]], %[[C_BUF]] +// CHECK: %[[C_TILE:.+]] = amx.tile_load %[[C_BUF]] + +/// Perform tile multiplication +// CHECK: %[[RES:.+]] = amx.tile_mulf +// CHECK-SAME: %[[A_TILE]], %[[B_TILE]], %[[C_TILE]] + +/// Load the result back into a vector +// CHECK: %[[RES_BUF:.+]] = memref.alloca() : memref<4x16xf32> +// CHECK: amx.tile_store %[[RES_BUF]]{{.*}}, %[[RES]] +// CHECK: %[[RES_VEC:.+]] = vector.transfer_read %[[RES_BUF]] + +// CHECK: return %[[RES_VEC]] + +// ----- + +#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)> +#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)> +#map2 = affine_map<(m, n, k, vnni) -> (m, n)> +func.func @contract_vnni_bf16(%A: vector<4x8x2xbf16>, %B: vector<8x16x2xbf16>, + %C: vector<4x16xf32>) -> vector<4x16xf32> { + %0 = vector.contract + {kind = #vector.kind, + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction", "reduction"]} + %A, %B, %C : vector<4x8x2xbf16>, vector<8x16x2xbf16> into vector<4x16xf32> + return %0 : vector<4x16xf32> +} + +// CHECK-LABEL: @contract_vnni_bf16( +// CHECK-COUNT-3: amx.tile_load +// CHECK: amx.tile_mulf +// CHECK: amx.tile_store + +// ----- + +#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)> +#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)> +#map2 = affine_map<(m, n, k, vnni) -> (m, n)> +func.func @contract_vnni_i8(%A: vector<4x16x4xi8>, %B: vector<16x8x4xi8>, + %C: vector<4x8xi32>) -> vector<4x8xi32> { + %0 = vector.contract + {kind = #vector.kind, + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction", "reduction"]} + %A, %B, %C : vector<4x16x4xi8>, vector<16x8x4xi8> into vector<4x8xi32> + return %0 : vector<4x8xi32> +} + +// CHECK-LABEL: @contract_vnni_i8( +// CHECK-COUNT-3: amx.tile_load +// CHECK: amx.tile_muli +// CHECK: amx.tile_store + +// ----- + +#map = affine_map<(vnni, m, k, n) -> (m, k, vnni)> +#map1 = affine_map<(vnni, m, k, n) -> (k, n, vnni)> +#map2 = affine_map<(vnni, m, k, n) -> (m, n)> +func.func @contract_shuffled_iterators(%A: vector<4x16x4xi8>, %B: vector<16x8x4xi8>, + %C: vector<4x8xi32>) -> vector<4x8xi32> { + %0 = vector.contract + {kind = #vector.kind, + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "parallel", "reduction", "parallel"]} + %A, %B, %C : vector<4x16x4xi8>, vector<16x8x4xi8> into vector<4x8xi32> + return %0 : vector<4x8xi32> +} + +// CHECK-LABEL: @contract_shuffled_iterators( +// CHECK-COUNT-3: amx.tile_load +// CHECK: amx.tile_muli +// CHECK: amx.tile_store + +// ----- + +#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)> +#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)> +#map2 = affine_map<(m, n, k, vnni) -> (m, n)> +func.func @negative_invalid_kind(%A: vector<4x8x2xf16>, %B: vector<8x16x2xf16>, + %C: vector<4x16xf32>) -> vector<4x16xf32> { + %0 = vector.contract + {kind = #vector.kind, + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction", "reduction"]} + %A, %B, %C : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32> + return %0 : vector<4x16xf32> +} + +// CHECK-LABEL: @negative_invalid_kind( +// CHECK-NOT: amx +// CHECK: vector.contract + +// ----- + +#map = affine_map<(m, k, vnni) -> (m, k, vnni)> +#map1 = affine_map<(m, k, vnni) -> (k, m, vnni)> +#map2 = affine_map<(m, k, vnni) -> ()> +func.func @negative_non_vector_acc(%A: vector<4x8x2xf16>, %B: vector<8x4x2xf16>, + %C: f32) -> f32 { + %0 = vector.contract + {kind = #vector.kind, + indexing_maps = [#map, #map1, #map2], + iterator_types = ["reduction", "reduction", "reduction"]} + %A, %B, %C : vector<4x8x2xf16>, vector<8x4x2xf16> into f32 + return %0 : f32 +} + +// CHECK-LABEL: @negative_non_vector_acc( +// CHECK-NOT: amx +// CHECK: vector.contract + +// ----- + +#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)> +#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)> +#map2 = affine_map<(m, n, k, vnni) -> (m, n)> +func.func @negative_invalid_operand_types(%A: vector<4x8x2xf32>, %B: vector<8x16x2xf32>, + %C: vector<4x16xf32>) -> vector<4x16xf32> { + %0 = vector.contract + {kind = #vector.kind, + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction", "reduction"]} + %A, %B, %C : vector<4x8x2xf32>, vector<8x16x2xf32> into vector<4x16xf32> + return %0 : vector<4x16xf32> +} + +// CHECK-LABEL: @negative_invalid_operand_types( +// CHECK-NOT: amx +// CHECK: vector.contract + +// ----- + +#map = affine_map<(m, n, k) -> (m, k)> +#map1 = affine_map<(m, n, k) -> (k, n)> +#map2 = affine_map<(m, n, k) -> (m, n)> +func.func @negative_non_packed_layout(%A: vector<4x16xf16>, %B: vector<16x16xf16>, + %C: vector<4x16xf32>) -> vector<4x16xf32> { + %0 = vector.contract + {kind = #vector.kind, + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction"]} + %A, %B, %C : vector<4x16xf16>, vector<16x16xf16> into vector<4x16xf32> + return %0 : vector<4x16xf32> +} + +// CHECK-LABEL: @negative_non_packed_layout( +// CHECK-NOT: amx +// CHECK: vector.contract + +// ----- + +#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)> +#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)> +#map2 = affine_map<(m, n, k, vnni) -> (m, n)> +func.func @negative_invalid_vnni_factor(%A: vector<4x2x4xf16>, %B: vector<2x2x4xf16>, + %C: vector<4x2xf32>) -> vector<4x2xf32> { + %0 = vector.contract + {kind = #vector.kind, + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction", "reduction"]} + %A, %B, %C : vector<4x2x4xf16>, vector<2x2x4xf16> into vector<4x2xf32> + return %0 : vector<4x2xf32> +} + +// CHECK-LABEL: @negative_invalid_vnni_factor( +// CHECK-NOT: amx +// CHECK: vector.contract + +// ----- + +#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)> +#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)> +#map2 = affine_map<(m, n, k, vnni) -> (m, n)> +func.func @negative_too_many_rows(%A: vector<32x8x2xf16>, %B: vector<8x16x2xf16>, + %C: vector<32x16xf32>) -> vector<32x16xf32> { + %0 = vector.contract + {kind = #vector.kind, + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction", "reduction"]} + %A, %B, %C : vector<32x8x2xf16>, vector<8x16x2xf16> into vector<32x16xf32> + return %0 : vector<32x16xf32> +} + +// CHECK-LABEL: @negative_too_many_rows( +// CHECK-NOT: amx +// CHECK: vector.contract + +// ----- + +#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)> +#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)> +#map2 = affine_map<(m, n, k, vnni) -> (m, n)> +func.func @negative_too_wide_rows(%A: vector<4x32x2xf16>, %B: vector<32x16x2xf16>, + %C: vector<4x16xf32>) -> vector<4x16xf32> { + %0 = vector.contract + {kind = #vector.kind, + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction", "reduction"]} + %A, %B, %C : vector<4x32x2xf16>, vector<32x16x2xf16> into vector<4x16xf32> + return %0 : vector<4x16xf32> +} + +// CHECK-LABEL: @negative_too_wide_rows( +// CHECK-NOT: amx +// CHECK: vector.contract + +// ----- + +#map = affine_map<(m, n, k, vnni) -> (k, vnni, m)> +#map1 = affine_map<(m, n, k, vnni) -> (n, k, vnni)> +#map2 = affine_map<(m, n, k, vnni) -> (m, n)> +func.func @negative_input_dim_permutation(%A: vector<2x2x2xf16>, + %B: vector<2x2x2xf16>, %C: vector<2x2xf32>) -> vector<2x2xf32> { + %0 = vector.contract + {kind = #vector.kind, + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction", "reduction"]} + %A, %B, %C : vector<2x2x2xf16>, vector<2x2x2xf16> into vector<2x2xf32> + return %0 : vector<2x2xf32> +} + +// CHECK-LABEL: @negative_input_dim_permutation( +// CHECK-NOT: amx +// CHECK: vector.contract + +// ----- + +#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)> +#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)> +#map2 = affine_map<(m, n, k, vnni) -> (n, m)> +func.func @negative_output_dim_permutation(%A: vector<4x8x2xf16>, + %B: vector<8x16x2xf16>, %C: vector<16x4xf32>) -> vector<16x4xf32> { + %0 = vector.contract + {kind = #vector.kind, + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction", "reduction"]} + %A, %B, %C : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<16x4xf32> + return %0 : vector<16x4xf32> +} + +// CHECK-LABEL: @negative_output_dim_permutation( +// CHECK-NOT: amx +// CHECK: vector.contract From c909b55c24c7d01cc5c4d4732cfd1a5972797926 Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Tue, 29 Jul 2025 12:46:21 +0200 Subject: [PATCH 2/4] Extra test case --- .../VectorToAMX/contract-to-amx.mlir | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir b/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir index ad23964a15dd2..4fb88dd165126 100644 --- a/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir +++ b/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir @@ -162,8 +162,8 @@ func.func @negative_non_vector_acc(%A: vector<4x8x2xf16>, %B: vector<8x4x2xf16>, #map = affine_map<(m, n, k, vnni) -> (m, k, vnni)> #map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)> #map2 = affine_map<(m, n, k, vnni) -> (m, n)> -func.func @negative_invalid_operand_types(%A: vector<4x8x2xf32>, %B: vector<8x16x2xf32>, - %C: vector<4x16xf32>) -> vector<4x16xf32> { +func.func @negative_invalid_operand_types(%A: vector<4x8x2xf32>, + %B: vector<8x16x2xf32>, %C: vector<4x16xf32>) -> vector<4x16xf32> { %0 = vector.contract {kind = #vector.kind, indexing_maps = [#map, #map1, #map2], @@ -216,6 +216,25 @@ func.func @negative_invalid_vnni_factor(%A: vector<4x2x4xf16>, %B: vector<2x2x4x // ----- +#map = affine_map<(batch, m, n, k, vnni) -> (batch, m, k, vnni)> +#map1 = affine_map<(batch, m, n, k, vnni) -> (batch, k, n, vnni)> +#map2 = affine_map<(batch, m, n, k, vnni) -> (batch, m, n)> +func.func @negative_invalid_operands_shapes(%A: vector<1x4x8x2xf16>, + %B: vector<1x8x16x2xf16>, %C: vector<1x4x16xf32>) -> vector<1x4x16xf32> { + %0 = vector.contract + {kind = #vector.kind, + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} + %A, %B, %C : vector<1x4x8x2xf16>, vector<1x8x16x2xf16> into vector<1x4x16xf32> + return %0 : vector<1x4x16xf32> +} + +// CHECK-LABEL: @negative_invalid_operands_shapes( +// CHECK-NOT: amx +// CHECK: vector.contract + +// ----- + #map = affine_map<(m, n, k, vnni) -> (m, k, vnni)> #map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)> #map2 = affine_map<(m, n, k, vnni) -> (m, n)> From 6f32572c2de010052d9f771a012437ff9d6c6b1a Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Tue, 29 Jul 2025 21:50:26 +0200 Subject: [PATCH 3/4] Fix typo --- mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp index fc24275a1467c..bc22a62a788a0 100644 --- a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp +++ b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp @@ -1,4 +1,4 @@ -//===- VectorToXeGPU.cpp - Convert vector to AMX dialect --------*- C++ -*-===// +//===- VectorToAMX.cpp - Convert vector to AMX dialect ----------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// // -// This file implements lowering of vector operations to XeGPU dialect ops. +// This file implements lowering of vector operations to AMX dialect ops. // //===----------------------------------------------------------------------===// From cd999e481b10d6426b581b33bedc68c9f8d6b3cb Mon Sep 17 00:00:00 2001 From: Adam Siemieniuk Date: Tue, 29 Jul 2025 21:56:30 +0200 Subject: [PATCH 4/4] Simplify docs --- mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp index bc22a62a788a0..a11e9b2624300 100644 --- a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp +++ b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp @@ -5,10 +5,6 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -// -// This file implements lowering of vector operations to AMX dialect ops. -// -//===----------------------------------------------------------------------===// #include "mlir/Conversion/VectorToAMX/VectorToAMX.h"