Skip to content

[DAGCombiner] Add combine for vector interleave of splats #151110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

david-arm
Copy link
Contributor

@david-arm david-arm commented Jul 29, 2025

This patch adds a DAG combine that looks for
concat_vectors(vector_interleave(splat, splat, ...)),
where all the splats are identical.

For fixed-width vectors the DAG combine only occurs for
interleave factors of 3 or more, however it's not currently
safe to test this for AArch64 since there isn't any lowering
support for fixed-width interleaves. I've only added
fixed-width tests for RISCV.

@llvmbot
Copy link
Member

llvmbot commented Jul 29, 2025

@llvm/pr-subscribers-backend-risc-v
@llvm/pr-subscribers-backend-aarch64

@llvm/pr-subscribers-llvm-selectiondag

Author: David Sherwood (david-arm)

Changes

When lowering vector.interleave we should check if every operand is
the same splat and lower directly to a wider splat if so. This
avoids having to create a VECTOR_INTERLEAVE node, which expects to
return a struct of values that we then have to concatenate. We
could also do this with a DAG combine that looks for
concat_vectors(vector_interleave(splat, splat, ...)), but it seemed
like a lot of unnecessary extra complexity.

While here I fixed up the induction variable names to meet the
coding standard.


Patch is 24.60 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/151110.diff

5 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+24-11)
  • (modified) llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-predicated-scalable.ll (+26-29)
  • (modified) llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-scalable.ll (+25-27)
  • (modified) llvm/test/CodeGen/AArch64/fixed-vector-interleave.ll (+88-1)
  • (modified) llvm/test/CodeGen/AArch64/sve-vector-interleave.ll (+94-1)
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 306e068f1c1da..dcdd2cf5e1a11 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -12594,18 +12594,31 @@ void SelectionDAGBuilder::visitVectorDeinterleave(const CallInst &I,
   setValue(&I, Res);
 }
 
