Skip to content

Commit 7aecd7e

Browse files
authored
[mlir][Vector] Add vector.to_elements op (#141457)
This PR introduces the `vector.to_elements` op, which decomposes a vector into its scalar elements. This operation is symmetrical to the existing `vector.from_elements`. Examples: ``` // Decompose a 0-D vector. %0 = vector.to_elements %v0 : vector<f32> // %0 = %v0[0] // Decompose a 1-D vector. %0:2 = vector.to_elements %v1 : vector<2xf32> // %0#0 = %v1[0] // %0#1 = %v1[1] // Decompose a 2-D. %0:6 = vector.to_elements %v2 : vector<2x3xf32> // %0#0 = %v2[0, 0] // %0#1 = %v2[0, 1] // %0#2 = %v2[0, 2] // %0#3 = %v2[1, 0] // %0#4 = %v2[1, 1] // %0#5 = %v2[1, 2] ``` This op is aimed at reducing code size when modeling "structured" vector extractions and simplifying canonicalizations of large sequences of `vector.extract` and `vector.insert` ops into `vector.shuffle` and other sophisticated ops that can re-arrange vector elements.
1 parent b85e929 commit 7aecd7e

File tree

6 files changed

+181
-22
lines changed

6 files changed

+181
-22
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 67 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -790,40 +790,89 @@ def Vector_FMAOp :
790790
}];
791791
}
792792

793+
def Vector_ToElementsOp : Vector_Op<"to_elements", [
794+
Pure,
795+
ShapedTypeMatchesElementCountAndTypes<"source", "elements">]> {
796+
let summary = "operation that decomposes a vector into all its scalar elements";
797+
let description = [{
798+
This operation decomposes all the scalar elements from a vector. The
799+
decomposed scalar elements are returned in row-major order. The number of
800+
scalar results must match the number of elements in the input vector type.
801+
All the result elements have the same result type, which must match the
802+
element type of the input vector. Scalable vectors are not supported.
803+
804+
Examples:
805+
806+
```mlir
807+
// Decompose a 0-D vector.
808+
%0 = vector.to_elements %v0 : vector<f32>
809+
// %0 = %v0[0]
810+
811+
// Decompose a 1-D vector.
812+
%0:2 = vector.to_elements %v1 : vector<2xf32>
813+
// %0#0 = %v1[0]
814+
// %0#1 = %v1[1]
815+
816+
// Decompose a 2-D.
817+
%0:6 = vector.to_elements %v2 : vector<2x3xf32>
818+
// %0#0 = %v2[0, 0]
819+
// %0#1 = %v2[0, 1]
820+
// %0#2 = %v2[0, 2]
821+
// %0#3 = %v2[1, 0]
822+
// %0#4 = %v2[1, 1]
823+
// %0#5 = %v2[1, 2]
824+
825+
// Decompose a 3-D vector.
826+
%0:6 = vector.to_elements %v3 : vector<3x1x2xf32>
827+
// %0#0 = %v3[0, 0, 0]
828+
// %0#1 = %v3[0, 0, 1]
829+
// %0#2 = %v3[1, 0, 0]
830+
// %0#3 = %v3[1, 0, 1]
831+
// %0#4 = %v3[2, 0, 0]
832+
// %0#5 = %v3[2, 0, 1]
833+
```
834+
}];
835+
836+
let arguments = (ins AnyVectorOfAnyRank:$source);
837+
let results = (outs Variadic<AnyType>:$elements);
838+
let assemblyFormat = "$source attr-dict `:` type($source)";
839+
}
840+
793841
def Vector_FromElementsOp : Vector_Op<"from_elements", [
794842
Pure,
795-
TypesMatchWith<"operand types match result element type",
796-
"result", "elements", "SmallVector<Type>("
797-
"::llvm::cast<VectorType>($_self).getNumElements(), "
798-
"::llvm::cast<VectorType>($_self).getElementType())">]> {
843+
ShapedTypeMatchesElementCountAndTypes<"dest", "elements">]> {
799844
let summary = "operation that defines a vector from scalar elements";
800845
let description = [{
801846
This operation defines a vector from one or multiple scalar elements. The
802-
number of elements must match the number of elements in the result type.
803-
All elements must have the same type, which must match the element type of
804-
the result vector type.
805-
806-
`elements` are a flattened version of the result vector in row-major order.
847+
scalar elements are arranged in row-major within the vector. The number of
848+
elements must match the number of elements in the result type. All elements
849+
must have the same type, which must match the element type of the result
850+
vector type. Scalable vectors are not supported.
807851

808-
Example:
852+
Examples:
809853

810854
```mlir
811-
// %f1
855+
// Define a 0-D vector.
812856
%0 = vector.from_elements %f1 : vector<f32>
813-
// [%f1, %f2]
857+
// [%f1]
858+
859+
// Define a 1-D vector.
814860
%1 = vector.from_elements %f1, %f2 : vector<2xf32>
815-
// [[%f1, %f2, %f3], [%f4, %f5, %f6]]
861+
// [%f1, %f2]
862+
863+
// Define a 2-D vector.
816864
%2 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<2x3xf32>
817-
// [[[%f1, %f2]], [[%f3, %f4]], [[%f5, %f6]]]
865+
// [[%f1, %f2, %f3], [%f4, %f5, %f6]]
866+
867+
// Define a 3-D vector.
818868
%3 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<3x1x2xf32>
869+
// [[[%f1, %f2]], [[%f3, %f4]], [[%f5, %f6]]]
819870
```
820-
821-
Note, scalable vectors are not supported.
822871
}];
823872

824873
let arguments = (ins Variadic<AnyType>:$elements);
825-
let results = (outs AnyFixedVectorOfAnyRank:$result);
826-
let assemblyFormat = "$elements attr-dict `:` type($result)";
874+
let results = (outs AnyFixedVectorOfAnyRank:$dest);
875+
let assemblyFormat = "$elements attr-dict `:` type($dest)";
827876
let hasCanonicalizer = 1;
828877
}
829878

