@@ -671,6 +671,38 @@ static StatusOr<int64_t> getFootprintInBytes(llvm::ArrayRef<int64_t> shape,
671671 return sizeBytes;
672672}
673673
674+ static std::vector<int64_t > getCanonicalStride (const llvm::ArrayRef<int64_t >& shape) {
675+ if (shape.empty ())
676+ return {};
677+
678+ std::vector<int64_t > canonicalStride (shape.size (), 1 );
679+ int64_t cumulativeProduct = 1 ;
680+
681+ for (int64_t dimIndex = shape.size () - 1 ; dimIndex >= 0 ; --dimIndex) {
682+ bool isFirstZeroDim = (shape[dimIndex] == 0 && dimIndex != static_cast <int64_t >(shape.size ()) - 1 );
683+ if (isFirstZeroDim)
684+ canonicalStride[dimIndex] = 1 ;
685+ else if (shape[dimIndex] != 1 )
686+ canonicalStride[dimIndex] = cumulativeProduct;
687+ cumulativeProduct *= isFirstZeroDim ? 1 : shape[dimIndex];
688+ }
689+
690+ return canonicalStride;
691+ }
692+
693+ static bool areStridesEquivalent (llvm::ArrayRef<int64_t > shape,
694+ llvm::ArrayRef<int64_t > stride,
695+ llvm::ArrayRef<int64_t > expectedStride) {
696+ if (shape.size () != stride.size () || shape.size () != expectedStride.size ())
697+ return false ;
698+
699+ for (size_t i = 0 ; i < shape.size (); ++i)
700+ if (stride[i] != expectedStride[i] && shape[i] != 0 && shape[i] != 1 )
701+ return false ;
702+
703+ return true ;
704+ }
705+
674706StatusOr<std::unique_ptr<MemRefValue>> MemRefValue::create (
675707 RuntimeClient *client, mlirtrt::runtime::PointerType addressSpace,
676708 int64_t bitsPerElement, uintptr_t ptr, int64_t offset,
@@ -691,6 +723,17 @@ StatusOr<std::unique_ptr<MemRefValue>> MemRefValue::create(
691723 return getInvalidArgStatus (" a specific device must be provided for MemRefs "
692724 " that are device-visible" );
693725
726+ // Check if given strides match canonical stride
727+ if (!strides.empty () && !shape.empty ()) {
728+ std::vector<int64_t > canonicalStride = getCanonicalStride (shape);
729+ if (!areStridesEquivalent (shape, strides, canonicalStride)) {
730+ std::string errorMsg = llvm::formatv (
731+ " Given strides [{0}] do not match canonical strides [{1}] for shape [{2}]" ,
732+ strides, llvm::ArrayRef (canonicalStride), shape);
733+ return getInvalidArgStatus (errorMsg.c_str ());
734+ }
735+ }
736+
694737 return std::unique_ptr<MemRefValue>(
695738 new MemRefValue (client, addressSpace, bitsPerElement, ptr, offset, shape,
696739 strides, device, scalarType));
0 commit comments