@@ -671,12 +671,50 @@ static StatusOr<int64_t> getFootprintInBytes(llvm::ArrayRef<int64_t> shape,
671671 return sizeBytes;
672672}
673673
674+ static std::vector<int64_t > getCanonicalStride (const llvm::ArrayRef<int64_t >& shape) {
675+ if (shape.empty ())
676+ return {};
677+
678+ std::vector<int64_t > canonicalStride (shape.size (), 1 );
679+ int64_t cumulativeProduct = 1 ;
680+
681+ for (int64_t dimIndex = shape.size () - 1 ; dimIndex >= 0 ; --dimIndex) {
682+ bool isFirstZeroDim = (shape[dimIndex] == 0 && dimIndex != static_cast <int64_t >(shape.size ()) - 1 );
683+ // For dimensions with size 0 or 1, the stride can be arbitrary.
684+ // We set it to 1 here, but other values would also be valid.
685+ if (isFirstZeroDim || shape[dimIndex] == 1 )
686+ canonicalStride[dimIndex] = 1 ;
687+ else
688+ canonicalStride[dimIndex] = cumulativeProduct;
689+ // For zero-sized dimensions (except the last one), we don't update the cumulative product
690+ // This allows for consistent handling of zero-sized dimensions across different frameworks
691+ cumulativeProduct *= isFirstZeroDim ? 1 : shape[dimIndex];
692+ }
693+
694+ return canonicalStride;
695+ }
696+
697+ static bool areStridesEquivalent (llvm::ArrayRef<int64_t > shape,
698+ llvm::ArrayRef<int64_t > stride,
699+ llvm::ArrayRef<int64_t > expectedStride) {
700+ if (shape.size () != stride.size () || shape.size () != expectedStride.size ())
701+ return false ;
702+
703+ for (size_t i = 0 ; i < shape.size (); ++i)
704+ // Allow arbitrary strides for dimensions with size 0 or 1
705+ // This accounts for discrepancies in how different frameworks handle these cases
706+ if (stride[i] != expectedStride[i] && shape[i] != 0 && shape[i] != 1 )
707+ return false ;
708+
709+ return true ;
710+ }
711+
674712StatusOr<std::unique_ptr<MemRefValue>> MemRefValue::create (
675713 RuntimeClient *client, mlirtrt::runtime::PointerType addressSpace,
676714 int64_t bitsPerElement, uintptr_t ptr, int64_t offset,
677715 llvm::ArrayRef<int64_t > shape, llvm::ArrayRef<int64_t > strides,
678- std::optional<const Device *> device,
679- std::optional<ScalarType> scalarType ) {
716+ std::optional<const Device *> device, std::optional<ScalarType> scalarType,
717+ std::optional<bool > assertCanonicalStrides ) {
680718 if (!client)
681719 return getInvalidArgStatus (" a valid RuntimeClient must be provided to "
682720 " create a tracked MemRef object" );
@@ -691,6 +729,19 @@ StatusOr<std::unique_ptr<MemRefValue>> MemRefValue::create(
691729 return getInvalidArgStatus (" a specific device must be provided for MemRefs "
692730 " that are device-visible" );
693731
732+ // Check if given strides match canonical stride
733+ if (assertCanonicalStrides && *assertCanonicalStrides) {
734+ std::vector<int64_t > canonicalStride = getCanonicalStride (shape);
735+ if (!strides.empty () &&
736+ !areStridesEquivalent (shape, strides, canonicalStride)) {
737+ std::string errorMsg =
738+ llvm::formatv (" Given strides [{0}] do not match canonical strides "
739+ " [{1}] for shape [{2}]" ,
740+ strides, llvm::ArrayRef (canonicalStride), shape);
741+ return getInvalidArgStatus (errorMsg.c_str ());
742+ }
743+ }
744+
694745 return std::unique_ptr<MemRefValue>(
695746 new MemRefValue (client, addressSpace, bitsPerElement, ptr, offset, shape,
696747 strides, device, scalarType));
@@ -777,7 +828,7 @@ StatusOr<std::unique_ptr<MemRefValue>> RuntimeClient::allocateMemRef(
777828 PointerType addressSpace, int64_t bitsPerElement,
778829 llvm::ArrayRef<int64_t > shape, llvm::ArrayRef<int64_t > strides,
779830 std::optional<const Device *> device, std::optional<CudaStream> stream,
780- std::optional<ScalarType> scalarType) {
831+ std::optional<ScalarType> scalarType, std::optional< bool > assertCanonicalStrides ) {
781832 if (addressSpace == PointerType::device ||
782833 addressSpace == PointerType::unified) {
783834 if (!device || !*device)
@@ -800,7 +851,7 @@ StatusOr<std::unique_ptr<MemRefValue>> RuntimeClient::allocateMemRef(
800851 // Create the descriptor.
801852 StatusOr<std::unique_ptr<MemRefValue>> bufferImpl =
802853 MemRefValue::create (this , addressSpace, bitsPerElement, allocation->ptr ,
803- 0 , shape, strides, device, scalarType);
854+ 0 , shape, strides, device, scalarType, assertCanonicalStrides );
804855 if (bufferImpl.isError ())
805856 return bufferImpl.getStatus ();
806857
@@ -811,11 +862,11 @@ StatusOr<std::unique_ptr<MemRefValue>> RuntimeClient::createExternalMemRef(
811862 PointerType addressSpace, int64_t bitsPerElement, uintptr_t ptr,
812863 int64_t offset, llvm::ArrayRef<int64_t > shape,
813864 llvm::ArrayRef<int64_t > strides, std::optional<const Device *> device,
814- std::optional<ScalarType> scalarType) {
865+ std::optional<ScalarType> scalarType, std::optional< bool > assertCanonicalStrides ) {
815866 // Create the descriptor.
816867 StatusOr<std::unique_ptr<MemRefValue>> memref =
817868 MemRefValue::create (this , addressSpace, bitsPerElement, ptr, offset,
818- shape, strides, device, scalarType);
869+ shape, strides, device, scalarType, assertCanonicalStrides );
819870 if (!memref.isOk ())
820871 return memref.getStatus ();
821872
0 commit comments