Skip to content

Commit dc4f660

Browse files
committed
Remove dependency on numba.core.target_extension for CUDATarget
1 parent d7ac874 commit dc4f660

File tree

11 files changed

+74
-127
lines changed

11 files changed

+74
-127
lines changed

numba_cuda/numba/cuda/compiler.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -741,20 +741,16 @@ def compile_cuda(
741741
flags.max_registers = max_registers
742742
flags.lto = lto
743743

744-
# Run compilation pipeline
745-
from numba.core.target_extension import target_override
746-
747-
with target_override("cuda"):
748-
cres = compile_extra(
749-
typingctx=typingctx,
750-
targetctx=targetctx,
751-
func=pyfunc,
752-
args=args,
753-
return_type=return_type,
754-
flags=flags,
755-
locals={},
756-
pipeline_class=CUDACompiler,
757-
)
744+
cres = compile_extra(
745+
typingctx=typingctx,
746+
targetctx=targetctx,
747+
func=pyfunc,
748+
args=args,
749+
return_type=return_type,
750+
flags=flags,
751+
locals={},
752+
pipeline_class=CUDACompiler,
753+
)
758754

759755
library = cres.library
760756
library.finalize()

numba_cuda/numba/cuda/core/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,10 +212,10 @@ def enable_boundscheck(self, value):
212212
def __init__(self, typing_context, target):
213213
self.address_size = utils.MACHINE_BITS
214214
self.typing_context = typing_context
215-
from numba.core.target_extension import target_registry
215+
from numba.cuda.descriptor import cuda_target
216216

217217
self.target_name = target
218-
self.target = target_registry[target]
218+
self.target = cuda_target
219219

220220
# A mapping of installed registries to their loaders
221221
self._registries = {}

numba_cuda/numba/cuda/cuda_types/functions.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -314,14 +314,10 @@ def get_call_type(self, context, args, kws):
314314
context, self, args, kws, depth=self._depth
315315
)
316316

317-
# get the order in which to try templates
318-
from numba.core.target_extension import (
319-
get_local_target,
320-
) # circular
317+
from numba.cuda.descriptor import cuda_target
321318

322-
target_hw = get_local_target(context)
323319
order = utils.order_by_target_specificity(
324-
target_hw, self.templates, fnkey=self.key[0]
320+
cuda_target, self.templates, fnkey=self.key[0]
325321
)
326322

327323
self._depth += 1

numba_cuda/numba/cuda/descriptor.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,34 @@ def target_context(self):
3333

3434

3535
cuda_target = CUDATarget("cuda")
36+
37+
# Monkey-patch numba's get_local_target and order_by_target_specificity for CUDATarget
38+
try:
39+
from numba.core import target_extension
40+
from numba.cuda.utils import order_by_target_specificity
41+
from numba.core import utils as numba_utils
42+
43+
def _is_cuda_context(obj):
44+
return (
45+
isinstance(obj, CUDATarget)
46+
or (hasattr(obj, "__class__") and "CUDA" in obj.__class__.__name__)
47+
or (hasattr(obj, "target") and isinstance(obj.target, CUDATarget))
48+
)
49+
50+
def _patch_numba_for_cuda_target():
51+
_orig_get_local = target_extension.get_local_target
52+
53+
def get_local_target_cuda(context):
54+
return (
55+
cuda_target
56+
if _is_cuda_context(context)
57+
else _orig_get_local(context)
58+
)
59+
60+
target_extension.get_local_target = get_local_target_cuda
61+
numba_utils.order_by_target_specificity = order_by_target_specificity
62+
63+
_patch_numba_for_cuda_target()
64+
65+
except ImportError:
66+
pass

numba_cuda/numba/cuda/dispatcher.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -726,13 +726,8 @@ class CUDACache(Cache):
726726
_impl_class = CUDACacheImpl
727727

728728
def load_overload(self, sig, target_context):
729-
# Loading an overload refreshes the context to ensure it is
730-
# initialized. To initialize the correct (i.e. CUDA) target, we need to
731-
# enforce that the current target is the CUDA target.
732-
from numba.core.target_extension import target_override
733-
734-
with target_override("cuda"):
735-
return super().load_overload(sig, target_context)
729+
# Loading an overload refreshes the context to ensure it is initialized.
730+
return super().load_overload(sig, target_context)
736731

737732

738733
class OmittedArg(object):

numba_cuda/numba/cuda/initialize.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,3 @@
55
def initialize_all():
66
# Import models to register them with the data model manager
77
import numba.cuda.models # noqa: F401
8-
9-
from numba.cuda.decorators import jit
10-
from numba.cuda.dispatcher import CUDADispatcher
11-
from numba.core.target_extension import (
12-
target_registry,
13-
dispatcher_registry,
14-
jit_registry,
15-
)
16-
17-
cuda_target = target_registry["cuda"]
18-
jit_registry[cuda_target] = jit
19-
dispatcher_registry[cuda_target] = CUDADispatcher

numba_cuda/numba/cuda/lowering.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,12 +1239,9 @@ def _lower_call_normal(self, fnty, expr, signature):
12391239
)
12401240
tname = expr.target
12411241
if tname is not None:
1242-
from numba.core.target_extension import (
1243-
resolve_dispatcher_from_str,
1244-
)
1242+
from numba.cuda.descriptor import cuda_target
12451243

1246-
disp = resolve_dispatcher_from_str(tname)
1247-
hw_ctx = disp.targetdescr.target_context
1244+
hw_ctx = cuda_target.target_context
12481245
impl = hw_ctx.get_function(fnty, signature)
12491246
else:
12501247
impl = self.context.get_function(fnty, signature)

