Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ MLIR_CAPI_EXPORTED MTRT_Status
mtrtMemRefCreate(MTRT_RuntimeClient client, MTRT_PointerType pointerKind,
int64_t bitsPerElement, int64_t rank, const int64_t *shape,
const int64_t *strides, MTRT_Device device, MTRT_Stream stream,
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result);
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result,
bool assertCanonicalStrides = false);

/// Creates an externally managed MemRef value. The caller provides all the
/// metadata for the MemRef including the shape, strides (in elements), pointer,
Expand All @@ -142,7 +143,8 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtMemRefCreateExternal(
MTRT_RuntimeClient client, MTRT_PointerType pointerKind,
int64_t bitsPerElement, uintptr_t ptr, int64_t offset, int64_t rank,
const int64_t *shape, const int64_t *strides, MTRT_Device device,
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result);
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result,
bool assertCanonicalStrides = false);

/// Destroys `MTRT_MemRefValue` in a potentially asynchronous manner.
/// If `buffer` is a device buffer, device memory is freed in the stream
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,8 @@ class MemRefValue : public RuntimeValue {
int64_t bitsPerElement, uintptr_t ptr, int64_t offset,
llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides,
std::optional<const Device *> device,
std::optional<ScalarType> scalarType);
std::optional<ScalarType> scalarType,
std::optional<bool> assertCanonicalStrides = {});

mlirtrt::runtime::PointerType getBufferKind() { return addressSpace; }
int64_t getElementBitWidth() const { return bitsPerElement; }
Expand Down Expand Up @@ -917,15 +918,17 @@ class RuntimeClient {
llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides,
std::optional<const Device *> device = {},
std::optional<CudaStream> stream = {},
std::optional<ScalarType> scalarType = {});
std::optional<ScalarType> scalarType = {},
std::optional<bool> assertCanonicalStrides = {});

StatusOr<std::unique_ptr<MemRefValue>>
createExternalMemRef(PointerType addressSpace, int64_t bitsPerElement,
uintptr_t ptr, int64_t offset,
llvm::ArrayRef<int64_t> shape,
llvm::ArrayRef<int64_t> strides,
std::optional<const Device *> device = {},
std::optional<ScalarType> scalarType = {});
std::optional<ScalarType> scalarType = {},
std::optional<bool> assertCanonicalStrides = {});

