Skip to content

Commit a1ccbe6

Browse files
Upgrade LLVM-Project, StableHLO dependencies (#414)
This is a combination of the following commits (from older to most recent): ## [compiler/Dialect/Plan] Improve efficiency of TensorRT clustering Runs TensorRT analysis conversion on the container op and uses that result in the clustering analysis instead of checking conversion op-by-op, which will no longer work well in after an upcoming LLVM upgrade. This change also removes the "tensorrt major version" parameter for the `ClusterKindAttrInterface::getClusteringOpts` method arguments since that information can be passed directly as a parameter of the `TensorRTClusterKindAttr`` ## Upgrade all dependencies This upgrades dependencies to the following base commits: ``` LLVM_COMMIT_UPSTREAM = "6c64c8a6f3f77c30745c751d4163ff6bf2fc323b" TORCH_MLIR_COMMIT = "30c519369ed7eabad0282d0f874500a9b41fcbbd" STABLEHLO_COMMIT = "6e403b1aa6a71f5eaa09cc720e4ad42f692745e6" ``` ## Update how cluster kinds and their options are populated This change updates the 'plan-stablehlo-clustering' pass so that it no longer populates default "cluster kinds" if the 'plan.cluster_kinds' attribute is missing on the module. Instead, it is up to the top-level compilation entrypoints or the frontend to populate these key metadata attributes. This allows for significant simplification in how options are passed to various passes and pipelines. For example, the 'stablehlo-clustering' and 'plan-segmentation-pipeline' both no longer need to accept pass options specifically for the TensorRT clustering options. These options are dictated purely by the TensorRT clustering attribute attached to the module. Only the top level compilation entrypoint requires such options, which it can use to populate the clustering backend attributes parameters. GitOrigin-RevId: 66aa8a1e07b4c94195ed27bcd96f7fd2d00c6f3e
1 parent 2c6cc1a commit a1ccbe6

39 files changed

+549
-1151
lines changed

mlir-tensorrt/CMakeLists.txt

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,13 @@ if(PROJECT_IS_TOP_LEVEL)
124124
elseif(MLIR_TRT_LLVM_COMMIT)
125125
mtrt_llvm_project(
126126
NAME llvm_project
127-
VERSION 0.0.20240812
127+
VERSION 0.0.20241126
128128
URL "https://github.com/llvm/llvm-project/archive/${MLIR_TRT_LLVM_COMMIT}.zip"
129129
EXCLUDE_FROM_ALL TRUE
130130
SOURCE_SUBDIR "llvm"
131-
PATCHES "${CMAKE_SOURCE_DIR}/build_tools/llvm-project.patch"
132-
131+
PATCHES
132+
"${CMAKE_CURRENT_LIST_DIR}/build_tools/patches/mlir/000_fix_bufferization_tensor_encoding_memory_spaces.patch"
133+
"${CMAKE_CURRENT_LIST_DIR}/build_tools/patches/mlir/001-mlir-Add-a-null-pointer-check-in-symbol-lookup-11516.patch"
133134
OPTIONS
134135
"LLVM_ENABLE_PROJECTS mlir"
135136
"MLIR_ENABLE_BINDINGS_PYTHON ${MLIR_TRT_ENABLE_PYTHON}"
@@ -180,12 +181,11 @@ include(HandleLLVMOptions)
180181
# Download Stablehlo if it isn't provided by a parent project.
181182
if(MLIR_TRT_ENABLE_HLO AND NOT TARGET StablehloOps)
182183
mtrt_add_stablehlo(
183-
VERSION 1.6.4
184-
GIT_TAG 1456dfa1e1a83aab0cc717714ba3695886f60302
184+
VERSION 1.8.0
185+
GIT_TAG 6e403b1aa6a71f5eaa09cc720e4ad42f692745e6
185186
GIT_REPOSITORY "https://github.com/openxla/stablehlo.git"
186-
PATCHES
187-
"${MLIR_TENSORRT_ROOT_DIR}/build_tools/stablehlo.patch"
188-
"${MLIR_TENSORRT_ROOT_DIR}/build_tools/stablehlo_aggressive_folder.patch"
187+
PATCHES
188+
"${CMAKE_CURRENT_LIST_DIR}/build_tools/patches/stablehlo/0001-transforms-Fix-simplification-patterns-for-stablehlo.patch"
189189
OPTIONS
190190
"STABLEHLO_ENABLE_BINDINGS_PYTHON ${MLIR_TRT_ENABLE_PYTHON}"
191191
"STABLEHLO_BUILD_EMBEDDED ON"
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
set(LLVM_COMMIT "c49770c60f26e449379447109f7d915bd8de0384")
1+
set(LLVM_COMMIT "6c64c8a6f3f77c30745c751d4163ff6bf2fc323b")

mlir-tensorrt/build_tools/llvm-project.patch renamed to mlir-tensorrt/build_tools/patches/mlir/000_fix_bufferization_tensor_encoding_memory_spaces.patch

Lines changed: 19 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@ index 1c70a4b8df92..c97a1fe819a3 100644
7272

7373
let hasFolder = 1;
7474
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
75-
index a610ddcc9899..18c4e5cedc8c 100644
75+
index afc193b5517d..b5c81a0e73e2 100644
7676
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
7777
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
78-
@@ -533,6 +533,10 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
78+
@@ -526,6 +526,10 @@ def OneShotBufferize : Pass<"one-shot-bufferize"> {
7979
/*default=*/"false",
8080
"The memory space of an memref types must always be inferred. If "
8181
"unset, a default memory space of 0 is used otherwise.">,
@@ -87,10 +87,10 @@ index a610ddcc9899..18c4e5cedc8c 100644
8787
/*default=*/"false",
8888
"Test only: Only run inplaceability analysis and annotate IR">,
8989
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
90-
index d51d63f243ea..550ac7e83b9e 100644
90+
index 85604eef2f28..065739ea8e59 100644
9191
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
9292
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
93-
@@ -719,7 +719,7 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
93+
@@ -718,7 +718,7 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
9494
// loose all of its users and eventually DCE away.
9595
rewriter.setInsertionPointAfter(op);
9696
replacement = rewriter.create<bufferization::ToTensorOp>(
@@ -100,7 +100,7 @@ index d51d63f243ea..550ac7e83b9e 100644
100100
replacements.push_back(replacement);
101101
}
102102
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
103-
index 04a8ff30ee94..d1d5d3b89b3e 100644
103+
index f1841b860ff8..774d1a2ec04b 100644
104104
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
105105
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
106106
@@ -23,6 +23,16 @@ using namespace mlir::bufferization;
@@ -121,7 +121,7 @@ index 04a8ff30ee94..d1d5d3b89b3e 100644
121121
OpBuilder &b, Value value, MemRefType destType,
122122
const BufferizationOptions &options) {
123123
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
124-
index e422086c9fde..d4271c68cbcc 100644
124+
index 429695126a95..911c2e862e6f 100644
125125
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
126126
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
127127
@@ -67,10 +67,14 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
@@ -153,12 +153,12 @@ index e422086c9fde..d4271c68cbcc 100644
153153
+ };
154154
}
155155
opt.printConflicts = printConflicts;
156-
opt.testAnalysisOnly = testAnalysisOnly;
156+
opt.bufferAlignment = bufferAlignment;
157157
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
158-
index 87464ccb7172..c654dbba46e8 100644
158+
index c2b8614148bf..9797b73f534a 100644
159159
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
160160
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
161-
@@ -479,10 +479,6 @@ struct FromElementsOpInterface
161+
@@ -480,10 +480,6 @@ struct FromElementsOpInterface
162162
auto fromElementsOp = cast<tensor::FromElementsOp>(op);
163163
auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
164164

@@ -169,7 +169,7 @@ index 87464ccb7172..c654dbba46e8 100644
169169
// Allocate a buffer for the result.
170170
Location loc = op->getLoc();
171171
auto shape = tensorType.getShape();
172-
@@ -492,10 +488,12 @@ struct FromElementsOpInterface
172+
@@ -493,10 +489,12 @@ struct FromElementsOpInterface
173173
/*copy=*/false);
174174
if (failed(tensorAlloc))
175175
return failure();
@@ -223,7 +223,7 @@ index 5293977fe733..55e086ff0110 100644
223223
}
224224

