|
11 | 11 |
|
12 | 12 | from numba.cuda.core.imputils import Registry |
13 | 13 | from numba.cuda.typing.npydecl import parse_dtype |
14 | | -from numba.cuda.datamodel import models |
| 14 | +from numba.cuda.datamodel.models import StructModel |
15 | 15 | from numba.cuda import types |
16 | 16 | from numba.cuda import cgutils |
17 | 17 | from numba.cuda.np import ufunc_db |
|
21 | 21 | from numba.cuda import nvvmutils, stubs |
22 | 22 | from numba.cuda.types.ext_types import dim3, CUDADispatcher |
23 | 23 |
|
| 24 | +if cuda.HAS_NUMBA: |
| 25 | + from numba.core.datamodel.models import StructModel as CoreStructModel |
| 26 | + from numba.core import types as core_types |
| 27 | + |
24 | 28 | registry = Registry("cudaimpl") |
25 | 29 | lower = registry.lower |
26 | 30 | lower_attr = registry.lower_getattr |
@@ -880,13 +884,19 @@ def _generic_array( |
880 | 884 | raise ValueError("array length <= 0") |
881 | 885 |
|
882 | 886 | # Check that we support the requested dtype |
| 887 | + number_domain = types.number_domain |
| 888 | + struct_model_types = (StructModel,) |
| 889 | + if cuda.HAS_NUMBA: |
| 890 | + number_domain |= core_types.number_domain |
| 891 | + struct_model_types = (StructModel, CoreStructModel) |
| 892 | + |
883 | 893 | data_model = context.data_model_manager[dtype] |
884 | 894 | other_supported_type = ( |
885 | 895 | isinstance(dtype, (types.Record, types.Boolean)) |
886 | | - or isinstance(data_model, models.StructModel) |
| 896 | + or isinstance(data_model, struct_model_types) |
887 | 897 | or dtype == types.float16 |
888 | 898 | ) |
889 | | - if dtype not in types.number_domain and not other_supported_type: |
| 899 | + if dtype not in number_domain and not other_supported_type: |
890 | 900 | raise TypeError("unsupported type: %s" % dtype) |
891 | 901 |
|
892 | 902 | lldtype = context.get_data_type(dtype) |
|
0 commit comments