-
Notifications
You must be signed in to change notification settings - Fork 14.6k
[mlir][amx] Vector to AMX conversion pass #151121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Adds a pass for Vector to AMX operation conversion. Initially, a direct rewrite for vector contraction in packed VNNI layout is supported. Operations are expected to already be in shapes which are AMX-compatible for the rewriting to occur.
cc: @arun-thmn |
@llvm/pr-subscribers-mlir Author: Adam Siemieniuk (adam-smnk) ChangesAdds a pass for Vector to AMX operation conversion. Initially, a direct rewrite for vector contraction in packed VNNI layout is supported. Operations are expected to already be in shapes which are AMX-compatible for the rewriting to occur. Patch is 26.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/151121.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 3dc48b2201cf2..91b2ecf8922a3 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -75,6 +75,7 @@
#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
+#include "mlir/Conversion/VectorToAMX/VectorToAMX.h"
#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
#include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index cf7596cc8a928..20ead98acc371 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1515,6 +1515,19 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> {
];
}
+//===----------------------------------------------------------------------===//
+// VectorToAMX
+//===----------------------------------------------------------------------===//
+
+def ConvertVectorToAMX : Pass<"convert-vector-to-amx"> {
+ let summary = "Lower the operations from the vector dialect into the AMX "
+ "dialect";
+ let dependentDialects = [
+ "affine::AffineDialect", "amx::AMXDialect", "arith::ArithDialect",
+ "memref::MemRefDialect", "scf::SCFDialect", "vector::VectorDialect"
+ ];
+}
+
//===----------------------------------------------------------------------===//
// XeVMToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h b/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h
new file mode 100644
index 0000000000000..b075ac92990a2
--- /dev/null
+++ b/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h
@@ -0,0 +1,26 @@
+//===- VectorToAMX.h - Convert vector to AMX dialect ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_VECTORTOAMX_VECTORTOAMX_H
+#define MLIR_CONVERSION_VECTORTOAMX_VECTORTOAMX_H
+
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+class Pass;
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_CONVERTVECTORTOAMX
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Collect a set of patterns to convert from the vector to AMX ops.
+void populateVectorToAMXConversionPatterns(RewritePatternSet &patterns);
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_VECTORTOAMX_VECTORTOAMX_H
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 785cb8293810c..171f7169fd41d 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -68,6 +68,7 @@ add_subdirectory(TosaToSCF)
add_subdirectory(TosaToTensor)
add_subdirectory(UBToLLVM)
add_subdirectory(UBToSPIRV)
+add_subdirectory(VectorToAMX)
add_subdirectory(VectorToArmSME)
add_subdirectory(VectorToGPU)
add_subdirectory(VectorToLLVM)
diff --git a/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt b/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt
new file mode 100644
index 0000000000000..2d4b2b6e9283c
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRVectorToAMX
+ VectorToAMX.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToAMX
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRAMXDialect
+ MLIRAffineUtils
+ MLIRArithDialect
+ MLIRLinalgUtils
+ MLIRMemRefDialect
+ MLIRSCFDialect
+ MLIRTransforms
+ MLIRVectorDialect
+ )
diff --git a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
new file mode 100644
index 0000000000000..fc24275a1467c
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
@@ -0,0 +1,287 @@
+//===- VectorToXeGPU.cpp - Convert vector to AMX dialect --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering of vector operations to XeGPU dialect ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/VectorToAMX/VectorToAMX.h"
+
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include <numeric>
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTVECTORTOAMX
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+/// Return true if vector shape is compatible with AMX tiles.
+/// The validation accounts for VNNI packing.
+static bool verifyAmxShape(VectorType vec) {
+ // Check overall shape:
+ // - 2D for plain layout input or output
+ // - 3D for VNNI packed input
+ if (vec.getRank() != 2 && vec.getRank() != 3)
+ return false;
+
+ ArrayRef<int64_t> shape = vec.getShape();
+ int64_t rows = shape[0];
+ int64_t cols = shape[1];
+ unsigned elemBitWidth = vec.getElementType().getIntOrFloatBitWidth();
+
+ // 3D shape indicates VNNI packed layout.
+ if (vec.getRank() == 3) {
+ int64_t vnniFactor = 32 / elemBitWidth;
+ if (shape.back() != vnniFactor)
+ return false;
+ cols *= vnniFactor;
+ }
+
+ // AMX tile supports up to 16 rows of 64 bytes each.
+ constexpr unsigned maxRows = 16;
+ constexpr unsigned maxBitsPerRow = 64 * 8;
+ return rows <= maxRows && (cols * elemBitWidth) <= maxBitsPerRow;
+}
+
+/// Checks if contraction operands are in AMX-compatible packed VNNI layout.
+static LogicalResult isAmxVnniLayout(PatternRewriter &rewriter,
+ vector::ContractionOp contractOp) {
+ VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
+ if (!accType || accType.getRank() != 2)
+ return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
+
+ // Expect 3D inputs for VNNI packed data.
+ VectorType lhsType = contractOp.getLhs().getType();
+ VectorType rhsType = contractOp.getRhs().getType();
+ if (lhsType.getRank() != 3 || rhsType.getRank() != 3)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects lhs and rhs 3D vectors");
+
+ // Check if shapes are compatible with AMX tile.
+ if (!verifyAmxShape(lhsType) || !verifyAmxShape(rhsType) ||
+ !verifyAmxShape(accType))
+ return rewriter.notifyMatchFailure(contractOp, "Invalid operand shape");
+
+ // Validate affine maps.
+ //
+ // Iterators can be ordered arbitrarily. Indexing map positions are based on
+ // operands' target shapes.
+ // The matrix layouts must match the following:
+ // - matrix A - [M]x[K/vnniFactor]x[vnniFactor]
+ // - matrix B - [K/vnniFactor]x[N]x[vnniFactor]
+ // - matrix C - [M]x[N]
+ SmallVector<AffineMap, 4> indexingMaps = contractOp.getIndexingMapsArray();
+ AffineMap mapA = indexingMaps[0];
+ AffineMap mapB = indexingMaps[1];
+ if (mapA.getNumInputs() != 4 || mapA.getNumResults() != 3 ||
+ mapB.getNumResults() != 3)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Invalid input indexing maps");
+ FailureOr<linalg::ContractionDimensions> dims =
+ linalg::inferContractionDims(indexingMaps);
+ if (failed(dims))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Failed to infer contraction dims");
+ // Two reduction dimensions are expected:
+ // - one for the K dimension
+ // - one for the VNNI factor
+ if (dims->k.size() != 2)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expected two reduction dims");
+ assert(dims->m.size() == 1 && dims->n.size() == 1 &&
+ "Invalid parallel contraction dims");
+
+ SmallVector<vector::IteratorType> iteratorTypes =
+ contractOp.getIteratorTypesArray();
+ // Check VNNI dim maps - the innermost dim for A and B inputs.
+ auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(2));
+ auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(2));
+ if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
+ iteratorTypes[vnniDimA.getPosition()] != vector::IteratorType::reduction)
+ return rewriter.notifyMatchFailure(contractOp, "Invalid VNNI dim map");
+ // Check K dim maps - non-transposed row-major layout.
+ auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(1));
+ auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(0));
+ if (!redDimA || !redDimB || redDimA != redDimB ||
+ iteratorTypes[redDimA.getPosition()] != vector::IteratorType::reduction)
+ return rewriter.notifyMatchFailure(contractOp, "Invalid K dim map");
+ // Check M and N dim maps - map to non-transposed output.
+ AffineMap mapC = indexingMaps[2];
+ auto mDimC = dyn_cast<AffineDimExpr>(mapC.getResult(0));
+ auto nDimC = dyn_cast<AffineDimExpr>(mapC.getResult(1));
+ if (!mDimC || !nDimC)
+ return rewriter.notifyMatchFailure(contractOp, "Invalid acc maps");
+ auto parallelDimA = dyn_cast<AffineDimExpr>(mapA.getResult(0));
+ if (!parallelDimA ||
+ iteratorTypes[parallelDimA.getPosition()] !=
+ vector::IteratorType::parallel ||
+ parallelDimA != mDimC)
+ return rewriter.notifyMatchFailure(contractOp, "Invalid M dim map");
+ auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(1));
+ if (!parallelDimB ||
+ iteratorTypes[parallelDimB.getPosition()] !=
+ vector::IteratorType::parallel ||
+ parallelDimB != nDimC)
+ return rewriter.notifyMatchFailure(contractOp, "Invalid N dim map");
+
+ return success();
+}
+
+/// Validate contraction operands for AMX lowering.
+static LogicalResult validateOperands(PatternRewriter &rewriter,
+ vector::ContractionOp contractOp) {
+ VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
+ if (!accType)
+ return rewriter.notifyMatchFailure(contractOp, "Expects vector acc");
+
+ // Check if operand types are compatible with AMX compute ops.
+ bool validElemTypes = false;
+ Type lhsElemType = contractOp.getLhs().getType().getElementType();
+ Type rhsElemType = contractOp.getRhs().getType().getElementType();
+ Type accElemType = accType.getElementType();
+ if (accElemType.isInteger(32)) {
+ validElemTypes = lhsElemType.isInteger(8) && rhsElemType.isInteger(8);
+ } else if (accElemType.isF32()) {
+ validElemTypes = (lhsElemType.isF16() && rhsElemType.isF16()) ||
+ (lhsElemType.isBF16() && rhsElemType.isBF16());
+ }
+ if (!validElemTypes)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Invalid combination of operand types");
+
+ if (failed(isAmxVnniLayout(rewriter, contractOp)))
+ return failure();
+
+ return success();
+}
+
+/// Collapses the two innermost dimensions together.
+static Value collapseLastDim(PatternRewriter &rewriter,
+ TypedValue<MemRefType> memref) {
+ int64_t rank = memref.getType().getRank();
+ SmallVector<ReassociationIndices> reassocIndices;
+ for (auto i : llvm::seq<int64_t>(0, rank - 2))
+ reassocIndices.push_back({i});
+ reassocIndices.push_back({rank - 2, rank - 1});
+ return memref::CollapseShapeOp::create(rewriter, memref.getLoc(), memref,
+ reassocIndices);
+}
+
+/// Loads vector values to an AMX tile.
+static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
+ TypedValue<VectorType> vec) {
+ Location loc = vec.getLoc();
+ Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+
+ // Transfer the vector to a tile through an intermediate buffer.
+ VectorType vecTy = vec.getType();
+ Value buf = memref::AllocaOp::create(
+ rewriter, loc, MemRefType::get(vecTy.getShape(), vecTy.getElementType()));
+ SmallVector<Value> indices(vecTy.getRank(), zeroIndex);
+ vector::TransferWriteOp::create(rewriter, loc, vec, buf, indices);
+
+ // Collapse the VNNI dimension in case of packing.
+ bool isPacked = vecTy.getRank() == 3;
+ if (isPacked)
+ buf = collapseLastDim(rewriter, cast<TypedValue<MemRefType>>(buf));
+
+ ArrayRef<int64_t> shape = vecTy.getShape();
+ int64_t rows = shape[0];
+ int64_t cols = std::accumulate(shape.begin() + 1, shape.end(), 1,
+ std::multiplies<int64_t>());
+ auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
+
+ return amx::TileLoadOp::create(rewriter, loc, tileType, buf,
+ {zeroIndex, zeroIndex});
+}
+
+/// Stores an AMX tile in a vector.
+static TypedValue<VectorType> storeTile(PatternRewriter &rewriter,
+ TypedValue<amx::TileType> tile) {
+ Location loc = tile.getLoc();
+ Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+
+ // Transfer the tile to a vector through an intermediate buffer.
+ amx::TileType tileTy = tile.getType();
+ Value buf = memref::AllocaOp::create(
+ rewriter, loc,
+ MemRefType::get(tileTy.getShape(), tileTy.getElementType()));
+ SmallVector<Value> indices(2, zeroIndex);
+ amx::TileStoreOp::create(rewriter, loc, buf, indices, tile);
+
+ auto vecTy = VectorType::get(tileTy.getShape(), tileTy.getElementType());
+ return vector::TransferReadOp::create(rewriter, loc, vecTy, buf, indices, {});
+}
+
+struct ContractionToAMX : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+ Location loc = contractOp.getLoc();
+
+ if (contractOp.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind");
+ if (failed(validateOperands(rewriter, contractOp)))
+ return failure();
+
+ TypedValue<amx::TileType> lhsTile = loadTile(rewriter, contractOp.getLhs());
+ TypedValue<amx::TileType> rhsTile = loadTile(rewriter, contractOp.getRhs());
+ auto acc = dyn_cast<TypedValue<VectorType>>(contractOp.getAcc());
+ assert(acc && "Invalid accumulator type");
+ TypedValue<amx::TileType> accTile = loadTile(rewriter, acc);
+
+ TypedValue<amx::TileType> tileMul;
+ if (acc.getType().getElementType().isFloat()) {
+ tileMul = amx::TileMulFOp::create(rewriter, loc, accTile.getType(),
+ lhsTile, rhsTile, accTile);
+ } else {
+ tileMul = amx::TileMulIOp::create(rewriter, loc, accTile.getType(),
+ lhsTile, rhsTile, accTile);
+ }
+
+ Value res = storeTile(rewriter, tileMul);
+ rewriter.replaceOp(contractOp, res);
+
+ return success();
+ }
+};
+
+struct ConvertVectorToAMXPass
+ : public impl::ConvertVectorToAMXBase<ConvertVectorToAMXPass> {
+ void runOnOperation() override {
+ MLIRContext &ctx = getContext();
+ RewritePatternSet patterns(&ctx);
+ populateVectorToAMXConversionPatterns(patterns);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+
+} // namespace
+
+void mlir::populateVectorToAMXConversionPatterns(RewritePatternSet &patterns) {
+ patterns.add<ContractionToAMX>(patterns.getContext());
+}
diff --git a/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir b/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir
new file mode 100644
index 0000000000000..ad23964a15dd2
--- /dev/null
+++ b/mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir
@@ -0,0 +1,291 @@
+// RUN: mlir-opt %s -convert-vector-to-amx -split-input-file | FileCheck %s
+
+/// VNNI format is Intel's packed data layout.
+/// For matrix multiplication, elements from the reduction dimension `k`
+/// are packed into 32-bit tuples. Then the appropriate AMX operations can
+/// perform tile multiplication directly on the packed data.
+///
+/// These packed elements are represented in the indexing maps by a separate
+/// reduction dimension `vnni`.
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @contract_vnni_f16(%A: vector<4x8x2xf16>, %B: vector<8x16x2xf16>,
+ %C: vector<4x16xf32>) -> vector<4x16xf32> {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %A, %B, %C : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+ return %0 : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @contract_vnni_f16(
+// CHECK-SAME: %[[A:.+]]: vector<4x8x2xf16>,
+// CHECK-SAME: %[[B:.+]]: vector<8x16x2xf16>,
+// CHECK-SAME: %[[C:.+]]: vector<4x16xf32>
+
+/// AMX hardware has no direct access to the registers. Thus, data must
+/// be transfered through intermediate buffers.
+///
+/// Load A vector into an AMX tile
+// CHECK: %[[A_BUF:.+]] = memref.alloca() : memref<4x8x2xf16>
+// CHECK: vector.transfer_write %[[A]], %[[A_BUF]]
+// CHECK: %[[A_BUF_2D:.+]] = memref.collapse_shape %[[A_BUF]]
+// CHECK-SAME: {{\[}}[0], [1, 2]] : memref<4x8x2xf16> into memref<4x16xf16>
+// CHECK: %[[A_TILE:.+]] = amx.tile_load %[[A_BUF_2D]]
+
+/// Load B vector into an AMX tile
+// CHECK: %[[B_BUF:.+]] = memref.alloca() : memref<8x16x2xf16>
+// CHECK: vector.transfer_write %[[B]], %[[B_BUF]]
+// CHECK: %[[B_BUF_2D:.+]] = memref.collapse_shape %[[B_BUF]]
+// CHECK-SAME: {{\[}}[0], [1, 2]] : memref<8x16x2xf16> into memref<8x32xf16>
+// CHECK: %[[B_TILE:.+]] = amx.tile_load %[[B_BUF_2D]]
+
+/// Load C vector into an AMX tile
+// CHECK: %[[C_BUF:.+]] = memref.alloca() : memref<4x16xf32>
+// CHECK: vector.transfer_write %[[C]], %[[C_BUF]]
+// CHECK: %[[C_TILE:.+]] = amx.tile_load %[[C_BUF]]
+
+/// Perform tile multiplication
+// CHECK: %[[RES:.+]] = amx.tile_mulf
+// CHECK-SAME: %[[A_TILE]], %[[B_TILE]], %[[C_TILE]]
+
+/// Load the result back into a vector
+// CHECK: %[[RES_BUF:.+]] = memref.alloca() : memref<4x16xf32>
+// CHECK: amx.tile_store %[[RES_BUF]]{{.*}}, %[[RES]]
+// CHECK: %[[RES_VEC:.+]] = vector.transfer_read %[[RES_BUF]]
+
+// CHECK: return %[[RES_VEC]]
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @contract_vnni_bf16(%A: vector<4x8x2xbf16>, %B: vector<8x16x2xbf16>,
+ %C: vector<4x16xf32>) -> vector<4x16xf32> {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %A, %B, %C : vector<4x8x2xbf16>, vector<8x16x2xbf16> into vector<4x16xf32>
+ return %0 : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @contract_vnni_bf16(
+// CHECK-COUNT-3: amx.tile_load
+// CHECK: amx.tile_mulf
+// CHECK: amx.tile_store
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @contract_vnni_i8(%A: vector<4x16x4xi8>, %B: vector<16x8x4xi8>,
+ %C: vector<4x8xi32>) -> vector<4x8xi32> {
+ %0 = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but I'll leave the AVX specifics to Renato :)
Ha! You probably know more about AVX than I do! :D We discussed internally and the shapes being considered are the bare minimum that we can use today upstream from Linalg and Triton (through vector). Anything beyond that will need to use some kind of heuristics, or auto-tuning/DLTI, which is still in progress. LGTM too, thanks! |
For next steps, I'm preparing a follow up to try reducing number of data transfers when possible. |
Adds a pass for Vector to AMX operation conversion.
Initially, a direct rewrite for vector contraction in packed VNNI layout is supported. Operations are expected to already be in shapes which are AMX-compatible for the rewriting to occur.