Commit 4bc28e2
authored
[optimize-dot-operands]: Fuse load and trans operations - part 2 (#4468)
This PR enhances the new transformation pass aimed at fusing `tt.load`
and `tt.trans` operations. Specifically it adds support for loop carried
arguments used (possibly transitively) by the candidate `tt.load` that
should be fused with a `tt.trans`.
Example:
```
%10 = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xbf16, #linear>>
%13:3 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32, %arg6 = %10) -> (tensor<256x256xf32, #mma>, i32, !tt.ptr<tensor<256x32xbf16, #linear>>) : i32 {
%17 = tt.advance %9, [%c256_i32, %arg5] : <tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
%18 = tt.load %17 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
%19 = tt.advance %arg6, [%c16_i32, %arg5] : <tensor<256x32xbf16, #linear>>
%20 = tt.load %19 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "column_major"} : !tt.ptr<tensor<256x32xbf16, #linear>>
%21 = tt.trans %20 {order = array<i32: 1, 0>} : tensor<256x32xbf16, #linear> -> tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%22 = tt.dot %18, %21, %arg4, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma>
%23 = arith.addi %arg5, %c32_i32 : i32
scf.yield %22, %23, %19 : tensor<256x256xf32, #mma>, i32, !tt.ptr<tensor<256x32xbf16, #linear>>
}
```
Here the load `%20` is a candidate for fusion with the `tt.trans`
operation. The pointer argument used by the candidate load (`%19`) is
produced by a `tt.advance` operation which uses the loop carried pointer
`%arg6`.
---------
Signed-off-by: Tiotto, Ettore <[email protected]>1 parent c10bebc commit 4bc28e2
File tree
7 files changed
+496
-130
lines changed- test/TritonIntelGPU
- third_party/intel
- backend
- include/Utils
- lib
- Dialect/Triton/Transforms
- TritonIntelGPUTransforms
- Utils
7 files changed
+496
-130
lines changedLarge diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
281 | 281 | | |
282 | 282 | | |
283 | 283 | | |
| 284 | + | |
284 | 285 | | |
285 | 286 | | |
286 | 287 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
29 | 29 | | |
30 | 30 | | |
31 | 31 | | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
32 | 35 | | |
33 | 36 | | |
34 | 37 | | |
Lines changed: 3 additions & 26 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
95 | 95 | | |
96 | 96 | | |
97 | 97 | | |
98 | | - | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
99 | 101 | | |
100 | 102 | | |
101 | 103 | | |
| |||
267 | 269 | | |
268 | 270 | | |
269 | 271 | | |
270 | | - | |
271 | | - | |
272 | | - | |
273 | | - | |
274 | | - | |
275 | | - | |
276 | | - | |
277 | | - | |
278 | | - | |
279 | | - | |
280 | | - | |
281 | | - | |
282 | | - | |
283 | | - | |
284 | | - | |
285 | | - | |
286 | | - | |
287 | | - | |
288 | | - | |
289 | | - | |
290 | | - | |
291 | | - | |
292 | | - | |
293 | | - | |
294 | | - | |
295 | 272 | | |
296 | 273 | | |
297 | 274 | | |
| |||
0 commit comments