Skip to content

Commit eeac024

Browse files
authored
[Refactor][NFC][Cleanups] Update imports to upstream numba to use the numba.cuda modules (#561)
This PR updates the vast majority of imports to upstream numba modules, that had already been vendored in for future CUDA-specific changes. After this PR, there should be exactly 3 (not strictly necessary) upstream numba imports remaining - one in testing vectorization for in test_math, one in random.py, and the global compiler_lock. It seems from the existing comments that removing the CPU jit in random.py would result in considerable performance drops, so I've just added a guard for numba being available. Some lowering / overload implementations that depend on numba.core.misc.{mergesort,quicksort} have been removed. These are presently not regarded as supported, and are unlikely to be performant implementations in a CUDA-specific context. There are many tests that rely on CPU jitting, these tests are now guarded to not run in environments where the numba package is not available. This is done using a global importlib check in numba-cuda's `__init__.py`. This cannot be moved into numba.cuda.config since the upstream numba doesn't have this flag, and we intend on forwarding calls to upstream numba when it is available. All previously sporadic guards for this *should* have been updated to use this global flag now. Support for prange, pndindex, and stencils has been removed, likewise for typed.Dict, and typed.List. Also, this removes all mention of the legacy type system flag since we implicitly assume the legacy type system in place.
1 parent 7ce01f4 commit eeac024

File tree

90 files changed

+340
-494
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

90 files changed

+340
-494
lines changed

numba_cuda/numba/cuda/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
import warnings
88
import sys
99

10+
1011
# Re-export types itself
1112
import numba.cuda.types as types
1213

1314
# Re-export all type names
1415
from numba.cuda.types import *
1516

17+
HAS_NUMBA = importlib.util.find_spec("numba") is not None
1618

1719
# Require NVIDIA CUDA bindings at import time
1820
if not (

numba_cuda/numba/cuda/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ def run_pass(self, state):
446446
@register_pass(mutates_CFG=True, analysis_only=False)
447447
class CUDANativeLowering(BaseNativeLowering):
448448
"""Lowering pass for a CUDA native function IR described solely in terms of
449-
Numba's standard `numba.core.ir` nodes."""
449+
Numba's standard `numba.cuda.core.ir` nodes."""
450450

451451
_name = "cuda_native_lowering"
452452

numba_cuda/numba/cuda/core/analysis.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
from collections import namedtuple, defaultdict
55
from numba.cuda import types
66
from numba.cuda.core import ir
7-
from numba.core import errors
7+
from numba.cuda.core import errors
88
from numba.cuda.core import consts
99
import operator
1010
from functools import reduce
1111

1212
from .controlflow import CFGraph
13-
from numba.misc import special
13+
from numba.cuda.misc import special
1414

1515
#
1616
# Analysis related to variable lifetime
@@ -354,7 +354,7 @@ def rewrite_tuple_len(val, func_ir, called_args):
354354
if isinstance(argty, types.BaseTuple):
355355
rewrite_statement(func_ir, stmt, argty.count)
356356

357-
from numba.core.ir_utils import get_definition, guard
357+
from numba.cuda.core.ir_utils import get_definition, guard
358358

359359
for blk in func_ir.blocks.values():
360360
for stmt in blk.body:

numba_cuda/numba/cuda/core/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from numba.cuda.core import imputils, targetconfig, funcdesc
1616
from numba.cuda import cgutils, debuginfo, types, utils, datamodel, config
17-
from numba.core import errors
17+
from numba.cuda.core import errors
1818
from numba.core.compiler_lock import global_compiler_lock
1919
from numba.cuda.core.pythonapi import PythonAPI
2020
from numba.cuda.core.imputils import (

numba_cuda/numba/cuda/core/bytecode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
from types import CodeType, ModuleType
1010

11-
from numba.core import errors
12-
from numba.core import serialize
11+
from numba.cuda.core import errors
12+
from numba.cuda import serialize
1313
from numba.cuda import utils
1414
from numba.cuda.utils import PYVERSION
1515

numba_cuda/numba/cuda/core/byteflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
_lazy_pformat,
1818
)
1919
from numba.cuda.core.controlflow import NEW_BLOCKERS, CFGraph
20-
from numba.core.ir import Loc
20+
from numba.cuda.core.ir import Loc
2121
from numba.cuda.errors import UnsupportedBytecodeError
2222

2323

numba_cuda/numba/cuda/core/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from numba.cuda.core.tracing import event
55

6-
from numba.core import errors
6+
from numba.cuda.core import errors
77
from numba.cuda.core.errors import CompilerError
88

99
from numba.cuda.core import callconv, config, bytecode

numba_cuda/numba/cuda/core/compiler_machinery.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
from numba.core.compiler_lock import global_compiler_lock
11-
from numba.core import errors
11+
from numba.cuda.core import errors
1212
from numba.cuda.core import config
1313
from numba.cuda import utils
1414
from numba.cuda.core import transforms

numba_cuda/numba/cuda/core/config.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,6 @@ def _readenv(name, ctor, default):
212212
def optional_str(x):
213213
return str(x) if x is not None else None
214214

215-
# Type casting rules selection
216-
USE_LEGACY_TYPE_SYSTEM = _readenv(
217-
"NUMBA_USE_LEGACY_TYPE_SYSTEM", int, 1
218-
)
219-
220215
# developer mode produces full tracebacks, disables help instructions
221216
DEVELOPER_MODE = _readenv("NUMBA_DEVELOPER_MODE", int, 0)
222217

numba_cuda/numba/cuda/core/controlflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import functools
66
import sys
77

8-
from numba.core.ir import Loc
8+
from numba.cuda.core.ir import Loc
99
from numba.cuda.core.errors import UnsupportedError
1010
from numba.cuda.utils import PYVERSION
1111

0 commit comments

Comments
 (0)