Skip to content

Commit 798dfde

Browse files
committed
Merge 'main' into vk/target_extension
2 parents fa3331f + 9614920 commit 798dfde

File tree

13 files changed

+307
-36
lines changed

13 files changed

+307
-36
lines changed

docs/source/user/cudapysupported.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ The following functions from the :mod:`math` module are supported:
214214
* :func:`math.erf`
215215
* :func:`math.erfc`
216216
* :func:`math.exp`
217+
* :func:`math.exp2`
217218
* :func:`math.expm1`
218219
* :func:`math.fabs`
219220
* :func:`math.frexp`

numba_cuda/numba/cuda/bf16.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
3+
import sys
34

45
from numba.cuda._internal.cuda_bf16 import (
56
typing_registry,
@@ -191,14 +192,12 @@ def exp_ol(a):
191192
return _make_unary(a, hexp)
192193

193194

194-
try:
195-
from math import exp2
195+
if sys.version_info >= (3, 11):
196196

197-
@overload(exp2, target="cuda")
197+
@overload(math.exp2, target="cuda")
198198
def exp2_ol(a):
199199
return _make_unary(a, hexp2)
200-
except ImportError:
201-
pass
200+
202201

203202
## Public aliases using Numba/Numpy-style type names
204203
# Floating-point

numba_cuda/numba/cuda/cudadrv/driver.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import importlib
4444
import numpy as np
4545
from collections import namedtuple, deque
46+
from uuid import UUID
4647

4748

4849
from numba.cuda.cext import mviewbuf
@@ -536,11 +537,10 @@ def from_identity(self, identity):
536537
if d.get_device_identity() == identity:
537538
return d
538539
else:
539-
errmsg = (
540-
"No device of {} is found. "
540+
raise RuntimeError(
541+
f"No device of {identity} is found. "
541542
"Target device may not be visible in this process."
542-
).format(identity)
543-
raise RuntimeError(errmsg)
543+
)
544544

545545
def __init__(self, devnum):
546546
result = driver.cuDeviceGet(devnum)
@@ -551,8 +551,6 @@ def __init__(self, devnum):
551551
if devnum != got_devnum:
552552
raise RuntimeError(msg)
553553

554-
self.attributes = {}
555-
556554
# Read compute capability
557555
self.compute_capability = (
558556
self.COMPUTE_CAPABILITY_MAJOR,
@@ -562,20 +560,13 @@ def __init__(self, devnum):
562560
# Read name
563561
bufsz = 128
564562
buf = driver.cuDeviceGetName(bufsz, self.id)
565-
name = buf.split(b"\x00")[0]
563+
name = buf.split(b"\x00", 1)[0]
566564

567565
self.name = name
568566

569567
# Read UUID
570568
uuid = driver.cuDeviceGetUuid(self.id)
571-
uuid_vals = tuple(uuid.bytes)
572-
573-
b = "%02x"
574-
b2 = b * 2
575-
b4 = b * 4
576-
b6 = b * 6
577-
fmt = f"GPU-{b4}-{b2}-{b2}-{b2}-{b6}"
578-
self.uuid = fmt % uuid_vals
569+
self.uuid = f"GPU-{UUID(bytes=uuid.bytes)}"
579570

580571
self.primary_context = None
581572

@@ -587,7 +578,7 @@ def get_device_identity(self):
587578
}
588579

589580
def __repr__(self):
590-
return "<CUDA device %d '%s'>" % (self.id, self.name)
581+
return f"<CUDA device {self.id:d} '{self.name}'>"
591582

592583
def __getattr__(self, attr):
593584
"""Read attributes lazily"""
@@ -603,9 +594,7 @@ def __hash__(self):
603594
return hash(self.id)
604595

605596
def __eq__(self, other):
606-
if isinstance(other, Device):
607-
return self.id == other.id
608-
return False
597+
return isinstance(other, Device) and self.id == other.id
609598

610599
def __ne__(self, other):
611600
return not (self == other)
@@ -615,8 +604,8 @@ def get_primary_context(self):
615604
Returns the primary context for the device.
616605
Note: it is not pushed to the CPU thread.
617606
"""
618-
if self.primary_context is not None:
619-
return self.primary_context
607+
if (ctx := self.primary_context) is not None:
608+
return ctx
620609

621610
met_requirement_for_device(self)
622611
# create primary context
@@ -637,8 +626,8 @@ def release_primary_context(self):
637626

638627
def reset(self):
639628
try:
640-
if self.primary_context is not None:
641-
self.primary_context.reset()
629+
if (ctx := self.primary_context) is not None:
630+
ctx.reset()
642631
self.release_primary_context()
643632
finally:
644633
# reset at the driver level

numba_cuda/numba/cuda/cudamath.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
33

4+
import sys
45
import math
56
from numba.cuda import types
67
from numba.cuda.typing.templates import ConcreteTemplate, signature, Registry
@@ -58,6 +59,10 @@ class Math_unary_with_fp16(ConcreteTemplate):
5859
]
5960

6061

62+
if sys.version_info >= (3, 11):
63+
Math_unary_with_fp16 = infer_global(math.exp2)(Math_unary_with_fp16)
64+
65+
6166
@infer_global(math.atan2)
6267
class Math_atan2(ConcreteTemplate):
6368
key = math.atan2

numba_cuda/numba/cuda/fp16.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
33

4+
import sys
45
import numba.cuda.types as types
56
from numba.cuda._internal.cuda_fp16 import (
67
typing_registry,
@@ -190,6 +191,13 @@ def exp_ol(a):
190191
return _make_unary(a, hexp)
191192

192193

194+
if sys.version_info >= (3, 11):
195+
196+
@overload(math.exp2, target="cuda")
197+
def exp2_ol(a):
198+
return _make_unary(a, hexp2)
199+
200+
193201
@overload(math.tanh, target="cuda")
194202
def tanh_ol(a):
195203
return _make_unary(a, htanh)

numba_cuda/numba/cuda/mathimpl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
33

4+
import sys
45
import math
56
import operator
67
from llvmlite import ir
@@ -25,6 +26,8 @@
2526
unarys += [("floor", "floorf", math.floor)]
2627
unarys += [("fabs", "fabsf", math.fabs)]
2728
unarys += [("exp", "expf", math.exp)]
29+
if sys.version_info >= (3, 11):
30+
unarys += [("exp2", "exp2f", math.exp2)]
2831
unarys += [("expm1", "expm1f", math.expm1)]
2932
unarys += [("erf", "erff", math.erf)]
3033
unarys += [("erfc", "erfcf", math.erfc)]
@@ -330,6 +333,7 @@ def tanhf_impl_fastmath():
330333
impl_unary_int(math.tanh, int64, libdevice.tanh)
331334
impl_unary_int(math.tanh, uint64, libdevice.tanh)
332335

336+
333337
# Complex power implementations - translations of _Py_c_pow from CPython
334338
# https://github.com/python/cpython/blob/a755410e054e1e2390de5830befc08fe80706c66/Objects/complexobject.c#L123-L151
335339
#

numba_cuda/numba/cuda/tests/benchmarks/test_kernel_launch.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
),
2424
id="torch",
2525
),
26+
param(
27+
lambda: pytest.importorskip("cupy").empty(128, dtype=np.float32),
28+
id="cupy",
29+
),
2630
],
2731
)
2832
def test_one_arg(benchmark, array_func):
@@ -58,6 +62,13 @@ def bench(func, arr):
5862
],
5963
id="torch",
6064
),
65+
param(
66+
lambda: [
67+
pytest.importorskip("cupy").empty(128, dtype=np.float32)
68+
for _ in range(len(string.ascii_lowercase))
69+
],
70+
id="cupy",
71+
),
6172
],
6273
)
6374
def test_many_args(benchmark, array_func):

numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
33

4+
import sys
45
import numpy as np
56
from ml_dtypes import bfloat16 as mldtypes_bf16
67
from numba import cuda
@@ -134,12 +135,8 @@ def test_math_bindings(self):
134135
self.skip_unsupported()
135136

136137
exp_functions = [math.exp]
137-
try:
138-
from math import exp2
139-
140-
exp_functions += [exp2]
141-
except ImportError:
142-
pass
138+
if sys.version_info >= (3, 11):
139+
exp_functions += [math.exp2]
143140

144141
functions = [
145142
math.trunc,

numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
33

4+
import sys
45
from typing import List
56
from dataclasses import dataclass, field
67
from numba import cuda
@@ -142,6 +143,19 @@ def test_expf(self):
142143
),
143144
)
144145

146+
@unittest.skipUnless(sys.version_info >= (3, 11), "Python 3.11+ required")
147+
def test_exp2f(self):
148+
from math import exp2
149+
150+
self._test_fast_math_unary(
151+
exp2,
152+
FastMathCriterion(
153+
fast_expected=["ex2.approx.ftz.f32 "],
154+
prec_expected=["ex2.approx.f32 "],
155+
prec_unexpected=["ex2.approx.ftz.f32 "],
156+
),
157+
)
158+
145159
def test_logf(self):
146160
# Look for constant used to convert from log base 2 to log base e
147161
self._test_fast_math_unary(

numba_cuda/numba/cuda/tests/cudapy/test_math.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
33

4+
import sys
45
import numpy as np
56
from numba.cuda.testing import (
67
skip_unless_cc_53,
@@ -84,6 +85,11 @@ def math_exp(A, B):
8485
B[i] = math.exp(A[i])
8586

8687

88+
def math_exp2(A, B):
89+
i = cuda.grid(1)
90+
B[i] = math.exp2(A[i])
91+
92+
8793
def math_erf(A, B):
8894
i = cuda.grid(1)
8995
B[i] = math.erf(A[i])
@@ -401,6 +407,8 @@ def test_math_fp16(self):
401407
self.unary_template_float16(math_sqrt, np.sqrt)
402408
self.unary_template_float16(math_ceil, np.ceil)
403409
self.unary_template_float16(math_floor, np.floor)
410+
if sys.version_info >= (3, 11):
411+
self.unary_template_float16(math_exp2, np.exp2)
404412

405413
@skip_on_cudasim("numpy does not support trunc for float16")
406414
@skip_unless_cc_53
@@ -496,6 +504,16 @@ def test_math_exp(self):
496504
self.unary_template_int64(math_exp, np.exp)
497505
self.unary_template_uint64(math_exp, np.exp)
498506

507+
# ---------------------------------------------------------------------------
508+
# test_math_exp2
509+
510+
@unittest.skipUnless(sys.version_info >= (3, 11), "Python 3.11+ required")
511+
def test_math_exp2(self):
512+
self.unary_template_float32(math_exp2, np.exp2)
513+
self.unary_template_float64(math_exp2, np.exp2)
514+
self.unary_template_int64(math_exp2, np.exp2)
515+
self.unary_template_uint64(math_exp2, np.exp2)
516+
499517
# ---------------------------------------------------------------------------
500518
# test_math_expm1
501519

0 commit comments

Comments
 (0)