-void SelectionDAGBuilder::visitVectorInterleave(const CallInst &I,
+void SelectionDAGBuilder::visitVectorInterleave(const CallInst &CI,
                                                 unsigned Factor) {
   auto DL = getCurSDLoc();
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-  EVT InVT = getValue(I.getOperand(0)).getValueType();
-  EVT OutVT = TLI.getValueType(DAG.getDataLayout(), I.getType());
+  EVT InVT = getValue(CI.getOperand(0)).getValueType();
+  EVT OutVT = TLI.getValueType(DAG.getDataLayout(), CI.getType());
 
   SmallVector<SDValue, 8> InVecs(Factor);
-  for (unsigned i = 0; i < Factor; ++i) {
-    InVecs[i] = getValue(I.getOperand(i));
-    assert(InVecs[i].getValueType() == InVecs[0].getValueType() &&
+  bool OperandsAreSame = true;
+  for (unsigned I = 0; I < Factor; ++I) {
+    InVecs[I] = getValue(CI.getOperand(I));
+    assert(InVecs[I].getValueType() == InVecs[0].getValueType() &&
            "Expected VTs to be the same");
+    if (InVecs[I] != InVecs[0])
+      OperandsAreSame = false;
+  }
+
+  if (OperandsAreSame) {
+    if (auto *C = dyn_cast<Constant>(CI.getOperand(0))) {
+      if (auto *SV = C->getSplatValue()) {
+        SDValue Res = DAG.getNode(ISD::SPLAT_VECTOR, DL, OutVT, getValue(SV));
+        setValue(&CI, Res);
+        return;
+      }
+    }
   }
 
   // Use VECTOR_SHUFFLE for fixed-length vectors with factor of 2 to benefit
@@ -12613,8 +12626,8 @@ void SelectionDAGBuilder::visitVectorInterleave(const CallInst &I,
   if (OutVT.isFixedLengthVector() && Factor == 2) {
     unsigned NumElts = InVT.getVectorMinNumElements();
     SDValue V = DAG.getNode(ISD::CONCAT_VECTORS, DL, OutVT, InVecs);
-    setValue(&I, DAG.getVectorShuffle(OutVT, DL, V, DAG.getUNDEF(OutVT),
-                                      createInterleaveMask(NumElts, 2)));
+    setValue(&CI, DAG.getVectorShuffle(OutVT, DL, V, DAG.getUNDEF(OutVT),
+                                       createInterleaveMask(NumElts, 2)));
     return;
   }
 
@@ -12623,11 +12636,11 @@ void SelectionDAGBuilder::visitVectorInterleave(const CallInst &I,
       DAG.getNode(ISD::VECTOR_INTERLEAVE, DL, DAG.getVTList(ValueVTs), InVecs);
 
   SmallVector<SDValue, 8> Results(Factor);
-  for (unsigned i = 0; i < Factor; ++i)
-    Results[i] = Res.getValue(i);
+  for (unsigned I = 0; I < Factor; ++I)
+    Results[I] = Res.getValue(I);
 
   Res = DAG.getNode(ISD::CONCAT_VECTORS, DL, OutVT, Results);
-  setValue(&I, Res);
+  setValue(&CI, Res);
 }
 
 void SelectionDAGBuilder::visitFreeze(const FreezeInst &I) {
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-predicated-scalable.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-predicated-scalable.ll
index 880bd2904154c..d67aa08125f74 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-predicated-scalable.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-predicated-scalable.ll
@@ -14,20 +14,19 @@ target triple = "aarch64"
 define %"class.std::complex" @complex_mul_v2f64(ptr %a, ptr %b) {
 ; CHECK-LABEL: complex_mul_v2f64:
 ; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    movi v0.2d, #0000000000000000
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:    mov w8, #100 // =0x64
-; CHECK-NEXT:    cntd x9
 ; CHECK-NEXT:    whilelo p1.d, xzr, x8
+; CHECK-NEXT:    cntd x9
 ; CHECK-NEXT:    rdvl x10, #2
-; CHECK-NEXT:    mov x11, x9
 ; CHECK-NEXT:    ptrue p0.d
-; CHECK-NEXT:    zip2 z0.d, z1.d, z1.d
-; CHECK-NEXT:    zip1 z1.d, z1.d, z1.d
+; CHECK-NEXT:    mov x11, x9
 ; CHECK-NEXT:  .LBB0_1: // %vector.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
 ; CHECK-NEXT:    zip2 p2.d, p1.d, p1.d
-; CHECK-NEXT:    mov z6.d, z1.d
-; CHECK-NEXT:    mov z7.d, z0.d
+; CHECK-NEXT:    mov z6.d, z0.d
+; CHECK-NEXT:    mov z7.d, z1.d
 ; CHECK-NEXT:    zip1 p1.d, p1.d, p1.d
 ; CHECK-NEXT:    ld1d { z2.d }, p2/z, [x0, #1, mul vl]
 ; CHECK-NEXT:    ld1d { z4.d }, p2/z, [x1, #1, mul vl]
@@ -39,14 +38,14 @@ define %"class.std::complex" @complex_mul_v2f64(ptr %a, ptr %b) {
 ; CHECK-NEXT:    fcmla z6.d, p0/m, z5.d, z3.d, #0
 ; CHECK-NEXT:    fcmla z7.d, p0/m, z4.d, z2.d, #90
 ; CHECK-NEXT:    fcmla z6.d, p0/m, z5.d, z3.d, #90
-; CHECK-NEXT:    mov z0.d, p2/m, z7.d
-; CHECK-NEXT:    mov z1.d, p1/m, z6.d
+; CHECK-NEXT:    mov z1.d, p2/m, z7.d
+; CHECK-NEXT:    mov z0.d, p1/m, z6.d
 ; CHECK-NEXT:    whilelo p1.d, x11, x8
 ; CHECK-NEXT:    add x11, x11, x9
 ; CHECK-NEXT:    b.mi .LBB0_1
 ; CHECK-NEXT:  // %bb.2: // %exit.block
-; CHECK-NEXT:    uzp1 z2.d, z1.d, z0.d
-; CHECK-NEXT:    uzp2 z1.d, z1.d, z0.d
+; CHECK-NEXT:    uzp1 z2.d, z0.d, z1.d
+; CHECK-NEXT:    uzp2 z1.d, z0.d, z1.d
 ; CHECK-NEXT:    faddv d0, p0, z2.d
 ; CHECK-NEXT:    faddv d1, p0, z1.d
 ; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
@@ -111,21 +110,20 @@ exit.block:                                     ; preds = %vector.body
 define %"class.std::complex" @complex_mul_predicated_v2f64(ptr %a, ptr %b, ptr %cond) {
 ; CHECK-LABEL: complex_mul_predicated_v2f64:
 ; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    movi v0.2d, #0000000000000000
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:    cntd x9
-; CHECK-NEXT:    mov w11, #100 // =0x64
 ; CHECK-NEXT:    neg x10, x9
+; CHECK-NEXT:    mov w11, #100 // =0x64
 ; CHECK-NEXT:    ptrue p0.d
 ; CHECK-NEXT:    mov x8, xzr
 ; CHECK-NEXT:    and x10, x10, x11
 ; CHECK-NEXT:    rdvl x11, #2
-; CHECK-NEXT:    zip2 z0.d, z1.d, z1.d
-; CHECK-NEXT:    zip1 z1.d, z1.d, z1.d
 ; CHECK-NEXT:  .LBB1_1: // %vector.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
 ; CHECK-NEXT:    ld1w { z2.d }, p0/z, [x2, x8, lsl #2]
-; CHECK-NEXT:    mov z6.d, z1.d
-; CHECK-NEXT:    mov z7.d, z0.d
+; CHECK-NEXT:    mov z6.d, z0.d
+; CHECK-NEXT:    mov z7.d, z1.d
 ; CHECK-NEXT:    add x8, x8, x9
 ; CHECK-NEXT:    cmpne p1.d, p0/z, z2.d, #0
 ; CHECK-NEXT:    cmp x10, x8
@@ -141,12 +139,12 @@ define %"class.std::complex" @complex_mul_predicated_v2f64(ptr %a, ptr %b, ptr %
 ; CHECK-NEXT:    fcmla z6.d, p0/m, z5.d, z3.d, #0
 ; CHECK-NEXT:    fcmla z7.d, p0/m, z4.d, z2.d, #90
 ; CHECK-NEXT:    fcmla z6.d, p0/m, z5.d, z3.d, #90
-; CHECK-NEXT:    mov z0.d, p2/m, z7.d
-; CHECK-NEXT:    mov z1.d, p1/m, z6.d
+; CHECK-NEXT:    mov z1.d, p2/m, z7.d
+; CHECK-NEXT:    mov z0.d, p1/m, z6.d
 ; CHECK-NEXT:    b.ne .LBB1_1
 ; CHECK-NEXT:  // %bb.2: // %exit.block
-; CHECK-NEXT:    uzp1 z2.d, z1.d, z0.d
-; CHECK-NEXT:    uzp2 z1.d, z1.d, z0.d
+; CHECK-NEXT:    uzp1 z2.d, z0.d, z1.d
+; CHECK-NEXT:    uzp2 z1.d, z0.d, z1.d
 ; CHECK-NEXT:    faddv d0, p0, z2.d
 ; CHECK-NEXT:    faddv d1, p0, z1.d
 ; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
@@ -213,21 +211,20 @@ exit.block:                                     ; preds = %vector.body
 define %"class.std::complex" @complex_mul_predicated_x2_v2f64(ptr %a, ptr %b, ptr %cond) {
 ; CHECK-LABEL: complex_mul_predicated_x2_v2f64:
 ; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    movi v0.2d, #0000000000000000
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:    mov w8, #100 // =0x64
-; CHECK-NEXT:    cntd x9
 ; CHECK-NEXT:    whilelo p1.d, xzr, x8
+; CHECK-NEXT:    cntd x9
 ; CHECK-NEXT:    rdvl x10, #2
-; CHECK-NEXT:    cnth x11
 ; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    cnth x11
 ; CHECK-NEXT:    mov x12, x9
-; CHECK-NEXT:    zip2 z0.d, z1.d, z1.d
-; CHECK-NEXT:    zip1 z1.d, z1.d, z1.d
 ; CHECK-NEXT:  .LBB2_1: // %vector.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
 ; CHECK-NEXT:    ld1w { z2.d }, p1/z, [x2]
-; CHECK-NEXT:    mov z6.d, z1.d
-; CHECK-NEXT:    mov z7.d, z0.d
+; CHECK-NEXT:    mov z6.d, z0.d
+; CHECK-NEXT:    mov z7.d, z1.d
 ; CHECK-NEXT:    add x2, x2, x11
 ; CHECK-NEXT:    and z2.d, z2.d, #0xffffffff
 ; CHECK-NEXT:    cmpne p1.d, p1/z, z2.d, #0
@@ -243,14 +240,14 @@ define %"class.std::complex" @complex_mul_predicated_x2_v2f64(ptr %a, ptr %b, pt
 ; CHECK-NEXT:    fcmla z6.d, p0/m, z5.d, z3.d, #0
 ; CHECK-NEXT:    fcmla z7.d, p0/m, z4.d, z2.d, #90
 ; CHECK-NEXT:    fcmla z6.d, p0/m, z5.d, z3.d, #90
-; CHECK-NEXT:    mov z0.d, p2/m, z7.d
-; CHECK-NEXT:    mov z1.d, p1/m, z6.d
+; CHECK-NEXT:    mov z1.d, p2/m, z7.d
+; CHECK-NEXT:    mov z0.d, p1/m, z6.d
 ; CHECK-NEXT:    whilelo p1.d, x12, x8
 ; CHECK-NEXT:    add x12, x12, x9
 ; CHECK-NEXT:    b.mi .LBB2_1
 ; CHECK-NEXT:  // %bb.2: // %exit.block
-; CHECK-NEXT:    uzp1 z2.d, z1.d, z0.d
-; CHECK-NEXT:    uzp2 z1.d, z1.d, z0.d
+; CHECK-NEXT:    uzp1 z2.d, z0.d, z1.d
+; CHECK-NEXT:    uzp2 z1.d, z0.d, z1.d
 ; CHECK-NEXT:    faddv d0, p0, z2.d
 ; CHECK-NEXT:    faddv d1, p0, z1.d
 ; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-scalable.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-scalable.ll
index 29be231920305..0646ca4948e1d 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-scalable.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-scalable.ll
@@ -14,15 +14,14 @@ target triple = "aarch64"
 define %"class.std::complex" @complex_mul_v2f64(ptr %a, ptr %b) {
 ; CHECK-LABEL: complex_mul_v2f64:
 ; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    movi v0.2d, #0000000000000000
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:    cntd x8
-; CHECK-NEXT:    mov w10, #100 // =0x64
 ; CHECK-NEXT:    neg x9, x8
+; CHECK-NEXT:    mov w10, #100 // =0x64
 ; CHECK-NEXT:    ptrue p0.d
 ; CHECK-NEXT:    and x9, x9, x10
 ; CHECK-NEXT:    rdvl x10, #2
-; CHECK-NEXT:    zip2 z0.d, z1.d, z1.d
-; CHECK-NEXT:    zip1 z1.d, z1.d, z1.d
 ; CHECK-NEXT:  .LBB0_1: // %vector.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
 ; CHECK-NEXT:    ldr z2, [x0, #1, mul vl]
@@ -32,14 +31,14 @@ define %"class.std::complex" @complex_mul_v2f64(ptr %a, ptr %b) {
 ; CHECK-NEXT:    ldr z5, [x1]
 ; CHECK-NEXT:    add x1, x1, x10
 ; CHECK-NEXT:    add x0, x0, x10
-; CHECK-NEXT:    fcmla z1.d, p0/m, z5.d, z3.d, #0
-; CHECK-NEXT:    fcmla z0.d, p0/m, z4.d, z2.d, #0
-; CHECK-NEXT:    fcmla z1.d, p0/m, z5.d, z3.d, #90
-; CHECK-NEXT:    fcmla z0.d, p0/m, z4.d, z2.d, #90
+; CHECK-NEXT:    fcmla z0.d, p0/m, z5.d, z3.d, #0
+; CHECK-NEXT:    fcmla z1.d, p0/m, z4.d, z2.d, #0
+; CHECK-NEXT:    fcmla z0.d, p0/m, z5.d, z3.d, #90
+; CHECK-NEXT:    fcmla z1.d, p0/m, z4.d, z2.d, #90
 ; CHECK-NEXT:    b.ne .LBB0_1
 ; CHECK-NEXT:  // %bb.2: // %exit.block
-; CHECK-NEXT:    uzp1 z2.d, z1.d, z0.d
-; CHECK-NEXT:    uzp2 z1.d, z1.d, z0.d
+; CHECK-NEXT:    uzp1 z2.d, z0.d, z1.d
+; CHECK-NEXT:    uzp2 z1.d, z0.d, z1.d
 ; CHECK-NEXT:    faddv d0, p0, z2.d
 ; CHECK-NEXT:    faddv d1, p0, z1.d
 ; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
@@ -183,17 +182,16 @@ exit.block:                                     ; preds = %vector.body
 define %"class.std::complex" @complex_mul_v2f64_unrolled(ptr %a, ptr %b) {
 ; CHECK-LABEL: complex_mul_v2f64_unrolled:
 ; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    movi v0.2d, #0000000000000000
 ; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:    cntw x8
-; CHECK-NEXT:    mov w10, #1000 // =0x3e8
+; CHECK-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-NEXT:    movi v3.2d, #0000000000000000
 ; CHECK-NEXT:    neg x9, x8
+; CHECK-NEXT:    mov w10, #1000 // =0x3e8
 ; CHECK-NEXT:    ptrue p0.d
 ; CHECK-NEXT:    and x9, x9, x10
 ; CHECK-NEXT:    rdvl x10, #4
-; CHECK-NEXT:    zip2 z0.d, z1.d, z1.d
-; CHECK-NEXT:    zip1 z1.d, z1.d, z1.d
-; CHECK-NEXT:    mov z2.d, z1.d
-; CHECK-NEXT:    mov z3.d, z0.d
 ; CHECK-NEXT:  .LBB2_1: // %vector.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
 ; CHECK-NEXT:    ldr z4, [x0, #1, mul vl]
@@ -207,20 +205,20 @@ define %"class.std::complex" @complex_mul_v2f64_unrolled(ptr %a, ptr %b) {
 ; CHECK-NEXT:    ldr z18, [x1, #3, mul vl]
 ; CHECK-NEXT:    ldr z19, [x1, #2, mul vl]
 ; CHECK-NEXT:    add x1, x1, x10
-; CHECK-NEXT:    fcmla z1.d, p0/m, z16.d, z5.d, #0
-; CHECK-NEXT:    fcmla z0.d, p0/m, z7.d, z4.d, #0
+; CHECK-NEXT:    fcmla z0.d, p0/m, z16.d, z5.d, #0
+; CHECK-NEXT:    fcmla z1.d, p0/m, z7.d, z4.d, #0
 ; CHECK-NEXT:    fcmla z3.d, p0/m, z18.d, z6.d, #0
 ; CHECK-NEXT:    fcmla z2.d, p0/m, z19.d, z17.d, #0
-; CHECK-NEXT:    fcmla z1.d, p0/m, z16.d, z5.d, #90
-; CHECK-NEXT:    fcmla z0.d, p0/m, z7.d, z4.d, #90
+; CHECK-NEXT:    fcmla z0.d, p0/m, z16.d, z5.d, #90
+; CHECK-NEXT:    fcmla z1.d, p0/m, z7.d, z4.d, #90
 ; CHECK-NEXT:    fcmla z3.d, p0/m, z18.d, z6.d, #90
 ; CHECK-NEXT:    fcmla z2.d, p0/m, z19.d, z17.d, #90
 ; CHECK-NEXT:    b.ne .LBB2_1
 ; CHECK-NEXT:  // %bb.2: // %exit.block
 ; CHECK-NEXT:    uzp1 z4.d, z2.d, z3.d
-; CHECK-NEXT:    uzp1 z5.d, z1.d, z0.d
+; CHECK-NEXT:    uzp1 z5.d, z0.d, z1.d
 ; CHECK-NEXT:    uzp2 z2.d, z2.d, z3.d
-; CHECK-NEXT:    uzp2 z0.d, z1.d, z0.d
+; CHECK-NEXT:    uzp2 z0.d, z0.d, z1.d
 ; CHECK-NEXT:    fadd z1.d, z4.d, z5.d
 ; CHECK-NEXT:    fadd z2.d, z2.d, z0.d
 ; CHECK-NEXT:    faddv d0, p0, z1.d
@@ -310,15 +308,15 @@ define dso_local %"class.std::complex" @reduction_mix(ptr %a, ptr %b, ptr noalia
 ; CHECK-LABEL: reduction_mix:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-NEXT:    movi v0.2d, #0000000000000000
 ; CHECK-NEXT:    cntd x9
-; CHECK-NEXT:    mov w11, #100 // =0x64
+; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:    neg x10, x9
+; CHECK-NEXT:    mov w11, #100 // =0x64
 ; CHECK-NEXT:    ptrue p0.d
 ; CHECK-NEXT:    mov x8, xzr
 ; CHECK-NEXT:    and x10, x10, x11
 ; CHECK-NEXT:    rdvl x11, #2
-; CHECK-NEXT:    zip2 z0.d, z2.d, z2.d
-; CHECK-NEXT:    zip1 z1.d, z2.d, z2.d
 ; CHECK-NEXT:  .LBB3_1: // %vector.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
 ; CHECK-NEXT:    ldr z3, [x0]
@@ -327,13 +325,13 @@ define dso_local %"class.std::complex" @reduction_mix(ptr %a, ptr %b, ptr noalia
 ; CHECK-NEXT:    ld1w { z5.d }, p0/z, [x3, x8, lsl #2]
 ; CHECK-NEXT:    add x8, x8, x9
 ; CHECK-NEXT:    cmp x10, x8
-; CHECK-NEXT:    fadd z0.d, z4.d, z0.d
-; CHECK-NEXT:    fadd z1.d, z3.d, z1.d
+; CHECK-NEXT:    fadd z1.d, z4.d, z1.d
+; CHECK-NEXT:    fadd z0.d, z3.d, z0.d
 ; CHECK-NEXT:    add z2.d, z5.d, z2.d
 ; CHECK-NEXT:    b.ne .LBB3_1
 ; CHECK-NEXT:  // %bb.2: // %middle.block
-; CHECK-NEXT:    uzp2 z3.d, z1.d, z0.d
-; CHECK-NEXT:    uzp1 z1.d, z1.d, z0.d
+; CHECK-NEXT:    uzp2 z3.d, z0.d, z1.d
+; CHECK-NEXT:    uzp1 z1.d, z0.d, z1.d
 ; CHECK-NEXT:    uaddv d2, p0, z2.d
 ; CHECK-NEXT:    faddv d0, p0, z3.d
 ; CHECK-NEXT:    faddv d1, p0, z1.d
diff --git a/llvm/test/CodeGen/AArch64/fixed-vector-interleave.ll b/llvm/test/CodeGen/AArch64/fixed-vector-interleave.ll
index a9618fdc2dec3..54f4543c81d4c 100644
--- a/llvm/test/CodeGen/AArch64/fixed-vector-interleave.ll
+++ b/llvm/test/CodeGen/AArch64/fixed-vector-interleave.ll
@@ -131,6 +131,92 @@ define <4 x i64> @interleave2_v4i64(<2 x i64> %vec0, <2 x i64> %vec1) {
   ret <4 x i64> %retval
 }
 
+define <4 x i16> @interleave2_same_const_splat_v4i16() {
+; CHECK-SD-LABEL: interleave2_same_const_splat_v4i16:
+; CHECK-SD:       // %bb.0:
+; CHECK-SD-NEXT:    movi v0.4h, #3
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: interleave2_same_const_splat_v4i16:
+; CHECK-GI:       // %bb.0:
+; CHECK-GI-NEXT:    mov w8, #3 // =0x3
+; CHECK-GI-NEXT:    fmov s0, w8
+; CHECK-GI-NEXT:    mov v0.h[1], w8
+; CHECK-GI-NEXT:    zip1 v0.4h, v0.4h, v0.4h
+; CHECK-GI-NEXT:    ret
+  %retval = call <4 x i16> @llvm.vector.interleave2.v4i16(<2 x i16> splat(i16 3), <2 x i16> splat(i16 3))
+  ret <4 x i16> %retval
+}
+
+define <4 x i16> @interleave2_diff_const_splat_v4i16() {
+; CHECK-SD-LABEL: interleave2_diff_const_splat_v4i16:
+; CHECK-SD:       // %bb.0:
+; CHECK-SD-NEXT:    adrp x8, .LCPI11_0
+; CHECK-SD-NEXT:    ldr d0, [x8, :lo12:.LCPI11_0]
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: interleave2_diff_const_splat_v4i16:
+; CHECK-GI:       // %bb.0:
+; CHECK-GI-NEXT:    mov w8, #3 // =0x3
+; CHECK-GI-NEXT:    mov w9, #4 // =0x4
+; CHECK-GI-NEXT:    fmov s0, w8
+; CHECK-GI-NEXT:    fmov s1, w9
+; CHECK-GI-NEXT:    mov v0.h[1], w8
+; CHECK-GI-NEXT:    mov v1.h[1], w9
+; CHECK-GI-NEXT:    zip1 v0.4h, v0.4h, v1.4h
+; CHECK-GI-NEXT:    ret
+  %retval = call <4 x i16> @llvm.vector.interleave2.v4i16(<2 x i16> splat(i16 3), <2 x i16> splat(i16 4))
+  ret <4 x i16> %retval
+}
+
+define <4 x i16> @interleave2_same_nonconst_splat_v4i16(i16 %a) {
+; CHECK-SD-LABEL: interleave2_same_nonconst_splat_v4i16:
+; CHECK-SD:       // %bb.0:
+; CHECK-SD-NEXT:    dup v0.4h, w0
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: interleave2_same_nonconst_splat_v4i16:
+; CHECK-GI:       // %bb.0:
+; CHECK-GI-NEXT:    dup v0.4h, w0
+; CHECK-GI-NEXT:    zip1 v0.4h, v0.4h, v0.4h
+; CHECK-GI-NEXT:    ret
+  %ins = insertelement <2 x i16> poison, i16 %a, i32 0
+  %splat = shufflevector <2 x i16> %ins, <2 x i16> poison, <2 x i32> <i32 0, i32 0>
+  %retval = call <4 x i16> @llvm.vector.interleave2.v4i16(<2 x i16> %splat, <2 x i16> %splat)
+  ret <4 x i16> %retval
+}
+
+define <4 x i16> @interleave2_diff_nonconst_splat_v4i16(i16 %a, i16 %b) {
+; CHECK-SD-LABEL: interleave2_diff_nonconst_splat_v4i16:
+; CHECK-SD:       // %bb.0:
+; CHECK-SD-NEXT:    fmov s0, w0
+; CHECK-SD-NEXT:    mov v0.h[1], w0
+; CHECK-SD-NEXT:    mov v0.h[2], w1
+; CHECK-SD-NEXT:    mov v0.h[3], w1
+; CHECK-SD-NEXT:    rev32 v1.4h, v0.4h
+; CHECK-SD-NEXT:    uzp1 v0.4h, v0.4h, v1.4h
+; CHECK-SD-NEXT:    ret
+;
+; CHECK-GI-LABEL: interleave2_diff_nonconst_splat_v4i16:
+; CHECK-GI:       // %bb.0:
+; CHECK-GI-NEXT:    dup v0.4h, w0
+; CHECK-GI-NEXT:    dup v1.4h, w1
+; CHECK-GI-NEXT:    zip1 v0.4h, v0.4h, v1.4h
+; CHECK-GI-NEXT:    ret
+  %ins1 = insertelement <2 x i16> poison, i16 %a, i32 0
+  %splat1 = shufflevector <2 x i16> %ins1, <2 x i16> poison, <2 x i32> <i32 0, i32 0>
+  %ins2 = insertelement <2 x i16> poison, i16 %b, i32 0
+  %splat2 = shufflevector <2 x i16> %ins2, <2 x i16> poison, <2 x i32> <i32 0, i32 0>
+  %retval = call <4 x i16> @llvm.vector.interleave2.v4i16(<2 x i16> %splat1, <2 x i16> %splat2)
+  ret <4 x i16> %retval
+}
+
+; FIXME: This test crashes during lowering
+;define <8 x i16> @interleave4_const_splat_v8i16(<2 x i16> %a) {
+;  %retval = call <8 x i16> @llvm.vector.interleave4.v8i16(<2 x i16> splat(i16 3), <2 x i16> splat(i16 3), <2 x i16> splat(i16 3), <2 x i16> splat(i16 3))
+;  ret <8 x i16> %retval
+;}
+
 
 ; Float declarations
 declare <4 x half> @llvm.vector.interleave2.v4f16(<2 x half>, <2 x half>)
@@ -145,4 +231,5 @@ declare <32 x i8> @llvm.vector.interleave2.v32i8(<16 x i8>, <16 x i8>)
 declare <16 x i16> @llvm.vector.interleave2.v16i16(<8 x i16>, <8 x i16>)
 declare <8 x i32> @llvm.vector.interleave2.v8i32(<4 x i32>, <4 x i32>)
 declare <4 x i64> @llvm.vector.interleave2.v4i64(<2 x i64>, <2 x i64>)
-
+declare <4 x i16> @llvm.vector.interleave2.v4i16(<2 x i16>, <2 x i16>)
+declare <8 x i16> @llvm.vector.interleave4.v8i16(<2 x i16>, <2 x i16>, <2 x i16>, <2 x i16>)
diff --git a/llvm/test/CodeGen/AArch64/sve-vector-interleave.ll b/llvm/test/CodeGen/AArch64/sve-vector-interleave.ll
index 52cb2d9ebe343..ab9ddbd78e5e0 100644
--- a/llvm/test/CodeGen/AArch64/sve-vector-interleave.ll
+++ b/llvm/test/CodeGen/AArch64/sve-vector-interleave.ll
@@ -267,7 +267,7 @@ define <vscale x 32 x i16> @interleave4_nxv8i16(<vscale x 8 x i16> %vec0, <vscal
 ; SME2-NEXT:    // kill: def $z0 killed $z0 killed $z0_z1_z2_z3 def $z0_z1_z2_z3
 ; SME2-NEXT:    zip { z0.h - z3.h }, { z0.h - z3.h }
 ; SME2-NEXT:    ret
-  %retval = call <vscale x 32 x i16> @llvm.vector.interleave4.nxv8i16(<vscale x 8 x i16> %vec0, <vscale x 8 x i16> %vec1, <vscale x 8 x i16> %vec2, <vscale x 8 x i16> %vec3)
+  %retval = call <vscale x 32 x i16> @llvm.vector.interleave4.nxv32i16(<vscale x 8 x i16> %vec0, <vscale x 8 x i16> %vec1, <vscale x 8 x i16> %vec2, <vscale x 8 x i16> %vec3)
   ret <vscale x 32 x i16> %retval
 }
 
@@ -540,6 +540,97 @@ define <vscale x 4 x i32> @interleave2_nxv2i32(<vscale x 2 x i32> %vec0, <vscale
   ret <vscale x 4 x i32> %retval
 }
 
+define <vscale x 4 x i16> @interleave2_same_const_splat_nxv4i16() {
+; CHECK-LABEL: interleave2_same_const_splat_nxv4i16:
+;...
[truncated]

This patch adds a DAG combine that looks for
  concat_vectors(vector_interleave(splat, splat, ...)),
where all the splats are identical.

For fixed-width vectors the DAG combine only occurs for
interleave factors of 3 or more, however it's not currently
safe to test this for AArch64 since there isn't any lowering
support for fixed-width interleaves. I've only added
fixed-width tests for RISCV.
@david-arm david-arm changed the title [CodeGen] Lower vector interleaves of const splats to a wider splat [DAGCombiner] Add combine for vector interleave of splats Jul 29, 2025
Comment on lines 25188 to 25191
SDValue InOp0 = FirstOp.getOperand(0);
if (!llvm::all_of(FirstOp->ops(),
[&InOp0](SDValue Op) { return Op == InOp0; }))
return SDValue();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I presume the lambda's there because we need to cast from SDUse to SDValue? Does all_equal(FirstOp->op_values()) work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion, I'll give it a go!

Comment on lines 25432 to 25434
if (SDValue V = combineConcatVectorInterleave(N, DAG))
return V;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get an away with doing this directly on the interleave node?

We could combine it to a wide splat and then use extract_subvector for each resulting vector.

This is just an idea though so I'm not 100% if it'll work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case is it really a DAG combine? It sounds more like a simplification of a node, which might be better done at the time of creation in SelectionDAG::getNode perhaps? It could work, but I'm not sure how to do it because the node still needs to set N results. I see what you mean though - if only the first result gets used then this DAG combine doesn't help. I kind of expected the common use case at the moment is the pattern created by SelectionDAGBuilder when lowering the llvm.vector.interleave intrinsic. @paulwalker-arm any thoughts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Relying on the concat will unnecessarily restrict the combine making it unlikely to fire after type legalisation so I'd rather this be an interleave specific combine without caring about its uses. Which I think will simplify the code as well?

Doing this in SelectionDAG::getNode() sounds reasonable but then the combiner doesn't always create new nodes so having the combine could catch more case. These are multi-result nodes so perhaps just choose the path with the cleanest implementation (I'll guess DAGCombiner will win here).

Copy link
Contributor

@lukel97 lukel97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

FirstOp.getNumOperands() != N->getNumOperands())
return SDValue();

for (unsigned I = 0; I < N->getNumOperands(); I++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, does enumerating op_values also work here?

Suggested change
for (unsigned I = 0; I < N->getNumOperands(); I++) {
for (auto [Idx, Op] : enumerate(N->op_values())) {

Comment on lines +25201 to +25203
EVT WideVecTy =
EVT::getVectorVT(*DAG.getContext(), SubVecTy.getScalarType(),
SubVecTy.getVectorElementCount() * N->getNumOperands());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is WideVecTy is the same as the concat_vector's type? Can we just replace this with N->getValueType(0) instead?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants