Skip to content

Commit 4b513c3

Browse files
committed
Consider core Numba number domain and struct model in array type check
1 parent f28dfe2 commit 4b513c3

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

numba_cuda/numba/cuda/cudaimpl.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from numba.cuda.core.imputils import Registry
1313
from numba.cuda.typing.npydecl import parse_dtype
14-
from numba.cuda.datamodel import models
14+
from numba.cuda.datamodel.models import StructModel
1515
from numba.cuda import types
1616
from numba.cuda import cgutils
1717
from numba.cuda.np import ufunc_db
@@ -21,6 +21,10 @@
2121
from numba.cuda import nvvmutils, stubs
2222
from numba.cuda.types.ext_types import dim3, CUDADispatcher
2323

24+
if cuda.HAS_NUMBA:
25+
from numba.core.datamodel.models import StructModel as CoreStructModel
26+
from numba.core import types as core_types
27+
2428
registry = Registry("cudaimpl")
2529
lower = registry.lower
2630
lower_attr = registry.lower_getattr
@@ -880,13 +884,19 @@ def _generic_array(
880884
raise ValueError("array length <= 0")
881885

882886
# Check that we support the requested dtype
887+
number_domain = types.number_domain
888+
struct_model_types = (StructModel,)
889+
if cuda.HAS_NUMBA:
890+
number_domain |= core_types.number_domain
891+
struct_model_types = (StructModel, CoreStructModel)
892+
883893
data_model = context.data_model_manager[dtype]
884894
other_supported_type = (
885895
isinstance(dtype, (types.Record, types.Boolean))
886-
or isinstance(data_model, models.StructModel)
896+
or isinstance(data_model, struct_model_types)
887897
or dtype == types.float16
888898
)
889-
if dtype not in types.number_domain and not other_supported_type:
899+
if dtype not in number_domain and not other_supported_type:
890900
raise TypeError("unsupported type: %s" % dtype)
891901

892902
lldtype = context.get_data_type(dtype)

0 commit comments

Comments
 (0)