225225
diff --git a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
226-
index ab18ce05e355..148908536d6c 100644
226+
index ab18ce05e355..97a69e153e39 100644
227227
--- a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
228228
+++ b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
229229
@@ -4,8 +4,8 @@
@@ -287,32 +287,17 @@ index ab18ce05e355..148908536d6c 100644
287287
return %1 : memref<?xf32>
288288
}
289289

290-
@@ -77,19 +77,20 @@ func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset
291-
// TODO: to_memref with layout maps not supported yet. This should fold to a
292-
// memref.cast.
293-
func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, strided<[1], offset: ?>> {
294-
- %0 = bufferization.to_tensor %m : memref<?xf32>
295-
+ %0 = bufferization.to_tensor %m : memref<?xf32> -> tensor<?xf32>
296-
// expected-error @+1 {{failed to legalize unresolved materialization from ('memref<?xf32>') to 'memref<?xf32, strided<[1], offset: ?>>' that remained live after conversion}}
297-
- %1 = bufferization.to_memref %0 : memref<?xf32, strided<[1], offset: ?>>
298-
- // expected-note @below{{see existing live user here}}
299-
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32, strided<[1], offset: ?>>
300-
+ // expected-note @below {{see existing live user here:}}
301-
return %1 : memref<?xf32, strided<[1], offset: ?>>
302-
}
303-
290+
@@ -87,9 +87,8 @@ func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, stri
304291
// -----
305292

