Skip to content

Commit 238e993

Browse files
committed
[API/MemRef] Implement canonical stride validation for MemRefValue
Introduce stride validation in MemRefValue::create to ensure consistency and correctness of memory layouts. Add getCanonicalStride and areStridesEquivalent functions to compute and compare strides, handling special cases for zero-sized and unit-sized dimensions. Return InvalidArgument error for non-canonical strides. Add Python tests for verification. This change improves MemRefValue robustness and prevents issues with incorrect stride representations.
1 parent 59b9536 commit 238e993

File tree

2 files changed

+105
-0
lines changed

2 files changed

+105
-0
lines changed

mlir-tensorrt/executor/lib/Runtime/API/API.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,44 @@ static StatusOr<int64_t> getFootprintInBytes(llvm::ArrayRef<int64_t> shape,
671671
return sizeBytes;
672672
}
673673

674+
static std::vector<int64_t> getCanonicalStride(const llvm::ArrayRef<int64_t>& shape) {
675+
if (shape.empty())
676+
return {};
677+
678+
std::vector<int64_t> canonicalStride(shape.size(), 1);
679+
int64_t cumulativeProduct = 1;
680+
681+
for (int64_t dimIndex = shape.size() - 1; dimIndex >= 0; --dimIndex) {
682+
bool isFirstZeroDim = (shape[dimIndex] == 0 && dimIndex != static_cast<int64_t>(shape.size()) - 1);
683+
// For dimensions with size 0 or 1, the stride can be arbitrary.
684+
// We set it to 1 here, but other values would also be valid.
685+
if (isFirstZeroDim || shape[dimIndex] == 1)
686+
canonicalStride[dimIndex] = 1;
687+
else
688+
canonicalStride[dimIndex] = cumulativeProduct;
689+
// For zero-sized dimensions (except the last one), we don't update the cumulative product
690+
// This allows for consistent handling of zero-sized dimensions across different frameworks
691+
cumulativeProduct *= isFirstZeroDim ? 1 : shape[dimIndex];
692+
}
693+
694+
return canonicalStride;
695+
}
696+
697+
static bool areStridesEquivalent(llvm::ArrayRef<int64_t> shape,
698+
llvm::ArrayRef<int64_t> stride,
699+
llvm::ArrayRef<int64_t> expectedStride) {
700+
if (shape.size() != stride.size() || shape.size() != expectedStride.size())
701+
return false;
702+
703+
for (size_t i = 0; i < shape.size(); ++i)
704+
// Allow arbitrary strides for dimensions with size 0 or 1
705+
// This accounts for discrepancies in how different frameworks handle these cases
706+
if (stride[i] != expectedStride[i] && shape[i] != 0 && shape[i] != 1)
707+
return false;
708+
709+
return true;
710+
}
711+
674712
StatusOr<std::unique_ptr<MemRefValue>> MemRefValue::create(
675713
RuntimeClient *client, mlirtrt::runtime::PointerType addressSpace,
676714
int64_t bitsPerElement, uintptr_t ptr, int64_t offset,
@@ -691,6 +729,17 @@ StatusOr<std::unique_ptr<MemRefValue>> MemRefValue::create(
691729
return getInvalidArgStatus("a specific device must be provided for MemRefs "
692730
"that are device-visible");
693731

732+
// Check if given strides match canonical stride
733+
if (!strides.empty() && !shape.empty()) {
734+
std::vector<int64_t> canonicalStride = getCanonicalStride(shape);
735+
if (!areStridesEquivalent(shape, strides, canonicalStride)) {
736+
std::string errorMsg = llvm::formatv(
737+
"Given strides [{0}] do not match canonical strides [{1}] for shape [{2}]",
738+
strides, llvm::ArrayRef(canonicalStride), shape);
739+
return getInvalidArgStatus(errorMsg.c_str());
740+
}
741+
}
742+
694743
return std::unique_ptr<MemRefValue>(
695744
new MemRefValue(client, addressSpace, bitsPerElement, ptr, offset, shape,
696745
strides, device, scalarType));

mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_create_memref.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,3 +514,59 @@ def create_dangling_memref():
514514
# CHECK-LABEL: Test memref maintains data's lifetime
515515
# CHECK-NEXT: -- Inner scope: np.from_dlpack(): [1 2]
516516
# CHECK-NEXT: -- Outer scope: np.from_dlpack(): [1 2]
517+
518+
519+
def check_non_canonical_stride(client):
520+
try:
521+
t = cp.arange(12, dtype=cp.float32).reshape(3, 4)
522+
a = cp.transpose(t)
523+
memref = client.create_memref_view_from_dlpack(a.__dlpack__())
524+
print(
525+
"Test failed: Expected exception was not thrown for CuPy non-canonical stride"
526+
)
527+
return False
528+
except Exception as e:
529+
error_message = str(e)
530+
if (
531+
"Given strides do not match canonical stride for the shape.\nShape: 4, 3\nExpected canonical stride: 3, 1\nGiven stride: 1, 4'"
532+
in error_message
533+
):
534+
print(
535+
"Test passed: Correct error message received for non-canonical stride"
536+
)
537+
return True
538+
else:
539+
print(
540+
f"Test failed: Unexpected error message for non-canonical stride: {error_message}"
541+
)
542+
return False
543+
544+
545+
def check_canonical_stride(client):
546+
try:
547+
t = cp.arange(12, dtype=cp.float32).reshape(3, 4)
548+
memref = client.create_memref_view_from_dlpack(t.__dlpack__())
549+
print("Test passed: No exception thrown for canonical stride")
550+
return True
551+
except Exception as e:
552+
print(f"Test failed: Unexpected exception for canonical stride: {str(e)}")
553+
return False
554+
555+
556+
def test_memref_strides():
557+
print("Testing non-canonical stride:")
558+
non_canonical_result = check_non_canonical_stride(client)
559+
560+
print("Testing canonical stride:")
561+
canonical_result = check_canonical_stride(client)
562+
563+
564+
print("Test memref strides")
565+
test_memref_strides()
566+
567+
# CHECK-LABEL: Test memref strides
568+
# CHECK-NEXT: Testing non-canonical stride:
569+
# CHECK-NEXT: Test failed: Unexpected error message for non-canonical stride: InvalidArgument:
570+
# CHECK-SAME: InvalidArgument: Given strides [1, 4] do not match canonical strides [3, 1] for shape [4, 3]
571+
# CHECK-NEXT: Testing canonical stride:
572+
# CHECK-NEXT: Test passed: No exception thrown for canonical stride

0 commit comments

Comments
 (0)