Skip to content

Commit 361a788

Browse files
committed
Install only external registrations from numba's builtin_registry
1 parent 802b38f commit 361a788

File tree

3 files changed

+61
-5
lines changed

3 files changed

+61
-5
lines changed

numba_cuda/numba/cuda/core/base.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,63 @@ def install_registry(self, registry):
356356
self._insert_cast_defn(loader.new_registrations("casts"))
357357
self._insert_get_constant_defn(loader.new_registrations("constants"))
358358

359+
def install_external_registry(self, registry):
360+
"""
361+
Install only the registrations that were defined outside
362+
of the "numba." namespace (i.e., in third-party extensions).
363+
This is useful for selectively installing implementations
364+
from the shared builtin_registry without pulling in any CPU-specific
365+
implementations from Numba.
366+
"""
367+
368+
def is_external(impl):
369+
"""Check if implementation is defined outside numba.* namespace."""
370+
try:
371+
module = impl.__module__
372+
return not module.startswith("numba.")
373+
except AttributeError:
374+
# If we can't determine module, conservatively include registration
375+
return True
376+
377+
try:
378+
loader = self._registries[registry]
379+
except KeyError:
380+
loader = RegistryLoader(registry)
381+
self._registries[registry] = loader
382+
383+
# Filter registrations
384+
funcs = [
385+
(impl, func, sig)
386+
for impl, func, sig in loader.new_registrations("functions")
387+
if is_external(impl)
388+
]
389+
getattrs = [
390+
(impl, attr, sig)
391+
for impl, attr, sig in loader.new_registrations("getattrs")
392+
if is_external(impl)
393+
]
394+
setattrs = [
395+
(impl, attr, sig)
396+
for impl, attr, sig in loader.new_registrations("setattrs")
397+
if is_external(impl)
398+
]
399+
casts = [
400+
(impl, sig)
401+
for impl, sig in loader.new_registrations("casts")
402+
if is_external(impl)
403+
]
404+
constants = [
405+
(impl, sig)
406+
for impl, sig in loader.new_registrations("constants")
407+
if is_external(impl)
408+
]
409+
410+
self.insert_func_defn(funcs)
411+
self._insert_getattr_defn(getattrs)
412+
self._insert_setattr_defn(setattrs)
413+
self._insert_cast_defn(casts)
414+
self._insert_get_constant_defn(constants)
415+
359416
def insert_func_defn(self, defns):
360417
for impl, func, sig in defns:
361418
self._defns[func].append(impl, sig)

numba_cuda/numba/cuda/np/arrayobj.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3018,7 +3018,6 @@ def _compatible_view(a, dtype):
30183018

30193019

30203020
@overload(_compatible_view)
3021-
@overload(_compatible_view, target="generic")
30223021
def ol_compatible_view(a, dtype):
30233022
"""Determines if the array and dtype are compatible for forming a view."""
30243023

numba_cuda/numba/cuda/target.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -218,12 +218,12 @@ def load_additional_registries(self):
218218
self.install_registry(npdatetime.registry)
219219
self.install_registry(arrayobj.registry)
220220

221+
# Install only implementations that are defined outside of numba (i.e., in third-party extensions)
222+
# from Numba's builtin_registry, i.e. exclude anything in "numba.*" namespace.
221223
if importlib.util.find_spec("numba.core.imputils") is not None:
222-
from numba.core.imputils import (
223-
builtin_registry as upstream_builtin_registry,
224-
)
224+
from numba.core.imputils import builtin_registry
225225

226-
self.install_registry(upstream_builtin_registry)
226+
self.install_external_registry(builtin_registry)
227227

228228
def codegen(self):
229229
return self._internal_codegen

0 commit comments

Comments
 (0)