Commit 34e1e17
[mlir-tensorrt] Transpose Reshape Elimination pass (#686)
The `TransposeReshapeElimination` pass is designed to subsume the
existing Transpose and Reshape Elimination passes. This pass adds a lot
of new rewrite rules which enable pushing the transposes and reshapes
around so that they can be combined and then eliminated. The rules from
the existing `TransposeElimination` are copied into the
`TransposeReshapeElimination.cpp` file. The rules from the
`ReshapeElimination` pass should be subsumed by the rules added to the
`TransposeReshapeElimination`.
The motivation for this pass is that there are some cases where shuffles
can get inserted around matrix multiplications and element wise ops
which break [various fusions inside of
TensorRT](https://docs.nvidia.com/deeplearning/tensorrt/latest/performance/best-practices.html#types-of-fusions).
The process is as follows:
1. "canonicalize" the network into simpler ops
- `shuffle(x)` -> `reshape(transpose(reshape(x)))`
- `matrix_multiply(x, y)` -> `einsum("ij,jk->ik", x, y)`
- `expand_rank(x)` -> `reshape(x)`
- `collapse_rank(x)` -> `reshape(x)`
2. Push down `reshape` and `transpose` ops as much as possible. Merging
and eliminating when possible
- `einsum(transpose(x), ...)` -> `einsum(x, ...)` Merge transpose into
einsum
- `einsum(...)` -> `transpose(einsum(...))` Pull transpose out of einsum
(to try to match matrix multiply pattern)
- `einsum(reshape(x), y, ...)` -> `transpose(reshape(einsum(x,
reshape(transpose(y)), ...)))` Push reshape down. Possibly add reshape
and transposes to other inputs as needed. Conditioned on heuristic
checking if "better"
- EXISTING `unary(transpose(x))` -> `transpose(unary(x))`
- EXISTING `activation(transpose(x))` -> `transpose(activation(x))`
- EXSITING `identity_op(transpose(x))` -> `transpose(identity_op(x))`
- `activation(reshape(x))` -> `reshape(activation(x))`
- `unary(reshape(x))` -> `reshape(unary(x))`
- `identity_op(reshape(x))` -> `reshape(identity_op(x))`
- `reshape(transpose(x))` -> `transpose(reshape(x))` if possible put
reshape before transpose
- `qdq(transpose(x))` -> `transpose(qdq(x))` if the scale is 0-dim
- `qdq(reshape(x))` -> `reshape(qdq(x))` if the scale is 0-dim
- EXISTING `reshape(reshape(x))` -> `reshape(x)`
- EXISTING `transpose(transpose(x))` -> `transpose(x)`
- EXISTING `reshape(x)` -> `x` if `reshape` is identity
- EXISTING `transpose(x)` -> `x` if `transpose` is identity
- `elementwise(reshape(a), b)` -> `reshape(elementwise(a, reshape(b)))`
conditioned on heuristic
- EXISTING `elementwise(transpose(a), b)` -> `transpose(elementwise(a,
transpose(b)))`
- `softmax(transpose(x))` -> `transpose(softmax(x))`
- `softmax(reshape(x))` -> `reshape(softmax(x))`
3. Push up `reshape` and `transpose` ops as much as possible. Merging
and eliminating when possible
- `transpose(einsum(...))` -> `einsum(...)`. Merge transpose into einsum
- `einsum(...)` -> `einsum(transpose(x), ...)`. Pull transposes out of
einsum (to try to match matrix multiply pattern)
- `reshape(einsum(...))` -> `einsum(reshape(transpose(x)), ...)` Push
reshapes up through einsum. Adding transposes as needed
- EXISTING `transpose(activation(x))` -> `activation(transpose(x))`
- EXISTING `transpose(unary(x))` -> `unary(transpose(x))`
- EXISTING `transpose(identity_op(x))` -> `identity_op(transpose(x))`
- `reshape(activation(x))` -> `activation(reshape(x))`
- `reshape(unary(x))` -> `unary(reshape(x))`
- `reshape(identity_op(x))` -> `identity_op(reshape(x))`
- EXISTING `reshape(reshape(x))` -> `reshape(x)`
- EXISTING `transpose(transpose(x))` -> `transpose(x)`
- EXISTING `reshape(x)` -> `x` if `reshape` is identity
- EXISTING `transpose(x)` -> `x` if `transpose` is identity
- `transpose(reshape(x))` -> `reshape(transpose(x))` if possible put
transpose before reshape
- `transpose(qdq(x))` -> `qdq(transpose(x))` if the scale is 0-dim
- `reshape(qdq(x))` -> `qdq(reshape(x))` if the scale is 0-dim
- `reshape(elementwise(a, b))` -> `elementwise(reshape(a), reshape(b))`
- EXISTING `transpose(elementwise(a, b))` -> `elementwise(transpose(a),
transpose(b))`
- `transpose(softmax(x))` -> `softmax(transpose(x))`
- `reshape(softmax(x))` -> `softmax(reshape(x))`
4. Convert back to matrix multiplication form to assist with TRT's
pattern matching
- `einsum(x, y)` -> `matrix_multiply(x, y)` if einsum matches a matrix
multiply pattern
- `matrix_multiply(transpose(x), y)` -> `matrix_multiply(x, y)` merge
transpose if possible
5. Final clean ups, additional merging of transpose/reshapes into
leftover einsums
- `einsum(x, y)` -> `matrix_multiply(x, y)` if einsum matches a matrix
multiply pattern
- `matrix_multiply(transpose(x), y)` -> `matrix_multiply(x, y)` merge
transpose if possible
- `transpose(einsum(...))` -> `einsum(...)`
- `einsum(tranpose(x), ...)` -> `einsum(...)`
- `einsum(collapse_rank(x), ...)` -> `einsum(...)`
- `expand_rank(einsum(...))` -> `einsum(...)`
__NOTE__: The overarching goal of this PR is to improve the pattern
matching inside of TensorRT (and therefore the quality of kernel's that
TensorRT can generate, and fusion that TensorRT will generate). I have
some empirical evidence that the mlir that is generated seems to be be
an improvement, however I am still not 100% sure what is the _best_ way
to generate mlir in some of these edge cases when it comes to getting
the fastest model out of TensorRT.
Co-authored-by: Matthew Francis-Landau <[email protected]>1 parent b344e42 commit 34e1e17
File tree
12 files changed
+3611
-830
lines changed- mlir-tensorrt/tensorrt
- include/mlir-tensorrt-dialect/TensorRT
- IR
- Transforms
- lib/TensorRT
- IR
- Transforms
- test/Dialect/TensorRT
12 files changed
+3611
-830
lines changedLines changed: 5 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3875 | 3875 | | |
3876 | 3876 | | |
3877 | 3877 | | |
| 3878 | + | |
| 3879 | + | |
| 3880 | + | |
| 3881 | + | |
| 3882 | + | |
3878 | 3883 | | |
3879 | 3884 | | |
3880 | 3885 | | |
| |||
Lines changed: 73 additions & 70 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
176 | 176 | | |
177 | 177 | | |
178 | 178 | | |
179 | | - | |
| 179 | + | |
180 | 180 | | |
181 | | - | |
182 | | - | |
| 181 | + | |
| 182 | + | |
183 | 183 | | |
184 | 184 | | |
185 | | - | |
186 | | - | |
187 | | - | |
188 | | - | |
189 | | - | |
190 | | - | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
191 | 190 | | |
192 | 191 | | |
193 | 192 | | |
194 | | - | |
195 | | - | |
196 | | - | |
197 | | - | |
198 | | - | |
199 | | - | |
200 | | - | |
201 | | - | |
202 | | - | |
203 | | - | |
204 | | - | |
205 | | - | |
206 | | - | |
207 | | - | |
208 | | - | |
209 | | - | |
210 | | - | |
211 | | - | |
212 | | - | |
213 | | - | |
214 | | - | |
215 | | - | |
216 | | - | |
217 | | - | |
218 | | - | |
219 | | - | |
220 | | - | |
221 | | - | |
222 | | - | |
223 | | - | |
224 | | - | |
225 | | - | |
226 | | - | |
227 | | - | |
228 | | - | |
229 | | - | |
230 | | - | |
231 | | - | |
232 | | - | |
233 | | - | |
234 | | - | |
235 | | - | |
236 | | - | |
237 | | - | |
238 | | - | |
239 | | - | |
240 | | - | |
241 | | - | |
242 | | - | |
243 | | - | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
244 | 258 | | |
245 | 259 | | |
246 | 260 | | |
247 | | - | |
248 | | - | |
249 | | - | |
250 | | - | |
251 | | - | |
252 | | - | |
253 | | - | |
254 | | - | |
255 | | - | |
256 | | - | |
257 | | - | |
258 | 261 | | |
259 | 262 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
143 | 143 | | |
144 | 144 | | |
145 | 145 | | |
146 | | - | |
147 | | - | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
148 | 149 | | |
149 | | - | |
| 150 | + | |
150 | 151 | | |
151 | 152 | | |
152 | 153 | | |
| |||
203 | 204 | | |
204 | 205 | | |
205 | 206 | | |
206 | | - | |
207 | | - | |
| 207 | + | |
| 208 | + | |
208 | 209 | | |
209 | 210 | | |
210 | 211 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1850 | 1850 | | |
1851 | 1851 | | |
1852 | 1852 | | |
| 1853 | + | |
| 1854 | + | |
| 1855 | + | |
| 1856 | + | |
| 1857 | + | |
| 1858 | + | |
| 1859 | + | |
1853 | 1860 | | |
1854 | 1861 | | |
1855 | 1862 | | |
| |||
Lines changed: 1 addition & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
19 | 19 | | |
20 | 20 | | |
21 | 21 | | |
22 | | - | |
23 | | - | |
| 22 | + | |
24 | 23 | | |
25 | 24 | | |
26 | 25 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
86 | 86 | | |
87 | 87 | | |
88 | 88 | | |
89 | | - | |
90 | | - | |
91 | | - | |
| 89 | + | |
92 | 90 | | |
93 | 91 | | |
94 | 92 | | |
| |||
0 commit comments