Skip to content

Commit e287ae2

Browse files
committed
refactor: make shape and strides definitely tuples
1 parent f4c15c7 commit e287ae2

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

numba_cuda/numba/cuda/cudadrv/devicearray.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,13 @@ def __init__(self, shape, strides, dtype, stream=0, gpu_data=None):
8686
"""
8787
if isinstance(shape, int):
8888
shape = (shape,)
89+
else:
90+
shape = tuple(shape)
8991
if isinstance(strides, int):
9092
strides = (strides,)
93+
else:
94+
if strides:
95+
strides = tuple(strides)
9196
dtype = np.dtype(dtype)
9297
itemsize = dtype.itemsize
9398
self.ndim = ndim = len(shape)

numba_cuda/numba/cuda/cudadrv/dummyarray.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -446,8 +446,8 @@ def reshape(self, *newdims, **kws):
446446

447447
ret = self.from_desc(
448448
self.extent.begin,
449-
shape=newdims,
450-
strides=newstrides,
449+
shape=tuple(newdims),
450+
strides=tuple(newstrides),
451451
itemsize=self.itemsize,
452452
)
453453

@@ -475,8 +475,8 @@ def squeeze(self, axis=None):
475475
newstrides.append(stride)
476476
newarr = self.from_desc(
477477
self.extent.begin,
478-
shape=newshape,
479-
strides=newstrides,
478+
shape=tuple(newshape),
479+
strides=tuple(newstrides),
480480
itemsize=self.itemsize,
481481
)
482482
return newarr, list(self.iter_contiguous_extent())

0 commit comments

Comments
 (0)