Skip to content

Commit cc17b5d

Browse files
committed
perf: reduce the number of __cuda_array_interface__ accesses
1 parent f6664ab commit cc17b5d

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

numba_cuda/numba/cuda/api.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,11 @@ def as_cuda_array(obj, sync=True):
7575
If ``sync`` is ``True``, then the imported stream (if present) will be
7676
synchronized.
7777
"""
78-
if not is_cuda_array(obj):
79-
raise TypeError("*obj* doesn't implement the cuda array interface.")
80-
else:
81-
return from_cuda_array_interface(
82-
obj.__cuda_array_interface__, owner=obj, sync=sync
83-
)
78+
if (
79+
interface := getattr(obj, "__cuda_array_interface__", None)
80+
) is not None:
81+
return from_cuda_array_interface(interface, owner=obj, sync=sync)
82+
raise TypeError("*obj* doesn't implement the cuda array interface.")
8483

8584

8685
def is_cuda_array(obj):

numba_cuda/numba/cuda/cudadrv/devicearray.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import numpy as np
1717

18-
import numba
1918
from numba.cuda.cext import _devicearray
2019
from numba.cuda.cudadrv import devices, dummyarray
2120
from numba.cuda.cudadrv import driver as _driver
@@ -901,8 +900,12 @@ def auto_device(obj, stream=0, copy=True, user_explicit=False):
901900
"""
902901
if _driver.is_device_memory(obj):
903902
return obj, False
904-
elif hasattr(obj, "__cuda_array_interface__"):
905-
return numba.cuda.as_cuda_array(obj), False
903+
elif (
904+
interface := getattr(obj, "__cuda_array_interface__", None)
905+
) is not None:
906+
from numba.cuda.api import from_cuda_array_interface
907+
908+
return from_cuda_array_interface(interface, owner=obj), False
906909
else:
907910
if isinstance(obj, np.void):
908911
devobj = from_record_like(obj, stream=stream)

0 commit comments

Comments
 (0)