Skip to content
This repository was archived by the owner on Nov 27, 2025. It is now read-only.

Commit 6ccc228

Browse files
[mlir][bufferization] Refine tensor-buffer compatibility checks (#167705)
Generally, to_tensor and to_buffer already perform sufficient verification. However, there are some unnecessarily strict constraints: * builtin tensor requires its buffer counterpart to always be memref * to_buffer on ranked tensor requires to always return memref These checks are assertions (i.e. preconditions), however, they actually prevent an apparently useful bufferization where builtin tensors could become custom buffers. Lift these assertions, maintaining the verification procedure unchanged, to allow builtin -> custom bufferizations at operation boundary level.
1 parent dd7c13f commit 6ccc228

File tree

5 files changed

+110
-20
lines changed

5 files changed

+110
-20
lines changed

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -656,16 +656,6 @@ bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
656656
return false;
657657
}
658658

659-
// bufferization.to_buffer is not allowed to change the rank.
660-
static void ensureToBufferOpIsValid(Value tensor, Type memrefType) {
661-
#ifndef NDEBUG
662-
auto rankedTensorType = llvm::dyn_cast<RankedTensorType>(tensor.getType());
663-
assert((!rankedTensorType || llvm::cast<MemRefType>(memrefType).getRank() ==
664-
rankedTensorType.getRank()) &&
665-
"to_buffer would be invalid: mismatching ranks");
666-
#endif
667-
}
668-
669659
FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
670660
const BufferizationOptions &options) {
671661
#ifndef NDEBUG
@@ -683,7 +673,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
683673
FailureOr<BufferLikeType> bufferType = getBufferType(value, options);
684674
if (failed(bufferType))
685675
return failure();
686-
ensureToBufferOpIsValid(value, *bufferType);
676+
687677
return rewriter
688678
.create<bufferization::ToBufferOp>(value.getLoc(), *bufferType, value)
689679
.getResult();

mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,6 @@ struct BuiltinTensorExternalModel
7373
mlir::LogicalResult verifyCompatibleBufferType(
7474
mlir::Type tensor, BufferLikeType bufferType,
7575
llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
76-
assert(isa<TensorType>(tensor) && "expected tensor type");
77-
assert(isa<BaseMemRefType>(bufferType) && "expected memref type");
78-
7976
auto tensorType = cast<ShapedType>(tensor);
8077
auto memrefType = cast<ShapedType>(bufferType);
8178

mlir/test/Dialect/Bufferization/invalid.mlir

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,63 @@ func.func @invalid_manual_deallocation() {
127127
// expected-error @below{{op attribute 'bufferization.manual_deallocation' can be used only on ops that have an allocation and/or free side effect}}
128128
arith.constant {bufferization.manual_deallocation} 0 : index
129129
}
130+
131+
// -----
132+
133+
func.func @invalid_rank_to_buffer(%t: tensor<1x2x3x4xf32>) {
134+
// expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}}
135+
// expected-error @below{{shapes do not match}}
136+
%b = bufferization.to_buffer %t
137+
: tensor<1x2x3x4xf32> to memref<1x2x3xf32>
138+
return
139+
}
140+
141+
// -----
142+
143+
func.func @invalid_rank_to_tensor(%b: memref<1x2x3xf32>) {
144+
// expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}}
145+
// expected-error @below{{shapes do not match}}
146+
%t = bufferization.to_tensor %b
147+
: memref<1x2x3xf32> to tensor<1x2x3x4xf32>
148+
return
149+
}
150+
151+
// -----
152+
153+
func.func @invalid_shape_to_buffer(%t: tensor<1x2x3x4xf32>) {
154+
// expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}}
155+
// expected-error @below{{shapes do not match}}
156+
%b = bufferization.to_buffer %t
157+
: tensor<1x2x3x4xf32> to memref<1x2x4x3xf32>
158+
return
159+
}
160+
161+
// -----
162+
163+
func.func @invalid_shape_to_tensor(%b: memref<1x2x4x3xf32>) {
164+
// expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}}
165+
// expected-error @below{{shapes do not match}}
166+
%t = bufferization.to_tensor %b
167+
: memref<1x2x4x3xf32> to tensor<1x2x3x4xf32>
168+
return
169+
}
170+
171+
// -----
172+
173+
func.func @invalid_type_to_buffer(%t: tensor<1x2x3x4xf32>) {
174+
// expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}}
175+
// expected-error @below{{element types do not match}}
176+
%b = bufferization.to_buffer %t
177+
: tensor<1x2x3x4xf32> to memref<1x2x3x4xf16>
178+
return
179+
}
180+
181+
// -----
182+
183+
func.func @invalid_type_to_tensor(%b: memref<1x2x3x4xf16>) {
184+
// expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}}
185+
// expected-error @below{{element types do not match}}
186+
%t2 = bufferization.to_tensor %b
187+
: memref<1x2x3x4xf16> to tensor<1x2x3x4xf32>
188+
return
189+
}

mlir/test/Dialect/Bufferization/ops.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,40 @@ func.func @test_dealloc_op(%arg0: memref<2xf32>, %arg1: memref<4xi32>,
8383
bufferization.dealloc
8484
return %0#0, %0#1 : i1, i1
8585
}
86+
87+
// CHECK: func.func @test_builtin_custom_builtin_type_conversion
88+
// CHECK-SAME: (%[[t:.*]]: tensor<42xf32>) -> tensor<42xf32>
89+
func.func @test_builtin_custom_builtin_type_conversion(%t: tensor<42xf32>)
90+
-> tensor<42xf32> {
91+
// CHECK: %[[buffer:.*]] = bufferization.to_buffer %[[t]]
92+
// CHECK-SAME: to !test.test_memref<[42], f32>
93+
%buffer = bufferization.to_buffer %t
94+
: tensor<42xf32> to !test.test_memref<[42], f32>
95+
96+
// CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[buffer]]
97+
// CHECK-SAME: to tensor<42xf32>
98+
%tensor = bufferization.to_tensor %buffer
99+
: !test.test_memref<[42], f32> to tensor<42xf32>
100+
101+
// CHECK: return %[[tensor]]
102+
return %tensor : tensor<42xf32>
103+
}
104+
105+
// CHECK: func.func @test_custom_builtin_custom_type_conversion
106+
// CHECK-SAME: (%[[t:.*]]: !test.test_tensor<[42], f32>)
107+
// CHECK-SAME: -> !test.test_tensor<[42], f32>
108+
func.func @test_custom_builtin_custom_type_conversion(%t: !test.test_tensor<[42], f32>)
109+
-> !test.test_tensor<[42], f32> {
110+
// CHECK: %[[buffer:.*]] = bufferization.to_buffer %[[t]]
111+
// CHECK-SAME: to memref<42xf32>
112+
%buffer = bufferization.to_buffer %t
113+
: !test.test_tensor<[42], f32> to memref<42xf32>
114+
115+
// CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[buffer]]
116+
// CHECK-SAME: to !test.test_tensor<[42], f32>
117+
%tensor = bufferization.to_tensor %buffer
118+
: memref<42xf32> to !test.test_tensor<[42], f32>
119+
120+
// CHECK: return %[[tensor]]
121+
return %tensor : !test.test_tensor<[42], f32>
122+
}

mlir/test/lib/Dialect/Test/TestTypes.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -584,11 +584,17 @@ TestTensorType::getBufferType(
584584
::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType(
585585
::mlir::bufferization::BufferLikeType bufferType,
586586
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {
587-
auto testMemref = llvm::dyn_cast<TestMemrefType>(bufferType);
588-
if (!testMemref)
589-
return emitError() << "expected TestMemrefType";
587+
if (auto testMemref = llvm::dyn_cast<TestMemrefType>(bufferType)) {
588+
const bool valid = getShape() == testMemref.getShape() &&
589+
getElementType() == testMemref.getElementType();
590+
return mlir::success(valid);
591+
}
592+
593+
if (auto builtinMemref = llvm::dyn_cast<MemRefType>(bufferType)) {
594+
const bool valid = getShape() == builtinMemref.getShape() &&
595+
getElementType() == builtinMemref.getElementType();
596+
return mlir::success(valid);
597+
}
590598

591-
const bool valid = getShape() == testMemref.getShape() &&
592-
getElementType() == testMemref.getElementType();
593-
return mlir::success(valid);
599+
return emitError() << "expected MemRefType or TestMemrefType";
594600
}

0 commit comments

Comments
 (0)