Skip to content

Commit 58ed274

Browse files
Fix double-registration of items from npydecl Historically cudadecl registered a few things from npydecl to support NumPy dtypes in kernels and a small set of ufuncs. Now that the real npydecl is registered, these items should not also be registered in cudadecl, because this leads to an erroneous double-registration.
Co-authored-by: Graham Markall <[email protected]>
1 parent c896826 commit 58ed274

File tree

4 files changed

+6
-30
lines changed

4 files changed

+6
-30
lines changed

numba_cuda/numba/cuda/cudadecl.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,6 @@
66
from numba.cuda.typing.npydecl import (
77
parse_dtype,
88
parse_shape,
9-
register_number_classes,
10-
register_numpy_ufunc,
11-
trigonometric_functions,
12-
comparison_functions,
13-
math_operations,
14-
bit_twiddling_functions,
159
)
1610
from numba.cuda.typing.templates import (
1711
AttributeTemplate,
@@ -29,8 +23,6 @@
2923
register_attr = registry.register_attr
3024
register_global = registry.register_global
3125

32-
register_number_classes(register_global)
33-
3426

3527
class Cuda_array_decl(CallableTemplate):
3628
def generic(self):
@@ -562,19 +554,3 @@ def resolve_local(self, mod):
562554

563555

564556
register_global(cuda, types.Module(cuda))
565-
566-
567-
# NumPy
568-
569-
for func in trigonometric_functions:
570-
register_numpy_ufunc(func, register_global)
571-
572-
for func in comparison_functions:
573-
register_numpy_ufunc(func, register_global)
574-
575-
for func in bit_twiddling_functions:
576-
register_numpy_ufunc(func, register_global)
577-
578-
for func in math_operations:
579-
if func in ("log", "log2", "log10"):
580-
register_numpy_ufunc(func, register_global)

numba_cuda/numba/cuda/device_init.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@
8787
PTXSource,
8888
)
8989

90-
# from numba.cuda.misc.special import literal_unroll
91-
# from numba.cuda.misc import literal
90+
from numba.cuda.misc.special import literal_unroll
91+
from numba.cuda.misc import literal
9292

9393
reduce = Reduce = reduction.Reduce
9494

numba_cuda/numba/cuda/target.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
# Typing
3737

3838

39-
class CUDATypingContext(typing.Context):
39+
class CUDATypingContext(typing.BaseContext):
4040
def load_additional_registries(self):
4141
from . import (
4242
cudadecl,
@@ -46,18 +46,18 @@ def load_additional_registries(self):
4646
libdevicedecl,
4747
vector_types,
4848
)
49-
from numba.cuda.typing import enumdecl, cffi_utils
49+
from numba.cuda.typing import enumdecl, cffi_utils, npydecl
5050

5151
self.install_registry(cudadecl.registry)
5252
self.install_registry(cffi_utils.registry)
5353
self.install_registry(cudamath.registry)
5454
self.install_registry(cmathdecl.registry)
5555
self.install_registry(libdevicedecl.registry)
56+
self.install_registry(npydecl.registry)
5657
self.install_registry(enumdecl.registry)
5758
self.install_registry(vector_types.typing_registry)
5859
self.install_registry(fp16.typing_registry)
5960
self.install_registry(bf16.typing_registry)
60-
super().load_additional_registries()
6161

6262
def resolve_value_type(self, val):
6363
# treat other dispatcher object as another device function

numba_cuda/numba/cuda/typing/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ def is_external(obj):
522522
else:
523523
# A type was already inserted, see if we can add to it
524524
newty = existing.augment(gty)
525-
if newty is None and existing != gty:
525+
if newty is None:
526526
raise TypeError(
527527
"cannot augment %s with %s" % (existing, gty)
528528
)

0 commit comments

Comments
 (0)