Skip to content

Commit 0afebbd

Browse files
committed
feat: add math.exp2
1 parent 65351bc commit 0afebbd

File tree

7 files changed

+34
-13
lines changed

7 files changed

+34
-13
lines changed

numba_cuda/numba/cuda/bf16.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
3+
import sys
34

45
from numba.cuda._internal.cuda_bf16 import (
56
typing_registry,
@@ -191,14 +192,10 @@ def exp_ol(a):
191192
return _make_unary(a, hexp)
192193

193194

194-
try:
195-
from math import exp2
196-
197-
@overload(exp2, target="cuda")
195+
if sys.version_info >= (3, 11):
196+
@overload(math.exp2, target="cuda")
198197
def exp2_ol(a):
199198
return _make_unary(a, hexp2)
200-
except ImportError:
201-
pass
202199

203200
## Public aliases using Numba/Numpy-style type names
204201
# Floating-point

numba_cuda/numba/cuda/cudamath.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
33

4+
import sys
45
import math
56
from numba.core import types
67
from numba.cuda.typing.templates import ConcreteTemplate, signature, Registry
@@ -57,6 +58,9 @@ class Math_unary_with_fp16(ConcreteTemplate):
5758
signature(types.float16, types.float16),
5859
]
5960

61+
if sys.version_info >= (3, 11):
62+
Math_unary_with_fp16 = infer_global(math.exp2)(Math_unary_with_fp16)
63+
6064

6165
@infer_global(math.atan2)
6266
class Math_atan2(ConcreteTemplate):

numba_cuda/numba/cuda/fp16.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def log10_ol(a):
185185
return _make_unary(a, hlog10)
186186

187187

188-
@overload(math.exp, target="cuda")
188+
@overload(, target="cuda")
189189
def exp_ol(a):
190190
return _make_unary(a, hexp)
191191

numba_cuda/numba/cuda/mathimpl.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
33

4+
import sys
45
import math
56
import operator
67
from llvmlite import ir
@@ -25,6 +26,8 @@
2526
unarys += [("floor", "floorf", math.floor)]
2627
unarys += [("fabs", "fabsf", math.fabs)]
2728
unarys += [("exp", "expf", math.exp)]
29+
if sys.version_info >= (3, 11):
30+
unarys += [("exp2", "exp2f", math.exp2)]
2831
unarys += [("expm1", "expm1f", math.expm1)]
2932
unarys += [("erf", "erff", math.erf)]
3033
unarys += [("erfc", "erfcf", math.erfc)]

numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
33

4+
import sys
45
import numpy as np
56
from ml_dtypes import bfloat16 as mldtypes_bf16
67

@@ -135,12 +136,8 @@ def test_math_bindings(self):
135136
self.skip_unsupported()
136137

137138
exp_functions = [math.exp]
138-
try:
139-
from math import exp2
140-
141-
exp_functions += [exp2]
142-
except ImportError:
143-
pass
139+
if sys.version_info >= (3, 11):
140+
exp_functions += [math.exp2]
144141

145142
functions = [
146143
math.trunc,

numba_cuda/numba/cuda/tests/cudapy/test_math.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
33

4+
import sys
45
import numpy as np
56
from numba.cuda.testing import (
67
skip_unless_cc_53,
@@ -83,6 +84,11 @@ def math_exp(A, B):
8384
B[i] = math.exp(A[i])
8485

8586

87+
def math_exp2(A, B):
88+
i = cuda.grid(1)
89+
B[i] = math.exp2(A[i])
90+
91+
8692
def math_erf(A, B):
8793
i = cuda.grid(1)
8894
B[i] = math.erf(A[i])
@@ -400,6 +406,8 @@ def test_math_fp16(self):
400406
self.unary_template_float16(math_sqrt, np.sqrt)
401407
self.unary_template_float16(math_ceil, np.ceil)
402408
self.unary_template_float16(math_floor, np.floor)
409+
if sys.version_info >= (3, 11):
410+
self.unary_template_float16(math_exp2, np.exp2)
403411

404412
@skip_on_cudasim("numpy does not support trunc for float16")
405413
@skip_unless_cc_53
@@ -494,6 +502,15 @@ def test_math_exp(self):
494502
self.unary_template_float64(math_exp, np.exp)
495503
self.unary_template_int64(math_exp, np.exp)
496504
self.unary_template_uint64(math_exp, np.exp)
505+
# ---------------------------------------------------------------------------
506+
# test_math_exp2
507+
508+
@unittest.skipUnless(sys.version_info >= (3, 11), "Python 3.11+ required")
509+
def test_math_exp2(self):
510+
self.unary_template_float32(math_exp2, np.exp2)
511+
self.unary_template_float64(math_exp2, np.exp2)
512+
self.unary_template_int64(math_exp2, np.exp2)
513+
self.unary_template_uint64(math_exp2, np.exp2)
497514

498515
# ---------------------------------------------------------------------------
499516
# test_math_expm1

numba_cuda/numba/cuda/typing/mathdecl.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: BSD-2-Clause
33

4+
import sys
45
import math
56
from numba.core import types
67
from numba.cuda.typing.templates import ConcreteTemplate, signature, Registry
@@ -43,6 +44,8 @@ class Math_unary(ConcreteTemplate):
4344
signature(types.float64, types.float64),
4445
]
4546

47+
if sys.version_info >= (3, 11):
48+
Math_unary = infer_global(math.exp2)(Math_unary)
4649

4750
@infer_global(math.atan2)
4851
class Math_atan2(ConcreteTemplate):

0 commit comments

Comments
 (0)