Skip to content

Commit 4a91333

Browse files
committed
[API/MemRef] Implement canonical stride validation for MemRefValue
Add optional stride validation in `MemRefValue::create` to compute canonical stride and compare against given strides while creaing a memref view from DLPack tensors. We need to handle special cases for zero-sized and unit-sized dimensions since frameworks deal with them arbitrarily while converting to the corresponding DLPack tensor. Add Python tests to verify both canonical and non-canonical stride validation.
1 parent 223cc67 commit 4a91333

File tree

6 files changed

+145
-27
lines changed

6 files changed

+145
-27
lines changed

mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ MLIR_CAPI_EXPORTED MTRT_Status
130130
mtrtMemRefCreate(MTRT_RuntimeClient client, MTRT_PointerType pointerKind,
131131
int64_t bitsPerElement, int64_t rank, const int64_t *shape,
132132
const int64_t *strides, MTRT_Device device, MTRT_Stream stream,
133-
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result);
133+
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result,
134+
bool assertCanonicalStrides = false);
134135

135136
/// Creates an externally managed MemRef value. The caller provides all the
136137
/// metadata for the MemRef including the shape, strides (in elements), pointer,
@@ -142,7 +143,8 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtMemRefCreateExternal(
142143
MTRT_RuntimeClient client, MTRT_PointerType pointerKind,
143144
int64_t bitsPerElement, uintptr_t ptr, int64_t offset, int64_t rank,
144145
const int64_t *shape, const int64_t *strides, MTRT_Device device,
145-
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result);
146+
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result,
147+
bool assertCanonicalStrides = false);
146148

147149
/// Destroys `MTRT_MemRefValue` in a potentially asynchronous manner.
148150
/// If `buffer` is a device buffer, device memory is freed in the stream

mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,8 @@ class MemRefValue : public RuntimeValue {
647647
int64_t bitsPerElement, uintptr_t ptr, int64_t offset,
648648
llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides,
649649
std::optional<const Device *> device,
650-
std::optional<ScalarType> scalarType);
650+
std::optional<ScalarType> scalarType,
651+
std::optional<bool> assertCanonicalStrides = {});
651652

652653
mlirtrt::runtime::PointerType getBufferKind() { return addressSpace; }
653654
int64_t getElementBitWidth() const { return bitsPerElement; }
@@ -917,15 +918,17 @@ class RuntimeClient {
917918
llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides,
918919
std::optional<const Device *> device = {},
919920
std::optional<CudaStream> stream = {},
920-
std::optional<ScalarType> scalarType = {});
921+
std::optional<ScalarType> scalarType = {},
922+
std::optional<bool> assertCanonicalStrides = {});
921923

922924
StatusOr<std::unique_ptr<MemRefValue>>
923925
createExternalMemRef(PointerType addressSpace, int64_t bitsPerElement,
924926
uintptr_t ptr, int64_t offset,
925927
llvm::ArrayRef<int64_t> shape,
926928
llvm::ArrayRef<int64_t> strides,
927929
std::optional<const Device *> device = {},
928-
std::optional<ScalarType> scalarType = {});
930+
std::optional<ScalarType> scalarType = {},
931+
std::optional<bool> assertCanonicalStrides = {});
929932

930933
/// Frees the memory in `value`. The `stream` may optionally be provided
931934
/// for resources that can be deallocated asynchronously.

mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,8 @@ MTRT_Status
231231
mtrtMemRefCreate(MTRT_RuntimeClient client, MTRT_PointerType pointerKind,
232232
int64_t bitsPerElement, int64_t rank, const int64_t *shape,
233233
const int64_t *strides, MTRT_Device device, MTRT_Stream stream,
234-
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result) {
234+
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result,
235+
bool assertCanonicalStrides) {
235236
StatusOr<std::unique_ptr<MemRefValue>> bufferImpl =
236237
unwrap(client)->allocateMemRef(
237238
unwrap(pointerKind), bitsPerElement,
@@ -244,7 +245,8 @@ mtrtMemRefCreate(MTRT_RuntimeClient client, MTRT_PointerType pointerKind,
244245
: std::optional(unwrap(stream)->getRawStream()),
245246
scalarType != MTRT_ScalarTypeCode::MTRT_ScalarTypeCode_unknown
246247
? std::optional(ScalarType(unwrap(scalarType)))
247-
: std::nullopt);
248+
: std::nullopt,
249+
std::optional(assertCanonicalStrides));
248250

249251
if (bufferImpl.isError())
250252
return wrap(bufferImpl.getStatus());
@@ -257,7 +259,8 @@ MTRT_Status mtrtMemRefCreateExternal(
257259
MTRT_RuntimeClient client, MTRT_PointerType pointerKind,
258260
int64_t bitsPerElement, uintptr_t ptr, int64_t offset, int64_t rank,
259261
const int64_t *shape, const int64_t *strides, MTRT_Device device,
260-
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result) {
262+
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result,
263+
bool assertCanonicalStrides) {
261264
StatusOr<std::unique_ptr<MemRefValue>> bufferImpl =
262265
unwrap(client)->createExternalMemRef(
263266
unwrap(pointerKind), bitsPerElement, ptr, offset,
@@ -267,7 +270,8 @@ MTRT_Status mtrtMemRefCreateExternal(
267270
: std::optional(unwrap(device)),
268271
scalarType == MTRT_ScalarTypeCode_unknown
269272
? std::nullopt
270-
: std::optional(ScalarType(unwrap(scalarType))));
273+
: std::optional(ScalarType(unwrap(scalarType))),
274+
std::optional(assertCanonicalStrides));
271275

272276
if (bufferImpl.isError())
273277
return wrap(bufferImpl.getStatus());

mlir-tensorrt/executor/lib/Runtime/API/API.cpp

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -671,12 +671,50 @@ static StatusOr<int64_t> getFootprintInBytes(llvm::ArrayRef<int64_t> shape,
671671
return sizeBytes;
672672
}
673673

674+
static llvm::SmallVector<int64_t> getCanonicalStride(const llvm::ArrayRef<int64_t>& shape) {
675+
if (shape.empty())
676+
return {};
677+
678+
llvm::SmallVector<int64_t> canonicalStride(shape.size(), 1);
679+
int64_t cumulativeProduct = 1;
680+
681+
for (int64_t dimIndex = shape.size() - 1; dimIndex >= 0; --dimIndex) {
682+
bool isFirstZeroDim = (shape[dimIndex] == 0 && dimIndex != static_cast<int64_t>(shape.size()) - 1);
683+
// For dimensions with size 0 or 1, the stride can be arbitrary.
684+
// We set it to 1 here, but other values would also be valid.
685+
if (isFirstZeroDim || shape[dimIndex] == 1)
686+
canonicalStride[dimIndex] = 1;
687+
else
688+
canonicalStride[dimIndex] = cumulativeProduct;
689+
// For zero-sized dimensions (except the last one), we don't update the cumulative product
690+
// This allows for consistent handling of zero-sized dimensions across different frameworks
691+
cumulativeProduct *= isFirstZeroDim ? 1 : shape[dimIndex];
692+
}
693+
694+
return canonicalStride;
695+
}
696+
697+
static bool areStridesEquivalent(llvm::ArrayRef<int64_t> shape,
698+
llvm::ArrayRef<int64_t> stride,
699+
llvm::ArrayRef<int64_t> expectedStride) {
700+
if (shape.size() != stride.size() || shape.size() != expectedStride.size())
701+
return false;
702+
703+
for (size_t i = 0; i < shape.size(); ++i)
704+
// Allow arbitrary strides for dimensions with size 0 or 1
705+
// This accounts for discrepancies in how different frameworks handle these cases
706+
if (stride[i] != expectedStride[i] && shape[i] != 0 && shape[i] != 1)
707+
return false;
708+
709+
return true;
710+
}
711+
674712
StatusOr<std::unique_ptr<MemRefValue>> MemRefValue::create(
675713
RuntimeClient *client, mlirtrt::runtime::PointerType addressSpace,
676714
int64_t bitsPerElement, uintptr_t ptr, int64_t offset,
677715
llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides,
678-
std::optional<const Device *> device,
679-
std::optional<ScalarType> scalarType) {
716+
std::optional<const Device *> device, std::optional<ScalarType> scalarType,
717+
std::optional<bool> assertCanonicalStrides) {
680718
if (!client)
681719
return getInvalidArgStatus("a valid RuntimeClient must be provided to "
682720
"create a tracked MemRef object");
@@ -691,6 +729,19 @@ StatusOr<std::unique_ptr<MemRefValue>> MemRefValue::create(
691729
return getInvalidArgStatus("a specific device must be provided for MemRefs "
692730
"that are device-visible");
693731

732+
// Check if given strides match canonical stride
733+
if (assertCanonicalStrides && *assertCanonicalStrides) {
734+
llvm::SmallVector<int64_t> canonicalStride = getCanonicalStride(shape);
735+
if (!strides.empty() &&
736+
!areStridesEquivalent(shape, strides, canonicalStride)) {
737+
std::string errorMsg =
738+
llvm::formatv("Given strides [{0}] do not match canonical strides "
739+
"[{1}] for shape [{2}]",
740+
strides, canonicalStride, shape);
741+
return getInvalidArgStatus(errorMsg.c_str());
742+
}
743+
}
744+
694745
return std::unique_ptr<MemRefValue>(
695746
new MemRefValue(client, addressSpace, bitsPerElement, ptr, offset, shape,
696747
strides, device, scalarType));
@@ -777,7 +828,7 @@ StatusOr<std::unique_ptr<MemRefValue>> RuntimeClient::allocateMemRef(
777828
PointerType addressSpace, int64_t bitsPerElement,
778829
llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides,
779830
std::optional<const Device *> device, std::optional<CudaStream> stream,
780-
std::optional<ScalarType> scalarType) {
831+
std::optional<ScalarType> scalarType, std::optional<bool> assertCanonicalStrides) {
781832
if (addressSpace == PointerType::device ||
782833
addressSpace == PointerType::unified) {
783834
if (!device || !*device)
@@ -800,7 +851,7 @@ StatusOr<std::unique_ptr<MemRefValue>> RuntimeClient::allocateMemRef(
800851
// Create the descriptor.
801852
StatusOr<std::unique_ptr<MemRefValue>> bufferImpl =
802853
MemRefValue::create(this, addressSpace, bitsPerElement, allocation->ptr,
803-
0, shape, strides, device, scalarType);
854+
0, shape, strides, device, scalarType, assertCanonicalStrides);
804855
if (bufferImpl.isError())
805856
return bufferImpl.getStatus();
806857

@@ -811,11 +862,11 @@ StatusOr<std::unique_ptr<MemRefValue>> RuntimeClient::createExternalMemRef(
811862
PointerType addressSpace, int64_t bitsPerElement, uintptr_t ptr,
812863
int64_t offset, llvm::ArrayRef<int64_t> shape,
813864
llvm::ArrayRef<int64_t> strides, std::optional<const Device *> device,
814-
std::optional<ScalarType> scalarType) {
865+
std::optional<ScalarType> scalarType, std::optional<bool> assertCanonicalStrides) {
815866
// Create the descriptor.
816867
StatusOr<std::unique_ptr<MemRefValue>> memref =
817868
MemRefValue::create(this, addressSpace, bitsPerElement, ptr, offset,
818-
shape, strides, device, scalarType);
869+
shape, strides, device, scalarType, assertCanonicalStrides);
819870
if (!memref.isOk())
820871
return memref.getStatus();
821872

mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,8 @@ static std::unique_ptr<PyMemRefValue> createMemRef(
313313
}
314314

315315
static std::unique_ptr<PyMemRefValue>
316-
createMemRefViewFromDLPack(PyRuntimeClient &client, py::capsule capsule) {
316+
createMemRefViewFromDLPack(PyRuntimeClient &client, py::capsule capsule,
317+
std::optional<bool> assertCanonicalStrides) {
317318
DLManagedTensor *managedTensor = static_cast<DLManagedTensor *>(
318319
PyCapsule_GetPointer(capsule.ptr(), "dltensor"));
319320

@@ -368,14 +369,16 @@ createMemRefViewFromDLPack(PyRuntimeClient &client, py::capsule capsule) {
368369
}
369370

370371
if (data) {
371-
s = mtrtMemRefCreateExternal(client, addressSpace, bytesPerElement * 8,
372-
reinterpret_cast<uintptr_t>(data), offset,
373-
rank, shape, strides, device, elementType,
374-
&result);
372+
s = mtrtMemRefCreateExternal(
373+
client, addressSpace, bytesPerElement * 8,
374+
reinterpret_cast<uintptr_t>(data), offset, rank, shape, strides, device,
375+
elementType, &result,
376+
assertCanonicalStrides ? *assertCanonicalStrides : false);
375377
} else {
376-
s = mtrtMemRefCreate(client, addressSpace, bytesPerElement * 8, rank, shape,
377-
strides, device, mtrtStreamGetNull(), elementType,
378-
&result);
378+
s = mtrtMemRefCreate(
379+
client, addressSpace, bytesPerElement * 8, rank, shape, strides, device,
380+
mtrtStreamGetNull(), elementType, &result,
381+
assertCanonicalStrides ? *assertCanonicalStrides : false);
379382
}
380383

381384
THROW_IF_MTRT_ERROR(s);
@@ -788,11 +791,15 @@ PYBIND11_MODULE(_api, m) {
788791
"returns a new memref and allocates uninitialized backing storage")
789792
.def(
790793
"create_memref_view_from_dlpack",
791-
[](PyRuntimeClient &self, py::capsule capsule) {
792-
return createMemRefViewFromDLPack(self, capsule).release();
794+
[](PyRuntimeClient &self, py::capsule capsule,
795+
std::optional<bool> assertCanonicalStrides) {
796+
return createMemRefViewFromDLPack(self, capsule,
797+
assertCanonicalStrides)
798+
.release();
793799
},
794-
py::arg("dltensor") = py::none(), py::keep_alive<0, 1>(),
795-
py::keep_alive<0, 2>())
800+
py::arg("dltensor") = py::none(),
801+
py::arg("assert_canonical_strides") = py::none(),
802+
py::keep_alive<0, 1>(), py::keep_alive<0, 2>())
796803
.def(
797804
"create_device_memref_view",
798805
[](PyRuntimeClient &self, uintptr_t ptr, std::vector<int64_t> shape,

mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_create_memref.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,3 +514,54 @@ def create_dangling_memref():
514514
# CHECK-LABEL: Test memref maintains data's lifetime
515515
# CHECK-NEXT: -- Inner scope: np.from_dlpack(): [1 2]
516516
# CHECK-NEXT: -- Outer scope: np.from_dlpack(): [1 2]
517+
518+
519+
def check_non_canonical_stride(client, assert_canonical_strides):
520+
try:
521+
t = cp.arange(12, dtype=cp.float32).reshape(3, 4)
522+
a = cp.transpose(t)
523+
memref = client.create_memref_view_from_dlpack(
524+
a.__dlpack__(), assert_canonical_strides
525+
)
526+
except Exception as e:
527+
print(f"Received error message: {str(e)}")
528+
529+
530+
def check_canonical_stride(client, assert_canonical_strides):
531+
try:
532+
t = cp.arange(12, dtype=cp.float32).reshape(3, 4)
533+
memref = client.create_memref_view_from_dlpack(
534+
t.__dlpack__(), assert_canonical_strides
535+
)
536+
except Exception as e:
537+
print(f"Received error message: {str(e)}")
538+
539+
540+
def test_memref_strides():
541+
print("Testing non-canonical stride: assert_canonical_strides = True")
542+
non_canonical_result = check_non_canonical_stride(
543+
client, assert_canonical_strides=True
544+
)
545+
546+
print("Testing non-canonical stride: assert_canonical_strides = False")
547+
non_canonical_result = check_non_canonical_stride(
548+
client, assert_canonical_strides=False
549+
)
550+
551+
print("Testing canonical stride: assert_canonical_strides = True")
552+
canonical_result = check_canonical_stride(client, assert_canonical_strides=True)
553+
554+
print("Testing canonical stride: assert_canonical_strides = False")
555+
canonical_result = check_canonical_stride(client, assert_canonical_strides=False)
556+
557+
558+
print("Test memref strides")
559+
test_memref_strides()
560+
561+
# CHECK-LABEL: Test memref strides
562+
# CHECK-NEXT: Testing non-canonical stride: assert_canonical_strides = True
563+
# CHECK-NEXT: Received error message: InvalidArgument: InvalidArgument:
564+
# CHECK-SAME: Given strides [1, 4] do not match canonical strides [3, 1] for shape [4, 3]
565+
# CHECK-NEXT: Testing non-canonical stride: assert_canonical_strides = False
566+
# CHECK-NEXT: Testing canonical stride: assert_canonical_strides = True
567+
# CHECK-NEXT: Testing canonical stride: assert_canonical_strides = False

0 commit comments

Comments
 (0)