|
1 | 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
2 | 2 | # SPDX-License-Identifier: BSD-2-Clause |
3 | 3 |
|
4 | | -""" |
5 | | -Utilities for signature normalization and type conversion. |
6 | | -
|
7 | | -This module also provides type mapping between numba.core types |
8 | | -and numba.cuda types to ensure cross-compatibility. |
9 | | -""" |
10 | | - |
11 | 4 | from numba.cuda import types, typing |
12 | 5 |
|
13 | 6 | try: |
14 | 7 | from numba.core.typing import Signature as CoreSignature |
15 | | - from numba.core import types as core_types |
16 | 8 |
|
17 | 9 | numba_sig_present = True |
18 | | - core_types_available = True |
19 | 10 | except ImportError: |
20 | 11 | numba_sig_present = False |
21 | | - core_types = None |
22 | | - core_types_available = False |
23 | | - |
24 | | - |
25 | | -def is_numba_type(ty): |
26 | | - """ |
27 | | - Check if a type is a numba.core type and not a numba.cuda type. |
28 | | - """ |
29 | | - if not core_types_available: |
30 | | - return False |
31 | | - return isinstance(ty, core_types.Type) and not isinstance(ty, types.Type) |
32 | | - |
33 | | - |
34 | | -def convert_to_cuda_type(ty): |
35 | | - """ |
36 | | - Convert a numba.core type to its numba.cuda type equivalent if possible. |
37 | | -
|
38 | | - This is the main entry point for type conversion. It handles: |
39 | | - - numba.core.types -> numba.cuda.types conversion |
40 | | - - Recursive conversion for container types (arrays, tuples, optionals) |
41 | | - - Special handling for type wrappers like NumberClass |
42 | | - - Pass-through for types that are already numba.cuda types |
43 | | - """ |
44 | | - if not core_types_available: |
45 | | - return ty |
46 | | - |
47 | | - if isinstance(ty, types.Type): |
48 | | - return ty |
49 | | - |
50 | | - # If it's not a core type at all, return as-is |
51 | | - if not isinstance(ty, core_types.Type): |
52 | | - return ty |
53 | | - |
54 | | - # External types (from third-party libraries) should be returned as-is |
55 | | - # They have their own typing registrations and shouldn't be converted |
56 | | - if hasattr(ty, "__module__") and not ty.__module__.startswith("numba."): |
57 | | - return ty |
58 | | - |
59 | | - if isinstance(ty, core_types.NumberClass): |
60 | | - cuda_inner = convert_to_cuda_type(ty.instance_type) |
61 | | - return types.NumberClass(cuda_inner) |
62 | | - |
63 | | - if isinstance(ty, core_types.TypeRef): |
64 | | - cuda_inner = convert_to_cuda_type(ty.instance_type) |
65 | | - return types.TypeRef(cuda_inner) |
66 | | - |
67 | | - if isinstance(ty, core_types.Literal): |
68 | | - return types.literal(ty.literal_value) |
69 | | - |
70 | | - if isinstance(ty, core_types.Record): |
71 | | - # Convert field types to CUDA types |
72 | | - cuda_fields = [] |
73 | | - for field_name, field_info in ty.fields.items(): |
74 | | - cuda_field_type = convert_to_cuda_type(field_info.type) |
75 | | - cuda_fields.append( |
76 | | - ( |
77 | | - field_name, |
78 | | - {"type": cuda_field_type, "offset": field_info.offset}, |
79 | | - ) |
80 | | - ) |
81 | | - # Create a cuda.types Record with converted field types |
82 | | - return types.Record(cuda_fields, ty.size, ty.aligned) |
83 | | - |
84 | | - if isinstance(ty, core_types.Array): |
85 | | - cuda_dtype = convert_to_cuda_type(ty.dtype) |
86 | | - # Reconstruct array with CUDA dtype, only passing attributes that exist |
87 | | - kwargs = {} |
88 | | - if hasattr(ty, "readonly"): |
89 | | - kwargs["readonly"] = ty.readonly |
90 | | - if hasattr(ty, "aligned"): |
91 | | - kwargs["aligned"] = ty.aligned |
92 | | - return types.Array(cuda_dtype, ty.ndim, ty.layout, **kwargs) |
93 | | - |
94 | | - if isinstance(ty, (core_types.BaseTuple, core_types.BaseAnonymousTuple)): |
95 | | - cuda_elements = tuple(convert_to_cuda_type(t) for t in ty.types) |
96 | | - if isinstance(ty, core_types.UniTuple): |
97 | | - return types.UniTuple(cuda_elements[0], ty.count) |
98 | | - else: |
99 | | - return types.Tuple(cuda_elements) |
100 | | - |
101 | | - if isinstance(ty, core_types.Optional): |
102 | | - cuda_inner = convert_to_cuda_type(ty.type) |
103 | | - return types.Optional(cuda_inner) |
104 | | - |
105 | | - # Handle simple types via name lookup |
106 | | - # This includes: Integer, Float, Complex, Boolean, PyObject, etc. |
107 | | - # Note: Built-in Opaques (none, ellipsis) are converted here |
108 | | - if hasattr(ty, "name") and hasattr(types, ty.name): |
109 | | - cuda_type = getattr(types, ty.name) |
110 | | - if isinstance(cuda_type, types.Type): |
111 | | - return cuda_type |
112 | | - |
113 | | - # Handle custom Opaque types that didn't match in name lookup above |
114 | | - # These are user-defined types (e.g., DummyType in numba.cuda.tests) |
115 | | - if isinstance(ty, core_types.Opaque): |
116 | | - # Return as-is. User should have appropriate typeof registration |
117 | | - # for corresponding target (CUDA, cpu) |
118 | | - return ty |
119 | | - |
120 | | - # Fallback: return as-is (Function, Dispatcher, other special types) |
121 | | - return ty |
122 | 12 |
|
123 | 13 |
|
124 | 14 | def is_signature(sig): |
@@ -167,14 +57,8 @@ def normalize_signature(sig): |
167 | 57 | % (sig, sig.__class__.__name__, parsed.__class__.__name__) |
168 | 58 | ) |
169 | 59 |
|
170 | | - # Convert core types to CUDA types transparently |
171 | | - if return_type is not None: |
172 | | - return_type = convert_to_cuda_type(return_type) |
173 | | - args = tuple(convert_to_cuda_type(ty) for ty in args) |
174 | | - |
175 | 60 | def check_type(ty): |
176 | | - # Accept both CUDA types and numba.core types (for cross-compatibility) |
177 | | - if not (isinstance(ty, types.Type) or is_numba_type(ty)): |
| 61 | + if not isinstance(ty, types.Type): |
178 | 62 | raise TypeError( |
179 | 63 | "invalid type in signature: expected a type " |
180 | 64 | "instance, got %r" % (ty,) |
|
0 commit comments