Skip to content

Commit 244a38e

Browse files
committed
Add type-mapping for Records
1 parent 6722536 commit 244a38e

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

numba_cuda/numba/cuda/core/sigutils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ def convert_to_cuda_type(ty):
5151
if not isinstance(ty, core_types.Type):
5252
return ty
5353

54+
# External types (from third-party libraries) should be returned as-is
55+
# They have their own typing registrations and shouldn't be converted
56+
if hasattr(ty, "__module__") and not ty.__module__.startswith("numba."):
57+
return ty
58+
5459
if isinstance(ty, core_types.NumberClass):
5560
cuda_inner = convert_to_cuda_type(ty.instance_type)
5661
return types.NumberClass(cuda_inner)
@@ -60,7 +65,21 @@ def convert_to_cuda_type(ty):
6065
return types.TypeRef(cuda_inner)
6166

6267
if isinstance(ty, core_types.Literal):
63-
return ty
68+
return types.literal(ty.literal_value)
69+
70+
if isinstance(ty, core_types.Record):
71+
# Convert field types to CUDA types
72+
cuda_fields = []
73+
for field_name, field_info in ty.fields.items():
74+
cuda_field_type = convert_to_cuda_type(field_info.type)
75+
cuda_fields.append(
76+
(
77+
field_name,
78+
{"type": cuda_field_type, "offset": field_info.offset},
79+
)
80+
)
81+
# Create a cuda.types Record with converted field types
82+
return types.Record(cuda_fields, ty.size, ty.aligned)
6483

6584
if isinstance(ty, core_types.Array):
6685
cuda_dtype = convert_to_cuda_type(ty.dtype)

numba_cuda/numba/cuda/core/typeinfer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
)
5252
from numba.cuda.core.funcdesc import qualifying_prefix
5353
from numba.cuda.typeconv import Conversion
54+
from numba.cuda.core.sigutils import is_numba_type, convert_to_cuda_type
5455

5556
_logger = logging.getLogger(__name__)
5657

@@ -79,15 +80,14 @@ def _ensure_cuda_type(self, tp):
7980
Convert numba.core types to numba.cuda types if necessary.
8081
This ensures cross-compatibility.
8182
"""
82-
from numba.cuda.core.sigutils import is_numba_type, convert_to_cuda_type
8383

8484
if is_numba_type(tp):
8585
tp = convert_to_cuda_type(tp)
8686
return tp
8787

8888
def add_type(self, tp, loc):
8989
tp = self._ensure_cuda_type(tp)
90-
assert isinstance(tp, types.Type), type(tp)
90+
assert isinstance(tp, types.Type) or is_numba_type(tp), type(tp)
9191
# Special case for _undef_var.
9292
# If the typevar is the _undef_var, use the incoming type directly.
9393
if self.type is types._undef_var:
@@ -122,7 +122,7 @@ def add_type(self, tp, loc):
122122

123123
def lock(self, tp, loc, literal_value=NOTSET):
124124
tp = self._ensure_cuda_type(tp)
125-
assert isinstance(tp, types.Type), type(tp)
125+
assert isinstance(tp, types.Type) or is_numba_type(tp), type(tp)
126126

127127
if self.locked:
128128
msg = (

0 commit comments

Comments
 (0)