Skip to content

Commit 38cfd58

Browse files
[compiler/StableHloExt] Bump max iterations in stablehlo-ext-refine-shapes (#385)
Increases the max number of iterations that the initial dynamic shape refinement pipeline will iterate in order to better reported use cases.
1 parent 49fede3 commit 38cfd58

File tree

2 files changed

+87
-1
lines changed

2 files changed

+87
-1
lines changed

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/StableHloExt/Transforms/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def CanonicalizeShapesPass : Pass<"stablehlo-ext-canonicalize-shapes", "ModuleOp
9595
}];
9696

9797
let options = [
98-
Option<"maxIterations", "max-iterations", "int64_t", "4",
98+
Option<"maxIterations", "max-iterations", "int64_t", "8",
9999
"the maximum number of iterations to run the dynamism simplification and "
100100
"shape refinement if a fixed-point is not reached">
101101
];
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// RUN: mlir-tensorrt-opt %s -split-input-file -stablehlo-ext-refine-shapes | FileCheck %s
2+
3+
func.func @check_type_refinement() -> tensor<?xf32> {
4+
%c = stablehlo.constant dense<[1, 2, 3]> : tensor<3xi32>
5+
%c_0 = stablehlo.constant dense<3> : tensor<i32>
6+
%c_1 = stablehlo.constant dense<1> : tensor<1xi32>
7+
%c_2 = stablehlo.constant dense<3> : tensor<1xi32>
8+
%c_3 = stablehlo.constant dense<1> : tensor<i32>
9+
%c_4 = stablehlo.constant dense<1> : tensor<1xi32>
10+
%c_5 = stablehlo.constant dense<0> : tensor<i32>
11+
%c_6 = stablehlo.constant dense<1> : tensor<i32>
12+
%c_7 = stablehlo.constant dense<0> : tensor<1xi32>
13+
%c_8 = stablehlo.constant dense<1> : tensor<1xi32>
14+
%0 = stablehlo.compare LE, %c_7, %c_8 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
15+
%1 = stablehlo.select %0, %c_7, %c_8 : tensor<1xi1>, tensor<1xi32>
16+
%c_9 = stablehlo.constant dense<1> : tensor<1xi32>
17+
%2 = stablehlo.real_dynamic_slice %c_4, %1, %c_8, %c_9 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32>
18+
%c_10 = stablehlo.constant dense<> : tensor<0xi32>
19+
%3 = stablehlo.dynamic_reshape %2, %c_10 : (tensor<?xi32>, tensor<0xi32>) -> tensor<i32>
20+
%c_11 = stablehlo.constant dense<-1> : tensor<i32>
21+
%c_12 = stablehlo.constant dense<> : tensor<0xi32>
22+
%4 = stablehlo.compare EQ, %c_12, %c_10 : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi1>
23+
%5 = stablehlo.select %4, %c_12, %c_12 : tensor<0xi1>, tensor<0xi32>
24+
%6 = stablehlo.dynamic_broadcast_in_dim %3, %5, dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
25+
%7 = stablehlo.dynamic_broadcast_in_dim %c_11, %5, dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
26+
%8 = stablehlo.add %6, %7 : tensor<i32>
27+
%c_13 = stablehlo.constant dense<0> : tensor<1xi32>
28+
%c_14 = stablehlo.constant dense<1> : tensor<1xi32>
29+
%9 = stablehlo.compare LE, %c_13, %c_14 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
30+
%10 = stablehlo.select %9, %c_13, %c_14 : tensor<1xi1>, tensor<1xi32>
31+
%c_15 = stablehlo.constant dense<1> : tensor<1xi32>
32+
%11 = stablehlo.real_dynamic_slice %c_4, %10, %c_14, %c_15 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32>
33+
%12 = stablehlo.dynamic_reshape %11, %c_10 : (tensor<?xi32>, tensor<0xi32>) -> tensor<i32>
34+
%13 = stablehlo.compare EQ, %c_12, %c_10 : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi1>
35+
%14 = stablehlo.select %13, %c_12, %c_12 : tensor<0xi1>, tensor<0xi32>
36+
%15 = stablehlo.dynamic_broadcast_in_dim %12, %14, dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
37+
%16 = stablehlo.dynamic_broadcast_in_dim %c_11, %14, dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
38+
%17 = stablehlo.add %15, %16 : tensor<i32>
39+
%18 = stablehlo.compare EQ, %c_12, %c_10 : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi1>
40+
%19 = stablehlo.select %18, %c_12, %c_12 : tensor<0xi1>, tensor<0xi32>
41+
%20 = stablehlo.dynamic_broadcast_in_dim %17, %19, dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
42+
%21 = stablehlo.dynamic_broadcast_in_dim %c_6, %19, dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
43+
%22 = stablehlo.add %20, %21 : tensor<i32>
44+
%23 = stablehlo.reshape %8 : (tensor<i32>) -> tensor<1xi32>
45+
%24 = stablehlo.reshape %22 : (tensor<i32>) -> tensor<1xi32>
46+
%25 = stablehlo.compare LE, %23, %24 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
47+
%26 = stablehlo.select %25, %23, %24 : tensor<1xi1>, tensor<1xi32>
48+
%c_16 = stablehlo.constant dense<1> : tensor<1xi32>
49+
%27 = stablehlo.real_dynamic_slice %c_2, %26, %24, %c_16 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32>
50+
%28 = stablehlo.dynamic_reshape %27, %c_10 : (tensor<?xi32>, tensor<0xi32>) -> tensor<i32>
51+
%29 = stablehlo.dynamic_broadcast_in_dim %28, %c_1, dims = [] : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
52+
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
53+
%30 = stablehlo.dynamic_broadcast_in_dim %cst, %29, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
54+
return %30 : tensor<?xf32>
55+
}
56+
57+
// CHECK-LABEL: func.func @check_type_refinement
58+
// CHECK-DAG: %[[cst:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
59+
// CHECK-DAG: %[[c:.+]] = stablehlo.constant dense<-1> : tensor<i32>
60+
// CHECK-DAG: %[[c_0:.+]] = stablehlo.constant dense<> : tensor<0xi32>
61+
// CHECK-DAG: %[[c_1:.+]] = stablehlo.constant dense<1> : tensor<1xi32>
62+
// CHECK-DAG: %[[c_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32>
63+
// CHECK-DAG: %[[c_3:.+]] = stablehlo.constant dense<1> : tensor<i32>
64+
// CHECK-DAG: %[[c_4:.+]] = stablehlo.constant dense<0> : tensor<1xi32>
65+
// CHECK-DAG: %[[v0:.+]] = stablehlo.real_dynamic_slice %[[c_1]], %[[c_4]], %[[c_1]], %[[c_1]] : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
66+
// CHECK-DAG: %[[v1:.+]] = stablehlo.dynamic_reshape %[[v0]], %[[c_0]] : (tensor<1xi32>, tensor<0xi32>) -> tensor<i32>
67+
// CHECK-DAG: %[[v2:.+]] = stablehlo.dynamic_broadcast_in_dim %[[v1]], %[[c_0]], dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
68+
// CHECK-DAG: %[[v3:.+]] = stablehlo.dynamic_broadcast_in_dim %[[c]], %[[c_0]], dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
69+
// CHECK-DAG: %[[v4:.+]] = stablehlo.add %[[v2]], %[[v3]] : tensor<i32>
70+
// CHECK-DAG: %[[v5:.+]] = stablehlo.real_dynamic_slice %[[c_1]], %[[c_4]], %[[c_1]], %[[c_1]] : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
71+
// CHECK-DAG: %[[v6:.+]] = stablehlo.dynamic_reshape %[[v5]], %[[c_0]] : (tensor<1xi32>, tensor<0xi32>) -> tensor<i32>
72+
// CHECK-DAG: %[[v7:.+]] = stablehlo.dynamic_broadcast_in_dim %[[v6]], %[[c_0]], dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
73+
// CHECK-DAG: %[[v8:.+]] = stablehlo.dynamic_broadcast_in_dim %[[c]], %[[c_0]], dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
74+
// CHECK-DAG: %[[v9:.+]] = stablehlo.add %[[v7]], %[[v8]] : tensor<i32>
75+
// CHECK-DAG: %[[v10:.+]] = stablehlo.dynamic_broadcast_in_dim %[[v9]], %[[c_0]], dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
76+
// CHECK-DAG: %[[v11:.+]] = stablehlo.dynamic_broadcast_in_dim %[[c_3]], %[[c_0]], dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
77+
// CHECK-DAG: %[[v12:.+]] = stablehlo.add %[[v10]], %[[v11]] : tensor<i32>
78+
// CHECK-DAG: %[[v13:.+]] = stablehlo.reshape %[[v4]] : (tensor<i32>) -> tensor<1xi32>
79+
// CHECK-DAG: %[[v14:.+]] = stablehlo.reshape %[[v12]] : (tensor<i32>) -> tensor<1xi32>
80+
// CHECK-DAG: %[[v15:.+]] = stablehlo.compare LE, %[[v13]], %[[v14]] : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
81+
// CHECK-DAG: %[[v16:.+]] = stablehlo.select %[[v15]], %[[v13]], %[[v14]] : tensor<1xi1>, tensor<1xi32>
82+
// CHECK-DAG: %[[v17:.+]] = stablehlo.real_dynamic_slice %[[c_2]], %[[v16]], %[[v14]], %[[c_1]] : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32>
83+
// CHECK-DAG: %[[v18:.+]] = stablehlo.dynamic_reshape %[[v17]], %[[c_0]] : (tensor<?xi32>, tensor<0xi32>) -> tensor<i32>
84+
// CHECK-DAG: %[[v19:.+]] = stablehlo.dynamic_broadcast_in_dim %[[v18]], %[[c_1]], dims = [] : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
85+
// CHECK-DAG: %[[v20:.+]] = stablehlo.dynamic_broadcast_in_dim %[[cst]], %[[v19]], dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
86+
// CHECK-DAG: return %[[v20]] : tensor<?xf32>

0 commit comments

Comments
 (0)