Skip to content

Commit 211d301

Browse files
committed
[mlir][Vector] Add simple folders for vector.from_element/vector.to_elements
This PR adds simple folders to remove no-op sequences of `vector.from_elements` and `vector.to_elements`.
1 parent 7588419 commit 211d301

File tree

3 files changed

+144
-0
lines changed

3 files changed

+144
-0
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,7 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [
839839
let arguments = (ins AnyVectorOfAnyRank:$source);
840840
let results = (outs Variadic<AnyType>:$elements);
841841
let assemblyFormat = "$source attr-dict `:` type($source)";
842+
let hasFolder = 1;
842843
}
843844

844845
def Vector_FromElementsOp : Vector_Op<"from_elements", [
@@ -879,6 +880,7 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [
879880
let arguments = (ins Variadic<AnyType>:$elements);
880881
let results = (outs AnyFixedVectorOfAnyRank:$dest);
881882
let assemblyFormat = "$elements attr-dict `:` type($dest)";
883+
let hasFolder = 1;
882884
let hasCanonicalizer = 1;
883885
}
884886

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2370,10 +2370,100 @@ std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
23702370
return llvm::to_vector<4>(getVectorType().getShape());
23712371
}
23722372

2373+
//===----------------------------------------------------------------------===//
2374+
// ToElementsOp
2375+
//===----------------------------------------------------------------------===//
2376+
2377+
/// Returns true if all the `operands` are defined by `defOp`.
2378+
/// Otherwise, returns false.
2379+
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp) {
2380+
if (operands.empty())
2381+
return false;
2382+
2383+
return llvm::all_of(operands, [&](Value operand) {
2384+
Operation *currentDef = operand.getDefiningOp();
2385+
return currentDef == defOp;
2386+
});
2387+
}
2388+
2389+
/// Folds vector.to_elements(vector.from_elements(%e0, %e1, ...)) into
2390+
/// (%e0, %e1, ...). For example:
2391+
///
2392+
/// %0 = vector.from_elements %a, %b, %c : vector<3xf32>
2393+
/// %1:3 = vector.to_elements %0 : vector<3xf32>
2394+
/// user_op %1#0, %1#1, %1#2
2395+
///
2396+
/// becomes:
2397+
///
2398+
/// user_op %a, %b, %c
2399+
///
2400+
static LogicalResult
2401+
foldToElementsFromElements(ToElementsOp toElementsOp,
2402+
SmallVectorImpl<OpFoldResult> &results) {
2403+
auto fromElementsOp = toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2404+
if (!fromElementsOp)
2405+
return failure();
2406+
2407+
results.append(fromElementsOp.getElements().begin(),
2408+
fromElementsOp.getElements().end());
2409+
return success();
2410+
}
2411+
2412+
LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2413+
SmallVectorImpl<OpFoldResult> &results) {
2414+
if (succeeded(foldToElementsFromElements(*this, results)))
2415+
return success();
2416+
return failure();
2417+
}
2418+
23732419
//===----------------------------------------------------------------------===//
23742420
// FromElementsOp
23752421
//===----------------------------------------------------------------------===//
23762422

2423+
/// Folds vector.from_elements(vector.to_elements(%vector)) into %vector.
2424+
///
2425+
/// Case #1: Input and output vectors are the same.
2426+
///
2427+
/// %0:3 = vector.to_elements %a : vector<3xf32>
2428+
/// %1 = vector.from_elements %0#0, %0#1, %0#2 : vector<3xf32>
2429+
/// user_op %1
2430+
///
2431+
/// becomes:
2432+
///
2433+
/// user_op %a
2434+
///
2435+
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
2436+
auto fromElemsOperands = fromElementsOp.getElements();
2437+
2438+
if (fromElemsOperands.empty())
2439+
return {};
2440+
2441+
auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2442+
if (!toElementsOp)
2443+
return {};
2444+
2445+
if (!haveSameDefiningOp(fromElemsOperands, toElementsOp))
2446+
return {};
2447+
2448+
// Case #1: Input and output vectors are the same. Forward the input vector.
2449+
Value toElementsInput = toElementsOp.getSource();
2450+
if (fromElementsOp.getType() == toElementsInput.getType() &&
2451+
llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2452+
return toElementsInput;
2453+
}
2454+
2455+
// TODO: Support cases with different input and output shapes and different
2456+
// number of elements.
2457+
2458+
return {};
2459+
}
2460+
2461+
OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2462+
if (auto result = foldFromElementsToElements(*this))
2463+
return result;
2464+
return {};
2465+
}
2466+
23772467
/// Rewrite a vector.from_elements into a vector.splat if all elements are the
23782468
/// same SSA value. E.g.:
23792469
///

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3023,6 +3023,58 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
30233023

