Skip to content

Commit 3ebbe29

Browse files
authored
perf: cache dimension computations (#542)
This PR adds some `functools.cache` decorations to a few functions for computing dimensions, which are repeatedly computed for the same inputs many times. This does trade some space to reduce duration of conversion. That space will scale with the number of unique combinations of shape, strides, dtypes, and memory ordering. If someone has a lot of arrays that have different shapes, strides, dtypes, and memory ordering, the cache can potentially use a lot of memory. I'm not sure that's a likely scenario, but it's worth pointing out as a potential issue.
1 parent 2567b28 commit 3ebbe29

File tree

8 files changed

+43
-23
lines changed

8 files changed

+43
-23
lines changed

numba_cuda/numba/cuda/api.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,9 @@ def from_cuda_array_interface(desc, owner=None, sync=True):
3939

4040
shape = desc["shape"]
4141
strides = desc.get("strides")
42-
dtype = np.dtype(desc["typestr"])
4342

4443
shape, strides, dtype = prepare_shape_strides_dtype(
45-
shape, strides, dtype, order="C"
44+
shape, strides, desc["typestr"], order="C"
4645
)
4746
size = driver.memory_size_from_info(shape, strides, dtype.itemsize)
4847

numba_cuda/numba/cuda/api_util.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import numpy as np
55

6+
import functools
7+
68

79
def prepare_shape_strides_dtype(shape, strides, dtype, order):
810
dtype = np.dtype(dtype)
@@ -14,25 +16,33 @@ def prepare_shape_strides_dtype(shape, strides, dtype, order):
1416
raise TypeError("shape must be an integer or tuple of integers")
1517
if isinstance(shape, int):
1618
shape = (shape,)
19+
else:
20+
shape = tuple(shape)
1721
if isinstance(strides, int):
1822
strides = (strides,)
1923
else:
20-
strides = strides or _fill_stride_by_order(shape, dtype, order)
24+
if not strides:
25+
strides = _fill_stride_by_order(shape, dtype, order)
26+
else:
27+
strides = tuple(strides)
2128
return shape, strides, dtype
2229

2330

31+
@functools.cache
2432
def _fill_stride_by_order(shape, dtype, order):
25-
nd = len(shape)
26-
if nd == 0:
33+
ndims = len(shape)
34+
if not ndims:
2735
return ()
28-
strides = [0] * nd
36+
strides = [0] * ndims
2937
if order == "C":
3038
strides[-1] = dtype.itemsize
31-
for d in reversed(range(nd - 1)):
39+
# -2 because we subtract one for zero-based indexing and another one
40+
# for skipping the already-filled-in last element
41+
for d in range(ndims - 2, -1, -1):
3242
strides[d] = strides[d + 1] * shape[d + 1]
3343
elif order == "F":
3444
strides[0] = dtype.itemsize
35-
for d in range(1, nd):
45+
for d in range(1, ndims):
3646
strides[d] = strides[d - 1] * shape[d - 1]
3747
else:
3848
raise ValueError("must be either C/F order")

numba_cuda/numba/cuda/cudadrv/devicearray.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,13 @@ def __init__(self, shape, strides, dtype, stream=0, gpu_data=None):
8686
"""
8787
if isinstance(shape, int):
8888
shape = (shape,)
89+
else:
90+
shape = tuple(shape)
8991
if isinstance(strides, int):
9092
strides = (strides,)
93+
else:
94+
if strides:
95+
strides = tuple(strides)
9196
dtype = np.dtype(dtype)
9297
itemsize = dtype.itemsize
9398
self.ndim = ndim = len(shape)
@@ -96,9 +101,6 @@ def __init__(self, shape, strides, dtype, stream=0, gpu_data=None):
96101
self._dummy = dummy = dummyarray.Array.from_desc(
97102
0, shape, strides, itemsize
98103
)
99-
# confirm that all elements of shape are ints
100-
if not all(isinstance(dim, (int, np.integer)) for dim in shape):
101-
raise TypeError("all elements of shape must be ints")
102104
self.shape = shape = dummy.shape
103105
self.strides = strides = dummy.strides
104106
self.dtype = dtype
@@ -121,17 +123,17 @@ def __init__(self, shape, strides, dtype, stream=0, gpu_data=None):
121123

122124
@property
123125
def __cuda_array_interface__(self):
124-
if self.device_ctypes_pointer.value is not None:
125-
ptr = self.device_ctypes_pointer.value
126+
if (value := self.device_ctypes_pointer.value) is not None:
127+
ptr = value
126128
else:
127129
ptr = 0
128130

129131
return {
130-
"shape": tuple(self.shape),
132+
"shape": self.shape,
131133
"strides": None if is_contiguous(self) else tuple(self.strides),
132134
"data": (ptr, False),
133135
"typestr": self.dtype.str,
134-
"stream": int(self.stream) if self.stream != 0 else None,
136+
"stream": int(stream) if (stream := self.stream) != 0 else None,
135137
"version": 3,
136138
}
137139

numba_cuda/numba/cuda/cudadrv/driver.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3023,6 +3023,7 @@ def host_memory_extents(obj):
30233023
return mviewbuf.memoryview_get_extents(obj)
30243024

30253025

3026+
@functools.cache
30263027
def memory_size_from_info(shape, strides, itemsize):
30273028
"""Get the byte size of a contiguous memory buffer given the shape, strides
30283029
and itemsize.

numba_cuda/numba/cuda/cudadrv/dummyarray.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import itertools
66
import functools
77
import operator
8+
import numpy as np
89

910

1011
Extent = namedtuple("Extent", ["begin", "end"])
@@ -245,9 +246,12 @@ class Array(object):
245246
is_array = True
246247

247248
@classmethod
249+
@functools.cache
248250
def from_desc(cls, offset, shape, strides, itemsize):
249251
dims = []
250252
for ashape, astride in zip(shape, strides):
253+
if not isinstance(ashape, (int, np.integer)):
254+
raise TypeError("all elements of shape must be ints")
251255
dim = Dim(
252256
offset, offset + ashape * astride, ashape, astride, single=False
253257
)
@@ -442,8 +446,8 @@ def reshape(self, *newdims, **kws):
442446

443447
ret = self.from_desc(
444448
self.extent.begin,
445-
shape=newdims,
446-
strides=newstrides,
449+
shape=tuple(newdims),
450+
strides=tuple(newstrides),
447451
itemsize=self.itemsize,
448452
)
449453

@@ -471,8 +475,8 @@ def squeeze(self, axis=None):
471475
newstrides.append(stride)
472476
newarr = self.from_desc(
473477
self.extent.begin,
474-
shape=newshape,
475-
strides=newstrides,
478+
shape=tuple(newshape),
479+
strides=tuple(newstrides),
476480
itemsize=self.itemsize,
477481
)
478482
return newarr, list(self.iter_contiguous_extent())

numba_cuda/numba/cuda/dispatcher.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1629,11 +1629,15 @@ def typeof_pyval(self, val):
16291629
try:
16301630
return typeof(val, Purpose.argument)
16311631
except ValueError:
1632-
if cuda.is_cuda_array(val):
1632+
if (
1633+
interface := getattr(val, "__cuda_array_interface__")
1634+
) is not None:
16331635
# When typing, we don't need to synchronize on the array's
16341636
# stream - this is done when the kernel is launched.
1637+
16351638
return typeof(
1636-
cuda.as_cuda_array(val, sync=False), Purpose.argument
1639+
cuda.from_cuda_array_interface(interface, sync=False),
1640+
Purpose.argument,
16371641
)
16381642
else:
16391643
raise

pixi.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ benchcmp = { cmd = [
198198
"numba.cuda.tests.benchmarks",
199199
"--benchmark-only",
200200
"--benchmark-enable",
201-
"--benchmark-group-by=func",
201+
"--benchmark-group-by=name",
202202
"--benchmark-compare",
203203
] }
204204

0 commit comments

Comments
 (0)