Skip to content

Commit 7df599f

Browse files
committed
alternative module layout for ext_types
1 parent 5aa6b2e commit 7df599f

File tree

12 files changed

+27
-19
lines changed

12 files changed

+27
-19
lines changed

numba_cuda/numba/cuda/_internal/cuda_bf16.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
uint64,
6161
void,
6262
)
63-
from numba.cuda.ext_types import bfloat16
63+
from numba.cuda.types.ext_types import bfloat16
6464

6565
float32x2 = vector_types["float32x2"]
6666
__half = float16

numba_cuda/numba/cuda/cg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from numba.cuda.typing import signature
77
from numba.cuda import nvvmutils
88
from numba.cuda.extending import intrinsic
9-
from numba.cuda.ext_types import grid_group, GridGroup as GridGroupClass
9+
from numba.cuda.types.ext_types import grid_group, GridGroup as GridGroupClass
1010

1111

1212
class GridGroup:

numba_cuda/numba/cuda/cudadecl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
signature,
1616
Registry,
1717
)
18-
from numba.cuda.ext_types import dim3
18+
from numba.cuda.types.ext_types import dim3
1919
from numba import cuda
2020

2121
registry = Registry()

numba_cuda/numba/cuda/cudaimpl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .cudadrv import nvvm
2020
from numba import cuda
2121
from numba.cuda import nvvmutils, stubs
22-
from numba.cuda.ext_types import dim3, CUDADispatcher
22+
from numba.cuda.types.ext_types import dim3, CUDADispatcher
2323

2424
registry = Registry("cudaimpl")
2525
lower = registry.lower

numba_cuda/numba/cuda/debuginfo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from numba.cuda.core import config
1111
from numba.cuda import cgutils
1212
from numba.cuda.datamodel.models import ComplexModel, UnionModel, UniTupleModel
13-
from numba.cuda.ext_types import GridGroup
13+
from numba.cuda.types.ext_types import GridGroup
1414

1515

1616
@contextmanager

numba_cuda/numba/cuda/dispatcher.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
from numba.cuda.typing.templates import fold_arguments
2525
from numba.cuda.typing.typeof import Purpose, typeof
2626

27-
from numba.cuda import typing, types, ext_types
27+
from numba.cuda import typing, types
28+
from numba.cuda.types import ext_types
2829
from numba.cuda.api import get_current_device
2930
from numba.cuda.args import wrap_arg
3031
from numba.cuda.core.bytecode import get_code_object

numba_cuda/numba/cuda/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from numba.cuda.datamodel.models import StructModel
1111
from numba.cuda.extending import core_models as models
1212
from numba.cuda import types
13-
from numba.cuda.ext_types import Dim3, GridGroup, CUDADispatcher, Bfloat16
13+
from numba.cuda.types.ext_types import Dim3, GridGroup, CUDADispatcher, Bfloat16
1414

1515

1616
cuda_data_manager = DataModelManager()

numba_cuda/numba/cuda/printimpl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from numba.cuda.core.errors import NumbaWarning
99
from numba.cuda.core.imputils import Registry
1010
from numba.cuda import nvvmutils
11-
from numba.cuda.ext_types import Dim3, Bfloat16
11+
from numba.cuda.types.ext_types import Dim3, Bfloat16
1212
from warnings import warn
1313

1414
registry = Registry("printimpl")

numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,9 @@ def kernel(out):
575575
_bf16_ulp_distance(raw[4:], f8_expected), 2
576576
)
577577

578+
def test_bfloat16_type_import(self):
579+
self.skip_unsupported()
580+
578581

579582
def _bf16_ulp_rank(bits_int16: np.ndarray) -> np.ndarray:
580583
"""

numba_cuda/numba/cuda/tests/cudapy/test_print.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def print_too_many(r):
110110
def print_bfloat16():
111111
# 0.9375 is a dyadic rational, it's integer significand can expand within 7 digits.
112112
# printing this should not give any rounding error.
113-
a = cuda.ext_types.bfloat16(0.9375)
113+
a = cuda.bfloat16(0.9375)
114114
print(a, a, a)
115115
116116
print_bfloat16[1, 1]()

0 commit comments

Comments
 (0)