Skip to content

Commit 70b4167

Browse files
authored
Merge branch 'main' into vk/types
2 parents 3827cee + 3ebbe29 commit 70b4167

File tree

8 files changed

+43
-23
lines changed

8 files changed

+43
-23
lines changed

numba_cuda/numba/cuda/api.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,9 @@ def from_cuda_array_interface(desc, owner=None, sync=True):
3939

4040
shape = desc["shape"]
4141
strides = desc.get("strides")
42-
dtype = np.dtype(desc["typestr"])
4342

4443
shape, strides, dtype = prepare_shape_strides_dtype(
45-
shape, strides, dtype, order="C"
44+
shape, strides, desc["typestr"], order="C"
4645
)
4746
size = driver.memory_size_from_info(shape, strides, dtype.itemsize)
4847

numba_cuda/numba/cuda/api_util.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import numpy as np
55

6+
import functools
7+
68

79
def prepare_shape_strides_dtype(shape, strides, dtype, order):
810
dtype = np.dtype(dtype)
@@ -14,25 +16,33 @@ def prepare_shape_strides_dtype(shape, strides, dtype, order):
1416
raise TypeError("shape must be an integer or tuple of integers")
1517
if isinstance(shape, int):
1618
shape = (shape,)
19+
else:
20+
shape = tuple(shape)
1721
if isinstance(strides, int):
1822
strides = (strides,)
1923
else:
20-
strides = strides or _fill_stride_by_order(shape, dtype, order)
24+
if not strides:
25+
strides = _fill_stride_by_order(shape, dtype, order)
26+
else:
27+
strides = tuple(strides)
2128
return shape, strides, dtype
2229

2330

31+
@functools.cache
2432
def _fill_stride_by_order(shape, dtype, order):
25-
nd = len(shape)
26-
if nd == 0:
33+
ndims = len(shape)
34+
if not ndims:
2735
return ()
28-
strides = [0] * nd
36+
strides = [0] * ndims
2937
if order == "C":
3038
strides[-1] = dtype.itemsize
31-
for d in reversed(range(nd - 1)):
39+
# -2 because we subtract one for zero-based indexing and another one
40+
# for skipping the already-filled-in last element
41+
for d in range(ndims - 2, -1, -1):
3242
strides[d] = strides[d + 1] * shape[d + 1]
3343
elif order == "F":
3444
strides[0] = dtype.itemsize
35-
for d in range(1, nd):
45+
for d in range(1, ndims):
3646
strides[d] = strides[d - 1] * shape[d - 1]
3747
else:
3848
raise ValueError("must be either C/F order")

numba_cuda/numba/cuda/cudadrv/devicearray.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,13 @@ def __init__(self, shape, strides, dtype, stream=0, gpu_data=None):
8686
"""
8787
if isinstance(shape, int):
8888
shape = (shape,)
89+
else:
90+
shape = tuple(shape)
8991
if isinstance(strides, int):
9092
strides = (strides,)
93+
else:
94+
if strides:
95+
strides = tuple(strides)
9196
dtype = np.dtype(dtype)
9297
itemsize = dtype.itemsize
9398
self.ndim = ndim = len(shape)
@@ -96,9 +101,6 @@ def __init__(self, shape, strides, dtype, stream=0, gpu_data=None):
96101
self._dummy = dummy = dummyarray.Array.from_desc(
97102
0, shape, strides, itemsize
98103
)
99-
# confirm that all elements of shape are ints
100-
if not all(isinstance(dim, (int, np.integer)) for dim in shape):
101-
raise TypeError("all elements of shape must be ints")
102104
self.shape = shape = dummy.shape
103105
self.strides = strides = dummy.strides
104106
self.dtype = dtype
@@ -121,17 +123,17 @@ def __init__(self, shape, strides, dtype, stream=0, gpu_data=None):
121123

122124
@property
123125
def __cuda_array_interface__(self):
124-
if self.device_ctypes_pointer.value is not None:
125-
ptr = self.device_ctypes_pointer.value
126+
if (value := self.device_ctypes_pointer.value) is not None:
127+
ptr = value
126128
else:
127129
ptr = 0
128130

129131
return {
130-
"shape": tuple(self.shape),
132+
"shape": self.shape,
131133
"strides": None if is_contiguous(self) else tuple(self.strides),
132134
"data": (ptr, False),
133135
"typestr": self.dtype.str,
134-
"stream": int(self.stream) if self.stream != 0 else None,
136+
"stream": int(stream) if (stream := self.stream) != 0 else None,
135137
"version": 3,
136138
}
137139

numba_cuda/numba/cuda/cudadrv/driver.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3023,6 +3023,7 @@ def host_memory_extents(obj):
30233023
return mviewbuf.memoryview_get_extents(obj)
30243024

30253025

3026+
@functools.cache
30263027
def memory_size_from_info(shape, strides, itemsize):
30273028
"""Get the byte size of a contiguous memory buffer given the shape, strides
30283029
and itemsize.

numba_cuda/numba/cuda/cudadrv/dummyarray.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import itertools
66
import functools
77
import operator
8+
import numpy as np
89

910

1011
Extent = namedtuple("Extent", ["begin", "end"])
@@ -245,9 +246,12 @@ class Array(object):
245246
is_array = True
246247

247248
@classmethod
249+
@functools.cache
248250
def from_desc(cls, offset, shape, strides, itemsize):
249251
dims = []
250252
for ashape, astride in zip(shape, strides):
253+
if not isinstance(ashape, (int, np.integer)):
254+
raise TypeError("all elements of shape must be ints")
251255
dim = Dim(
252256
offset, offset + ashape * astride, ashape, astride, single=False
253257
)
@@ -442,8 +446,8 @@ def reshape(self, *newdims, **kws):
442446

443447
ret = self.from_desc(
444448
self.extent.begin,
445-
shape=newdims,
446-
strides=newstrides,
449+
shape=tuple(newdims),
450+
strides=tuple(newstrides),
447451
itemsize=self.itemsize,
448452
)
449453

@@ -471,8 +475,8 @@ def squeeze(self, axis=None):
471475
newstrides.append(stride)
472476
newarr = self.from_desc(
473477
self.extent.begin,
474-
shape=newshape,
475-
strides=newstrides,
478+
shape=tuple(newshape),
479+
strides=tuple(newstrides),
476480
itemsize=self.itemsize,
477481
)
478482
return newarr, list(self.iter_contiguous_extent())

numba_cuda/numba/cuda/dispatcher.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1628,11 +1628,15 @@ def typeof_pyval(self, val):
16281628
try:
16291629
return typeof(val, Purpose.argument)
16301630
except ValueError:
1631-
if cuda.is_cuda_array(val):
1631+
if (
1632+
interface := getattr(val, "__cuda_array_interface__")
1633+
) is not None:
16321634
# When typing, we don't need to synchronize on the array's
16331635
# stream - this is done when the kernel is launched.
1636+
16341637
return typeof(
1635-
cuda.as_cuda_array(val, sync=False), Purpose.argument
1638+
cuda.from_cuda_array_interface(interface, sync=False),
1639+
Purpose.argument,
16361640
)
16371641
else:
16381642
raise

pixi.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ benchcmp = { cmd = [
204204
"numba.cuda.tests.benchmarks",
205205
"--benchmark-only",
206206
"--benchmark-enable",
207-
"--benchmark-group-by=func",
207+
"--benchmark-group-by=name",
208208
"--benchmark-compare",
209209
] }
210210

0 commit comments

Comments
 (0)