Skip to content

Commit 9c9893c

Browse files
committed
Merge remote-tracking branch 'origin/main' into sam2
2 parents 18cda08 + d37e4f8 commit 9c9893c

38 files changed

+618
-177
lines changed

mlir-tensorrt/compiler/include/mlir-tensorrt/Dialect/StableHloExt/Transforms/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def CanonicalizeShapesPass : Pass<"stablehlo-ext-canonicalize-shapes", "ModuleOp
9595
}];
9696

9797
let options = [
98-
Option<"maxIterations", "max-iterations", "int64_t", "4",
98+
Option<"maxIterations", "max-iterations", "int64_t", "8",
9999
"the maximum number of iterations to run the dynamism simplification and "
100100
"shape refinement if a fixed-point is not reached">
101101
];

mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/MaterializeShapeCalculations.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,6 @@ struct SimplifyExtractOfReshape : public OpRewritePattern<tensor::ExtractOp> {
361361
if (!reshapeOp)
362362
return failure();
363363

364-
// Skip if either shape has dynamic dimensions
365364
if (!reshapeOp.getOperand().getType().hasStaticShape())
366365
return failure();
367366

mlir-tensorrt/compiler/lib/Dialect/StableHloExt/Transforms/ConstantFolding.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1065,7 +1065,6 @@ struct AbsorbTensorCastProducer : public RewritePattern {
10651065
};
10661066
} // namespace
10671067

1068-
10691068
/// Populates patterns that are temporarily reproduced here from upstream
10701069
/// commits we have not yet integrated.
10711070
static void populateFutureUpstreamPatterns(RewritePatternSet &patterns);

