Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 63 additions & 27 deletions lib/Conversion/CombToSynth/CombToSynth.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -769,36 +769,63 @@ struct CombAddOpConversion : OpConversionPattern<AddOp> {
matchAndRewrite(AddOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto inputs = adaptor.getInputs();
// Lower only when there are two inputs.
// Variadic operands must be lowered in a different pattern.
if (inputs.size() != 2)
return failure();

auto width = op.getType().getIntOrFloatBitWidth();

// Skip a zero width value.
if (width == 0) {
replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
op.getType(), 0);
return success();
}

// Check if the architecture is specified by an attribute.
// Handle add(a, b, const_1): a + b + 1 = prefix adder with carry_in=1.
// This pattern comes from compress(a, b, const_1) in DatapathToComb, which
// arises from sub → add(a, ~b, 1). Detect const_1 in any operand position.
if (inputs.size() == 3) {
for (unsigned i = 0; i < 3; ++i) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because comb::AddOp and datapath::CompressOp have commutative trait, usually constant is pushed to the last element. I think it's fine to check only the last element.
https://github.com/llvm/llvm-project/blob/3d421d59ad247afeadf5d4f886c9dea14a5eb229/mlir/docs/Canonicalization.md?plain=1#L117-L118

auto constOp = inputs[i].getDefiningOp<hw::ConstantOp>();
if (!constOp || !constOp.getValue().isOne())
continue;
// Found the carry-in constant. Collect the two main operands.
SmallVector<Value, 2> operands;
for (unsigned j = 0; j < 3; ++j)
if (j != i)
operands.push_back(inputs[j]);
// Create an i1 true value for the carry_in signal in the adder.
Value carryIn =
hw::ConstantOp::create(rewriter, op.getLoc(), APInt(1, 1));
auto arch = determineAdderArch(op, width);
if (arch == AdderArchitecture::RippleCarry)
return lowerRippleCarryAdder(op, operands[0], operands[1], carryIn,
rewriter);
return lowerParallelPrefixAdder(op, operands[0], operands[1], carryIn,
rewriter);
}
return failure(); // 3-input add without const_1 — handled elsewhere
}

// Lower only when there are two inputs.
// Variadic operands must be lowered in a different pattern.
if (inputs.size() != 2)
return failure();

auto arch = determineAdderArch(op, width);
if (arch == AdderArchitecture::RippleCarry)
return lowerRippleCarryAdder(op, inputs, rewriter);
return lowerParallelPrefixAdder(op, inputs, rewriter);
return lowerRippleCarryAdder(op, inputs[0], inputs[1], Value(), rewriter);
return lowerParallelPrefixAdder(op, inputs[0], inputs[1], Value(), rewriter);
Comment on lines +797 to +815
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you create a helper function for this part and share at two places?

    auto arch = determineAdderArch(op, width);
    if (arch == AdderArchitecture::RippleCarry)
      return lowerRippleCarryAdder(op, inputs[0], inputs[1], Value(), rewriter);
    return lowerParallelPrefixAdder(op, inputs[0], inputs[1], Value(), rewriter);

}