mlir/include/mlir/IR/OpBase.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,25 @@ class AllShapesMatch<list<string> names> :
556556
class AllTypesMatch<list<string> names> :
557557
AllMatchSameOperatorTrait<names, "$_self.getType()", "type">;
558558

559+
// A type constraint that verifies that a shaped type matches the size and
560+
// element type of a container with element types. More specifically, it denotes
561+
// shapedArg.getType().getNumElements() == elementsArg.size() &&
562+
// shapedArg.getType().getElementType() == elementsArg[i].getType(), for i in
563+
// [0, elementsArg.size()).
564+
class ShapedTypeMatchesElementCountAndTypes<string shapedArg,
565+
string elementsArg> :
566+
PredOpTrait<"shaped type '" # shapedArg # "' matches '" # elementsArg # "' "
567+
"element count and types",
568+
And<[CPred<ElementCount<shapedArg>.result # " == "
569+
"$" # elementsArg # ".getTypes().size()">,
570+
CPred<"::llvm::all_of($" # elementsArg # ".getTypes(), "
571+
"[&](::mlir::Type t) { return t == "
572+
# ElementType<shapedArg>.result # "; })">]>> {
573+
574+
string shaped = shapedArg;
575+
string elements = elementsArg;
576+
}
577+
559578
// A type constraint that denotes `transform(lhs.getType()) == rhs.getType()`.
560579
// An optional comparator function may be provided that changes the above form
561580
// into: `comparator(transform(lhs.getType()), rhs.getType())`.

mlir/lib/TableGen/Operator.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,37 @@ void Operator::populateTypeInferenceInfo(
468468
continue;
469469
}
470470

