Skip to content

Commit a7d0a6b

Browse files
Matthew Francis-Landaumatthewfl
authored andcommitted
[mlir-tensorrt] Transpose Reshape Elimination pass.
This is a new pass that is designed to replace the Transpose and Reshape Elemination passes. This pass adds a lot of new rewrite rules which enable pushing the transposes and reshapes around so that they can be combined and then eliminated. The motivation for this pass is that there are some cases where shuffles can get inserted around matrix multiplications and element wise ops which break various fusions inside of TensorRT. To accomplish this, this pass uses several rewrite rules that push transposes and reshapes around to combine them into identity transposes and reshapes which can be eliminated from the program. The rewrite rules are as follows: 1. "canonicalize" the network into simpler ops - `shuffle(x)` -> `reshape(transpose(reshape(x)))` - `matrix_multiply(x, y)` -> `einsum("ij,jk->ik", x, y)` - `expand_rank(x)` -> `reshape(x)` - `collapse_rank(x)` -> `reshape(x)` 2. Push down `reshape` and `transpose` ops as much as possible. Merging and eliminating when possible - `einsum(transpose(x), ...)` -> `einsum(x, ...)` Merge transpose into einsum - `einsum(...)` -> `transpose(einsum(...))` Pull transpose out of einsum (to try to match matrix multiply pattern) - `einsum(reshape(x), y, ...)` -> `transpose(reshape(einsum(x, reshape(transpose(y)), ...)))` Push reshape down. Possibly add reshape and transposes to other inputs as needed. Conditioned on heuristic checking if "better" - `unary(transpose(x))` -> `transpose(unary(x))` - `activation(transpose(x))` -> `transpose(activation(x))` - `identity_op(transpose(x))` -> `transpose(identity_op(x))` - `activation(reshape(x))` -> `reshape(activation(x))` - `unary(reshape(x))` -> `reshape(unary(x))` - `identity_op(reshape(x))` -> `reshape(identity_op(x))` - `reshape(transpose(x))` -> `transpose(reshape(x))` if possible put reshape before transpose - `qdq(transpose(x))` -> `transpose(qdq(x))` if the scale is 0-dim - `qdq(reshape(x))` -> `reshape(qdq(x))` if the scale is 0-dim - `reshape(reshape(x))` -> `reshape(x)` - `transpose(transpose(x))` -> `transpose(x)` - `reshape(x)` -> `x` if `reshape` is identity - `transpose(x)` -> `x` if `transpose` is identity - `elementwise(reshape(a), b)` -> `reshape(elementwise(a, reshape(b)))` conditioned on heuristic - `elementwise(transpose(a), b)` -> `transpose(elementwise(a, transpose(b)))` - `softmax(transpose(x))` -> `transpose(softmax(x))` - `softmax(reshape(x))` -> `reshape(softmax(x))` 3. Push up `reshape` and `transpose` ops as much as possible. Merging and eliminating when possible - `transpose(einsum(...))` -> `einsum(...)`. Merge transpose into einsum - `einsum(...)` -> `einsum(transpose(x), ...)`. Pull transposes out of einsum (to try to match matrix multiply pattern) - `reshape(einsum(...))` -> `einsum(reshape(transpose(x)), ...)` Push reshapes up through einsum. Adding transposes as needed - `transpose(activation(x))` -> `activation(transpose(x))` - `transpose(unary(x))` -> `unary(transpose(x))` - `transpose(identity_op(x))` -> `identity_op(transpose(x))` - `reshape(activation(x))` -> `activation(reshape(x))` - `reshape(unary(x))` -> `unary(reshape(x))` - `reshape(identity_op(x))` -> `identity_op(reshape(x))` - `reshape(reshape(x))` -> `reshape(x)` - `transpose(transpose(x))` -> `transpose(x)` - `reshape(x)` -> `x` if `reshape` is identity - `transpose(x)` -> `x` if `transpose` is identity - `transpose(reshape(x))` -> `reshape(transpose(x))` if possible put transpose before reshape - `transpose(qdq(x))` -> `qdq(transpose(x))` if the scale is 0-dim - `reshape(qdq(x))` -> `qdq(reshape(x))` if the scale is 0-dim - `reshape(elementwise(a, b))` -> `elementwise(reshape(a), reshape(b))` - `transpose(elementwise(a, b))` -> `elementwise(transpose(a), transpose(b))` - `transpose(softmax(x))` -> `softmax(transpose(x))` - `reshape(softmax(x))` -> `softmax(reshape(x))` 4. Convert back to matrix multiplication form to assist with TRT's pattern matching - `einsum(x, y)` -> `matrix_multiply(x, y)` if einsum matches a matrix multiply pattern - `matrix_multiply(transpose(x), y)` -> `matrix_multiply(x, y)` merge transpose if possible 5. Final clean ups, additional merging of transpose/reshapes into leftover einsums - `einsum(x, y)` -> `matrix_multiply(x, y)` if einsum matches a matrix multiply pattern - `matrix_multiply(transpose(x), y)` -> `matrix_multiply(x, y)` merge transpose if possible - `transpose(einsum(...))` -> `einsum(...)` - `einsum(tranpose(x), ...)` -> `einsum(...)` - `einsum(collapse_rank(x), ...)` -> `einsum(...)` - `expand_rank(einsum(...))` -> `einsum(...)`
1 parent b344e42 commit a7d0a6b