30243024
// -----
30253025

3026+
// CHECK-LABEL: func @to_elements_from_elements_no_op(
3027+
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32
3028+
func.func @to_elements_from_elements_no_op(%a: f32, %b: f32) -> (f32, f32) {
3029+
// CHECK-NOT: vector.from_elements
3030+
// CHECK-NOT: vector.to_elements
3031+
%0 = vector.from_elements %b, %a : vector<2xf32>
3032+
%1:2 = vector.to_elements %0 : vector<2xf32>
3033+
// CHECK: return %[[B]], %[[A]]
3034+
return %1#0, %1#1 : f32, f32
3035+
}
3036+
3037+
// -----
3038+
3039+
// CHECK-LABEL: func @from_elements_to_elements_no_op(
3040+
// CHECK-SAME: %[[A:.*]]: vector<4x2xf32>
3041+
func.func @from_elements_to_elements_no_op(%a: vector<4x2xf32>) -> vector<4x2xf32> {
3042+
// CHECK-NOT: vector.from_elements
3043+
// CHECK-NOT: vector.to_elements
3044+
%0:8 = vector.to_elements %a : vector<4x2xf32>
3045+
%1 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7 : vector<4x2xf32>
3046+
// CHECK: return %[[A]]
3047+
return %1 : vector<4x2xf32>
3048+
}
3049+
3050+
// -----
3051+
3052+
// CHECK-LABEL: func @from_elements_to_elements_dup_elems(
3053+
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
3054+
func.func @from_elements_to_elements_dup_elems(%a: vector<4xf32>) -> vector<4x2xf32> {
3055+
// CHECK: %[[TO_EL:.*]]:4 = vector.to_elements %[[A]]
3056+
// CHECK: %[[FROM_EL:.*]] = vector.from_elements %[[TO_EL]]#0, %[[TO_EL]]#1, %[[TO_EL]]#2
3057+
%0:4 = vector.to_elements %a : vector<4xf32> // 4 elements
3058+
%1 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %0#0, %0#1, %0#2, %0#3 : vector<4x2xf32>
3059+
// CHECK: return %[[FROM_EL]]
3060+
return %1 : vector<4x2xf32>
3061+
}
3062+
3063+
// -----
3064+
3065+
// CHECK-LABEL: func @from_elements_to_elements_shuffle(
3066+
// CHECK-SAME: %[[A:.*]]: vector<4x2xf32>
3067+
func.func @from_elements_to_elements_shuffle(%a: vector<4x2xf32>) -> vector<4x2xf32> {
3068+
// CHECK: %[[TO_EL:.*]]:8 = vector.to_elements %[[A]]
3069+
// CHECK: %[[FROM_EL:.*]] = vector.from_elements %[[TO_EL]]#7, %[[TO_EL]]#0, %[[TO_EL]]#6
3070+
%0:8 = vector.to_elements %a : vector<4x2xf32>
3071+
%1 = vector.from_elements %0#7, %0#0, %0#6, %0#1, %0#5, %0#2, %0#4, %0#3 : vector<4x2xf32>
3072+
// CHECK: return %[[FROM_EL]]
3073+
return %1 : vector<4x2xf32>
3074+
}
3075+
3076+
// -----
3077+
30263078
// CHECK-LABEL: func @vector_insert_const_regression(
30273079
// CHECK: llvm.mlir.undef
30283080
// CHECK: vector.insert

0 commit comments

Comments
 (0)