-
Notifications
You must be signed in to change notification settings - Fork 446
[Datapath] Implement Add with Carry-In to Improve Subtraction Circuit Implementation #9949
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -769,36 +769,63 @@ struct CombAddOpConversion : OpConversionPattern<AddOp> { | |||||
| matchAndRewrite(AddOp op, OpAdaptor adaptor, | ||||||
| ConversionPatternRewriter &rewriter) const override { | ||||||
| auto inputs = adaptor.getInputs(); | ||||||
| // Lower only when there are two inputs. | ||||||
| // Variadic operands must be lowered in a different pattern. | ||||||
| if (inputs.size() != 2) | ||||||
| return failure(); | ||||||
|
|
||||||
| auto width = op.getType().getIntOrFloatBitWidth(); | ||||||
|
|
||||||
| // Skip a zero width value. | ||||||
| if (width == 0) { | ||||||
| replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op, | ||||||
| op.getType(), 0); | ||||||
| return success(); | ||||||
| } | ||||||
|
|
||||||
| // Check if the architecture is specified by an attribute. | ||||||
| // Handle add(a, b, const_1): a + b + 1 = prefix adder with carry_in=1. | ||||||
| // This pattern comes from compress(a, b, const_1) in DatapathToComb, which | ||||||
| // arises from sub → add(a, ~b, 1). Detect const_1 in any operand position. | ||||||
| if (inputs.size() == 3) { | ||||||
| for (unsigned i = 0; i < 3; ++i) { | ||||||
| auto constOp = inputs[i].getDefiningOp<hw::ConstantOp>(); | ||||||
| if (!constOp || !constOp.getValue().isOne()) | ||||||
| continue; | ||||||
| // Found the carry-in constant. Collect the two main operands. | ||||||
| SmallVector<Value, 2> operands; | ||||||
| for (unsigned j = 0; j < 3; ++j) | ||||||
| if (j != i) | ||||||
| operands.push_back(inputs[j]); | ||||||
| // Create an i1 true value for the carry_in signal in the adder. | ||||||
| Value carryIn = | ||||||
| hw::ConstantOp::create(rewriter, op.getLoc(), APInt(1, 1)); | ||||||
| auto arch = determineAdderArch(op, width); | ||||||
| if (arch == AdderArchitecture::RippleCarry) | ||||||
| return lowerRippleCarryAdder(op, operands[0], operands[1], carryIn, | ||||||
| rewriter); | ||||||
| return lowerParallelPrefixAdder(op, operands[0], operands[1], carryIn, | ||||||
| rewriter); | ||||||
| } | ||||||
| return failure(); // 3-input add without const_1 — handled elsewhere | ||||||
| } | ||||||
|
|
||||||
| // Lower only when there are two inputs. | ||||||
| // Variadic operands must be lowered in a different pattern. | ||||||
| if (inputs.size() != 2) | ||||||
| return failure(); | ||||||
|
|
||||||
| auto arch = determineAdderArch(op, width); | ||||||
| if (arch == AdderArchitecture::RippleCarry) | ||||||
| return lowerRippleCarryAdder(op, inputs, rewriter); | ||||||
| return lowerParallelPrefixAdder(op, inputs, rewriter); | ||||||
| return lowerRippleCarryAdder(op, inputs[0], inputs[1], Value(), rewriter); | ||||||
| return lowerParallelPrefixAdder(op, inputs[0], inputs[1], Value(), rewriter); | ||||||
|
Comment on lines
+797
to
+815
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you create a helper function for this part and share at two places? auto arch = determineAdderArch(op, width);
if (arch == AdderArchitecture::RippleCarry)
return lowerRippleCarryAdder(op, inputs[0], inputs[1], Value(), rewriter);
return lowerParallelPrefixAdder(op, inputs[0], inputs[1], Value(), rewriter); |
||||||
| } | ||||||
|
|
||||||
| // Implement a basic ripple-carry adder for small bitwidths. | ||||||
| // carryIn may be a null Value (meaning carry_in = 0). | ||||||
| LogicalResult | ||||||
| lowerRippleCarryAdder(comb::AddOp op, ValueRange inputs, | ||||||
| lowerRippleCarryAdder(comb::AddOp op, Value inputA, Value inputB, | ||||||
| Value carryIn, | ||||||
| ConversionPatternRewriter &rewriter) const { | ||||||
| auto width = op.getType().getIntOrFloatBitWidth(); | ||||||
| // Implement a naive Ripple-carry full adder. | ||||||
| Value carry; | ||||||
| Value carry = carryIn; // null Value → carry_in = 0 | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Generally please use ASCII
Suggested change
|
||||||
|
|
||||||
| auto aBits = extractBits(rewriter, inputs[0]); | ||||||
| auto bBits = extractBits(rewriter, inputs[1]); | ||||||
| auto aBits = extractBits(rewriter, inputA); | ||||||
| auto bBits = extractBits(rewriter, inputB); | ||||||
| SmallVector<Value> results; | ||||||
| results.resize(width); | ||||||
| for (int64_t i = 0; i < width; ++i) { | ||||||
|
|
@@ -817,7 +844,6 @@ struct CombAddOpConversion : OpConversionPattern<AddOp> { | |||||
|
|
||||||
| // carry[i] = (carry[i-1] & (a[i] ^ b[i])) | (a[i] & b[i]) | ||||||
| if (!carry) { | ||||||
| // This is the first bit, so the carry is the next carry. | ||||||
| carry = comb::AndOp::create(rewriter, op.getLoc(), | ||||||
| ValueRange{aBits[i], bBits[i]}, true); | ||||||
| continue; | ||||||
|
|
@@ -835,14 +861,16 @@ struct CombAddOpConversion : OpConversionPattern<AddOp> { | |||||
|
|
||||||
| // Implement a parallel prefix adder - with Kogge-Stone or Brent-Kung trees | ||||||
| // Will introduce unused signals for the carry bits but these will be removed | ||||||
| // by the AIG pass. | ||||||
| // by the AIG pass. carryIn may be a null Value (meaning carry_in = 0). | ||||||
| LogicalResult | ||||||
| lowerParallelPrefixAdder(comb::AddOp op, ValueRange inputs, | ||||||
| lowerParallelPrefixAdder(comb::AddOp op, Value inputA, Value inputB, | ||||||
| Value carryIn, | ||||||
| ConversionPatternRewriter &rewriter) const { | ||||||
| auto width = op.getType().getIntOrFloatBitWidth(); | ||||||
| auto loc = op.getLoc(); | ||||||
|
|
||||||
| auto aBits = extractBits(rewriter, inputs[0]); | ||||||
| auto bBits = extractBits(rewriter, inputs[1]); | ||||||
| auto aBits = extractBits(rewriter, inputA); | ||||||
| auto bBits = extractBits(rewriter, inputB); | ||||||
|
|
||||||
| // Construct propagate (p) and generate (g) signals | ||||||
| SmallVector<Value> p, g; | ||||||
|
|
@@ -851,9 +879,16 @@ struct CombAddOpConversion : OpConversionPattern<AddOp> { | |||||
|
|
||||||
| for (auto [aBit, bBit] : llvm::zip(aBits, bBits)) { | ||||||
| // p_i = a_i XOR b_i | ||||||
| p.push_back(comb::XorOp::create(rewriter, op.getLoc(), aBit, bBit)); | ||||||
| p.push_back(comb::XorOp::create(rewriter, loc, aBit, bBit)); | ||||||
| // g_i = a_i AND b_i | ||||||
| g.push_back(comb::AndOp::create(rewriter, op.getLoc(), aBit, bBit)); | ||||||
| g.push_back(comb::AndOp::create(rewriter, loc, aBit, bBit)); | ||||||
| } | ||||||
|
|
||||||
| // With carry_in, adjust g[0]: g[0] = (a[0] AND b[0]) OR (p[0] AND carry_in) | ||||||
| // This bakes the carry_in into the prefix tree, avoiding a separate adder. | ||||||
| if (carryIn) { | ||||||
| Value pAndC = comb::AndOp::create(rewriter, loc, p[0], carryIn); | ||||||
| g[0] = comb::OrOp::create(rewriter, loc, g[0], pAndC); | ||||||
| } | ||||||
|
|
||||||
| LLVM_DEBUG({ | ||||||
|
|
@@ -880,28 +915,29 @@ struct CombAddOpConversion : OpConversionPattern<AddOp> { | |||||
| llvm_unreachable("Ripple-Carry should be handled separately"); | ||||||
| break; | ||||||
| case AdderArchitecture::Sklanskey: | ||||||
| lowerSklanskeyPrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix); | ||||||
| lowerSklanskeyPrefixTree(rewriter, loc, pPrefix, gPrefix); | ||||||
| break; | ||||||
| case AdderArchitecture::KoggeStone: | ||||||
| lowerKoggeStonePrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix); | ||||||
| lowerKoggeStonePrefixTree(rewriter, loc, pPrefix, gPrefix); | ||||||
| break; | ||||||
| case AdderArchitecture::BrentKung: | ||||||
| lowerBrentKungPrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix); | ||||||
| lowerBrentKungPrefixTree(rewriter, loc, pPrefix, gPrefix); | ||||||
| break; | ||||||
| } | ||||||
|
|
||||||
| // Generate result sum bits | ||||||
| // Generate result sum bits. | ||||||
| // NOTE: The result is stored in reverse order. | ||||||
| SmallVector<Value> results; | ||||||
| results.resize(width); | ||||||
| // Sum bit 0 is just p[0] since carry_in = 0 | ||||||
| results[width - 1] = p[0]; | ||||||
| // sum[0] = p[0] XOR carry_in (carry_in = 0 when null → just p[0]) | ||||||
| results[width - 1] = | ||||||
| carryIn ? comb::XorOp::create(rewriter, loc, p[0], carryIn) : p[0]; | ||||||
|
|
||||||
| // For remaining bits, sum_i = p_i XOR g_(i-1) | ||||||
| // The carry into position i is the group generate from position i-1 | ||||||
| for (int64_t i = 1; i < width; ++i) | ||||||
| results[width - 1 - i] = | ||||||
| comb::XorOp::create(rewriter, op.getLoc(), p[i], gPrefix[i - 1]); | ||||||
| comb::XorOp::create(rewriter, loc, p[i], gPrefix[i - 1]); | ||||||
|
|
||||||
| replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(rewriter, op, results); | ||||||
|
|
||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -75,15 +75,51 @@ struct DatapathCompressOpConversion : mlir::OpRewritePattern<CompressOp> { | |
| mlir::PatternRewriter &rewriter) const override { | ||
| Location loc = op.getLoc(); | ||
| auto inputs = op.getOperands(); | ||
| auto width = inputs[0].getType().getIntOrFloatBitWidth(); | ||
|
|
||
| SmallVector<SmallVector<Value>> addends; | ||
| for (auto input : inputs) { | ||
| addends.push_back( | ||
| extractBits(rewriter, input)); // Extract bits from each input | ||
| // Special case: compress(a, b, const_1) with exactly 3 inputs where one | ||
| // input is the constant 1. This pattern arises from sub → add(a, ~b, 1). | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use ASCII |
||
| // Check that both results feed a single comb.add(carry, save), then | ||
| // replace compress + add together with comb.add(a, b, const_1) so the | ||
| // downstream prefix adder can bake carry_in=1 directly into g[0]. | ||
| if (inputs.size() == 3) { | ||
| for (unsigned i = 0; i < 3; ++i) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can remove this outer loop when we know where the final argument is |
||
| auto constOp = inputs[i].getDefiningOp<hw::ConstantOp>(); | ||
| if (!constOp || !constOp.getValue().isOne()) | ||
| continue; | ||
|
Comment on lines
+87
to
+89
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As above constOp should be found as the final argument |
||
| // Both compress results must have exactly one use and share the same | ||
| // downstream comb.add(carry, save). | ||
| Value carry = op->getResult(0), save = op->getResult(1); | ||
| if (!carry.hasOneUse() || !save.hasOneUse()) | ||
| break; | ||
| auto *carryUser = carry.getUses().begin()->getOwner(); | ||
| auto *saveUser = save.getUses().begin()->getOwner(); | ||
| if (carryUser != saveUser) | ||
| break; | ||
| auto addOp = dyn_cast<comb::AddOp>(carryUser); | ||
| if (!addOp || addOp.getInputs().size() != 2) | ||
| break; | ||
| // Replace compress(a, b, 1) + add(carry, save) with add(a, b, const_1). | ||
| // The const_1 tells CombToSynth to synthesize a prefix adder with | ||
| // carry_in=1 baked into g[0], avoiding a separate Wallace tree stage. | ||
| SmallVector<Value> operands; | ||
| for (unsigned j = 0; j < 3; ++j) | ||
| if (j != i) | ||
| operands.push_back(inputs[j]); | ||
| Value sum = comb::AddOp::create( | ||
| rewriter, loc, | ||
| ValueRange{operands[0], operands[1], inputs[i]}, true); | ||
| rewriter.replaceOp(addOp, sum); | ||
| rewriter.eraseOp(op); | ||
| return success(); | ||
| } | ||
| } | ||
|
|
||
| // Compressor tree reduction | ||
| auto width = inputs[0].getType().getIntOrFloatBitWidth(); | ||
| // General case: Wallace tree reduction. | ||
| SmallVector<SmallVector<Value>> addends; | ||
| for (auto input : inputs) | ||
| addends.push_back(extractBits(rewriter, input)); | ||
|
|
||
| auto targetAddends = op.getNumResults(); | ||
| datapath::CompressorTree comp(width, addends, loc); | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because comb::AddOp and datapath::CompressOp have commutative trait, usually constant is pushed to the last element. I think it's fine to check only the last element.
https://github.com/llvm/llvm-project/blob/3d421d59ad247afeadf5d4f886c9dea14a5eb229/mlir/docs/Canonicalization.md?plain=1#L117-L118