Skip to content
Open
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,9 @@ class LayoutRematerialization {
void reduceLoopCarriedValues();
// Existing tuples of (value, layout) that needs to be updated when recreating
// scf ops. This prevents keeping track of Values that have been delete when
// rewriting slices.
DenseMap<Value, Attribute> mappedValues;
// rewriting slices. The Value maybe mapped to different attributes in remove
// layout.
DenseMap<Value, SmallVector<Attribute>> mappedValues;
// map of the values remat based on encoding.
DenseMap<std::pair<Value, Attribute>, Value> rematMapping;
// DenseMap<std::pair<Operation*, Attribute>, Operation*>
Expand All @@ -185,7 +186,10 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
Value newV) {
LDBG("addRematValue " << old << " encoding " << encoding << " " << newV);
rematMapping[{old, encoding}] = newV;
mappedValues[old] = encoding;
if (mappedValues.contains(old))
mappedValues[old].push_back(encoding);
else
mappedValues[old] = {encoding};
}

// Remove unneeded values now that we are done with the rematMapping.
Expand Down Expand Up @@ -990,22 +994,27 @@ void LayoutRematerialization::updateRematMapping(
for (auto [old, newV] : values) {
auto it = mappedValues.find(old);
if (it != mappedValues.end()) {
Attribute encoding = it->second;
auto rematIt = rematMapping.find({old, it->second});
assert(rematIt != rematMapping.end());
Value replacedValue = rematIt->second;
rematMapping.erase(rematIt);
mappedValues.erase(it);
// Loop through the replacement value to find the new version of remat
// value. This should be okay as the number of values should be small.
for (auto [before, after] : values) {
if (before == replacedValue) {
replacedValue = after;
break;
SmallVector<Attribute> encodings = it->second;
for (Attribute encoding : encodings) {
auto rematIt = rematMapping.find({old, encoding});
assert(rematIt != rematMapping.end());
Value replacedValue = rematIt->second;
rematMapping.erase(rematIt);
// Loop through the replacement value to find the new version of remat
// value. This should be okay as the number of values should be small.
for (auto [before, after] : values) {
if (before == replacedValue) {
replacedValue = after;
break;
}
}
rematMapping[{newV, encoding}] = replacedValue;
}
rematMapping[{newV, encoding}] = replacedValue;
mappedValues[newV] = encoding;
mappedValues.erase(it);
if (mappedValues.contains(newV))
mappedValues[newV].append(encodings);
else
mappedValues[newV] = std::move(encodings);
}
}
}
Expand Down Expand Up @@ -1199,6 +1208,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
deadOps.push_back(forOp.getOperation());
Block &loopBody = *newForOp.getBody();
for (auto m : argMapping) {
mapping.map(newForOp.getResult(m.first), newForOp.getResult(m.second));
mapping.map(forOp.getResult(m.first), newForOp.getResult(m.second));
int numIndVars = newForOp.getNumInductionVars();
mapping.map(loopBody.getArgument(m.first + numIndVars),
Expand Down Expand Up @@ -1309,8 +1319,12 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
builder.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv));
}

for (Operation *op : deadOps)
opToDelete.insert(op);
for (Operation *op : deadOps) {
if (!isa<scf::ForOp>(op))
opToDelete.insert(op);
else
op->erase();
}
}

void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
Expand Down Expand Up @@ -1485,8 +1499,9 @@ void LayoutRematerialization::backwardRematerialization(
// rematerialized.
SetVector<Value> slice;
DenseMap<Value, Attribute> layout;
LogicalResult result = getRematerializableSlice(
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout);
LogicalResult result = getRematerializableSlice(convertOp.getSrcMutable(),
targetType.getEncoding(),
slice, layout, nullptr);
Copy link

Copilot AI Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Adding a nullptr parameter without documentation or context makes the API unclear. Consider adding a comment explaining what this parameter represents or using a named constant instead of nullptr.

Suggested change
LogicalResult result = getRematerializableSlice(convertOp.getSrcMutable(),
targetType.getEncoding(),
slice, layout, nullptr);
// No filter function is provided, so we pass kNoRematerializationFilter (nullptr).
static constexpr auto kNoRematerializationFilter = nullptr;
LogicalResult result = getRematerializableSlice(convertOp.getSrcMutable(),
targetType.getEncoding(),
slice, layout, kNoRematerializationFilter);

Copilot uses AI. Check for mistakes.
if (result.failed()) {
LDBG(" getRematerializableSlice failed");
return;
Expand Down
Loading