Skip to content

Commit 5aeb63c

Browse files
Remove dependencies on target_extension for CUDA target (#555)
This PR removes dependency on `numba.core.target_extension`. This import was primarily used to get the local target which is `CUDA` in our case. --------- Co-authored-by: Graham Markall <[email protected]>
1 parent 76205bc commit 5aeb63c

File tree

10 files changed

+53
-121
lines changed

10 files changed

+53
-121
lines changed

numba_cuda/numba/cuda/compiler.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -741,10 +741,7 @@ def compile_cuda(
741741
flags.max_registers = max_registers
742742
flags.lto = lto
743743

744-
# Run compilation pipeline
745-
from numba.core.target_extension import target_override
746-
747-
with target_override("cuda"):
744+
with utils.numba_target_override():
748745
cres = compile_extra(
749746
typingctx=typingctx,
750747
targetctx=targetctx,

numba_cuda/numba/cuda/core/base.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from collections import defaultdict
55
import copy
6+
import importlib
67
import sys
78
from itertools import permutations, takewhile
89
from contextlib import contextmanager
@@ -212,10 +213,15 @@ def enable_boundscheck(self, value):
212213
def __init__(self, typing_context, target):
213214
self.address_size = utils.MACHINE_BITS
214215
self.typing_context = typing_context
215-
from numba.core.target_extension import target_registry
216-
217216
self.target_name = target
218-
self.target = target_registry[target]
217+
218+
if importlib.util.find_spec("numba"):
219+
from numba.core.target_extension import CUDA
220+
221+
# Used only in Numba's target_extension implementation.
222+
# Numba-CUDA has the target_extension implementation removed, and
223+
# references to it hardcoded to values specific to the CUDA target.
224+
self.target = CUDA
219225

220226
# A mapping of installed registries to their loaders
221227
self._registries = {}

numba_cuda/numba/cuda/dispatcher.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -726,12 +726,8 @@ class CUDACache(Cache):
726726
_impl_class = CUDACacheImpl
727727

728728
def load_overload(self, sig, target_context):
729-
# Loading an overload refreshes the context to ensure it is
730-
# initialized. To initialize the correct (i.e. CUDA) target, we need to
731-
# enforce that the current target is the CUDA target.
732-
from numba.core.target_extension import target_override
733-
734-
with target_override("cuda"):
729+
# Loading an overload refreshes the context to ensure it is initialized.
730+
with utils.numba_target_override():
735731
return super().load_overload(sig, target_context)
736732

737733

numba_cuda/numba/cuda/initialize.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,3 @@
55
def initialize_all():
66
# Import models to register them with the data model manager
77
import numba.cuda.models # noqa: F401
8-
9-
from numba.cuda.decorators import jit
10-
from numba.cuda.dispatcher import CUDADispatcher
11-
from numba.core.target_extension import (
12-
target_registry,
13-
dispatcher_registry,
14-
jit_registry,
15-
)
16-
17-
cuda_target = target_registry["cuda"]
18-
jit_registry[cuda_target] = jit
19-
dispatcher_registry[cuda_target] = CUDADispatcher

numba_cuda/numba/cuda/lowering.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,12 +1239,9 @@ def _lower_call_normal(self, fnty, expr, signature):
12391239
)
12401240
tname = expr.target
12411241
if tname is not None:
1242-
from numba.core.target_extension import (
1243-
resolve_dispatcher_from_str,
1244-
)
1242+
from numba.cuda.descriptor import cuda_target
12451243

1246-
disp = resolve_dispatcher_from_str(tname)
1247-
hw_ctx = disp.targetdescr.target_context
1244+
hw_ctx = cuda_target.target_context
12481245
impl = hw_ctx.get_function(fnty, signature)
12491246
else:
12501247
impl = self.context.get_function(fnty, signature)

numba_cuda/numba/cuda/tests/core/test_serialize.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,18 @@
1313
import numba
1414
from numba.core.errors import TypingError
1515
from numba.cuda.tests.support import TestCase
16-
from numba.core.target_extension import resolve_dispatcher_from_str
1716
from numba.cuda.cloudpickle import dumps, loads
1817

18+
try:
19+
from numba.core.target_extension import resolve_dispatcher_from_str
20+
except ImportError:
21+
resolve_dispatcher_from_str = None
1922

23+
24+
@unittest.skipIf(
25+
resolve_dispatcher_from_str is None,
26+
"numba.core.target_extension not available",
27+
)
2028
class TestDispatcherPickling(TestCase):
2129
def run_with_protocols(self, meth, *args, **kwargs):
2230
for proto in range(pickle.HIGHEST_PROTOCOL + 1):

numba_cuda/numba/cuda/types/cuda_functions.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -314,14 +314,8 @@ def get_call_type(self, context, args, kws):
314314
context, self, args, kws, depth=self._depth
315315
)
316316

317-
# get the order in which to try templates
318-
from numba.core.target_extension import (
319-
get_local_target,
320-
) # circular
321-
322-
target_hw = get_local_target(context)
323317
order = utils.order_by_target_specificity(
324-
target_hw, self.templates, fnkey=self.key[0]
318+
self.templates, fnkey=self.key[0]
325319
)
326320

327321
self._depth += 1

numba_cuda/numba/cuda/typing/context.py

Lines changed: 5 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -290,11 +290,7 @@ def core(typ):
290290
def find_matching_getattr_template(self, typ, attr):
291291
templates = list(self._get_attribute_templates(typ))
292292

293-
# get the order in which to try templates
294-
from numba.core.target_extension import get_local_target
295-
296-
target_hw = get_local_target(self)
297-
order = order_by_target_specificity(target_hw, templates, fnkey=attr)
293+
order = order_by_target_specificity(templates, fnkey=attr)
298294

299295
for template in order:
300296
return_type = template.resolve(typ, attr)
@@ -446,13 +442,6 @@ def install_registry(self, registry, external_defs_only=False):
446442
loader = templates.RegistryLoader(registry)
447443
self._registries[registry] = loader
448444

449-
from numba.core.target_extension import (
450-
get_local_target,
451-
resolve_target_str,
452-
)
453-
454-
current_target = get_local_target(self)
455-
456445
def is_for_this_target(ftcls):
457446
metadata = getattr(ftcls, "metadata", None)
458447
if metadata is None:
@@ -462,31 +451,11 @@ def is_for_this_target(ftcls):
462451
if target_str is None:
463452
return True
464453

465-
# There may be pending registrations for nonexistent targets.
466-
# Ideally it would be impossible to leave a registration pending
467-
# for an invalid target, but in practice this is exceedingly
468-
# difficult to guard against - many things are registered at import
469-
# time, and eagerly reporting an error when registering for invalid
470-
# targets would require that all target registration code is
471-
# executed prior to all typing registrations during the import
472-
# process; attempting to enforce this would impose constraints on
473-
# execution order during import that would be very difficult to
474-
# resolve and maintain in the presence of typical code maintenance.
475-
# Furthermore, these constraints would be imposed not only on
476-
# Numba internals, but also on its dependents.
477-
#
478-
# Instead of that enforcement, we simply catch any occurrences of
479-
# registrations for targets that don't exist, and report that
480-
# they're not for this target. They will then not be encountered
481-
# again during future typing context refreshes (because the
482-
# loader's new registrations are a stream_list that doesn't yield
483-
# previously-yielded items).
484-
try:
485-
ft_target = resolve_target_str(target_str)
486-
except errors.NonexistentTargetError:
487-
return False
454+
# Accept both "cuda" and "generic" targets
455+
if target_str in ("cuda", "generic"):
456+
return True
488457

489-
return current_target.inherits_from(ft_target)
458+
return False
490459

491460
def is_external(obj):
492461
"""Check if obj is from outside numba.* namespace."""

numba_cuda/numba/cuda/typing/templates.py

Lines changed: 4 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -778,37 +778,9 @@ def _get_impl(self, args, kws):
778778

779779
def _get_jit_decorator(self):
780780
"""Gets a jit decorator suitable for the current target"""
781+
from numba.cuda.decorators import jit
781782

782-
from numba.core.target_extension import (
783-
target_registry,
784-
get_local_target,
785-
jit_registry,
786-
)
787-
788-
jitter_str = self.metadata.get("target", "generic")
789-
jitter = jit_registry.get(jitter_str, None)
790-
791-
if jitter is None:
792-
# No JIT known for target string, see if something is
793-
# registered for the string and report if not.
794-
target_class = target_registry.get(jitter_str, None)
795-
if target_class is None:
796-
msg = ("Unknown target '{}', has it been ", "registered?")
797-
raise ValueError(msg.format(jitter_str))
798-
799-
target_hw = get_local_target(self.context)
800-
801-
# check that the requested target is in the hierarchy for the
802-
# current frame's target.
803-
if not issubclass(target_hw, target_class):
804-
msg = "No overloads exist for the requested target: {}."
805-
806-
jitter = jit_registry[target_hw]
807-
808-
if jitter is None:
809-
raise ValueError("Cannot find a suitable jit decorator")
810-
811-
return jitter
783+
return jit
812784

813785
def _build_impl(self, cache_key, args, kws):
814786
"""Build and cache the implementation.
@@ -988,16 +960,9 @@ def _get_target_registry(self, reason):
988960
-------
989961
reg : a registry suitable for the current target.
990962
"""
991-
from numba.core.target_extension import (
992-
_get_local_target_checked,
993-
dispatcher_registry,
994-
)
963+
from numba.cuda.descriptor import cuda_target
995964

996-
hwstr = self.metadata.get("target", "generic")
997-
target_hw = _get_local_target_checked(self.context, hwstr, reason)
998-
# Get registry for the current hardware
999-
disp = dispatcher_registry[target_hw]
1000-
tgtctx = disp.targetdescr.target_context
965+
tgtctx = cuda_target.target_context
1001966

1002967
# ---------------------------------------------------------------------
1003968
# XXX: In upstream Numba, this function would prefer the builtin

numba_cuda/numba/cuda/utils.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import atexit
1010
import builtins
11+
import importlib
1112
import inspect
1213
import operator
1314
import timeit
@@ -311,7 +312,7 @@ def __hash__(self):
311312
return hash(tuple(sorted(self._values.items())))
312313

313314

314-
def order_by_target_specificity(target, templates, fnkey=""):
315+
def order_by_target_specificity(templates, fnkey=""):
315316
"""This orders the given templates from most to least specific against the
316317
current "target". "fnkey" is an indicative typing key for use in the
317318
exception message in the case that there's no usable templates for the
@@ -321,8 +322,6 @@ def order_by_target_specificity(target, templates, fnkey=""):
321322
if templates == []:
322323
return []
323324

324-
from numba.core.target_extension import target_registry
325-
326325
# fish out templates that are specific to the target if a target is
327326
# specified
328327
DEFAULT_TARGET = "generic"
@@ -332,20 +331,22 @@ def order_by_target_specificity(target, templates, fnkey=""):
332331
md = getattr(temp_cls, "metadata", {})
333332
hw = md.get("target", DEFAULT_TARGET)
334333
if hw is not None:
335-
hw_clazz = target_registry[hw]
336-
if target.inherits_from(hw_clazz):
337-
usable.append((temp_cls, hw_clazz, ix))
334+
if hw in ("generic", "cuda"):
335+
usable.append((temp_cls, ix))
338336

339337
# sort templates based on target specificity
338+
# cuda-specific templates get priority before generic ones
340339
def key(x):
341-
return target.__mro__.index(x[1])
340+
md = getattr(x[0], "metadata", {})
341+
hw = md.get("target", DEFAULT_TARGET)
342+
return (0 if hw == "cuda" else 1, x[1])
342343

343344
order = [x[0] for x in sorted(usable, key=key)]
344345

345346
if not order:
346347
msg = (
347348
f"Function resolution cannot find any matches for function "
348-
f"'{fnkey}' for the current target: '{target}'."
349+
f"'{fnkey}'."
349350
)
350351
from numba.core.errors import UnsupportedError
351352

@@ -710,3 +711,14 @@ def _readenv(name, ctor, default):
710711
def cached_file_read(filepath, how="r"):
711712
with open(filepath, how) as f:
712713
return f.read()
714+
715+
716+
@contextlib.contextmanager
717+
def numba_target_override():
718+
if importlib.util.find_spec("numba"):
719+
from numba.core.target_extension import target_override
720+
721+
with target_override("cuda"):
722+
yield
723+
else:
724+
yield

0 commit comments

Comments
 (0)