@@ -221,3 +221,28 @@ func.func @matmul_argument_swap(%arg0: tensor<1x2x4x3x561xf32>, %arg1: tensor<1x
221221 %4 = tensorrt.transpose {permutation = #map } %3 : tensor <1 x2 x4 x3 x561 xf32 > to tensor <1 x2 x4 x561 x3 xf32 >
222222 return %4 : tensor <1 x2 x4 x561 x3 xf32 >
223223}
224+
225+ // -----
226+
227+ #map = affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d1 , d2 , d4 , d3 )>
228+ func.func @olympus_v0_2 (%arg0: tensor <1 x2 x6 x3 x3 xf32 >, %arg1: tensor <1 x2 x6 x3 x2048 xf32 >) -> tensor <1 x2 x6 x2048 x3 xf32 > {
229+ %0 = tensorrt.reshape %arg0 : tensor <1 x2 x6 x3 x3 xf32 > to tensor <12 x3 x3 xf32 >
230+ %1 = tensorrt.reshape %arg1 : tensor <1 x2 x6 x3 x2048 xf32 > to tensor <12 x3 x2048 xf32 >
231+ %2 = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation <kNONE >, op1 = #tensorrt.matrix_operation <kNONE >} ins (%0 , %1 : tensor <12 x3 x3 xf32 >, tensor <12 x3 x2048 xf32 >) -> tensor <12 x3 x2048 xf32 >
232+ %3 = tensorrt.reshape %2 : tensor <12 x3 x2048 xf32 > to tensor <1 x2 x6 x3 x2048 xf32 >
233+ %4 = tensorrt.transpose {permutation = #map } %3 : tensor <1 x2 x6 x3 x2048 xf32 > to tensor <1 x2 x6 x2048 x3 xf32 >
234+ return %4 : tensor <1 x2 x6 x2048 x3 xf32 >
235+ }
236+
237+ // -----
238+
239+ // CHECK: @transpose_reshape_reorder(%[[arg0:.+]]: tensor<12x256x8x8x16x8xf32>)
240+ // CHECK: %[[v0:.+]] = tensorrt.transpose {permutation = #map} %[[arg0]] : tensor<12x256x8x8x16x8xf32> to tensor<12x8x8x16x8x256xf32>
241+ // CHECK: %[[v1:.+]] = tensorrt.reshape %0 : tensor<12x8x8x16x8x256xf32> to tensor<12x64x128x256xf32>
242+ // CHECK: return %[[v1]]
243+ #map = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d2 , d3 , d1 )>
244+ func.func @transpose_reshape_reorder (%arg0: tensor <12 x256 x8 x8 x16 x8 xf32 >) -> tensor <12 x64 x128 x256 xf32 > {
245+ %0 = tensorrt.reshape %arg0 : tensor <12 x256 x8 x8 x16 x8 xf32 > to tensor <12 x256 x64 x128 xf32 >
246+ %1 = tensorrt.transpose {permutation = #map } %0 : tensor <12 x256 x64 x128 xf32 > to tensor <12 x64 x128 x256 xf32 >
247+ return %1 : tensor <12 x64 x128 x256 xf32 >
248+ }
0 commit comments