@@ -175,97 +175,86 @@ def LegalizeInt8Pass : Pass<"tensorrt-legalize-int8", "func::FuncOp"> {
175175 let dependentDialects = ["::mlir::tensorrt::TensorRTDialect"];
176176}
177177
178- //===----------------------------------------------------------------------===//
179- // TransposeEliminationPass
180- //===----------------------------------------------------------------------===//
181- def TransposeEliminationPass : Pass<"tensorrt-transpose-elimination"> {
182- let summary = "try to eliminate tensorrt.transpose operations";
183-
184- let description = [{
185-
186- It is well-known that excessive number of transpose ops (either
187- "tensorrt.transpose" or "tensorrt.shuffle" operations with identity reshape)
188- can cause performance issues with TensorRT. This commonly occurs when the
189- input source being converted represents convolutions in "NHWC" format vs.
190- TensorRT's preferred "NCHW" format. In the conversion of these types of
191- convolutions, a number of transpose operations must be inserted. These
192- transpose operations can prevent fusions. For example, a transpose operation
193- between a convolution and a pointwise addition can prevent convolution-bias
194- fusion.
195-
196- This pass tries to eliminate transpose operations by applying the following
197- patterns in a greedy manner:
198-
199- 1) rotating `tensorrt.transpose` "forward" certain computational operations,
200- especially `tensorrt.element_wise` ops. This means that the transpose will
201- be applied to the result of the elementwise operation as well as the other
202- branch of the operation. To avoid an infinite ping-pong application of this
203- pattern certain heuristics are applied to determine whether or not this is
204- beneficial. For example:
205-
206- ```
207- func.func @transpose_pushdown_switch(%arg0: tensor<2x2xf32>, %arg1: tensor<1x2xf32>)
208- -> tensor<2x2xf32> {
209- %1 = tensorrt.transpose {
210- permutation = affine_map<(d0, d1)->(d1, d0)>
211- } %arg0 : tensor<2x2xf32> to tensor<2x2xf32>
212- %2 = tensorrt.element_wise <kSUM> (
213- %1, %arg1: tensor<2x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32>
214- return %2 : tensor<2x2xf32>
215- }
216- ```
217-
218- becomes
219-
220- ```
221- func.func @transpose_pushdown_switch(%arg0: tensor<2x2xf32>,
222- %arg1: tensor<1x2xf32>) -> tensor<2x2xf32> {
223- %0 = tensorrt.transpose {permutation = #map}
224- %arg1 : tensor<1x2xf32> to tensor<2x1xf32>
225- %1 = tensorrt.element_wise <kSUM>
226- (%arg0, %0 : tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x2xf32>
227- %2 = tensorrt.transpose {permutation = #map} %1 : tensor<2x2xf32> to tensor<2x2xf32>
228- return %2 : tensor<2x2xf32>
229- }
230- ```
231-
232- In this case, moving the transpose to the other branch results in lower
233- memory cost on the inputs, but higher total memory cost (because a transpose
234- on the result is also added). However, we always prefer to push transpose
235- operations as far forward as possible in this transformation.
236-
237- 2) Const-folding transpose operations. Often, it is undesirable to let
238- weights be transposed at runtime. Instead, weights should be pre-transformed
239- to put them into a form that is suitable for TensorRT convolutions.
240- Therefore, we apply global transpose-const folding. This can be quite
241- expensive for large weights but is important to reduce runtime transpose
242- costs.
243-
244- }];
245- }
246-
247- //===----------------------------------------------------------------------===//
248- // ReshapeEliminationPass
249- //===----------------------------------------------------------------------===//
250- def ReshapeEliminationPass : Pass<"tensorrt-reshape-elimination"> {
251- let summary = "try to eliminate tensorrt.reshape operations";
252-
253- let description = [{
254- Reshape elimination pass captures pattern with un-necessary reshape and
255- simplifies it by eliminating reshape operations.
256- }];
257- }
258-
259178//===----------------------------------------------------------------------===//
260179// TransposeReshapeEliminationPass
261180//===----------------------------------------------------------------------===//
262181def TransposeReshapeEliminationPass : Pass<"tensorrt-transpose-reshape-elimination"> {
263182 let summary = "try to eliminate tensorrt.transpose, tensorrt.reshape, and tensorrt.shuffle operations";
264183
265184 let description = [{
266- Push tensorrt.transpose and tensorrt.reshape operations around to attempt to eleminate them
267- and merge them with other ops such as matrix multiply. The intention is to improve
268- pattern matching and fusion inside of TensorRT.
185+ It is well-known that excessive number of transpose or reshapes ops (either
186+ "tensorrt.transpose", "tensorrt.reshape" or "tensorrt.shuffle")
187+ can cause performance issues with TensorRT. For example, this commonly occurs
188+ when the input source being converted represents convolutions in "NHWC" format
189+ vs. TensorRT's preferred "NCHW" format. In the conversion of these types of
190+ convolutions, a number of transpose operations must be inserted. These
191+ transpose operations can prevent fusions. For example, a transpose operation
192+ between a convolution and a pointwise addition can prevent convolution-bias
193+ fusion. Fusions can also be blocked if a reshape is placed between a
194+ matrix multiplication and an activation.
195+
196+ The pass tries to eliminate transposes and reshapes by pushing the transposes
197+ and reshapes around to combine them into identity transposes and reshapes which
198+ can be eliminated from the program. To accomplish this, this pass uses several
199+ rewrite rules that push transposes and reshapes from the output of an Operation
200+ to the operation's input, and vice-versa. The rewrites currently included in this
201+ pass handle common cases, though currently does not every possible scenario---one
202+ may wish to extend this pass in the future as needed.
203+
204+ The process is as follows:
205+ 1) Normalize transpose, reshape, shuffle and matrix multiply into a common set of ops.
206+ - shuffle(x) -> reshape(transpose(reshape(x)))
207+ - matrix_multiply(x, y) -> einsum("ij,jk->ik", x, y)
208+ - expand_rank(x) -> reshape(x)
209+ - collapse_rank(x) -> reshape(x)
210+ 2) Push down reshape and transpose, eliminating when possible. E.g. op(transpose(x)) -> transpose(op(x))
211+ - einsum(transpose(x), ...) -> einsum(x, ...)
212+ - einsum(...) -> transpose(einsum(...)) Pull transposes out of einsum (to try to match matrix multiply pattern)
213+ - einsum(reshape(x), y, ...) -> transpose(reshape(einsum(x, reshape(transpose(y)), ...))
214+ - unary(transpose(x)) -> transpose(unary(x))
215+ - activation(transpose(x)) -> transpose(activation(x))
216+ - identity_op(transpose(x)) -> transpose(identity_op(x))
217+ - activation(reshape(x)) -> reshape(activation(x))
218+ - unary(reshape(x)) -> reshape(unary(x))
219+ - identity_op(reshape(x)) -> reshape(identity_op(x))
220+ - reshape(transpose(x)) -> transpose(reshape(x)) if possible put reshape before transpose
221+ - dequantize(quantize(transpose(x))) -> transpose(dequantize(quantize((x))) if the scale is 0-dim
222+ - dequantize(quantize(reshape(x))) -> reshape(dequantize(quantize(x))) if the scale is 0-dim
223+ - reshape(reshape(x)) -> reshape(x)
224+ - transpose(transpose(x)) -> transpose(x)
225+ - reshape(x) -> x if reshape is identity
226+ - transpose(x) -> x if transpose is identity
227+ - elementwise(reshape(a), b) -> reshape(elementwise(a, reshape(b))) conditioned on heuristic
228+ - elementwise(transpose(a), b) -> transpose(elementwise(a, transpose(b)))
229+ 3) Push up reshape and transpose, eliminating when possible. E.g. transpose(op(x)) -> op(transpose(x))
230+ - transpose(einsum(...)) -> einsum(...)
231+ - einsum(...) -> einsum(transpose(x), ...) Pull transposes out of einsum (to try to match matrix multiply pattern)
232+ - reshape(einsum(...)) -> einsum(reshape(transpose(x)), ...)
233+ - transpose(activation(x)) -> activation(transpose(x))
234+ - transpose(unary(x)) -> unary(transpose(x))
235+ - transpose(identity_op(x)) -> identity_op(transpose(x))
236+ - reshape(activation(x)) -> activation(reshape(x))
237+ - reshape(unary(x)) -> unary(reshape(x))
238+ - reshape(identity_op(x)) -> identity_op(reshape(x))
239+ - reshape(reshape(x)) -> reshape(x)
240+ - transpose(transpose(x)) -> transpose(x)
241+ - reshape(x) -> x if reshape is identity
242+ - transpose(x) -> x if transpose is identity
243+ - transpose(reshape(x)) -> reshape(transpose(x)) if possible put transpose before reshape
244+ - transpose(dequantize(quantize(x))) -> dequantize(quantize(transpose(x))) if the scale is 0-dim
245+ - reshape(dequantize(quantize(x))) -> dequantize(quantize(reshape(x))) if the scale is 0-dim
246+ - reshape(elementwise(a, b)) -> elementwise(reshape(a), reshape(b))
247+ - transpose(elementwise(a, b)) -> elementwise(transpose(a), transpose(b))
248+ 4) Final clean up. Fuse leftover transposes and reshapes with other ops.
249+ - einsum("ij,jk->ik", x, y) -> matrix_multiply(x, y) if einsum matches a matrix multiply pattern
250+ - matrix_multiply(transpose(x), y) -> matrix_multiply(x, y) merge transpose if possible
251+ - transpose(einsum(...)) -> einsum(...)
252+ - einsum(tranpose(x), ...) -> einsum(...)
253+ - einsum(collapse_rank(x), ...) -> einsum(...)
254+ - expand_rank(einsum(...)) -> einsum(...)
255+
256+ To avoid an infinite ping-pong application of these patterns, heuristics are
257+ applied to determine when a pattern is beneficial.
269258 }];
270259}
271260
0 commit comments