diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index cf42e37cdf..391bffe077 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -172,8 +172,9 @@ class LayoutRematerialization { void reduceLoopCarriedValues(); // Existing tuples of (value, layout) that needs to be updated when recreating // scf ops. This prevents keeping track of Values that have been delete when - // rewriting slices. - DenseMap mappedValues; + // rewriting slices. The Value maybe mapped to different attributes in remove + // layout. + DenseMap> mappedValues; // map of the values remat based on encoding. DenseMap, Value> rematMapping; // DenseMap, Operation*> @@ -187,7 +188,10 @@ void LayoutRematerialization::addRematValue(Value old, Attribute encoding, Value newV) { LDBG("addRematValue " << old << " encoding " << encoding << " " << newV); rematMapping[{old, encoding}] = newV; - mappedValues[old] = encoding; + if (mappedValues.contains(old)) + mappedValues[old].push_back(encoding); + else + mappedValues[old] = {encoding}; } // Remove unneeded values now that we are done with the rematMapping. @@ -992,22 +996,27 @@ void LayoutRematerialization::updateRematMapping( for (auto [old, newV] : values) { auto it = mappedValues.find(old); if (it != mappedValues.end()) { - Attribute encoding = it->second; - auto rematIt = rematMapping.find({old, it->second}); - assert(rematIt != rematMapping.end()); - Value replacedValue = rematIt->second; - rematMapping.erase(rematIt); - mappedValues.erase(it); - // Loop through the replacement value to find the new version of remat - // value. This should be okay as the number of values should be small. - for (auto [before, after] : values) { - if (before == replacedValue) { - replacedValue = after; - break; + SmallVector encodings = it->second; + for (Attribute encoding : encodings) { + auto rematIt = rematMapping.find({old, encoding}); + assert(rematIt != rematMapping.end()); + Value replacedValue = rematIt->second; + rematMapping.erase(rematIt); + // Loop through the replacement value to find the new version of remat + // value. This should be okay as the number of values should be small. + for (auto [before, after] : values) { + if (before == replacedValue) { + replacedValue = after; + break; + } } + rematMapping[{newV, encoding}] = replacedValue; } - rematMapping[{newV, encoding}] = replacedValue; - mappedValues[newV] = encoding; + mappedValues.erase(it); + if (mappedValues.contains(newV)) + mappedValues[newV].append(encodings); + else + mappedValues[newV] = std::move(encodings); } } }