File tree

12 files changed

+3611
-830
lines changed

12 files changed

+3611
-830
lines changed

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3875,6 +3875,11 @@ def TensorRT_ReshapeOp : TensorRT_Op<"reshape",
38753875
let extraClassDeclaration = [{
38763876
/// Returns true if created op is valid for TensorRT major version.
38773877
bool isValidForTensorRTVersion(int64_t trtMajorVersion);
3878+
3879+
/// Get canonicalization patterns which rewrite as ReshapeOp and
3880+
/// do NOT include rewrites which do not get to a different kind of Op
3881+
/// (e.g. ExpandRankOp, CollapseRankOp).
3882+
static void getCanonicalizationPatternsSameOp(RewritePatternSet &results, MLIRContext *context);
38783883
}] # baseClassDeclaration;
38793884
}
38803885

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/Transforms/Passes.td

Lines changed: 73 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -176,84 +176,87 @@ def LegalizeInt8Pass : Pass<"tensorrt-legalize-int8", "func::FuncOp"> {
176176
}
177177

178178
//===----------------------------------------------------------------------===//
179-
// TransposeEliminationPass
179+
// TransposeReshapeEliminationPass
180180
//===----------------------------------------------------------------------===//
181-
def TransposeEliminationPass : Pass<"tensorrt-transpose-elimination"> {
182-
let summary = "try to eliminate tensorrt.transpose operations";
181+
def TransposeReshapeEliminationPass : Pass<"tensorrt-transpose-reshape-elimination"> {
182+
let summary = "try to eliminate tensorrt.transpose, tensorrt.reshape, and tensorrt.shuffle operations";
183183

184184
let description = [{
185-
186-
It is well-known that excessive number of transpose ops (either
187-
"tensorrt.transpose" or "tensorrt.shuffle" operations with identity reshape)
188-
can cause performance issues with TensorRT. This commonly occurs when the
189-
input source being converted represents convolutions in "NHWC" format vs.
190-
TensorRT's preferred "NCHW" format. In the conversion of these types of
185+
It is well-known that excessive number of transpose or reshapes ops (either
186+
"tensorrt.transpose", "tensorrt.reshape" or "tensorrt.shuffle")
187+
can cause performance issues with TensorRT. For example, this commonly occurs
188+
when the input source being converted represents convolutions in "NHWC" format
189+
vs. TensorRT's preferred "NCHW" format. In the conversion of these types of
191190
convolutions, a number of transpose operations must be inserted. These
192191
transpose operations can prevent fusions. For example, a transpose operation
193192
between a convolution and a pointwise addition can prevent convolution-bias
194-
fusion.
195-
196-
This pass tries to eliminate transpose operations by applying the following
197-
patterns in a greedy manner:
198-
199-
1) rotating `tensorrt.transpose` "forward" certain computational operations,
200-
especially `tensorrt.element_wise` ops. This means that the transpose will
201-
be applied to the result of the elementwise operation as well as the other
202-
branch of the operation. To avoid an infinite ping-pong application of this
203-
pattern certain heuristics are applied to determine whether or not this is
204-
beneficial. For example:
205-
206-
```
207-
func.func @transpose_pushdown_switch(%arg0: tensor<2x2xf32>, %arg1: tensor<1x2xf32>)
208-
-> tensor<2x2xf32> {
209-
%1 = tensorrt.transpose {
210-
permutation = affine_map<(d0, d1)->(d1, d0)>
211-
} %arg0 : tensor<2x2xf32> to tensor<2x2xf32>
212-
%2 = tensorrt.element_wise <kSUM> (
213-
%1, %arg1: tensor<2x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32>
214-
return %2 : tensor<2x2xf32>
215-
}
216-
```
217-
218-
becomes
219-
220-
```
221-
func.func @transpose_pushdown_switch(%arg0: tensor<2x2xf32>,
222-
%arg1: tensor<1x2xf32>) -> tensor<2x2xf32> {
223-
%0 = tensorrt.transpose {permutation = #map}
224-
%arg1 : tensor<1x2xf32> to tensor<2x1xf32>
225-
%1 = tensorrt.element_wise <kSUM>
226-
(%arg0, %0 : tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
227-
%2 = tensorrt.transpose {permutation = #map} %1 : tensor<2x2xf32> to tensor<2x2xf32>
228-
return %2 : tensor<2x2xf32>
229-
}
230-
```
231-
232-
In this case, moving the transpose to the other branch results in lower
233-
memory cost on the inputs, but higher total memory cost (because a transpose
234-
on the result is also added). However, we always prefer to push transpose
235-
operations as far forward as possible in this transformation.
236-
237-
2) Const-folding transpose operations. Often, it is undesirable to let
238-
weights be transposed at runtime. Instead, weights should be pre-transformed
239-
to put them into a form that is suitable for TensorRT convolutions.
240-
Therefore, we apply global transpose-const folding. This can be quite
241-
expensive for large weights but is important to reduce runtime transpose
242-
costs.
243-
193+
fusion. Fusions can also be blocked if a reshape is placed between a
194+
matrix multiplication and an activation.
195+
196+
The pass tries to eliminate transposes and reshapes by pushing the transposes
197+
and reshapes around to combine them into identity transposes and reshapes which
198+
can be eliminated from the program. To accomplish this, this pass uses several
199+
rewrite rules that push transposes and reshapes from the output of an Operation
200+
to the operation's input, and vice-versa. The rewrites currently included in this
201+
pass handle common cases, though currently does not handle every possible scenario---one
202+
may wish to extend this pass in the future as needed.
203+
204+
The process is as follows:
205+
1) Normalize transpose, reshape, shuffle and matrix multiply into a common set of ops.
206+
- shuffle(x) -> reshape(transpose(reshape(x)))
207+
- matrix_multiply(x, y) -> einsum("ij,jk->ik", x, y)
208+
- expand_rank(x) -> reshape(x)
209+
- collapse_rank(x) -> reshape(x)
210+
2) Push down reshape and transpose, eliminating when possible. E.g. op(transpose(x)) -> transpose(op(x))
211+
- einsum(transpose(x), ...) -> einsum(x, ...)
212+
- einsum(...) -> transpose(einsum(...)) Pull transposes out of einsum (to try to match matrix multiply pattern)
213+
- einsum(reshape(x), y, ...) -> transpose(reshape(einsum(x, reshape(transpose(y)), ...))
214+
- unary(transpose(x)) -> transpose(unary(x))
215+
- activation(transpose(x)) -> transpose(activation(x))
216+
- identity_op(transpose(x)) -> transpose(identity_op(x))
217+
- activation(reshape(x)) -> reshape(activation(x))
218+
- unary(reshape(x)) -> reshape(unary(x))
219+
- identity_op(reshape(x)) -> reshape(identity_op(x))
220+
- reshape(transpose(x)) -> transpose(reshape(x)) if possible put reshape before transpose
221+
- dequantize(quantize(transpose(x))) -> transpose(dequantize(quantize((x))) if the scale is 0-dim
222+
- dequantize(quantize(reshape(x))) -> reshape(dequantize(quantize(x))) if the scale is 0-dim
223+
- reshape(reshape(x)) -> reshape(x)
224+
- transpose(transpose(x)) -> transpose(x)
225+
- reshape(x) -> x if reshape is identity
226+
- transpose(x) -> x if transpose is identity
227+
- elementwise(reshape(a), b) -> reshape(elementwise(a, reshape(b))) conditioned on heuristic
228+
- elementwise(transpose(a), b) -> transpose(elementwise(a, transpose(b)))
229+
3) Push up reshape and transpose, eliminating when possible. E.g. transpose(op(x)) -> op(transpose(x))
230+
- transpose(einsum(...)) -> einsum(...)
231+
- einsum(...) -> einsum(transpose(x), ...) Pull transposes out of einsum (to try to match matrix multiply pattern)
232+
- reshape(einsum(...)) -> einsum(reshape(transpose(x)), ...)
233+
- transpose(activation(x)) -> activation(transpose(x))
234+
- transpose(unary(x)) -> unary(transpose(x))
235+
- transpose(identity_op(x)) -> identity_op(transpose(x))
236+
- reshape(activation(x)) -> activation(reshape(x))
237+
- reshape(unary(x)) -> unary(reshape(x))
238+
- reshape(identity_op(x)) -> identity_op(reshape(x))
239+
- reshape(reshape(x)) -> reshape(x)
240+
- transpose(transpose(x)) -> transpose(x)
241+
- reshape(x) -> x if reshape is identity
242+
- transpose(x) -> x if transpose is identity
243+
- transpose(reshape(x)) -> reshape(transpose(x)) if possible put transpose before reshape
244+
- transpose(dequantize(quantize(x))) -> dequantize(quantize(transpose(x))) if the scale is 0-dim
245+
- reshape(dequantize(quantize(x))) -> dequantize(quantize(reshape(x))) if the scale is 0-dim
246+
- reshape(elementwise(a, b)) -> elementwise(reshape(a), reshape(b))
247+
- transpose(elementwise(a, b)) -> elementwise(transpose(a), transpose(b))
248+
4) Final clean up. Fuse leftover transposes and reshapes with other ops.
249+
- einsum("ij,jk->ik", x, y) -> matrix_multiply(x, y) if einsum matches a matrix multiply pattern
250+
- matrix_multiply(transpose(x), y) -> matrix_multiply(x, y) merge transpose if possible
251+
- transpose(einsum(...)) -> einsum(...)
252+
- einsum(transpose(x), ...) -> einsum(...)
253+
- einsum(collapse_rank(x), ...) -> einsum(...)
254+
- expand_rank(einsum(...)) -> einsum(...)
255+
256+
To avoid an infinite ping-pong application of these patterns, heuristics are
257+
applied to determine when a pattern is beneficial.
244258
}];
245259
}
246260

