Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ LogicalResult getConvertBackwardSlice(
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation = nullptr,
std::function<Value(OpOperand &, Attribute)> getExistingConversion =
nullptr);
nullptr,
bool includeForOp = false);

LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name,
ArrayRef<Type> paramTypes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,20 +158,23 @@ class LayoutRematerialization {
getConvertBackwardSlice(OpOperand &root, Attribute rootEncoding,
SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation);
std::function<bool(Operation *)> stopPropagation,
bool includeForOp = false);

LogicalResult getRematerializableSlice(
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation = nullptr);
std::function<bool(Operation *)> stopPropagation = nullptr,
bool includeForOp = false);

private:
void updateRematMapping(SmallVector<std::tuple<Value, Value>> &values);
void reduceLoopCarriedValues();
// Existing tuples of (value, layout) that needs to be updated when recreating
// scf ops. This prevents keeping track of Values that have been delete when
// rewriting slices.
DenseMap<Value, Attribute> mappedValues;
// rewriting slices. The Value maybe mapped to different attributes in remove
// layout.
DenseMap<Value, SmallVector<Attribute>> mappedValues;
// map of the values remat based on encoding.
DenseMap<std::pair<Value, Attribute>, Value> rematMapping;
// DenseMap<std::pair<Operation*, Attribute>, Operation*>
Expand All @@ -185,7 +188,11 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
Value newV) {
LDBG("addRematValue " << old << " encoding " << encoding << " " << newV);
rematMapping[{old, encoding}] = newV;
mappedValues[old] = encoding;
if (mappedValues.contains(old)) {
mappedValues[old].push_back(encoding);
} else {
mappedValues[old] = {encoding};
}
}

// Remove unneeded values now that we are done with the rematMapping.
Expand Down Expand Up @@ -990,22 +997,28 @@ void LayoutRematerialization::updateRematMapping(
for (auto [old, newV] : values) {
auto it = mappedValues.find(old);
if (it != mappedValues.end()) {
Attribute encoding = it->second;
auto rematIt = rematMapping.find({old, it->second});
assert(rematIt != rematMapping.end());
Value replacedValue = rematIt->second;
rematMapping.erase(rematIt);
mappedValues.erase(it);
// Loop through the replacement value to find the new version of remat
// value. This should be okay as the number of values should be small.
for (auto [before, after] : values) {
if (before == replacedValue) {
replacedValue = after;
break;
SmallVector<Attribute> encodings = it->second;
for (auto encoding : encodings) {
auto rematIt = rematMapping.find({old, encoding});
assert(rematIt != rematMapping.end());
Value replacedValue = rematIt->second;
rematMapping.erase(rematIt);
// Loop through the replacement value to find the new version of remat
// value. This should be okay as the number of values should be small.
for (auto [before, after] : values) {
if (before == replacedValue) {
replacedValue = after;
break;
}
}
rematMapping[{newV, encoding}] = replacedValue;
}
mappedValues.erase(it);
if (mappedValues.contains(newV)) {
mappedValues[newV].append(encodings);
} else {
mappedValues[newV] = std::move(encodings);
}
rematMapping[{newV, encoding}] = replacedValue;
mappedValues[newV] = encoding;
}
}
}
Expand Down Expand Up @@ -1167,6 +1180,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
deadOps.push_back(forOp.getOperation());
Block &loopBody = *newForOp.getBody();
for (auto m : argMapping) {
mapping.map(newForOp.getResult(m.first), newForOp.getResult(m.second));
mapping.map(forOp.getResult(m.first), newForOp.getResult(m.second));
int numIndVars = newForOp.getNumInductionVars();
mapping.map(loopBody.getArgument(m.first + numIndVars),
Expand Down Expand Up @@ -1277,8 +1291,12 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
builder.replaceAllUsesWith(std::get<0>(kv), std::get<1>(kv));
}

for (Operation *op : deadOps)
opToDelete.insert(op);
for (Operation *op : deadOps) {
if (!isa<scf::ForOp>(op))
opToDelete.insert(op);
else
op->erase();
}
}

void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
Expand All @@ -1291,7 +1309,7 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
LogicalResult LayoutRematerialization::getConvertBackwardSlice(
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation) {
std::function<bool(Operation *)> stopPropagation, bool includeForOp) {
// Allow re-using existing conversions for a value. Check dominance of any
// reusable materializations against the root value. This is sufficient
// because the conversions are processed in post-order.
Expand Down Expand Up @@ -1320,15 +1338,16 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
};

return ttgi::getConvertBackwardSlice(root, slice, rootEncoding, layout,
stopPropagation, getExistingConversion);
stopPropagation, getExistingConversion,
includeForOp);
}

LogicalResult LayoutRematerialization::getRematerializableSlice(
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation) {
LogicalResult result = getConvertBackwardSlice(root, rootEncoding, slice,
layout, stopPropagation);
std::function<bool(Operation *)> stopPropagation, bool includeForOp) {
LogicalResult result = getConvertBackwardSlice(
root, rootEncoding, slice, layout, stopPropagation, includeForOp);
if (result.failed() || slice.empty())
return failure();

Expand Down Expand Up @@ -1453,8 +1472,9 @@ void LayoutRematerialization::backwardRematerialization(
// rematerialized.
SetVector<Value> slice;
DenseMap<Value, Attribute> layout;
LogicalResult result = getRematerializableSlice(
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout);
LogicalResult result = getRematerializableSlice(convertOp.getSrcMutable(),
targetType.getEncoding(),
slice, layout, nullptr, true);
if (result.failed()) {
LDBG(" getRematerializableSlice failed");
return;
Expand Down
24 changes: 19 additions & 5 deletions third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ LogicalResult getConvertBackwardSlice(
OpOperand &root, SetVector<Value> &slice, Attribute rootEncoding,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation,
std::function<Value(OpOperand &, Attribute)> getExistingConversion) {
std::function<Value(OpOperand &, Attribute)> getExistingConversion,
bool includeForOp) {
DenseSet<std::pair<OpOperand *, Attribute>> seen;
SmallVector<std::pair<OpOperand *, Attribute>> queue;

Expand Down Expand Up @@ -216,10 +217,7 @@ LogicalResult getConvertBackwardSlice(
queue.pop_back();
if (!isTensorOrTensorPointerType(currentValue.getType()))
continue;
// Skip propagating through for op results for now.
// TODO: enable this based on needs.
if (currentValue.getDefiningOp<scf::ForOp>())
return failure();

if (failed(updateLayout(currentValue, encoding)))
return failure();

Expand All @@ -231,6 +229,22 @@ LogicalResult getConvertBackwardSlice(
currentValue = existing;
}

if (auto forOp = currentValue.getDefiningOp<scf::ForOp>()) {
if (!includeForOp)
return failure();
if (stopPropagation && stopPropagation(forOp))
continue;
unsigned argIdx = mlir::cast<OpResult>(currentValue).getResultNumber();
int numIndVars = forOp.getNumInductionVars();
Block &loopBody = *forOp.getBody();
auto blockArg = loopBody.getArgument(argIdx + numIndVars);
OpOperand *initOperand = forOp.getTiedLoopInit(blockArg);
OpOperand &yieldOperand = loopBody.getTerminator()->getOpOperand(argIdx);
enqueue(*initOperand, encoding);
enqueue(yieldOperand, encoding);
continue;
}

if (auto ifOp = currentValue.getDefiningOp<scf::IfOp>()) {
if (stopPropagation && stopPropagation(ifOp))
continue;
Expand Down