@@ -202,3 +202,22 @@ func.func @elementwise_reshape(%arg0: tensor<12x3x3xf32>, %arg1: tensor<12xf32>)
202202 %2 = tensorrt.element_wise <kDIV >(%0 , %1 : tensor <12 x3 x3 xf32 >, tensor <12 x1 x1 xf32 >) -> tensor <12 x3 x3 xf32 >
203203 return %2 : tensor <12 x3 x3 xf32 >
204204}
205+
206+ // -----
207+
208+ // CHECK: @matmul_argument_swap(%[[arg0:.+]]: tensor<1x2x4x3x561xf32>, %[[arg1:.+]]: tensor<1x2x4x3x3xf32>) -> tensor<1x2x4x561x3xf32>
209+ // CHECK-DAG: %[[v0:.+]]= tensorrt.collapse_rank %[[arg1]] : tensor<1x2x4x3x3xf32> to tensor<2x4x3x3xf32>
210+ // CHECK-DAG: %[[v1:.+]] = tensorrt.transpose {permutation = #map} %[[arg0]] : tensor<1x2x4x3x561xf32> to tensor<2x4x561x3x1xf32>
211+ // CHECK-DAG: %[[v2:.+]] = tensorrt.collapse_rank %[[v1]] : tensor<2x4x561x3x1xf32> to tensor<2x4x561x3xf32>
212+ // CHECK: %[[v3:.+]] = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation<kNONE>, op1 = #tensorrt.matrix_operation<kTRANSPOSE>} ins(%2, %0 : tensor<2x4x561x3xf32>, tensor<2x4x3x3xf32>) -> tensor<2x4x561x3xf32>
213+ // CHECK: %[[v4:.+]] = tensorrt.expand_rank %[[v3]] : tensor<2x4x561x3xf32> to tensor<1x2x4x561x3xf32>
214+ // CHECK: return %[[v4]]
215+ #map = affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d1 , d2 , d4 , d3 )>
216+ func.func @matmul_argument_swap (%arg0: tensor <1 x2 x4 x3 x561 xf32 >, %arg1: tensor <1 x2 x4 x3 x3 xf32 >) -> tensor <1 x2 x4 x561 x3 xf32 > {
217+ %0 = tensorrt.reshape %arg1 : tensor <1 x2 x4 x3 x3 xf32 > to tensor <8 x3 x3 xf32 >
218+ %1 = tensorrt.reshape %arg0 : tensor <1 x2 x4 x3 x561 xf32 > to tensor <8 x3 x561 xf32 >
219+ %2 = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation <kNONE >, op1 = #tensorrt.matrix_operation <kNONE >} ins (%0 , %1 : tensor <8 x3 x3 xf32 >, tensor <8 x3 x561 xf32 >) -> tensor <8 x3 x561 xf32 >
220+ %3 = tensorrt.reshape %2 : tensor <8 x3 x561 xf32 > to tensor <1 x2 x4 x3 x561 xf32 >
221+ %4 = tensorrt.transpose {permutation = #map } %3 : tensor <1 x2 x4 x3 x561 xf32 > to tensor <1 x2 x4 x561 x3 xf32 >
222+ return %4 : tensor <1 x2 x4 x561 x3 xf32 >
223+ }
0 commit comments