Skip to content

Commit ff6367b

Browse files
authored
[[mlir][Vector] Add simple folders for vector.from_element/vector.to_elements (#144444)
This PR adds simple folders to remove no-op sequences of `vector.from_elements` and `vector.to_elements`.
1 parent bae48ac commit ff6367b

File tree

3 files changed

+139
-0
lines changed

3 files changed

+139
-0
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -836,6 +836,7 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [
836836
let arguments = (ins AnyVectorOfAnyRank:$source);
837837
let results = (outs Variadic<AnyType>:$elements);
838838
let assemblyFormat = "$source attr-dict `:` type($source)";
839+
let hasFolder = 1;
839840
}
840841

841842
def Vector_FromElementsOp : Vector_Op<"from_elements", [
@@ -873,6 +874,7 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [
873874
let arguments = (ins Variadic<AnyType>:$elements);
874875
let results = (outs AnyFixedVectorOfAnyRank:$dest);
875876
let assemblyFormat = "$elements attr-dict `:` type($dest)";
877+
let hasFolder = 1;
876878
let hasCanonicalizer = 1;
877879
}
878880

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2373,10 +2373,95 @@ std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
23732373
return llvm::to_vector<4>(getVectorType().getShape());
23742374
}
23752375

2376+
//===----------------------------------------------------------------------===//
2377+
// ToElementsOp
2378+
//===----------------------------------------------------------------------===//
2379+
2380+
/// Returns true if all the `operands` are defined by `defOp`.
2381+
/// Otherwise, returns false.
2382+
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp) {
2383+
if (operands.empty())
2384+
return false;
2385+
2386+
return llvm::all_of(operands, [&](Value operand) {
2387+
Operation *currentDef = operand.getDefiningOp();
2388+
return currentDef == defOp;
2389+
});
2390+
}
2391+
2392+
/// Folds vector.to_elements(vector.from_elements(%e0, %e1, ...)) into
2393+
/// (%e0, %e1, ...). For example:
2394+
///
2395+
/// %0 = vector.from_elements %a, %b, %c : vector<3xf32>
2396+
/// %1:3 = vector.to_elements %0 : vector<3xf32>
2397+
/// user_op %1#0, %1#1, %1#2
2398+
///
2399+
/// becomes:
2400+
///
2401+
/// user_op %a, %b, %c
2402+
///
2403+
static LogicalResult
2404+
foldToElementsFromElements(ToElementsOp toElementsOp,
2405+
SmallVectorImpl<OpFoldResult> &results) {
2406+
auto fromElementsOp =
2407+
toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2408+
if (!fromElementsOp)
2409+
return failure();
2410+
2411+
llvm::append_range(results, fromElementsOp.getElements());
2412+
return success();
2413+
}
2414+
2415+
LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2416+
SmallVectorImpl<OpFoldResult> &results) {
2417+
return foldToElementsFromElements(*this, results);
2418+
}
2419+
23762420
//===----------------------------------------------------------------------===//
23772421
// FromElementsOp
23782422
//===----------------------------------------------------------------------===//
23792423

2424+
/// Folds vector.from_elements(vector.to_elements(%vector)) into %vector.
2425+
///
2426+
/// Case #1: Input and output vectors are the same.
2427+
///
2428+
/// %0:3 = vector.to_elements %a : vector<3xf32>
2429+
/// %1 = vector.from_elements %0#0, %0#1, %0#2 : vector<3xf32>
2430+
/// user_op %1
2431+
///
2432+
/// becomes:
2433+
///
2434+
/// user_op %a
2435+
///
2436+
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
2437+
OperandRange fromElemsOperands = fromElementsOp.getElements();
2438+
if (fromElemsOperands.empty())
2439+
return {};
2440+
2441+
auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2442+
if (!toElementsOp)
2443+
return {};
2444+
2445+
if (!haveSameDefiningOp(fromElemsOperands, toElementsOp))
2446+
return {};
2447+
2448+
// Case #1: Input and output vectors are the same. Forward the input vector.
2449+
Value toElementsInput = toElementsOp.getSource();
2450+
if (fromElementsOp.getType() == toElementsInput.getType() &&
2451+
llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2452+
return toElementsInput;
2453+
}
2454+
2455+
// TODO: Support cases with different input and output shapes and different
2456+
// number of elements.
2457+
2458+
return {};
2459+
}
2460+
2461+
OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2462+
return foldFromElementsToElements(*this);
2463+
}
2464+
23802465
/// Rewrite a vector.from_elements into a vector.splat if all elements are the
23812466
/// same SSA value. E.g.:
23822467
///

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3023,6 +3023,58 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
30233023

