Skip to content

Commit 6204f3a

Browse files
committed
[API/MemRef] Implement canonical stride validation for MemRefValue
Add stride validation in MemRefValue::create to compute canonical stride and compare against given strides. We need to handle special cases for zero-sized and unit-sized dimensions since frameworks deal with them arbitrarily while converting to the corresponding DLPack tensor. Add Python tests to verify both canonical and non-canonical stride validation.
1 parent 0c73646 commit 6204f3a

File tree

6 files changed

+133
-17
lines changed

6 files changed

+133
-17
lines changed

mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ MLIR_CAPI_EXPORTED MTRT_Status
130130
mtrtMemRefCreate(MTRT_RuntimeClient client, MTRT_PointerType pointerKind,
131131
int64_t bitsPerElement, int64_t rank, const int64_t *shape,
132132
const int64_t *strides, MTRT_Device device, MTRT_Stream stream,
133-
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result);
133+
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result,
134+
bool assertCanonicalStrides = false);
134135

135136
/// Creates an externally managed MemRef value. The caller provides all the
136137
/// metadata for the MemRef including the shape, strides (in elements), pointer,
@@ -142,7 +143,8 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtMemRefCreateExternal(
142143
MTRT_RuntimeClient client, MTRT_PointerType pointerKind,
143144
int64_t bitsPerElement, uintptr_t ptr, int64_t offset, int64_t rank,
144145
const int64_t *shape, const int64_t *strides, MTRT_Device device,
145-
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result);
146+
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result,
147+
bool assertCanonicalStrides = false);
146148

147149
/// Destroys `MTRT_MemRefValue` in a potentially asynchronous manner.
148150
/// If `buffer` is a device buffer, device memory is freed in the stream

mlir-tensorrt/executor/include/mlir-executor/Runtime/API/API.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,8 @@ class MemRefValue : public RuntimeValue {
647647
int64_t bitsPerElement, uintptr_t ptr, int64_t offset,
648648
llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides,
649649
std::optional<const Device *> device,
650-
std::optional<ScalarType> scalarType);
650+
std::optional<ScalarType> scalarType,
651+
std::optional<bool> assertCanonicalStrides = false);
651652

652653
mlirtrt::runtime::PointerType getBufferKind() { return addressSpace; }
653654
int64_t getElementBitWidth() const { return bitsPerElement; }
@@ -917,15 +918,17 @@ class RuntimeClient {
917918
llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides,
918919
std::optional<const Device *> device = {},
919920
std::optional<CudaStream> stream = {},
920-
std::optional<ScalarType> scalarType = {});
921+
std::optional<ScalarType> scalarType = {},
922+
std::optional<bool> assertCanonicalStrides = {});
921923

922924
StatusOr<std::unique_ptr<MemRefValue>>
923925
createExternalMemRef(PointerType addressSpace, int64_t bitsPerElement,
924926
uintptr_t ptr, int64_t offset,
925927
llvm::ArrayRef<int64_t> shape,
926928
llvm::ArrayRef<int64_t> strides,
927929
std::optional<const Device *> device = {},
928-
std::optional<ScalarType> scalarType = {});
930+
std::optional<ScalarType> scalarType = {},
931+
std::optional<bool> assertCanonicalStrides = {});
929932

930933
/// Frees the memory in `value`. The `stream` may optionally be provided
931934
/// for resources that can be deallocated asynchronously.

mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,8 @@ MTRT_Status
231231
mtrtMemRefCreate(MTRT_RuntimeClient client, MTRT_PointerType pointerKind,
232232
int64_t bitsPerElement, int64_t rank, const int64_t *shape,
233233
const int64_t *strides, MTRT_Device device, MTRT_Stream stream,
234-
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result) {
234+
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result,
235+
bool assertCanonicalStrides) {
235236
StatusOr<std::unique_ptr<MemRefValue>> bufferImpl =
236237
unwrap(client)->allocateMemRef(
237238
unwrap(pointerKind), bitsPerElement,
@@ -244,7 +245,8 @@ mtrtMemRefCreate(MTRT_RuntimeClient client, MTRT_PointerType pointerKind,
244245
: std::optional(unwrap(stream)->getRawStream()),
245246
scalarType != MTRT_ScalarTypeCode::MTRT_ScalarTypeCode_unknown
246247
? std::optional(ScalarType(unwrap(scalarType)))
247-
: std::nullopt);
248+
: std::nullopt,
249+
std::optional(assertCanonicalStrides));
248250

249251
if (bufferImpl.isError())
250252
return wrap(bufferImpl.getStatus());
@@ -257,7 +259,8 @@ MTRT_Status mtrtMemRefCreateExternal(
257259
MTRT_RuntimeClient client, MTRT_PointerType pointerKind,
258260
int64_t bitsPerElement, uintptr_t ptr, int64_t offset, int64_t rank,
259261
const int64_t *shape, const int64_t *strides, MTRT_Device device,
260-
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result) {
262+
MTRT_ScalarTypeCode scalarType, MTRT_MemRefValue *result,
263+
bool assertCanonicalStrides) {
261264
StatusOr<std::unique_ptr<MemRefValue>> bufferImpl =
262265
unwrap(client)->createExternalMemRef(
263266
unwrap(pointerKind), bitsPerElement, ptr, offset,
@@ -267,7 +270,8 @@ MTRT_Status mtrtMemRefCreateExternal(
267270
: std::optional(unwrap(device)),
268271
scalarType == MTRT_ScalarTypeCode_unknown
269272
? std::nullopt
270-
: std::optional(ScalarType(unwrap(scalarType))));
273+
: std::optional(ScalarType(unwrap(scalarType))),
274+
std::optional(assertCanonicalStrides));
271275

272276
if (bufferImpl.isError())
273277
return wrap(bufferImpl.getStatus());

mlir-tensorrt/executor/lib/Runtime/API/API.cpp

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -671,12 +671,50 @@ static StatusOr<int64_t> getFootprintInBytes(llvm::ArrayRef<int64_t> shape,
671671
return sizeBytes;
672672
}
673673

674+
static std::vector<int64_t> getCanonicalStride(const llvm::ArrayRef<int64_t>& shape) {
675+
if (shape.empty())
676+
return {};
677+
678+
std::vector<int64_t> canonicalStride(shape.size(), 1);
679+
int64_t cumulativeProduct = 1;
680+
681+
for (int64_t dimIndex = shape.size() - 1; dimIndex >= 0; --dimIndex) {
682+
bool isFirstZeroDim = (shape[dimIndex] == 0 && dimIndex != static_cast<int64_t>(shape.size()) - 1);
683+
// For dimensions with size 0 or 1, the stride can be arbitrary.
684+
// We set it to 1 here, but other values would also be valid.
685+
if (isFirstZeroDim || shape[dimIndex] == 1)
686+
canonicalStride[dimIndex] = 1;
687+
else
688+
canonicalStride[dimIndex] = cumulativeProduct;
689+
// For zero-sized dimensions (except the last one), we don't update the cumulative product
690+
// This allows for consistent handling of zero-sized dimensions across different frameworks
691+
cumulativeProduct *= isFirstZeroDim ? 1 : shape[dimIndex];
692+
}
693+
694+
return canonicalStride;
695+
}
696+
697+
static bool areStridesEquivalent(llvm::ArrayRef<int64_t> shape,
698+
llvm::ArrayRef<int64_t> stride,
699+
llvm::ArrayRef<int64_t> expectedStride) {
700+
if (shape.size() != stride.size() || shape.size() != expectedStride.size())
701+
return false;
702+
703+
for (size_t i = 0; i < shape.size(); ++i)
704+
// Allow arbitrary strides for dimensions with size 0 or 1
705+
// This accounts for discrepancies in how different frameworks handle these cases
706+
if (stride[i] != expectedStride[i] && shape[i] != 0 && shape[i] != 1)
707+
return false;
708+
709+
return true;
710+
}
711+
674712
StatusOr<std::unique_ptr<MemRefValue>> MemRefValue::create(
675713
RuntimeClient *client, mlirtrt::runtime::PointerType addressSpace,
676714
int64_t bitsPerElement, uintptr_t ptr, int64_t offset,
677715
llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides,
678-
std::optional<const Device *> device,
679-
std::optional<ScalarType> scalarType) {
716+
std::optional<const Device *> device, std::optional<ScalarType> scalarType,
717+
std::optional<bool> assertCanonicalStrides) {
680718
if (!client)
681719
return getInvalidArgStatus("a valid RuntimeClient must be provided to "
682720
"create a tracked MemRef object");
@@ -691,6 +729,19 @@ StatusOr<std::unique_ptr<MemRefValue>> MemRefValue::create(
691729
return getInvalidArgStatus("a specific device must be provided for MemRefs "
692730
"that are device-visible");
693731

732+
// Check if given strides match canonical stride
733+
if (assertCanonicalStrides && *assertCanonicalStrides) {
734+
std::vector<int64_t> canonicalStride = getCanonicalStride(shape);
735+
if (!strides.empty() &&
736+
!areStridesEquivalent(shape, strides, canonicalStride)) {
737+
std::string errorMsg =
738+
llvm::formatv("Given strides [{0}] do not match canonical strides "
739+
"[{1}] for shape [{2}]",
740+
strides, llvm::ArrayRef(canonicalStride), shape);
741+
return getInvalidArgStatus(errorMsg.c_str());
742+
}
743+
}
744+
694745
return std::unique_ptr<MemRefValue>(
695746
new MemRefValue(client, addressSpace, bitsPerElement, ptr, offset, shape,
696747
strides, device, scalarType));
@@ -777,7 +828,7 @@ StatusOr<std::unique_ptr<MemRefValue>> RuntimeClient::allocateMemRef(
777828
PointerType addressSpace, int64_t bitsPerElement,
778829
llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides,
779830
std::optional<const Device *> device, std::optional<CudaStream> stream,
780-
std::optional<ScalarType> scalarType) {
831+
std::optional<ScalarType> scalarType, std::optional<bool> assertCanonicalStrides) {
781832
if (addressSpace == PointerType::device ||
782833
addressSpace == PointerType::unified) {
783834
if (!device || !*device)
@@ -800,7 +851,7 @@ StatusOr<std::unique_ptr<MemRefValue>> RuntimeClient::allocateMemRef(
800851
// Create the descriptor.
801852
StatusOr<std::unique_ptr<MemRefValue>> bufferImpl =
802853
MemRefValue::create(this, addressSpace, bitsPerElement, allocation->ptr,
803-
0, shape, strides, device, scalarType);
854+
0, shape, strides, device, scalarType, assertCanonicalStrides);
804855
if (bufferImpl.isError())
805856
return bufferImpl.getStatus();
806857

@@ -811,11 +862,11 @@ StatusOr<std::unique_ptr<MemRefValue>> RuntimeClient::createExternalMemRef(
811862
PointerType addressSpace, int64_t bitsPerElement, uintptr_t ptr,
812863
int64_t offset, llvm::ArrayRef<int64_t> shape,
813864
llvm::ArrayRef<int64_t> strides, std::optional<const Device *> device,
814-
std::optional<ScalarType> scalarType) {
865+
std::optional<ScalarType> scalarType, std::optional<bool> assertCanonicalStrides) {
815866
// Create the descriptor.
816867
StatusOr<std::unique_ptr<MemRefValue>> memref =
817868
MemRefValue::create(this, addressSpace, bitsPerElement, ptr, offset,
818-
shape, strides, device, scalarType);
869+
shape, strides, device, scalarType, assertCanonicalStrides);
819870
if (!memref.isOk())
820871
return memref.getStatus();
821872

mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,11 +371,11 @@ createMemRefViewFromDLPack(PyRuntimeClient &client, py::capsule capsule) {
371371
s = mtrtMemRefCreateExternal(client, addressSpace, bytesPerElement * 8,
372372
reinterpret_cast<uintptr_t>(data), offset,
373373
rank, shape, strides, device, elementType,
374-
&result);
374+
&result, true /*assertCanonicalStrides*/);
375375
} else {
376376
s = mtrtMemRefCreate(client, addressSpace, bytesPerElement * 8, rank, shape,
377377
strides, device, mtrtStreamGetNull(), elementType,
378-
&result);
378+
&result, true /*assertCanonicalStrides*/);
379379
}
380380

381381
THROW_IF_MTRT_ERROR(s);

mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_create_memref.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,3 +514,59 @@ def create_dangling_memref():
514514
# CHECK-LABEL: Test memref maintains data's lifetime
515515
# CHECK-NEXT: -- Inner scope: np.from_dlpack(): [1 2]
516516
# CHECK-NEXT: -- Outer scope: np.from_dlpack(): [1 2]
517+
518+
519+
def check_non_canonical_stride(client):
520+
try:
521+
t = cp.arange(12, dtype=cp.float32).reshape(3, 4)
522+
a = cp.transpose(t)
523+
memref = client.create_memref_view_from_dlpack(a.__dlpack__())
524+
print(
525+
"Test failed: Expected exception was not thrown for CuPy non-canonical stride"
526+
)
527+
return False
528+
except Exception as e:
529+
error_message = str(e)
530+
if (
531+
"Given strides do not match canonical stride for the shape.\nShape: 4, 3\nExpected canonical stride: 3, 1\nGiven stride: 1, 4'"
532+
in error_message
533+
):
534+
print(
535+
"Test passed: Correct error message received for non-canonical stride"
536+
)
537+
return True
538+
else:
539+
print(
540+
f"Test failed: Unexpected error message for non-canonical stride: {error_message}"
541+
)
542+
return False
543+
544+
545+
def check_canonical_stride(client):
546+
try:
547+
t = cp.arange(12, dtype=cp.float32).reshape(3, 4)
548+
memref = client.create_memref_view_from_dlpack(t.__dlpack__())
549+
print("Test passed: No exception thrown for canonical stride")
550+
return True
551+
except Exception as e:
552+
print(f"Test failed: Unexpected exception for canonical stride: {str(e)}")
553+
return False
554+
555+
556+
def test_memref_strides():
557+
print("Testing non-canonical stride:")
558+
non_canonical_result = check_non_canonical_stride(client)
559+
560+
print("Testing canonical stride:")
561+
canonical_result = check_canonical_stride(client)
562+
563+
564+
print("Test memref strides")
565+
test_memref_strides()
566+
567+
# CHECK-LABEL: Test memref strides
568+
# CHECK-NEXT: Testing non-canonical stride:
569+
# CHECK-NEXT: Test failed: Unexpected error message for non-canonical stride: InvalidArgument:
570+
# CHECK-SAME: InvalidArgument: Given strides [1, 4] do not match canonical strides [3, 1] for shape [4, 3]
571+
# CHECK-NEXT: Testing canonical stride:
572+
# CHECK-NEXT: Test passed: No exception thrown for canonical stride

0 commit comments

Comments
 (0)