1+ // RUN: mlir-tensorrt-opt %s -split-input-file -stablehlo-ext-refine-shapes | FileCheck %s
2+
3+ func.func @check_type_refinement () -> tensor <?xf32 > {
4+ %c = stablehlo.constant dense <[1 , 2 , 3 ]> : tensor <3 xi32 >
5+ %c_0 = stablehlo.constant dense <3 > : tensor <i32 >
6+ %c_1 = stablehlo.constant dense <1 > : tensor <1 xi32 >
7+ %c_2 = stablehlo.constant dense <3 > : tensor <1 xi32 >
8+ %c_3 = stablehlo.constant dense <1 > : tensor <i32 >
9+ %c_4 = stablehlo.constant dense <1 > : tensor <1 xi32 >
10+ %c_5 = stablehlo.constant dense <0 > : tensor <i32 >
11+ %c_6 = stablehlo.constant dense <1 > : tensor <i32 >
12+ %c_7 = stablehlo.constant dense <0 > : tensor <1 xi32 >
13+ %c_8 = stablehlo.constant dense <1 > : tensor <1 xi32 >
14+ %0 = stablehlo.compare LE , %c_7 , %c_8 : (tensor <1 xi32 >, tensor <1 xi32 >) -> tensor <1 xi1 >
15+ %1 = stablehlo.select %0 , %c_7 , %c_8 : tensor <1 xi1 >, tensor <1 xi32 >
16+ %c_9 = stablehlo.constant dense <1 > : tensor <1 xi32 >
17+ %2 = stablehlo.real_dynamic_slice %c_4 , %1 , %c_8 , %c_9 : (tensor <1 xi32 >, tensor <1 xi32 >, tensor <1 xi32 >, tensor <1 xi32 >) -> tensor <?xi32 >
18+ %c_10 = stablehlo.constant dense <> : tensor <0 xi32 >
19+ %3 = stablehlo.dynamic_reshape %2 , %c_10 : (tensor <?xi32 >, tensor <0 xi32 >) -> tensor <i32 >
20+ %c_11 = stablehlo.constant dense <-1 > : tensor <i32 >
21+ %c_12 = stablehlo.constant dense <> : tensor <0 xi32 >
22+ %4 = stablehlo.compare EQ , %c_12 , %c_10 : (tensor <0 xi32 >, tensor <0 xi32 >) -> tensor <0 xi1 >
23+ %5 = stablehlo.select %4 , %c_12 , %c_12 : tensor <0 xi1 >, tensor <0 xi32 >
24+ %6 = stablehlo.dynamic_broadcast_in_dim %3 , %5 , dims = [] : (tensor <i32 >, tensor <0 xi32 >) -> tensor <i32 >
25+ %7 = stablehlo.dynamic_broadcast_in_dim %c_11 , %5 , dims = [] : (tensor <i32 >, tensor <0 xi32 >) -> tensor <i32 >
26+ %8 = stablehlo.add %6 , %7 : tensor <i32 >
27+ %c_13 = stablehlo.constant dense <0 > : tensor <1 xi32 >
28+ %c_14 = stablehlo.constant dense <1 > : tensor <1 xi32 >
29+ %9 = stablehlo.compare LE , %c_13 , %c_14 : (tensor <1 xi32 >, tensor <1 xi32 >) -> tensor <1 xi1 >
30+ %10 = stablehlo.select %9 , %c_13 , %c_14 : tensor <1 xi1 >, tensor <1 xi32 >
31+ %c_15 = stablehlo.constant dense <1 > : tensor <1 xi32 >
32+ %11 = stablehlo.real_dynamic_slice %c_4 , %10 , %c_14 , %c_15 : (tensor <1 xi32 >, tensor <1 xi32 >, tensor <1 xi32 >, tensor <1 xi32 >) -> tensor <?xi32 >
33+ %12 = stablehlo.dynamic_reshape %11 , %c_10 : (tensor <?xi32 >, tensor <0 xi32 >) -> tensor <i32 >
34+ %13 = stablehlo.compare EQ , %c_12 , %c_10 : (tensor <0 xi32 >, tensor <0 xi32 >) -> tensor <0 xi1 >
35+ %14 = stablehlo.select %13 , %c_12 , %c_12 : tensor <0 xi1 >, tensor <0 xi32 >
36+ %15 = stablehlo.dynamic_broadcast_in_dim %12 , %14 , dims = [] : (tensor <i32 >, tensor <0 xi32 >) -> tensor <i32 >
37+ %16 = stablehlo.dynamic_broadcast_in_dim %c_11 , %14 , dims = [] : (tensor <i32 >, tensor <0 xi32 >) -> tensor <i32 >
38+ %17 = stablehlo.add %15 , %16 : tensor <i32 >
39+ %18 = stablehlo.compare EQ , %c_12 , %c_10 : (tensor <0 xi32 >, tensor <0 xi32 >) -> tensor <0 xi1 >
40+ %19 = stablehlo.select %18 , %c_12 , %c_12 : tensor <0 xi1 >, tensor <0 xi32 >
41+ %20 = stablehlo.dynamic_broadcast_in_dim %17 , %19 , dims = [] : (tensor <i32 >, tensor <0 xi32 >) -> tensor <i32 >
42+ %21 = stablehlo.dynamic_broadcast_in_dim %c_6 , %19 , dims = [] : (tensor <i32 >, tensor <0 xi32 >) -> tensor <i32 >
43+ %22 = stablehlo.add %20 , %21 : tensor <i32 >
44+ %23 = stablehlo.reshape %8 : (tensor <i32 >) -> tensor <1 xi32 >
45+ %24 = stablehlo.reshape %22 : (tensor <i32 >) -> tensor <1 xi32 >
46+ %25 = stablehlo.compare LE , %23 , %24 : (tensor <1 xi32 >, tensor <1 xi32 >) -> tensor <1 xi1 >
47+ %26 = stablehlo.select %25 , %23 , %24 : tensor <1 xi1 >, tensor <1 xi32 >
48+ %c_16 = stablehlo.constant dense <1 > : tensor <1 xi32 >
49+ %27 = stablehlo.real_dynamic_slice %c_2 , %26 , %24 , %c_16 : (tensor <1 xi32 >, tensor <1 xi32 >, tensor <1 xi32 >, tensor <1 xi32 >) -> tensor <?xi32 >
50+ %28 = stablehlo.dynamic_reshape %27 , %c_10 : (tensor <?xi32 >, tensor <0 xi32 >) -> tensor <i32 >
51+ %29 = stablehlo.dynamic_broadcast_in_dim %28 , %c_1 , dims = [] : (tensor <i32 >, tensor <1 xi32 >) -> tensor <1 xi32 >
52+ %cst = stablehlo.constant dense <1.000000e+00 > : tensor <f32 >
53+ %30 = stablehlo.dynamic_broadcast_in_dim %cst , %29 , dims = [] : (tensor <f32 >, tensor <1 xi32 >) -> tensor <?xf32 >
54+ return %30 : tensor <?xf32 >
55+ }
56+
57+ // CHECK-LABEL: func.func @check_type_refinement
58+ // CHECK-DAG: %[[cst:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
59+ // CHECK-DAG: %[[c:.+]] = stablehlo.constant dense<-1> : tensor<i32>
60+ // CHECK-DAG: %[[c_0:.+]] = stablehlo.constant dense<> : tensor<0xi32>
61+ // CHECK-DAG: %[[c_1:.+]] = stablehlo.constant dense<1> : tensor<1xi32>
62+ // CHECK-DAG: %[[c_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32>
63+ // CHECK-DAG: %[[c_3:.+]] = stablehlo.constant dense<1> : tensor<i32>
64+ // CHECK-DAG: %[[c_4:.+]] = stablehlo.constant dense<0> : tensor<1xi32>
65+ // CHECK-DAG: %[[v0:.+]] = stablehlo.real_dynamic_slice %[[c_1]], %[[c_4]], %[[c_1]], %[[c_1]] : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
66+ // CHECK-DAG: %[[v1:.+]] = stablehlo.dynamic_reshape %[[v0]], %[[c_0]] : (tensor<1xi32>, tensor<0xi32>) -> tensor<i32>
67+ // CHECK-DAG: %[[v2:.+]] = stablehlo.dynamic_broadcast_in_dim %[[v1]], %[[c_0]], dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
68+ // CHECK-DAG: %[[v3:.+]] = stablehlo.dynamic_broadcast_in_dim %[[c]], %[[c_0]], dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
69+ // CHECK-DAG: %[[v4:.+]] = stablehlo.add %[[v2]], %[[v3]] : tensor<i32>
70+ // CHECK-DAG: %[[v5:.+]] = stablehlo.real_dynamic_slice %[[c_1]], %[[c_4]], %[[c_1]], %[[c_1]] : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
71+ // CHECK-DAG: %[[v6:.+]] = stablehlo.dynamic_reshape %[[v5]], %[[c_0]] : (tensor<1xi32>, tensor<0xi32>) -> tensor<i32>
72+ // CHECK-DAG: %[[v7:.+]] = stablehlo.dynamic_broadcast_in_dim %[[v6]], %[[c_0]], dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
73+ // CHECK-DAG: %[[v8:.+]] = stablehlo.dynamic_broadcast_in_dim %[[c]], %[[c_0]], dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
74+ // CHECK-DAG: %[[v9:.+]] = stablehlo.add %[[v7]], %[[v8]] : tensor<i32>
75+ // CHECK-DAG: %[[v10:.+]] = stablehlo.dynamic_broadcast_in_dim %[[v9]], %[[c_0]], dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
76+ // CHECK-DAG: %[[v11:.+]] = stablehlo.dynamic_broadcast_in_dim %[[c_3]], %[[c_0]], dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
77+ // CHECK-DAG: %[[v12:.+]] = stablehlo.add %[[v10]], %[[v11]] : tensor<i32>
78+ // CHECK-DAG: %[[v13:.+]] = stablehlo.reshape %[[v4]] : (tensor<i32>) -> tensor<1xi32>
79+ // CHECK-DAG: %[[v14:.+]] = stablehlo.reshape %[[v12]] : (tensor<i32>) -> tensor<1xi32>
80+ // CHECK-DAG: %[[v15:.+]] = stablehlo.compare LE, %[[v13]], %[[v14]] : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
81+ // CHECK-DAG: %[[v16:.+]] = stablehlo.select %[[v15]], %[[v13]], %[[v14]] : tensor<1xi1>, tensor<1xi32>
82+ // CHECK-DAG: %[[v17:.+]] = stablehlo.real_dynamic_slice %[[c_2]], %[[v16]], %[[v14]], %[[c_1]] : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32>
83+ // CHECK-DAG: %[[v18:.+]] = stablehlo.dynamic_reshape %[[v17]], %[[c_0]] : (tensor<?xi32>, tensor<0xi32>) -> tensor<i32>
84+ // CHECK-DAG: %[[v19:.+]] = stablehlo.dynamic_broadcast_in_dim %[[v18]], %[[c_1]], dims = [] : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
85+ // CHECK-DAG: %[[v20:.+]] = stablehlo.dynamic_broadcast_in_dim %[[cst]], %[[v19]], dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
86+ // CHECK-DAG: return %[[v20]] : tensor<?xf32>
0 commit comments