Skip to content

Commit a0e3c49

Browse files
authored
Merge branch 'main' into vk/typing
2 parents 1de8588 + f737c25 commit a0e3c49

File tree

4 files changed

+77
-163
lines changed

4 files changed

+77
-163
lines changed

ci/test_thirdparty.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@ rapids-logger "Show Numba system info"
4141
python -m numba --sysinfo
4242

4343
rapids-logger "Run Scalar UDF tests"
44-
python -m pytest python/cudf/cudf/tests/dataframe/methods/test_apply.py -W ignore::UserWarning -W ignore::DeprecationWarning:numba.cuda.core.config
44+
python -m pytest python/cudf/cudf/tests/dataframe/methods/test_apply.py -W ignore::UserWarning
4545

4646
rapids-logger "Run GroupBy UDF tests"
47-
python -m pytest python/cudf/cudf/tests/groupby/test_apply.py -k test_groupby_apply_jit -W ignore::UserWarning -W ignore::DeprecationWarning:numba.cuda.core.config
47+
python -m pytest python/cudf/cudf/tests/groupby/test_apply.py -k test_groupby_apply_jit -W ignore::UserWarning
4848

4949
rapids-logger "Run NRT Stats Counting tests"
50-
python -m pytest python/cudf/cudf/tests/private_objects/test_nrt_stats.py -W ignore::UserWarning -W ignore::DeprecationWarning:numba.cuda.core.config
50+
python -m pytest python/cudf/cudf/tests/private_objects/test_nrt_stats.py -W ignore::UserWarning
5151

5252

5353
popd

numba_cuda/numba/cuda/core/config.py

Lines changed: 25 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -124,49 +124,6 @@ def _process_opt_level(opt_level):
124124
return _OptLevel(opt_level)
125125

126126

127-
class _EnvVar(object):
128-
"""Descriptor for configuration values that checks numba.config on access."""
129-
130-
def __init__(self, value, name):
131-
self.name = name
132-
if isinstance(value, _EnvVar):
133-
self.value = value.__get__()
134-
else:
135-
self.value = value
136-
self.check_numba_config()
137-
138-
def check_numba_config(self):
139-
"""Check for conflicting value in numba.config and emit deprecation warning."""
140-
try:
141-
from numba import config as numba_config
142-
143-
if hasattr(numba_config, self.name):
144-
config_value = getattr(numba_config, self.name)
145-
if config_value != self.value:
146-
msg = (
147-
f"Configuration value '{self.name}' is explicitly set "
148-
f"to `{config_value}` in numba.config. "
149-
"numba.config is deprecated for numba-cuda "
150-
"and support for configuration values from it "
151-
"will be removed in a future release. "
152-
"Please use numba.cuda.core.config."
153-
)
154-
warnings.warn(msg, category=DeprecationWarning)
155-
self.value = config_value
156-
else:
157-
# Initialize any missing variables in numba.config
158-
setattr(numba_config, self.name, self.value)
159-
except ImportError:
160-
pass
161-
162-
def __get__(self):
163-
self.check_numba_config()
164-
return self.value
165-
166-
def __set__(self, value):
167-
self.value = value
168-
169-
170127
class _EnvReloader(object):
171128
def __init__(self):
172129
self.reset()
@@ -211,18 +168,7 @@ def update(self, force=False):
211168
self.validate()
212169

213170
def validate(self):
214-
current_module = sys.modules[__name__]
215-
try:
216-
CUDA_USE_NVIDIA_BINDING = current_module.CUDA_USE_NVIDIA_BINDING
217-
except AttributeError:
218-
CUDA_USE_NVIDIA_BINDING = 0
219-
220-
try:
221-
CUDA_PER_THREAD_DEFAULT_STREAM = (
222-
current_module.CUDA_PER_THREAD_DEFAULT_STREAM
223-
)
224-
except AttributeError:
225-
CUDA_PER_THREAD_DEFAULT_STREAM = 0
171+
global CUDA_USE_NVIDIA_BINDING
226172

227173
if CUDA_USE_NVIDIA_BINDING: # noqa: F821
228174
try:
@@ -235,7 +181,7 @@ def validate(self):
235181
)
236182
warnings.warn(msg)
237183

238-
current_module.CUDA_USE_NVIDIA_BINDING = 0
184+
CUDA_USE_NVIDIA_BINDING = False
239185

