@@ -86,8 +86,13 @@ def __init__(self, shape, strides, dtype, stream=0, gpu_data=None):
8686 """
8787 if isinstance (shape , int ):
8888 shape = (shape ,)
89+ else :
90+ shape = tuple (shape )
8991 if isinstance (strides , int ):
9092 strides = (strides ,)
93+ else :
94+ if strides :
95+ strides = tuple (strides )
9196 dtype = np .dtype (dtype )
9297 itemsize = dtype .itemsize
9398 self .ndim = ndim = len (shape )
@@ -96,9 +101,6 @@ def __init__(self, shape, strides, dtype, stream=0, gpu_data=None):
96101 self ._dummy = dummy = dummyarray .Array .from_desc (
97102 0 , shape , strides , itemsize
98103 )
99- # confirm that all elements of shape are ints
100- if not all (isinstance (dim , (int , np .integer )) for dim in shape ):
101- raise TypeError ("all elements of shape must be ints" )
102104 self .shape = shape = dummy .shape
103105 self .strides = strides = dummy .strides
104106 self .dtype = dtype
@@ -121,17 +123,17 @@ def __init__(self, shape, strides, dtype, stream=0, gpu_data=None):
121123
122124 @property
123125 def __cuda_array_interface__ (self ):
124- if self .device_ctypes_pointer .value is not None :
125- ptr = self . device_ctypes_pointer . value
126+ if ( value := self .device_ctypes_pointer .value ) is not None :
127+ ptr = value
126128 else :
127129 ptr = 0
128130
129131 return {
130- "shape" : tuple ( self .shape ) ,
132+ "shape" : self .shape ,
131133 "strides" : None if is_contiguous (self ) else tuple (self .strides ),
132134 "data" : (ptr , False ),
133135 "typestr" : self .dtype .str ,
134- "stream" : int (self . stream ) if self .stream != 0 else None ,
136+ "stream" : int (stream ) if ( stream := self .stream ) != 0 else None ,
135137 "version" : 3 ,
136138 }
137139
0 commit comments