30243024
// -----
30253025

3026+
// CHECK-LABEL: func @to_elements_from_elements_no_op(
3027+
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32
3028+
func.func @to_elements_from_elements_no_op(%a: f32, %b: f32) -> (f32, f32) {
3029+
// CHECK-NOT: vector.from_elements
3030+
// CHECK-NOT: vector.to_elements
3031+
%0 = vector.from_elements %b, %a : vector<2xf32>
3032+
%1:2 = vector.to_elements %0 : vector<2xf32>
3033+
// CHECK: return %[[B]], %[[A]]
3034+
return %1#0, %1#1 : f32, f32
3035+
}
3036+
3037+
// -----
3038+
3039+
// CHECK-LABEL: func @from_elements_to_elements_no_op(
3040+
// CHECK-SAME: %[[A:.*]]: vector<4x2xf32>
3041+
func.func @from_elements_to_elements_no_op(%a: vector<4x2xf32>) -> vector<4x2xf32> {
3042+
// CHECK-NOT: vector.from_elements
3043+
// CHECK-NOT: vector.to_elements
3044+
%0:8 = vector.to_elements %a : vector<4x2xf32>
3045+
%1 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7 : vector<4x2xf32>
3046+
// CHECK: return %[[A]]
3047+
return %1 : vector<4x2xf32>
3048+
}
3049+
3050+
// -----
3051+
3052+
// CHECK-LABEL: func @from_elements_to_elements_dup_elems(
3053+
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
3054+
func.func @from_elements_to_elements_dup_elems(%a: vector<4xf32>) -> vector<4x2xf32> {
3055+
// CHECK: %[[TO_EL:.*]]:4 = vector.to_elements %[[A]]
3056+
// CHECK: %[[FROM_EL:.*]] = vector.from_elements %[[TO_EL]]#0, %[[TO_EL]]#1, %[[TO_EL]]#2
3057+
%0:4 = vector.to_elements %a : vector<4xf32> // 4 elements
3058+
%1 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %0#0, %0#1, %0#2, %0#3 : vector<4x2xf32>
3059+
// CHECK: return %[[FROM_EL]]
3060+
return %1 : vector<4x2xf32>
3061+
}
3062+
3063+
// -----
3064+
3065+
// CHECK-LABEL: func @from_elements_to_elements_shuffle(
3066+
// CHECK-SAME: %[[A:.*]]: vector<4x2xf32>
3067+
func.func @from_elements_to_elements_shuffle(%a: vector<4x2xf32>) -> vector<4x2xf32> {
3068+
// CHECK: %[[TO_EL:.*]]:8 = vector.to_elements %[[A]]
3069+
// CHECK: %[[FROM_EL:.*]] = vector.from_elements %[[TO_EL]]#7, %[[TO_EL]]#0, %[[TO_EL]]#6
3070+
%0:8 = vector.to_elements %a : vector<4x2xf32>
3071+
%1 = vector.from_elements %0#7, %0#0, %0#6, %0#1, %0#5, %0#2, %0#4, %0#3 : vector<4x2xf32>
3072+
// CHECK: return %[[FROM_EL]]
3073+
return %1 : vector<4x2xf32>
3074+
}
3075+
3076+
// -----
3077+
30263078
// CHECK-LABEL: func @vector_insert_const_regression(
30273079
// CHECK: llvm.mlir.undef
30283080
// CHECK: vector.insert

0 commit comments

Comments
 (0)