Skip to content

Commit f7ac1c6

Browse files
committed
porting the correct np_empty_like implementation
1 parent f1f8377 commit f7ac1c6

File tree

1 file changed

+18
-11
lines changed

1 file changed

+18
-11
lines changed

numba_cuda/numba/cuda/tests/nrt/mock_numpy.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
import numpy as np
2-
31
from numba.core import errors, types
42
from numba.core.extending import overload
53
from numba.np.arrayobj import (_check_const_str_dtype, is_nonelike,
6-
ty_parse_dtype, ty_parse_shape, numpy_empty_nd)
4+
ty_parse_dtype, ty_parse_shape, numpy_empty_nd,
5+
numpy_empty_like_nd)
76

87

98
# Typical tests for allocation use array construction (e.g. np.zeros, np.empty,
@@ -48,12 +47,20 @@ def impl(shape, dtype):
4847

4948

5049
@overload(cuda_empty_like)
51-
def ol_cuda_empty_like(a, dtype=None):
52-
_check_const_str_dtype("zeros_like", dtype)
53-
54-
# NumPy uses 'a' as the arg name for the array-like
55-
def impl(a, dtype=None):
56-
arr = np.empty_like(a, dtype=dtype)
57-
arr._zero_fill()
58-
return arr
50+
def ol_cuda_empty_like(arr):
51+
52+
if isinstance(arr, types.Array):
53+
nb_dtype = arr.dtype
54+
else:
55+
nb_dtype = arr
56+
57+
if isinstance(arr, types.Array):
58+
layout = arr.layout if arr.layout != 'A' else 'C'
59+
retty = arr.copy(dtype=nb_dtype, layout=layout, readonly=False)
60+
else:
61+
retty = types.Array(nb_dtype, 0, 'C')
62+
63+
def impl(arr):
64+
dtype = None
65+
return numpy_empty_like_nd(arr, dtype, retty)
5966
return impl

0 commit comments

Comments
 (0)