Skip to content

Commit 0b6601c

Browse files
committed
Redirect to numba.types and numba.core.datamodel if they are available
1 parent 00ec227 commit 0b6601c

37 files changed

+175
-143
lines changed

numba_cuda/numba/cuda/_internal/cuda_bf16.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
uint64,
6161
void,
6262
)
63-
from numba.cuda.types import bfloat16
63+
from numba.cuda.ext_types import bfloat16
6464

6565
float32x2 = vector_types["float32x2"]
6666
__half = float16

numba_cuda/numba/cuda/cg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from numba.cuda.typing import signature
77
from numba.cuda import nvvmutils
88
from numba.cuda.extending import intrinsic
9-
from numba.cuda.types import grid_group, GridGroup as GridGroupClass
9+
from numba.cuda.ext_types import grid_group, GridGroup as GridGroupClass
1010

1111

1212
class GridGroup:

numba_cuda/numba/cuda/core/typeinfer.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
)
5252
from numba.cuda.core.funcdesc import qualifying_prefix
5353
from numba.cuda.typeconv import Conversion
54-
from numba.cuda.core.sigutils import is_numba_type, convert_to_cuda_type
5554

5655
_logger = logging.getLogger(__name__)
5756

@@ -75,19 +74,8 @@ def __init__(self, context, var):
7574
# Qualifiers
7675
self.literal_value = NOTSET
7776

78-
def _ensure_cuda_type(self, tp):
79-
"""
80-
Convert numba.core types to numba.cuda types if necessary.
81-
This ensures cross-compatibility.
82-
"""
83-
84-
if is_numba_type(tp):
85-
tp = convert_to_cuda_type(tp)
86-
return tp
87-
8877
def add_type(self, tp, loc):
89-
tp = self._ensure_cuda_type(tp)
90-
assert isinstance(tp, types.Type) or is_numba_type(tp), type(tp)
78+
assert isinstance(tp, types.Type), type(tp)
9179
# Special case for _undef_var.
9280
# If the typevar is the _undef_var, use the incoming type directly.
9381
if self.type is types._undef_var:
@@ -121,8 +109,7 @@ def add_type(self, tp, loc):
121109
return self.type
122110

123111
def lock(self, tp, loc, literal_value=NOTSET):
124-
tp = self._ensure_cuda_type(tp)
125-
assert isinstance(tp, types.Type) or is_numba_type(tp), type(tp)
112+
assert isinstance(tp, types.Type), type(tp)
126113

127114
if self.locked:
128115
msg = (

numba_cuda/numba/cuda/types/__init__.py renamed to numba_cuda/numba/cuda/cuda_types/__init__.py

Lines changed: 0 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -165,100 +165,6 @@
165165
longlong = np_longlong = _make_signed(np.longlong)
166166
ulonglong = np_ulonglong = _make_unsigned(np.longlong)
167167

168-
169-
class Dim3(types.Type):
170-
"""
171-
A 3-tuple (x, y, z) representing the position of a block or thread.
172-
"""
173-
174-
def __init__(self):
175-
super().__init__(name="Dim3")
176-
177-
178-
class GridGroup(types.Type):
179-
"""
180-
The grid of all threads in a cooperative kernel launch.
181-
"""
182-
183-
def __init__(self):
184-
super().__init__(name="GridGroup")
185-
186-
187-
dim3 = Dim3()
188-
grid_group = GridGroup()
189-
190-
191-
class CUDADispatcher(types.Dispatcher):
192-
"""The type of CUDA dispatchers"""
193-
194-
# This type exists (instead of using types.Dispatcher as the type of CUDA
195-
# dispatchers) so that we can have an alternative lowering for them to the
196-
# lowering of CPU dispatchers - the CPU target lowers all dispatchers as a
197-
# constant address, but we need to lower to a dummy value because it's not
198-
# generally valid to use the address of CUDA kernels and functions.
199-
#
200-
# Notes: it may be a bug in the CPU target that it lowers all dispatchers to
201-
# a constant address - it should perhaps only lower dispatchers acting as
202-
# first-class functions to a constant address. Even if that bug is fixed, it
203-
# is still probably a good idea to have a separate type for CUDA
204-
# dispatchers, and this type might get other differentiation from the CPU
205-
# dispatcher type in future.
206-
207-
208-
class Bfloat16(Number):
209-
"""
210-
A bfloat16 type. Has 8 exponent bits and 7 significand bits.
211-
212-
Conversion rules:
213-
Floats:
214-
from:
215-
fp32, fp64: UNSAFE
216-
fp16: UNSAFE (loses precision)
217-
to:
218-
fp32, fp64: PROMOTE (same exponent, more mantissa)
219-
fp16: UNSAFE (loses range)
220-
221-
Integers:
222-
from:
223-
int8: SAFE
224-
other int: All UNSAFE (bf16 cannot represent all integers in range)
225-
to: UNSAFE (loses precision, round to zeros)
226-
227-
All other conversions are not allowed.
228-
"""
229-
230-
def __init__(self):
231-
super().__init__(name="__nv_bfloat16")
232-
233-
self.alignof_ = 2
234-
self.bitwidth = 16
235-
236-
def can_convert_from(self, typingctx, other):
237-
if isinstance(other, types.Float):
238-
return Conversion.unsafe
239-
240-
elif isinstance(other, types.Integer):
241-
if other.bitwidth == 8:
242-
return Conversion.safe
243-
else:
244-
return Conversion.unsafe
245-
246-
def can_convert_to(self, typingctx, other):
247-
if isinstance(other, types.Float):
248-
if other.bitwidth >= 32:
249-
return Conversion.safe
250-
else:
251-
return Conversion.unsafe
252-
elif isinstance(other, types.Integer):
253-
return Conversion.unsafe
254-
255-
def unify(self, typingctx, other):
256-
if isinstance(other, (types.Float, types.Integer)):
257-
return typingctx.unify_pairs(self, other)
258-
259-
260-
bfloat16 = Bfloat16()
261-
262168
all_str = """
263169
int8
264170
int16
@@ -312,9 +218,6 @@ def unify(self, typingctx, other):
312218
ffi_forced_object
313219
ffi
314220
deferred_type
315-
dim3
316-
grid_group
317-
bfloat16
318221
"""
319222

320223

0 commit comments

Comments
 (0)