Skip to content

Commit 21f60d1

Browse files
Matthew Francis-Landaumatthewfl
authored andcommitted
Address comments on PR for transpose reshape Elimination pass
1 parent cc8d26b commit 21f60d1

File tree

6 files changed

+476
-1354
lines changed

6 files changed

+476
-1354
lines changed

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3790,6 +3790,10 @@ def TensorRT_ReshapeOp : TensorRT_Op<"reshape",
37903790
let extraClassDeclaration = [{
37913791
/// Returns true if created op is valid for TensorRT major version.
37923792
bool isValidForTensorRTVersion(int64_t trtMajorVersion);
3793+
3794+
/// Get canonicalization patterns which rewrite as ReshapeOp and
3795+
/// do NOT include rewrites which not get to a different kind of Op
3796+
/// (e.g. ExpandRankOp, CollapseRankOp).
37933797
static void getCanonicalizationPatternsSameOp(RewritePatternSet &results, MLIRContext *context);
37943798
}] # baseClassDeclaration;
37953799
}

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/Transforms/Passes.td

Lines changed: 73 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -175,97 +175,86 @@ def LegalizeInt8Pass : Pass<"tensorrt-legalize-int8", "func::FuncOp"> {
175175
let dependentDialects = ["::mlir::tensorrt::TensorRTDialect"];
176176
}
177177

