@@ -89,26 +89,27 @@ def __init__(self, shape, strides, dtype, stream=0, gpu_data=None):
8989 if isinstance (strides , int ):
9090 strides = (strides ,)
9191 dtype = np .dtype (dtype )
92- self .ndim = len (shape )
93- if len (strides ) != self .ndim :
92+ itemsize = dtype .itemsize
93+ self .ndim = ndim = len (shape )
94+ if len (strides ) != ndim :
9495 raise ValueError ("strides not match ndim" )
95- self ._dummy = dummyarray .Array .from_desc (
96- 0 , shape , strides , dtype . itemsize
96+ self ._dummy = dummy = dummyarray .Array .from_desc (
97+ 0 , shape , strides , itemsize
9798 )
9899 # confirm that all elements of shape are ints
99100 if not all (isinstance (dim , (int , np .integer )) for dim in shape ):
100101 raise TypeError ("all elements of shape must be ints" )
101- self .shape = tuple ( shape )
102- self .strides = tuple ( strides )
102+ self .shape = shape = dummy . shape
103+ self .strides = strides = dummy . strides
103104 self .dtype = dtype
104- self .size = int ( functools . reduce ( operator . mul , self . shape , 1 ))
105+ self .size = size = dummy . size
105106 # prepare gpu memory
106- if self . size > 0 :
107- self .alloc_size = _driver .memory_size_from_info (
108- self . shape , self . strides , self . dtype . itemsize
107+ if size :
108+ self .alloc_size = alloc_size = _driver .memory_size_from_info (
109+ shape , strides , itemsize
109110 )
110111 if gpu_data is None :
111- gpu_data = devices .get_context ().memalloc (self . alloc_size )
112+ gpu_data = devices .get_context ().memalloc (alloc_size )
112113 else :
113114 # Make NULL pointer for empty allocation
114115 null = _driver .binding .CUdeviceptr (0 )
0 commit comments