diff --git a/lib/Conversion/CombToSynth/CombToSynth.cpp b/lib/Conversion/CombToSynth/CombToSynth.cpp index 31d6582f8b36..5c14ced44410 100644 --- a/lib/Conversion/CombToSynth/CombToSynth.cpp +++ b/lib/Conversion/CombToSynth/CombToSynth.cpp @@ -769,12 +769,8 @@ struct CombAddOpConversion : OpConversionPattern { matchAndRewrite(AddOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto inputs = adaptor.getInputs(); - // Lower only when there are two inputs. - // Variadic operands must be lowered in a different pattern. - if (inputs.size() != 2) - return failure(); - auto width = op.getType().getIntOrFloatBitWidth(); + // Skip a zero width value. if (width == 0) { replaceOpWithNewOpAndCopyNamehint(rewriter, op, @@ -782,23 +778,54 @@ struct CombAddOpConversion : OpConversionPattern { return success(); } - // Check if the architecture is specified by an attribute. + // Handle add(a, b, const_1): a + b + 1 = prefix adder with carry_in=1. + // This pattern comes from compress(a, b, const_1) in DatapathToComb, which + // arises from sub → add(a, ~b, 1). Detect const_1 in any operand position. + if (inputs.size() == 3) { + for (unsigned i = 0; i < 3; ++i) { + auto constOp = inputs[i].getDefiningOp(); + if (!constOp || !constOp.getValue().isOne()) + continue; + // Found the carry-in constant. Collect the two main operands. + SmallVector operands; + for (unsigned j = 0; j < 3; ++j) + if (j != i) + operands.push_back(inputs[j]); + // Create an i1 true value for the carry_in signal in the adder. + Value carryIn = + hw::ConstantOp::create(rewriter, op.getLoc(), APInt(1, 1)); + auto arch = determineAdderArch(op, width); + if (arch == AdderArchitecture::RippleCarry) + return lowerRippleCarryAdder(op, operands[0], operands[1], carryIn, + rewriter); + return lowerParallelPrefixAdder(op, operands[0], operands[1], carryIn, + rewriter); + } + return failure(); // 3-input add without const_1 — handled elsewhere + } + + // Lower only when there are two inputs. + // Variadic operands must be lowered in a different pattern. + if (inputs.size() != 2) + return failure(); + auto arch = determineAdderArch(op, width); if (arch == AdderArchitecture::RippleCarry) - return lowerRippleCarryAdder(op, inputs, rewriter); - return lowerParallelPrefixAdder(op, inputs, rewriter); + return lowerRippleCarryAdder(op, inputs[0], inputs[1], Value(), rewriter); + return lowerParallelPrefixAdder(op, inputs[0], inputs[1], Value(), rewriter); } // Implement a basic ripple-carry adder for small bitwidths. + // carryIn may be a null Value (meaning carry_in = 0). LogicalResult - lowerRippleCarryAdder(comb::AddOp op, ValueRange inputs, + lowerRippleCarryAdder(comb::AddOp op, Value inputA, Value inputB, + Value carryIn, ConversionPatternRewriter &rewriter) const { auto width = op.getType().getIntOrFloatBitWidth(); - // Implement a naive Ripple-carry full adder. - Value carry; + Value carry = carryIn; // null Value → carry_in = 0 - auto aBits = extractBits(rewriter, inputs[0]); - auto bBits = extractBits(rewriter, inputs[1]); + auto aBits = extractBits(rewriter, inputA); + auto bBits = extractBits(rewriter, inputB); SmallVector results; results.resize(width); for (int64_t i = 0; i < width; ++i) { @@ -817,7 +844,6 @@ struct CombAddOpConversion : OpConversionPattern { // carry[i] = (carry[i-1] & (a[i] ^ b[i])) | (a[i] & b[i]) if (!carry) { - // This is the first bit, so the carry is the next carry. carry = comb::AndOp::create(rewriter, op.getLoc(), ValueRange{aBits[i], bBits[i]}, true); continue; @@ -835,14 +861,16 @@ struct CombAddOpConversion : OpConversionPattern { // Implement a parallel prefix adder - with Kogge-Stone or Brent-Kung trees // Will introduce unused signals for the carry bits but these will be removed - // by the AIG pass. + // by the AIG pass. carryIn may be a null Value (meaning carry_in = 0). LogicalResult - lowerParallelPrefixAdder(comb::AddOp op, ValueRange inputs, + lowerParallelPrefixAdder(comb::AddOp op, Value inputA, Value inputB, + Value carryIn, ConversionPatternRewriter &rewriter) const { auto width = op.getType().getIntOrFloatBitWidth(); + auto loc = op.getLoc(); - auto aBits = extractBits(rewriter, inputs[0]); - auto bBits = extractBits(rewriter, inputs[1]); + auto aBits = extractBits(rewriter, inputA); + auto bBits = extractBits(rewriter, inputB); // Construct propagate (p) and generate (g) signals SmallVector p, g; @@ -851,9 +879,16 @@ struct CombAddOpConversion : OpConversionPattern { for (auto [aBit, bBit] : llvm::zip(aBits, bBits)) { // p_i = a_i XOR b_i - p.push_back(comb::XorOp::create(rewriter, op.getLoc(), aBit, bBit)); + p.push_back(comb::XorOp::create(rewriter, loc, aBit, bBit)); // g_i = a_i AND b_i - g.push_back(comb::AndOp::create(rewriter, op.getLoc(), aBit, bBit)); + g.push_back(comb::AndOp::create(rewriter, loc, aBit, bBit)); + } + + // With carry_in, adjust g[0]: g[0] = (a[0] AND b[0]) OR (p[0] AND carry_in) + // This bakes the carry_in into the prefix tree, avoiding a separate adder. + if (carryIn) { + Value pAndC = comb::AndOp::create(rewriter, loc, p[0], carryIn); + g[0] = comb::OrOp::create(rewriter, loc, g[0], pAndC); } LLVM_DEBUG({ @@ -880,28 +915,29 @@ struct CombAddOpConversion : OpConversionPattern { llvm_unreachable("Ripple-Carry should be handled separately"); break; case AdderArchitecture::Sklanskey: - lowerSklanskeyPrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix); + lowerSklanskeyPrefixTree(rewriter, loc, pPrefix, gPrefix); break; case AdderArchitecture::KoggeStone: - lowerKoggeStonePrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix); + lowerKoggeStonePrefixTree(rewriter, loc, pPrefix, gPrefix); break; case AdderArchitecture::BrentKung: - lowerBrentKungPrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix); + lowerBrentKungPrefixTree(rewriter, loc, pPrefix, gPrefix); break; } - // Generate result sum bits + // Generate result sum bits. // NOTE: The result is stored in reverse order. SmallVector results; results.resize(width); - // Sum bit 0 is just p[0] since carry_in = 0 - results[width - 1] = p[0]; + // sum[0] = p[0] XOR carry_in (carry_in = 0 when null → just p[0]) + results[width - 1] = + carryIn ? comb::XorOp::create(rewriter, loc, p[0], carryIn) : p[0]; // For remaining bits, sum_i = p_i XOR g_(i-1) // The carry into position i is the group generate from position i-1 for (int64_t i = 1; i < width; ++i) results[width - 1 - i] = - comb::XorOp::create(rewriter, op.getLoc(), p[i], gPrefix[i - 1]); + comb::XorOp::create(rewriter, loc, p[i], gPrefix[i - 1]); replaceOpWithNewOpAndCopyNamehint(rewriter, op, results); diff --git a/lib/Conversion/DatapathToComb/DatapathToComb.cpp b/lib/Conversion/DatapathToComb/DatapathToComb.cpp index e04c6d5a0ef3..79101f75955c 100644 --- a/lib/Conversion/DatapathToComb/DatapathToComb.cpp +++ b/lib/Conversion/DatapathToComb/DatapathToComb.cpp @@ -75,15 +75,51 @@ struct DatapathCompressOpConversion : mlir::OpRewritePattern { mlir::PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto inputs = op.getOperands(); + auto width = inputs[0].getType().getIntOrFloatBitWidth(); - SmallVector> addends; - for (auto input : inputs) { - addends.push_back( - extractBits(rewriter, input)); // Extract bits from each input + // Special case: compress(a, b, const_1) with exactly 3 inputs where one + // input is the constant 1. This pattern arises from sub → add(a, ~b, 1). + // Check that both results feed a single comb.add(carry, save), then + // replace compress + add together with comb.add(a, b, const_1) so the + // downstream prefix adder can bake carry_in=1 directly into g[0]. + if (inputs.size() == 3) { + for (unsigned i = 0; i < 3; ++i) { + auto constOp = inputs[i].getDefiningOp(); + if (!constOp || !constOp.getValue().isOne()) + continue; + // Both compress results must have exactly one use and share the same + // downstream comb.add(carry, save). + Value carry = op->getResult(0), save = op->getResult(1); + if (!carry.hasOneUse() || !save.hasOneUse()) + break; + auto *carryUser = carry.getUses().begin()->getOwner(); + auto *saveUser = save.getUses().begin()->getOwner(); + if (carryUser != saveUser) + break; + auto addOp = dyn_cast(carryUser); + if (!addOp || addOp.getInputs().size() != 2) + break; + // Replace compress(a, b, 1) + add(carry, save) with add(a, b, const_1). + // The const_1 tells CombToSynth to synthesize a prefix adder with + // carry_in=1 baked into g[0], avoiding a separate Wallace tree stage. + SmallVector operands; + for (unsigned j = 0; j < 3; ++j) + if (j != i) + operands.push_back(inputs[j]); + Value sum = comb::AddOp::create( + rewriter, loc, + ValueRange{operands[0], operands[1], inputs[i]}, true); + rewriter.replaceOp(addOp, sum); + rewriter.eraseOp(op); + return success(); + } } - // Compressor tree reduction - auto width = inputs[0].getType().getIntOrFloatBitWidth(); + // General case: Wallace tree reduction. + SmallVector> addends; + for (auto input : inputs) + addends.push_back(extractBits(rewriter, input)); + auto targetAddends = op.getNumResults(); datapath::CompressorTree comp(width, addends, loc);