numba_cuda/numba/cuda/tests/core/test_serialize.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,18 @@
1313
import numba
1414
from numba.core.errors import TypingError
1515
from numba.cuda.tests.support import TestCase
16-
from numba.core.target_extension import resolve_dispatcher_from_str
1716
from numba.cuda.cloudpickle import dumps, loads
1817

18+
try:
19+
from numba.core.target_extension import resolve_dispatcher_from_str
20+
except ImportError:
21+
resolve_dispatcher_from_str = None
1922

23+
24+
@unittest.skipIf(
25+
resolve_dispatcher_from_str is None,
26+
"numba.core.target_extension not available",
27+
)
2028
class TestDispatcherPickling(TestCase):
2129
def run_with_protocols(self, meth, *args, **kwargs):
2230
for proto in range(pickle.HIGHEST_PROTOCOL + 1):

numba_cuda/numba/cuda/typing/context.py

Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -290,11 +290,9 @@ def core(typ):
290290
def find_matching_getattr_template(self, typ, attr):
291291
templates = list(self._get_attribute_templates(typ))
292292

293-
# get the order in which to try templates
294-
from numba.core.target_extension import get_local_target
293+
from numba.cuda.descriptor import cuda_target
295294

296-
target_hw = get_local_target(self)
297-
order = order_by_target_specificity(target_hw, templates, fnkey=attr)
295+
order = order_by_target_specificity(cuda_target, templates, fnkey=attr)
298296

299297
for template in order:
300298
return_type = template.resolve(typ, attr)
@@ -446,13 +444,6 @@ def install_registry(self, registry, external_defs_only=False):
446444
loader = templates.RegistryLoader(registry)
447445
self._registries[registry] = loader
448446

449-
from numba.core.target_extension import (
450-
get_local_target,
451-
resolve_target_str,
452-
)
453-
454-
current_target = get_local_target(self)
455-
456447
def is_for_this_target(ftcls):
457448
metadata = getattr(ftcls, "metadata", None)
458449
if metadata is None:
@@ -462,31 +453,11 @@ def is_for_this_target(ftcls):
462453
if target_str is None:
463454
return True
464455

465-
# There may be pending registrations for nonexistent targets.
466-
# Ideally it would be impossible to leave a registration pending
467-
# for an invalid target, but in practice this is exceedingly
468-
# difficult to guard against - many things are registered at import
469-
# time, and eagerly reporting an error when registering for invalid
470-
# targets would require that all target registration code is
471-
# executed prior to all typing registrations during the import
472-
# process; attempting to enforce this would impose constraints on
473-
# execution order during import that would be very difficult to
474-
# resolve and maintain in the presence of typical code maintenance.
475-
# Furthermore, these constraints would be imposed not only on
476-
# Numba internals, but also on its dependents.
477-
#
478-
# Instead of that enforcement, we simply catch any occurrences of
479-
# registrations for targets that don't exist, and report that
480-
# they're not for this target. They will then not be encountered
481-
# again during future typing context refreshes (because the
482-
# loader's new registrations are a stream_list that doesn't yield
483-
# previously-yielded items).
484-
try:
485-
ft_target = resolve_target_str(target_str)
486-
except errors.NonexistentTargetError:
487-
return False
456+
# Accept both "cuda" and "generic" targets
457+
if target_str in ("cuda", "generic"):
458+
return True
488459

489-
return current_target.inherits_from(ft_target)
460+
return False
490461

491462
def is_external(obj):
492463
"""Check if obj is from outside numba.* namespace."""

numba_cuda/numba/cuda/typing/templates.py

Lines changed: 4 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -778,37 +778,9 @@ def _get_impl(self, args, kws):
778778

779779
def _get_jit_decorator(self):
780780
"""Gets a jit decorator suitable for the current target"""
781+
from numba.cuda.decorators import jit
781782

782-
from numba.core.target_extension import (
783-
target_registry,
784-
get_local_target,
785-
jit_registry,
786-
)
787-
788-
jitter_str = self.metadata.get("target", "generic")
789-
jitter = jit_registry.get(jitter_str, None)
790-
791-
if jitter is None:
792-
# No JIT known for target string, see if something is
793-
# registered for the string and report if not.
794-
target_class = target_registry.get(jitter_str, None)
795-
if target_class is None:
796-
msg = ("Unknown target '{}', has it been ", "registered?")
797-
raise ValueError(msg.format(jitter_str))
798-
799-
target_hw = get_local_target(self.context)
800-
801-
# check that the requested target is in the hierarchy for the
802-
# current frame's target.
803-
if not issubclass(target_hw, target_class):
804-
msg = "No overloads exist for the requested target: {}."
805-
806-
jitter = jit_registry[target_hw]
807-
808-
if jitter is None:
809-
raise ValueError("Cannot find a suitable jit decorator")
810-
811-
return jitter
783+
return jit
812784

813785
def _build_impl(self, cache_key, args, kws):
814786
"""Build and cache the implementation.
@@ -988,16 +960,9 @@ def _get_target_registry(self, reason):
988960
-------
989961
reg : a registry suitable for the current target.
990962
"""
991-
from numba.core.target_extension import (
992-
_get_local_target_checked,
993-
dispatcher_registry,
994-
)
963+
from numba.cuda.descriptor import cuda_target
995964

996-
hwstr = self.metadata.get("target", "generic")
997-
target_hw = _get_local_target_checked(self.context, hwstr, reason)
998-
# Get registry for the current hardware
999-
disp = dispatcher_registry[target_hw]
1000-
tgtctx = disp.targetdescr.target_context
965+
tgtctx = cuda_target.target_context
1001966

1002967
# ---------------------------------------------------------------------
1003968
# XXX: In upstream Numba, this function would prefer the builtin

0 commit comments

Comments
 (0)