240186
if CUDA_PER_THREAD_DEFAULT_STREAM: # noqa: F821
241187
warnings.warn(
@@ -250,23 +196,18 @@ def process_environ(self, environ):
250196
def _readenv(name, ctor, default):
251197
value = environ.get(name)
252198
if value is None:
253-
result = default() if callable(default) else default
254-
else:
255-
try:
256-
result = ctor(value)
257-
except Exception:
258-
warnings.warn(
259-
f"Environment variable '{name}' is defined but "
260-
f"its associated value '{value}' could not be "
261-
"parsed.\nThe parse failed with exception:\n"
262-
f"{traceback.format_exc()}",
263-
RuntimeWarning,
264-
)
265-
result = default() if callable(default) else default
266-
var_name = name
267-
if name.startswith("NUMBA_"):
268-
var_name = name[6:]
269-
return _EnvVar(result, var_name)
199+
return default() if callable(default) else default
200+
try:
201+
return ctor(value)
202+
except Exception:
203+
warnings.warn(
204+
f"Environment variable '{name}' is defined but "
205+
f"its associated value '{value}' could not be "
206+
"parsed.\nThe parse failed with exception:\n"
207+
f"{traceback.format_exc()}",
208+
RuntimeWarning,
209+
)
210+
return default
270211

271212
def optional_str(x):
272213
return str(x) if x is not None else None
@@ -348,12 +289,6 @@ def optional_str(x):
348289
# Enable NRT statistics counters
349290
NRT_STATS = _readenv("NUMBA_NRT_STATS", int, 0)
350291

351-
# Enable NRT statistics
352-
CUDA_NRT_STATS = _readenv("NUMBA_CUDA_NRT_STATS", int, 0)
353-
354-
# Enable NRT
355-
CUDA_ENABLE_NRT = _readenv("NUMBA_CUDA_ENABLE_NRT", int, 0)
356-
357292
# How many recently deserialized functions to retain regardless
358293
# of external references
359294
FUNCTION_CACHE_SIZE = _readenv("NUMBA_FUNCTION_CACHE_SIZE", int, 128)
@@ -695,53 +630,26 @@ def which_gdb(path_or_bin):
695630
0,
696631
)
697632

698-
# Inject the configuration values into _descriptors
699-
if not hasattr(self, "_descriptors"):
700-
self._descriptors = {}
701-
633+
# Inject the configuration values into the module globals
702634
for name, value in locals().copy().items():
703635
if name.isupper():
704-
self._descriptors[name] = value
636+
globals()[name] = value
705637

706638

707639
_env_reloader = _EnvReloader()
708640

709641

710-
def __getattr__(name):
711-
"""Module-level __getattr__ provides dynamic behavior for _EnvVar descriptors."""
712-
# Fetch non-descriptor globals directly
713-
if name in globals():
714-
return globals()[name]
715-
716-
if (
717-
hasattr(_env_reloader, "_descriptors")
718-
and name in _env_reloader._descriptors
719-
):
720-
return _env_reloader._descriptors[name].__get__()
721-
722-
raise AttributeError(f"module {__name__} has no attribute {name}")
723-
724-
725-
def __setattr__(name, value):
726-
"""Module-level __setattr__ provides dynamic behavior for _EnvVar descriptors."""
727-
# Update non-descriptor globals
728-
if name in globals():
729-
globals()[name] = value
730-
return
731-
732-
if (
733-
hasattr(_env_reloader, "_descriptors")
734-
and name in _env_reloader._descriptors
735-
):
736-
_env_reloader._descriptors[name].__set__(value)
737-
else:
738-
if not hasattr(_env_reloader, "_descriptors"):
739-
_env_reloader._descriptors = {}
740-
_env_reloader._descriptors[name] = _EnvVar(value, name)
741-
742-
743642
def reload_config():
744643
"""
745644
Reload the configuration from environment variables, if necessary.
746645
"""
747646
_env_reloader.update()
647+
648+
649+
# use numba.core.config if available, otherwise use numba.cuda.core.config
650+
try:
651+
import numba.core.config as _config
652+
653+
sys.modules[__name__] = _config
654+
except ImportError:
655+
pass

numba_cuda/numba/cuda/memory_management/nrt.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020
from numba.cuda.cudadrv import devices
2121
from numba.cuda.api import get_current_device
22-
from numba.cuda.utils import cached_file_read
22+
from numba.cuda.utils import _readenv, cached_file_read
2323
from numba.cuda.cudadrv.linkable_code import CUSource
2424
from numba.cuda.typing.templates import signature
2525

@@ -28,6 +28,22 @@
2828
_nrt_mstats = namedtuple("nrt_mstats", ["alloc", "free", "mi_alloc", "mi_free"])
2929

3030

31+
# Check environment variable or config for NRT statistics enablement
32+
NRT_STATS = _readenv("NUMBA_CUDA_NRT_STATS", bool, False) or getattr(
33+
config, "NUMBA_CUDA_NRT_STATS", False
34+
)
35+
if not hasattr(config, "NUMBA_CUDA_NRT_STATS"):
36+
config.CUDA_NRT_STATS = NRT_STATS
37+
38+
39+
# Check environment variable or config for NRT enablement
40+
ENABLE_NRT = _readenv("NUMBA_CUDA_ENABLE_NRT", bool, False) or getattr(
41+
config, "NUMBA_CUDA_ENABLE_NRT", False
42+
)
43+
if not hasattr(config, "NUMBA_CUDA_ENABLE_NRT"):
44+
config.CUDA_ENABLE_NRT = ENABLE_NRT
45+
46+
3147
def get_include():
3248
"""Return the include path for the NRT header"""
3349
return os.path.dirname(os.path.abspath(__file__))

