Skip to content

Commit c4417e6

Browse files
[compiler] Fix bufferization uses of bufferization.materialize_in_destination
In (previous commit) tried to fix a number of bufferization issues dealing with our use of `bufferization.materialize_in_destination` and `scf.if` ops. However, I discovered a better solution, which I implement in this change. The problem is that `bufferization.materialize_in_destination` is not just a tensor-land copy operation. It is meant to indicate that the target of the copy must be the buffer which will be associated with the `dest` SSA, and it must be bufferized in-place. Bufferization will raise an error if the bufferization does not occur in-place. This is useful for indicating that e.g. the resulting bufferized IR *must* copy a source data into a particular Value with important meaning (e.g. function output argument). However, we were using it in a couple places (namely convert all `tensor.cast` to `tensor.empty` + `bufferization.materialize_in_destination` ops) where the "in place" requirement is not necessary. This was causing bufferization failures in edge cases associated with `scf.if`. To fix this, we just need an alternate copy-like operation that is bufferizable and a DesitnationStyleOp. Luckily, there is already `linalg.copy`, which we can use as a drop-in replacement. Then, to recover the original behavior, we just convert the `linalg.copy` to `memref.copy` operations. GitOrigin-RevId: 93ed038f47690d33633db23be5e8f70b3d89d119
1 parent a9800c7 commit c4417e6

File tree

12 files changed

+229
-384
lines changed

12 files changed

+229
-384
lines changed

mlir-tensorrt/compiler/include/mlir-tensorrt/Conversion/Passes.td

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,11 @@ def ConvertStablehloToTensorRTPass : Pass<"convert-stablehlo-to-tensorrt"> {
6565
"prefer converting to 'tensorrt.einsum' over 'tensorrt.matrix_multiply'">
6666
];
6767
}
68-
#endif // MLIR_TENSORRT_ENABLE_HLO
6968

7069
//===----------------------------------------------------------------------===//
7170
// ChloToStableHloExt
7271
//===----------------------------------------------------------------------===//
7372

