Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -656,16 +656,6 @@ bool AnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
return false;
}

// bufferization.to_buffer is not allowed to change the rank.
static void ensureToBufferOpIsValid(Value tensor, Type memrefType) {
#ifndef NDEBUG
auto rankedTensorType = llvm::dyn_cast<RankedTensorType>(tensor.getType());
assert((!rankedTensorType || llvm::cast<MemRefType>(memrefType).getRank() ==
rankedTensorType.getRank()) &&
"to_buffer would be invalid: mismatching ranks");
#endif
}

FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options) {
#ifndef NDEBUG
Expand All @@ -683,7 +673,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
FailureOr<BufferLikeType> bufferType = getBufferType(value, options);
if (failed(bufferType))
return failure();
ensureToBufferOpIsValid(value, *bufferType);

return rewriter
.create<bufferization::ToBufferOp>(value.getLoc(), *bufferType, value)
.getResult();
Expand Down
3 changes: 0 additions & 3 deletions mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,6 @@ struct BuiltinTensorExternalModel
mlir::LogicalResult verifyCompatibleBufferType(
mlir::Type tensor, BufferLikeType bufferType,
llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
assert(isa<TensorType>(tensor) && "expected tensor type");
assert(isa<BaseMemRefType>(bufferType) && "expected memref type");

auto tensorType = cast<ShapedType>(tensor);
auto memrefType = cast<ShapedType>(bufferType);

Expand Down
60 changes: 60 additions & 0 deletions mlir/test/Dialect/Bufferization/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,63 @@ func.func @invalid_manual_deallocation() {
// expected-error @below{{op attribute 'bufferization.manual_deallocation' can be used only on ops that have an allocation and/or free side effect}}
arith.constant {bufferization.manual_deallocation} 0 : index
}

// -----

func.func @invalid_rank_to_buffer(%t: tensor<1x2x3x4xf32>) {
// expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}}
// expected-error @below{{shapes do not match}}
%b = bufferization.to_buffer %t
: tensor<1x2x3x4xf32> to memref<1x2x3xf32>
return
}

// -----

func.func @invalid_rank_to_tensor(%b: memref<1x2x3xf32>) {
// expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}}
// expected-error @below{{shapes do not match}}
%t = bufferization.to_tensor %b
: memref<1x2x3xf32> to tensor<1x2x3x4xf32>
return
}

// -----

func.func @invalid_shape_to_buffer(%t: tensor<1x2x3x4xf32>) {
// expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}}
// expected-error @below{{shapes do not match}}
%b = bufferization.to_buffer %t
: tensor<1x2x3x4xf32> to memref<1x2x4x3xf32>
return
}

// -----

func.func @invalid_shape_to_tensor(%b: memref<1x2x4x3xf32>) {
// expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}}
// expected-error @below{{shapes do not match}}
%t = bufferization.to_tensor %b
: memref<1x2x4x3xf32> to tensor<1x2x3x4xf32>
return
}

// -----

func.func @invalid_type_to_buffer(%t: tensor<1x2x3x4xf32>) {
// expected-error @below{{'bufferization.to_buffer' op failed to verify that specified tensor and buffer types match}}
// expected-error @below{{element types do not match}}
%b = bufferization.to_buffer %t
: tensor<1x2x3x4xf32> to memref<1x2x3x4xf16>
return
}

// -----

func.func @invalid_type_to_tensor(%b: memref<1x2x3x4xf16>) {
// expected-error @below{{'bufferization.to_tensor' op failed to verify that specified tensor and buffer types match}}
// expected-error @below{{element types do not match}}
%t2 = bufferization.to_tensor %b
: memref<1x2x3x4xf16> to tensor<1x2x3x4xf32>
return
}
37 changes: 37 additions & 0 deletions mlir/test/Dialect/Bufferization/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,40 @@ func.func @test_dealloc_op(%arg0: memref<2xf32>, %arg1: memref<4xi32>,
bufferization.dealloc
return %0#0, %0#1 : i1, i1
}

// CHECK: func.func @test_builtin_custom_builtin_type_conversion
// CHECK-SAME: (%[[t:.*]]: tensor<42xf32>) -> tensor<42xf32>
func.func @test_builtin_custom_builtin_type_conversion(%t: tensor<42xf32>)
-> tensor<42xf32> {
// CHECK: %[[buffer:.*]] = bufferization.to_buffer %[[t]]
// CHECK-SAME: to !test.test_memref<[42], f32>
%buffer = bufferization.to_buffer %t
: tensor<42xf32> to !test.test_memref<[42], f32>

// CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[buffer]]
// CHECK-SAME: to tensor<42xf32>
%tensor = bufferization.to_tensor %buffer
: !test.test_memref<[42], f32> to tensor<42xf32>

// CHECK: return %[[tensor]]
return %tensor : tensor<42xf32>
}

// CHECK: func.func @test_custom_builtin_custom_type_conversion
// CHECK-SAME: (%[[t:.*]]: !test.test_tensor<[42], f32>)
// CHECK-SAME: -> !test.test_tensor<[42], f32>
func.func @test_custom_builtin_custom_type_conversion(%t: !test.test_tensor<[42], f32>)
-> !test.test_tensor<[42], f32> {
// CHECK: %[[buffer:.*]] = bufferization.to_buffer %[[t]]
// CHECK-SAME: to memref<42xf32>
%buffer = bufferization.to_buffer %t
: !test.test_tensor<[42], f32> to memref<42xf32>

// CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[buffer]]
// CHECK-SAME: to !test.test_tensor<[42], f32>
%tensor = bufferization.to_tensor %buffer
: memref<42xf32> to !test.test_tensor<[42], f32>

// CHECK: return %[[tensor]]
return %tensor : !test.test_tensor<[42], f32>
}
18 changes: 12 additions & 6 deletions mlir/test/lib/Dialect/Test/TestTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -584,11 +584,17 @@ TestTensorType::getBufferType(
::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType(
::mlir::bufferization::BufferLikeType bufferType,
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {
auto testMemref = llvm::dyn_cast<TestMemrefType>(bufferType);
if (!testMemref)
return emitError() << "expected TestMemrefType";
if (auto testMemref = llvm::dyn_cast<TestMemrefType>(bufferType)) {
const bool valid = getShape() == testMemref.getShape() &&
getElementType() == testMemref.getElementType();
return mlir::success(valid);
}

if (auto builtinMemref = llvm::dyn_cast<MemRefType>(bufferType)) {
const bool valid = getShape() == builtinMemref.getShape() &&
getElementType() == builtinMemref.getElementType();
return mlir::success(valid);
}

const bool valid = getShape() == testMemref.getShape() &&
getElementType() == testMemref.getElementType();
return mlir::success(valid);
return emitError() << "expected MemRefType or TestMemrefType";
}
Loading