Skip to content

Commit c3b4d85

Browse files
author
Matthew Francis-Landau
committed
fix infinite looping issue in transpose elimination pass that transpose reshape elimination builds on
1 parent 5486d22 commit c3b4d85

File tree

3 files changed

+43
-1
lines changed

3 files changed

+43
-1
lines changed

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,19 @@ static TransposeOp getLowestTransposeCost(ElementWiseOp consumer,
6464
int64_t cost1 = memoryCost(consumer.getType()) + memoryCost(op2.getType());
6565
int64_t cost2 = memoryCost(consumer.getType()) + memoryCost(op1.getType());
6666
LLVM_DEBUG(DBGS() << "cost1=" << cost1 << ", cost2=" << cost2 << "\n");
67+
if (cost1 == 0 && cost2 == 0)
68+
return {};
6769
return cost1 <= cost2 ? op1 : op2;
6870
}
6971

7072
static std::pair<TransposeOp, TransposeOp>
7173
getTransposeProducers(ElementWiseOp op) {
7274
auto producer1 = op.getInput1().getDefiningOp<TransposeOp>();
7375
auto producer2 = op.getInput2().getDefiningOp<TransposeOp>();
76+
if (producer1 && producer1.getInput().getDefiningOp<ConstantOp>())
77+
producer1 = {};
78+
if (producer2 && producer2.getInput().getDefiningOp<ConstantOp>())
79+
producer2 = {};
7480
return std::make_pair(producer1, producer2);
7581
}
7682

@@ -760,6 +766,7 @@ class EinsumPushDownTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
760766
if (newEinsumRhs == equation.rhs)
761767
return failure(); // no change
762768

769+
equation.rhs = newEinsumRhs;
763770
std::string newEinsumEquation = equation.generateEquation();
764771

765772
auto newEinsum = rewriter.create<tensorrt::EinsumOp>(
@@ -771,6 +778,7 @@ class EinsumPushDownTranspose : public OpRewritePattern<tensorrt::EinsumOp> {
771778
AffineMap::getPermutationMap(outputPerm, op.getLoc().getContext()));
772779

773780
rewriter.replaceOp(op, newTranspose.getResult());
781+
774782
return success();
775783
}
776784
};

mlir-tensorrt/tensorrt/test/Dialect/TensorRT/transpose-elimination.mlir

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,4 +453,35 @@ func.func @push_up_transpose_elementwise_reshape_transpose_neg(%arg0: tensor<10x
453453
// CHECK-NEXT: %[[v1:.+]] = tensorrt.transpose {permutation = #[[$map]]} %[[arg1]]
454454
// CHECK-NEXT: %[[v2:.+]] = tensorrt.element_wise <kDIV>(%[[v1]], %[[v0]] : {{.*}})
455455
// CHECK-NEXT: %[[v3:.+]] = tensorrt.transpose {permutation = #[[$map1]]} %[[v2]]
456-
// CHECK-NEXT: return %[[v3]]
456+
// CHECK-NEXT: return %[[v3]]
457+
458+
// -----
459+
460+
#map = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
461+
func.func @transpose_rearrange_loop(%arg0: tensor<512x7x24xf32>, %arg1: tensor<512x7x7xf32>) -> tensor<7x512x24xf32> {
462+
%0 = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation<kNONE>, op1 = #tensorrt.matrix_operation<kNONE>} ins(%arg1, %arg0 : tensor<512x7x7xf32>, tensor<512x7x24xf32>) -> tensor<512x7x24xf32>
463+
%1 = tensorrt.transpose {permutation = #map} %0 : tensor<512x7x24xf32> to tensor<7x512x24xf32>
464+
return %1 : tensor<7x512x24xf32>
465+
}
466+
467+
// CHECK: @transpose_rearrange_loop(%[[arg0:.+]]: tensor<512x7x24xf32>, %[[arg1:.+]]: tensor<512x7x7xf32>)
468+
// CHECK: %[[v0:.+]] = tensorrt.einsum {equation = [[equation:.+]]} ins(%[[arg1]], %[[arg0]] : tensor<512x7x7xf32>, tensor<512x7x24xf32>)
469+
// CHECK: return %[[v0]]
470+
471+
// -----
472+
473+
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4, d2)>
474+
func.func @element_wise_with_two_constants() -> tensor<1x8x20x35x192xf32> {
475+
%cst_f32 = tensorrt.constant dense_resource<__elided__> : tensor<1x8x192x20x35xf32>
476+
%cst_f32_0 = tensorrt.constant dense_resource<__elided__> : tensor<1x8x20x35x192xf32>
477+
%1 = tensorrt.transpose {permutation = #map} %cst_f32 : tensor<1x8x192x20x35xf32> to tensor<1x8x20x35x192xf32>
478+
%2 = tensorrt.element_wise <kSUM>(%1, %cst_f32_0 : tensor<1x8x20x35x192xf32>, tensor<1x8x20x35x192xf32>) -> tensor<1x8x20x35x192xf32>
479+
return %2 : tensor<1x8x20x35x192xf32>
480+
}
481+
482+
// CHECK: @element_wise_with_two_constants()
483+
// CHECK: %[[const0:.+]] = tensorrt.constant dense_resource<__elided__> : tensor<1x8x192x20x35xf32>
484+
// CHECK: %[[const1:.+]] = tensorrt.constant dense_resource<__elided__> : tensor<1x8x20x35x192xf32>
485+
// CHECK: %[[v0:.+]] = tensorrt.transpose {permutation = #map} %[[const0]]
486+
// CHECK: %[[v1:.+]] = tensorrt.element_wise <kSUM>(%[[v0]], %[[const1]]
487+
// CHECK: return %[[v1]]

mlir-tensorrt/tensorrt/test/Dialect/TensorRT/transpose-reshape-elimination.mlir

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@ func.func @reshape_with_one(%arg0: tensor<2x3x4x5xf32>) -> tensor<2x3x4x6xf32> {
177177

178178
// -----
179179

180+
// CHECK: matmul_eliminate_reshape_lhs_2(%[[arg0:.+]]: tensor<1x2x3x4x5x6xf16>, %[[arg1:.+]]: tensor<1x2x6x8xf16>)
181+
// CHECK: %[[v0:.+]] = tensorrt.einsum {equation = [[equation:.+]]} ins(%[[arg0]], %[[arg1]] : tensor<1x2x3x4x5x6xf16>, tensor<1x2x6x8xf16>) -> tensor<1x2x3x4x5x8xf16>
182+
// CHECK: return %[[v0]]
180183
func.func @matmul_eliminate_reshape_lhs_2(%arg0: tensor<1x2x3x4x5x6xf16>, %arg1: tensor<1x2x6x8xf16>) -> tensor<1x2x3x4x5x8xf16>{
181184
%0 = tensorrt.reshape %arg0 : tensor<1x2x3x4x5x6xf16> to tensor<1x2x60x6xf16>
182185
%1 = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation<kNONE>, op1 = #tensorrt.matrix_operation<kNONE>}

0 commit comments

Comments
 (0)