-
Notifications
You must be signed in to change notification settings - Fork 14.1k
[MLIR][AArch64] Lower vector.contract
with mixed signed/unsigned arguments to Neon FEAT_I8MM
#144698
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[MLIR][AArch64] Lower vector.contract
with mixed signed/unsigned arguments to Neon FEAT_I8MM
#144698
Conversation
…rguments to Neon FEAT_I8MM
@llvm/pr-subscribers-mlir-sve @llvm/pr-subscribers-mlir Author: Momchil Velikov (momchil-velikov) ChangesFull diff: https://github.com/llvm/llvm-project/pull/144698.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 5ce3d2b28aeb3..967aff579227b 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -37,6 +37,81 @@ static Type matchContainerType(Type element, Type container) {
return element;
}
+// Get the operand of a `vector.contract`. This function is intended to abstract
+// away from the particular way a value is extended before feeding it into the
+// `vector.contract` - via zero-extend or an explicit or implicit sign-extend
+// (for implicit sign-extension see `vector.contract` documentation).
+//
+// The template parameter `Op` indicates the extension operation (explicit or
+// implicit) for which we are checking.
+//
+// Return success only for extensions from `iN` (N <= 8) to `i32`.
+template <typename Op>
+std::optional<Value> getExtOperand(Value v) {
+
+ static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
+ "Must be instantiated with either sign- or zero- extension op");
+
+ // If the operand is not defined by an explicit extend operation of the
+ // accepted operation type allow for an implicit sign-extension.
+ auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
+ if (!extOp) {
+ if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
+ auto eltTy = cast<VectorType>(v.getType()).getElementType();
+ if (!eltTy.isSignlessInteger() || eltTy.getIntOrFloatBitWidth() > 8)
+ return {};
+ return v;
+ }
+ return {};
+ }
+
+ // If the operand is defined by an explicit extend operation of the accepted
+ // operation type, check it's extended from `iN` (N <= 8) to `i32`.
+ auto inOp = extOp.getIn();
+ auto inTy = dyn_cast<VectorType>(inOp.getType());
+ if (!inTy)
+ return {};
+ auto inEltTy = inTy.getElementType();
+ if (!inEltTy.isSignlessInteger() || inEltTy.getIntOrFloatBitWidth() > 8)
+ return {};
+
+ auto outTy = dyn_cast<VectorType>(extOp.getType());
+ if (!(outTy && outTy.getElementType().isSignlessInteger(32)))
+ return {};
+
+ return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+ Signed, // smmla
+ Unsigned, // ummla
+ Mixed, // usmmla
+ MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix mulitply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+ mlir::Type accType, Value acc, Value lhs, Value rhs) {
+ switch (op) {
+ case MMLA::Signed:
+ return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, accType, acc, lhs,
+ rhs);
+ case MMLA::Unsigned:
+ return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, accType, acc, lhs,
+ rhs);
+ case MMLA::Mixed:
+ return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, lhs,
+ rhs);
+ case MMLA::MixedSwapped:
+ // The accumulator comes transposed and the result will be transposed
+ // later, so all we have to do here is swap the operands.
+ return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, rhs,
+ lhs);
+ }
+}
+
/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
/// any vector.contract into multiple smmla instructions with unrolling so long
/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
@@ -88,39 +163,64 @@ class LowerContractionToSMMLAPattern
return failure();
}
- // Check two extsi inputs Rhs Lhs for contract.
- arith::ExtSIOp origLhsExtOp =
- dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp());
- arith::ExtSIOp origRhsExtOp =
- dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp());
- if (!origLhsExtOp || !origRhsExtOp) {
+ // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the
+ // values before the extension. All four signed/unsigned combinations for
+ // input operands are supported, but they are lowered to different
+ // operations. Determine which is the appropriate operation to lower to.
+ MMLA mmlaOp = MMLA::Signed;
+ auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
+ if (!maybeLhs) {
+ mmlaOp = MMLA::Unsigned;
+ maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
+ }
+ if (!maybeLhs)
return failure();
+
+ auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
+ if (maybeRhs) {
+ if (mmlaOp == MMLA::Unsigned)
+ mmlaOp = MMLA::Mixed;
+ } else {
+ if (mmlaOp == MMLA::Signed)
+ mmlaOp = MMLA::MixedSwapped;
+ maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
}
+ if (!maybeRhs)
+ return failure();
+
+ Value origLhs = *maybeLhs;
+ Value origRhs = *maybeRhs;
// Match any iX to i32 for X<8 then turn into an i8 output. Feed into
// following neon instruction. Check inputs for extsi are <=i8
- Value extsiLhs;
- Value extsiRhs;
- if (auto lhsExtInType =
- dyn_cast<mlir::VectorType>(origLhsExtOp.getIn().getType())) {
+ Value extLhs;
+ Value extRhs;
+ if (auto lhsExtInType = dyn_cast<mlir::VectorType>(origLhs.getType())) {
if (lhsExtInType.getElementTypeBitWidth() <= 8) {
Type targetLhsExtTy =
matchContainerType(rewriter.getI8Type(), lhsExtInType);
- extsiLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
- origLhsExtOp.getIn());
+ if (mmlaOp == MMLA::Signed || mmlaOp == MMLA::Mixed)
+ extLhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetLhsExtTy,
+ origLhs);
+ else
+ extLhs = rewriter.createOrFold<arith::ExtUIOp>(loc, targetLhsExtTy,
+ origLhs);
}
}
- if (auto rhsExtInType =
- dyn_cast<mlir::VectorType>(origRhsExtOp.getIn().getType())) {
+ if (auto rhsExtInType = dyn_cast<mlir::VectorType>(origRhs.getType())) {
if (rhsExtInType.getElementTypeBitWidth() <= 8) {
Type targetRhsExtTy =
matchContainerType(rewriter.getI8Type(), rhsExtInType);
- extsiRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
- origRhsExtOp.getIn());
+ if (mmlaOp == MMLA::Unsigned || mmlaOp == MMLA::Mixed)
+ extRhs = rewriter.createOrFold<arith::ExtUIOp>(loc, targetRhsExtTy,
+ origRhs);
+ else
+ extRhs = rewriter.createOrFold<arith::ExtSIOp>(loc, targetRhsExtTy,
+ origRhs);
}
}
- if (!extsiLhs || !extsiRhs) {
+ if (!extLhs || !extRhs) {
return failure();
}
@@ -155,11 +255,11 @@ class LowerContractionToSMMLAPattern
AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
SmallVector<int64_t> lhsOffsets =
applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
- Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
+ Value tiledLhs = extractOperand(extLhs, lhsPermutationMap, lhsOffsets);
AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
SmallVector<int64_t> rhsOffsets =
applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
- Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
+ Value tiledRhs = extractOperand(extRhs, rhsPermutationMap, rhsOffsets);
AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
SmallVector<int64_t> accOffsets =
applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
@@ -191,6 +291,13 @@ class LowerContractionToSMMLAPattern
tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
}
+ // Transpose ACC if doing signed by unsigned multiplication, because we're
+ // using the instruction for unsigned by signed multiplication with
+ // reversed operands.
+ if (mmlaOp == MMLA::MixedSwapped)
+ tiledAcc = rewriter.create<vector::TransposeOp>(
+ loc, tiledAcc, ArrayRef<int64_t>({1, 0}));
+
// Collapse tiled operands to 1D vectors required by smmla intrinsic
auto collapsedInputType =
VectorType::get(inputExpandedType.getNumElements(), inputElementType);
@@ -211,15 +318,21 @@ class LowerContractionToSMMLAPattern
}
// Insert contract op
- kAcc = rewriter.createOrFold<arm_neon::SmmlaOp>(
- op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
- collapsedRhs);
+ kAcc = createMMLA(rewriter, mmlaOp, op.getLoc(), collapsedRes.getType(),
+ collapsedRes, collapsedLhs, collapsedRhs);
// Reshape output back to 2D
Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
kAcc.getLoc(), tiledAcc.getType(), kAcc);
- // With vecmat, only one row of tiled ACC can be inserted into file result
+ // Because of the reversed operands the result is obtained transposed.
+ // Transpose it back,
+ if (mmlaOp == MMLA::MixedSwapped)
+ tiledRes = rewriter.create<vector::TransposeOp>(
+ loc, tiledRes, ArrayRef<int64_t>({1, 0}));
+
+ // With vecmat, only one row of tiled ACC can be inserted into the final
+ // result
if (isVecmat) {
tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
}
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
index e4f7ea150c850..5fc29c6442602 100644
--- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -17,14 +17,28 @@ func.func @vector_arm_neon_mixed_types(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi4
// -----
-// CHECK-LABEL: vector_arm_neon_same_types
-// CHECK-SAME: %[[A0:.*]]: vector<2x8xi8>, %[[A1:.*]]: vector<2x8xi8>, %[[A2:.*]]: vector<2x2xi32>
-// CHECK-DAG: %[[D0:.*]] = vector.shape_cast %[[A0]] : vector<2x8xi8> to vector<16xi8>
-// CHECK-DAG: %[[D1:.*]] = vector.shape_cast %[[A1]] : vector<2x8xi8> to vector<16xi8>
-// CHECK-DAG: %[[D2:.*]] = vector.shape_cast %[[A2]] : vector<2x2xi32> to vector<4xi32>
-// CHECK-DAG: %[[D3:.*]] = arm_neon.intr.smmla %[[D2]], %[[D0]], %[[D1]] : vector<16xi8> to vector<4xi32>
-// CHECK-DAG: %[[D4:.*]] = vector.shape_cast %[[D3]] : vector<4xi32> to vector<2x2xi32>
-func.func @vector_arm_neon_same_types(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
+// CHECK-LABEL: vector_arm_neon_implicit_extsi
+// CHECK-SAME: %[[LHS:.+]]: vector<2x8xi8>, %[[RHS:.+]]: vector<2x8xi8>, %[[ACC:.+]]: vector<2x2xi32>
+// CHECK: %[[L:.+]] = vector.shape_cast %[[LHS]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[R:.+]] = vector.shape_cast %[[RHS]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[A:.+]] = vector.shape_cast %[[ACC]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[M:.+]] = arm_neon.intr.smmla %[[A]], %[[L]], %[[R]] : vector<16xi8> to vector<4xi32>
+// CHECK: %{{.+}} = vector.shape_cast %[[M]] : vector<4xi32> to vector<2x2xi32>
+func.func @vector_arm_neon_implicit_extsi(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<2x8xi8>, vector<2x8xi8> into vector<2x2xi32>
+ return %res : vector<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: vector_arm_neon_signed_signed
+// CHECK-SAME: %[[LHS:.+]]: vector<2x8xi8>, %[[RHS:.+]]: vector<2x8xi8>, %[[ACC:.+]]: vector<2x2xi32>
+// CHECK: %[[L:.+]] = vector.shape_cast %[[LHS]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[R:.+]] = vector.shape_cast %[[RHS]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[A:.+]] = vector.shape_cast %[[ACC]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[M:.+]] = arm_neon.intr.smmla %[[A]], %[[L]], %[[R]] : vector<16xi8> to vector<4xi32>
+// CHECK: %{{.+}} = vector.shape_cast %[[M]] : vector<4xi32> to vector<2x2xi32>
+func.func @vector_arm_neon_signed_signed(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
%lhs_extsi = arith.extsi %lhs : vector<2x8xi8> to vector<2x8xi32>
%rhs_extsi = arith.extsi %rhs : vector<2x8xi8> to vector<2x8xi32>
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
@@ -33,11 +47,51 @@ func.func @vector_arm_neon_same_types(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>
// -----
-// CHECK-LABEL: vector_arm_neon_without_extsi
-// CHECK-SAME: %[[A0:.*]]: vector<2x8xi32>, %[[A1:.*]]: vector<2x8xi32>, %[[A2:.*]]: vector<2x2xi32>
-// CHECK-DAG: %[[D0:.*]] = vector.contract
-func.func @vector_arm_neon_without_extsi(%lhs: vector<2x8xi32>, %rhs: vector<2x8xi32>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
- %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
+// CHECK-LABEL: vector_arm_neon_unsigned_signed
+// CHECK-SAME: %[[LHS:.+]]: vector<2x8xi8>, %[[RHS:.+]]: vector<2x8xi8>, %[[ACC:.+]]: vector<2x2xi32>
+// CHECK: %[[L:.+]] = vector.shape_cast %[[LHS]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[R:.+]] = vector.shape_cast %[[RHS]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[A:.+]] = vector.shape_cast %[[ACC]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[M:.+]] = arm_neon.intr.usmmla %[[A]], %[[L]], %[[R]] : vector<16xi8> to vector<4xi32>
+// CHECK: %{{.+}} = vector.shape_cast %[[M]] : vector<4xi32> to vector<2x2xi32>
+func.func @vector_arm_neon_unsigned_signed(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
+ %lhs_extsi = arith.extui %lhs : vector<2x8xi8> to vector<2x8xi32>
+ %rhs_extsi = arith.extsi %rhs : vector<2x8xi8> to vector<2x8xi32>
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
+ return %res : vector<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: vector_arm_neon_unsigned_unsigned
+// CHECK-SAME: %[[LHS:.+]]: vector<2x8xi8>, %[[RHS:.+]]: vector<2x8xi8>, %[[ACC:.+]]: vector<2x2xi32>
+// CHECK: %[[L:.+]] = vector.shape_cast %[[LHS]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[R:.+]] = vector.shape_cast %[[RHS]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[A:.+]] = vector.shape_cast %[[ACC]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[M:.+]] = arm_neon.intr.ummla %[[A]], %[[L]], %[[R]] : vector<16xi8> to vector<4xi32>
+// CHECK: %{{.+}} = vector.shape_cast %[[M]] : vector<4xi32> to vector<2x2xi32>
+func.func @vector_arm_neon_unsigned_unsigned(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
+ %lhs_extsi = arith.extui %lhs : vector<2x8xi8> to vector<2x8xi32>
+ %rhs_extsi = arith.extui %rhs : vector<2x8xi8> to vector<2x8xi32>
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
+ return %res : vector<2x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: vector_arm_neon_signed_unsigned
+// CHECK-SAME: %[[LHS:.+]]: vector<2x8xi8>, %[[RHS:.+]]: vector<2x8xi8>, %[[ACC:.+]]: vector<2x2xi32>
+// CHECK: %[[ACC_T:.+]] = vector.transpose %[[ACC]], [1, 0] : vector<2x2xi32> to vector<2x2xi32>
+// CHECK: %[[L:.+]] = vector.shape_cast %[[LHS]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[R:.+]] = vector.shape_cast %[[RHS]] : vector<2x8xi8> to vector<16xi8>
+// CHECK: %[[A:.+]] = vector.shape_cast %[[ACC_T]] : vector<2x2xi32> to vector<4xi32>
+// CHECK: %[[M:.+]] = arm_neon.intr.usmmla %[[A]], %[[R]], %[[L]] : vector<16xi8> to vector<4xi32>
+// CHECK: %[[OUT_T:.+]] = vector.shape_cast %[[M]] : vector<4xi32> to vector<2x2xi32>
+// CHECK: %{{.+}} = vector.transpose %[[OUT_T]], [1, 0] : vector<2x2xi32> to vector<2x2xi32>
+func.func @vector_arm_neon_signed_unsigned(%lhs: vector<2x8xi8>, %rhs: vector<2x8xi8>, %acc : vector<2x2xi32>) -> vector<2x2xi32> {
+ %lhs_extsi = arith.extsi %lhs : vector<2x8xi8> to vector<2x8xi32>
+ %rhs_extsi = arith.extui %rhs : vector<2x8xi8> to vector<2x8xi32>
+ %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
return %res : vector<2x2xi32>
}
|
vector.contract
with mixed signend/unsigned arguments to Neon FEAT_I8MMvector.contract
with mixed signed/unsigned arguments to Neon FEAT_I8MM
... since we generate now all of smmla/ummla/usmmla.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, LGTM, though lets make sure to document the steps required to unify the NEON and SVE implementations.
// This file implements lowering patterns from vector.contract to | ||
// arm_neon.intr.smmla | ||
// This file implements lowering patterns from vector.contract to operations | ||
// that map to instructions from the Neon FEAT_I8MM extension. | ||
// | ||
//===--- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FIXME
// Get the operand of a `vector.contract`. This function is intended to abstract | ||
// away from the particular way a value is extended before feeding it into the | ||
// `vector.contract` - via zero-extend or an explicit or implicit sign-extend | ||
// (for implicit sign-extension see `vector.contract` documentation). | ||
// | ||
// The template parameter `Op` indicates the extension operation (explicit or | ||
// implicit) for which we are checking. | ||
// | ||
// Return success only for extensions from `iN` (N <= 8) to `i32`. | ||
template <typename Op> | ||
std::optional<Value> getExtOperand(Value v) { | ||
|
||
static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value, | ||
"Must be instantiated with either sign- or zero- extension op"); | ||
|
||
// If the operand is not defined by an explicit extend operation of the | ||
// accepted operation type allow for an implicit sign-extension. | ||
auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp()); | ||
if (!extOp) { | ||
if constexpr (std::is_same<Op, arith::ExtSIOp>::value) { | ||
auto eltTy = cast<VectorType>(v.getType()).getElementType(); | ||
if (!eltTy.isSignlessInteger() || eltTy.getIntOrFloatBitWidth() > 8) | ||
return {}; | ||
return v; | ||
} | ||
return {}; | ||
} | ||
|
||
// If the operand is defined by an explicit extend operation of the accepted | ||
// operation type, check it's extended from `iN` (N <= 8) to `i32`. | ||
auto inOp = extOp.getIn(); | ||
auto inTy = dyn_cast<VectorType>(inOp.getType()); | ||
if (!inTy) | ||
return {}; | ||
auto inEltTy = inTy.getElementType(); | ||
if (!inEltTy.isSignlessInteger() || inEltTy.getIntOrFloatBitWidth() > 8) | ||
return {}; | ||
|
||
auto outTy = dyn_cast<VectorType>(extOp.getType()); | ||
if (!(outTy && outTy.getElementType().isSignlessInteger(32))) | ||
return {}; | ||
|
||
return inOp; | ||
} | ||
|
||
// Designate the operation (resp. instruction) used to do sub-tile matrix | ||
// multiplications. | ||
enum class MMLA { | ||
Signed, // smmla | ||
Unsigned, // ummla | ||
Mixed, // usmmla | ||
MixedSwapped // usmmla with LHS and RHS swapped | ||
}; | ||
|
||
// Create the matrix mulitply and accumulate operation according to `op`. | ||
Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, | ||
mlir::Type accType, Value acc, Value lhs, Value rhs) { | ||
switch (op) { | ||
case MMLA::Signed: | ||
return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, accType, acc, lhs, | ||
rhs); | ||
case MMLA::Unsigned: | ||
return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, accType, acc, lhs, | ||
rhs); | ||
case MMLA::Mixed: | ||
return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, lhs, | ||
rhs); | ||
case MMLA::MixedSwapped: | ||
// The accumulator comes transposed and the result will be transposed | ||
// later, so all we have to do here is swap the operands. | ||
return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, accType, acc, rhs, | ||
lhs); | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a lot of duplication with https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
Is vecMat
support the only other outstanding difference between NEON and SVE support? If we don't unify the implementations now, it would be good to at least leave some TODOs + GitHub issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can think of sharing the one duplicated function and mangling the other in a way so it can be shared between both patterns as soon as you suggest where to put them.
Depends on what do you mean by "unify". The algorithms for the transformations are completely different, I don't see any unification possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is
vecMat
support the only other outstanding difference between NEON and SVE support?
The differences on the functionality:
a) Neon handles arbitrary permutation maps, the SVE only the usual "identities + transposed RHS" one.
b) the Neon can handle iN
, N <= 8
, SVE can handle `N == 8'
c) SVE does not handle "vecmat" (it wants LHS with even number of rows)
b) and c) can probably be added to the SVE version without too much trouble (as least as vector.contract
lowering is concerned)
I don't have high expectations for a), the SVE handling of RHS especially was rather tricky even with the simple maps. Moreover, I'd also question the necessity of support for arbitrary permutation maps - a great chunk (I'll even wave my hands and say the "greatest chunk") of performance comes from the ability to do sequential loads and stores.
I'm listening. I don't see opportunities for "unification", but I'm open to suggestions. |
No description provided.