178-
//===----------------------------------------------------------------------===//
179-
// TransposeEliminationPass
180-
//===----------------------------------------------------------------------===//
181-
def TransposeEliminationPass : Pass<"tensorrt-transpose-elimination"> {
182-
let summary = "try to eliminate tensorrt.transpose operations";
183-
184-
let description = [{
185-
186-
It is well-known that excessive number of transpose ops (either
187-
"tensorrt.transpose" or "tensorrt.shuffle" operations with identity reshape)
188-
can cause performance issues with TensorRT. This commonly occurs when the
189-
input source being converted represents convolutions in "NHWC" format vs.
190-
TensorRT's preferred "NCHW" format. In the conversion of these types of
191-
convolutions, a number of transpose operations must be inserted. These
192-
transpose operations can prevent fusions. For example, a transpose operation
193-
between a convolution and a pointwise addition can prevent convolution-bias
194-
fusion.
195-
196-
This pass tries to eliminate transpose operations by applying the following
197-
patterns in a greedy manner:
198-
199-
1) rotating `tensorrt.transpose` "forward" certain computational operations,
200-
especially `tensorrt.element_wise` ops. This means that the transpose will
201-
be applied to the result of the elementwise operation as well as the other
202-
branch of the operation. To avoid an infinite ping-pong application of this
203-
pattern certain heuristics are applied to determine whether or not this is
204-
beneficial. For example:
205-
206-
```
207-
func.func @transpose_pushdown_switch(%arg0: tensor<2x2xf32>, %arg1: tensor<1x2xf32>)
208-
-> tensor<2x2xf32> {
209-
%1 = tensorrt.transpose {
210-
permutation = affine_map<(d0, d1)->(d1, d0)>
211-
} %arg0 : tensor<2x2xf32> to tensor<2x2xf32>
212-
%2 = tensorrt.element_wise <kSUM> (
213-
%1, %arg1: tensor<2x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32>
214-
return %2 : tensor<2x2xf32>
215-
}
216-
```
217-
218-
becomes
219-
220-
```
221-
func.func @transpose_pushdown_switch(%arg0: tensor<2x2xf32>,
222-
%arg1: tensor<1x2xf32>) -> tensor<2x2xf32> {
223-
%0 = tensorrt.transpose {permutation = #map}
224-
%arg1 : tensor<1x2xf32> to tensor<2x1xf32>
225-
%1 = tensorrt.element_wise <kSUM>
226-
(%arg0, %0 : tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
227-
%2 = tensorrt.transpose {permutation = #map} %1 : tensor<2x2xf32> to tensor<2x2xf32>
228-
return %2 : tensor<2x2xf32>
229-
}
230-
```
231-
232-
In this case, moving the transpose to the other branch results in lower
233-
memory cost on the inputs, but higher total memory cost (because a transpose
234-
on the result is also added). However, we always prefer to push transpose
235-
operations as far forward as possible in this transformation.
236-
237-
2) Const-folding transpose operations. Often, it is undesirable to let
238-
weights be transposed at runtime. Instead, weights should be pre-transformed
239-
to put them into a form that is suitable for TensorRT convolutions.
240-
Therefore, we apply global transpose-const folding. This can be quite
241-
expensive for large weights but is important to reduce runtime transpose
242-
costs.
243-
244-
}];
245-
}
246-
247-
//===----------------------------------------------------------------------===//
248-
// ReshapeEliminationPass
249-
//===----------------------------------------------------------------------===//
250-
def ReshapeEliminationPass : Pass<"tensorrt-reshape-elimination"> {
251-
let summary = "try to eliminate tensorrt.reshape operations";
252-
253-
let description = [{
254-
Reshape elimination pass captures pattern with un-necessary reshape and
255-
simplifies it by eliminating reshape operations.
256-
}];
257-
}
258-
259178
//===----------------------------------------------------------------------===//
260179
// TransposeReshapeEliminationPass
261180
//===----------------------------------------------------------------------===//
262181
def TransposeReshapeEliminationPass : Pass<"tensorrt-transpose-reshape-elimination"> {
263182
let summary = "try to eliminate tensorrt.transpose, tensorrt.reshape, and tensorrt.shuffle operations";
264183

265184
let description = [{
266-
Push tensorrt.transpose and tensorrt.reshape operations around to attempt to eleminate them
267-
and merge them with other ops such as matrix multiply. The intention is to improve
268-
pattern matching and fusion inside of TensorRT.
185+
It is well-known that excessive number of transpose or reshapes ops (either
186+
"tensorrt.transpose", "tensorrt.reshape" or "tensorrt.shuffle")
187+
can cause performance issues with TensorRT. For example, this commonly occurs
188+
when the input source being converted represents convolutions in "NHWC" format
189+
vs. TensorRT's preferred "NCHW" format. In the conversion of these types of
190+
convolutions, a number of transpose operations must be inserted. These
191+
transpose operations can prevent fusions. For example, a transpose operation
192+
between a convolution and a pointwise addition can prevent convolution-bias
193+
fusion. Fusions can also be blocked if a reshape is placed between a
194+
matrix multiplication and an activation.
195+
196+
The pass tries to eliminate transposes and reshapes by pushing the transposes
197+
and reshapes around to combine them into identity transposes and reshapes which
198+
can be eliminated from the program. To accomplish this, this pass uses several
199+
rewrite rules that push transposes and reshapes from the output of an Operation
200+
to the operation's input, and vice-versa. The rewrites currently included in this
201+
pass handle common cases, though currently does not every possible scenario---one
202+
may wish to extend this pass in the future as needed.
203+
204+
The process is as follows:
205+
1) Normalize transpose, reshape, shuffle and matrix multiply into a common set of ops.
206+
- shuffle(x) -> reshape(transpose(reshape(x)))
207+
- matrix_multiply(x, y) -> einsum("ij,jk->ik", x, y)
208+
- expand_rank(x) -> reshape(x)
209+
- collapse_rank(x) -> reshape(x)
210+
2) Push down reshape and transpose, eliminating when possible. E.g. op(transpose(x)) -> transpose(op(x))
211+
- einsum(transpose(x), ...) -> einsum(x, ...)
212+
- einsum(...) -> transpose(einsum(...)) Pull transposes out of einsum (to try to match matrix multiply pattern)
213+
- einsum(reshape(x), y, ...) -> transpose(reshape(einsum(x, reshape(transpose(y)), ...))
214+
- unary(transpose(x)) -> transpose(unary(x))
215+
- activation(transpose(x)) -> transpose(activation(x))
216+
- identity_op(transpose(x)) -> transpose(identity_op(x))
217+
- activation(reshape(x)) -> reshape(activation(x))
218+
- unary(reshape(x)) -> reshape(unary(x))
219+
- identity_op(reshape(x)) -> reshape(identity_op(x))
220+
- reshape(transpose(x)) -> transpose(reshape(x)) if possible put reshape before transpose
221+
- dequantize(quantize(transpose(x))) -> transpose(dequantize(quantize((x))) if the scale is 0-dim
222+
- dequantize(quantize(reshape(x))) -> reshape(dequantize(quantize(x))) if the scale is 0-dim
223+
- reshape(reshape(x)) -> reshape(x)
224+
- transpose(transpose(x)) -> transpose(x)
225+
- reshape(x) -> x if reshape is identity
226+
- transpose(x) -> x if transpose is identity
227+
- elementwise(reshape(a), b) -> reshape(elementwise(a, reshape(b))) conditioned on heuristic
228+
- elementwise(transpose(a), b) -> transpose(elementwise(a, transpose(b)))
229+
3) Push up reshape and transpose, eliminating when possible. E.g. transpose(op(x)) -> op(transpose(x))
230+
- transpose(einsum(...)) -> einsum(...)
231+
- einsum(...) -> einsum(transpose(x), ...) Pull transposes out of einsum (to try to match matrix multiply pattern)
232+
- reshape(einsum(...)) -> einsum(reshape(transpose(x)), ...)
233+
- transpose(activation(x)) -> activation(transpose(x))
234+
- transpose(unary(x)) -> unary(transpose(x))
235+
- transpose(identity_op(x)) -> identity_op(transpose(x))
236+
- reshape(activation(x)) -> activation(reshape(x))
237+
- reshape(unary(x)) -> unary(reshape(x))
238+
- reshape(identity_op(x)) -> identity_op(reshape(x))
239+
- reshape(reshape(x)) -> reshape(x)
240+
- transpose(transpose(x)) -> transpose(x)
241+
- reshape(x) -> x if reshape is identity
242+
- transpose(x) -> x if transpose is identity
243+
- transpose(reshape(x)) -> reshape(transpose(x)) if possible put transpose before reshape
244+
- transpose(dequantize(quantize(x))) -> dequantize(quantize(transpose(x))) if the scale is 0-dim
245+
- reshape(dequantize(quantize(x))) -> dequantize(quantize(reshape(x))) if the scale is 0-dim
246+
- reshape(elementwise(a, b)) -> elementwise(reshape(a), reshape(b))
247+
- transpose(elementwise(a, b)) -> elementwise(transpose(a), transpose(b))
248+
4) Final clean up. Fuse leftover transposes and reshapes with other ops.
249+
- einsum("ij,jk->ik", x, y) -> matrix_multiply(x, y) if einsum matches a matrix multiply pattern
250+
- matrix_multiply(transpose(x), y) -> matrix_multiply(x, y) merge transpose if possible
251+
- transpose(einsum(...)) -> einsum(...)
252+
- einsum(tranpose(x), ...) -> einsum(...)
253+
- einsum(collapse_rank(x), ...) -> einsum(...)
254+
- expand_rank(einsum(...)) -> einsum(...)
255+
256+
To avoid an infinite ping-pong application of these patterns, heuristics are
257+
applied to determine when a pattern is beneficial.
269258
}];
270259
}
271260

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ add_mtrtd_library(MLIRTensorRTTransforms
1919
Passes.cpp
2020
RaiseActivations.cpp
2121
RaiseNormalizations.cpp
22-
ReshapeElimination.cpp
23-
TransposeElimination.cpp
2422
TransposeReshapeElimination.cpp
2523

2624
DEPENDS

0 commit comments

Comments
 (0)