mlir-tensorrt/tensorrt/test/Dialect/TensorRT/broadcast-elimination.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,3 +236,42 @@ func.func @broadcast_elim_matmul_vector(%arg0: tensor<?x?x128xf32>, %arg1: tenso
236236
// CHECK: return %[[v0]] : tensor<?x?x100xf32>
237237

238238

239+
// -----
240+
241+
func.func @broadcast_dynamic_expand_shape_regression(%arg0: tensor<?x?x1x1xi1>, %arg1: tensor<?x1xf16>, %arg2: tensor<?x?x256x256xf16>, %arg3: tensor<4xi32>) -> tensor<?x?x256x256xf16> {
242+
%0 = tensorrt.broadcast %arg0 broadcast_dims<0, 1, 2, 3> shape(%arg3 : tensor<4xi32>) : tensor<?x?x1x1xi1> to tensor<?x?x256x256xi1>
243+
%1 = tensorrt.broadcast %arg1 broadcast_dims<2, 3> shape(%arg3 : tensor<4xi32>) : tensor<?x1xf16> to tensor<?x?x256x256xf16>
244+
%2 = tensorrt.select ins(%0, %arg2, %1 : tensor<?x?x256x256xi1>, tensor<?x?x256x256xf16>, tensor<?x?x256x256xf16>)
245+
-> tensor<?x?x256x256xf16>
246+
return %2 : tensor<?x?x256x256xf16>
247+
}
248+
249+
// CHECK-LABEL: func.func @broadcast_dynamic_expand_shape_regression
250+
// CHECK-SAME: (%[[arg0:.+]]: tensor<?x?x1x1xi1>, %[[arg1:.+]]: tensor<?x1xf16>, %[[arg2:.+]]: tensor<?x?x256x256xf16>, %[[arg3:.+]]: tensor<4xi32>) -> tensor<?x?x256x256xf16> {
251+
// CHECK: %[[v0:.+]] = tensorrt.reshape %[[arg1]] : tensor<?x1xf16> to tensor<1x1x?x1xf16>
252+
// CHECK: %[[v1:.+]] = tensorrt.select ins(%[[arg0]], %[[arg2]], %[[v0]] : tensor<?x?x1x1xi1>, tensor<?x?x256x256xf16>, tensor<1x1x?x1xf16>) -> tensor<?x?x256x256xf16>
253+
// CHECK: return %[[v1]] : tensor<?x?x256x256xf16>
254+
255+
// -----
256+
257+
func.func @broadcast_dynamic_expand_shape_regression(%arg0: tensor<?x?x1x1xi1>, %arg1: tensor<?x1x?xf16>, %arg2: tensor<?x?x256x256xf16>, %arg3: tensor<4xi32>) -> tensor<?x?x256x256xf16> {
258+
%0 = tensorrt.broadcast %arg0 broadcast_dims<0, 1, 2, 3> shape(%arg3 : tensor<4xi32>) : tensor<?x?x1x1xi1> to tensor<?x?x256x256xi1>
259+
%1 = tensorrt.broadcast %arg1 broadcast_dims<3, 2, 1> shape(%arg3 : tensor<4xi32>) : tensor<?x1x?xf16> to tensor<?x?x256x256xf16>
260+
%2 = tensorrt.select ins(%0, %arg2, %1 : tensor<?x?x256x256xi1>, tensor<?x?x256x256xf16>, tensor<?x?x256x256xf16>)
261+
-> tensor<?x?x256x256xf16>
262+
return %2 : tensor<?x?x256x256xf16>
263+
}
264+
265+
// CHECK: #[[$map:.+]] = affine_map<(d0, d1, d2) -> (d2, d1, d0)>
266+
// CHECK: module {
267+
// CHECK-LABEL: func.func @broadcast_dynamic_expand_shape_regression
268+
// CHECK-SAME: (%[[arg0:.+]]: tensor<?x?x1x1xi1>, %[[arg1:.+]]: tensor<?x1x?xf16>, %[[arg2:.+]]: tensor<?x?x256x256xf16>, %[[arg3:.+]]: tensor<4xi32>) -> tensor<?x?x256x256xf16> {
269+
// CHECK: %[[cst_i32:.+]] = tensorrt.constant dense<1> : tensor<1xi32>
270+
// CHECK: %[[v0:.+]] = tensorrt.transpose {permutation = #[[$map]]} %[[arg1]] : tensor<?x1x?xf16> to tensor<?x1x?xf16>
271+
// CHECK: %[[v1:.+]] = tensorrt.shape %[[v0]] : tensor<?x1x?xf16> -> tensor<3xi32>
272+
// CHECK: %[[v2:.+]] = tensorrt.slice %[[v1]][0][1][1] : tensor<3xi32> to tensor<1xi32>
273+
// CHECK: %[[v3:.+]] = tensorrt.slice %[[v1]][2][1][1] : tensor<3xi32> to tensor<1xi32>
274+
// CHECK: %[[v4:.+]] = tensorrt.concatenation {axis = 0 : i32} ins(%[[cst_i32]], %[[v2]], %[[cst_i32]], %[[v3]] : tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
275+
// CHECK: %[[v5:.+]] = tensorrt.reshape %[[v0]] shape(%[[v4]]: tensor<4xi32>) : tensor<?x1x?xf16> to tensor<1x?x1x?xf16>
276+
// CHECK: %[[v6:.+]] = tensorrt.select ins(%[[arg0]], %[[arg2]], %[[v5]] : tensor<?x?x1x1xi1>, tensor<?x?x256x256xf16>, tensor<1x?x1x?xf16>) -> tensor<?x?x256x256xf16>
277+
// CHECK: return %[[v6]] : tensor<?x?x256x256xf16>

mlir-tensorrt/test/Dialect/Plan/materialize-shape-calculations.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,3 +1088,27 @@ func.func @reduce_window_dynamic_input(%arg0: tensor<?x?x?x?xf32> {tensorrt.shap
10881088
// CHECK-DAG: %[[v2:.+]] = arith.maxsi %[[dim]], %[[c0]] : index
10891089
// CHECK-DAG: %[[v3:.+]] = plan.with_shape %[[v1]](%[[v2]], %[[c3]], %[[c512]], %[[c512]]) :
10901090
// CHECK-DAG: return %[[v3]]
1091+
1092+
// -----
1093+
1094+
func.func @simplify_extract_of_reshape_negative(%arg0: tensor<1x?x3x4xf32>) -> f32 {
1095+
%c0 = arith.constant 0: index
1096+
%c1 = arith.constant 1 : index
1097+
%c2 = arith.constant 2 : index
1098+
%1 = stablehlo.reshape %arg0 : (tensor<1x?x3x4xf32>) -> tensor<1x6x4xf32>
1099+
%2 = tensor.extract %1[%c0, %c1, %c2] : tensor<1x6x4xf32>
1100+
return %2 : f32
1101+
}
1102+
1103+
// CHECK-LABEL: simplify_extract_of_reshape_negative
1104+
// CHECK-SAME: (%[[arg0:.+]]: tensor<1x?x3x4xf32>)
1105+
// CHECK-NEXT: %[[c4:.+]] = arith.constant 4 : index
1106+
// CHECK-NEXT: %[[c3:.+]] = arith.constant 3 : index
1107+
// CHECK-NEXT: %[[c2:.+]] = arith.constant 2 : index
1108+
// CHECK-NEXT: %[[c1:.+]] = arith.constant 1 : index
1109+
// CHECK-NEXT: %[[c0:.+]] = arith.constant 0 : index
1110+
// CHECK-NEXT: %[[dim:.+]] = tensor.dim %[[arg0]], %[[c1]] : tensor<1x?x3x4xf32>
1111+
// CHECK-NEXT: %[[v0:.+]] = plan.with_shape %[[arg0]](%[[c1]], %[[dim]], %[[c3]], %[[c4]])
1112+
// CHECK-NEXT: %[[v1:.+]] = stablehlo.reshape %[[v0]]
1113+
// CHECK-NEXT: %[[extracted:.+]] = tensor.extract %[[v1]][%[[c0]], %[[c1]], %[[c2]]]
1114+
// CHECK-NEXT: return %extracted

mlir-tensorrt/test/Dialect/StableHloExt/constant-folding.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,22 @@ func.func @concat_simplify_single_operand_requires_cast(%arg0: tensor<4xi32>) ->
402402

403403
// -----
404404

405+
func.func @concat_slice_concat(%arg0: tensor<1xi32>, %arg1: tensor<3xi32>, %arg2: tensor<1xi32>) -> tensor<5xi32> {
406+
%0 = stablehlo.concatenate %arg0, %arg1, %arg2, dim = 0 : (tensor<1xi32>, tensor<3xi32>, tensor<1xi32>) -> tensor<5xi32>
407+
%1 = stablehlo.slice %0 [1:5] : (tensor<5xi32>) -> tensor<4xi32>
408+
%2 = stablehlo.constant dense<1> : tensor<1xi32>
409+
%3 = stablehlo.concatenate %2, %1, dim = 0 : (tensor<1xi32>, tensor<4xi32>) -> tensor<5xi32>
410+
return %3 : tensor<5xi32>
411+
}
412+
413+
// CHECK-LABEL: func.func @concat_slice_concat
414+
// CHECK-SAME: (%[[arg0:.+]]: tensor<1xi32>, %[[arg1:.+]]: tensor<3xi32>, %[[arg2:.+]]: tensor<1xi32>) -> tensor<5xi32>
415+
// CHECK: %[[c:.+]] = stablehlo.constant dense<1> : tensor<1xi32>
416+
// CHECK: %[[v0:.+]] = stablehlo.concatenate %[[c]], %[[arg1]], %[[arg2]], dim = 0
417+
// CHECK: return %[[v0]] : tensor<5xi32>
418+
419+
// -----
420+
405421
func.func @bitwise_or_fold_lhs(%arg0: tensor<5xi8>, %arg1: tensor<5xi1>, %arg2: tensor<5xi32>) -> (tensor<5xi8>, tensor<5xi1>, tensor<5xi32>, tensor<5xi32>){
406422
%0 = stablehlo.constant dense<[255, 255, 255, 255, 255]> : tensor<5xi8>
407423
%1 = stablehlo.or %0, %arg0 : tensor<5xi8>
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// RUN: mlir-tensorrt-opt %s -split-input-file -stablehlo-ext-refine-shapes | FileCheck %s
2+
3+
func.func @check_type_refinement() -> tensor<?xf32> {
4+
%c = stablehlo.constant dense<[1, 2, 3]> : tensor<3xi32>
5+
%c_0 = stablehlo.constant dense<3> : tensor<i32>
6+
%c_1 = stablehlo.constant dense<1> : tensor<1xi32>
7+
%c_2 = stablehlo.constant dense<3> : tensor<1xi32>
8+
%c_3 = stablehlo.constant dense<1> : tensor<i32>
9+
%c_4 = stablehlo.constant dense<1> : tensor<1xi32>
10+
%c_5 = stablehlo.constant dense<0> : tensor<i32>
11+
%c_6 = stablehlo.constant dense<1> : tensor<i32>
12+
%c_7 = stablehlo.constant dense<0> : tensor<1xi32>
13+
%c_8 = stablehlo.constant dense<1> : tensor<1xi32>
14+
%0 = stablehlo.compare LE, %c_7, %c_8 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
15+
%1 = stablehlo.select %0, %c_7, %c_8 : tensor<1xi1>, tensor<1xi32>
16+
%c_9 = stablehlo.constant dense<1> : tensor<1xi32>
17+
%2 = stablehlo.real_dynamic_slice %c_4, %1, %c_8, %c_9 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32>
18+
%c_10 = stablehlo.constant dense<> : tensor<0xi32>
19+
%3 = stablehlo.dynamic_reshape %2, %c_10 : (tensor<?xi32>, tensor<0xi32>) -> tensor<i32>
20+
%c_11 = stablehlo.constant dense<-1> : tensor<i32>
21+
%c_12 = stablehlo.constant dense<> : tensor<0xi32>
22+
%4 = stablehlo.compare EQ, %c_12, %c_10 : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi1>
23+
%5 = stablehlo.select %4, %c_12, %c_12 : tensor<0xi1>, tensor<0xi32>
24+
%6 = stablehlo.dynamic_broadcast_in_dim %3, %5, dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
25+
%7 = stablehlo.dynamic_broadcast_in_dim %c_11, %5, dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
26+
%8 = stablehlo.add %6, %7 : tensor<i32>
27+
%c_13 = stablehlo.constant dense<0> : tensor<1xi32>
28+
%c_14 = stablehlo.constant dense<1> : tensor<1xi32>
29+
%9 = stablehlo.compare LE, %c_13, %c_14 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
30+
%10 = stablehlo.select %9, %c_13, %c_14 : tensor<1xi1>, tensor<1xi32>
31+
%c_15 = stablehlo.constant dense<1> : tensor<1xi32>
32+
%11 = stablehlo.real_dynamic_slice %c_4, %10, %c_14, %c_15 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32>
33+
%12 = stablehlo.dynamic_reshape %11, %c_10 : (tensor<?xi32>, tensor<0xi32>) -> tensor<i32>
34+
%13 = stablehlo.compare EQ, %c_12, %c_10 : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi1>
35+
%14 = stablehlo.select %13, %c_12, %c_12 : tensor<0xi1>, tensor<0xi32>
36+
%15 = stablehlo.dynamic_broadcast_in_dim %12, %14, dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
37+
%16 = stablehlo.dynamic_broadcast_in_dim %c_11, %14, dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
38+
%17 = stablehlo.add %15, %16 : tensor<i32>
39+
%18 = stablehlo.compare EQ, %c_12, %c_10 : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi1>
40+
%19 = stablehlo.select %18, %c_12, %c_12 : tensor<0xi1>, tensor<0xi32>
41+
%20 = stablehlo.dynamic_broadcast_in_dim %17, %19, dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
42+
%21 = stablehlo.dynamic_broadcast_in_dim %c_6, %19, dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
43+
%22 = stablehlo.add %20, %21 : tensor<i32>
44+
%23 = stablehlo.reshape %8 : (tensor<i32>) -> tensor<1xi32>
45+
%24 = stablehlo.reshape %22 : (tensor<i32>) -> tensor<1xi32>
46+
%25 = stablehlo.compare LE, %23, %24 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
47+
%26 = stablehlo.select %25, %23, %24 : tensor<1xi1>, tensor<1xi32>
48+
%c_16 = stablehlo.constant dense<1> : tensor<1xi32>
49+
%27 = stablehlo.real_dynamic_slice %c_2, %26, %24, %c_16 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32>
50+
%28 = stablehlo.dynamic_reshape %27, %c_10 : (tensor<?xi32>, tensor<0xi32>) -> tensor<i32>
51+
%29 = stablehlo.dynamic_broadcast_in_dim %28, %c_1, dims = [] : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
52+
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
53+
%30 = stablehlo.dynamic_broadcast_in_dim %cst, %29, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
54+
return %30 : tensor<?xf32>
55+
}
56+
57+
// CHECK-LABEL: func.func @check_type_refinement
58+
// CHECK-DAG: %[[cst:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
59+
// CHECK-DAG: %[[c:.+]] = stablehlo.constant dense<-1> : tensor<i32>
60+
// CHECK-DAG: %[[c_0:.+]] = stablehlo.constant dense<> : tensor<0xi32>
61+
// CHECK-DAG: %[[c_1:.+]] = stablehlo.constant dense<1> : tensor<1xi32>
62+
// CHECK-DAG: %[[c_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32>
63+
// CHECK-DAG: %[[c_3:.+]] = stablehlo.constant dense<1> : tensor<i32>
64+
// CHECK-DAG: %[[c_4:.+]] = stablehlo.constant dense<0> : tensor<1xi32>
65+
// CHECK-DAG: %[[v0:.+]] = stablehlo.real_dynamic_slice %[[c_1]], %[[c_4]], %[[c_1]], %[[c_1]] : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
66+
// CHECK-DAG: %[[v1:.+]] = stablehlo.dynamic_reshape %[[v0]], %[[c_0]] : (tensor<1xi32>, tensor<0xi32>) -> tensor<i32>
67+
// CHECK-DAG: %[[v2:.+]] = stablehlo.dynamic_broadcast_in_dim %[[v1]], %[[c_0]], dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
68+
// CHECK-DAG: %[[v3:.+]] = stablehlo.dynamic_broadcast_in_dim %[[c]], %[[c_0]], dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
69+
// CHECK-DAG: %[[v4:.+]] = stablehlo.add %[[v2]], %[[v3]] : tensor<i32>
70+
// CHECK-DAG: %[[v5:.+]] = stablehlo.real_dynamic_slice %[[c_1]], %[[c_4]], %[[c_1]], %[[c_1]] : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
71+
// CHECK-DAG: %[[v6:.+]] = stablehlo.dynamic_reshape %[[v5]], %[[c_0]] : (tensor<1xi32>, tensor<0xi32>) -> tensor<i32>
72+
// CHECK-DAG: %[[v7:.+]] = stablehlo.dynamic_broadcast_in_dim %[[v6]], %[[c_0]], dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
73+
// CHECK-DAG: %[[v8:.+]] = stablehlo.dynamic_broadcast_in_dim %[[c]], %[[c_0]], dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
74+
// CHECK-DAG: %[[v9:.+]] = stablehlo.add %[[v7]], %[[v8]] : tensor<i32>
75+
// CHECK-DAG: %[[v10:.+]] = stablehlo.dynamic_broadcast_in_dim %[[v9]], %[[c_0]], dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
76+
// CHECK-DAG: %[[v11:.+]] = stablehlo.dynamic_broadcast_in_dim %[[c_3]], %[[c_0]], dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
77+
// CHECK-DAG: %[[v12:.+]] = stablehlo.add %[[v10]], %[[v11]] : tensor<i32>
78+
// CHECK-DAG: %[[v13:.+]] = stablehlo.reshape %[[v4]] : (tensor<i32>) -> tensor<1xi32>
79+
// CHECK-DAG: %[[v14:.+]] = stablehlo.reshape %[[v12]] : (tensor<i32>) -> tensor<1xi32>
80+
// CHECK-DAG: %[[v15:.+]] = stablehlo.compare LE, %[[v13]], %[[v14]] : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
81+
// CHECK-DAG: %[[v16:.+]] = stablehlo.select %[[v15]], %[[v13]], %[[v14]] : tensor<1xi1>, tensor<1xi32>
82+
// CHECK-DAG: %[[v17:.+]] = stablehlo.real_dynamic_slice %[[c_2]], %[[v16]], %[[v14]], %[[c_1]] : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32>
83+
// CHECK-DAG: %[[v18:.+]] = stablehlo.dynamic_reshape %[[v17]], %[[c_0]] : (tensor<?xi32>, tensor<0xi32>) -> tensor<i32>
84+
// CHECK-DAG: %[[v19:.+]] = stablehlo.dynamic_broadcast_in_dim %[[v18]], %[[c_1]], dims = [] : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
85+
// CHECK-DAG: %[[v20:.+]] = stablehlo.dynamic_broadcast_in_dim %[[cst]], %[[v19]], dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
86+
// CHECK-DAG: return %[[v20]] : tensor<?xf32>

mlir-tensorrt/test/models/bert.stablehlo.elided.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
1+
module @bert attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
22
func.func public @main(%arg0: tensor<32x8xi32> {mhlo.layout_mode = "default"}) -> (tensor<32x8x768xf16> {mhlo.layout_mode = "default"}, tensor<32x768xf16> {mhlo.layout_mode = "default"}) {
33
%0 = stablehlo.constant dense_resource<__elided__> : tensor<30522x768xf32>
44
%1 = stablehlo.constant dense_resource<__elided__> : tensor<512x768xf32>

mlir-tensorrt/test/models/gpt2.stablehlo.bs2.elided.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module @jit_generate attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
1+
module @gpt2_bs2 attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
22
func.func public @main(%arg0: tensor<2x6xi32> {mhlo.sharding = "{replicated}"}, %arg1: tensor<2x6xi32> {mhlo.sharding = "{replicated}"}) -> (tensor<2x20xi32> {jax.result_info = ""}) {
33
%0 = stablehlo.constant dense<0> : tensor<1xi32>
44
%1 = stablehlo.constant dense<768> : tensor<i32>

mlir-tensorrt/test/models/gpt2.stablehlo.elided.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module @jit_generate attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
1+
module @gpt_bs1 attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
22
func.func public @main(%arg0: tensor<1x7xi32> {jax.arg_info = "inputs['attention_mask']", mhlo.sharding = "{replicated}"}, %arg1: tensor<1x7xi32> {jax.arg_info = "inputs['input_ids']", mhlo.sharding = "{replicated}"}) -> (tensor<1x20xi32> {jax.result_info = ""}) {
33
%0 = stablehlo.constant dense_resource<__elided__> : tensor<50257x768xf16>
44
%1 = stablehlo.constant dense_resource<__elided__> : tensor<1024x768xf16>

0 commit comments

Comments
 (0)