@@ -453,4 +453,35 @@ func.func @push_up_transpose_elementwise_reshape_transpose_neg(%arg0: tensor<10x
453453// CHECK-NEXT: %[[v1:.+]] = tensorrt.transpose {permutation = #[[$map]]} %[[arg1]]
454454// CHECK-NEXT: %[[v2:.+]] = tensorrt.element_wise <kDIV>(%[[v1]], %[[v0]] : {{.*}})
455455// CHECK-NEXT: %[[v3:.+]] = tensorrt.transpose {permutation = #[[$map1]]} %[[v2]]
456- // CHECK-NEXT: return %[[v3]]
456+ // CHECK-NEXT: return %[[v3]]
457+
458+ // -----
459+
460+ #map = affine_map <(d0 , d1 , d2 ) -> (d1 , d0 , d2 )>
461+ func.func @transpose_rearrange_loop (%arg0: tensor <512 x7 x24 xf32 >, %arg1: tensor <512 x7 x7 xf32 >) -> tensor <7 x512 x24 xf32 > {
462+ %0 = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation <kNONE >, op1 = #tensorrt.matrix_operation <kNONE >} ins (%arg1 , %arg0 : tensor <512 x7 x7 xf32 >, tensor <512 x7 x24 xf32 >) -> tensor <512 x7 x24 xf32 >
463+ %1 = tensorrt.transpose {permutation = #map } %0 : tensor <512 x7 x24 xf32 > to tensor <7 x512 x24 xf32 >
464+ return %1 : tensor <7 x512 x24 xf32 >
465+ }
466+
467+ // CHECK: @transpose_rearrange_loop(%[[arg0:.+]]: tensor<512x7x24xf32>, %[[arg1:.+]]: tensor<512x7x7xf32>)
468+ // CHECK: %[[v0:.+]] = tensorrt.einsum {equation = [[equation:.+]]} ins(%[[arg1]], %[[arg0]] : tensor<512x7x7xf32>, tensor<512x7x24xf32>)
469+ // CHECK: return %[[v0]]
470+
471+ // -----
472+
473+ #map = affine_map <(d0 , d1 , d2 , d3 , d4 ) -> (d0 , d1 , d3 , d4 , d2 )>
474+ func.func @element_wise_with_two_constants () -> tensor <1 x8 x20 x35 x192 xf32 > {
475+ %cst_f32 = tensorrt.constant dense_resource <__elided__ > : tensor <1 x8 x192 x20 x35 xf32 >
476+ %cst_f32_0 = tensorrt.constant dense_resource <__elided__ > : tensor <1 x8 x20 x35 x192 xf32 >
477+ %1 = tensorrt.transpose {permutation = #map } %cst_f32 : tensor <1 x8 x192 x20 x35 xf32 > to tensor <1 x8 x20 x35 x192 xf32 >
478+ %2 = tensorrt.element_wise <kSUM >(%1 , %cst_f32_0 : tensor <1 x8 x20 x35 x192 xf32 >, tensor <1 x8 x20 x35 x192 xf32 >) -> tensor <1 x8 x20 x35 x192 xf32 >
479+ return %2 : tensor <1 x8 x20 x35 x192 xf32 >
480+ }
481+
482+ // CHECK: @element_wise_with_two_constants()
483+ // CHECK: %[[const0:.+]] = tensorrt.constant dense_resource<__elided__> : tensor<1x8x192x20x35xf32>
484+ // CHECK: %[[const1:.+]] = tensorrt.constant dense_resource<__elided__> : tensor<1x8x20x35x192xf32>
485+ // CHECK: %[[v0:.+]] = tensorrt.transpose {permutation = #map} %[[const0]]
486+ // CHECK: %[[v1:.+]] = tensorrt.element_wise <kSUM>(%[[v0]], %[[const1]]
487+ // CHECK: return %[[v1]]
0 commit comments