74-
#ifdef MLIR_TENSORRT_ENABLE_HLO
7573
def ConvertChloToStableHloExtPass : Pass<"convert-chlo-to-stablehlo-ext"> {
7674
let summary = "Convert specific CHLO operations to stablehlo";
7775
let description = [{
@@ -89,9 +87,8 @@ def ConvertChloToStableHloExtPass : Pass<"convert-chlo-to-stablehlo-ext"> {
8987
"do not convert chlo.topk ops">,
9088
];
9189
}
92-
#endif // MLIR_TENSORRT_ENABLE_HLO
93-
9490

91+
#endif // MLIR_TENSORRT_ENABLE_HLO
9592

9693
//===----------------------------------------------------------------------===//
9794
// HostToEmitC
@@ -145,6 +142,17 @@ def ConvertTensorRTToEmitCPass : Pass<"convert-tensorrt-to-emitc",
145142
let dependentDialects = ["::mlir::emitc::EmitCDialect"];
146143
}
147144

145+
//===----------------------------------------------------------------------===//
146+
// LowerLinalgCopiesPass
147+
//===----------------------------------------------------------------------===//
148+
149+
def LowerLinalgCopiesPass : Pass<"lower-linalg-copies"> {
150+
let summary = "Lower linalg.copy to memref.copy or other operations";
151+
let description = [{
152+
This pass lowers `linalg.copy` to `memref.copy`.
153+
}];
154+
}
155+
148156
//===----------------------------------------------------------------------===//
149157
// ConvertMemRefToCUDAPass
150158
//===----------------------------------------------------------------------===//

mlir-tensorrt/compiler/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_subdirectory(CUDAToLLVM)
1313
add_subdirectory(HostToEmitC)
1414
add_subdirectory(HostToLLVM)
1515
add_subdirectory(LLVMCommon)
16+
add_subdirectory(LowerLinalgCopies)
1617
add_subdirectory(MemRefToCUDA)
1718
add_subdirectory(PlanToExecutor)
1819
add_subdirectory(PlanToLLVM)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
add_mlir_tensorrt_library(MLIRTensorRTLowerLinalgCopies
2+
LowerLinalgCopies.cpp
3+
4+
LINK_LIBS PUBLIC
5+
MLIRLinalgDialect
6+
MLIRMemRefDialect
7+
MLIRPass
8+
MLIRTransformUtils
9+
)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
//===- LowerLinalgCopies.cpp ----------------------------------------------===//
2+
//
3+
// SPDX-FileCopyrightText: Copyright 2025 NVIDIA CORPORATION & AFFILIATES.
4+
// All rights reserved.
5+
// SPDX-License-Identifier: Apache-2.0
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===----------------------------------------------------------------------===//
20+
///
21+
/// Implementation of `lower-linalg-copies` pass.
22+
///
23+
//===----------------------------------------------------------------------===//
24+
#include "mlir-tensorrt/Conversion/Passes.h"
25+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
26+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
27+
#include "mlir/IR/OperationSupport.h"
28+
#include "mlir/Transforms/DialectConversion.h"
29+
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
30+
31+
namespace mlir {
32+
#define GEN_PASS_DEF_LOWERLINALGCOPIESPASS
33+
#include "mlir-tensorrt/Conversion/Passes.h.inc"
34+
} // namespace mlir
35+
36+
using namespace mlir;
37+
38+
namespace {
39+
40+
struct LowerLinalgCopyPattern : public OpRewritePattern<linalg::CopyOp> {
41+
using OpRewritePattern::OpRewritePattern;
42+
LogicalResult matchAndRewrite(linalg::CopyOp op,
43+
PatternRewriter &rewriter) const override {
44+
if (!op.hasPureBufferSemantics())
45+
return rewriter.notifyMatchFailure(op, "expected pure buffer semantics");
46+
rewriter.replaceOpWithNewOp<memref::CopyOp>(op, op.getInputs().front(),
47+
op.getOutputs().front());
48+
return success();
49+
}
50+
};
51+
52+
class LowerLinalgCopiesPass
53+
: public impl::LowerLinalgCopiesPassBase<LowerLinalgCopiesPass> {
54+
using Base::Base;
55+
56+
void runOnOperation() override {
57+
MLIRContext *ctx = &getContext();
58+
RewritePatternSet patterns(ctx);
59+
patterns.insert<LowerLinalgCopyPattern>(ctx);
60+
walkAndApplyPatterns(getOperation(), std::move(patterns));
61+
}
62+
};
63+
64+
} // namespace

mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/AllocTensors.cpp

Lines changed: 15 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -63,81 +63,6 @@ using bufferization::OneShotAnalysisState;
6363
using bufferization::func_ext::FuncAnalysisState;
6464
using bufferization::func_ext::FuncOpAnalysisState;
6565

66-
namespace {
67-
68-
/// Simplify a func.return operand produced by
69-
/// `materialize_in_dest(cast(materialize_in_dest(..., %alloc)), %out_arg)` so
70-
/// that only the single `materialize_in_dest` is used directly into the block
71-
/// argument.
72-
struct RemoveRedundantMaterializeInDestPattern
73-
: OpRewritePattern<bufferization::MaterializeInDestinationOp> {
74-
using OpRewritePattern::OpRewritePattern;
75-
LogicalResult matchAndRewrite(bufferization::MaterializeInDestinationOp op,
76-
PatternRewriter &rewriter) const override {
77-
if (!op->hasOneUse() || !isa<func::ReturnOp>(*op->user_begin()))
78-
return failure();
79-
80-
auto dest = dyn_cast<BlockArgument>(op.getDest());
81-
auto castOp = op.getSource().getDefiningOp<tensor::CastOp>();
82-
auto funcOp = op->getParentOfType<func::FuncOp>();
83-
if (!castOp || !dest || !funcOp ||
84-
dest.getOwner() != &funcOp.getBody().front())
85-
return failure();
86-
87-
auto producer =
88-
castOp.getSource()
89-
.getDefiningOp<bufferization::MaterializeInDestinationOp>();
90-
if (!producer || !producer->hasOneUse() ||
91-
!producer.getDest().hasOneUse() ||
92-
!producer.getDest().getDefiningOp<bufferization::AllocTensorOp>())
93-
return failure();
94-
95-
// Replace the returned value with the result of the cast.
96-
Location loc = op->getLoc();
97-
rewriter.replaceOp(op, castOp);
98-
99-
// Create a new cast on the block arg to the type of the producer alloc
100-
// result.
101-
rewriter.setInsertionPoint(producer);
102-
auto blockArgCast = rewriter.create<tensor::CastOp>(
103-
loc, producer.getDest().getType(), dest);
104-
// Update the producer materialization to materialize into the block arg.
105-
rewriter.replaceOp(producer.getDest().getDefiningOp(), blockArgCast);
106-
return success();
107-
}
108-
};
109-
110-
/// Rewrite `tensor.empty` to `bufferization.alloc_tensor` in the `device`
111-
/// memory space.
112-
struct RewriteEmptyTensor : public OpRewritePattern<tensor::EmptyOp> {
113-
using OpRewritePattern::OpRewritePattern;
114-
LogicalResult matchAndRewrite(tensor::EmptyOp op,
115-
PatternRewriter &rewriter) const override {
116-
auto memorySpace =
117-
dyn_cast_or_null<MemorySpaceAttr>(op.getType().getEncoding());
118-
if (!memorySpace)
119-
return failure();
120-
rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>(
121-
op, op.getType(), op.getDynamicSizes(),
122-
/*copy=*/Value{}, /*size_hint=*/Value{}, memorySpace);
123-
return success();
124-
}
125-
};
126-
127-
/// Drop `bufferization.alloc_tensor` operations that do not have uses.
128-
struct CleanupAllocTensorOps
129-
: public OpRewritePattern<bufferization::AllocTensorOp> {
130-
using OpRewritePattern::OpRewritePattern;
131-
LogicalResult matchAndRewrite(bufferization::AllocTensorOp op,
132-
PatternRewriter &rewriter) const override {
133-
if (!op->use_empty())
134-
return failure();
135-
rewriter.eraseOp(op);
136-
return success();
137-
}
138-
};
139-
} // namespace
140-
14166
/// Creates a DPS argument of type `argType` in the first block of `func` by
14267
/// appending to the end of current arguments. It then updates the function
14368
/// type, adds a `executor.result_arg` argument attribute to the new arg, and
@@ -701,7 +626,12 @@ static void uniqueEmptyTensorUses(RewriterBase &rewriter, ModuleLikeOp op) {
701626
return WalkResult::advance();
702627
if (nestedOp->hasOneUse())
703628
return WalkResult::advance();
629+
unsigned firstUse = true;
704630
for (OpOperand &use : llvm::make_early_inc_range(emptyOp->getUses())) {
631+
if (firstUse) {
632+
firstUse = false;
633+
continue;
634+
}
705635
rewriter.setInsertionPoint(use.getOwner());
706636
auto clonedOp = cast<tensor::EmptyOp>(rewriter.clone(*emptyOp));
707637
use.assign(clonedOp);
@@ -748,22 +678,16 @@ class AllocTensorsPass
748678
return signalPassFailure();
749679
}
750680

751-
// Eliminate any straggling `tensor.empty` operations. Only run this on
752-
// functions in the host module.
753-
{
754-
FrozenRewritePatternSet patterns = [&]() {
755-
RewritePatternSet patterns_(ctx);
756-
patterns_.insert<RewriteEmptyTensor, CleanupAllocTensorOps,
757-
RemoveRedundantMaterializeInDestPattern>(ctx);
758-
return patterns_;
759-
}();
760-
for (FunctionOpInterface func : op.getOps<FunctionOpInterface>()) {
761-
if (failed(applyPatternsGreedily(func, patterns))) {
762-
op->emitError() << "failed to run " << getArgument() << " patterns";
763-
return signalPassFailure();
764-
}
765-
}
766-
}
681+
// Remove leftover empty tensors.
682+
op->walk<WalkOrder::PreOrder>([&](Operation *nestedOp) {
683+
if (ModuleLikeOp(nestedOp) && nestedOp != op)
684+
return WalkResult::skip();
685+
auto emptyOp = dyn_cast<tensor::EmptyOp>(nestedOp);
686+
if (!emptyOp || !emptyOp.use_empty())
687+
return WalkResult::advance();
688+
rewriter.eraseOp(emptyOp);
689+
return WalkResult::skip();
690+
});
767691
}
768692
};
769693
} // namespace

mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/CMakeLists.txt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ add_mlir_tensorrt_library(MLIRTensorRTPlanTransforms
2828

2929
LINK_LIBS PUBLIC
3030

31+
MLIRBufferizationDialect
3132
MLIRBufferizationPipelines
33+
MLIRBufferizationToMemRef
34+
MLIRBufferizationTransforms
3235
MLIRExecutorGenericClustering
3336
MLIRFuncTransforms
3437
MLIRIR
@@ -37,21 +40,19 @@ add_mlir_tensorrt_library(MLIRTensorRTPlanTransforms
3740
MLIRSCFDialect
3841
MLIRTensorDialect
3942
MLIRTensorRTAnalysis
43+
MLIRTensorRTBufferizationScopeInterface
4044
MLIRTensorRTCUDADialect
4145
MLIRTensorRTDialect
4246
MLIRTensorRTDuplicateFunctionElimination
4347
MLIRTensorRTExecutorDialect
44-
MLIRBufferizationDialect
45-
MLIRBufferizationTransforms
48+
MLIRTensorRTLowerLinalgCopies
4649
MLIRTensorRTMemRefCastElimination
4750
MLIRTensorRTPlanAnalysis
4851
MLIRTensorRTPlanDialect
4952
MLIRTensorRTStableHloExtTransforms
5053
MLIRTensorRTStablehloScalarToArith
5154
MLIRTensorRTStablehloToTensorRT
5255
MLIRTensorRTTensorRTRuntimeDialect
53-
MLIRTensorRTBufferizationScopeInterface
54-
MLIRBufferizationToMemRef
5556
MLIRTransforms
5657
StablehloOps
5758
)

0 commit comments

Comments
 (0)