33
44import numpy as np
55
6+ import functools
7+
68
79def prepare_shape_strides_dtype (shape , strides , dtype , order ):
810 dtype = np .dtype (dtype )
@@ -14,25 +16,33 @@ def prepare_shape_strides_dtype(shape, strides, dtype, order):
1416 raise TypeError ("shape must be an integer or tuple of integers" )
1517 if isinstance (shape , int ):
1618 shape = (shape ,)
19+ else :
20+ shape = tuple (shape )
1721 if isinstance (strides , int ):
1822 strides = (strides ,)
1923 else :
20- strides = strides or _fill_stride_by_order (shape , dtype , order )
24+ if not strides :
25+ strides = _fill_stride_by_order (shape , dtype , order )
26+ else :
27+ strides = tuple (strides )
2128 return shape , strides , dtype
2229
2330
31+ @functools .cache
2432def _fill_stride_by_order (shape , dtype , order ):
25- nd = len (shape )
26- if nd == 0 :
33+ ndims = len (shape )
34+ if not ndims :
2735 return ()
28- strides = [0 ] * nd
36+ strides = [0 ] * ndims
2937 if order == "C" :
3038 strides [- 1 ] = dtype .itemsize
31- for d in reversed (range (nd - 1 )):
39+ # -2 because we subtract one for zero-based indexing and another one
40+ # for skipping the already-filled-in last element
41+ for d in range (ndims - 2 , - 1 , - 1 ):
3242 strides [d ] = strides [d + 1 ] * shape [d + 1 ]
3343 elif order == "F" :
3444 strides [0 ] = dtype .itemsize
35- for d in range (1 , nd ):
45+ for d in range (1 , ndims ):
3646 strides [d ] = strides [d - 1 ] * shape [d - 1 ]
3747 else :
3848 raise ValueError ("must be either C/F order" )
0 commit comments