Skip to content

Commit 903e995

Browse files
author
Matthew Francis-Landau
committed
einsum -> matmul reorder arguments if matches pattern
1 parent 1da7863 commit 903e995

File tree

2 files changed

+32
-2
lines changed

2 files changed

+32
-2
lines changed

mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/TransposeReshapeElimination.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2043,6 +2043,8 @@ class EinsumToMatrixMultiply : public OpRewritePattern<tensorrt::EinsumOp> {
20432043
char multipliedAxis = 0;
20442044
std::string batchAxes = "";
20452045

2046+
Value inputs[2] = {op.getInputs()[0], op.getInputs()[1]};
2047+
20462048
for (size_t i = 0; i < equation.rhs.size(); i++) {
20472049
char c = equation.rhs[i];
20482050
size_t lhsPos = equation.lhsParts[0].find(c);
@@ -2085,6 +2087,15 @@ class EinsumToMatrixMultiply : public OpRewritePattern<tensorrt::EinsumOp> {
20852087
return failure(/* no multiplied axis */);
20862088
}
20872089

2090+
if (matrixAxis[0] != 0 && matrixAxis[1] != 0 &&
2091+
equation.rhs.find(matrixAxis[0]) > equation.rhs.find(matrixAxis[1])) {
2092+
// the order of the arguments need to get swapped as the order for a
2093+
// matrix multiply requires the first matrix axis appears first
2094+
std::swap(equation.lhsParts[0], equation.lhsParts[1]);
2095+
std::swap(matrixAxis[0], matrixAxis[1]);
2096+
std::swap(inputs[0], inputs[1]);
2097+
}
2098+
20882099
MatrixOperation opType[2];
20892100
for (int i = 0; i < 2; i++) {
20902101
std::string e = batchAxes;
@@ -2121,8 +2132,8 @@ class EinsumToMatrixMultiply : public OpRewritePattern<tensorrt::EinsumOp> {
21212132
}
21222133

21232134
rewriter.replaceOpWithNewOp<tensorrt::MatrixMultiplyOp>(
2124-
op, op.getResult().getType(), op.getInputs()[0], op.getInputs()[1],
2125-
opType[0], opType[1]);
2135+
op, op.getResult().getType(), inputs[0], inputs[1], opType[0],
2136+
opType[1]);
21262137

21272138
return success();
21282139
}

mlir-tensorrt/tensorrt/test/Dialect/TensorRT/transpose-reshape-elimination.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,22 @@ func.func @elementwise_reshape(%arg0: tensor<12x3x3xf32>, %arg1: tensor<12xf32>)
202202
%2 = tensorrt.element_wise <kDIV>(%0, %1 : tensor<12x3x3xf32>, tensor<12x1x1xf32>) -> tensor<12x3x3xf32>
203203
return %2 : tensor<12x3x3xf32>
204204
}
205+
206+
// -----
207+
208+
// CHECK: @matmul_argument_swap(%[[arg0:.+]]: tensor<1x2x4x3x561xf32>, %[[arg1:.+]]: tensor<1x2x4x3x3xf32>) -> tensor<1x2x4x561x3xf32>
209+
// CHECK-DAG: %[[v0:.+]]= tensorrt.collapse_rank %[[arg1]] : tensor<1x2x4x3x3xf32> to tensor<2x4x3x3xf32>
210+
// CHECK-DAG: %[[v1:.+]] = tensorrt.transpose {permutation = #map} %[[arg0]] : tensor<1x2x4x3x561xf32> to tensor<2x4x561x3x1xf32>
211+
// CHECK-DAG: %[[v2:.+]] = tensorrt.collapse_rank %[[v1]] : tensor<2x4x561x3x1xf32> to tensor<2x4x561x3xf32>
212+
// CHECK: %[[v3:.+]] = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation<kNONE>, op1 = #tensorrt.matrix_operation<kTRANSPOSE>} ins(%2, %0 : tensor<2x4x561x3xf32>, tensor<2x4x3x3xf32>) -> tensor<2x4x561x3xf32>
213+
// CHECK: %[[v4:.+]] = tensorrt.expand_rank %[[v3]] : tensor<2x4x561x3xf32> to tensor<1x2x4x561x3xf32>
214+
// CHECK: return %[[v4]]
215+
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d4, d3)>
216+
func.func @matmul_argument_swap(%arg0: tensor<1x2x4x3x561xf32>, %arg1: tensor<1x2x4x3x3xf32>) -> tensor<1x2x4x561x3xf32> {
217+
%0 = tensorrt.reshape %arg1 : tensor<1x2x4x3x3xf32> to tensor<8x3x3xf32>
218+
%1 = tensorrt.reshape %arg0 : tensor<1x2x4x3x561xf32> to tensor<8x3x561xf32>
219+
%2 = tensorrt.matrix_multiply {op0 = #tensorrt.matrix_operation<kNONE>, op1 = #tensorrt.matrix_operation<kNONE>} ins(%0, %1 : tensor<8x3x3xf32>, tensor<8x3x561xf32>) -> tensor<8x3x561xf32>
220+
%3 = tensorrt.reshape %2 : tensor<8x3x561xf32> to tensor<1x2x4x3x561xf32>
221+
%4 = tensorrt.transpose {permutation = #map} %3 : tensor<1x2x4x3x561xf32> to tensor<1x2x4x561x3xf32>
222+
return %4 : tensor<1x2x4x561x3xf32>
223+
}

0 commit comments

Comments
 (0)