306293
func.func @illegal_unranked_to_rank(%m: memref<*xf32>) -> memref<?xf32> {
307294
- // expected-note @+1 {{prior use here}}
308295
- %0 = bufferization.to_tensor %m : memref<*xf32>
309296
- // expected-error @+1 {{expects different type than prior uses: 'tensor<?xf32>' vs 'tensor<*xf32>'}}
310297
- %1 = bufferization.to_memref %0 : memref<?xf32>
311-
+
312-
+ %0 = bufferization.to_tensor %m : memref<*xf32> -> tensor<*xf32>
298+
+ %0 = bufferization.to_tensor %m : memref<*xf32> -> tensor<?xf32>
313299
+ // expected-error @+1 {{failed to legalize unresolved materialization from ('memref<*xf32>') to 'memref<?xf32>' that remained live after conversion}}
314-
+ %1 = bufferization.to_memref %0 : tensor<*xf32> -> memref<?xf32>
315-
+ // expected-note @below {{see existing live user here:}}
300+
+ %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
316301
return %1 : memref<?xf32>
317302
}
318303
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
@@ -800,10 +785,10 @@ index 4bc2ed140da9..af5e745cb3a9 100644
800785
return %pack_18 : tensor<1x1x8x4x4x8xi32>
801786
}
802787
diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
803-
index c7af033a22a2..1c80a9f6024d 100644
788+
index 11114bcf2b1a..76f6ee511093 100644
804789
--- a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
805790
+++ b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
806-
@@ -356,10 +356,10 @@ func.func @neg_map() -> memref<2x3xf32, #neg> {
791+
@@ -360,11 +360,11 @@ func.func @neg_map() -> memref<2x3xf32, #neg> {
807792
// CHECK-LABEL: func @memref_with_strided_offset
808793
func.func @memref_with_strided_offset(%arg0: tensor<128x512xf32>, %arg1: index, %arg2: index) -> tensor<16x512xf32> {
809794
%c0 = arith.constant 0 : index
@@ -816,6 +801,7 @@ index c7af033a22a2..1c80a9f6024d 100644
816801
+ %1 = bufferization.to_tensor %cast : memref<16x512xf32, strided<[?, ?], offset: ?>> -> tensor<16x512xf32>
817802
return %1 : tensor<16x512xf32>
818803
}
804+
819805
diff --git a/mlir/test/Dialect/SCF/bufferize.mlir b/mlir/test/Dialect/SCF/bufferize.mlir
820806
index ff1612310255..c7ca1dcb031b 100644
821807
--- a/mlir/test/Dialect/SCF/bufferize.mlir
@@ -2072,19 +2058,6 @@ index 78e29979ca1a..6332b35ef6c0 100644
20722058
// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_11]][] : memref<f32>
20732059
// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
20742060
// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
2075-
diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
2076-
index f819458e0385..a8398deb7a3b 100644
2077-
--- a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
2078-
+++ b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir
2079-
@@ -85,7 +85,7 @@ func.func @sqsum(%arg0: tensor<?x?x?x?xi32, #COO>) -> tensor<i32> {
2080-
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
2081-
// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i32
2082-
// CHECK: %[[VAL_5:.*]] = arith.constant dense<0> : tensor<10xi32>
2083-
-// CHECK: %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_5]] : memref<10xi32>
2084-
+// CHECK: %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_5]] : tensor<10xi32> -> memref<10xi32>
2085-
// CHECK: linalg.fill ins(%[[VAL_4]] : i32) outs(%[[VAL_6]] : memref<10xi32>)
2086-
// CHECK: %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex>
2087-
// CHECK: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex>
20882061
diff --git a/mlir/test/Dialect/SparseTensor/sparse_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_lower.mlir
20892062
index c27df0078552..62ebc7ef3d96 100644
20902063
--- a/mlir/test/Dialect/SparseTensor/sparse_lower.mlir
@@ -2739,10 +2712,10 @@ index 3a3c8af15e6e..13be67d0562e 100644
27392712
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
27402713
// CHECK-DAG: %[[dim0:.*]] = memref.dim %[[m1]], %[[c0]]
27412714
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
2742-
index e2169fe1404c..c0e6b415e1a6 100644
2715+
index dc4306b8316a..bd71970cfb08 100644
27432716
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
27442717
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
2745-
@@ -387,7 +387,7 @@ func.func @tensor.reshape() -> tensor<2x2x5xf32> {
2718+
@@ -402,7 +402,7 @@ func.func @tensor.reshape() -> tensor<2x2x5xf32> {
27462719
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]*]]: memref<2x2xf32, strided<[?, ?], offset: ?>, 3>,
27472720
// CHECK-SAME: %[[LAYOUT:[a-zA-Z0-9]*]]: memref<2xi32, strided<[?], offset: ?>>,
27482721
func.func @reshape_with_non_identity_layout(%arg0: memref<2x2xf32, strided<[?, ?], offset: ?>, 3>, %arg1: tensor<2xi32>, %idx: index) -> f32 {
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
From efdb53b115f5ce7233482a344bfc0c26f9e47041 Mon Sep 17 00:00:00 2001
2+
From: Christopher Bate <[email protected]>
3+
Date: Fri, 15 Nov 2024 05:17:54 +0000
4+
Subject: [PATCH] [mlir] Add a null pointer check in symbol lookup #115165
5+
6+
---
7+
mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp | 2 ++
8+
1 file changed, 2 insertions(+)
9+
10+
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
11+
index 3c190d4e9919..e805e21d878b 100644
12+
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
13+
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
14+
@@ -186,6 +186,8 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
15+
// If a callable symbol has a non-call use, then we can't be guaranteed to
16+
// know all callsites.
17+
Operation *symbol = symbolTable.lookupSymbolIn(top, use.getSymbolRef());
18+
+ if (!symbol)
19+
+ continue;
20+
auto *state = getOrCreate<PredecessorState>(getProgramPointAfter(symbol));
21+
propagateIfChanged(state, state->setHasUnknownPredecessors());
22+
}
23+
--
24+
2.47.0
25+
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
From 00a3c5e6b9207bae81c6d401fa368fdfe270122b Mon Sep 17 00:00:00 2001
2+
From: Christopher Bate <[email protected]>
3+
Date: Fri, 22 Nov 2024 22:43:47 +0000
4+
Subject: [PATCH] [transforms] Fix simplification patterns for
5+
`stablehlo.(and|or)`
6+
7+
Fixes an issue in `stablehlo-aggressive-simplification` where `%1` in
8+
the below would get replaced by `%arg0`:
9+
10+
```
11+
%0 = stablehlo.constant dense<1> : tensor<2xi32>
12+
%1 = stablehlo.and %0, %arg0 : tensor<2xi32>
13+
```
14+
15+
The pattern was checking whether `%0` is equal to `0b1` and was
16+
only tested on bools. A similar bug existed for `stablehlo.and`. Fixed
17+
by just making sure the constant is integer with all bits set to 1.
18+
---
19+
.../stablehlo_aggressive_simplification.mlir | 38 +++++++++++++++++++
20+
...ablehloAggressiveSimplificationPatterns.td | 14 ++++++-
21+
2 files changed, 50 insertions(+), 2 deletions(-)
22+
23+
diff --git a/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
24+
index 809c0700..b2d05de3 100644
25+
--- a/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
26+
+++ b/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
27+
@@ -63,6 +63,25 @@ func.func @and_one(%arg0: tensor<2xi1>) -> tensor<2xi1> {
28+
return %1 : tensor<2xi1>
29+
}
30+
31+
+// CHECK-LABEL: @and_i32_one
32+
+func.func @and_i32_one(%arg0: tensor<2xi32>) -> tensor<2xi32> {
33+
+ %0 = stablehlo.constant dense<1> : tensor<2xi32>
34+
+ %1 = stablehlo.and %0, %arg0 : tensor<2xi32>
35+
+ // CHECK: %[[AND:.+]] = stablehlo.and
36+
+ // CHECK: return %[[AND]]
37+
+ return %1 : tensor<2xi32>
38+
+}
39+
+
40+
+// CHECK-LABEL: @and_i32_neg_one
41+
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<2xi32>)
42+
+func.func @and_i32_neg_one(%arg0: tensor<2xi32>) -> tensor<2xi32> {
43+
+ %0 = stablehlo.constant dense<-1> : tensor<2xi32>
44+
+ %1 = stablehlo.and %0, %arg0 : tensor<2xi32>
45+
+ // CHECK-NOT: stablehlo.and
46+
+ // CHECK: return %[[ARG0]]
47+
+ return %1 : tensor<2xi32>
48+
+}
49+
+
50+
// -----
51+
52+
/////////
53+
@@ -540,6 +559,25 @@ func.func @or_one(%arg0: tensor<2xi1>) -> tensor<2xi1> {
54+
return %1 : tensor<2xi1>
55+
}
56+
57+
+// CHECK-LABEL: @or_i32_one
58+
+func.func @or_i32_one(%arg0: tensor<2xi32>) -> tensor<2xi32> {
59+
+ %0 = stablehlo.constant dense<1> : tensor<2xi32>
60+
+ %1 = stablehlo.or %0, %arg0 : tensor<2xi32>
61+
+ // CHECK: %[[OR:.+]] = stablehlo.or
62+
+ // CHECK: return %[[OR]]
63+
+ return %1 : tensor<2xi32>
64+
+}
65+
+
66+
+// CHECK-LABEL: @or_i32_neg_one
67+
+func.func @or_i32_neg_one(%arg0: tensor<2xi32>) -> tensor<2xi32> {
68+
+ %0 = stablehlo.constant dense<-1> : tensor<2xi32>
69+
+ %1 = stablehlo.or %0, %arg0 : tensor<2xi32>
70+
+ // CHECK-NOT: stablehlo.or
71+
+ // CHECK: [[NEG_ONE:%.+]] = stablehlo.constant dense<-1> : tensor<2xi32>
72+
+ // CHECK: return [[NEG_ONE]]
73+
+ return %1 : tensor<2xi32>
74+
+}
75+
+
76+
// -----
77+
78+
/////////
79+
diff --git a/stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td b/stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td
80+
index 31f1f475..cef9f303 100644
81+
--- a/stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td
82+
+++ b/stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td
83+
@@ -41,6 +41,16 @@ def AnySplat : AttrConstraint<CPred<"$_self.isSplat()">, "is any splat">;
84+
def AnyZero : AttrConstraint<
85+
CPred<"::mlir::matchPattern($_self, m_AnyAttrOf(m_Zero(), m_AnyZeroFloat()))">, "is int or float zero">;
86+
87+
+def IntAllOnes : AttrConstraint<
88+
+ CPred<[{
89+
+ ::mlir::matchPattern($_self,
90+
+ ::mlir::detail::constant_int_predicate_matcher{
91+
+ [](const llvm::APInt &val) {
92+
+ return val.isAllOnes();
93+
+ }})
94+
+ }]>,
95+
+ "is integer with all bits set to 1">;
96+
+
97+
def IntZero : AttrConstraint<
98+
CPred<"::mlir::matchPattern($_self, m_Zero())">, "is integer zero">;
99+
100+
@@ -101,7 +111,7 @@ def : Pat<(StableHLO_AndOp $lhs, (StableHLO_ConstantOp:$zero IntZero:$value)),
101+
(replaceWithValue $zero)>;
102+
103+
// Pattern: and(X, 1) -> X
104+
-def : Pat<(StableHLO_AndOp $lhs, (StableHLO_ConstantOp:$one IntOne:$value)),
105+
+def : Pat<(StableHLO_AndOp $lhs, (StableHLO_ConstantOp:$one IntAllOnes:$value)),
106+
(replaceWithValue $lhs)>;
107+
108+
////////
109+
@@ -208,7 +218,7 @@ def : Pat<(StableHLO_MulOp $lhs, (StableHLO_ConstantOp IntOne:$value)),
110+
def : CanonicalizeConstantToRhs<StableHLO_OrOp>;
111+
112+
// Pattern: or(X, 1) -> 1
113+
-def : Pat<(StableHLO_OrOp $lhs, (StableHLO_ConstantOp:$one IntOne:$value)),
114+
+def : Pat<(StableHLO_OrOp $lhs, (StableHLO_ConstantOp:$one IntAllOnes:$value)),
115+
(replaceWithValue $one)>;
116+
117+
// Pattern: or(X, 0) -> X
118+
--
119+
2.47.0
120+

0 commit comments

Comments
 (0)