Commit 50e9d57
Matthew Francis-Landau
[mlir-tensorrt] Transpose Reshape Elimination pass.
This is a new pass that is designed to replace the Transpose
and Reshape Elemination passes. This pass adds a lot of new rewrite
rules which enable pushing the transposes and reshapes around so that
they can be combined and then eliminated.
The motivation for this pass is that there are some cases where shuffles
can get inserted around matrix multiplications and element wise ops
which break various fusions inside of TensorRT.
To accomplish this, this pass uses several rewrite rules that push transposes
and reshapes around to combine them into identity transposes and reshapes which
can be eliminated from the program. The rewrite rules are as follows:
1. "canonicalize" the network into simpler ops
- `shuffle(x)` -> `reshape(transpose(reshape(x)))`
- `matrix_multiply(x, y)` -> `einsum("ij,jk->ik", x, y)`
- `expand_rank(x)` -> `reshape(x)`
- `collapse_rank(x)` -> `reshape(x)`
2. Push down `reshape` and `transpose` ops as much as possible. Merging and eliminating when possible
- `einsum(transpose(x), ...)` -> `einsum(x, ...)` Merge transpose into einsum
- `einsum(...)` -> `transpose(einsum(...))` Pull transpose out of einsum (to try to match matrix multiply pattern)
- `einsum(reshape(x), y, ...)` -> `transpose(reshape(einsum(x, reshape(transpose(y)), ...)))` Push reshape down. Possibly add reshape and transposes to other inputs as needed. Conditioned on heuristic checking if "better"
- `unary(transpose(x))` -> `transpose(unary(x))`
- `activation(transpose(x))` -> `transpose(activation(x))`
- `identity_op(transpose(x))` -> `transpose(identity_op(x))`
- `activation(reshape(x))` -> `reshape(activation(x))`
- `unary(reshape(x))` -> `reshape(unary(x))`
- `identity_op(reshape(x))` -> `reshape(identity_op(x))`
- `reshape(transpose(x))` -> `transpose(reshape(x))` if possible put reshape before transpose
- `qdq(transpose(x))` -> `transpose(qdq(x))` if the scale is 0-dim
- `qdq(reshape(x))` -> `reshape(qdq(x))` if the scale is 0-dim
- `reshape(reshape(x))` -> `reshape(x)`
- `transpose(transpose(x))` -> `transpose(x)`
- `reshape(x)` -> `x` if `reshape` is identity
- `transpose(x)` -> `x` if `transpose` is identity
- `elementwise(reshape(a), b)` -> `reshape(elementwise(a, reshape(b)))` conditioned on heuristic
- `elementwise(transpose(a), b)` -> `transpose(elementwise(a, transpose(b)))`
- `softmax(transpose(x))` -> `transpose(softmax(x))`
- `softmax(reshape(x))` -> `reshape(softmax(x))`
3. Push up `reshape` and `transpose` ops as much as possible. Merging and eliminating when possible
- `transpose(einsum(...))` -> `einsum(...)`. Merge transpose into einsum
- `einsum(...)` -> `einsum(transpose(x), ...)`. Pull transposes out of einsum (to try to match matrix multiply pattern)
- `reshape(einsum(...))` -> `einsum(reshape(transpose(x)), ...)` Push reshapes up through einsum. Adding transposes as needed
- `transpose(activation(x))` -> `activation(transpose(x))`
- `transpose(unary(x))` -> `unary(transpose(x))`
- `transpose(identity_op(x))` -> `identity_op(transpose(x))`
- `reshape(activation(x))` -> `activation(reshape(x))`
- `reshape(unary(x))` -> `unary(reshape(x))`
- `reshape(identity_op(x))` -> `identity_op(reshape(x))`
- `reshape(reshape(x))` -> `reshape(x)`
- `transpose(transpose(x))` -> `transpose(x)`
- `reshape(x)` -> `x` if `reshape` is identity
- `transpose(x)` -> `x` if `transpose` is identity
- `transpose(reshape(x))` -> `reshape(transpose(x))` if possible put transpose before reshape
- `transpose(qdq(x))` -> `qdq(transpose(x))` if the scale is 0-dim
- `reshape(qdq(x))` -> `qdq(reshape(x))` if the scale is 0-dim
- `reshape(elementwise(a, b))` -> `elementwise(reshape(a), reshape(b))`
- `transpose(elementwise(a, b))` -> `elementwise(transpose(a), transpose(b))`
- `transpose(softmax(x))` -> `softmax(transpose(x))`
- `reshape(softmax(x))` -> `softmax(reshape(x))`
4. Convert back to matrix multiplication form to assist with TRT's pattern matching
- `einsum(x, y)` -> `matrix_multiply(x, y)` if einsum matches a matrix multiply pattern
- `matrix_multiply(transpose(x), y)` -> `matrix_multiply(x, y)` merge transpose if possible
5. Final clean ups, additional merging of transpose/reshapes into leftover einsums
- `einsum(x, y)` -> `matrix_multiply(x, y)` if einsum matches a matrix multiply pattern
- `matrix_multiply(transpose(x), y)` -> `matrix_multiply(x, y)` merge transpose if possible
- `transpose(einsum(...))` -> `einsum(...)`
- `einsum(tranpose(x), ...)` -> `einsum(...)`
- `einsum(collapse_rank(x), ...)` -> `einsum(...)`
- `expand_rank(einsum(...))` -> `einsum(...)`1 parent b344e42 commit 50e9d57
File tree
12 files changed
+3611
-830
lines changed- mlir-tensorrt/tensorrt
- include/mlir-tensorrt-dialect/TensorRT
- IR
- Transforms
- lib/TensorRT
- IR
- Transforms
- test/Dialect/TensorRT
12 files changed
+3611
-830
lines changedLines changed: 5 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3875 | 3875 | | |
3876 | 3876 | | |
3877 | 3877 | | |
| 3878 | + | |
| 3879 | + | |
| 3880 | + | |
| 3881 | + | |
| 3882 | + | |
3878 | 3883 | | |
3879 | 3884 | | |
3880 | 3885 | | |
| |||
Lines changed: 73 additions & 70 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
176 | 176 | | |
177 | 177 | | |
178 | 178 | | |
179 | | - | |
| 179 | + | |
180 | 180 | | |
181 | | - | |
182 | | - | |
| 181 | + | |
| 182 | + | |
183 | 183 | | |
184 | 184 | | |
185 | | - | |
186 | | - | |
187 | | - | |
188 | | - | |
189 | | - | |
190 | | - | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
191 | 190 | | |
192 | 191 | | |
193 | 192 | | |
194 | | - | |
195 | | - | |
196 | | - | |
197 | | - | |
198 | | - | |
199 | | - | |
200 | | - | |
201 | | - | |
202 | | - | |
203 | | - | |
204 | | - | |
205 | | - | |
206 | | - | |
207 | | - | |
208 | | - | |
209 | | - | |
210 | | - | |
211 | | - | |
212 | | - | |
213 | | - | |
214 | | - | |
215 | | - | |
216 | | - | |
217 | | - | |
218 | | - | |
219 | | - | |
220 | | - | |
221 | | - | |
222 | | - | |
223 | | - | |
224 | | - | |
225 | | - | |
226 | | - | |
227 | | - | |
228 | | - | |
229 | | - | |
230 | | - | |
231 | | - | |
232 | | - | |
233 | | - | |
234 | | - | |
235 | | - | |
236 | | - | |
237 | | - | |
238 | | - | |
239 | | - | |
240 | | - | |
241 | | - | |
242 | | - | |
243 | | - | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
244 | 258 | | |
245 | 259 | | |
246 | 260 | | |
247 | | - | |
248 | | - | |
249 | | - | |
250 | | - | |
251 | | - | |
252 | | - | |
253 | | - | |
254 | | - | |
255 | | - | |
256 | | - | |
257 | | - | |
258 | 261 | | |
259 | 262 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
143 | 143 | | |
144 | 144 | | |
145 | 145 | | |
146 | | - | |
147 | | - | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
148 | 149 | | |
149 | | - | |
| 150 | + | |
150 | 151 | | |
151 | 152 | | |
152 | 153 | | |
| |||
203 | 204 | | |
204 | 205 | | |
205 | 206 | | |
206 | | - | |
207 | | - | |
| 207 | + | |
| 208 | + | |
208 | 209 | | |
209 | 210 | | |
210 | 211 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1850 | 1850 | | |
1851 | 1851 | | |
1852 | 1852 | | |
| 1853 | + | |
| 1854 | + | |
| 1855 | + | |
| 1856 | + | |
| 1857 | + | |
| 1858 | + | |
| 1859 | + | |
1853 | 1860 | | |
1854 | 1861 | | |
1855 | 1862 | | |
| |||
Lines changed: 1 addition & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
19 | 19 | | |
20 | 20 | | |
21 | 21 | | |
22 | | - | |
23 | | - | |
| 22 | + | |
24 | 23 | | |
25 | 24 | | |
26 | 25 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
86 | 86 | | |
87 | 87 | | |
88 | 88 | | |
89 | | - | |
90 | | - | |
91 | | - | |
| 89 | + | |
92 | 90 | | |
93 | 91 | | |
94 | 92 | | |
| |||
0 commit comments