// Implement a basic ripple-carry adder for small bitwidths.
// carryIn may be a null Value (meaning carry_in = 0).
LogicalResult
lowerRippleCarryAdder(comb::AddOp op, ValueRange inputs,
lowerRippleCarryAdder(comb::AddOp op, Value inputA, Value inputB,
Value carryIn,
ConversionPatternRewriter &rewriter) const {
auto width = op.getType().getIntOrFloatBitWidth();
// Implement a naive Ripple-carry full adder.
Value carry;
Value carry = carryIn; // null Value → carry_in = 0
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally please use ASCII

Suggested change
Value carry = carryIn; // null Value carry_in = 0
Value carry = carryIn; // null Value -> carry_in = 0


auto aBits = extractBits(rewriter, inputs[0]);
auto bBits = extractBits(rewriter, inputs[1]);
auto aBits = extractBits(rewriter, inputA);
auto bBits = extractBits(rewriter, inputB);
SmallVector<Value> results;
results.resize(width);
for (int64_t i = 0; i < width; ++i) {
Expand All @@ -817,7 +844,6 @@ struct CombAddOpConversion : OpConversionPattern<AddOp> {

// carry[i] = (carry[i-1] & (a[i] ^ b[i])) | (a[i] & b[i])
if (!carry) {
// This is the first bit, so the carry is the next carry.
carry = comb::AndOp::create(rewriter, op.getLoc(),
ValueRange{aBits[i], bBits[i]}, true);
continue;
Expand All @@ -835,14 +861,16 @@ struct CombAddOpConversion : OpConversionPattern<AddOp> {

// Implement a parallel prefix adder - with Kogge-Stone or Brent-Kung trees
// Will introduce unused signals for the carry bits but these will be removed
// by the AIG pass.
// by the AIG pass. carryIn may be a null Value (meaning carry_in = 0).
LogicalResult
lowerParallelPrefixAdder(comb::AddOp op, ValueRange inputs,
lowerParallelPrefixAdder(comb::AddOp op, Value inputA, Value inputB,
Value carryIn,
ConversionPatternRewriter &rewriter) const {
auto width = op.getType().getIntOrFloatBitWidth();
auto loc = op.getLoc();

auto aBits = extractBits(rewriter, inputs[0]);
auto bBits = extractBits(rewriter, inputs[1]);
auto aBits = extractBits(rewriter, inputA);
auto bBits = extractBits(rewriter, inputB);

// Construct propagate (p) and generate (g) signals
SmallVector<Value> p, g;
Expand All @@ -851,9 +879,16 @@ struct CombAddOpConversion : OpConversionPattern<AddOp> {

for (auto [aBit, bBit] : llvm::zip(aBits, bBits)) {
// p_i = a_i XOR b_i
p.push_back(comb::XorOp::create(rewriter, op.getLoc(), aBit, bBit));
p.push_back(comb::XorOp::create(rewriter, loc, aBit, bBit));
// g_i = a_i AND b_i
g.push_back(comb::AndOp::create(rewriter, op.getLoc(), aBit, bBit));
g.push_back(comb::AndOp::create(rewriter, loc, aBit, bBit));
}

// With carry_in, adjust g[0]: g[0] = (a[0] AND b[0]) OR (p[0] AND carry_in)
// This bakes the carry_in into the prefix tree, avoiding a separate adder.
if (carryIn) {
Value pAndC = comb::AndOp::create(rewriter, loc, p[0], carryIn);
g[0] = comb::OrOp::create(rewriter, loc, g[0], pAndC);
}

LLVM_DEBUG({
Expand All @@ -880,28 +915,29 @@ struct CombAddOpConversion : OpConversionPattern<AddOp> {
llvm_unreachable("Ripple-Carry should be handled separately");
break;
case AdderArchitecture::Sklanskey:
lowerSklanskeyPrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix);
lowerSklanskeyPrefixTree(rewriter, loc, pPrefix, gPrefix);
break;
case AdderArchitecture::KoggeStone:
lowerKoggeStonePrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix);
lowerKoggeStonePrefixTree(rewriter, loc, pPrefix, gPrefix);
break;
case AdderArchitecture::BrentKung:
lowerBrentKungPrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix);
lowerBrentKungPrefixTree(rewriter, loc, pPrefix, gPrefix);
break;
}

// Generate result sum bits
// Generate result sum bits.
// NOTE: The result is stored in reverse order.
SmallVector<Value> results;
results.resize(width);
// Sum bit 0 is just p[0] since carry_in = 0
results[width - 1] = p[0];
// sum[0] = p[0] XOR carry_in (carry_in = 0 when null → just p[0])
results[width - 1] =
carryIn ? comb::XorOp::create(rewriter, loc, p[0], carryIn) : p[0];

// For remaining bits, sum_i = p_i XOR g_(i-1)
// The carry into position i is the group generate from position i-1
for (int64_t i = 1; i < width; ++i)
results[width - 1 - i] =
comb::XorOp::create(rewriter, op.getLoc(), p[i], gPrefix[i - 1]);
comb::XorOp::create(rewriter, loc, p[i], gPrefix[i - 1]);

replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(rewriter, op, results);

Expand Down
48 changes: 42 additions & 6 deletions lib/Conversion/DatapathToComb/DatapathToComb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,51 @@ struct DatapathCompressOpConversion : mlir::OpRewritePattern<CompressOp> {
mlir::PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto inputs = op.getOperands();
auto width = inputs[0].getType().getIntOrFloatBitWidth();

SmallVector<SmallVector<Value>> addends;
for (auto input : inputs) {
addends.push_back(
extractBits(rewriter, input)); // Extract bits from each input
// Special case: compress(a, b, const_1) with exactly 3 inputs where one
// input is the constant 1. This pattern arises from sub → add(a, ~b, 1).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use ASCII

// Check that both results feed a single comb.add(carry, save), then
// replace compress + add together with comb.add(a, b, const_1) so the
// downstream prefix adder can bake carry_in=1 directly into g[0].
if (inputs.size() == 3) {
for (unsigned i = 0; i < 3; ++i) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can remove this outer loop when we know where the final argument is

auto constOp = inputs[i].getDefiningOp<hw::ConstantOp>();
if (!constOp || !constOp.getValue().isOne())
continue;
Comment on lines +87 to +89
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above constOp should be found as the final argument

// Both compress results must have exactly one use and share the same
// downstream comb.add(carry, save).
Value carry = op->getResult(0), save = op->getResult(1);
if (!carry.hasOneUse() || !save.hasOneUse())
break;
auto *carryUser = carry.getUses().begin()->getOwner();
auto *saveUser = save.getUses().begin()->getOwner();
if (carryUser != saveUser)
break;
auto addOp = dyn_cast<comb::AddOp>(carryUser);
if (!addOp || addOp.getInputs().size() != 2)
break;
// Replace compress(a, b, 1) + add(carry, save) with add(a, b, const_1).
// The const_1 tells CombToSynth to synthesize a prefix adder with
// carry_in=1 baked into g[0], avoiding a separate Wallace tree stage.
SmallVector<Value> operands;
for (unsigned j = 0; j < 3; ++j)
if (j != i)
operands.push_back(inputs[j]);
Value sum = comb::AddOp::create(
rewriter, loc,
ValueRange{operands[0], operands[1], inputs[i]}, true);
rewriter.replaceOp(addOp, sum);
rewriter.eraseOp(op);
return success();
}
}

// Compressor tree reduction
auto width = inputs[0].getType().getIntOrFloatBitWidth();
// General case: Wallace tree reduction.
SmallVector<SmallVector<Value>> addends;
for (auto input : inputs)
addends.push_back(extractBits(rewriter, input));

auto targetAddends = op.getNumResults();
datapath::CompressorTree comp(width, addends, loc);

Expand Down
Loading