Skip to content

[LLHD] Add llhd.sig.extract support to Mem2Reg #8542

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions include/circt/Dialect/Comb/CombOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ using llvm::KnownBits;
/// in neither set is unknown.
KnownBits computeKnownBits(Value value);

/// Create the ops to zero-extend a value to an integer of equal or larger type.
Value createZExt(OpBuilder &builder, Location loc, Value value,
unsigned targetWidth);

/// Create a sign extension operation from a value of integer type to an equal
/// or larger integer type.
Value createOrFoldSExt(Location loc, Value value, Type destTy,
Expand All @@ -67,6 +71,20 @@ Value constructMuxTree(OpBuilder &builder, Location loc,
ArrayRef<Value> selectors, ArrayRef<Value> leafNodes,
Value outOfBoundsValue);

/// Extract a range of bits from an integer at a dynamic offset.
Value createDynamicExtract(OpBuilder &builder, Location loc, Value value,
Value offset, unsigned width);

/// Replace a range of bits in an integer at a dynamic offset, and return the
/// updated integer value. Calls `createInject` if the offset is constant.
Value createDynamicInject(OpBuilder &builder, Location loc, Value value,
Value offset, Value replacement,
bool twoState = false);

/// Replace a range of bits in an integer and return the updated integer value.
Value createInject(OpBuilder &builder, Location loc, Value value,
unsigned offset, Value replacement);

} // namespace comb
} // namespace circt

Expand Down
105 changes: 105 additions & 0 deletions lib/Dialect/Comb/CombOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,29 @@
#include "circt/Dialect/HW/HWOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/FormatVariadic.h"

using namespace circt;
using namespace comb;

Value comb::createZExt(OpBuilder &builder, Location loc, Value value,
unsigned targetWidth) {
assert(value.getType().isSignlessInteger());
auto inputWidth = value.getType().getIntOrFloatBitWidth();
assert(inputWidth <= targetWidth);

// Nothing to do if the width already matches.
if (inputWidth == targetWidth)
return value;

// Create a zero constant for the upper bits.
auto zeros = builder.create<hw::ConstantOp>(
loc, builder.getIntegerType(targetWidth - inputWidth), 0);
return builder.createOrFold<ConcatOp>(loc, zeros, value);
}

