Skip to content

Commit fbdff35

Browse files
committed
Fallback to Numba's typeof_impl for third-party typing registrations
1 parent 000a6bf commit fbdff35

File tree

2 files changed

+50
-6
lines changed

2 files changed

+50
-6
lines changed

numba_cuda/numba/cuda/typing/context.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -405,20 +405,34 @@ def _load_builtins(self):
405405
self.install_registry(npdatetime.registry)
406406
self.install_registry(templates.builtin_registry)
407407

408+
# Install only third-party declarations from Numba's typing registry
409+
# if it is available. Exclude Numba's own typing declarations.
408410
if find_spec("numba.core.typing") is not None:
409411
from numba.core.typing import templates as core_templates
410412

411-
self.install_registry(core_templates.builtin_registry)
413+
self.install_registry(
414+
core_templates.builtin_registry, external_defs_only=True
415+
)
412416

413417
def load_additional_registries(self):
414418
"""
415419
Load target-specific registries. Can be overridden by subclasses.
416420
"""
417421

418-
def install_registry(self, registry):
422+
def install_registry(self, registry, external_defs_only=False):
419423
"""
420424
Install a *registry* (a templates.Registry instance) of function,
421425
attribute and global declarations.
426+
427+
Parameters
428+
----------
429+
registry : Registry
430+
The registry to install
431+
external_defs_only : bool, optional
432+
If True, only install registrations for types from outside numba.* namespace.
433+
This is useful when installing third-party registrations from
434+
a shared registry like Numba's typing registry (builtin_registry).
435+
422436
"""
423437
try:
424438
loader = self._registries[registry]
@@ -468,15 +482,35 @@ def is_for_this_target(ftcls):
468482

469483
return current_target.inherits_from(ft_target)
470484

471-
for ftcls in loader.new_registrations("functions"):
472-
if not is_for_this_target(ftcls):
473-
continue
474-
self.insert_function(ftcls(self))
485+
def is_external_type(typ):
486+
"""Check if a type is from outside numba.* namespace."""
487+
try:
488+
return not typ.__module__.startswith("numba.")
489+
except AttributeError:
490+
return True
491+
492+
# Skip functions entirely when external_defs_only=True
493+
if not external_defs_only:
494+
for ftcls in loader.new_registrations("functions"):
495+
if not is_for_this_target(ftcls):
496+
continue
497+
self.insert_function(ftcls(self))
475498
for ftcls in loader.new_registrations("attributes"):
476499
if not is_for_this_target(ftcls):
477500
continue
501+
# If external_defs_only, check if the type being registered is external
502+
if external_defs_only:
503+
key = getattr(ftcls, "key", None)
504+
if key is not None and not is_external_type(key):
505+
continue
478506
self.insert_attributes(ftcls(self))
479507
for gv, gty in loader.new_registrations("globals"):
508+
# If external_defs_only, check the global type's module
509+
if external_defs_only:
510+
if hasattr(gty, "__module__") and gty.__module__.startswith(
511+
"numba."
512+
):
513+
continue
480514
existing = self._lookup_global(gv)
481515
if existing is None:
482516
self.insert_global(gv, gty)

numba_cuda/numba/cuda/typing/typeof.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,16 @@ def typeof_impl(val, c):
6464
if cffi_utils.is_ffi_instance(val):
6565
return types.ffi
6666

67+
# Fallback to Numba's typeof_impl for third-party registrations
68+
from numba.core.typing.typeof import typeof_impl as core_typeof_impl
69+
70+
try:
71+
tp = core_typeof_impl(val, c)
72+
if tp is not None:
73+
return tp
74+
except (ValueError, TypeError, AttributeError):
75+
pass
76+
6777
return None
6878

6979

0 commit comments

Comments
 (0)