Skip to content

Commit 8cd336b

Browse files
author
Matthew Francis-Landau
committed
fix permutation from transpose-reshape reordering
1 parent 903e995 commit 8cd336b

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1726,8 +1726,8 @@ class MoveTransposeBeforeReshape
17261726
for (int i = 0; i < reshapeOutputType.getRank(); i++) {
17271727
transposePerm.push_back(i);
17281728
}
1729-
transposePerm = op.getPermutation().compose(
1730-
transposePerm); // TODO: check if this is correct
1729+
transposePerm =
1730+
inversePermutation(op.getPermutation()).compose(transposePerm);
17311731

17321732
struct ReshapeGroup {
17331733
SmallVector<int64_t> inputAxes;

mlir-tensorrt/tensorrt/test/Dialect/TensorRT/transpose-reshape-elimination.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,28 @@ func.func @matmul_argument_swap(%arg0: tensor<1x2x4x3x561xf32>, %arg1: tensor<1x
221221
%4 = tensorrt.transpose {permutation = #map} %3 : tensor<1x2x4x3x561xf32> to tensor<1x2x4x561x3xf32>
222222
return %4 : tensor<1x2x4x561x3xf32>
223223
}
224+
225+
// -----
226+
227+
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4, d3)>
228+
func.func @olympus_v0_2(%arg0: tensor<1x2x6x3x3xf32>, %arg1: tensor<1x2x6x3x2048xf32>) -> tensor<1x2x6x2048x3xf32> {
229+
%0 = tensorrt.reshape %arg0 : tensor<1x2x6x3x3xf32> to tensor<12x3x3xf32>
230+
%1 = tensorrt.reshape %arg1 : tensor<1x2x6x3x2048xf32> to tensor<12x3x2048xf32>
231+
%2 = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation<kNONE>, op1 = #tensorrt.matrix_operation<kNONE>} ins(%0, %1 : tensor<12x3x3xf32>, tensor<12x3x2048xf32>) -> tensor<12x3x2048xf32>
232+
%3 = tensorrt.reshape %2 : tensor<12x3x2048xf32> to tensor<1x2x6x3x2048xf32>
233+
%4 = tensorrt.transpose {permutation = #map} %3 : tensor<1x2x6x3x2048xf32> to tensor<1x2x6x2048x3xf32>
234+
return %4 : tensor<1x2x6x2048x3xf32>
235+
}
236+
237+
// -----
238+
239+
// CHECK: @transpose_reshape_reorder(%[[arg0:.+]]: tensor<12x256x8x8x16x8xf32>)
240+
// CHECK: %[[v0:.+]] = tensorrt.transpose {permutation = #map} %[[arg0]] : tensor<12x256x8x8x16x8xf32> to tensor<12x8x8x16x8x256xf32>
241+
// CHECK: %[[v1:.+]] = tensorrt.reshape %0 : tensor<12x8x8x16x8x256xf32> to tensor<12x64x128x256xf32>
242+
// CHECK: return %[[v1]]
243+
#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>
244+
func.func @transpose_reshape_reorder(%arg0: tensor<12x256x8x8x16x8xf32>) -> tensor<12x64x128x256xf32> {
245+
%0 = tensorrt.reshape %arg0 : tensor<12x256x8x8x16x8xf32> to tensor<12x256x64x128xf32>
246+
%1 = tensorrt.transpose {permutation = #map} %0 : tensor<12x256x64x128xf32> to tensor<12x64x128x256xf32>
247+
return %1 : tensor<12x64x128x256xf32>
248+
}

0 commit comments

Comments
 (0)