Skip to content

Commit 13395d4

Browse files
committed
[Testing] Add some numpy array testing for the CUDA target
1 parent f6664ab commit 13395d4

File tree

9 files changed

+1368
-10
lines changed

9 files changed

+1368
-10
lines changed

numba_cuda/numba/cuda/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,9 @@
6060
)
6161

6262
from numba.cuda.np.ufunc import vectorize, guvectorize
63+
64+
# Re-export typeof
65+
from numba.cuda.misc.special import (
66+
literally,
67+
literal_unroll,
68+
)

numba_cuda/numba/cuda/np/arrayobj.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616

1717
import numpy as np
1818

19-
from numba import pndindex, literal_unroll
20-
from numba.core import types, errors
21-
from numba.cuda import typing
19+
from numba import pndindex
20+
from numba.cuda import literal_unroll
21+
from numba.core import types, typing, errors
2222
from numba.cuda import cgutils, extending
2323
from numba.cuda.np.numpy_support import (
2424
as_dtype,

numba_cuda/numba/cuda/simulator/api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,7 @@ def jitwrapper(fn):
161161
def defer_cleanup():
162162
# No effect for simulator
163163
yield
164+
165+
166+
class grid(object):
167+
pass

numba_cuda/numba/cuda/target.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
# Typing
3737

3838

39-
class CUDATypingContext(typing.BaseContext):
39+
class CUDATypingContext(typing.Context):
4040
def load_additional_registries(self):
4141
from . import (
4242
cudadecl,
@@ -46,7 +46,7 @@ def load_additional_registries(self):
4646
libdevicedecl,
4747
vector_types,
4848
)
49-
from numba.cuda.typing import enumdecl, cffi_utils
49+
from numba.cuda.typing import enumdecl, cffi_utils, npydecl
5050

5151
self.install_registry(cudadecl.registry)
5252
self.install_registry(cffi_utils.registry)
@@ -57,6 +57,7 @@ def load_additional_registries(self):
5757
self.install_registry(vector_types.typing_registry)
5858
self.install_registry(fp16.typing_registry)
5959
self.install_registry(bf16.typing_registry)
60+
self.install_registry(npydecl.registry)
6061

6162
def resolve_value_type(self, val):
6263
# treat other dispatcher object as another device function
@@ -182,6 +183,8 @@ def load_additional_registries(self):
182183
arrayobj,
183184
npdatetime,
184185
polynomial,
186+
arraymath,
187+
npyimpl,
185188
)
186189
from . import (
187190
cudaimpl,
@@ -222,6 +225,8 @@ def load_additional_registries(self):
222225
self.install_registry(polynomial.registry)
223226
self.install_registry(npdatetime.registry)
224227
self.install_registry(arrayobj.registry)
228+
self.install_registry(arraymath.registry)
229+
self.install_registry(npyimpl.registry)
225230

226231
# Install only implementations that are defined outside of numba (i.e.,
227232
# in third-party extensions) from Numba's builtin_registry.

0 commit comments

Comments
 (0)