247-
//===----------------------------------------------------------------------===//
248-
// ReshapeEliminationPass
249-
//===----------------------------------------------------------------------===//
250-
def ReshapeEliminationPass : Pass<"tensorrt-reshape-elimination"> {
251-
let summary = "try to eliminate tensorrt.reshape operations";
252-
253-
let description = [{
254-
Reshape elimination pass captures pattern with un-necessary reshape and
255-
simplifies it by eliminating reshape operations.
256-
}];
257-
}
258261

259262
#endif // MLIR_TENSORRT_DIALECT_TENSORRT_TRANSFORMS_PASSES

mlir-tensorrt/tensorrt/lib/TensorRT/IR/EinsumHelper.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,11 @@ static LogicalResult validateInputsSubscript(const IOSubscripts &subscripts,
143143
// match. for example, ('ij,jk->ik', a, b) is valid for a =
144144
// tensor<4x5xf32>, b = tensor<5x6xf32> but invalid for a =
145145
// tensor<4x6xf32>, b = tensor<5x6xf32>
146-
if (allLabelDims.count(label) == 0) {
147-
allLabelDims.insert(std::pair<char, int64_t>(label, dimension));
146+
// Einsum also supports broadcasting
147+
if (allLabelDims.count(label) == 0 || allLabelDims[label] == 1) {
148+
allLabelDims[label] = dimension;
148149
} else {
149-
if (allLabelDims[label] != dimension)
150+
if (allLabelDims[label] != dimension && dimension != 1)
150151
return emitErrorFn(loc, Twine("label `") + Twine(label) +
151152
Twine("` is repeated between inputs but "
152153
"dimensions are not same"));
@@ -203,8 +204,8 @@ static LogicalResult inferOutputShapeImpl(const IOSubscripts &ioSubscripts,
203204
llvm::zip((ioSubscripts).inputs, inputOperands)) {
204205
for (const auto &[label, dims] :
205206
llvm::zip(subscript, cast<RankedTensorType>(operand).getShape()))
206-
if (inputLabelsDims.count(label) == 0)
207-
inputLabelsDims.insert(std::pair<char, int64_t>(label, dims));
207+
if (inputLabelsDims.count(label) == 0 || inputLabelsDims[label] == 1)
208+
inputLabelsDims[label] = dims;
208209
}
209210

210211
for (const auto &label : (ioSubscripts).outputs) {

mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRT.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1850,6 +1850,13 @@ void tensorrt::ReshapeOp::getCanonicalizationPatterns(
18501850
SimplifyReshapeToRankExpandCollapse>(context);
18511851
}
18521852

1853+
void tensorrt::ReshapeOp::getCanonicalizationPatternsSameOp(
1854+
RewritePatternSet &results, MLIRContext *context) {
1855+
results.insert<ConstFoldReshapePattern<ReshapeOp>, SimplifyReshapeReshape
1856+
// NOT INCLUDED: SimplifyReshapeToRankExpandCollapse
1857+
>(context);
1858+
}
1859+
18531860
void tensorrt::ReshapeOp::build(OpBuilder &builder, OperationState &state,
18541861
Type result, Value input) {
18551862
ReshapeOp::build(builder, state, result, input, Value());

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ add_mtrtd_library(MLIRTensorRTTransforms
1919
Passes.cpp
2020
RaiseActivations.cpp
2121
RaiseNormalizations.cpp
22-
ReshapeElimination.cpp
23-
TransposeElimination.cpp
22+
TransposeReshapeElimination.cpp
2423

2524
DEPENDS
2625
MLIRTensorRTTransformsActivationsPdllGen

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/Passes.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,7 @@ void tensorrt::buildTensorRTModuleSimplificationPipeline(OpPassManager &pm) {
8686
// Try to eliminate as many `tensorrt.broadcast` ops as possible.
8787
pm.addPass(tensorrt::createBroadcastEliminationPass());
8888
addCleanupPasses(pm);
89-
pm.addPass(tensorrt::createTransposeEliminationPass());
90-
addCleanupPasses(pm);
91-
pm.addPass(tensorrt::createReshapeEliminationPass());
89+
pm.addPass(tensorrt::createTransposeReshapeEliminationPass());
9290
addCleanupPasses(pm);
9391
pm.addPass(tensorrt::createRaiseNormalizationsPass());
9492
addCleanupPasses(pm);

0 commit comments

Comments
 (0)