/// Frees the memory in `value`. The `stream` may optionally be provided
/// for resources that can be deallocated asynchronously.
Expand Down
12 changes: 8 additions & 4 deletions mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ MTRT_Status
mtrtMemRefCreate(MTRT_RuntimeClient client, MTRT_PointerType pointerKind,
int64_t bitsPerElement, int64_t rank, const int64_t *shape,
const int64_t *strides, MTRT_Device device, MTRT_Stream stream,
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result) {
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result,
bool assertCanonicalStrides) {
StatusOr<std::unique_ptr<MemRefValue>> bufferImpl =
unwrap(client)->allocateMemRef(
unwrap(pointerKind), bitsPerElement,
Expand All @@ -244,7 +245,8 @@ mtrtMemRefCreate(MTRT_RuntimeClient client, MTRT_PointerType pointerKind,
: std::optional(unwrap(stream)->getRawStream()),
scalarType != MTRT_ScalarTypeCode::MTRT_ScalarTypeCode_unknown
? std::optional(ScalarType(unwrap(scalarType)))
: std::nullopt);
: std::nullopt,
std::optional(assertCanonicalStrides));

if (bufferImpl.isError())
return wrap(bufferImpl.getStatus());
Expand All @@ -257,7 +259,8 @@ MTRT_Status mtrtMemRefCreateExternal(
MTRT_RuntimeClient client, MTRT_PointerType pointerKind,
int64_t bitsPerElement, uintptr_t ptr, int64_t offset, int64_t rank,
const int64_t *shape, const int64_t *strides, MTRT_Device device,
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result) {
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result,
bool assertCanonicalStrides) {
StatusOr<std::unique_ptr<MemRefValue>> bufferImpl =
unwrap(client)->createExternalMemRef(
unwrap(pointerKind), bitsPerElement, ptr, offset,
Expand All @@ -267,7 +270,8 @@ MTRT_Status mtrtMemRefCreateExternal(
: std::optional(unwrap(device)),
scalarType == MTRT_ScalarTypeCode_unknown
? std::nullopt
: std::optional(ScalarType(unwrap(scalarType))));
: std::optional(ScalarType(unwrap(scalarType))),
std::optional(assertCanonicalStrides));

if (bufferImpl.isError())
return wrap(bufferImpl.getStatus());
Expand Down
63 changes: 57 additions & 6 deletions mlir-tensorrt/executor/lib/Runtime/API/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,12 +671,50 @@ static StatusOr<int64_t> getFootprintInBytes(llvm::ArrayRef<int64_t> shape,
return sizeBytes;
}

static llvm::SmallVector<int64_t> getCanonicalStride(const llvm::ArrayRef<int64_t>& shape) {
if (shape.empty())
return {};

llvm::SmallVector<int64_t> canonicalStride(shape.size(), 1);
int64_t cumulativeProduct = 1;

for (int64_t dimIndex = shape.size() - 1; dimIndex >= 0; --dimIndex) {
bool isFirstZeroDim = (shape[dimIndex] == 0 && dimIndex != static_cast<int64_t>(shape.size()) - 1);
// For dimensions with size 0 or 1, the stride can be arbitrary.
// We set it to 1 here, but other values would also be valid.
if (isFirstZeroDim || shape[dimIndex] == 1)
canonicalStride[dimIndex] = 1;
else
canonicalStride[dimIndex] = cumulativeProduct;
// For zero-sized dimensions (except the last one), we don't update the cumulative product
// This allows for consistent handling of zero-sized dimensions across different frameworks
cumulativeProduct *= isFirstZeroDim ? 1 : shape[dimIndex];
}

return canonicalStride;
}

static bool areStridesEquivalent(llvm::ArrayRef<int64_t> shape,
llvm::ArrayRef<int64_t> stride,
llvm::ArrayRef<int64_t> expectedStride) {
if (shape.size() != stride.size() || shape.size() != expectedStride.size())
return false;

for (size_t i = 0; i < shape.size(); ++i)
// Allow arbitrary strides for dimensions with size 0 or 1
// This accounts for discrepancies in how different frameworks handle these cases
if (stride[i] != expectedStride[i] && shape[i] != 0 && shape[i] != 1)
return false;

return true;
}

StatusOr<std::unique_ptr<MemRefValue>> MemRefValue::create(
RuntimeClient *client, mlirtrt::runtime::PointerType addressSpace,
int64_t bitsPerElement, uintptr_t ptr, int64_t offset,
llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides,
std::optional<const Device *> device,
std::optional<ScalarType> scalarType) {
std::optional<const Device *> device, std::optional<ScalarType> scalarType,
std::optional<bool> assertCanonicalStrides) {
if (!client)
return getInvalidArgStatus("a valid RuntimeClient must be provided to "
"create a tracked MemRef object");
Expand All @@ -691,6 +729,19 @@ StatusOr<std::unique_ptr<MemRefValue>> MemRefValue::create(
return getInvalidArgStatus("a specific device must be provided for MemRefs "
"that are device-visible");

// Check if given strides match canonical stride
if (assertCanonicalStrides && *assertCanonicalStrides) {
llvm::SmallVector<int64_t> canonicalStride = getCanonicalStride(shape);
if (!strides.empty() &&
!areStridesEquivalent(shape, strides, canonicalStride)) {
std::string errorMsg =
llvm::formatv("Given strides [{0}] do not match canonical strides "
"[{1}] for shape [{2}]",
strides, canonicalStride, shape);
return getInvalidArgStatus(errorMsg.c_str());
}
}

return std::unique_ptr<MemRefValue>(
new MemRefValue(client, addressSpace, bitsPerElement, ptr, offset, shape,
strides, device, scalarType));
Expand Down Expand Up @@ -777,7 +828,7 @@ StatusOr<std::unique_ptr<MemRefValue>> RuntimeClient::allocateMemRef(
PointerType addressSpace, int64_t bitsPerElement,
llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides,
std::optional<const Device *> device, std::optional<CudaStream> stream,
std::optional<ScalarType> scalarType) {
std::optional<ScalarType> scalarType, std::optional<bool> assertCanonicalStrides) {
if (addressSpace == PointerType::device ||
addressSpace == PointerType::unified) {
if (!device || !*device)
Expand All @@ -800,7 +851,7 @@ StatusOr<std::unique_ptr<MemRefValue>> RuntimeClient::allocateMemRef(
// Create the descriptor.
StatusOr<std::unique_ptr<MemRefValue>> bufferImpl =
MemRefValue::create(this, addressSpace, bitsPerElement, allocation->ptr,
0, shape, strides, device, scalarType);
0, shape, strides, device, scalarType, assertCanonicalStrides);
if (bufferImpl.isError())
return bufferImpl.getStatus();

Expand All @@ -811,11 +862,11 @@ StatusOr<std::unique_ptr<MemRefValue>> RuntimeClient::createExternalMemRef(
PointerType addressSpace, int64_t bitsPerElement, uintptr_t ptr,
int64_t offset, llvm::ArrayRef<int64_t> shape,
llvm::ArrayRef<int64_t> strides, std::optional<const Device *> device,
std::optional<ScalarType> scalarType) {
std::optional<ScalarType> scalarType, std::optional<bool> assertCanonicalStrides) {
// Create the descriptor.
StatusOr<std::unique_ptr<MemRefValue>> memref =
MemRefValue::create(this, addressSpace, bitsPerElement, ptr, offset,
shape, strides, device, scalarType);
shape, strides, device, scalarType, assertCanonicalStrides);
if (!memref.isOk())
return memref.getStatus();

Expand Down
31 changes: 19 additions & 12 deletions mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,8 @@ static std::unique_ptr<PyMemRefValue> createMemRef(
}

static std::unique_ptr<PyMemRefValue>
createMemRefViewFromDLPack(PyRuntimeClient &client, py::capsule capsule) {
createMemRefViewFromDLPack(PyRuntimeClient &client, py::capsule capsule,
std::optional<bool> assertCanonicalStrides) {
DLManagedTensor *managedTensor = static_cast<DLManagedTensor *>(
PyCapsule_GetPointer(capsule.ptr(), "dltensor"));

Expand Down Expand Up @@ -368,14 +369,16 @@ createMemRefViewFromDLPack(PyRuntimeClient &client, py::capsule capsule) {
}

if (data) {
s = mtrtMemRefCreateExternal(client, addressSpace, bytesPerElement * 8,
reinterpret_cast<uintptr_t>(data), offset,
rank, shape, strides, device, elementType,
&result);
s = mtrtMemRefCreateExternal(
client, addressSpace, bytesPerElement * 8,
reinterpret_cast<uintptr_t>(data), offset, rank, shape, strides, device,
elementType, &result,
assertCanonicalStrides ? *assertCanonicalStrides : false);
} else {
s = mtrtMemRefCreate(client, addressSpace, bytesPerElement * 8, rank, shape,
strides, device, mtrtStreamGetNull(), elementType,
&result);
s = mtrtMemRefCreate(
client, addressSpace, bytesPerElement * 8, rank, shape, strides, device,
mtrtStreamGetNull(), elementType, &result,
assertCanonicalStrides ? *assertCanonicalStrides : false);
}

THROW_IF_MTRT_ERROR(s);
Expand Down Expand Up @@ -788,11 +791,15 @@ PYBIND11_MODULE(_api, m) {
"returns a new memref and allocates uninitialized backing storage")
.def(
"create_memref_view_from_dlpack",
[](PyRuntimeClient &self, py::capsule capsule) {
return createMemRefViewFromDLPack(self, capsule).release();
[](PyRuntimeClient &self, py::capsule capsule,
std::optional<bool> assertCanonicalStrides) {
return createMemRefViewFromDLPack(self, capsule,
assertCanonicalStrides)
.release();
},
py::arg("dltensor") = py::none(), py::keep_alive<0, 1>(),
py::keep_alive<0, 2>())
py::arg("dltensor") = py::none(),
py::arg("assert_canonical_strides") = py::none(),
py::keep_alive<0, 1>(), py::keep_alive<0, 2>())
.def(
"create_device_memref_view",
[](PyRuntimeClient &self, uintptr_t ptr, std::vector<int64_t> shape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,3 +514,54 @@ def create_dangling_memref():
# CHECK-LABEL: Test memref maintains data's lifetime
# CHECK-NEXT: -- Inner scope: np.from_dlpack(): [1 2]
# CHECK-NEXT: -- Outer scope: np.from_dlpack(): [1 2]


def check_non_canonical_stride(client, assert_canonical_strides):
try:
t = cp.arange(12, dtype=cp.float32).reshape(3, 4)
a = cp.transpose(t)
memref = client.create_memref_view_from_dlpack(
a.__dlpack__(), assert_canonical_strides
)
except Exception as e:
print(f"Received error message: {str(e)}")


def check_canonical_stride(client, assert_canonical_strides):
try:
t = cp.arange(12, dtype=cp.float32).reshape(3, 4)
memref = client.create_memref_view_from_dlpack(
t.__dlpack__(), assert_canonical_strides
)
except Exception as e:
print(f"Received error message: {str(e)}")


def test_memref_strides():
print("Testing non-canonical stride: assert_canonical_strides = True")
non_canonical_result = check_non_canonical_stride(
client, assert_canonical_strides=True
)

print("Testing non-canonical stride: assert_canonical_strides = False")
non_canonical_result = check_non_canonical_stride(
client, assert_canonical_strides=False
)

print("Testing canonical stride: assert_canonical_strides = True")
canonical_result = check_canonical_stride(client, assert_canonical_strides=True)

print("Testing canonical stride: assert_canonical_strides = False")
canonical_result = check_canonical_stride(client, assert_canonical_strides=False)


print("Test memref strides")
test_memref_strides()

# CHECK-LABEL: Test memref strides
# CHECK-NEXT: Testing non-canonical stride: assert_canonical_strides = True
# CHECK-NEXT: Received error message: InvalidArgument: InvalidArgument:
# CHECK-SAME: Given strides [1, 4] do not match canonical strides [3, 1] for shape [4, 3]
# CHECK-NEXT: Testing non-canonical stride: assert_canonical_strides = False
# CHECK-NEXT: Testing canonical stride: assert_canonical_strides = True
# CHECK-NEXT: Testing canonical stride: assert_canonical_strides = False