Skip to content

Commit 911b178

Browse files
committed
Merge remote-tracking branch 'gmarkall/vk/target_extension' into vk/target_extension
2 parents f2be476 + d3e1677 commit 911b178

File tree

7 files changed

+38
-53
lines changed

7 files changed

+38
-53
lines changed

numba_cuda/numba/cuda/compiler.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -741,16 +741,17 @@ def compile_cuda(
741741
flags.max_registers = max_registers
742742
flags.lto = lto
743743

744-
cres = compile_extra(
745-
typingctx=typingctx,
746-
targetctx=targetctx,
747-
func=pyfunc,
748-
args=args,
749-
return_type=return_type,
750-
flags=flags,
751-
locals={},
752-
pipeline_class=CUDACompiler,
753-
)
744+
with utils.numba_target_override():
745+
cres = compile_extra(
746+
typingctx=typingctx,
747+
targetctx=targetctx,
748+
func=pyfunc,
749+
args=args,
750+
return_type=return_type,
751+
flags=flags,
752+
locals={},
753+
pipeline_class=CUDACompiler,
754+
)
754755

755756
library = cres.library
756757
library.finalize()

numba_cuda/numba/cuda/core/base.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from collections import defaultdict
55
import copy
6+
import importlib
67
import sys
78
from itertools import permutations, takewhile
89
from contextlib import contextmanager
@@ -212,10 +213,15 @@ def enable_boundscheck(self, value):
212213
def __init__(self, typing_context, target):
213214
self.address_size = utils.MACHINE_BITS
214215
self.typing_context = typing_context
215-
from numba.cuda.descriptor import cuda_target
216-
217216
self.target_name = target
218-
self.target = cuda_target
217+
218+
if importlib.util.find_spec("numba"):
219+
from numba.core.target_extension import CUDA
220+
221+
# Used only in Numba's target_extension implementation.
222+
# Numba-CUDA has the target_extension implementation removed, and
223+
# references to it hardcoded to values specific to the CUDA target.
224+
self.target = CUDA
219225

220226
# A mapping of installed registries to their loaders
221227
self._registries = {}

numba_cuda/numba/cuda/descriptor.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -33,34 +33,3 @@ def target_context(self):
3333

3434

3535
cuda_target = CUDATarget("cuda")
36-
37-
# Monkey-patch numba's get_local_target and order_by_target_specificity for CUDATarget
38-
try:
39-
from numba.core import target_extension
40-
from numba.cuda.utils import order_by_target_specificity
41-
from numba.core import utils as numba_utils
42-
43-
def _is_cuda_context(obj):
44-
return (
45-
isinstance(obj, CUDATarget)
46-
or (hasattr(obj, "__class__") and "CUDA" in obj.__class__.__name__)
47-
or (hasattr(obj, "target") and isinstance(obj.target, CUDATarget))
48-
)
49-
50-
def _patch_numba_for_cuda_target():
51-
_orig_get_local = target_extension.get_local_target
52-
53-
def get_local_target_cuda(context):
54-
return (
55-
cuda_target
56-
if _is_cuda_context(context)
57-
else _orig_get_local(context)
58-
)
59-
60-
target_extension.get_local_target = get_local_target_cuda
61-
numba_utils.order_by_target_specificity = order_by_target_specificity
62-
63-
_patch_numba_for_cuda_target()
64-
65-
except ImportError:
66-
pass

numba_cuda/numba/cuda/dispatcher.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,8 @@ class CUDACache(Cache):
727727

728728
def load_overload(self, sig, target_context):
729729
# Loading an overload refreshes the context to ensure it is initialized.
730-
return super().load_overload(sig, target_context)
730+
with utils.numba_target_override():
731+
return super().load_overload(sig, target_context)
731732

732733

733734
class OmittedArg(object):

numba_cuda/numba/cuda/types/cuda_functions.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -314,10 +314,8 @@ def get_call_type(self, context, args, kws):
314314
context, self, args, kws, depth=self._depth
315315
)
316316

317-
from numba.cuda.descriptor import cuda_target
318-
319317
order = utils.order_by_target_specificity(
320-
cuda_target, self.templates, fnkey=self.key[0]
318+
self.templates, fnkey=self.key[0]
321319
)
322320

323321
self._depth += 1

numba_cuda/numba/cuda/typing/context.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,7 @@ def core(typ):
290290
def find_matching_getattr_template(self, typ, attr):
291291
templates = list(self._get_attribute_templates(typ))
292292

293-
from numba.cuda.descriptor import cuda_target
294-
295-
order = order_by_target_specificity(cuda_target, templates, fnkey=attr)
293+
order = order_by_target_specificity(templates, fnkey=attr)
296294

297295
for template in order:
298296
return_type = template.resolve(typ, attr)

numba_cuda/numba/cuda/utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import atexit
1010
import builtins
11+
import importlib
1112
import inspect
1213
import operator
1314
import timeit
@@ -311,7 +312,7 @@ def __hash__(self):
311312
return hash(tuple(sorted(self._values.items())))
312313

313314

314-
def order_by_target_specificity(target, templates, fnkey=""):
315+
def order_by_target_specificity(templates, fnkey=""):
315316
"""This orders the given templates from most to least specific against the
316317
current "target". "fnkey" is an indicative typing key for use in the
317318
exception message in the case that there's no usable templates for the
@@ -345,7 +346,7 @@ def key(x):
345346
if not order:
346347
msg = (
347348
f"Function resolution cannot find any matches for function "
348-
f"'{fnkey}' for the current target: '{target}'."
349+
f"'{fnkey}'."
349350
)
350351
from numba.core.errors import UnsupportedError
351352

@@ -710,3 +711,14 @@ def _readenv(name, ctor, default):
710711
def cached_file_read(filepath, how="r"):
711712
with open(filepath, how) as f:
712713
return f.read()
714+
715+
716+
@contextlib.contextmanager
717+
def numba_target_override():
718+
if importlib.util.find_spec("numba"):
719+
from numba.core.target_extension import target_override
720+
721+
with target_override("cuda"):
722+
yield
723+
else:
724+
yield

0 commit comments

Comments
 (0)