Skip to content

Commit 7c6d1d2

Browse files
committed
[API/MemRefValue] Validate strides against canonical strides for non-empty shapes
1 parent 59b9536 commit 7c6d1d2

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

mlir-tensorrt/executor/lib/Runtime/API/API.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,38 @@ static StatusOr<int64_t> getFootprintInBytes(llvm::ArrayRef<int64_t> shape,
671671
return sizeBytes;
672672
}
673673

674+
static std::vector<int64_t> getCanonicalStride(const llvm::ArrayRef<int64_t>& shape) {
675+
if (shape.empty())
676+
return {};
677+
678+
std::vector<int64_t> canonicalStride(shape.size(), 1);
679+
int64_t cumulativeProduct = 1;
680+
681+
for (int64_t dimIndex = shape.size() - 1; dimIndex >= 0; --dimIndex) {
682+
bool isFirstZeroDim = (shape[dimIndex] == 0 && dimIndex != static_cast<int64_t>(shape.size()) - 1);
683+
if (isFirstZeroDim)
684+
canonicalStride[dimIndex] = 1;
685+
else if (shape[dimIndex] != 1)
686+
canonicalStride[dimIndex] = cumulativeProduct;
687+
cumulativeProduct *= isFirstZeroDim ? 1 : shape[dimIndex];
688+
}
689+
690+
return canonicalStride;
691+
}
692+
693+
static bool areStridesEquivalent(llvm::ArrayRef<int64_t> shape,
694+
llvm::ArrayRef<int64_t> stride,
695+
llvm::ArrayRef<int64_t> expectedStride) {
696+
if (shape.size() != stride.size() || shape.size() != expectedStride.size())
697+
return false;
698+
699+
for (size_t i = 0; i < shape.size(); ++i)
700+
if (stride[i] != expectedStride[i] && shape[i] != 0 && shape[i] != 1)
701+
return false;
702+
703+
return true;
704+
}
705+
674706
StatusOr<std::unique_ptr<MemRefValue>> MemRefValue::create(
675707
RuntimeClient *client, mlirtrt::runtime::PointerType addressSpace,
676708
int64_t bitsPerElement, uintptr_t ptr, int64_t offset,
@@ -691,6 +723,17 @@ StatusOr<std::unique_ptr<MemRefValue>> MemRefValue::create(
691723
return getInvalidArgStatus("a specific device must be provided for MemRefs "
692724
"that are device-visible");
693725

726+
// Check if given strides match canonical stride
727+
if (!strides.empty() && !shape.empty()) {
728+
std::vector<int64_t> canonicalStride = getCanonicalStride(shape);
729+
if (!areStridesEquivalent(shape, strides, canonicalStride)) {
730+
std::string errorMsg = llvm::formatv(
731+
"Given strides [{0}] do not match canonical strides [{1}] for shape [{2}]",
732+
strides, llvm::ArrayRef(canonicalStride), shape);
733+
return getInvalidArgStatus(errorMsg.c_str());
734+
}
735+
}
736+
694737
return std::unique_ptr<MemRefValue>(
695738
new MemRefValue(client, addressSpace, bitsPerElement, ptr, offset, shape,
696739
strides, device, scalarType));

mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_create_memref.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,3 +514,59 @@ def create_dangling_memref():
514514
# CHECK-LABEL: Test memref maintains data's lifetime
515515
# CHECK-NEXT: -- Inner scope: np.from_dlpack(): [1 2]
516516
# CHECK-NEXT: -- Outer scope: np.from_dlpack(): [1 2]
517+
518+
519+
def check_non_canonical_stride(client):
520+
try:
521+
t = cp.arange(12, dtype=cp.float32).reshape(3, 4)
522+
a = cp.transpose(t)
523+
memref = client.create_memref_view_from_dlpack(a.__dlpack__())
524+
print(
525+
"Test failed: Expected exception was not thrown for CuPy non-canonical stride"
526+
)
527+
return False
528+
except Exception as e:
529+
error_message = str(e)
530+
if (
531+
"Given strides do not match canonical stride for the shape.\nShape: 4, 3\nExpected canonical stride: 3, 1\nGiven stride: 1, 4'"
532+
in error_message
533+
):
534+
print(
535+
"Test passed: Correct error message received for non-canonical stride"
536+
)
537+
return True
538+
else:
539+
print(
540+
f"Test failed: Unexpected error message for non-canonical stride: {error_message}"
541+
)
542+
return False
543+
544+
545+
def check_canonical_stride(client):
546+
try:
547+
t = cp.arange(12, dtype=cp.float32).reshape(3, 4)
548+
memref = client.create_memref_view_from_dlpack(t.__dlpack__())
549+
print("Test passed: No exception thrown for canonical stride")
550+
return True
551+
except Exception as e:
552+
print(f"Test failed: Unexpected exception for canonical stride: {str(e)}")
553+
return False
554+
555+
556+
def test_memref_strides():
557+
print("Testing non-canonical stride:")
558+
non_canonical_result = check_non_canonical_stride(client)
559+
560+
print("Testing canonical stride:")
561+
canonical_result = check_canonical_stride(client)
562+
563+
564+
print("Test memref strides")
565+
test_memref_strides()
566+
567+
# CHECK-LABEL: Test memref strides
568+
# CHECK-NEXT: Testing non-canonical stride:
569+
# CHECK-NEXT: Test failed: Unexpected error message for non-canonical stride: InvalidArgument:
570+
# CHECK-SAME: InvalidArgument: Given strides [1, 4] do not match canonical strides [3, 1] for shape [4, 3]
571+
# CHECK-NEXT: Testing canonical stride:
572+
# CHECK-NEXT: Test passed: No exception thrown for canonical stride

0 commit comments

Comments
 (0)