|
1 | | -import numpy as np |
2 | | - |
3 | 1 | from numba.core import errors, types |
4 | 2 | from numba.core.extending import overload |
5 | 3 | from numba.np.arrayobj import (_check_const_str_dtype, is_nonelike, |
6 | | - ty_parse_dtype, ty_parse_shape, numpy_empty_nd) |
| 4 | + ty_parse_dtype, ty_parse_shape, numpy_empty_nd, |
| 5 | + numpy_empty_like_nd) |
7 | 6 |
|
8 | 7 |
|
9 | 8 | # Typical tests for allocation use array construction (e.g. np.zeros, np.empty, |
@@ -48,12 +47,20 @@ def impl(shape, dtype): |
48 | 47 |
|
49 | 48 |
|
50 | 49 | @overload(cuda_empty_like) |
51 | | -def ol_cuda_empty_like(a, dtype=None): |
52 | | - _check_const_str_dtype("zeros_like", dtype) |
53 | | - |
54 | | - # NumPy uses 'a' as the arg name for the array-like |
55 | | - def impl(a, dtype=None): |
56 | | - arr = np.empty_like(a, dtype=dtype) |
57 | | - arr._zero_fill() |
58 | | - return arr |
| 50 | +def ol_cuda_empty_like(arr): |
| 51 | + |
| 52 | + if isinstance(arr, types.Array): |
| 53 | + nb_dtype = arr.dtype |
| 54 | + else: |
| 55 | + nb_dtype = arr |
| 56 | + |
| 57 | + if isinstance(arr, types.Array): |
| 58 | + layout = arr.layout if arr.layout != 'A' else 'C' |
| 59 | + retty = arr.copy(dtype=nb_dtype, layout=layout, readonly=False) |
| 60 | + else: |
| 61 | + retty = types.Array(nb_dtype, 0, 'C') |
| 62 | + |
| 63 | + def impl(arr): |
| 64 | + dtype = None |
| 65 | + return numpy_empty_like_nd(arr, dtype, retty) |
59 | 66 | return impl |
0 commit comments