/// Create a sign extension operation from a value of integer type to an equal
/// or larger integer type.
Value comb::createOrFoldSExt(Location loc, Value value, Type destTy,
Expand Down Expand Up @@ -111,6 +128,94 @@ Value comb::constructMuxTree(OpBuilder &builder, Location loc,
return constructTreeHelper(0, llvm::Log2_64_Ceil(leafNodes.size()));
}

Value comb::createDynamicExtract(OpBuilder &builder, Location loc, Value value,
Value offset, unsigned width) {
assert(value.getType().isSignlessInteger());
auto valueWidth = value.getType().getIntOrFloatBitWidth();
assert(width <= valueWidth);

// Handle the special case where the offset is constant.
APInt constOffset;
if (matchPattern(offset, mlir::m_ConstantInt(&constOffset)))
if (constOffset.getActiveBits() < 32)
return builder.createOrFold<comb::ExtractOp>(
loc, value, constOffset.getZExtValue(), width);

// Zero-extend the offset, shift the value down, and extract the requested
// number of bits.
offset = createZExt(builder, loc, offset, valueWidth);
value = builder.createOrFold<comb::ShrUOp>(loc, value, offset);
return builder.createOrFold<comb::ExtractOp>(loc, value, 0, width);
}

Value comb::createDynamicInject(OpBuilder &builder, Location loc, Value value,
Value offset, Value replacement,
bool twoState) {
assert(value.getType().isSignlessInteger());
assert(replacement.getType().isSignlessInteger());
auto largeWidth = value.getType().getIntOrFloatBitWidth();
auto smallWidth = replacement.getType().getIntOrFloatBitWidth();
assert(smallWidth <= largeWidth);

// If we're inserting a zero-width value there's nothing to do.
if (smallWidth == 0)
return value;

// Handle the special case where the offset is constant.
APInt constOffset;
if (matchPattern(offset, mlir::m_ConstantInt(&constOffset)))
if (constOffset.getActiveBits() < 32)
return createInject(builder, loc, value, constOffset.getZExtValue(),
replacement);

// Zero-extend the offset and clear the value bits we are replacing.
offset = createZExt(builder, loc, offset, largeWidth);
Value mask = builder.create<hw::ConstantOp>(
loc, APInt::getLowBitsSet(largeWidth, smallWidth));
mask = builder.createOrFold<comb::ShlOp>(loc, mask, offset);
mask = createOrFoldNot(loc, mask, builder, true);
value = builder.createOrFold<comb::AndOp>(loc, value, mask, twoState);

// Zero-extend the replacement value, shift it up to the offset, and merge it
// with the value that has the corresponding bits cleared.
replacement = createZExt(builder, loc, replacement, largeWidth);
replacement = builder.createOrFold<comb::ShlOp>(loc, replacement, offset);
return builder.createOrFold<comb::OrOp>(loc, value, replacement, twoState);
}

Value comb::createInject(OpBuilder &builder, Location loc, Value value,
unsigned offset, Value replacement) {
assert(value.getType().isSignlessInteger());
assert(replacement.getType().isSignlessInteger());
auto largeWidth = value.getType().getIntOrFloatBitWidth();
auto smallWidth = replacement.getType().getIntOrFloatBitWidth();
assert(smallWidth <= largeWidth);

// If the offset is outside the value there's nothing to do.
if (offset >= largeWidth)
return value;

// If we're inserting a zero-width value there's nothing to do.
if (smallWidth == 0)
return value;

// Assemble the pieces of the injection as everything below the offset, the
// replacement value, and everything above the replacement value.
SmallVector<Value, 3> fragments;
auto end = offset + smallWidth;
if (end < largeWidth)
fragments.push_back(
builder.create<comb::ExtractOp>(loc, value, end, largeWidth - end));
if (end <= largeWidth)
fragments.push_back(replacement);
else
fragments.push_back(builder.create<comb::ExtractOp>(loc, replacement, 0,
largeWidth - offset));
if (offset > 0)
fragments.push_back(builder.create<comb::ExtractOp>(loc, value, 0, offset));
return builder.createOrFold<comb::ConcatOp>(loc, fragments);
}

//===----------------------------------------------------------------------===//
// ICmpOp
//===----------------------------------------------------------------------===//
Expand Down
24 changes: 18 additions & 6 deletions lib/Dialect/LLHD/Transforms/Mem2Reg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -567,11 +567,18 @@ static Value unpackProjections(OpBuilder &builder, Value value,
ProjectionStack &projections) {
for (auto &projection : llvm::reverse(projections)) {
projection.into = value;
value = TypeSwitch<Operation *, Value>(projection.op)
.Case<SigArrayGetOp>([&](auto op) {
return builder.createOrFold<hw::ArrayGetOp>(
op.getLoc(), value, op.getIndex());
});
value =
TypeSwitch<Operation *, Value>(projection.op)
.Case<SigArrayGetOp>([&](auto op) {
return builder.createOrFold<hw::ArrayGetOp>(op.getLoc(), value,
op.getIndex());
})
.Case<SigExtractOp>([&](auto op) {
auto type = cast<hw::InOutType>(op.getType()).getElementType();
auto width = type.getIntOrFloatBitWidth();
return comb::createDynamicExtract(builder, op.getLoc(), value,
op.getLowBit(), width);
});
}
return value;
}
Expand All @@ -595,6 +602,11 @@ static Value packProjections(OpBuilder &builder, Value value,
.Case<SigArrayGetOp>([&](auto op) {
return builder.createOrFold<hw::ArrayInjectOp>(
op.getLoc(), projection.into, op.getIndex(), value);
})
.Case<SigExtractOp>([&](auto op) {
return comb::createDynamicInject(builder, op.getLoc(),
projection.into,
op.getLowBit(), value);
});
}
return value;
Expand Down Expand Up @@ -753,7 +765,7 @@ void Promoter::findPromotableSlots() {
return true;
// Projection operations are okay as long as they are in the same block
// as any of their users.
if (isa<SigArrayGetOp>(user)) {
if (isa<SigArrayGetOp, SigExtractOp>(user)) {
for (auto *projectionUser : user->getUsers()) {
if (projectionUser->getBlock() != user->getBlock())
return false;
Expand Down
120 changes: 120 additions & 0 deletions test/Dialect/LLHD/Transforms/mem2reg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,126 @@ hw.module @NestedArrayGet3D(
}
}

// CHECK-LABEL: @BasicSigExtract
hw.module @BasicSigExtract(in %u: i42, in %v: i10, in %i: i6, in %q: i1) {
%0 = llhd.constant_time <0ns, 0d, 1e>
%a = llhd.sig %u : i42
// CHECK: llhd.process
llhd.process {
// CHECK-NOT: llhd.drv
llhd.drv %a, %u after %0 : !hw.inout<i42>
// CHECK-NOT: llhd.sig.extract
%1 = llhd.sig.extract %a from %i : (!hw.inout<i42>) -> !hw.inout<i10>
// CHECK-NOT: llhd.drv
// CHECK-NEXT: [[EXT1:%.+]] = hw.constant 0 : i36
// CHECK-NEXT: [[EXT2:%.+]] = comb.concat [[EXT1]], %i : i36, i6
// CHECK-NEXT: [[EXT3:%.+]] = comb.shru %u, [[EXT2]] : i42
// CHECK-NEXT: [[EXT4:%.+]] = comb.extract [[EXT3]] from 0 : (i42) -> i10
// CHECK-NEXT: [[MUX:%.+]] = comb.mux %q, %v, [[EXT4]] : i10
// CHECK-NEXT: [[INJ1:%.+]] = hw.constant 0 : i36
// CHECK-NEXT: [[INJ2:%.+]] = comb.concat [[INJ1]], %i : i36, i6
// CHECK-NEXT: [[INJ3:%.+]] = hw.constant 1023 : i42
// CHECK-NEXT: [[INJ4:%.+]] = comb.shl [[INJ3]], [[INJ2]] : i42
// CHECK-NEXT: [[INJ5:%.+]] = hw.constant -1 : i42
// CHECK-NEXT: [[INJ6:%.+]] = comb.xor bin [[INJ4]], [[INJ5]] : i42
// CHECK-NEXT: [[INJ7:%.+]] = comb.and %u, [[INJ6]] : i42
// CHECK-NEXT: [[INJ8:%.+]] = hw.constant 0 : i32
// CHECK-NEXT: [[INJ9:%.+]] = comb.concat [[INJ8]], [[MUX]] : i32, i10
// CHECK-NEXT: [[INJ10:%.+]] = comb.shl [[INJ9]], [[INJ2]] : i42
// CHECK-NEXT: [[INJ11:%.+]] = comb.or [[INJ7]], [[INJ10]] : i42
llhd.drv %1, %v after %0 if %q : !hw.inout<i10>
// CHECK-NOT: llhd.prb
%2 = llhd.prb %a : !hw.inout<i42>
// CHECK-NEXT: call @use_i42([[INJ11]])
func.call @use_i42(%2) : (i42) -> ()
// CHECK-NEXT: llhd.constant_time
// CHECK-NEXT: llhd.drv %a, [[INJ11]]
// CHECK-NEXT: llhd.halt
llhd.halt
}
}

// CHECK-LABEL: @CombCreateDynamicInject
hw.module @CombCreateDynamicInject(in %u: i42, in %v: i10, in %q: i1) {
%0 = llhd.constant_time <0ns, 0d, 1e>
%a = llhd.sig %u : i42

// offset = 0
// CHECK: llhd.process
llhd.process {
// CHECK-NEXT: [[TMP1:%.+]] = comb.extract %u from 10 : (i42) -> i32
// CHECK-NEXT: [[TMP2:%.+]] = comb.concat [[TMP1]], %v : i32, i10
// CHECK-NEXT: llhd.constant_time
// CHECK-NEXT: llhd.drv %a, [[TMP2]]
// CHECK-NEXT: llhd.halt
%c0_i6 = hw.constant 0 : i6
%1 = llhd.sig.extract %a from %c0_i6 : (!hw.inout<i42>) -> !hw.inout<i10>
llhd.drv %a, %u after %0 : !hw.inout<i42>
llhd.drv %1, %v after %0 : !hw.inout<i10>
llhd.halt
}

// offset > 0, end < 42
// CHECK: llhd.process
llhd.process {
// CHECK-NEXT: [[TMP1:%.+]] = comb.extract %u from 30 : (i42) -> i12
// CHECK-NEXT: [[TMP2:%.+]] = comb.extract %u from 0 : (i42) -> i20
// CHECK-NEXT: [[TMP3:%.+]] = comb.concat [[TMP1]], %v, [[TMP2]] : i12, i10, i20
// CHECK-NEXT: llhd.constant_time
// CHECK-NEXT: llhd.drv %a, [[TMP3]]
// CHECK-NEXT: llhd.halt
%c20_i6 = hw.constant 20 : i6
%1 = llhd.sig.extract %a from %c20_i6 : (!hw.inout<i42>) -> !hw.inout<i10>
llhd.drv %a, %u after %0 : !hw.inout<i42>
llhd.drv %1, %v after %0 : !hw.inout<i10>
llhd.halt
}

// end = 42
// CHECK: llhd.process
llhd.process {
// CHECK-NEXT: [[TMP1:%.+]] = comb.extract %u from 0 : (i42) -> i32
// CHECK-NEXT: [[TMP2:%.+]] = comb.concat %v, [[TMP1]] : i10, i32
// CHECK-NEXT: llhd.constant_time
// CHECK-NEXT: llhd.drv %a, [[TMP2]]
// CHECK-NEXT: llhd.halt
%c32_i6 = hw.constant 32 : i6
%1 = llhd.sig.extract %a from %c32_i6 : (!hw.inout<i42>) -> !hw.inout<i10>
llhd.drv %a, %u after %0 : !hw.inout<i42>
llhd.drv %1, %v after %0 : !hw.inout<i10>
llhd.halt
}

// offset < 42, end > 42
// CHECK: llhd.process
llhd.process {
// CHECK-NEXT: [[TMP1:%.+]] = comb.extract %v from 0 : (i10) -> i5
// CHECK-NEXT: [[TMP2:%.+]] = comb.extract %u from 0 : (i42) -> i37
// CHECK-NEXT: [[TMP3:%.+]] = comb.concat [[TMP1]], [[TMP2]] : i5, i37
// CHECK-NEXT: llhd.constant_time
// CHECK-NEXT: llhd.drv %a, [[TMP3]]
// CHECK-NEXT: llhd.halt
%c37_i6 = hw.constant 37 : i6
%1 = llhd.sig.extract %a from %c37_i6 : (!hw.inout<i42>) -> !hw.inout<i10>
llhd.drv %a, %u after %0 : !hw.inout<i42>
llhd.drv %1, %v after %0 : !hw.inout<i10>
llhd.halt
}

// offset >= 42
// CHECK: llhd.process
llhd.process {
// CHECK-NEXT: llhd.constant_time
// CHECK-NEXT: llhd.drv %a, %u
// CHECK-NEXT: llhd.halt
%c42_i6 = hw.constant 42 : i6
%1 = llhd.sig.extract %a from %c42_i6 : (!hw.inout<i42>) -> !hw.inout<i10>
llhd.drv %a, %u after %0 : !hw.inout<i42>
llhd.drv %1, %v after %0 : !hw.inout<i10>
llhd.halt
}
}

func.func private @use_i42(%arg0: i42)
func.func private @use_inout_i42(%arg0: !hw.inout<i42>)
func.func private @use_array_i42(%arg0: !hw.array<4xi42>)