numba_cuda/numba/cuda/typeconv/rules.py

Lines changed: 32 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import itertools
55
from .typeconv import TypeManager, TypeCastingRules
66
from numba.core import types
7-
from numba.cuda import config
87

98

109
default_type_manager = TypeManager()
@@ -16,58 +15,49 @@ def dump_number_rules():
1615
print(a, "->", b, tm.check_compatible(a, b))
1716

1817

19-
if config.USE_LEGACY_TYPE_SYSTEM: # Old type system
18+
def _init_casting_rules(tm):
19+
tcr = TypeCastingRules(tm)
20+
tcr.safe_unsafe(types.boolean, types.int8)
21+
tcr.safe_unsafe(types.boolean, types.uint8)
2022

21-
def _init_casting_rules(tm):
22-
tcr = TypeCastingRules(tm)
23-
tcr.safe_unsafe(types.boolean, types.int8)
24-
tcr.safe_unsafe(types.boolean, types.uint8)
23+
tcr.promote_unsafe(types.int8, types.int16)
24+
tcr.promote_unsafe(types.uint8, types.uint16)
2525

26-
tcr.promote_unsafe(types.int8, types.int16)
27-
tcr.promote_unsafe(types.uint8, types.uint16)
26+
tcr.promote_unsafe(types.int16, types.int32)
27+
tcr.promote_unsafe(types.uint16, types.uint32)
2828

29-
tcr.promote_unsafe(types.int16, types.int32)
30-
tcr.promote_unsafe(types.uint16, types.uint32)
29+
tcr.promote_unsafe(types.int32, types.int64)
30+
tcr.promote_unsafe(types.uint32, types.uint64)
3131

32-
tcr.promote_unsafe(types.int32, types.int64)
33-
tcr.promote_unsafe(types.uint32, types.uint64)
32+
tcr.safe_unsafe(types.uint8, types.int16)
33+
tcr.safe_unsafe(types.uint16, types.int32)
34+
tcr.safe_unsafe(types.uint32, types.int64)
3435

35-
tcr.safe_unsafe(types.uint8, types.int16)
36-
tcr.safe_unsafe(types.uint16, types.int32)
37-
tcr.safe_unsafe(types.uint32, types.int64)
36+
tcr.safe_unsafe(types.int8, types.float16)
37+
tcr.safe_unsafe(types.int16, types.float32)
38+
tcr.safe_unsafe(types.int32, types.float64)
3839

39-
tcr.safe_unsafe(types.int8, types.float16)
40-
tcr.safe_unsafe(types.int16, types.float32)
41-
tcr.safe_unsafe(types.int32, types.float64)
40+
tcr.unsafe_unsafe(types.int16, types.float16)
41+
tcr.unsafe_unsafe(types.int32, types.float32)
42+
# XXX this is inconsistent with the above; but we want to prefer
43+
# float64 over int64 when typing a heterogeneous operation,
44+
# e.g. `float64 + int64`. Perhaps we need more granularity in the
45+
# conversion kinds.
46+
tcr.safe_unsafe(types.int64, types.float64)
47+
tcr.safe_unsafe(types.uint64, types.float64)
4248

43-
tcr.unsafe_unsafe(types.int16, types.float16)
44-
tcr.unsafe_unsafe(types.int32, types.float32)
45-
# XXX this is inconsistent with the above; but we want to prefer
46-
# float64 over int64 when typing a heterogeneous operation,
47-
# e.g. `float64 + int64`. Perhaps we need more granularity in the
48-
# conversion kinds.
49-
tcr.safe_unsafe(types.int64, types.float64)
50-
tcr.safe_unsafe(types.uint64, types.float64)
49+
tcr.promote_unsafe(types.float16, types.float32)
50+
tcr.promote_unsafe(types.float32, types.float64)
5151

52-
tcr.promote_unsafe(types.float16, types.float32)
53-
tcr.promote_unsafe(types.float32, types.float64)
52+
tcr.safe(types.float32, types.complex64)
53+
tcr.safe(types.float64, types.complex128)
5454

55-
tcr.safe(types.float32, types.complex64)
56-
tcr.safe(types.float64, types.complex128)
55+
tcr.promote_unsafe(types.complex64, types.complex128)
5756

58-
tcr.promote_unsafe(types.complex64, types.complex128)
57+
# Allow integers to cast ot void*
58+
tcr.unsafe_unsafe(types.uintp, types.voidptr)
5959

60-
# Allow integers to cast ot void*
61-
tcr.unsafe_unsafe(types.uintp, types.voidptr)
62-
63-
return tcr
64-
else: # New type system
65-
# Currently left as empty
66-
# If no casting rules are required we may opt to remove
67-
# this framework upon deprecation
68-
def _init_casting_rules(tm):
69-
tcr = TypeCastingRules(tm)
70-
return tcr
60+
return tcr
7161

7262

7363
default_casting_rules = _init_casting_rules(default_type_manager)

0 commit comments

Comments
 (0)