Skip to content

Commit 9614920

Browse files
kaeun97gmarkall
andauthored
feat: add support for math.exp2 (#541)
Adding support for `math.exp2`. Follow up of numba/numba#10276. --------- Co-authored-by: Graham Markall <[email protected]>
1 parent 3282e93 commit 9614920

File tree

9 files changed

+62
-11
lines changed

9 files changed

+62
-11
lines changed

docs/source/user/cudapysupported.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ The following functions from the :mod:`math` module are supported:
214214
* :func:`math.erf`
215215
* :func:`math.erfc`
216216
* :func:`math.exp`
217+
* :func:`math.exp2`
217218
* :func:`math.expm1`
218219
* :func:`math.fabs`
219220
* :func:`math.frexp`

numba_cuda/numba/cuda/bf16.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
3+
import sys
34

45
from numba.cuda._internal.cuda_bf16 import (
56
typing_registry,
@@ -191,14 +192,12 @@ def exp_ol(a):
191192
return _make_unary(a, hexp)
192193

193194

194-
try:
195-
from math import exp2
195+
if sys.version_info >= (3, 11):
196196

197-
@overload(exp2, target="cuda")
197+
@overload(math.exp2, target="cuda")
198198
def exp2_ol(a):
199199
return _make_unary(a, hexp2)
200-
except ImportError:
201-
pass
200+
202201

203202
## Public aliases using Numba/Numpy-style type names
204203
# Floating-point

numba_cuda/numba/cuda/cudamath.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
33

4+
import sys
45
import math
56
from numba.cuda import types
67
from numba.cuda.typing.templates import ConcreteTemplate, signature, Registry
@@ -58,6 +59,10 @@ class Math_unary_with_fp16(ConcreteTemplate):
5859
]
5960

6061

62+
if sys.version_info >= (3, 11):
63+
Math_unary_with_fp16 = infer_global(math.exp2)(Math_unary_with_fp16)
64+
65+
6166
@infer_global(math.atan2)
6267
class Math_atan2(ConcreteTemplate):
6368
key = math.atan2

numba_cuda/numba/cuda/fp16.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
33

4+
import sys
45
import numba.cuda.types as types
56
from numba.cuda._internal.cuda_fp16 import (
67
typing_registry,
@@ -190,6 +191,13 @@ def exp_ol(a):
190191
return _make_unary(a, hexp)
191192

192193

194+
if sys.version_info >= (3, 11):
195+
196+
@overload(math.exp2, target="cuda")
197+
def exp2_ol(a):
198+
return _make_unary(a, hexp2)
199+
200+
193201
@overload(math.tanh, target="cuda")
194202
def tanh_ol(a):
195203
return _make_unary(a, htanh)

numba_cuda/numba/cuda/mathimpl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
33

4+
import sys
45
import math
56
import operator
67
from llvmlite import ir
@@ -25,6 +26,8 @@
2526
unarys += [("floor", "floorf", math.floor)]
2627
unarys += [("fabs", "fabsf", math.fabs)]
2728
unarys += [("exp", "expf", math.exp)]
29+
if sys.version_info >= (3, 11):
30+
unarys += [("exp2", "exp2f", math.exp2)]
2831
unarys += [("expm1", "expm1f", math.expm1)]
2932
unarys += [("erf", "erff", math.erf)]
3033
unarys += [("erfc", "erfcf", math.erfc)]
@@ -330,6 +333,7 @@ def tanhf_impl_fastmath():
330333
impl_unary_int(math.tanh, int64, libdevice.tanh)
331334
impl_unary_int(math.tanh, uint64, libdevice.tanh)
332335

336+
333337
# Complex power implementations - translations of _Py_c_pow from CPython
334338
# https://github.com/python/cpython/blob/a755410e054e1e2390de5830befc08fe80706c66/Objects/complexobject.c#L123-L151
335339
#

numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
33

4+
import sys
45
import numpy as np
56
from ml_dtypes import bfloat16 as mldtypes_bf16
67
from numba import cuda
@@ -134,12 +135,8 @@ def test_math_bindings(self):
134135
self.skip_unsupported()
135136

136137
exp_functions = [math.exp]
137-
try:
138-
from math import exp2
139-
140-
exp_functions += [exp2]
141-
except ImportError:
142-
pass
138+
if sys.version_info >= (3, 11):
139+
exp_functions += [math.exp2]
143140

144141
functions = [
145142
math.trunc,

numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
33

4+
import sys
45
from typing import List
56
from dataclasses import dataclass, field
67
from numba import cuda
@@ -142,6 +143,19 @@ def test_expf(self):
142143
),
143144
)
144145

146+
@unittest.skipUnless(sys.version_info >= (3, 11), "Python 3.11+ required")
147+
def test_exp2f(self):
148+
from math import exp2
149+
150+
self._test_fast_math_unary(
151+
exp2,
152+
FastMathCriterion(
153+
fast_expected=["ex2.approx.ftz.f32 "],
154+
prec_expected=["ex2.approx.f32 "],
155+
prec_unexpected=["ex2.approx.ftz.f32 "],
156+
),
157+
)
158+
145159
def test_logf(self):
146160
# Look for constant used to convert from log base 2 to log base e
147161
self._test_fast_math_unary(

numba_cuda/numba/cuda/tests/cudapy/test_math.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
33

4+
import sys
45
import numpy as np
56
from numba.cuda.testing import (
67
skip_unless_cc_53,
@@ -84,6 +85,11 @@ def math_exp(A, B):
8485
B[i] = math.exp(A[i])
8586

8687

88+
def math_exp2(A, B):
89+
i = cuda.grid(1)
90+
B[i] = math.exp2(A[i])
91+
92+
8793
def math_erf(A, B):
8894
i = cuda.grid(1)
8995
B[i] = math.erf(A[i])
@@ -401,6 +407,8 @@ def test_math_fp16(self):
401407
self.unary_template_float16(math_sqrt, np.sqrt)
402408
self.unary_template_float16(math_ceil, np.ceil)
403409
self.unary_template_float16(math_floor, np.floor)
410+
if sys.version_info >= (3, 11):
411+
self.unary_template_float16(math_exp2, np.exp2)
404412

405413
@skip_on_cudasim("numpy does not support trunc for float16")
406414
@skip_unless_cc_53
@@ -496,6 +504,16 @@ def test_math_exp(self):
496504
self.unary_template_int64(math_exp, np.exp)
497505
self.unary_template_uint64(math_exp, np.exp)
498506

507+
# ---------------------------------------------------------------------------
508+
# test_math_exp2
509+
510+
@unittest.skipUnless(sys.version_info >= (3, 11), "Python 3.11+ required")
511+
def test_math_exp2(self):
512+
self.unary_template_float32(math_exp2, np.exp2)
513+
self.unary_template_float64(math_exp2, np.exp2)
514+
self.unary_template_int64(math_exp2, np.exp2)
515+
self.unary_template_uint64(math_exp2, np.exp2)
516+
499517
# ---------------------------------------------------------------------------
500518
# test_math_expm1
501519

numba_cuda/numba/cuda/typing/mathdecl.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
33

4+
import sys
45
import math
56
from numba.cuda import types
67
from numba.cuda.typing.templates import ConcreteTemplate, signature, Registry
@@ -44,6 +45,10 @@ class Math_unary(ConcreteTemplate):
4445
]
4546

4647

48+
if sys.version_info >= (3, 11):
49+
Math_unary = infer_global(math.exp2)(Math_unary)
50+
51+
4752
@infer_global(math.atan2)
4853
class Math_atan2(ConcreteTemplate):
4954
cases = [

0 commit comments

Comments
 (0)