You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is a new pass that is designed to replace the Transpose
and Reshape Elemination passes. This pass adds a lot of new rewrite
rules which enable pushing the transposes and reshapes around so that
they can be combined and then eliminated.
The motivation for this pass is that there are some cases where shuffles
can get inserted around matrix multiplications and element wise ops
which break various fusions inside of TensorRT.
To accomplish this, this pass uses several rewrite rules that push transposes
and reshapes around to combine them into identity transposes and reshapes which
can be eliminated from the program. The rewrite rules are as follows:
1. "canonicalize" the network into simpler ops
- `shuffle(x)` -> `reshape(transpose(reshape(x)))`
- `matrix_multiply(x, y)` -> `einsum("ij,jk->ik", x, y)`
- `expand_rank(x)` -> `reshape(x)`
- `collapse_rank(x)` -> `reshape(x)`
2. Push down `reshape` and `transpose` ops as much as possible. Merging and eliminating when possible
- `einsum(transpose(x), ...)` -> `einsum(x, ...)` Merge transpose into einsum
- `einsum(...)` -> `transpose(einsum(...))` Pull transpose out of einsum (to try to match matrix multiply pattern)
- `einsum(reshape(x), y, ...)` -> `transpose(reshape(einsum(x, reshape(transpose(y)), ...)))` Push reshape down. Possibly add reshape and transposes to other inputs as needed. Conditioned on heuristic checking if "better"
- `unary(transpose(x))` -> `transpose(unary(x))`
- `activation(transpose(x))` -> `transpose(activation(x))`
- `identity_op(transpose(x))` -> `transpose(identity_op(x))`
- `activation(reshape(x))` -> `reshape(activation(x))`
- `unary(reshape(x))` -> `reshape(unary(x))`
- `identity_op(reshape(x))` -> `reshape(identity_op(x))`
- `reshape(transpose(x))` -> `transpose(reshape(x))` if possible put reshape before transpose
- `qdq(transpose(x))` -> `transpose(qdq(x))` if the scale is 0-dim
- `qdq(reshape(x))` -> `reshape(qdq(x))` if the scale is 0-dim
- `reshape(reshape(x))` -> `reshape(x)`
- `transpose(transpose(x))` -> `transpose(x)`
- `reshape(x)` -> `x` if `reshape` is identity
- `transpose(x)` -> `x` if `transpose` is identity
- `elementwise(reshape(a), b)` -> `reshape(elementwise(a, reshape(b)))` conditioned on heuristic
- `elementwise(transpose(a), b)` -> `transpose(elementwise(a, transpose(b)))`
- `softmax(transpose(x))` -> `transpose(softmax(x))`
- `softmax(reshape(x))` -> `reshape(softmax(x))`
3. Push up `reshape` and `transpose` ops as much as possible. Merging and eliminating when possible
- `transpose(einsum(...))` -> `einsum(...)`. Merge transpose into einsum
- `einsum(...)` -> `einsum(transpose(x), ...)`. Pull transposes out of einsum (to try to match matrix multiply pattern)
- `reshape(einsum(...))` -> `einsum(reshape(transpose(x)), ...)` Push reshapes up through einsum. Adding transposes as needed
- `transpose(activation(x))` -> `activation(transpose(x))`
- `transpose(unary(x))` -> `unary(transpose(x))`
- `transpose(identity_op(x))` -> `identity_op(transpose(x))`
- `reshape(activation(x))` -> `activation(reshape(x))`
- `reshape(unary(x))` -> `unary(reshape(x))`
- `reshape(identity_op(x))` -> `identity_op(reshape(x))`
- `reshape(reshape(x))` -> `reshape(x)`
- `transpose(transpose(x))` -> `transpose(x)`
- `reshape(x)` -> `x` if `reshape` is identity
- `transpose(x)` -> `x` if `transpose` is identity
- `transpose(reshape(x))` -> `reshape(transpose(x))` if possible put transpose before reshape
- `transpose(qdq(x))` -> `qdq(transpose(x))` if the scale is 0-dim
- `reshape(qdq(x))` -> `qdq(reshape(x))` if the scale is 0-dim
- `reshape(elementwise(a, b))` -> `elementwise(reshape(a), reshape(b))`
- `transpose(elementwise(a, b))` -> `elementwise(transpose(a), transpose(b))`
- `transpose(softmax(x))` -> `softmax(transpose(x))`
- `reshape(softmax(x))` -> `softmax(reshape(x))`
4. Convert back to matrix multiplication form to assist with TRT's pattern matching
- `einsum(x, y)` -> `matrix_multiply(x, y)` if einsum matches a matrix multiply pattern
- `matrix_multiply(transpose(x), y)` -> `matrix_multiply(x, y)` merge transpose if possible
5. Final clean ups, additional merging of transpose/reshapes into leftover einsums
- `einsum(x, y)` -> `matrix_multiply(x, y)` if einsum matches a matrix multiply pattern
- `matrix_multiply(transpose(x), y)` -> `matrix_multiply(x, y)` merge transpose if possible
- `transpose(einsum(...))` -> `einsum(...)`
- `einsum(tranpose(x), ...)` -> `einsum(...)`
- `einsum(collapse_rank(x), ...)` -> `einsum(...)`
- `expand_rank(einsum(...))` -> `einsum(...)`
0 commit comments