471+
// The `ShapedTypeMatchesElementCountAndTypes` trait represents a 1 -> 1
472+
// type inference edge where a shaped type matches element count and types
473+
// of variadic elements.
474+
if (def.isSubClassOf("ShapedTypeMatchesElementCountAndTypes")) {
475+
StringRef shapedArg = def.getValueAsString("shaped");
476+
StringRef elementsArg = def.getValueAsString("elements");
477+
478+
int shapedIndex = argumentsAndResultsIndex.lookup(shapedArg);
479+
int elementsIndex = argumentsAndResultsIndex.lookup(elementsArg);
480+
481+
// Handle result type inference from shaped type to variadic elements.
482+
if (InferredResultType::isResultIndex(elementsIndex) &&
483+
InferredResultType::isArgIndex(shapedIndex)) {
484+
int resultIndex = InferredResultType::unmapResultIndex(elementsIndex);
485+
ResultTypeInference &infer = inference[resultIndex];
486+
if (!infer.inferred) {
487+
infer.sources.emplace_back(
488+
shapedIndex,
489+
"::llvm::SmallVector<::mlir::Type>(::llvm::cast<::mlir::"
490+
"ShapedType>($_self).getNumElements(), "
491+
"::llvm::cast<::mlir::ShapedType>($_self).getElementType())");
492+
infer.inferred = true;
493+
}
494+
}
495+
496+
// Type inference in the opposite direction is not possible as the actual
497+
// shaped type can't be inferred from the variadic elements.
498+
499+
continue;
500+
}
501+
471502
if (!def.isSubClassOf("AllTypesMatch"))
472503
continue;
473504

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1896,7 +1896,24 @@ func.func @deinterleave_scalable_rank_fail(%vec : vector<2x[4]xf32>) {
18961896

18971897
// -----
18981898

1899-
func.func @invalid_from_elements(%a: f32) {
1899+
func.func @to_elements_wrong_num_results(%a: vector<1x1x2xf32>) {
1900+
// expected-error @+1 {{operation defines 2 results but was provided 4 to bind}}
1901+
%0:4 = vector.to_elements %a : vector<1x1x2xf32>
1902+
return
1903+
}
1904+
1905+
// -----
1906+
1907+
func.func @to_elements_wrong_result_type(%a: vector<2xf32>) -> i32 {
1908+
// expected-error @+3 {{use of value '%0' expects different type than prior uses: 'i32'}}
1909+
// expected-note @+1 {{prior use here}}
1910+
%0:2 = vector.to_elements %a : vector<2xf32>
1911+
return %0#0 : i32
1912+
}
1913+
1914+
// -----
1915+
1916+
func.func @from_elements_wrong_num_operands(%a: f32) {
19001917
// expected-error @+1 {{'vector.from_elements' number of operands and types do not match: got 1 operands and 2 types}}
19011918
vector.from_elements %a : vector<2xf32>
19021919
return
@@ -1905,16 +1922,15 @@ func.func @invalid_from_elements(%a: f32) {
19051922
// -----
19061923

19071924
// expected-note @+1 {{prior use here}}
1908-
func.func @invalid_from_elements(%a: f32, %b: i32) {
1925+
func.func @from_elements_wrong_operand_type(%a: f32, %b: i32) {
19091926
// expected-error @+1 {{use of value '%b' expects different type than prior uses: 'f32' vs 'i32'}}
19101927
vector.from_elements %a, %b : vector<2xf32>
19111928
return
19121929
}
1913-
19141930
// -----
19151931

19161932
func.func @invalid_from_elements_scalable(%a: f32, %b: i32) {
1917-
// expected-error @+1 {{'result' must be fixed-length vector of any type values, but got 'vector<[2]xf32>'}}
1933+
// expected-error @+1 {{'dest' must be fixed-length vector of any type values, but got 'vector<[2]xf32>'}}
19181934
vector.from_elements %a, %b : vector<[2]xf32>
19191935
return
19201936
}

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,6 +1175,24 @@ func.func @deinterleave_nd_scalable(%arg:vector<2x3x4x[6]xf32>) -> (vector<2x3x4
11751175
return %0, %1 : vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>
11761176
}
11771177

1178+
// CHECK-LABEL: func @to_elements(
1179+
// CHECK-SAME: %[[A_VEC:.*]]: vector<f32>, %[[B_VEC:.*]]: vector<1xf32>,
1180+
// CHECK-SAME: %[[C_VEC:.*]]: vector<1x2xf32>, %[[D_VEC:.*]]: vector<2x2xf32>)
1181+
func.func @to_elements(%a_vec : vector<f32>, %b_vec : vector<1xf32>,
1182+
%c_vec : vector<1x2xf32>, %d_vec : vector<2x2xf32>)
1183+
-> (f32, f32, f32, f32, f32, f32, f32, f32) {
1184+
// CHECK: %[[A_ELEMS:.*]] = vector.to_elements %[[A_VEC]] : vector<f32>
1185+
%0 = vector.to_elements %a_vec : vector<f32>
1186+
// CHECK: %[[B_ELEMS:.*]] = vector.to_elements %[[B_VEC]] : vector<1xf32>
1187+
%1 = vector.to_elements %b_vec : vector<1xf32>
1188+
// CHECK: %[[C_ELEMS:.*]]:2 = vector.to_elements %[[C_VEC]] : vector<1x2xf32>
1189+
%2:2 = vector.to_elements %c_vec : vector<1x2xf32>
1190+
// CHECK: %[[D_ELEMS:.*]]:4 = vector.to_elements %[[D_VEC]] : vector<2x2xf32>
1191+
%3:4 = vector.to_elements %d_vec : vector<2x2xf32>
1192+
// CHECK: return %[[A_ELEMS]], %[[B_ELEMS]], %[[C_ELEMS]]#0, %[[C_ELEMS]]#1, %[[D_ELEMS]]#0, %[[D_ELEMS]]#1, %[[D_ELEMS]]#2, %[[D_ELEMS]]#3
1193+
return %0, %1, %2#0, %2#1, %3#0, %3#1, %3#2, %3#3: f32, f32, f32, f32, f32, f32, f32, f32
1194+
}
1195+
11781196
// CHECK-LABEL: func @from_elements(
11791197
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
11801198
func.func @from_elements(%a: f32, %b: f32) -> (vector<f32>, vector<1xf32>, vector<1x2xf32>, vector<2x2xf32>) {

mlir/tools/mlir-tblgen/OpFormatGen.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2787,6 +2787,11 @@ class OpFormatParser : public FormatParser {
27872787
void handleTypesMatchConstraint(
27882788
StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def);
27892789

2790+
/// Check for inferable type resolution based on
2791+
/// `ShapedTypeMatchesElementCountAndTypes` constraint.
2792+
void handleShapedTypeMatchesElementCountAndTypesConstraint(
2793+
StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def);
2794+
27902795
/// Returns an argument or attribute with the given name that has been seen
27912796
/// within the format.
27922797
ConstArgument findSeenArg(StringRef name);
@@ -2850,6 +2855,9 @@ LogicalResult OpFormatParser::verify(SMLoc loc,
28502855
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
28512856
} else if (def.isSubClassOf("TypesMatchWith")) {
28522857
handleTypesMatchConstraint(variableTyResolver, def);
2858+
} else if (def.isSubClassOf("ShapedTypeMatchesElementCountAndTypes")) {
2859+
handleShapedTypeMatchesElementCountAndTypesConstraint(variableTyResolver,
2860+
def);
28532861
} else if (!op.allResultTypesKnown()) {
28542862
// This doesn't check the name directly to handle
28552863
// DeclareOpInterfaceMethods<InferTypeOpInterface>
@@ -3289,6 +3297,24 @@ void OpFormatParser::handleTypesMatchConstraint(
32893297
variableTyResolver[rhsName] = {arg, transformer};
32903298
}
32913299

3300+
void OpFormatParser::handleShapedTypeMatchesElementCountAndTypesConstraint(
3301+
StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def) {
3302+
StringRef shapedArg = def.getValueAsString("shaped");
3303+
StringRef elementsArg = def.getValueAsString("elements");
3304+
3305+
// Check if the 'shaped' argument is seen, then we can infer the 'elements'
3306+
// types.
3307+
if (ConstArgument arg = findSeenArg(shapedArg)) {
3308+
variableTyResolver[elementsArg] = {
3309+
arg, "::llvm::SmallVector<::mlir::Type>(::llvm::cast<::mlir::"
3310+
"ShapedType>($_self).getNumElements(), "
3311+
"::llvm::cast<::mlir::ShapedType>($_self).getElementType())"};
3312+
}
3313+
3314+
// Type inference in the opposite direction is not possible as the actual
3315+
// shaped type can't be inferred from the variadic elements.
3316+
}
3317+
32923318
ConstArgument OpFormatParser::findSeenArg(StringRef name) {
32933319
if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name))
32943320
return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;

0 commit comments

Comments
 (0)