Skip to content

Commit 49e4030

Browse files
merge/resolve
2 parents 467f814 + 9614920 commit 49e4030

File tree

222 files changed

+8089
-364
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

222 files changed

+8089
-364
lines changed

docs/source/user/cudapysupported.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ The following functions from the :mod:`math` module are supported:
214214
* :func:`math.erf`
215215
* :func:`math.erfc`
216216
* :func:`math.exp`
217+
* :func:`math.exp2`
217218
* :func:`math.expm1`
218219
* :func:`math.fabs`
219220
* :func:`math.frexp`

numba_cuda/numba/cuda/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77
import warnings
88
import sys
99

10+
# Re-export types itself
11+
import numba.cuda.types as types
12+
13+
# Re-export all type names
14+
from numba.cuda.types import *
15+
1016

1117
# Require NVIDIA CUDA bindings at import time
1218
if not (

numba_cuda/numba/cuda/_internal/cuda_bf16.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818

1919
import numba
2020
from llvmlite import ir
21-
from numba import types
22-
from numba.core.datamodel import PrimitiveModel, StructModel
21+
from numba.cuda import types
22+
from numba.cuda.datamodel import PrimitiveModel, StructModel
2323
from numba.cuda.extending import (
2424
lower_cast,
2525
make_attribute_wrapper,
@@ -41,7 +41,7 @@
4141
from numba.cuda import CUSource, declare_device
4242
from numba.cuda.vector_types import vector_types
4343
from numba.cuda.extending import as_numba_type
44-
from numba.types import (
44+
from numba.cuda.types import (
4545
CPointer,
4646
Function,
4747
Number,
@@ -60,7 +60,7 @@
6060
uint64,
6161
void,
6262
)
63-
from numba.cuda.types import bfloat16
63+
from numba.cuda.ext_types import bfloat16
6464

6565
float32x2 = vector_types["float32x2"]
6666
__half = float16

numba_cuda/numba/cuda/_internal/cuda_fp16.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919
import numba
2020
from llvmlite import ir
21-
from numba import types
21+
from numba.cuda import types
2222
from numba.cuda.cudadrv.driver import _have_nvjitlink
23-
from numba.core.datamodel import PrimitiveModel, StructModel
23+
from numba.cuda.datamodel import PrimitiveModel, StructModel
2424
from numba.core.errors import NumbaPerformanceWarning
2525
from numba.cuda.extending import (
2626
lower_cast,
@@ -40,7 +40,7 @@
4040
from numba.cuda.typing.templates import Registry as TypingRegistry
4141
from numba.cuda.vector_types import vector_types
4242
from numba.cuda.extending import as_numba_type
43-
from numba.types import (
43+
from numba.cuda.types import (
4444
CPointer,
4545
Function,
4646
Number,
@@ -221,7 +221,7 @@ class _ctor_template_unnamed1362180(ConcreteTemplate):
221221

222222
register_global(unnamed1362180, Function(_ctor_template_unnamed1362180))
223223

224-
__half = _type___half = numba.core.types.float16
224+
__half = _type___half = numba.cuda.types.float16
225225
setattr(__half, "alignof_", 2)
226226
setattr(__half, "align", 2)
227227

numba_cuda/numba/cuda/api.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,9 @@ def from_cuda_array_interface(desc, owner=None, sync=True):
3939

4040
shape = desc["shape"]
4141
strides = desc.get("strides")
42-
dtype = np.dtype(desc["typestr"])
4342

4443
shape, strides, dtype = prepare_shape_strides_dtype(
45-
shape, strides, dtype, order="C"
44+
shape, strides, desc["typestr"], order="C"
4645
)
4746
size = driver.memory_size_from_info(shape, strides, dtype.itemsize)
4847

numba_cuda/numba/cuda/api_util.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import numpy as np
55

6+
import functools
7+
68

79
def prepare_shape_strides_dtype(shape, strides, dtype, order):
810
dtype = np.dtype(dtype)
@@ -14,25 +16,33 @@ def prepare_shape_strides_dtype(shape, strides, dtype, order):
1416
raise TypeError("shape must be an integer or tuple of integers")
1517
if isinstance(shape, int):
1618
shape = (shape,)
19+
else:
20+
shape = tuple(shape)
1721
if isinstance(strides, int):
1822
strides = (strides,)
1923
else:
20-
strides = strides or _fill_stride_by_order(shape, dtype, order)
24+
if not strides:
25+
strides = _fill_stride_by_order(shape, dtype, order)
26+
else:
27+
strides = tuple(strides)
2128
return shape, strides, dtype
2229

2330

31+
@functools.cache
2432
def _fill_stride_by_order(shape, dtype, order):
25-
nd = len(shape)
26-
if nd == 0:
33+
ndims = len(shape)
34+
if not ndims:
2735
return ()
28-
strides = [0] * nd
36+
strides = [0] * ndims
2937
if order == "C":
3038
strides[-1] = dtype.itemsize
31-
for d in reversed(range(nd - 1)):
39+
# -2 because we subtract one for zero-based indexing and another one
40+
# for skipping the already-filled-in last element
41+
for d in range(ndims - 2, -1, -1):
3242
strides[d] = strides[d + 1] * shape[d + 1]
3343
elif order == "F":
3444
strides[0] = dtype.itemsize
35-
for d in range(1, nd):
45+
for d in range(1, ndims):
3646
strides[d] = strides[d - 1] * shape[d - 1]
3747
else:
3848
raise ValueError("must be either C/F order")

numba_cuda/numba/cuda/bf16.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
3+
import sys
34

45
from numba.cuda._internal.cuda_bf16 import (
56
typing_registry,
@@ -191,14 +192,12 @@ def exp_ol(a):
191192
return _make_unary(a, hexp)
192193

193194

194-
try:
195-
from math import exp2
195+
if sys.version_info >= (3, 11):
196196

197-
@overload(exp2, target="cuda")
197+
@overload(math.exp2, target="cuda")
198198
def exp2_ol(a):
199199
return _make_unary(a, hexp2)
200-
except ImportError:
201-
pass
200+
202201

203202
## Public aliases using Numba/Numpy-style type names
204203
# Floating-point

numba_cuda/numba/cuda/cg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
33

4-
from numba.core import types
4+
from numba.cuda import types
55
from numba.cuda.extending import overload, overload_method
66
from numba.cuda.typing import signature
77
from numba.cuda import nvvmutils
88
from numba.cuda.extending import intrinsic
9-
from numba.cuda.types import grid_group, GridGroup as GridGroupClass
9+
from numba.cuda.ext_types import grid_group, GridGroup as GridGroupClass
1010

1111

1212
class GridGroup:

numba_cuda/numba/cuda/cgutils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111

1212
from llvmlite import ir
1313

14-
from numba.core import types
14+
from numba.cuda import types
1515
from numba.cuda import config, utils, debuginfo
16-
import numba.core.datamodel
16+
import numba.cuda.datamodel
1717

1818

1919
bool_t = ir.IntType(1)
@@ -104,7 +104,7 @@ class _StructProxy(object):
104104
def __init__(self, context, builder, value=None, ref=None):
105105
self._context = context
106106
self._datamodel = self._context.data_model_manager[self._fe_type]
107-
if not isinstance(self._datamodel, numba.core.datamodel.StructModel):
107+
if not isinstance(self._datamodel, numba.cuda.datamodel.StructModel):
108108
raise TypeError(
109109
"Not a structure model: {0}".format(self._datamodel)
110110
)

numba_cuda/numba/cuda/compiler.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@
77
import copy
88

99
from numba.core import ir as numba_ir
10-
from numba.core import (
11-
types,
12-
bytecode,
13-
)
10+
from numba.core import bytecode
11+
from numba.cuda import types
1412
from numba.cuda.core.options import ParallelOptions
1513
from numba.core.compiler_lock import global_compiler_lock
1614
from numba.core.errors import NumbaWarning, NumbaInvalidConfigWarning

0 commit comments

Comments
 (0)